-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
fireworks.ts
127 lines (107 loc) Β· 3.68 KB
/
fireworks.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import type { OpenAI as OpenAIClient } from "openai";
import type { ChatOpenAICallOptions, OpenAIChatInput } from "./openai.js";
import type { OpenAICoreRequestOptions } from "../types/openai-types.js";
import type { BaseChatModelParams } from "./base.js";
import { ChatOpenAI } from "./openai.js";
import { getEnvironmentVariable } from "../util/env.js";
type FireworksUnsupportedArgs =
| "frequencyPenalty"
| "presencePenalty"
| "logitBias"
| "functions";
type FireworksUnsupportedCallOptions = "functions" | "function_call" | "tools";
export type ChatFireworksCallOptions = Partial<
Omit<ChatOpenAICallOptions, FireworksUnsupportedCallOptions>
>;
/**
* Wrapper around Fireworks API for large language models fine-tuned for chat
*
* Fireworks API is compatible to the OpenAI API with some limitations described in
* https://readme.fireworks.ai/docs/openai-compatibility.
*
* To use, you should have the `openai` package installed and
* the `FIREWORKS_API_KEY` environment variable set.
*/
export class ChatFireworks extends ChatOpenAI<ChatFireworksCallOptions> {
static lc_name() {
return "ChatFireworks";
}
_llmType() {
return "fireworks";
}
get lc_secrets(): { [key: string]: string } | undefined {
return {
fireworksApiKey: "FIREWORKS_API_KEY",
};
}
lc_serializable = true;
fireworksApiKey?: string;
constructor(
fields?: Partial<
Omit<OpenAIChatInput, "openAIApiKey" | FireworksUnsupportedArgs>
> &
BaseChatModelParams & { fireworksApiKey?: string }
) {
const fireworksApiKey =
fields?.fireworksApiKey || getEnvironmentVariable("FIREWORKS_API_KEY");
if (!fireworksApiKey) {
throw new Error(
`Fireworks API key not found. Please set the FIREWORKS_API_KEY environment variable or provide the key into "fireworksApiKey"`
);
}
super({
...fields,
modelName:
fields?.modelName || "accounts/fireworks/models/llama-v2-13b-chat",
openAIApiKey: fireworksApiKey,
configuration: {
baseURL: "https://api.fireworks.ai/inference/v1",
},
});
this.fireworksApiKey = fireworksApiKey;
}
toJSON() {
const result = super.toJSON();
if (
"kwargs" in result &&
typeof result.kwargs === "object" &&
result.kwargs != null
) {
delete result.kwargs.openai_api_key;
delete result.kwargs.configuration;
}
return result;
}
async completionWithRetry(
request: OpenAIClient.Chat.ChatCompletionCreateParamsStreaming,
options?: OpenAICoreRequestOptions
): Promise<AsyncIterable<OpenAIClient.Chat.Completions.ChatCompletionChunk>>;
async completionWithRetry(
request: OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming,
options?: OpenAICoreRequestOptions
): Promise<OpenAIClient.Chat.Completions.ChatCompletion>;
/**
* Calls the Fireworks API with retry logic in case of failures.
* @param request The request to send to the Fireworks API.
* @param options Optional configuration for the API call.
* @returns The response from the Fireworks API.
*/
async completionWithRetry(
request:
| OpenAIClient.Chat.ChatCompletionCreateParamsStreaming
| OpenAIClient.Chat.ChatCompletionCreateParamsNonStreaming,
options?: OpenAICoreRequestOptions
): Promise<
| AsyncIterable<OpenAIClient.Chat.Completions.ChatCompletionChunk>
| OpenAIClient.Chat.Completions.ChatCompletion
> {
delete request.frequency_penalty;
delete request.presence_penalty;
delete request.logit_bias;
delete request.functions;
if (request.stream === true) {
return super.completionWithRetry(request, options);
}
return super.completionWithRetry(request, options);
}
}