Skip to content

Commit

Permalink
Update language models and tests for Azure OpenAI (#573)
Browse files Browse the repository at this point in the history
This PR update the language models for Azure OpenAI, and improves some
tests:

- It updates the current models available on Azure OpenAI, removing
those which are not available like `gpt-4-0314`, `gpt-3.5-turbo-0301`
- It creates some upcoming models as described in
https://openai.com/blog/new-embedding-models-and-api-updates
- It passes those updated model names to the integration tests and fixes
some issues there: I have no idea why we were testing for "length" and
not for "stop", so now we do both correctly.
- Also the Dall E 3 test was "disabled", so I enabled it again.
  • Loading branch information
jdubois committed Feb 8, 2024
1 parent df21030 commit 3236162
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@ public class AzureOpenAiModelName {

// Use with AzureOpenAiChatModel and AzureOpenAiStreamingChatModel
public static final String GPT_3_5_TURBO = "gpt-3.5-turbo"; // alias for the latest model
public static final String GPT_3_5_TURBO_0301 = "gpt-3.5-turbo-0301"; // 4k context
public static final String GPT_3_5_TURBO_0613 = "gpt-3.5-turbo-0613"; // 4k context, functions
public static final String GPT_3_5_TURBO_0125 = "gpt-3.5-turbo-0125"; // 4k context, functions
public static final String GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106"; // 16k context, functions

public static final String GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k"; // alias for the latest model
public static final String GPT_3_5_TURBO_16K_0613 = "gpt-3.5-turbo-16k-0613"; // 16k context, functions

public static final String GPT_4 = "gpt-4"; // alias for the latest model
public static final String GPT_4_0314 = "gpt-4-0314"; // 8k context
public static final String GPT_4_1106_PREVIEW = "gpt-4-1106-preview"; // 8k context
public static final String GPT_4_0613 = "gpt-4-0613"; // 8k context, functions

public static final String GPT_4_32K = "gpt-4-32k"; // alias for the latest model
public static final String GPT_4_32K_0314 = "gpt-4-32k-0314"; // 32k context
public static final String GPT_4_32K_0613 = "gpt-4-32k-0613"; // 32k context, functions


// Use with AzureOpenAiLanguageModel and AzureOpenAiStreamingLanguageModel
public static final String TEXT_DAVINCI_003 = "text-davinci-003";
public static final String TEXT_DAVINCI_002 = "text-davinci-002";

public static final String GPT_3_5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct";


// Use with AzureOpenAiEmbeddingModel
public static final String TEXT_EMBEDDING_ADA_002 = "text-embedding-ada-002";
public static final String TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small";
public static final String TEXT_EMBEDDING_3_LARGE = "text-embedding-3-large";

// Use with AzureOpenAiImageModel
public static final String DALL_E_3 = "DALL_E_3";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class AzureOpenAiEmbeddingModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.serviceVersion(System.getenv("AZURE_OPENAI_SERVICE_VERSION"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"))
.deploymentName("text-embedding-ada-002")
.logRequestsAndResponses(true)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.azure.ai.openai.models.ImageGenerationResponseFormat;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -16,7 +15,6 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;

@Disabled
public class AzureOpenAiImageModelIT {

Logger logger = LoggerFactory.getLogger(AzureOpenAiImageModelIT.class);
Expand All @@ -26,8 +24,8 @@ void should_generate_image_with_url() {

AzureOpenAiImageModel model = AzureOpenAiImageModel.builder()
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("dall-e-3")
.logRequestsAndResponses(true)
.build();

Expand All @@ -50,7 +48,7 @@ void should_generate_image_with_url() {
void should_generate_image_in_base64() throws IOException {
AzureOpenAiImageModel model = AzureOpenAiImageModel.builder()
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.deploymentName(System.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"))
.deploymentName("dall-e-3")
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.logRequestsAndResponses(false) // The image is big, so we don't want to log it by default
.responseFormat(ImageGenerationResponseFormat.BASE64.toString())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@
import org.slf4j.LoggerFactory;

import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static org.assertj.core.api.Assertions.assertThat;

class AzureOpenAiLanguageModelIT {

Logger logger = LoggerFactory.getLogger(AzureOpenAiLanguageModelIT.class);

LanguageModel model = AzureOpenAiLanguageModel.builder()
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.serviceVersion(System.getenv("AZURE_OPENAI_SERVICE_VERSION"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("gpt-35-turbo-instruct")
.temperature(0.0)
.maxTokens(20)
.logRequestsAndResponses(true)
.build();

@Test
void should_generate_answer_and_return_token_usage_and_finish_reason_stop() {

LanguageModel model = AzureOpenAiLanguageModel.builder()
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.serviceVersion(System.getenv("AZURE_OPENAI_SERVICE_VERSION"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("davinci-002")
.temperature(0.0)
.maxTokens(10)
.logRequestsAndResponses(true)
.build();

String prompt = "The capital of France is: ";

Response<String> response = model.generate(prompt);
Expand All @@ -40,6 +40,16 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_stop() {
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());

assertThat(response.finishReason()).isEqualTo(STOP);
}

@Test
void should_generate_answer_and_finish_reason_length() {
String prompt = "Describe the capital of France in 100 words: ";

Response<String> response = model.generate(prompt);
logger.info(response.toString());

assertThat(response.finishReason()).isEqualTo(LENGTH);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.concurrent.CompletableFuture;

import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -21,12 +22,14 @@ class AzureOpenAiStreamingLanguageModelIT {
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.serviceVersion(System.getenv("AZURE_OPENAI_SERVICE_VERSION"))
.apiKey(System.getenv("AZURE_OPENAI_KEY"))
.deploymentName("davinci-002")
.deploymentName("gpt-35-turbo-instruct")
.temperature(0.0)
.maxTokens(20)
.logRequestsAndResponses(true)
.build();

@Test
void should_stream_answer() throws Exception {
void should_stream_answer_and_finish_reason_stop() throws Exception {

CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();
Expand Down Expand Up @@ -65,6 +68,42 @@ public void onError(Throwable error) {
assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(1);
assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(8);

assertThat(response.finishReason()).isEqualTo(STOP);
}

@Test
void should_stream_answer_and_finish_reason_length() throws Exception {

CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<String>> futureResponse = new CompletableFuture<>();

model.generate("Describe the capital of France in 100 words: ", new StreamingResponseHandler<String>() {

private final StringBuilder answerBuilder = new StringBuilder();

@Override
public void onNext(String token) {
logger.info("onNext: '" + token + "'");
answerBuilder.append(token);
}

@Override
public void onComplete(Response<String> response) {
logger.info("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);
}

@Override
public void onError(Throwable error) {
futureAnswer.completeExceptionally(error);
futureResponse.completeExceptionally(error);
}
});

String answer = futureAnswer.get(30, SECONDS);
Response<String> response = futureResponse.get(30, SECONDS);

assertThat(response.finishReason()).isEqualTo(LENGTH);
}
}

0 comments on commit 3236162

Please sign in to comment.