Skip to content

Commit

Permalink
Draft: LLM observability
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed May 7, 2024
1 parent fb29898 commit 83294cb
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package dev.langchain4j.model.chat.observability;

import dev.langchain4j.Experimental;

/**
* TODO
*/
// TODO name
// TODO package
@Experimental
public interface ChatLanguageModelListener {


/**
* TODO
*
* @param request
*/
// TODO names
@Experimental
default void onRequest(String id, ChatLanguageModelRequest request) {

}

/**
* TODO
*
* @param response
*/
// TODO names
// TODO accept Response<AiMessage> ?
@Experimental
default void onResponse(String id, ChatLanguageModelResponse response) {

}

// TODO onError?
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package dev.langchain4j.model.chat.observability;

import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import lombok.Builder;

import java.util.List;

/**
* TODO
*/
@Builder
@Experimental
public class ChatLanguageModelRequest {

private final String system; // gen_ai.system
private final String modelName; // gen_ai.request.model

// TODO group into "Parameters" POJO and re-use in ChatLanguageModel.generate()?
private final Double temperature; // gen_ai.request.temperature
private final Double topP; // gen_ai.request.top_p
private final Integer maxTokens; // gen_ai.request.max_tokens

// event
private final List<ChatMessage> messages; // gen_ai.prompt // TODO copy
private final List<ToolSpecification> toolSpecifications; // TODO copy

public List<ChatMessage> messages() {
return messages;
}

// TODO other getters
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package dev.langchain4j.model.chat.observability;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;

/**
* TODO
*/
@Builder
@Experimental
public class ChatLanguageModelResponse {

private final String id; // gen_ai.response.id
private final String modelName; // gen_ai.response.model
private final TokenUsage tokenUsage; // gen_ai.usage.completion_tokens + gen_ai.usage.prompt_tokens
private final FinishReason finishReason; // gen_ai.response.finish_reasons

// event
private final AiMessage aiMessage; // gen_ai.completion

public AiMessage aiMessage() {
return aiMessage;
}

// TODO other getters
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,27 @@
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.observability.ChatLanguageModelListener;
import dev.langchain4j.model.chat.observability.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.observability.ChatLanguageModelResponse;
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;

import java.net.Proxy;
import java.time.Duration;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
Expand Down Expand Up @@ -47,6 +54,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final String user;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final ChatLanguageModelListener listener;

@Builder
public OpenAiChatModel(String baseUrl,
Expand All @@ -69,7 +77,8 @@ public OpenAiChatModel(String baseUrl,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders) {
Map<String, String> customHeaders,
ChatLanguageModelListener listener) {

baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
Expand Down Expand Up @@ -105,6 +114,7 @@ public OpenAiChatModel(String baseUrl,
this.user = user;
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.listener = listener;
}

public String modelName() {
Expand Down Expand Up @@ -153,12 +163,39 @@ private Response<AiMessage> generate(List<ChatMessage> messages,

ChatCompletionRequest request = requestBuilder.build();

String id = Utils.randomUUID();
if (listener != null) {
listener.onRequest(id, ChatLanguageModelRequest.builder()
.system(null) // TODO
.modelName(request.model())
.temperature(request.temperature())
.topP(request.topP())
.maxTokens(request.maxTokens())
.messages(new ArrayList<>(messages))
.toolSpecifications(copyIfNotNull(toolSpecifications))
.build());
}

ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries);

AiMessage aiMessage = aiMessageFrom(response);
TokenUsage tokenUsage = tokenUsageFrom(response.usage());
FinishReason finishReason = finishReasonFrom(response.choices().get(0).finishReason());

if (listener != null) {
listener.onResponse(id, ChatLanguageModelResponse.builder()
.id(response.id())
.modelName(response.model())
.tokenUsage(tokenUsage)
.finishReason(finishReason)
.aiMessage(aiMessage)
.build());
}

return Response.from(
aiMessageFrom(response),
tokenUsageFrom(response.usage()),
finishReasonFrom(response.choices().get(0).finishReason())
aiMessage,
tokenUsage,
finishReason
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.observability.ChatLanguageModelListener;
import dev.langchain4j.model.chat.observability.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.observability.ChatLanguageModelResponse;
import dev.langchain4j.model.openai.spi.OpenAiStreamingChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;

import java.net.Proxy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
Expand Down Expand Up @@ -50,6 +54,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
private final String user;
private final Tokenizer tokenizer;
private final boolean isOpenAiModel;
private final ChatLanguageModelListener listener;

@Builder
public OpenAiStreamingChatModel(String baseUrl,
Expand All @@ -71,7 +76,8 @@ public OpenAiStreamingChatModel(String baseUrl,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders) {
Map<String, String> customHeaders,
ChatLanguageModelListener listener) {

timeout = getOrDefault(timeout, ofSeconds(60));

Expand Down Expand Up @@ -102,6 +108,7 @@ public OpenAiStreamingChatModel(String baseUrl,
this.user = user;
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.isOpenAiModel = isOpenAiModel(this.modelName);
this.listener = listener;
}

public String modelName() {
Expand Down Expand Up @@ -152,6 +159,19 @@ private void generate(List<ChatMessage> messages,

ChatCompletionRequest request = requestBuilder.build();

String id = Utils.randomUUID();
if (listener != null) {
listener.onRequest(id, ChatLanguageModelRequest.builder()
.system(null) // TODO
.modelName(request.model())
.temperature(request.temperature())
.topP(request.topP())
.maxTokens(request.maxTokens())
.messages(new ArrayList<>(messages))
.toolSpecifications(copyIfNotNull(toolSpecifications))
.build());
}

int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);

Expand All @@ -165,6 +185,17 @@ private void generate(List<ChatMessage> messages,
if (!isOpenAiModel) {
response = removeTokenUsage(response);
}

if (listener != null) {
listener.onResponse(id, ChatLanguageModelResponse.builder()
.id(null) // TODO
.modelName(null) // TODO
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.aiMessage(response.content())
.build());
}

handler.onComplete(response);
})
.onError(handler::onError)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.observability.ChatLanguageModelListener;
import dev.langchain4j.model.chat.observability.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.observability.ChatLanguageModelResponse;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;

import java.util.Base64;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.from;
Expand Down Expand Up @@ -216,6 +220,8 @@ void should_execute_multiple_tools_in_parallel_then_answer() {
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_3_5_TURBO_1106) // supports parallel function calling
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build();

UserMessage userMessage = userMessage("2+2=? 3+3=?");
Expand Down Expand Up @@ -463,4 +469,58 @@ public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionReque
// then
assertThat(tokenCount).isEqualTo(42);
}

@Test
void should_listen_request_and_response() {

// given
AtomicReference<String> requestIdReference = new AtomicReference<>();
AtomicReference<ChatLanguageModelRequest> requestReference = new AtomicReference<>();

AtomicReference<String> responseIdReference = new AtomicReference<>();
AtomicReference<ChatLanguageModelResponse> responseReference = new AtomicReference<>();

ChatLanguageModelListener listener = new ChatLanguageModelListener() {

@Override
public void onRequest(String id, ChatLanguageModelRequest request) {
requestIdReference.set(id);
requestReference.set(request);
}

@Override
public void onResponse(String id, ChatLanguageModelResponse response) {
responseIdReference.set(id);
responseReference.set(response);
}
};

OpenAiChatModel model = OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.logRequests(true)
.logResponses(true)
.listener(listener)
// TODO add other params
.build();

String userMessage = "hello";

// when
String answer = model.generate(userMessage);

// then
ChatLanguageModelRequest request = requestReference.get();
assertThat(request.messages()).containsExactly(UserMessage.from(userMessage));
// TODO assert all params

ChatLanguageModelResponse response = responseReference.get();
assertThat(response.aiMessage().text()).isEqualTo(answer);
// TODO assert all params

assertThat(requestIdReference.get())
.isNotBlank()
.isEqualTo(responseIdReference.get());
}
}

0 comments on commit 83294cb

Please sign in to comment.