-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
yandex.ts
128 lines (109 loc) Β· 3.41 KB
/
yandex.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import { CallbackManagerForLLMRun } from "../callbacks/manager.js";
import { YandexGPTInputs } from "../llms/yandex.js";
import {
AIMessage,
BaseMessage,
ChatResult,
ChatGeneration,
} from "../schema/index.js";
import { getEnvironmentVariable } from "../util/env.js";
import { BaseChatModel } from "./base.js";
const apiUrl = "https://llm.api.cloud.yandex.net/llm/v1alpha/chat";
interface ParsedMessage {
role: string;
text: string;
}
function _parseChatHistory(history: BaseMessage[]): [ParsedMessage[], string] {
const chatHistory: ParsedMessage[] = [];
let instruction = "";
for (const message of history) {
if (typeof message.content !== "string") {
throw new Error(
"ChatYandexGPT does not support non-string message content."
);
}
if ("content" in message) {
if (message._getType() === "human") {
chatHistory.push({ role: "user", text: message.content });
} else if (message._getType() === "ai") {
chatHistory.push({ role: "assistant", text: message.content });
} else if (message._getType() === "system") {
instruction = message.content;
}
}
}
return [chatHistory, instruction];
}
export class ChatYandexGPT extends BaseChatModel {
apiKey?: string;
iamToken?: string;
temperature = 0.6;
maxTokens = 1700;
model = "general";
constructor(fields?: YandexGPTInputs) {
super(fields ?? {});
const apiKey = fields?.apiKey ?? getEnvironmentVariable("YC_API_KEY");
const iamToken = fields?.iamToken ?? getEnvironmentVariable("YC_IAM_TOKEN");
if (apiKey === undefined && iamToken === undefined) {
throw new Error(
"Please set the YC_API_KEY or YC_IAM_TOKEN environment variable or pass it to the constructor as the apiKey or iamToken field."
);
}
this.apiKey = apiKey;
this.iamToken = iamToken;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.temperature = fields?.temperature ?? this.temperature;
this.model = fields?.model ?? this.model;
}
_llmType() {
return "yandexgpt";
}
_combineLLMOutput?() {
return {};
}
/** @ignore */
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
_?: CallbackManagerForLLMRun | undefined
): Promise<ChatResult> {
const [messageHistory, instruction] = _parseChatHistory(messages);
const headers = { "Content-Type": "application/json", Authorization: "" };
if (this.apiKey !== undefined) {
headers.Authorization = `Api-Key ${this.apiKey}`;
} else {
headers.Authorization = `Bearer ${this.iamToken}`;
}
const bodyData = {
model: this.model,
generationOptions: {
temperature: this.temperature,
maxTokens: this.maxTokens,
},
messages: messageHistory,
instructionText: instruction,
};
const response = await fetch(apiUrl, {
method: "POST",
headers,
body: JSON.stringify(bodyData),
signal: options?.signal,
});
if (!response.ok) {
throw new Error(
`Failed to fetch ${apiUrl} from YandexGPT: ${response.status}`
);
}
const responseData = await response.json();
const { result } = responseData;
const { text } = result.message;
const totalTokens = result.num_tokens;
const generations: ChatGeneration[] = [
{ text, message: new AIMessage(text) },
];
return {
generations,
llmOutput: { totalTokens },
};
}
}