Skip to content

Commit

Permalink
Token usage calculation bug (#280): slight refactoring + added a test
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-learning-dynamo committed Nov 24, 2023
1 parent 7cfa1a5 commit c671ebe
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
package dev.langchain4j.service;

import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.service.ServiceOutputParser.outputFormatInstructions;
import static java.util.Collections.singletonMap;
import static java.util.stream.Collectors.joining;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.data.message.AiMessage;
Expand All @@ -21,22 +14,22 @@
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.*;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.service.ServiceOutputParser.outputFormatInstructions;
import static java.util.Collections.singletonMap;
import static java.util.stream.Collectors.joining;

class DefaultAiServices<T> extends AiServices<T> {

Expand Down Expand Up @@ -138,21 +131,20 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio
return new AiServiceTokenStream(messages, context, memoryId); // TODO moderation
}

Response<AiMessage> response = context.toolSpecifications != null ?
context.chatModel.generate(messages, context.toolSpecifications) :
context.chatModel.generate(messages);
Response<AiMessage> response = context.toolSpecifications == null
? context.chatModel.generate(messages)
: context.chatModel.generate(messages, context.toolSpecifications);
TokenUsage tokenUsageAccumulator = response.tokenUsage();

verifyModerationIfNeeded(moderationFuture);

TokenUsage tokenUsage = new TokenUsage();
ToolExecutionRequest toolExecutionRequest;
while (true) { // TODO limit number of cycles

if (context.hasChatMemory()) {
context.chatMemory(memoryId).add(response.content());
}

tokenUsage = tokenUsage.add(response.tokenUsage());
toolExecutionRequest = response.content().toolExecutionRequest();
if (toolExecutionRequest == null) {
break;
Expand All @@ -167,9 +159,10 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio
chatMemory.add(toolExecutionResultMessage);

response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);
tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());
}

response = Response.from(response.content(), tokenUsage, response.finishReason());
response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
return ServiceOutputParser.parse(response, method.getReturnType());
}

Expand Down Expand Up @@ -252,7 +245,6 @@ private static ChatMessage prepareUserMessage(Method method, Object[] args) {
}



private Optional<Object> memoryId(Method method, Object[] args) {
Parameter[] parameters = method.getParameters();
for (int i = 0; i < parameters.length; i++) {
Expand All @@ -269,7 +261,6 @@ private Optional<Object> memoryId(Method method, Object[] args) {
}



private static String getUserName(Parameter[] parameters, Object[] args) {
for (int i = 0; i < parameters.length; i++) {
if (parameters[i].isAnnotationPresent(UserName.class)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiModerationModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import lombok.Builder;
Expand Down Expand Up @@ -657,7 +658,7 @@ public void deleteMessages(Object memoryId) {

interface Assistant {

String chat(String userMessage);
Response<AiMessage> chat(String userMessage);
}

static class Calculator {
Expand All @@ -683,9 +684,12 @@ void should_execute_tool_then_answer() {

String userMessage = "What is the square root of 485906798473894056 in scientific notation?";

String answer = assistant.chat(userMessage);
Response<AiMessage> answer = assistant.chat(userMessage);

assertThat(answer).contains("6.97");
assertThat(answer.content().text()).contains("6.97");
assertThat(answer.tokenUsage().inputTokenCount()).isEqualTo(72 + 110);
assertThat(answer.tokenUsage().outputTokenCount()).isEqualTo(21 + 28);
assertThat(answer.tokenUsage().totalTokenCount()).isEqualTo(93 + 138);


verify(calculator).squareRoot(485906798473894056.0);
Expand Down

0 comments on commit c671ebe

Please sign in to comment.