Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for specifying the organization id in the configuration #364

Merged
merged 2 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
@Builder
public OpenAiChatModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
Expand All @@ -67,6 +68,7 @@ public OpenAiChatModel(String baseUrl,
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class OpenAiEmbeddingModel implements EmbeddingModel, TokenCountEstimator
@Builder
public OpenAiEmbeddingModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Duration timeout,
Integer maxRetries,
Expand All @@ -53,6 +54,7 @@ public OpenAiEmbeddingModel(String baseUrl,
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class OpenAiLanguageModel implements LanguageModel, TokenCountEstimator {
@Builder
public OpenAiLanguageModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Duration timeout,
Expand All @@ -49,6 +50,7 @@ public OpenAiLanguageModel(String baseUrl,
this.client = OpenAiClient.builder()
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
.openAiApiKey(apiKey)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class OpenAiModerationModel implements ModerationModel {
@Builder
public OpenAiModerationModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Duration timeout,
Integer maxRetries,
Expand All @@ -53,6 +54,7 @@ public OpenAiModerationModel(String baseUrl,
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, Tok
@Builder
public OpenAiStreamingChatModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
Expand All @@ -64,6 +65,7 @@ public OpenAiStreamingChatModel(String baseUrl,
this.client = OpenAiClient.builder()
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
.openAiApiKey(apiKey)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class OpenAiStreamingLanguageModel implements StreamingLanguageModel, Tok
@Builder
public OpenAiStreamingLanguageModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Duration timeout,
Expand All @@ -47,6 +48,7 @@ public OpenAiStreamingLanguageModel(String baseUrl,
this.client = OpenAiClient.builder()
.baseUrl(getOrDefault(baseUrl, OPENAI_URL))
.openAiApiKey(apiKey)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class OpenAiChatModelIT {

ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.temperature(0.0)
.logRequests(true)
.logResponses(true)
Expand Down Expand Up @@ -65,6 +66,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_length() {
// given
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.maxTokens(3)
.build();

Expand Down Expand Up @@ -188,6 +190,7 @@ void should_execute_multiple_tools_in_parallel_then_answer() {
// given
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_3_5_TURBO_1106.toString()) // supports parallel function calling
.temperature(0.0)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class OpenAiLanguageModelIT {

LanguageModel model = OpenAiLanguageModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.logRequests(true)
.logResponses(true)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class OpenAiModerationModelIT {

ModerationModel model = OpenAiModerationModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.build();

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class OpenAiStreamingChatModelIT {

StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.temperature(0.0)
.logRequests(true)
.logResponses(true)
Expand Down Expand Up @@ -288,6 +289,7 @@ void should_execute_multiple_tools_in_parallel_then_stream_answer() throws Excep
// given
StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_3_5_TURBO_1106.toString()) // supports parallel function calling
.temperature(0.0)
.logRequests(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class OpenAiStreamingLanguageModelIT {

StreamingLanguageModel model = OpenAiStreamingLanguageModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.logRequests(true)
.logResponses(true)
.build();
Expand Down
2 changes: 1 addition & 1 deletion langchain4j-parent/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<openai4j.version>0.11.1</openai4j.version>
<openai4j.version>0.12.1</openai4j.version>
<azure-ai-openai.version>1.0.0-beta.6</azure-ai-openai.version>
<retrofit.version>2.9.0</retrofit.version>
<okhttp.version>4.10.0</okhttp.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
return OpenAiChatModel.builder()
.baseUrl(openAi.getBaseUrl())
.apiKey(openAi.getApiKey())
.organizationId(openAi.getOrganizationId())
.modelName(openAi.getModelName())
.temperature(openAi.getTemperature())
.topP(openAi.getTopP())
Expand Down Expand Up @@ -142,6 +143,7 @@ LanguageModel languageModel(LangChain4jProperties properties) {
}
return OpenAiLanguageModel.builder()
.apiKey(openAi.getApiKey())
.organizationId(openAi.getOrganizationId())
.modelName(openAi.getModelName())
.temperature(openAi.getTemperature())
.timeout(openAi.getTimeout())
Expand Down Expand Up @@ -229,6 +231,7 @@ EmbeddingModel embeddingModel(LangChain4jProperties properties) {
}
return OpenAiEmbeddingModel.builder()
.apiKey(openAi.getApiKey())
.organizationId(openAi.getOrganizationId())
.modelName(openAi.getModelName())
.timeout(openAi.getTimeout())
.maxRetries(openAi.getMaxRetries())
Expand Down Expand Up @@ -306,6 +309,7 @@ ModerationModel moderationModel(LangChain4jProperties properties) {

return OpenAiModerationModel.builder()
.apiKey(openAi.getApiKey())
.organizationId(openAi.getOrganizationId())
.modelName(openAi.getModelName())
.timeout(openAi.getTimeout())
.maxRetries(openAi.getMaxRetries())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class OpenAi {

private String baseUrl;
private String apiKey;
private String organizationId;
private String modelName;
private Double temperature;
private Double topP;
Expand Down Expand Up @@ -33,6 +34,14 @@ public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}

public String getOrganizationId() {
return organizationId;
}

public void setOrganizationId(String organizationId) {
this.organizationId = organizationId;
}

public String getModelName() {
return modelName;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class AiServicesIT {
@Spy
ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.temperature(0.0)
.logRequests(true)
.logResponses(true)
Expand All @@ -67,6 +68,7 @@ public class AiServicesIT {
@Spy
ModerationModel moderationModel = OpenAiModerationModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.build();

ToolSpecification calculatorSpecification = ToolSpecification.builder()
Expand Down Expand Up @@ -845,6 +847,7 @@ void should_execute_multiple_tools_in_parallel_then_answer() {

ChatLanguageModel chatLanguageModel = spy(OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_3_5_TURBO_1106)
.temperature(0.0)
.logRequests(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class StreamingAiServicesIT {

StreamingChatLanguageModel streamingChatModel = OpenAiStreamingChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.temperature(0.0)
.logRequests(true)
.logResponses(true)
Expand Down Expand Up @@ -298,6 +299,7 @@ void should_execute_multiple_tools_in_parallel_then_answer() throws Exception {

StreamingChatLanguageModel streamingChatModel = OpenAiStreamingChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_3_5_TURBO_1106)
.temperature(0.0)
.logRequests(true)
Expand Down