Skip to content

Commit

Permalink
Allow for specifying the organization id in the configuration
Browse files Browse the repository at this point in the history
Fixes #344
  • Loading branch information
edeandrea committed Dec 18, 2023
1 parent 99faffe commit 1fa9188
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
@Builder
public OpenAiChatModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
Expand All @@ -67,6 +68,7 @@ public OpenAiChatModel(String baseUrl,
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator
@Builder
public OpenAiEmbeddingModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Duration timeout,
Integer maxRetries,
Expand All @@ -53,6 +54,7 @@ public OpenAiEmbeddingModel(String baseUrl,
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class OpenAiLanguageModel implements LanguageModel, TokenCountEstimator {
@Builder
public OpenAiLanguageModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Duration timeout,
Expand All @@ -49,6 +50,7 @@ public OpenAiLanguageModel(String baseUrl,
this.client = OpenAiClient.builder()
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
.openAiApiKey(apiKey)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class OpenAiModerationModel implements ModerationModel {
@Builder
public OpenAiModerationModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Duration timeout,
Integer maxRetries,
Expand All @@ -53,6 +54,7 @@ public OpenAiModerationModel(String baseUrl,
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
@Builder
public OpenAiStreamingChatModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
Expand All @@ -64,6 +65,7 @@ public OpenAiStreamingChatModel(String baseUrl,
this.client = OpenAiClient.builder()
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
.openAiApiKey(apiKey)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok
@Builder
public OpenAiStreamingLanguageModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Duration timeout,
Expand All @@ -47,6 +48,7 @@ public OpenAiStreamingLanguageModel(String baseUrl,
this.client = OpenAiClient.builder()
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
.openAiApiKey(apiKey)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class OpenAiChatModelIT {

ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.temperature(0.0)
.logRequests(true)
.logResponses(true)
Expand Down Expand Up @@ -65,6 +66,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_length() {
// given
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.maxTokens(3)
.build();

Expand Down Expand Up @@ -188,6 +190,7 @@ void should_execute_multiple_tools_in_parallel_then_answer() {
// given
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_3_5_TURBO_1106.toString()) // supports parallel function calling
.temperature(0.0)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class OpenAiLanguageModelIT {

LanguageModel model = OpenAiLanguageModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.logRequests(true)
.logResponses(true)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class OpenAiModerationModelIT {

ModerationModel model = OpenAiModerationModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.build();

@Test
Expand Down
2 changes: 1 addition & 1 deletion langchain4j-parent/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<openai4j.version>0.11.1</openai4j.version>
<openai4j.version>0.12.1</openai4j.version>
<azure-ai-openai.version>1.0.0-beta.6</azure-ai-openai.version>
<retrofit.version>2.9.0</retrofit.version>
<okhttp.version>4.10.0</okhttp.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
return OpenAiChatModel.builder()
.baseUrl(openAi.getBaseUrl())
.apiKey(openAi.getApiKey())
.organizationId(openAi.getOrganizationId())
.modelName(openAi.getModelName())
.temperature(openAi.getTemperature())
.topP(openAi.getTopP())
Expand Down Expand Up @@ -142,6 +143,7 @@ LanguageModel languageModel(LangChain4jProperties properties) {
}
return OpenAiLanguageModel.builder()
.apiKey(openAi.getApiKey())
.organizationId(openAi.getOrganizationId())
.modelName(openAi.getModelName())
.temperature(openAi.getTemperature())
.timeout(openAi.getTimeout())
Expand Down Expand Up @@ -229,6 +231,7 @@ EmbeddingModel embeddingModel(LangChain4jProperties properties) {
}
return OpenAiEmbeddingModel.builder()
.apiKey(openAi.getApiKey())
.organizationId(openAi.getOrganizationId())
.modelName(openAi.getModelName())
.timeout(openAi.getTimeout())
.maxRetries(openAi.getMaxRetries())
Expand Down Expand Up @@ -306,6 +309,7 @@ ModerationModel moderationModel(LangChain4jProperties properties) {

return OpenAiModerationModel.builder()
.apiKey(openAi.getApiKey())
.organizationId(openAi.getOrganizationId())
.modelName(openAi.getModelName())
.timeout(openAi.getTimeout())
.maxRetries(openAi.getMaxRetries())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class OpenAi {

private String baseUrl;
private String apiKey;
private String organizationId;
private String modelName;
private Double temperature;
private Double topP;
Expand Down Expand Up @@ -33,6 +34,14 @@ public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}

public String getOrganizationId() {
return organizationId;
}

public void setOrganizationId(String organizationId) {
this.organizationId = organizationId;
}

public String getModelName() {
return modelName;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class AiServicesIT {
@Spy
ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.temperature(0.0)
.logRequests(true)
.logResponses(true)
Expand All @@ -67,6 +68,7 @@ public class AiServicesIT {
@Spy
ModerationModel moderationModel = OpenAiModerationModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.build();

ToolSpecification calculatorSpecification = ToolSpecification.builder()
Expand Down Expand Up @@ -845,6 +847,7 @@ void should_execute_multiple_tools_in_parallel_then_answer() {

ChatLanguageModel chatLanguageModel = spy(OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_3_5_TURBO_1106)
.temperature(0.0)
.logRequests(true)
Expand Down

0 comments on commit 1fa9188

Please sign in to comment.