Skip to content

Commit

Permalink
Fix #601: Do not restrict Map key/value types when deserializing from…
Browse files Browse the repository at this point in the history
… JSON
  • Loading branch information
langchain4j committed Apr 11, 2024
1 parent 9c69d96 commit 002f65d
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package dev.langchain4j.internal;

import com.google.gson.*;
import com.google.gson.reflect.TypeToken;
import com.google.gson.stream.JsonWriter;

import java.io.*;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.time.LocalDateTime;
Expand Down Expand Up @@ -40,8 +38,6 @@ class GsonJsonCodec implements Json.JsonCodec {
)
.create();

public static final Type MAP_TYPE = new TypeToken<Map<String, String>>() {}.getType();

@Override
public String toJson(Object o) {
return GSON.toJson(o);
Expand All @@ -60,9 +56,6 @@ public String toJson(Object o) {
*/
@Override
public <T> T fromJson(String json, Class<T> type) {
if (type == Map.class) {
return GSON.fromJson(json, MAP_TYPE);
}
return GSON.fromJson(json, type);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package dev.langchain4j.internal;

import com.google.gson.JsonSyntaxException;
import com.google.gson.reflect.TypeToken;
import org.assertj.core.api.WithAssertions;
import org.assertj.core.data.MapEntry;
import org.junit.jupiter.api.Test;

import java.io.IOException;
Expand All @@ -12,6 +12,8 @@
import java.util.HashMap;
import java.util.Map;

import static java.util.Arrays.asList;

class GsonJsonCodecTest implements WithAssertions {
public static class Example {
public String name;
Expand Down Expand Up @@ -61,26 +63,20 @@ public void test() throws Exception {
public void test_map() {
GsonJsonCodec codec = new GsonJsonCodec();
{
TypeToken<Map<String, String>> tt = new TypeToken<Map<String, String>>() {
};
assertThat(GsonJsonCodec.MAP_TYPE).isEqualTo(tt.getType());
}

{
Map<String, String> expectedMap = new HashMap<>();
Map<Object, Object> expectedMap = new HashMap<>();
expectedMap.put("a", "b");

assertThat(codec.toJson(expectedMap))
.isEqualTo("{\n \"a\": \"b\"\n}");

assertThat(codec.fromJson("{\"a\": \"b\"}", (Class<?>)expectedMap.getClass()))
assertThat(codec.fromJson("{\"a\": \"b\"}", (Class<?>) expectedMap.getClass()))
.isEqualTo(expectedMap);
}
{
assertThatExceptionOfType(JsonSyntaxException.class)
.isThrownBy(() -> codec.fromJson("{\"a\": [1, 2]}", Map.class));
}
Map<Object, Object> map = codec.fromJson("{\"a\": [1, 2]}", Map.class);

assertThat(map).containsExactly(MapEntry.entry("a", asList(1.0, 2.0)));
}
}

public static class DateExample {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package dev.langchain4j.service;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.*;
import static dev.langchain4j.service.StreamingAiServicesWithToolsIT.TransactionService.EXPECTED_SPECIFICATION;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;

class StreamingAiServicesWithToolsIT {

static Stream<StreamingChatLanguageModel> models() {
return Stream.of(
OpenAiStreamingChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.build(),
MistralAiStreamingChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName("mistral-large-latest")
.logRequests(true)
.logResponses(true)
.build()
// Add your AzureOpenAiChatModel instance here...
// Add your GeminiChatModel instance here...
);
}

interface Assistant {

TokenStream chat(String userMessage);
}

static class TransactionService {

static ToolSpecification EXPECTED_SPECIFICATION = ToolSpecification.builder()
.name("getTransactionAmounts")
.description("returns amounts of transactions")
.addParameter("arg0", ARRAY, items(STRING), description("IDs of transactions"))
.build();

@Tool("returns amounts of transactions")
List<Double> getTransactionAmounts(@P("IDs of transactions") List<String> ids) {
System.out.printf("called getTransactionAmounts(%s)%n", ids);
return ids.stream().map(id -> {
switch (id) {
case "T001":
return 42.0;
case "T002":
return 57.0;
default:
throw new IllegalArgumentException("Unknown transaction ID: " + id);
}
}).collect(Collectors.toList());
}
}

@ParameterizedTest
@MethodSource("models")
void should_use_tool_with_List_of_Strings_parameter(StreamingChatLanguageModel model) throws Exception {

// given
TransactionService transactionService = spy(new TransactionService());

ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(10);

StreamingChatLanguageModel spyModel = spy(model);

Assistant assistant = AiServices.builder(Assistant.class)
.streamingChatLanguageModel(spyModel)
.chatMemory(chatMemory)
.tools(transactionService)
.build();

String userMessage = "What are the amounts of transactions T001 and T002?";

// when
CompletableFuture<Response<AiMessage>> future = new CompletableFuture<>();
assistant.chat(userMessage)
.onNext(token -> {
})
.onComplete(future::complete)
.onError(future::completeExceptionally)
.start();
Response<AiMessage> response = future.get(60, TimeUnit.SECONDS);

// then
assertThat(response.content().text()).contains("42", "57");

// then
verify(transactionService).getTransactionAmounts(asList("T001", "T002"));
verifyNoMoreInteractions(transactionService);

// then
List<ChatMessage> messages = chatMemory.messages();
verify(spyModel).generate(
eq(singletonList(messages.get(0))),
eq(singletonList(EXPECTED_SPECIFICATION)),
any()
);
verify(spyModel).generate(
eq(asList(messages.get(0), messages.get(1), messages.get(2))),
eq(singletonList(EXPECTED_SPECIFICATION)),
any()
);
}

// TODO all other tests from sync version
}

0 comments on commit 002f65d

Please sign in to comment.