Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image generation support with Azure OpenAI #359

Merged
merged 37 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f2309ad
Add image generation support with Azure OpenAI
jdubois Dec 15, 2023
20b4e43
Merge branch 'main' into azure-openai-image
jdubois Dec 15, 2023
4c24646
Fix default deployment names
jdubois Dec 16, 2023
a8dd958
Finalize integration tests
jdubois Dec 16, 2023
92621da
Merge branch 'main' into azure-openai-image
jdubois Dec 18, 2023
c8aabbf
Update the API usage following code review from @Heezer
jdubois Dec 18, 2023
1f28f0c
Generate only 1 image, and support base64 data
jdubois Dec 18, 2023
00bf37a
Merge branch 'main' into azure-openai-image
jdubois Dec 18, 2023
36bf7fa
Use default parameters from Azure OpenAI
jdubois Dec 18, 2023
9e8a32c
Add documentation for using OpenAI Image generation
jdubois Dec 18, 2023
6de6f68
Merge branch 'main' into azure-openai-image
jdubois Dec 18, 2023
08cbbf1
Remove withPersisting and persistTo
jdubois Dec 18, 2023
072bdb6
Add enums in the Builder methods
jdubois Dec 18, 2023
7ed0f7c
Use Utils.quoted
jdubois Dec 18, 2023
9602f37
Try to fix JavaDoc
jdubois Dec 18, 2023
1b97059
AiMessage cannot have a "null" text, so putting "" by default.
jdubois Dec 19, 2023
cec63f8
Add image generation support with Azure OpenAI
jdubois Dec 15, 2023
6997c05
Fix default deployment names
jdubois Dec 16, 2023
fc8fa21
Finalize integration tests
jdubois Dec 16, 2023
77d69b2
Update the API usage following code review from @Heezer
jdubois Dec 18, 2023
4d5611a
Generate only 1 image, and support base64 data
jdubois Dec 18, 2023
da89c29
Use default parameters from Azure OpenAI
jdubois Dec 18, 2023
2753367
Add documentation for using OpenAI Image generation
jdubois Dec 18, 2023
3d03312
Remove withPersisting and persistTo
jdubois Dec 18, 2023
7882637
Add enums in the Builder methods
jdubois Dec 18, 2023
ce90c15
Try to fix JavaDoc
jdubois Dec 18, 2023
2d12aa9
AiMessage cannot have a "null" text, so putting "" by default.
jdubois Dec 19, 2023
502f9bd
Sync up with the main branch, to use the new Image API.
jdubois Dec 19, 2023
b13b357
Merge remote-tracking branch 'origin/azure-openai-image' into azure-o…
jdubois Dec 19, 2023
4399dab
Merge branch 'main' into azure-openai-image
langchain4j Dec 19, 2023
4049448
Fix messages order
jdubois Dec 19, 2023
39207e8
Use aiMessage(toolExecutionRequest)
jdubois Dec 19, 2023
5ffc013
Create a specific AiMessage implementation for Azure OpenAI
jdubois Dec 19, 2023
e4959d5
Refactor base64 to base64Data
jdubois Dec 19, 2023
fc9e3ba
Remove the specific Azure OpenAI AiMessage implementation
jdubois Dec 19, 2023
a2283d8
Merge branch 'main' into azure-openai-image
jdubois Dec 19, 2023
29cdef9
Remove the specific Azure OpenAI AiMessage implementation
jdubois Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ public AzureOpenAiChatModel(String endpoint,
this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses);
}



private AzureOpenAiChatModel(String deploymentName,
Tokenizer tokenizer,
Double temperature,
Expand All @@ -94,7 +92,7 @@ private AzureOpenAiChatModel(String deploymentName,
Double presencePenalty,
Double frequencyPenalty) {

this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo-0613");
this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo");
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO));
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
package dev.langchain4j.model.azure;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.*;
import com.azure.core.http.ProxyOptions;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;

import java.nio.file.Path;
import java.time.Duration;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.imageFrom;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient;

