-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
gradient_ai.ts
142 lines (119 loc) Β· 3.57 KB
/
gradient_ai.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import { Gradient } from "@gradientai/nodejs-sdk";
import {
type BaseLLMCallOptions,
type BaseLLMParams,
LLM,
} from "@langchain/core/language_models/llms";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
/**
* The GradientLLMParams interface defines the input parameters for
* the GradientLLM class.
*/
export interface GradientLLMParams extends BaseLLMParams {
/**
* Gradient AI Access Token.
* Provide Access Token if you do not wish to automatically pull from env.
*/
gradientAccessKey?: string;
/**
* Gradient Workspace Id.
* Provide workspace id if you do not wish to automatically pull from env.
*/
workspaceId?: string;
/**
* Parameters accepted by the Gradient npm package.
*/
inferenceParameters?: Record<string, unknown>;
/**
* Gradient AI Model Slug.
*/
modelSlug?: string;
/**
* Gradient Adapter ID for custom fine tuned models.
*/
adapterId?: string;
}
/**
* The GradientLLM class is used to interact with Gradient AI inference Endpoint models.
* This requires your Gradient AI Access Token which is autoloaded if not specified.
*/
export class GradientLLM extends LLM<BaseLLMCallOptions> {
static lc_name() {
return "GradientLLM";
}
get lc_secrets(): { [key: string]: string } | undefined {
return {
gradientAccessKey: "GRADIENT_ACCESS_TOKEN",
workspaceId: "GRADIENT_WORKSPACE_ID",
};
}
modelSlug = "llama2-7b-chat";
adapterId?: string;
gradientAccessKey?: string;
workspaceId?: string;
inferenceParameters?: Record<string, unknown>;
lc_serializable = true;
// Gradient AI does not export the BaseModel type. Once it does, we can use it here.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
model: any;
constructor(fields: GradientLLMParams) {
super(fields);
this.modelSlug = fields?.modelSlug ?? this.modelSlug;
this.adapterId = fields?.adapterId;
this.gradientAccessKey =
fields?.gradientAccessKey ??
getEnvironmentVariable("GRADIENT_ACCESS_TOKEN");
this.workspaceId =
fields?.workspaceId ?? getEnvironmentVariable("GRADIENT_WORKSPACE_ID");
this.inferenceParameters = fields.inferenceParameters;
if (!this.gradientAccessKey) {
throw new Error("Missing Gradient AI Access Token");
}
if (!this.workspaceId) {
throw new Error("Missing Gradient AI Workspace ID");
}
}
_llmType() {
return "gradient_ai";
}
/**
* Calls the Gradient AI endpoint and retrieves the result.
* @param {string} prompt The input prompt.
* @returns {Promise<string>} A promise that resolves to the generated string.
*/
/** @ignore */
async _call(
prompt: string,
_options: this["ParsedCallOptions"]
): Promise<string> {
await this.setModel();
// GradientLLM does not export the CompleteResponse type. Once it does, we can use it here.
interface CompleteResponse {
finishReason: string;
generatedOutput: string;
}
const response = (await this.caller.call(async () =>
this.model.complete({
query: prompt,
...this.inferenceParameters,
})
)) as CompleteResponse;
return response.generatedOutput;
}
async setModel() {
if (this.model) return;
const gradient = new Gradient({
accessToken: this.gradientAccessKey,
workspaceId: this.workspaceId,
});
if (this.adapterId) {
this.model = await gradient.getModelAdapter({
modelAdapterId: this.adapterId,
});
} else {
this.model = await gradient.getBaseModel({
baseModelSlug: this.modelSlug,
});
}
}
}