From 4d3e622f322b7b227559dd7fd41ab1901c8d6638 Mon Sep 17 00:00:00 2001 From: tenpigs267 <126336487+tenpigs267@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:16:05 +0100 Subject: [PATCH] nested POJO extract (#575) This PR enhance ServiceOutPutParser allowing outputFormatInstructions to document nested objects in jsonStructure. Integration tests have been modified to add a nested address object. Maybe a dedicated test would be better? --- .../service/ServiceOutputParser.java | 19 +- .../dev/langchain4j/service/AiServicesIT.java | 35 +++- .../service/ServiceOutputParserTest.java | 182 ++++++++++++++++++ 3 files changed, 229 insertions(+), 7 deletions(-) create mode 100644 langchain4j/src/test/java/dev/langchain4j/service/ServiceOutputParserTest.java diff --git a/langchain4j/src/main/java/dev/langchain4j/service/ServiceOutputParser.java b/langchain4j/src/main/java/dev/langchain4j/service/ServiceOutputParser.java index 1fe533dcbf..4d6e0ed5f6 100644 --- a/langchain4j/src/main/java/dev/langchain4j/service/ServiceOutputParser.java +++ b/langchain4j/src/main/java/dev/langchain4j/service/ServiceOutputParser.java @@ -123,7 +123,9 @@ private static String jsonStructure(Class structured) { jsonSchema.append("{\n"); for (Field field : structured.getDeclaredFields()) { String name = field.getName(); - if (name.equals("__$hits$__")) { + if (name.equals("__$hits$__") + || java.lang.reflect.Modifier.isStatic(field.getModifiers()) + || java.lang.reflect.Modifier.isFinal(field.getModifiers())) { // Skip coverage instrumentation field. continue; } @@ -151,15 +153,24 @@ private static String typeOf(Field field) { if (parameterizedType.getRawType().equals(List.class) || parameterizedType.getRawType().equals(Set.class)) { - return format("array of %s", simpleTypeName(typeArguments[0])); + if (((Class) typeArguments[0]).getPackage() == null || ((Class) typeArguments[0]).getPackage().getName().startsWith("java.")) + return format("array of %s", simpleTypeName(typeArguments[0])); + else + return format("array of %s", jsonStructure((Class) typeArguments[0])); } } else if (field.getType().isArray()) { - return format("array of %s", simpleTypeName(field.getType().getComponentType())); + if (field.getType().getComponentType().getPackage() == null || field.getType().getComponentType().getPackage().getName().startsWith("java.")) + return format("array of %s", simpleTypeName(field.getType().getComponentType())); + else + return format("array of %s", jsonStructure(field.getType().getComponentType())); } else if (((Class) type).isEnum()) { return "enum, must be one of " + Arrays.toString(((Class) type).getEnumConstants()); } - return simpleTypeName(type); + if (field.getType().getPackage() == null || field.getType().getPackage().getName().startsWith("java.")) + return simpleTypeName(type); + else + return jsonStructure(field.getType()); } private static String simpleTypeName(Type type) { diff --git a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java index e85d366afa..adb13292a0 100644 --- a/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java +++ b/langchain4j/src/test/java/dev/langchain4j/service/AiServicesIT.java @@ -200,11 +200,18 @@ void test_extract_enum() { @ToString - static class Person { + static class Address { + private Integer streetNumber; + private String street; + private String city; + } + @ToString + static class Person { private String firstName; private String lastName; private LocalDate birthDate; + private Address address; } interface PersonExtractor { @@ -220,7 +227,10 @@ void should_extract_custom_POJO() { String text = "In 1968, amidst the fading echoes of Independence Day, " + "a child named John arrived under the calm evening sky. " - + "This newborn, bearing the surname Doe, marked the start of a new journey."; + + "This newborn, bearing the surname Doe, marked the start of a new journey." + + "He was welcomed into the world at 345 Whispering Pines Avenue," + + "a quaint street nestled in the heart of Springfield," + + "an abode that echoed with the gentle hum of suburban dreams and aspirations."; Person person = personExtractor.extractPersonFrom(text); System.out.println(person); @@ -228,6 +238,9 @@ void should_extract_custom_POJO() { assertThat(person.firstName).isEqualTo("John"); assertThat(person.lastName).isEqualTo("Doe"); assertThat(person.birthDate).isEqualTo(LocalDate.of(1968, JULY, 4)); + assertThat(person.address.streetNumber).isEqualTo(345); + assertThat(person.address.street).isEqualTo("Whispering Pines Avenue"); + assertThat(person.address.city).isEqualTo("Springfield"); verify(chatLanguageModel).generate(singletonList(userMessage( "Extract information about a person from " + text + "\n" + @@ -235,6 +248,11 @@ void should_extract_custom_POJO() { "\"firstName\": (type: string),\n" + "\"lastName\": (type: string),\n" + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "\"address\": (type: {\n" + + "\"streetNumber\": (type: integer),\n" + + "\"street\": (type: string),\n" + + "\"city\": (type: string),\n" + + "}),\n" + "}"))); } @@ -256,7 +274,10 @@ void should_extract_custom_POJO_with_explicit_json_response_format() { String text = "In 1968, amidst the fading echoes of Independence Day, " + "a child named John arrived under the calm evening sky. " - + "This newborn, bearing the surname Doe, marked the start of a new journey."; + + "This newborn, bearing the surname Doe, marked the start of a new journey." + + "He was welcomed into the world at 345 Whispering Pines Avenue," + + "a quaint street nestled in the heart of Springfield," + + "an abode that echoed with the gentle hum of suburban dreams and aspirations."; Person person = personExtractor.extractPersonFrom(text); System.out.println(person); @@ -264,6 +285,9 @@ void should_extract_custom_POJO_with_explicit_json_response_format() { assertThat(person.firstName).isEqualTo("John"); assertThat(person.lastName).isEqualTo("Doe"); assertThat(person.birthDate).isEqualTo(LocalDate.of(1968, JULY, 4)); + assertThat(person.address.streetNumber).isEqualTo(345); + assertThat(person.address.street).isEqualTo("Whispering Pines Avenue"); + assertThat(person.address.city).isEqualTo("Springfield"); verify(chatLanguageModel).generate(singletonList(userMessage( "Extract information about a person from " + text + "\n" + @@ -271,6 +295,11 @@ void should_extract_custom_POJO_with_explicit_json_response_format() { "\"firstName\": (type: string),\n" + "\"lastName\": (type: string),\n" + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "\"address\": (type: {\n" + + "\"streetNumber\": (type: integer),\n" + + "\"street\": (type: string),\n" + + "\"city\": (type: string),\n" + + "}),\n" + "}"))); } diff --git a/langchain4j/src/test/java/dev/langchain4j/service/ServiceOutputParserTest.java b/langchain4j/src/test/java/dev/langchain4j/service/ServiceOutputParserTest.java new file mode 100644 index 0000000000..6950e2bd88 --- /dev/null +++ b/langchain4j/src/test/java/dev/langchain4j/service/ServiceOutputParserTest.java @@ -0,0 +1,182 @@ +package dev.langchain4j.service; + +import org.junit.jupiter.api.Test; + +import java.io.Serializable; +import java.time.LocalDate; +import java.util.Calendar; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +class ServiceOutputParserTest { + + static class Person { + private String firstName; + private String lastName; + private LocalDate birthDate; + } + + @Test + void outputFormatInstructions_SimplePerson() { + String formatInstructions = ServiceOutputParser.outputFormatInstructions(Person.class); + + assertThat(formatInstructions).isEqualTo( + "\nYou must answer strictly in the following JSON format: {\n" + + "\"firstName\": (type: string),\n" + + "\"lastName\": (type: string),\n" + + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "}"); + } + + static class PersonWithFirstNameList { + private List firstName; + private String lastName; + private LocalDate birthDate; + } + + @Test + void outputFormatInstructions_PersonWithFirstNameList() { + String formatInstructions = ServiceOutputParser.outputFormatInstructions(PersonWithFirstNameList.class); + + assertThat(formatInstructions).isEqualTo( + "\nYou must answer strictly in the following JSON format: {\n" + + "\"firstName\": (type: array of string),\n" + + "\"lastName\": (type: string),\n" + + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "}"); + } + + static class PersonWithFirstNameArray { + private String[] firstName; + private String lastName; + private LocalDate birthDate; + } + + @Test + void outputFormatInstructions_PersonWithFirstNameArray() { + String formatInstructions = ServiceOutputParser.outputFormatInstructions(PersonWithFirstNameArray.class); + + assertThat(formatInstructions).isEqualTo( + "\nYou must answer strictly in the following JSON format: {\n" + + "\"firstName\": (type: array of string),\n" + + "\"lastName\": (type: string),\n" + + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "}"); + } + + static class PersonWithCalendarDate { + private String firstName; + private String lastName; + private Calendar birthDate; + } + + @Test + void outputFormatInstructions_PersonWithJavaType() { + String formatInstructions = ServiceOutputParser.outputFormatInstructions(PersonWithCalendarDate.class); + + assertThat(formatInstructions).isEqualTo( + "\nYou must answer strictly in the following JSON format: {\n" + + "\"firstName\": (type: string),\n" + + "\"lastName\": (type: string),\n" + + "\"birthDate\": (type: java.util.Calendar),\n" + + "}"); + } + + static class PersonWithStaticField implements Serializable { + private static final long serialVersionUID = 1234567L; + private String firstName; + private String lastName; + private LocalDate birthDate; + } + + @Test + void outputFormatInstructions_PersonWithStaticFinalField() { + String formatInstructions = ServiceOutputParser.outputFormatInstructions(PersonWithStaticField.class); + + assertThat(formatInstructions).isEqualTo( + "\nYou must answer strictly in the following JSON format: {\n" + + "\"firstName\": (type: string),\n" + + "\"lastName\": (type: string),\n" + + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "}"); + } + + static class Address { + private Integer streetNumber; + private String street; + private String city; + } + + static class PersonAndAddress { + private String firstName; + private String lastName; + private LocalDate birthDate; + private Address address; + } + + @Test + void outputFormatInstructions_PersonWithNestedObject() { + String formatInstructions = ServiceOutputParser.outputFormatInstructions(PersonAndAddress.class); + + assertThat(formatInstructions).isEqualTo( + "\nYou must answer strictly in the following JSON format: {\n" + + "\"firstName\": (type: string),\n" + + "\"lastName\": (type: string),\n" + + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "\"address\": (type: {\n" + + "\"streetNumber\": (type: integer),\n" + + "\"street\": (type: string),\n" + + "\"city\": (type: string),\n" + + "}),\n" + + "}"); + } + + static class PersonAndAddressList { + private String firstName; + private String lastName; + private LocalDate birthDate; + private List
address; + } + + @Test + void outputFormatInstructions_PersonWithNestedObjectList() { + String formatInstructions = ServiceOutputParser.outputFormatInstructions(PersonAndAddressList.class); + + assertThat(formatInstructions).isEqualTo( + "\nYou must answer strictly in the following JSON format: {\n" + + "\"firstName\": (type: string),\n" + + "\"lastName\": (type: string),\n" + + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "\"address\": (type: array of {\n" + + "\"streetNumber\": (type: integer),\n" + + "\"street\": (type: string),\n" + + "\"city\": (type: string),\n" + + "}),\n" + + "}"); + } + + static class PersonAndAddressArray { + private String firstName; + private String lastName; + private LocalDate birthDate; + private List
address; + } + + @Test + void outputFormatInstructions_PersonWithNestedObjectArray() { + String formatInstructions = ServiceOutputParser.outputFormatInstructions(PersonAndAddressList.class); + + assertThat(formatInstructions).isEqualTo( + "\nYou must answer strictly in the following JSON format: {\n" + + "\"firstName\": (type: string),\n" + + "\"lastName\": (type: string),\n" + + "\"birthDate\": (type: date string (2023-12-31)),\n" + + "\"address\": (type: array of {\n" + + "\"streetNumber\": (type: integer),\n" + + "\"street\": (type: string),\n" + + "\"city\": (type: string),\n" + + "}),\n" + + "}"); + } +} \ No newline at end of file