public class AzureOpenAiImageModel implements ImageModel {
jdubois marked this conversation as resolved.
Show resolved Hide resolved

private OpenAIClient client;
private final String deploymentName;
private final ImageGenerationQuality quality;
private final ImageSize size;
private final String user;
private final ImageGenerationStyle style;
private final ImageGenerationResponseFormat responseFormat;

public AzureOpenAiImageModel(OpenAIClient client,
String deploymentName,
String quality,
String size,
String user,
String style,
String responseFormat) {

this(deploymentName, quality, size, user, style, responseFormat);
this.client = client;
}

public AzureOpenAiImageModel(String endpoint,
String serviceVersion,
String apiKey,
String deploymentName,
String quality,
String size,
String user,
String style,
String responseFormat,
Duration timeout,
Integer maxRetries,
ProxyOptions proxyOptions,
boolean logRequestsAndResponses) {

this(deploymentName, quality, size, user, style, responseFormat);
this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses);
}

private AzureOpenAiImageModel(String deploymentName, String quality, String size, String user, String style, String responseFormat) {
this.deploymentName = getOrDefault(deploymentName, "dall-e-3");
this.quality = getOrDefault(ImageGenerationQuality.fromString(quality), ImageGenerationQuality.STANDARD);
jdubois marked this conversation as resolved.
Show resolved Hide resolved
this.size = getOrDefault(ImageSize.fromString(size), ImageSize.SIZE1024X1024);
this.user = getOrDefault(user, "");
this.style = getOrDefault(ImageGenerationStyle.fromString(style), ImageGenerationStyle.NATURAL);
this.responseFormat = getOrDefault(ImageGenerationResponseFormat.fromString(responseFormat), ImageGenerationResponseFormat.URL);
}

@Override
public Response<Image> generate(String prompt) {
ImageGenerationOptions options = new ImageGenerationOptions(prompt)
.setModel(deploymentName)
.setN(1)
.setQuality(quality)
.setSize(size)
.setUser(user)
.setStyle(style)
.setResponseFormat(responseFormat);

ImageGenerations imageGenerations = client.getImageGenerations(deploymentName, options);
Image image = imageFrom(imageGenerations.getData().get(0));
return Response.from(image);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
jdubois marked this conversation as resolved.
Show resolved Hide resolved

private String endpoint;
private String serviceVersion;
private String apiKey;
private String deploymentName;
private String quality;
private String size;
private String user;
private String style;
private String responseFormat;
private Duration timeout;
private Integer maxRetries;
private ProxyOptions proxyOptions;
private boolean logRequestsAndResponses;
private boolean withPersisting;
jdubois marked this conversation as resolved.
Show resolved Hide resolved

private Path persistTo;
jdubois marked this conversation as resolved.
Show resolved Hide resolved
private OpenAIClient openAIClient;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
*
* @param endpoint The Azure OpenAI endpoint in the format: https://{resource}.openai.azure.com/
* @return builder
*/
public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}

/**
* Sets the Azure OpenAI API service version. This is a mandatory parameter.
*
* @param serviceVersion The Azure OpenAI API service version in the format: 2023-05-15
* @return builder
*/
public Builder serviceVersion(String serviceVersion) {
this.serviceVersion = serviceVersion;
return this;
}

/**
* Sets the Azure OpenAI API key. This is a mandatory parameter.
*
* @param apiKey The Azure OpenAI API key.
* @return builder
*/
public Builder apiKey(String apiKey) {
this.apiKey = apiKey;
return this;
}

/**
* Sets the deployment name in Azure OpenAI. This is a mandatory parameter.
*
* @param deploymentName The Deployment name.
* @return builder
*/
public Builder deploymentName(String deploymentName) {
this.deploymentName = deploymentName;
return this;
}

/**
* Sets the quality of the image. This is an optional parameter.
*
* @param quality The quality of the image.
* @return builder
*/
public Builder quality(String quality) {
jdubois marked this conversation as resolved.
Show resolved Hide resolved
this.quality = quality;
return this;
}

/**
* Sets the size of the image. This is an optional parameter.
*
* @param size The size of the image.
* @return builder
*/
public Builder size(String size) {
jdubois marked this conversation as resolved.
Show resolved Hide resolved
this.size = size;
return this;
}

