From f2309ad477fb6bd1179dd1ba3d165a36ca06b8ba Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Fri, 15 Dec 2023 17:33:19 +0100 Subject: [PATCH 01/30] Add image generation support with Azure OpenAI --- .../model/azure/AzureOpenAiChatModel.java | 5 +- .../model/azure/AzureOpenAiImageModel.java | 255 ++++++++++++++++++ .../azure/InternalAzureOpenAiHelper.java | 7 + .../model/azure/AzureOpenAiImageModelIT.java | 34 +++ .../dev/langchain4j/data/image/Image.java | 36 +++ .../langchain4j/model/image/ImageModel.java | 8 + 6 files changed, 342 insertions(+), 3 deletions(-) create mode 100644 langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java create mode 100644 langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java create mode 100644 langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java create mode 100644 langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java index 13b2b243f2..4c26f1e2e1 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java @@ -19,6 +19,7 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; +import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO_0613; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; import static java.util.Collections.singletonList; @@ -84,8 +85,6 @@ public AzureOpenAiChatModel(String endpoint, this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } - - private AzureOpenAiChatModel(String deploymentName, Tokenizer tokenizer, Double temperature, @@ -94,7 +93,7 @@ private AzureOpenAiChatModel(String deploymentName, Double presencePenalty, Double frequencyPenalty) { - this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo-0613"); + this.deploymentName = getOrDefault(deploymentName, GPT_3_5_TURBO_0613); this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO)); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java new file mode 100644 index 0000000000..062a50fde3 --- /dev/null +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -0,0 +1,255 @@ +package dev.langchain4j.model.azure; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.models.*; +import com.azure.core.http.ProxyOptions; +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.image.ImageModel; +import dev.langchain4j.model.output.Response; + +import java.time.Duration; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; + +public class AzureOpenAiImageModel implements ImageModel { + + private OpenAIClient client; + private final String deploymentName; + private final ImageGenerationQuality quality; + private final ImageSize size; + private final String user; + private final ImageGenerationStyle style; + private final ImageGenerationResponseFormat responseFormat; + + public AzureOpenAiImageModel(OpenAIClient client, + String deploymentName, + String quality, + String size, + String user, + String style, + String responseFormat) { + + this(deploymentName, quality, size, user, style, responseFormat); + this.client = client; + } + + public AzureOpenAiImageModel(String endpoint, + String serviceVersion, + String apiKey, + String deploymentName, + String quality, + String size, + String user, + String style, + String responseFormat, + Duration timeout, + Integer maxRetries, + ProxyOptions proxyOptions, + boolean logRequestsAndResponses) { + + this(deploymentName, quality, size, user, style, responseFormat); + this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + } + + private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { + this.deploymentName = getOrDefault(deploymentName, GPT_3_5_TURBO); + this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); + this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE256X256); + this.user = getOrDefault(user, ""); + this.style = getOrDefault(ImageGenerationStyle.fromString(style), ImageGenerationStyle.NATURAL); + this.responseFormat = getOrDefault(ImageGenerationResponseFormat.fromString(responseFormat), ImageGenerationResponseFormat.URL); + } + + @Override + public Response generate(String prompt) { + + ImageGenerationOptions options = new ImageGenerationOptions(prompt) + .setModel(deploymentName) + .setN(1) + .setQuality(quality) + .setSize(size) + .setUser(user) + .setStyle(style) + .setResponseFormat(responseFormat); + + ImageGenerations imageGenerations = client.getImageGenerations(deploymentName, options); + + return Response.from(imageFrom(imageGenerations.getData().get(0))); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String endpoint; + private String serviceVersion; + private String apiKey; + private String deploymentName; + private String quality; + private String size; + private String user; + private String style; + private String responseFormat; + private Duration timeout; + private Integer maxRetries; + private ProxyOptions proxyOptions; + private boolean logRequestsAndResponses; + private OpenAIClient openAIClient; + + /** + * Sets the Azure OpenAI endpoint. This is a mandatory parameter. + * + * @param endpoint The Azure OpenAI endpoint in the format: https://{resource}.openai.azure.com/ + * @return builder + */ + public Builder endpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + /** + * Sets the Azure OpenAI API service version. This is a mandatory parameter. + * + * @param serviceVersion The Azure OpenAI API service version in the format: 2023-05-15 + * @return builder + */ + public Builder serviceVersion(String serviceVersion) { + this.serviceVersion = serviceVersion; + return this; + } + + /** + * Sets the Azure OpenAI API key. This is a mandatory parameter. + * + * @param apiKey The Azure OpenAI API key. + * @return builder + */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Sets the deployment name in Azure OpenAI. This is a mandatory parameter. + * + * @param deploymentName The Deployment name. + * @return builder + */ + public Builder deploymentName(String deploymentName) { + this.deploymentName = deploymentName; + return this; + } + + /** + * Sets the quality of the image. This is an optional parameter. + * + * @param quality The quality of the image. + * @return builder + */ + public Builder quality(String quality) { + this.quality = quality; + return this; + } + + /** + * Sets the size of the image. This is an optional parameter. + * + * @param size The size of the image. + * @return builder + */ + public Builder size(String size) { + this.size = size; + return this; + } + + /** + * Sets the user of the image. This is an optional parameter. + * + * @param user The user of the image. + * @return builder + */ + public Builder user(String user) { + this.user = user; + return this; + } + + /** + * Sets the style of the image. This is an optional parameter. + * + * @param style The style of the image. + * @return builder + */ + public Builder style(String style) { + this.style = style; + return this; + } + + /** + * Sets the response format of the image. This is an optional parameter. + * + * @param responseFormat The response format of the image. + * @return builder + */ + public Builder responseFormat(String responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxRetries(Integer maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder proxyOptions(ProxyOptions proxyOptions) { + this.proxyOptions = proxyOptions; + return this; + } + + public Builder logRequestsAndResponses(Boolean logRequestsAndResponses) { + this.logRequestsAndResponses = logRequestsAndResponses; + return this; + } + + public Builder openAIClient(OpenAIClient openAIClient) { + this.openAIClient = openAIClient; + return this; + } + + public AzureOpenAiImageModel build() { + if (openAIClient != null) { + return new AzureOpenAiImageModel( + openAIClient, + deploymentName, + quality, + size, + user, + style, + responseFormat); + } + return new AzureOpenAiImageModel( + endpoint, + serviceVersion, + apiKey, + deploymentName, + quality, + size, + user, + style, + responseFormat, + timeout, + maxRetries, + proxyOptions, + logRequestsAndResponses); + } + } +} diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index f5331b27ce..dfcf46e2ef 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -17,6 +17,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; @@ -187,6 +188,12 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes } } + public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData imageGenerationData) { + return Image.builder() + .url(imageGenerationData.getUrl()) + .build(); + } + public static TokenUsage tokenUsageFrom(CompletionsUsage openAiUsage) { if (openAiUsage == null) { return null; diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java new file mode 100644 index 0000000000..059f2f1f89 --- /dev/null +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -0,0 +1,34 @@ +package dev.langchain4j.model.azure; + +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.assertj.core.api.Assertions.assertThat; + +public class AzureOpenAiImageModelIT { + + Logger logger = LoggerFactory.getLogger(AzureOpenAiImageModelIT.class); + + @Test + void should_generate_image() { + + AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .serviceVersion("2023-06-01-preview") + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) + .logRequestsAndResponses(true) + .size("1024x1024") + .build(); + + Response response = model.generate("An image of a Java developer in Paris, France"); + + logger.info(response.toString()); + + assertThat(response.content().url()).isNotBlank(); + + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java new file mode 100644 index 0000000000..ea00f96d4e --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java @@ -0,0 +1,36 @@ +package dev.langchain4j.data.image; + +public class Image { + + private String url; + + public String url() { + return url; + } + + public void url(String url) { + this.url = url; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String url; + + public Builder url(String url) { + this.url = url; + return this; + } + + public Image build() { + return new Image(this); + } + } + + private Image(Builder builder) { + this.url = builder.url; + } +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java new file mode 100644 index 0000000000..8ef464de77 --- /dev/null +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java @@ -0,0 +1,8 @@ +package dev.langchain4j.model.image; + +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.output.Response; + +public interface ImageModel { + Response generate(String prompt); +} From 4c24646228443b4b6bcfdab1887989f69ecd1b3c Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Sat, 16 Dec 2023 15:50:06 +0100 Subject: [PATCH 02/30] Fix default deployment names --- .../dev/langchain4j/model/azure/AzureOpenAiChatModel.java | 3 +-- .../dev/langchain4j/model/azure/AzureOpenAiModelName.java | 4 ++++ .../model/azure/AzureOpenAiStreamingChatModel.java | 2 +- .../dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java | 2 -- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java index 4c26f1e2e1..bfc227eacb 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java @@ -19,7 +19,6 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; -import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO_0613; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; import static java.util.Collections.singletonList; @@ -93,7 +92,7 @@ private AzureOpenAiChatModel(String deploymentName, Double presencePenalty, Double frequencyPenalty) { - this.deploymentName = getOrDefault(deploymentName, GPT_3_5_TURBO_0613); + this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo"); this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO)); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java index a7f99d558b..f99032bf80 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java @@ -6,6 +6,7 @@ public class AzureOpenAiModelName { public static final String GPT_3_5_TURBO = "gpt-3.5-turbo"; // alias for the latest model public static final String GPT_3_5_TURBO_0301 = "gpt-3.5-turbo-0301"; // 4k context public static final String GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613"; // 4k context, functions + public static final String GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106"; // 16k context, functions public static final String GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k"; // alias for the latest model public static final String GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613"; // 16k context, functions @@ -28,4 +29,7 @@ public class AzureOpenAiModelName { // Use with AzureOpenAiEmbeddingModel public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002"; + // Use with AzureOpenAiImageModel + public static final String DALL_E_3 = "DALL_E_3"; + } diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java index 79be8a6dd7..ce6a032c91 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java @@ -97,7 +97,7 @@ private AzureOpenAiStreamingChatModel(String deploymentName, Double presencePenalty, Double frequencyPenalty) { - this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo-0613"); + this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo"); this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO)); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java index 68d3a155d3..941d1bcc0b 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java @@ -120,7 +120,6 @@ void should_call_function_with_argument() { String weather = String.format("The weather in %s is %d degrees %s.", weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit()); - assertThat(weather).isEqualTo("The weather in Paris, France is 35 degrees celsius."); // Now that we know the function's result, we can call the model again with the result as input. @@ -130,7 +129,6 @@ void should_call_function_with_argument() { List chatMessages = new ArrayList<>(); chatMessages.add(systemMessage); chatMessages.add(userMessage); - chatMessages.add(aiMessage); chatMessages.add(toolExecutionResultMessage); Response response2 = model.generate(chatMessages); From a8dd95894b10f5ce58ffbd21f10282a441344af6 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Sat, 16 Dec 2023 15:50:42 +0100 Subject: [PATCH 03/30] Finalize integration tests --- .../langchain4j/model/azure/AzureOpenAiImageModel.java | 5 ++--- .../langchain4j/model/azure/AzureOpenAiImageModelIT.java | 8 ++------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 062a50fde3..7b623ad096 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -10,7 +10,6 @@ import java.time.Duration; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; @@ -55,9 +54,9 @@ public AzureOpenAiImageModel(String endpoint, } private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { - this.deploymentName = getOrDefault(deploymentName, GPT_3_5_TURBO); + this.deploymentName = getOrDefault(deploymentName, "dall-e-3"); this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); - this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE256X256); + this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024); this.user = getOrDefault(user, ""); this.style = getOrDefault(ImageGenerationStyle.fromString(style), ImageGenerationStyle.NATURAL); this.responseFormat = getOrDefault(ImageGenerationResponseFormat.fromString(responseFormat), ImageGenerationResponseFormat.URL); diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index 059f2f1f89..de13f2aacd 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -17,18 +17,14 @@ void should_generate_image() { AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .serviceVersion("2023-06-01-preview") - .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) .logRequestsAndResponses(true) - .size("1024x1024") .build(); - Response response = model.generate("An image of a Java developer in Paris, France"); + Response response = model.generate("A coffee mug in Paris, France"); logger.info(response.toString()); - assertThat(response.content().url()).isNotBlank(); - } } From c8aabbfd2079c2b494fd29b079a95ec56bfa6852 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 10:15:53 +0100 Subject: [PATCH 04/30] Update the API usage following code review from @Heezer --- .../model/azure/AzureOpenAiImageModel.java | 37 +++++++++++--- .../azure/InternalAzureOpenAiHelper.java | 21 ++++++-- .../model/azure/AzureOpenAiImageModelIT.java | 11 ++++- .../dev/langchain4j/data/image/Image.java | 49 ++++++++++++++++--- .../langchain4j/model/image/ImageModel.java | 3 +- 5 files changed, 100 insertions(+), 21 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 7b623ad096..dba89764f7 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -8,15 +8,17 @@ import dev.langchain4j.model.output.Response; import java.time.Duration; +import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; -import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; +import static java.util.stream.Collectors.toList; public class AzureOpenAiImageModel implements ImageModel { private OpenAIClient client; private final String deploymentName; + private final Integer n; private final ImageGenerationQuality quality; private final ImageSize size; private final String user; @@ -25,13 +27,14 @@ public class AzureOpenAiImageModel implements ImageModel { public AzureOpenAiImageModel(OpenAIClient client, String deploymentName, + Integer n, String quality, String size, String user, String style, String responseFormat) { - this(deploymentName, quality, size, user, style, responseFormat); + this(deploymentName, n, quality, size, user, style, responseFormat); this.client = client; } @@ -39,6 +42,7 @@ public AzureOpenAiImageModel(String endpoint, String serviceVersion, String apiKey, String deploymentName, + Integer n, String quality, String size, String user, @@ -49,12 +53,13 @@ public AzureOpenAiImageModel(String endpoint, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { - this(deploymentName, quality, size, user, style, responseFormat); + this(deploymentName, n, quality, size, user, style, responseFormat); this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } - private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { + private AzureOpenAiImageModel(String deploymentName, Integer n, String quality, String size, String user, String style, String responseFormat) { this.deploymentName = getOrDefault(deploymentName, "dall-e-3"); + this.n = getOrDefault(n, 1); this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024); this.user = getOrDefault(user, ""); @@ -63,11 +68,11 @@ private AzureOpenAiImageModel(String deploymentName, String quality, String size } @Override - public Response generate(String prompt) { + public Response> generate(String prompt) { ImageGenerationOptions options = new ImageGenerationOptions(prompt) .setModel(deploymentName) - .setN(1) + .setN(n) .setQuality(quality) .setSize(size) .setUser(user) @@ -76,7 +81,11 @@ public Response generate(String prompt) { ImageGenerations imageGenerations = client.getImageGenerations(deploymentName, options); - return Response.from(imageFrom(imageGenerations.getData().get(0))); + List images = imageGenerations.getData().stream() + .map(InternalAzureOpenAiHelper::imageFrom) + .collect(toList()); + + return Response.from(images); } public static Builder builder() { @@ -89,6 +98,7 @@ public static class Builder { private String serviceVersion; private String apiKey; private String deploymentName; + private Integer n; private String quality; private String size; private String user; @@ -144,6 +154,17 @@ public Builder deploymentName(String deploymentName) { return this; } + /** + * Sets the number of images to generate. This is an optional parameter. + * + * @param n The number of images to generate. + * @return builder + */ + public Builder setN(Integer n) { + this.n = n; + return this; + } + /** * Sets the quality of the image. This is an optional parameter. * @@ -229,6 +250,7 @@ public AzureOpenAiImageModel build() { return new AzureOpenAiImageModel( openAIClient, deploymentName, + n, quality, size, user, @@ -240,6 +262,7 @@ public AzureOpenAiImageModel build() { serviceVersion, apiKey, deploymentName, + n, quality, size, user, diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index dfcf46e2ef..12a4eb344a 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -22,6 +22,10 @@ import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; import java.time.Duration; import java.util.*; @@ -189,9 +193,20 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes } public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData imageGenerationData) { - return Image.builder() - .url(imageGenerationData.getUrl()) - .build(); + Image.Builder imageBuilder = Image.builder() + .revisedPrompt(imageGenerationData.getRevisedPrompt()); + + String urlString = imageGenerationData.getUrl(); + if (urlString != null) { + try { + URI uri = new URI(urlString); + imageBuilder.url(uri); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + return imageBuilder.build(); } public static TokenUsage tokenUsageFrom(CompletionsUsage openAiUsage) { diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index de13f2aacd..4f0b85f578 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -6,7 +6,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.List; + import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class AzureOpenAiImageModelIT { @@ -20,11 +23,15 @@ void should_generate_image() { .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .logRequestsAndResponses(true) + .setN(1) .build(); - Response response = model.generate("A coffee mug in Paris, France"); + Response> response = model.generate("A coffee mug in Paris, France"); logger.info(response.toString()); - assertThat(response.content().url()).isNotBlank(); + + Image image = response.content().get(0); + assertNotNull(image.revisedPrompt()); + assertNotNull(image.url()); } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java index ea00f96d4e..6622b3c7fa 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java @@ -1,36 +1,69 @@ package dev.langchain4j.data.image; -public class Image { +import java.net.URI; +import java.util.Objects; - private String url; +public final class Image { - public String url() { + private URI url; + private String revisedPrompt; + + public URI url() { return url; } - public void url(String url) { + public void url(URI url) { this.url = url; } + private Image(Builder builder) { + this.url = builder.url; + this.revisedPrompt = builder.revisedPrompt; + } + + public String revisedPrompt() { + return revisedPrompt; + } + public static Builder builder() { return new Builder(); } public static class Builder { - private String url; + private URI url; + private String revisedPrompt; - public Builder url(String url) { + public Builder url(URI url) { this.url = url; return this; } + public Builder revisedPrompt(String revisedPrompt) { + this.revisedPrompt = revisedPrompt; + return this; + } + public Image build() { return new Image(this); } } - private Image(Builder builder) { - this.url = builder.url; + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Image image = (Image) o; + return Objects.equals(url, image.url) && Objects.equals(revisedPrompt, image.revisedPrompt); + } + + @Override + public int hashCode() { + return Objects.hash(url, revisedPrompt); + } + + @Override + public String toString() { + return "Image{" + "url='" + url + '\'' + ", revisedPrompt='" + revisedPrompt + '\'' + '}'; } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java index 8ef464de77..3347c06ddd 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java @@ -2,7 +2,8 @@ import dev.langchain4j.data.image.Image; import dev.langchain4j.model.output.Response; +import java.util.List; public interface ImageModel { - Response generate(String prompt); + Response> generate(String prompt); } From 1f28f0c0e8be182888b3ebfb6dbdb94fec9438f5 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 13:58:11 +0100 Subject: [PATCH 05/30] Generate only 1 image, and support base64 data --- .../model/azure/AzureOpenAiImageModel.java | 44 +++++----------- .../azure/InternalAzureOpenAiHelper.java | 5 +- .../model/azure/AzureOpenAiImageModelIT.java | 52 ++++++++++++++++--- .../dev/langchain4j/data/image/Image.java | 34 +++++++++--- .../langchain4j/model/image/ImageModel.java | 3 +- 5 files changed, 90 insertions(+), 48 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index dba89764f7..499b0a5afd 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -7,18 +7,17 @@ import dev.langchain4j.model.image.ImageModel; import dev.langchain4j.model.output.Response; +import java.nio.file.Path; import java.time.Duration; -import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.imageFrom; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; -import static java.util.stream.Collectors.toList; public class AzureOpenAiImageModel implements ImageModel { private OpenAIClient client; private final String deploymentName; - private final Integer n; private final ImageGenerationQuality quality; private final ImageSize size; private final String user; @@ -27,14 +26,13 @@ public class AzureOpenAiImageModel implements ImageModel { public AzureOpenAiImageModel(OpenAIClient client, String deploymentName, - Integer n, String quality, String size, String user, String style, String responseFormat) { - this(deploymentName, n, quality, size, user, style, responseFormat); + this(deploymentName, quality, size, user, style, responseFormat); this.client = client; } @@ -42,7 +40,6 @@ public AzureOpenAiImageModel(String endpoint, String serviceVersion, String apiKey, String deploymentName, - Integer n, String quality, String size, String user, @@ -53,13 +50,12 @@ public AzureOpenAiImageModel(String endpoint, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { - this(deploymentName, n, quality, size, user, style, responseFormat); + this(deploymentName, quality, size, user, style, responseFormat); this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } - private AzureOpenAiImageModel(String deploymentName, Integer n, String quality, String size, String user, String style, String responseFormat) { + private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { this.deploymentName = getOrDefault(deploymentName, "dall-e-3"); - this.n = getOrDefault(n, 1); this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024); this.user = getOrDefault(user, ""); @@ -68,11 +64,10 @@ private AzureOpenAiImageModel(String deploymentName, Integer n, String quality, } @Override - public Response> generate(String prompt) { - + public Response generate(String prompt) { ImageGenerationOptions options = new ImageGenerationOptions(prompt) .setModel(deploymentName) - .setN(n) + .setN(1) .setQuality(quality) .setSize(size) .setUser(user) @@ -80,12 +75,8 @@ public Response> generate(String prompt) { .setResponseFormat(responseFormat); ImageGenerations imageGenerations = client.getImageGenerations(deploymentName, options); - - List images = imageGenerations.getData().stream() - .map(InternalAzureOpenAiHelper::imageFrom) - .collect(toList()); - - return Response.from(images); + Image image = imageFrom(imageGenerations.getData().get(0)); + return Response.from(image); } public static Builder builder() { @@ -98,7 +89,6 @@ public static class Builder { private String serviceVersion; private String apiKey; private String deploymentName; - private Integer n; private String quality; private String size; private String user; @@ -108,6 +98,9 @@ public static class Builder { private Integer maxRetries; private ProxyOptions proxyOptions; private boolean logRequestsAndResponses; + private boolean withPersisting; + + private Path persistTo; private OpenAIClient openAIClient; /** @@ -154,17 +147,6 @@ public Builder deploymentName(String deploymentName) { return this; } - /** - * Sets the number of images to generate. This is an optional parameter. - * - * @param n The number of images to generate. - * @return builder - */ - public Builder setN(Integer n) { - this.n = n; - return this; - } - /** * Sets the quality of the image. This is an optional parameter. * @@ -250,7 +232,6 @@ public AzureOpenAiImageModel build() { return new AzureOpenAiImageModel( openAIClient, deploymentName, - n, quality, size, user, @@ -262,7 +243,6 @@ public AzureOpenAiImageModel build() { serviceVersion, apiKey, deploymentName, - n, quality, size, user, diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index 12a4eb344a..c602923a72 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -22,10 +22,8 @@ import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; -import java.net.MalformedURLException; import java.net.URI; import java.net.URISyntaxException; -import java.net.URL; import java.time.Duration; import java.util.*; @@ -197,6 +195,7 @@ public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData ima .revisedPrompt(imageGenerationData.getRevisedPrompt()); String urlString = imageGenerationData.getUrl(); + String imageData = imageGenerationData.getBase64Data(); if (urlString != null) { try { URI uri = new URI(urlString); @@ -204,6 +203,8 @@ public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData ima } catch (URISyntaxException e) { throw new RuntimeException(e); } + } else if (imageData != null) { + imageBuilder.base64Data(imageData); } return imageBuilder.build(); diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index 4f0b85f578..e6112e0d10 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -1,37 +1,75 @@ package dev.langchain4j.model.azure; +import com.azure.ai.openai.models.ImageGenerationResponseFormat; import dev.langchain4j.data.image.Image; import dev.langchain4j.model.output.Response; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.List; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Base64; -import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; public class AzureOpenAiImageModelIT { Logger logger = LoggerFactory.getLogger(AzureOpenAiImageModelIT.class); @Test - void should_generate_image() { + void should_generate_image_with_url() { AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .logRequestsAndResponses(true) - .setN(1) .build(); - Response> response = model.generate("A coffee mug in Paris, France"); + Response response = model.generate("A coffee mug in Paris, France"); logger.info(response.toString()); - Image image = response.content().get(0); - assertNotNull(image.revisedPrompt()); + Image image = response.content(); + assertNotNull(image); assertNotNull(image.url()); + logger.info("The remote image is here: {}", image.url()); + + assertNull(image.base64Data()); + + assertNotNull(image.revisedPrompt()); + logger.info("The revised prompt is: {}", image.revisedPrompt()); + } + + @Test + void should_generate_image_in_base64() throws IOException { + AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .logRequestsAndResponses(false) // The image is big, so we don't want to log it by default + .responseFormat(ImageGenerationResponseFormat.BASE64.toString()) + .build(); + + Response response = model.generate("A croissant in Paris, France"); + + Image image = response.content(); + assertNotNull(image); + assertNull(image.url()); + assertNotNull(image.base64Data()); + logger.info("The image data is: {} characters", image.base64Data().length()); + + if (logger.isDebugEnabled()) { + byte[] decodedBytes = Base64.getDecoder().decode(response.content().base64Data()); + Path temp = Files.createTempFile("langchain4j", ".png"); + Files.write(temp, decodedBytes); + logger.debug("The image is here: {}", temp.toAbsolutePath()); + } + + assertNotNull(image.revisedPrompt()); + logger.info("The revised prompt is: {}", image.revisedPrompt()); } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java index 6622b3c7fa..817af68ff9 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java @@ -6,6 +6,7 @@ public final class Image { private URI url; + private String base64Data; private String revisedPrompt; public URI url() { @@ -16,15 +17,28 @@ public void url(URI url) { this.url = url; } - private Image(Builder builder) { - this.url = builder.url; - this.revisedPrompt = builder.revisedPrompt; + public String base64Data() { + return base64Data; + } + + public void base64Data(String base64Data) { + this.base64Data = base64Data; } public String revisedPrompt() { return revisedPrompt; } + public void revisedPrompt(String revisedPrompt) { + this.revisedPrompt = revisedPrompt; + } + + private Image(Builder builder) { + this.url = builder.url; + this.base64Data = builder.base64Data; + this.revisedPrompt = builder.revisedPrompt; + } + public static Builder builder() { return new Builder(); } @@ -32,6 +46,7 @@ public static Builder builder() { public static class Builder { private URI url; + private String base64Data; private String revisedPrompt; public Builder url(URI url) { @@ -39,6 +54,11 @@ public Builder url(URI url) { return this; } + public Builder base64Data(String base64Data) { + this.base64Data = base64Data; + return this; + } + public Builder revisedPrompt(String revisedPrompt) { this.revisedPrompt = revisedPrompt; return this; @@ -54,16 +74,18 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Image image = (Image) o; - return Objects.equals(url, image.url) && Objects.equals(revisedPrompt, image.revisedPrompt); + return Objects.equals(url, image.url) && Objects.equals(base64Data, image.base64Data) && Objects.equals(revisedPrompt, image.revisedPrompt); } @Override public int hashCode() { - return Objects.hash(url, revisedPrompt); + return Objects.hash(url, base64Data, revisedPrompt); } @Override public String toString() { - return "Image{" + "url='" + url + '\'' + ", revisedPrompt='" + revisedPrompt + '\'' + '}'; + return "Image{" + "url='" + url + '\'' + + ", base64Data='" + base64Data + '\'' + + ", revisedPrompt='" + revisedPrompt + '\'' + '}'; } } diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java index 3347c06ddd..dcccd29e1f 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/image/ImageModel.java @@ -5,5 +5,6 @@ import java.util.List; public interface ImageModel { - Response> generate(String prompt); + + Response generate(String prompt); } From 36bf7fae960605780ec2cc67b91cb381e5df6c49 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 16:16:33 +0100 Subject: [PATCH 06/30] Use default parameters from Azure OpenAI --- .../model/azure/AzureOpenAiImageModel.java | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 499b0a5afd..8bad9339f4 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -18,11 +18,11 @@ public class AzureOpenAiImageModel implements ImageModel { private OpenAIClient client; private final String deploymentName; - private final ImageGenerationQuality quality; - private final ImageSize size; - private final String user; - private final ImageGenerationStyle style; - private final ImageGenerationResponseFormat responseFormat; + private ImageGenerationQuality quality = null; + private ImageSize size = null; + private String user = null; + private ImageGenerationStyle style = null; + private ImageGenerationResponseFormat responseFormat = null; public AzureOpenAiImageModel(OpenAIClient client, String deploymentName, @@ -56,11 +56,21 @@ public AzureOpenAiImageModel(String endpoint, private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { this.deploymentName = getOrDefault(deploymentName, "dall-e-3"); - this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); - this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024); - this.user = getOrDefault(user, ""); - this.style = getOrDefault(ImageGenerationStyle.fromString(style), ImageGenerationStyle.NATURAL); - this.responseFormat = getOrDefault(ImageGenerationResponseFormat.fromString(responseFormat), ImageGenerationResponseFormat.URL); + if (quality != null) { + this.quality = ImageGenerationQuality.fromString(quality); + } + if (size != null) { + this.size = ImageSize.fromString(size); + } + if (user != null) { + this.user = user; + } + if (style != null) { + this.style = ImageGenerationStyle.fromString(style); + } + if (responseFormat != null) { + this.responseFormat = ImageGenerationResponseFormat.fromString(responseFormat); + } } @Override From 9e8a32c9e90883d36d0446bba342d6ced5dceac5 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:23:54 +0100 Subject: [PATCH 07/30] Add documentation for using OpenAI Image generation --- .../model/azure/AzureOpenAiImageModel.java | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 8bad9339f4..e57c4f982b 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -14,6 +14,27 @@ import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.imageFrom; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; +/** + * Represents an OpenAI image model, hosted on Azure, such as dall-e-3. + *

+ * You can find a tutorial on using Azure OpenAI to generate images at: https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=programming-language-java + *

+ * Mandatory parameters for initialization are: endpoint, serviceVersion, apiKey and deploymentName. + * You can also provide your own OpenAIClient instance, if you need more flexibility. + *

+ * There are two primary authentication methods to access Azure OpenAI: + *

+ * 1. API Key Authentication: For this type of authentication, HTTP requests must include the + * API Key in the "api-key" HTTP header as follows: `api-key: OPENAI_API_KEY`Y` + *

+ * 2. Azure Active Directory Authentication: For this type of authentication, HTTP requests must include the + * authentication/access token in the "Authorization" HTTP header. + *

+ * More information + *

+ * Please note, that currently, only API Key authentication is supported by this class, + * second authentication option will be supported later. + */ public class AzureOpenAiImageModel implements ImageModel { private OpenAIClient client; From 08cbbf1fc0ae3fac0dd66a9dbef9828441ec6813 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:29:42 +0100 Subject: [PATCH 08/30] Remove withPersisting and persistTo --- .../dev/langchain4j/model/azure/AzureOpenAiImageModel.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index e57c4f982b..46ea7bc61c 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -129,9 +129,6 @@ public static class Builder { private Integer maxRetries; private ProxyOptions proxyOptions; private boolean logRequestsAndResponses; - private boolean withPersisting; - - private Path persistTo; private OpenAIClient openAIClient; /** From 072bdb63f987cb9d39dca9a6f9fb8be90140649d Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:35:57 +0100 Subject: [PATCH 09/30] Add enums in the Builder methods --- .../model/azure/AzureOpenAiImageModel.java | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 46ea7bc61c..510c919a5e 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -186,6 +186,17 @@ public Builder quality(String quality) { return this; } + /** + * Sets the quality of the image, using the ImageGenerationQuality enum. This is an optional parameter. + * + * @param imageGenerationQuality The quality of the image. + * @return builder + */ + public Builder quality(ImageGenerationQuality imageGenerationQuality) { + this.quality = imageGenerationQuality.toString(); + return this; + } + /** * Sets the size of the image. This is an optional parameter. * @@ -197,6 +208,17 @@ public Builder size(String size) { return this; } + /** + * Sets the size of the image, using the ImageSize enum. This is an optional parameter. + * + * @param imageSize The size of the image. + * @return builder + */ + public Builder size(ImageSize imageSize) { + this.size = imageSize.toString(); + return this; + } + /** * Sets the user of the image. This is an optional parameter. * @@ -219,6 +241,17 @@ public Builder style(String style) { return this; } + /** + * Sets the style of the image, using the ImageGenerationStyle enum. This is an optional parameter. + * + * @param imageGenerationStyle The style of the image. + * @return builder + */ + public Builder style(ImageGenerationStyle imageGenerationStyle) { + this.style = imageGenerationStyle.toString(); + return this; + } + /** * Sets the response format of the image. This is an optional parameter. * @@ -230,6 +263,17 @@ public Builder responseFormat(String responseFormat) { return this; } + /** + * Sets the response format of the image, using the ImageGenerationResponseFormat enum. This is an optional parameter. + * + * @param imageGenerationResponseFormat The response format of the image. + * @return builder + */ + public Builder responseFormat(ImageGenerationResponseFormat imageGenerationResponseFormat) { + this.responseFormat = imageGenerationResponseFormat.toString(); + return this; + } + public Builder timeout(Duration timeout) { this.timeout = timeout; return this; From 7ed0f7c40839a86b1008c5ee4d444f88ee9769fb Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:40:07 +0100 Subject: [PATCH 10/30] Use Utils.quoted --- .../main/java/dev/langchain4j/data/image/Image.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java index 817af68ff9..26fb1209ba 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java @@ -1,5 +1,7 @@ package dev.langchain4j.data.image; +import static dev.langchain4j.internal.Utils.quoted; + import java.net.URI; import java.util.Objects; @@ -84,8 +86,10 @@ public int hashCode() { @Override public String toString() { - return "Image{" + "url='" + url + '\'' + - ", base64Data='" + base64Data + '\'' + - ", revisedPrompt='" + revisedPrompt + '\'' + '}'; + return "Image {" + + "url=" + quoted(url.toString()) + + ", base64Data=" + quoted(base64Data) + + ", revisedPrompt=" + quoted(revisedPrompt) + + " }"; } } From 9602f37de334a0e053aa3a7f2a57caf5902a43d8 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:43:58 +0100 Subject: [PATCH 11/30] Try to fix JavaDoc --- .../java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 510c919a5e..c9bdc184bf 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -17,7 +17,7 @@ /** * Represents an OpenAI image model, hosted on Azure, such as dall-e-3. *

- * You can find a tutorial on using Azure OpenAI to generate images at: https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=programming-language-java + * You can find a tutorial on using Azure OpenAI to generate images at: https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?pivots=programming-language-java *

* Mandatory parameters for initialization are: endpoint, serviceVersion, apiKey and deploymentName. * You can also provide your own OpenAIClient instance, if you need more flexibility. From 1b97059c359dddb932f15ced0a76e9384b899ea0 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 12:36:02 +0100 Subject: [PATCH 12/30] AiMessage cannot have a "null" text, so putting "" by default. --- .../dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java | 2 +- .../dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index c602923a72..df83d8fcea 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -186,7 +186,7 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes .arguments(functionCall.getArguments()) .build(); - return aiMessage(toolExecutionRequest); + return new AiMessage("", Collections.singletonList(toolExecutionRequest)); } } diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java index 941d1bcc0b..a91e1eab5a 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java @@ -128,6 +128,7 @@ void should_call_function_with_argument() { List chatMessages = new ArrayList<>(); chatMessages.add(systemMessage); + chatMessages.add(aiMessage); chatMessages.add(userMessage); chatMessages.add(toolExecutionResultMessage); From cec63f867ce06e71030bd29174e5fe811818fef1 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Fri, 15 Dec 2023 17:33:19 +0100 Subject: [PATCH 13/30] Add image generation support with Azure OpenAI --- .../model/azure/AzureOpenAiChatModel.java | 5 +- .../model/azure/AzureOpenAiImageModel.java | 255 ++++++++++++++++++ .../azure/InternalAzureOpenAiHelper.java | 7 + .../model/azure/AzureOpenAiImageModelIT.java | 34 +++ 4 files changed, 298 insertions(+), 3 deletions(-) create mode 100644 langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java create mode 100644 langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java index 13b2b243f2..4c26f1e2e1 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java @@ -19,6 +19,7 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; +import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO_0613; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; import static java.util.Collections.singletonList; @@ -84,8 +85,6 @@ public AzureOpenAiChatModel(String endpoint, this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } - - private AzureOpenAiChatModel(String deploymentName, Tokenizer tokenizer, Double temperature, @@ -94,7 +93,7 @@ private AzureOpenAiChatModel(String deploymentName, Double presencePenalty, Double frequencyPenalty) { - this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo-0613"); + this.deploymentName = getOrDefault(deploymentName, GPT_3_5_TURBO_0613); this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO)); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java new file mode 100644 index 0000000000..062a50fde3 --- /dev/null +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -0,0 +1,255 @@ +package dev.langchain4j.model.azure; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.models.*; +import com.azure.core.http.ProxyOptions; +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.image.ImageModel; +import dev.langchain4j.model.output.Response; + +import java.time.Duration; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; + +public class AzureOpenAiImageModel implements ImageModel { + + private OpenAIClient client; + private final String deploymentName; + private final ImageGenerationQuality quality; + private final ImageSize size; + private final String user; + private final ImageGenerationStyle style; + private final ImageGenerationResponseFormat responseFormat; + + public AzureOpenAiImageModel(OpenAIClient client, + String deploymentName, + String quality, + String size, + String user, + String style, + String responseFormat) { + + this(deploymentName, quality, size, user, style, responseFormat); + this.client = client; + } + + public AzureOpenAiImageModel(String endpoint, + String serviceVersion, + String apiKey, + String deploymentName, + String quality, + String size, + String user, + String style, + String responseFormat, + Duration timeout, + Integer maxRetries, + ProxyOptions proxyOptions, + boolean logRequestsAndResponses) { + + this(deploymentName, quality, size, user, style, responseFormat); + this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); + } + + private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { + this.deploymentName = getOrDefault(deploymentName, GPT_3_5_TURBO); + this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); + this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE256X256); + this.user = getOrDefault(user, ""); + this.style = getOrDefault(ImageGenerationStyle.fromString(style), ImageGenerationStyle.NATURAL); + this.responseFormat = getOrDefault(ImageGenerationResponseFormat.fromString(responseFormat), ImageGenerationResponseFormat.URL); + } + + @Override + public Response generate(String prompt) { + + ImageGenerationOptions options = new ImageGenerationOptions(prompt) + .setModel(deploymentName) + .setN(1) + .setQuality(quality) + .setSize(size) + .setUser(user) + .setStyle(style) + .setResponseFormat(responseFormat); + + ImageGenerations imageGenerations = client.getImageGenerations(deploymentName, options); + + return Response.from(imageFrom(imageGenerations.getData().get(0))); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String endpoint; + private String serviceVersion; + private String apiKey; + private String deploymentName; + private String quality; + private String size; + private String user; + private String style; + private String responseFormat; + private Duration timeout; + private Integer maxRetries; + private ProxyOptions proxyOptions; + private boolean logRequestsAndResponses; + private OpenAIClient openAIClient; + + /** + * Sets the Azure OpenAI endpoint. This is a mandatory parameter. + * + * @param endpoint The Azure OpenAI endpoint in the format: https://{resource}.openai.azure.com/ + * @return builder + */ + public Builder endpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + /** + * Sets the Azure OpenAI API service version. This is a mandatory parameter. + * + * @param serviceVersion The Azure OpenAI API service version in the format: 2023-05-15 + * @return builder + */ + public Builder serviceVersion(String serviceVersion) { + this.serviceVersion = serviceVersion; + return this; + } + + /** + * Sets the Azure OpenAI API key. This is a mandatory parameter. + * + * @param apiKey The Azure OpenAI API key. + * @return builder + */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Sets the deployment name in Azure OpenAI. This is a mandatory parameter. + * + * @param deploymentName The Deployment name. + * @return builder + */ + public Builder deploymentName(String deploymentName) { + this.deploymentName = deploymentName; + return this; + } + + /** + * Sets the quality of the image. This is an optional parameter. + * + * @param quality The quality of the image. + * @return builder + */ + public Builder quality(String quality) { + this.quality = quality; + return this; + } + + /** + * Sets the size of the image. This is an optional parameter. + * + * @param size The size of the image. + * @return builder + */ + public Builder size(String size) { + this.size = size; + return this; + } + + /** + * Sets the user of the image. This is an optional parameter. + * + * @param user The user of the image. + * @return builder + */ + public Builder user(String user) { + this.user = user; + return this; + } + + /** + * Sets the style of the image. This is an optional parameter. + * + * @param style The style of the image. + * @return builder + */ + public Builder style(String style) { + this.style = style; + return this; + } + + /** + * Sets the response format of the image. This is an optional parameter. + * + * @param responseFormat The response format of the image. + * @return builder + */ + public Builder responseFormat(String responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxRetries(Integer maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder proxyOptions(ProxyOptions proxyOptions) { + this.proxyOptions = proxyOptions; + return this; + } + + public Builder logRequestsAndResponses(Boolean logRequestsAndResponses) { + this.logRequestsAndResponses = logRequestsAndResponses; + return this; + } + + public Builder openAIClient(OpenAIClient openAIClient) { + this.openAIClient = openAIClient; + return this; + } + + public AzureOpenAiImageModel build() { + if (openAIClient != null) { + return new AzureOpenAiImageModel( + openAIClient, + deploymentName, + quality, + size, + user, + style, + responseFormat); + } + return new AzureOpenAiImageModel( + endpoint, + serviceVersion, + apiKey, + deploymentName, + quality, + size, + user, + style, + responseFormat, + timeout, + maxRetries, + proxyOptions, + logRequestsAndResponses); + } + } +} diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index f5331b27ce..dfcf46e2ef 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -17,6 +17,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; @@ -187,6 +188,12 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes } } + public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData imageGenerationData) { + return Image.builder() + .url(imageGenerationData.getUrl()) + .build(); + } + public static TokenUsage tokenUsageFrom(CompletionsUsage openAiUsage) { if (openAiUsage == null) { return null; diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java new file mode 100644 index 0000000000..059f2f1f89 --- /dev/null +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -0,0 +1,34 @@ +package dev.langchain4j.model.azure; + +import dev.langchain4j.data.image.Image; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.assertj.core.api.Assertions.assertThat; + +public class AzureOpenAiImageModelIT { + + Logger logger = LoggerFactory.getLogger(AzureOpenAiImageModelIT.class); + + @Test + void should_generate_image() { + + AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .serviceVersion("2023-06-01-preview") + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) + .logRequestsAndResponses(true) + .size("1024x1024") + .build(); + + Response response = model.generate("An image of a Java developer in Paris, France"); + + logger.info(response.toString()); + + assertThat(response.content().url()).isNotBlank(); + + } +} From 6997c05e389308d9c1c6770b87e1a492ce0d0344 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Sat, 16 Dec 2023 15:50:06 +0100 Subject: [PATCH 14/30] Fix default deployment names --- .../dev/langchain4j/model/azure/AzureOpenAiChatModel.java | 3 +-- .../dev/langchain4j/model/azure/AzureOpenAiModelName.java | 4 ++++ .../model/azure/AzureOpenAiStreamingChatModel.java | 2 +- .../dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java | 2 -- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java index 4c26f1e2e1..bfc227eacb 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiChatModel.java @@ -19,7 +19,6 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; -import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO_0613; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; import static java.util.Collections.singletonList; @@ -93,7 +92,7 @@ private AzureOpenAiChatModel(String deploymentName, Double presencePenalty, Double frequencyPenalty) { - this.deploymentName = getOrDefault(deploymentName, GPT_3_5_TURBO_0613); + this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo"); this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO)); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java index a7f99d558b..f99032bf80 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiModelName.java @@ -6,6 +6,7 @@ public class AzureOpenAiModelName { public static final String GPT_3_5_TURBO = "gpt-3.5-turbo"; // alias for the latest model public static final String GPT_3_5_TURBO_0301 = "gpt-3.5-turbo-0301"; // 4k context public static final String GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613"; // 4k context, functions + public static final String GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106"; // 16k context, functions public static final String GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k"; // alias for the latest model public static final String GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613"; // 16k context, functions @@ -28,4 +29,7 @@ public class AzureOpenAiModelName { // Use with AzureOpenAiEmbeddingModel public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002"; + // Use with AzureOpenAiImageModel + public static final String DALL_E_3 = "DALL_E_3"; + } diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java index 79be8a6dd7..ce6a032c91 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java @@ -97,7 +97,7 @@ private AzureOpenAiStreamingChatModel(String deploymentName, Double presencePenalty, Double frequencyPenalty) { - this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo-0613"); + this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo"); this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO)); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java index 68d3a155d3..941d1bcc0b 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java @@ -120,7 +120,6 @@ void should_call_function_with_argument() { String weather = String.format("The weather in %s is %d degrees %s.", weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit()); - assertThat(weather).isEqualTo("The weather in Paris, France is 35 degrees celsius."); // Now that we know the function's result, we can call the model again with the result as input. @@ -130,7 +129,6 @@ void should_call_function_with_argument() { List chatMessages = new ArrayList<>(); chatMessages.add(systemMessage); chatMessages.add(userMessage); - chatMessages.add(aiMessage); chatMessages.add(toolExecutionResultMessage); Response response2 = model.generate(chatMessages); From fc8fa212c57faf61b235ca84f8148749f51465ff Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Sat, 16 Dec 2023 15:50:42 +0100 Subject: [PATCH 15/30] Finalize integration tests --- .../langchain4j/model/azure/AzureOpenAiImageModel.java | 5 ++--- .../langchain4j/model/azure/AzureOpenAiImageModelIT.java | 8 ++------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 062a50fde3..7b623ad096 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -10,7 +10,6 @@ import java.time.Duration; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; @@ -55,9 +54,9 @@ public AzureOpenAiImageModel(String endpoint, } private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { - this.deploymentName = getOrDefault(deploymentName, GPT_3_5_TURBO); + this.deploymentName = getOrDefault(deploymentName, "dall-e-3"); this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); - this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE256X256); + this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024); this.user = getOrDefault(user, ""); this.style = getOrDefault(ImageGenerationStyle.fromString(style), ImageGenerationStyle.NATURAL); this.responseFormat = getOrDefault(ImageGenerationResponseFormat.fromString(responseFormat), ImageGenerationResponseFormat.URL); diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index 059f2f1f89..de13f2aacd 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -17,18 +17,14 @@ void should_generate_image() { AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .serviceVersion("2023-06-01-preview") - .apiKey(System.getenv("AZURE_OPENAI_KEY")) .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) .logRequestsAndResponses(true) - .size("1024x1024") .build(); - Response response = model.generate("An image of a Java developer in Paris, France"); + Response response = model.generate("A coffee mug in Paris, France"); logger.info(response.toString()); - assertThat(response.content().url()).isNotBlank(); - } } From 77d69b266135c3f8fe9850c6f0d327ac1cb657be Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 10:15:53 +0100 Subject: [PATCH 16/30] Update the API usage following code review from @Heezer --- .../model/azure/AzureOpenAiImageModel.java | 37 +++++++++++++++---- .../azure/InternalAzureOpenAiHelper.java | 21 +++++++++-- .../model/azure/AzureOpenAiImageModelIT.java | 11 +++++- 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 7b623ad096..dba89764f7 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -8,15 +8,17 @@ import dev.langchain4j.model.output.Response; import java.time.Duration; +import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; -import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*; +import static java.util.stream.Collectors.toList; public class AzureOpenAiImageModel implements ImageModel { private OpenAIClient client; private final String deploymentName; + private final Integer n; private final ImageGenerationQuality quality; private final ImageSize size; private final String user; @@ -25,13 +27,14 @@ public class AzureOpenAiImageModel implements ImageModel { public AzureOpenAiImageModel(OpenAIClient client, String deploymentName, + Integer n, String quality, String size, String user, String style, String responseFormat) { - this(deploymentName, quality, size, user, style, responseFormat); + this(deploymentName, n, quality, size, user, style, responseFormat); this.client = client; } @@ -39,6 +42,7 @@ public AzureOpenAiImageModel(String endpoint, String serviceVersion, String apiKey, String deploymentName, + Integer n, String quality, String size, String user, @@ -49,12 +53,13 @@ public AzureOpenAiImageModel(String endpoint, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { - this(deploymentName, quality, size, user, style, responseFormat); + this(deploymentName, n, quality, size, user, style, responseFormat); this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } - private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { + private AzureOpenAiImageModel(String deploymentName, Integer n, String quality, String size, String user, String style, String responseFormat) { this.deploymentName = getOrDefault(deploymentName, "dall-e-3"); + this.n = getOrDefault(n, 1); this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024); this.user = getOrDefault(user, ""); @@ -63,11 +68,11 @@ private AzureOpenAiImageModel(String deploymentName, String quality, String size } @Override - public Response generate(String prompt) { + public Response> generate(String prompt) { ImageGenerationOptions options = new ImageGenerationOptions(prompt) .setModel(deploymentName) - .setN(1) + .setN(n) .setQuality(quality) .setSize(size) .setUser(user) @@ -76,7 +81,11 @@ public Response generate(String prompt) { ImageGenerations imageGenerations = client.getImageGenerations(deploymentName, options); - return Response.from(imageFrom(imageGenerations.getData().get(0))); + List images = imageGenerations.getData().stream() + .map(InternalAzureOpenAiHelper::imageFrom) + .collect(toList()); + + return Response.from(images); } public static Builder builder() { @@ -89,6 +98,7 @@ public static class Builder { private String serviceVersion; private String apiKey; private String deploymentName; + private Integer n; private String quality; private String size; private String user; @@ -144,6 +154,17 @@ public Builder deploymentName(String deploymentName) { return this; } + /** + * Sets the number of images to generate. This is an optional parameter. + * + * @param n The number of images to generate. + * @return builder + */ + public Builder setN(Integer n) { + this.n = n; + return this; + } + /** * Sets the quality of the image. This is an optional parameter. * @@ -229,6 +250,7 @@ public AzureOpenAiImageModel build() { return new AzureOpenAiImageModel( openAIClient, deploymentName, + n, quality, size, user, @@ -240,6 +262,7 @@ public AzureOpenAiImageModel build() { serviceVersion, apiKey, deploymentName, + n, quality, size, user, diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index dfcf46e2ef..12a4eb344a 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -22,6 +22,10 @@ import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; import java.time.Duration; import java.util.*; @@ -189,9 +193,20 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes } public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData imageGenerationData) { - return Image.builder() - .url(imageGenerationData.getUrl()) - .build(); + Image.Builder imageBuilder = Image.builder() + .revisedPrompt(imageGenerationData.getRevisedPrompt()); + + String urlString = imageGenerationData.getUrl(); + if (urlString != null) { + try { + URI uri = new URI(urlString); + imageBuilder.url(uri); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + return imageBuilder.build(); } public static TokenUsage tokenUsageFrom(CompletionsUsage openAiUsage) { diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index de13f2aacd..4f0b85f578 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -6,7 +6,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.List; + import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; public class AzureOpenAiImageModelIT { @@ -20,11 +23,15 @@ void should_generate_image() { .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .logRequestsAndResponses(true) + .setN(1) .build(); - Response response = model.generate("A coffee mug in Paris, France"); + Response> response = model.generate("A coffee mug in Paris, France"); logger.info(response.toString()); - assertThat(response.content().url()).isNotBlank(); + + Image image = response.content().get(0); + assertNotNull(image.revisedPrompt()); + assertNotNull(image.url()); } } From 4d5611a98aee5cb3f8b2e416b90837dcc17b81b0 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 13:58:11 +0100 Subject: [PATCH 17/30] Generate only 1 image, and support base64 data --- .../model/azure/AzureOpenAiImageModel.java | 44 +++++----------- .../azure/InternalAzureOpenAiHelper.java | 5 +- .../model/azure/AzureOpenAiImageModelIT.java | 52 ++++++++++++++++--- 3 files changed, 60 insertions(+), 41 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index dba89764f7..499b0a5afd 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -7,18 +7,17 @@ import dev.langchain4j.model.image.ImageModel; import dev.langchain4j.model.output.Response; +import java.nio.file.Path; import java.time.Duration; -import java.util.List; import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.imageFrom; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; -import static java.util.stream.Collectors.toList; public class AzureOpenAiImageModel implements ImageModel { private OpenAIClient client; private final String deploymentName; - private final Integer n; private final ImageGenerationQuality quality; private final ImageSize size; private final String user; @@ -27,14 +26,13 @@ public class AzureOpenAiImageModel implements ImageModel { public AzureOpenAiImageModel(OpenAIClient client, String deploymentName, - Integer n, String quality, String size, String user, String style, String responseFormat) { - this(deploymentName, n, quality, size, user, style, responseFormat); + this(deploymentName, quality, size, user, style, responseFormat); this.client = client; } @@ -42,7 +40,6 @@ public AzureOpenAiImageModel(String endpoint, String serviceVersion, String apiKey, String deploymentName, - Integer n, String quality, String size, String user, @@ -53,13 +50,12 @@ public AzureOpenAiImageModel(String endpoint, ProxyOptions proxyOptions, boolean logRequestsAndResponses) { - this(deploymentName, n, quality, size, user, style, responseFormat); + this(deploymentName, quality, size, user, style, responseFormat); this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses); } - private AzureOpenAiImageModel(String deploymentName, Integer n, String quality, String size, String user, String style, String responseFormat) { + private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { this.deploymentName = getOrDefault(deploymentName, "dall-e-3"); - this.n = getOrDefault(n, 1); this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024); this.user = getOrDefault(user, ""); @@ -68,11 +64,10 @@ private AzureOpenAiImageModel(String deploymentName, Integer n, String quality, } @Override - public Response> generate(String prompt) { - + public Response generate(String prompt) { ImageGenerationOptions options = new ImageGenerationOptions(prompt) .setModel(deploymentName) - .setN(n) + .setN(1) .setQuality(quality) .setSize(size) .setUser(user) @@ -80,12 +75,8 @@ public Response> generate(String prompt) { .setResponseFormat(responseFormat); ImageGenerations imageGenerations = client.getImageGenerations(deploymentName, options); - - List images = imageGenerations.getData().stream() - .map(InternalAzureOpenAiHelper::imageFrom) - .collect(toList()); - - return Response.from(images); + Image image = imageFrom(imageGenerations.getData().get(0)); + return Response.from(image); } public static Builder builder() { @@ -98,7 +89,6 @@ public static class Builder { private String serviceVersion; private String apiKey; private String deploymentName; - private Integer n; private String quality; private String size; private String user; @@ -108,6 +98,9 @@ public static class Builder { private Integer maxRetries; private ProxyOptions proxyOptions; private boolean logRequestsAndResponses; + private boolean withPersisting; + + private Path persistTo; private OpenAIClient openAIClient; /** @@ -154,17 +147,6 @@ public Builder deploymentName(String deploymentName) { return this; } - /** - * Sets the number of images to generate. This is an optional parameter. - * - * @param n The number of images to generate. - * @return builder - */ - public Builder setN(Integer n) { - this.n = n; - return this; - } - /** * Sets the quality of the image. This is an optional parameter. * @@ -250,7 +232,6 @@ public AzureOpenAiImageModel build() { return new AzureOpenAiImageModel( openAIClient, deploymentName, - n, quality, size, user, @@ -262,7 +243,6 @@ public AzureOpenAiImageModel build() { serviceVersion, apiKey, deploymentName, - n, quality, size, user, diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index 12a4eb344a..c602923a72 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -22,10 +22,8 @@ import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; -import java.net.MalformedURLException; import java.net.URI; import java.net.URISyntaxException; -import java.net.URL; import java.time.Duration; import java.util.*; @@ -197,6 +195,7 @@ public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData ima .revisedPrompt(imageGenerationData.getRevisedPrompt()); String urlString = imageGenerationData.getUrl(); + String imageData = imageGenerationData.getBase64Data(); if (urlString != null) { try { URI uri = new URI(urlString); @@ -204,6 +203,8 @@ public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData ima } catch (URISyntaxException e) { throw new RuntimeException(e); } + } else if (imageData != null) { + imageBuilder.base64Data(imageData); } return imageBuilder.build(); diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index 4f0b85f578..e6112e0d10 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -1,37 +1,75 @@ package dev.langchain4j.model.azure; +import com.azure.ai.openai.models.ImageGenerationResponseFormat; import dev.langchain4j.data.image.Image; import dev.langchain4j.model.output.Response; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.List; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Base64; -import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; public class AzureOpenAiImageModelIT { Logger logger = LoggerFactory.getLogger(AzureOpenAiImageModelIT.class); @Test - void should_generate_image() { + void should_generate_image_with_url() { AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) .apiKey(System.getenv("AZURE_OPENAI_KEY")) .logRequestsAndResponses(true) - .setN(1) .build(); - Response> response = model.generate("A coffee mug in Paris, France"); + Response response = model.generate("A coffee mug in Paris, France"); logger.info(response.toString()); - Image image = response.content().get(0); - assertNotNull(image.revisedPrompt()); + Image image = response.content(); + assertNotNull(image); assertNotNull(image.url()); + logger.info("The remote image is here: {}", image.url()); + + assertNull(image.base64Data()); + + assertNotNull(image.revisedPrompt()); + logger.info("The revised prompt is: {}", image.revisedPrompt()); + } + + @Test + void should_generate_image_in_base64() throws IOException { + AzureOpenAiImageModel model = AzureOpenAiImageModel.builder() + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")) + .apiKey(System.getenv("AZURE_OPENAI_KEY")) + .logRequestsAndResponses(false) // The image is big, so we don't want to log it by default + .responseFormat(ImageGenerationResponseFormat.BASE64.toString()) + .build(); + + Response response = model.generate("A croissant in Paris, France"); + + Image image = response.content(); + assertNotNull(image); + assertNull(image.url()); + assertNotNull(image.base64Data()); + logger.info("The image data is: {} characters", image.base64Data().length()); + + if (logger.isDebugEnabled()) { + byte[] decodedBytes = Base64.getDecoder().decode(response.content().base64Data()); + Path temp = Files.createTempFile("langchain4j", ".png"); + Files.write(temp, decodedBytes); + logger.debug("The image is here: {}", temp.toAbsolutePath()); + } + + assertNotNull(image.revisedPrompt()); + logger.info("The revised prompt is: {}", image.revisedPrompt()); } } From da89c29b4c08f71a865adc05d398b99ab7c0f5b6 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 16:16:33 +0100 Subject: [PATCH 18/30] Use default parameters from Azure OpenAI --- .../model/azure/AzureOpenAiImageModel.java | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 499b0a5afd..8bad9339f4 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -18,11 +18,11 @@ public class AzureOpenAiImageModel implements ImageModel { private OpenAIClient client; private final String deploymentName; - private final ImageGenerationQuality quality; - private final ImageSize size; - private final String user; - private final ImageGenerationStyle style; - private final ImageGenerationResponseFormat responseFormat; + private ImageGenerationQuality quality = null; + private ImageSize size = null; + private String user = null; + private ImageGenerationStyle style = null; + private ImageGenerationResponseFormat responseFormat = null; public AzureOpenAiImageModel(OpenAIClient client, String deploymentName, @@ -56,11 +56,21 @@ public AzureOpenAiImageModel(String endpoint, private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) { this.deploymentName = getOrDefault(deploymentName, "dall-e-3"); - this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD); - this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024); - this.user = getOrDefault(user, ""); - this.style = getOrDefault(ImageGenerationStyle.fromString(style), ImageGenerationStyle.NATURAL); - this.responseFormat = getOrDefault(ImageGenerationResponseFormat.fromString(responseFormat), ImageGenerationResponseFormat.URL); + if (quality != null) { + this.quality = ImageGenerationQuality.fromString(quality); + } + if (size != null) { + this.size = ImageSize.fromString(size); + } + if (user != null) { + this.user = user; + } + if (style != null) { + this.style = ImageGenerationStyle.fromString(style); + } + if (responseFormat != null) { + this.responseFormat = ImageGenerationResponseFormat.fromString(responseFormat); + } } @Override From 2753367665328e59935db7418634d22fc70a6fa8 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:23:54 +0100 Subject: [PATCH 19/30] Add documentation for using OpenAI Image generation --- .../model/azure/AzureOpenAiImageModel.java | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 8bad9339f4..e57c4f982b 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -14,6 +14,27 @@ import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.imageFrom; import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient; +/** + * Represents an OpenAI image model, hosted on Azure, such as dall-e-3. + *

+ * You can find a tutorial on using Azure OpenAI to generate images at: https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=programming-language-java + *

+ * Mandatory parameters for initialization are: endpoint, serviceVersion, apiKey and deploymentName. + * You can also provide your own OpenAIClient instance, if you need more flexibility. + *

+ * There are two primary authentication methods to access Azure OpenAI: + *

+ * 1. API Key Authentication: For this type of authentication, HTTP requests must include the + * API Key in the "api-key" HTTP header as follows: `api-key: OPENAI_API_KEY`Y` + *

+ * 2. Azure Active Directory Authentication: For this type of authentication, HTTP requests must include the + * authentication/access token in the "Authorization" HTTP header. + *

+ * More information + *

+ * Please note, that currently, only API Key authentication is supported by this class, + * second authentication option will be supported later. + */ public class AzureOpenAiImageModel implements ImageModel { private OpenAIClient client; From 3d03312a211e80d704c6308a1bbf02b3ad88dbec Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:29:42 +0100 Subject: [PATCH 20/30] Remove withPersisting and persistTo --- .../dev/langchain4j/model/azure/AzureOpenAiImageModel.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index e57c4f982b..46ea7bc61c 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -129,9 +129,6 @@ public static class Builder { private Integer maxRetries; private ProxyOptions proxyOptions; private boolean logRequestsAndResponses; - private boolean withPersisting; - - private Path persistTo; private OpenAIClient openAIClient; /** From 788263725259072e4710d9756febf94242aea619 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:35:57 +0100 Subject: [PATCH 21/30] Add enums in the Builder methods --- .../model/azure/AzureOpenAiImageModel.java | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 46ea7bc61c..510c919a5e 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -186,6 +186,17 @@ public Builder quality(String quality) { return this; } + /** + * Sets the quality of the image, using the ImageGenerationQuality enum. This is an optional parameter. + * + * @param imageGenerationQuality The quality of the image. + * @return builder + */ + public Builder quality(ImageGenerationQuality imageGenerationQuality) { + this.quality = imageGenerationQuality.toString(); + return this; + } + /** * Sets the size of the image. This is an optional parameter. * @@ -197,6 +208,17 @@ public Builder size(String size) { return this; } + /** + * Sets the size of the image, using the ImageSize enum. This is an optional parameter. + * + * @param imageSize The size of the image. + * @return builder + */ + public Builder size(ImageSize imageSize) { + this.size = imageSize.toString(); + return this; + } + /** * Sets the user of the image. This is an optional parameter. * @@ -219,6 +241,17 @@ public Builder style(String style) { return this; } + /** + * Sets the style of the image, using the ImageGenerationStyle enum. This is an optional parameter. + * + * @param imageGenerationStyle The style of the image. + * @return builder + */ + public Builder style(ImageGenerationStyle imageGenerationStyle) { + this.style = imageGenerationStyle.toString(); + return this; + } + /** * Sets the response format of the image. This is an optional parameter. * @@ -230,6 +263,17 @@ public Builder responseFormat(String responseFormat) { return this; } + /** + * Sets the response format of the image, using the ImageGenerationResponseFormat enum. This is an optional parameter. + * + * @param imageGenerationResponseFormat The response format of the image. + * @return builder + */ + public Builder responseFormat(ImageGenerationResponseFormat imageGenerationResponseFormat) { + this.responseFormat = imageGenerationResponseFormat.toString(); + return this; + } + public Builder timeout(Duration timeout) { this.timeout = timeout; return this; From ce90c156f321acc632408f65abd6a48d3dcc4845 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Mon, 18 Dec 2023 20:43:58 +0100 Subject: [PATCH 22/30] Try to fix JavaDoc --- .../java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java index 510c919a5e..c9bdc184bf 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiImageModel.java @@ -17,7 +17,7 @@ /** * Represents an OpenAI image model, hosted on Azure, such as dall-e-3. *

- * You can find a tutorial on using Azure OpenAI to generate images at: https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=programming-language-java + * You can find a tutorial on using Azure OpenAI to generate images at: https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?pivots=programming-language-java *

* Mandatory parameters for initialization are: endpoint, serviceVersion, apiKey and deploymentName. * You can also provide your own OpenAIClient instance, if you need more flexibility. From 2d12aa90d38fc188e2153cfb4254abf517b5eb74 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 12:36:02 +0100 Subject: [PATCH 23/30] AiMessage cannot have a "null" text, so putting "" by default. --- .../dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java | 2 +- .../dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index c602923a72..df83d8fcea 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -186,7 +186,7 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes .arguments(functionCall.getArguments()) .build(); - return aiMessage(toolExecutionRequest); + return new AiMessage("", Collections.singletonList(toolExecutionRequest)); } } diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java index 941d1bcc0b..a91e1eab5a 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java @@ -128,6 +128,7 @@ void should_call_function_with_argument() { List chatMessages = new ArrayList<>(); chatMessages.add(systemMessage); + chatMessages.add(aiMessage); chatMessages.add(userMessage); chatMessages.add(toolExecutionResultMessage); From 502f9bd8e9c907ff9357730c724ac97d3ef54525 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 12:48:13 +0100 Subject: [PATCH 24/30] Sync up with the main branch, to use the new Image API. --- .../model/azure/InternalAzureOpenAiHelper.java | 2 +- .../langchain4j/model/azure/AzureOpenAiImageModelIT.java | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index df83d8fcea..c47a7b2018 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -204,7 +204,7 @@ public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData ima throw new RuntimeException(e); } } else if (imageData != null) { - imageBuilder.base64Data(imageData); + imageBuilder.base64(imageData); } return imageBuilder.build(); diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index e6112e0d10..6481614d42 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -38,7 +38,7 @@ void should_generate_image_with_url() { assertNotNull(image.url()); logger.info("The remote image is here: {}", image.url()); - assertNull(image.base64Data()); + assertNull(image.base64()); assertNotNull(image.revisedPrompt()); logger.info("The revised prompt is: {}", image.revisedPrompt()); @@ -59,11 +59,11 @@ void should_generate_image_in_base64() throws IOException { Image image = response.content(); assertNotNull(image); assertNull(image.url()); - assertNotNull(image.base64Data()); - logger.info("The image data is: {} characters", image.base64Data().length()); + assertNotNull(image.base64()); + logger.info("The image data is: {} characters", image.base64().length()); if (logger.isDebugEnabled()) { - byte[] decodedBytes = Base64.getDecoder().decode(response.content().base64Data()); + byte[] decodedBytes = Base64.getDecoder().decode(response.content().base64()); Path temp = Files.createTempFile("langchain4j", ".png"); Files.write(temp, decodedBytes); logger.debug("The image is here: {}", temp.toAbsolutePath()); From 4049448fe4a5f59a23cb3b7e6cb5fbbfd7fd0a53 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 14:10:39 +0100 Subject: [PATCH 25/30] Fix messages order --- .../dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java index a91e1eab5a..6266256056 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiChatModelIT.java @@ -128,8 +128,8 @@ void should_call_function_with_argument() { List chatMessages = new ArrayList<>(); chatMessages.add(systemMessage); - chatMessages.add(aiMessage); chatMessages.add(userMessage); + chatMessages.add(aiMessage); chatMessages.add(toolExecutionResultMessage); Response response2 = model.generate(chatMessages); From 39207e89d0768bf59b82d8b064be600fe7a5b91f Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 14:16:11 +0100 Subject: [PATCH 26/30] Use aiMessage(toolExecutionRequest) --- .../dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index c47a7b2018..71efad0232 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -186,7 +186,7 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes .arguments(functionCall.getArguments()) .build(); - return new AiMessage("", Collections.singletonList(toolExecutionRequest)); + return aiMessage(toolExecutionRequest); } } From 5ffc013d0e1b707097b2f8401a35bc1c2cff4ce5 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 14:42:11 +0100 Subject: [PATCH 27/30] Create a specific AiMessage implementation for Azure OpenAI --- .../model/azure/AzureAiMessage.java | 27 +++++++++++++++++++ .../azure/InternalAzureOpenAiHelper.java | 25 ++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureAiMessage.java diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureAiMessage.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureAiMessage.java new file mode 100644 index 0000000000..cf2f7b464c --- /dev/null +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureAiMessage.java @@ -0,0 +1,27 @@ +package dev.langchain4j.model.azure; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; + +import java.util.List; + +import static java.util.Arrays.asList; + +public class AzureAiMessage extends AiMessage { + + public AzureAiMessage(List toolExecutionRequests) { + super(toolExecutionRequests); + } + + public static AiMessage aiMessage(ToolExecutionRequest... toolExecutionRequests) { + return aiMessage(asList(toolExecutionRequests)); + } + + /** + * Returns an empty String as Azure OpenAI requires a non-Null object. + */ + @Override + public String text() { + return ""; + } +} diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index 71efad0232..992486ac2f 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -32,6 +32,7 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.model.output.FinishReason.*; import static java.time.Duration.ofSeconds; +import static java.util.Arrays.asList; import static java.util.stream.Collectors.toList; class InternalAzureOpenAiHelper { @@ -186,7 +187,7 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes .arguments(functionCall.getArguments()) .build(); - return aiMessage(toolExecutionRequest); + return AzureAiMessage.azureAiMessage(toolExecutionRequest); } } @@ -236,4 +237,26 @@ public static FinishReason finishReasonFrom(CompletionsFinishReason openAiFinish return null; } } + + /** + * Specific AiMessage implementation for Azure OpenAI, as it requires a non-Null "text" field. + */ + public static class AzureAiMessage extends AiMessage { + + public AzureAiMessage(List toolExecutionRequests) { + super(toolExecutionRequests); + } + + public static AiMessage azureAiMessage(ToolExecutionRequest... toolExecutionRequests) { + return new AzureAiMessage(asList(toolExecutionRequests)); + } + + /** + * Returns an empty String as Azure OpenAI requires a non-Null object. + */ + @Override + public String text() { + return ""; + } + } } From e4959d5ef3161e62a3d3a305a1da413da5cf0cb0 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 15:00:15 +0100 Subject: [PATCH 28/30] Refactor base64 to base64Data --- .../azure/InternalAzureOpenAiHelper.java | 2 +- .../model/azure/AzureOpenAiImageModelIT.java | 8 ++++---- .../dev/langchain4j/data/image/Image.java | 20 +++++++++---------- .../model/openai/OpenAiImageModel.java | 2 +- .../model/openai/OpenAiImageModelIT.java | 4 ++-- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index 992486ac2f..17b0992462 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -205,7 +205,7 @@ public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData ima throw new RuntimeException(e); } } else if (imageData != null) { - imageBuilder.base64(imageData); + imageBuilder.base64Data(imageData); } return imageBuilder.build(); diff --git a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java index 6481614d42..e6112e0d10 100644 --- a/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java +++ b/langchain4j-azure-open-ai/src/test/java/dev/langchain4j/model/azure/AzureOpenAiImageModelIT.java @@ -38,7 +38,7 @@ void should_generate_image_with_url() { assertNotNull(image.url()); logger.info("The remote image is here: {}", image.url()); - assertNull(image.base64()); + assertNull(image.base64Data()); assertNotNull(image.revisedPrompt()); logger.info("The revised prompt is: {}", image.revisedPrompt()); @@ -59,11 +59,11 @@ void should_generate_image_in_base64() throws IOException { Image image = response.content(); assertNotNull(image); assertNull(image.url()); - assertNotNull(image.base64()); - logger.info("The image data is: {} characters", image.base64().length()); + assertNotNull(image.base64Data()); + logger.info("The image data is: {} characters", image.base64Data().length()); if (logger.isDebugEnabled()) { - byte[] decodedBytes = Base64.getDecoder().decode(response.content().base64()); + byte[] decodedBytes = Base64.getDecoder().decode(response.content().base64Data()); Path temp = Files.createTempFile("langchain4j", ".png"); Files.write(temp, decodedBytes); logger.debug("The image is here: {}", temp.toAbsolutePath()); diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java index ea597426d0..5c64a6fa84 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/image/Image.java @@ -8,12 +8,12 @@ public final class Image { private URI url; - private String base64; + private String base64Data; private String revisedPrompt; private Image(Builder builder) { this.url = builder.url; - this.base64 = builder.base64; + this.base64Data = builder.base64Data; this.revisedPrompt = builder.revisedPrompt; } @@ -25,8 +25,8 @@ public URI url() { return url; } - public String base64() { - return base64; + public String base64Data() { + return base64Data; } public String revisedPrompt() { @@ -40,7 +40,7 @@ public boolean equals(Object o) { Image image = (Image) o; return ( Objects.equals(url, image.url) && - Objects.equals(base64, image.base64) && + Objects.equals(base64Data, image.base64Data) && Objects.equals(revisedPrompt, image.revisedPrompt) ); } @@ -56,8 +56,8 @@ public String toString() { "Image{" + " url=" + quoted(url.toString()) + - ", base64=" + - quoted(base64) + + ", base64Data=" + + quoted(base64Data) + ", revisedPrompt=" + quoted(revisedPrompt) + '}' @@ -67,7 +67,7 @@ public String toString() { public static class Builder { private URI url; - private String base64; + private String base64Data; private String revisedPrompt; public Builder url(URI url) { @@ -75,8 +75,8 @@ public Builder url(URI url) { return this; } - public Builder base64(String base64) { - this.base64 = base64; + public Builder base64Data(String base64Data) { + this.base64Data = base64Data; return this; } diff --git a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiImageModel.java b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiImageModel.java index 6978740e1a..f7c999782e 100644 --- a/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiImageModel.java +++ b/langchain4j-open-ai/src/main/java/dev/langchain4j/model/openai/OpenAiImageModel.java @@ -128,7 +128,7 @@ public OpenAiImageModelBuilder withApiKey(String apiKey) { } private static Image fromImageData(GenerateImagesResponse.ImageData data) { - return Image.builder().url(data.url()).base64(data.b64Json()).revisedPrompt(data.revisedPrompt()).build(); + return Image.builder().url(data.url()).base64Data(data.b64Json()).revisedPrompt(data.revisedPrompt()).build(); } private GenerateImagesRequest.Builder requestBuilder(String prompt) { diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiImageModelIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiImageModelIT.java index 16f36dc80c..bde5777891 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiImageModelIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiImageModelIT.java @@ -60,12 +60,12 @@ void multiple_images_generation_with_base64_works() { Image localImage1 = response.content().get(0); log.info("Your first local image is here: {}", localImage1.url()); assertThat(new File(localImage1.url())).exists(); - assertThat(localImage1.base64()).isNotNull().isBase64(); + assertThat(localImage1.base64Data()).isNotNull().isBase64(); Image localImage2 = response.content().get(1); log.info("Your second local image is here: {}", localImage2.url()); assertThat(new File(localImage2.url())).exists(); - assertThat(localImage2.base64()).isNotNull().isBase64(); + assertThat(localImage2.base64Data()).isNotNull().isBase64(); } @Test From fc9e3bad4e45903fd8590641d771be9e7e757a57 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 15:08:04 +0100 Subject: [PATCH 29/30] Remove the specific Azure OpenAI AiMessage implementation --- .../azure/InternalAzureOpenAiHelper.java | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java index 17b0992462..e8e81b15f9 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/InternalAzureOpenAiHelper.java @@ -32,7 +32,6 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.model.output.FinishReason.*; import static java.time.Duration.ofSeconds; -import static java.util.Arrays.asList; import static java.util.stream.Collectors.toList; class InternalAzureOpenAiHelper { @@ -85,7 +84,7 @@ public static List toOpenAiMessag public static com.azure.ai.openai.models.ChatRequestMessage toOpenAiMessage(ChatMessage message) { if (message instanceof AiMessage) { - ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(message.text()); + ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(getOrDefault(message.text(), "")); chatRequestAssistantMessage.setFunctionCall(functionCallFrom(message)); return chatRequestAssistantMessage; } else if (message instanceof ToolExecutionResultMessage) { @@ -187,7 +186,7 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes .arguments(functionCall.getArguments()) .build(); - return AzureAiMessage.azureAiMessage(toolExecutionRequest); + return aiMessage(toolExecutionRequest); } } @@ -237,26 +236,4 @@ public static FinishReason finishReasonFrom(CompletionsFinishReason openAiFinish return null; } } - - /** - * Specific AiMessage implementation for Azure OpenAI, as it requires a non-Null "text" field. - */ - public static class AzureAiMessage extends AiMessage { - - public AzureAiMessage(List toolExecutionRequests) { - super(toolExecutionRequests); - } - - public static AiMessage azureAiMessage(ToolExecutionRequest... toolExecutionRequests) { - return new AzureAiMessage(asList(toolExecutionRequests)); - } - - /** - * Returns an empty String as Azure OpenAI requires a non-Null object. - */ - @Override - public String text() { - return ""; - } - } } From 29cdef9a4525725534dbc146542519fb39951ad5 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 19 Dec 2023 15:24:45 +0100 Subject: [PATCH 30/30] Remove the specific Azure OpenAI AiMessage implementation --- .../model/azure/AzureAiMessage.java | 27 ------------------- 1 file changed, 27 deletions(-) delete mode 100644 langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureAiMessage.java diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureAiMessage.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureAiMessage.java deleted file mode 100644 index cf2f7b464c..0000000000 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureAiMessage.java +++ /dev/null @@ -1,27 +0,0 @@ -package dev.langchain4j.model.azure; - -import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.data.message.AiMessage; - -import java.util.List; - -import static java.util.Arrays.asList; - -public class AzureAiMessage extends AiMessage { - - public AzureAiMessage(List toolExecutionRequests) { - super(toolExecutionRequests); - } - - public static AiMessage aiMessage(ToolExecutionRequest... toolExecutionRequests) { - return aiMessage(asList(toolExecutionRequests)); - } - - /** - * Returns an empty String as Azure OpenAI requires a non-Null object. - */ - @Override - public String text() { - return ""; - } -}