Skip to content

Commit

Permalink
Include placeholder value for all secrets used, not just those in kwa…
Browse files Browse the repository at this point in the history
…rgs (#1696)

* Include placeholder value for all secrets used, not just those in kwargs

* Fix test

* Fix test

* Add error handling for openai streaming when errors are sent as events in the SSE stream (#1698)

* Add error handling for openai streaming when errors are sent as events in the SSE stream

* Lint
  • Loading branch information
nfcampos committed Jun 19, 2023
1 parent 542c2f1 commit 276fd1c
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 63 deletions.
12 changes: 6 additions & 6 deletions langchain/src/chat_models/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {

lc_serializable = true;

apiKey?: string;
anthropicApiKey?: string;

apiUrl?: string;

Expand Down Expand Up @@ -153,9 +153,9 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
constructor(fields?: Partial<AnthropicInput> & BaseChatModelParams) {
super(fields ?? {});

this.apiKey =
this.anthropicApiKey =
fields?.anthropicApiKey ?? getEnvironmentVariable("ANTHROPIC_API_KEY");
if (!this.apiKey) {
if (!this.anthropicApiKey) {
throw new Error("Anthropic API key not found");
}

Expand Down Expand Up @@ -266,14 +266,14 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
options: { signal?: AbortSignal },
runManager?: CallbackManagerForLLMRun
): Promise<CompletionResponse> {
if (!this.apiKey) {
if (!this.anthropicApiKey) {
throw new Error("Missing Anthropic API key.");
}
let makeCompletionRequest;
if (request.stream) {
if (!this.streamingClient) {
const options = this.apiUrl ? { apiUrl: this.apiUrl } : undefined;
this.streamingClient = new AnthropicApi(this.apiKey, options);
this.streamingClient = new AnthropicApi(this.anthropicApiKey, options);
}
makeCompletionRequest = async () => {
let currentCompletion = "";
Expand Down Expand Up @@ -308,7 +308,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
} else {
if (!this.batchClient) {
const options = this.apiUrl ? { apiUrl: this.apiUrl } : undefined;
this.batchClient = new AnthropicApi(this.apiKey, options);
this.batchClient = new AnthropicApi(this.anthropicApiKey, options);
}
makeCompletionRequest = async () =>
this.batchClient
Expand Down
47 changes: 27 additions & 20 deletions langchain/src/chat_models/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ export class ChatOpenAI

maxTokens?: number;

openAIApiKey?: string;

azureOpenAIApiVersion?: string;

azureOpenAIApiKey?: string;
Expand All @@ -175,35 +177,33 @@ export class ChatOpenAI
fields?: Partial<OpenAIChatInput> &
Partial<AzureOpenAIInput> &
BaseChatModelParams & {
concurrency?: number;
cache?: boolean;
openAIApiKey?: string;
configuration?: ConfigurationParameters;
},
/** @deprecated */
configuration?: ConfigurationParameters
) {
super(fields ?? {});

const apiKey =
this.openAIApiKey =
fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY");

const azureApiKey =
this.azureOpenAIApiKey =
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");
if (!azureApiKey && !apiKey) {

if (!this.azureOpenAIApiKey && !this.openAIApiKey) {
throw new Error("(Azure) OpenAI API key not found");
}

const azureApiInstanceName =
this.azureOpenAIApiInstanceName =
fields?.azureOpenAIApiInstanceName ??
getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME");

const azureApiDeploymentName =
this.azureOpenAIApiDeploymentName =
fields?.azureOpenAIApiDeploymentName ??
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME");

const azureApiVersion =
this.azureOpenAIApiVersion =
fields?.azureOpenAIApiVersion ??
getEnvironmentVariable("AZURE_OPENAI_API_VERSION");

Expand All @@ -222,11 +222,6 @@ export class ChatOpenAI

this.streaming = fields?.streaming ?? false;

this.azureOpenAIApiVersion = azureApiVersion;
this.azureOpenAIApiKey = azureApiKey;
this.azureOpenAIApiInstanceName = azureApiInstanceName;
this.azureOpenAIApiDeploymentName = azureApiDeploymentName;

if (this.streaming && this.n > 1) {
throw new Error("Cannot stream results when n > 1");
}
Expand All @@ -244,7 +239,7 @@ export class ChatOpenAI
}

this.clientConfig = {
apiKey,
apiKey: this.openAIApiKey,
...configuration,
...fields?.configuration,
};
Expand Down Expand Up @@ -327,18 +322,29 @@ export class ChatOpenAI
responseType: "stream",
onmessage: (event) => {
if (event.data?.trim?.() === "[DONE]") {
if (resolved) {
if (resolved || rejected) {
return;
}
resolved = true;
resolve(response);
} else {
const message = JSON.parse(event.data) as {
const data = JSON.parse(event.data);

if (data?.error) {
if (rejected) {
return;
}
rejected = true;
reject(data.error);
return;
}

const message = data as {
id: string;
object: string;
created: number;
model: string;
choices: Array<{
choices?: Array<{
index: number;
finish_reason: string | null;
delta: {
Expand All @@ -361,7 +367,7 @@ export class ChatOpenAI
}

// on all messages, update choice
for (const part of message.choices) {
for (const part of message.choices ?? []) {
if (part != null) {
let choice = response.choices.find(
(c) => c.index === part.index
Expand Down Expand Up @@ -414,7 +420,8 @@ export class ChatOpenAI
// when all messages are finished, resolve
if (
!resolved &&
message.choices.every((c) => c.finish_reason != null)
!rejected &&
message.choices?.every((c) => c.finish_reason != null)
) {
resolved = true;
resolve(response);
Expand Down
44 changes: 27 additions & 17 deletions langchain/src/llms/openai-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ export class OpenAIChat

streaming = false;

openAIApiKey?: string;

azureOpenAIApiVersion?: string;

azureOpenAIApiKey?: string;
Expand All @@ -116,34 +118,35 @@ export class OpenAIChat
fields?: Partial<OpenAIChatInput> &
Partial<AzureOpenAIInput> &
BaseLLMParams & {
openAIApiKey?: string;
configuration?: ConfigurationParameters;
},
/** @deprecated */
configuration?: ConfigurationParameters
) {
super(fields ?? {});

const apiKey =
this.openAIApiKey =
fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY");

const azureApiKey =
this.azureOpenAIApiKey =
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");

if (!azureApiKey && !apiKey) {
if (!this.azureOpenAIApiKey && !this.openAIApiKey) {
throw new Error("(Azure) OpenAI API key not found");
}

const azureApiInstanceName =
this.azureOpenAIApiInstanceName =
fields?.azureOpenAIApiInstanceName ??
getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME");

const azureApiDeploymentName =
fields?.azureOpenAIApiDeploymentName ??
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME");
this.azureOpenAIApiDeploymentName =
(fields?.azureOpenAIApiCompletionsDeploymentName ||
fields?.azureOpenAIApiDeploymentName) ??
(getEnvironmentVariable("AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME") ||
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME"));

const azureApiVersion =
this.azureOpenAIApiVersion =
fields?.azureOpenAIApiVersion ??
getEnvironmentVariable("AZURE_OPENAI_API_VERSION");

Expand All @@ -163,11 +166,6 @@ export class OpenAIChat

this.streaming = fields?.streaming ?? false;

this.azureOpenAIApiVersion = azureApiVersion;
this.azureOpenAIApiKey = azureApiKey;
this.azureOpenAIApiInstanceName = azureApiInstanceName;
this.azureOpenAIApiDeploymentName = azureApiDeploymentName;

if (this.streaming && this.n > 1) {
throw new Error("Cannot stream results when n > 1");
}
Expand All @@ -185,7 +183,7 @@ export class OpenAIChat
}

this.clientConfig = {
apiKey,
apiKey: this.openAIApiKey,
...configuration,
...fields?.configuration,
};
Expand Down Expand Up @@ -266,13 +264,24 @@ export class OpenAIChat
responseType: "stream",
onmessage: (event) => {
if (event.data?.trim?.() === "[DONE]") {
if (resolved) {
if (resolved || rejected) {
return;
}
resolved = true;
resolve(response);
} else {
const message = JSON.parse(event.data) as {
const data = JSON.parse(event.data);

if (data?.error) {
if (rejected) {
return;
}
rejected = true;
reject(data.error);
return;
}

const message = data as {
id: string;
object: string;
created: number;
Expand Down Expand Up @@ -329,6 +338,7 @@ export class OpenAIChat
// when all messages are finished, resolve
if (
!resolved &&
!rejected &&
message.choices.every((c) => c.finish_reason != null)
) {
resolved = true;
Expand Down
38 changes: 23 additions & 15 deletions langchain/src/llms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {

streaming = false;

openAIApiKey?: string;

azureOpenAIApiVersion?: string;

azureOpenAIApiKey?: string;
Expand All @@ -120,7 +122,6 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
fields?: Partial<OpenAIInput> &
Partial<AzureOpenAIInput> &
BaseLLMParams & {
openAIApiKey?: string;
configuration?: ConfigurationParameters;
},
/** @deprecated */
Expand All @@ -136,28 +137,28 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
}
super(fields ?? {});

const apiKey =
this.openAIApiKey =
fields?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY");

const azureApiKey =
this.azureOpenAIApiKey =
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");

if (!azureApiKey && !apiKey) {
if (!this.azureOpenAIApiKey && !this.openAIApiKey) {
throw new Error("(Azure) OpenAI API key not found");
}

const azureApiInstanceName =
this.azureOpenAIApiInstanceName =
fields?.azureOpenAIApiInstanceName ??
getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME");

const azureApiDeploymentName =
this.azureOpenAIApiDeploymentName =
(fields?.azureOpenAIApiCompletionsDeploymentName ||
fields?.azureOpenAIApiDeploymentName) ??
(getEnvironmentVariable("AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME") ||
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME"));

const azureApiVersion =
this.azureOpenAIApiVersion =
fields?.azureOpenAIApiVersion ??
getEnvironmentVariable("AZURE_OPENAI_API_VERSION");

Expand All @@ -178,11 +179,6 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {

this.streaming = fields?.streaming ?? false;

this.azureOpenAIApiVersion = azureApiVersion;
this.azureOpenAIApiKey = azureApiKey;
this.azureOpenAIApiInstanceName = azureApiInstanceName;
this.azureOpenAIApiDeploymentName = azureApiDeploymentName;

if (this.streaming && this.n > 1) {
throw new Error("Cannot stream results when n > 1");
}
Expand All @@ -204,7 +200,7 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
}

this.clientConfig = {
apiKey,
apiKey: this.openAIApiKey,
...configuration,
...fields?.configuration,
};
Expand Down Expand Up @@ -310,7 +306,7 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
responseType: "stream",
onmessage: (event) => {
if (event.data?.trim?.() === "[DONE]") {
if (resolved) {
if (resolved || rejected) {
return;
}
resolved = true;
Expand All @@ -319,7 +315,18 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
choices,
});
} else {
const message = JSON.parse(event.data) as Omit<
const data = JSON.parse(event.data);

if (data?.error) {
if (rejected) {
return;
}
rejected = true;
reject(data.error);
return;
}

const message = data as Omit<
CreateCompletionResponse,
"usage"
>;
Expand Down Expand Up @@ -352,6 +359,7 @@ export class OpenAI extends BaseLLM implements OpenAIInput, AzureOpenAIInput {
// when all messages are finished, resolve
if (
!resolved &&
!rejected &&
choices.every((c) => c.finish_reason != null)
) {
resolved = true;
Expand Down

1 comment on commit 276fd1c

@vercel
Copy link

@vercel vercel bot commented on 276fd1c Jun 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.