/**
* Sets the user of the image. This is an optional parameter.
*
* @param user The user of the image.
* @return builder
*/
public Builder user(String user) {
this.user = user;
return this;
}

/**
* Sets the style of the image. This is an optional parameter.
*
* @param style The style of the image.
* @return builder
*/
public Builder style(String style) {
jdubois marked this conversation as resolved.
Show resolved Hide resolved
this.style = style;
return this;
}

/**
* Sets the response format of the image. This is an optional parameter.
*
* @param responseFormat The response format of the image.
* @return builder
*/
public Builder responseFormat(String responseFormat) {
jdubois marked this conversation as resolved.
Show resolved Hide resolved
this.responseFormat = responseFormat;
return this;
}

public Builder timeout(Duration timeout) {
this.timeout = timeout;
return this;
}

public Builder maxRetries(Integer maxRetries) {
this.maxRetries = maxRetries;
return this;
}

public Builder proxyOptions(ProxyOptions proxyOptions) {
this.proxyOptions = proxyOptions;
return this;
}

public Builder logRequestsAndResponses(Boolean logRequestsAndResponses) {
this.logRequestsAndResponses = logRequestsAndResponses;
return this;
}

public Builder openAIClient(OpenAIClient openAIClient) {
this.openAIClient = openAIClient;
return this;
}

public AzureOpenAiImageModel build() {
if (openAIClient != null) {
return new AzureOpenAiImageModel(
openAIClient,
deploymentName,
quality,
size,
user,
style,
responseFormat);
}
return new AzureOpenAiImageModel(
endpoint,
serviceVersion,
apiKey,
deploymentName,
quality,
size,
user,
style,
responseFormat,
timeout,
maxRetries,
proxyOptions,
logRequestsAndResponses);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ public class AzureOpenAiModelName {
public static final String GPT_3_5_TURBO = "gpt-3.5-turbo"; // alias for the latest model
public static final String GPT_3_5_TURBO_0301 = "gpt-3.5-turbo-0301"; // 4k context
public static final String GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613"; // 4k context, functions
public static final String GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106"; // 16k context, functions

public static final String GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k"; // alias for the latest model
public static final String GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613"; // 16k context, functions
Expand All @@ -28,4 +29,7 @@ public class AzureOpenAiModelName {
// Use with AzureOpenAiEmbeddingModel
public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002";

// Use with AzureOpenAiImageModel
public static final String DALL_E_3 = "DALL_E_3";

}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ private AzureOpenAiStreamingChatModel(String deploymentName,
Double presencePenalty,
Double frequencyPenalty) {

this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo-0613");
this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo");
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO));
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;

import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
import java.util.*;

Expand Down Expand Up @@ -187,6 +190,26 @@ public static AiMessage aiMessageFrom(com.azure.ai.openai.models.ChatResponseMes
}
}

public static Image imageFrom(com.azure.ai.openai.models.ImageGenerationData imageGenerationData) {
Image.Builder imageBuilder = Image.builder()
.revisedPrompt(imageGenerationData.getRevisedPrompt());

String urlString = imageGenerationData.getUrl();
String imageData = imageGenerationData.getBase64Data();
if (urlString != null) {
try {
URI uri = new URI(urlString);
imageBuilder.url(uri);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
} else if (imageData != null) {
imageBuilder.base64Data(imageData);
}

return imageBuilder.build();
}

public static TokenUsage tokenUsageFrom(CompletionsUsage openAiUsage) {
if (openAiUsage == null) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ void should_call_function_with_argument() {
String weather = String.format("The weather in %s is %d degrees %s.",
weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit());


assertThat(weather).isEqualTo("The weather in Paris, France is 35 degrees celsius.");

// Now that we know the function's result, we can call the model again with the result as input.
Expand All @@ -130,7 +129,6 @@ void should_call_function_with_argument() {
List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(systemMessage);
chatMessages.add(userMessage);
jdubois marked this conversation as resolved.
Show resolved Hide resolved
chatMessages.add(aiMessage);
jdubois marked this conversation as resolved.
Show resolved Hide resolved
chatMessages.add(toolExecutionResultMessage);

Response<AiMessage> response2 = model.generate(chatMessages);
Expand Down