Skip to content

Commit

Permalink
add changes needed for SAP Generative AI Hub
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrner committed Jan 9, 2024
1 parent e0ff9b9 commit dc80fda
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 18 deletions.
16 changes: 14 additions & 2 deletions core/config/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ declare global {
apiType?: string;
region?: string;
projectId?: string;
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;
_fetch?: (input: any, init?: any) => Promise<any>;
Expand Down Expand Up @@ -198,6 +202,12 @@ declare global {
// GCP Options
region?: string;
projectId?: string;
// SAP Gen AI Core options
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;
}
type RequireAtLeastOne<T, Keys extends keyof T = keyof T> = Pick<
T,
Expand Down Expand Up @@ -231,6 +241,7 @@ declare global {
// IDE
export interface DiffLine {
type: "new" | "old" | "same";
line: string;
Expand Down Expand Up @@ -342,6 +353,7 @@ declare global {
| "gemini"
| "mistral"
| "bedrock"
| "sap-gen-ai-hub"
| "deepinfra";
export type ModelName =
Expand Down Expand Up @@ -509,8 +521,8 @@ declare global {
disableIndexing?: boolean;
userToken?: string;
}
}
export {};
Expand Down
12 changes: 12 additions & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ export interface ILLM extends LLMOptions {
apiType?: string;
region?: string;
projectId?: string;
// SAP Gen AI Core options
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;

_fetch?: (input: any, init?: any) => Promise<any>;

Expand Down Expand Up @@ -196,6 +201,12 @@ export interface LLMOptions {
// GCP Options
region?: string;
projectId?: string;

// SAP Gen AI Core options
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;
}
type RequireAtLeastOne<T, Keys extends keyof T = keyof T> = Pick<
T,
Expand Down Expand Up @@ -340,6 +351,7 @@ type ModelProvider =
| "gemini"
| "mistral"
| "bedrock"
| "sap-gen-ai-hub"
| "deepinfra";

export type ModelName =
Expand Down
15 changes: 14 additions & 1 deletion core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ function autodetectTemplateType(model: string): TemplateType | undefined {
if (lower.includes("mistral")) {
return "llama2";
}

if (lower.includes("deepseek")) {
return "deepseek";
}
Expand Down Expand Up @@ -202,6 +201,12 @@ export abstract class BaseLLM implements ILLM {
region?: string;
projectId?: string;

// SAP Gen AI Core options
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;

private _llmOptions: LLMOptions;

constructor(options: LLMOptions) {
Expand Down Expand Up @@ -252,6 +257,14 @@ export abstract class BaseLLM implements ILLM {
this.apiType = options.apiType;
this.region = options.region;
this.projectId = options.projectId;

// SAP Gen AI Core options
this.resourceGroup = options.resourceGroup;
this.authURL = options.authURL;
this.clientID = options.clientID;
this.clientSecret = options.clientSecret;


}

private _compileChatMessages(
Expand Down
33 changes: 19 additions & 14 deletions core/llm/llms/OpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class OpenAI extends BaseLLM {

return completion;
}
private _getCompletionUrl() {
protected _getCompletionUrl() {
if (this.apiType === "azure") {
return `${this.apiBase}/openai/deployments/${this.engine}/completions?api-version=${this.apiVersion}`;
} else {
Expand All @@ -79,6 +79,15 @@ class OpenAI extends BaseLLM {
}
}

protected async _getRequestHeaders(): Promise<Record<string, string>> {
return {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
"api-key": this.apiKey || "", // For Azure
};
}


protected async *_streamComplete(
prompt: string,
options: CompletionOptions
Expand All @@ -95,13 +104,11 @@ class OpenAI extends BaseLLM {
prompt: string,
options: CompletionOptions
): AsyncGenerator<string> {
const response = await this.fetch(this._getCompletionUrl(), {
const header = await this._getRequestHeaders();
const url = this._getCompletionUrl();
const response = await this.fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
"api-key": this.apiKey || "", // For Azure
},
headers: header,
body: JSON.stringify({
...{
prompt,
Expand All @@ -124,7 +131,7 @@ class OpenAI extends BaseLLM {
}
}

private _getChatUrl() {
protected _getChatUrl() {
if (this.apiType === "azure") {
return `${this.apiBase}/openai/deployments/${this.engine}/chat/completions?api-version=${this.apiVersion}`;
} else {
Expand Down Expand Up @@ -163,13 +170,11 @@ class OpenAI extends BaseLLM {
return;
}

const response = await this.fetch(this._getChatUrl(), {
const header = await this._getRequestHeaders();
const url = this._getChatUrl();
const response = await this.fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
"api-key": this.apiKey || "", // For Azure
},
headers: header,
body: JSON.stringify({
...this._convertArgs(options, messages),
stream: true,
Expand Down
114 changes: 114 additions & 0 deletions core/llm/llms/SAPGenAIHub.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import { ModelProvider } from "../..";
import OpenAI from "./OpenAI";

class SAPGenAIHub extends OpenAI {
private tokenCache: { token: string; expiry: number } | null = null;

static providerName: ModelProvider = "sap-gen-ai-hub";

private _getTokenParams(): { authURL: string; clientID: string; clientSecret: string } {
if (!this.authURL || !this.clientID || !this.clientSecret) {
throw new Error("Authentication parameters (authURL, clientID, clientSecret) are undefined");
}
return {
authURL: this.authURL.endsWith("/oauth/token") ? this.authURL : `${this.authURL}/oauth/token`,
clientID: this.clientID,
clientSecret: this.clientSecret,
};
}

private async fetchWithTimeout(url: string, options: RequestInit, timeout: number): Promise<Response> {
return new Promise((resolve, reject) => {
const timer = setTimeout(() => reject(new Error('Request timed out')), timeout);

fetch(url, options)
.then(response => {
clearTimeout(timer);
resolve(response);
})
.catch(err => {
clearTimeout(timer);
reject(err);
});
});
}

private async fetchOAuthToken(): Promise<string> {
const params = this._getTokenParams();
const response = await this.fetchWithTimeout(params.authURL, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: new URLSearchParams({
client_id: params.clientID,
client_secret: params.clientSecret,
grant_type: 'client_credentials',
}).toString(),
}, 10000); // 10-second timeout

const data = await response.json();
return data.access_token;
}

private async ensureToken(): Promise<void> {
if (!this.tokenCache || Date.now() >= this.tokenCache.expiry) {
const token = await this.fetchOAuthToken();
const expiry = Date.now() + 3600 * 1000; // Consider making this configurable
this.tokenCache = { token, expiry };
}
}

protected async _getRequestHeaders(): Promise<Record<string, string>> {
await this.ensureToken();
const header: Record<string, string> = {
"Content-Type": "application/json",
Authorization: `Bearer ${this.tokenCache?.token}`,
"api-key": this.apiKey || "",
"AI-Resource-Group": this.resourceGroup || "default",
};
return header;
}

protected _getChatUrl() {
if (this.apiType === "azure") {
return `${this.apiBase}/chat/completions?api-version=${this.apiVersion}`;
} else {
let url = this.apiBase;
if (!url) {
throw new Error(
"No API base URL provided. Please set the 'apiBase' option in config.json"
);
}
if (url.endsWith("/")) {
url = url.slice(0, -1);
}

if (!url.endsWith("/v1")) {
url += "/v1";
}
return url + "/chat/completions";
}
}

protected _getCompletionUrl() {
if (this.apiType === "azure") {
return `${this.apiBase}/completions?api-version=${this.apiVersion}`;
} else {
let url = this.apiBase;
if (!url) {
throw new Error(
"No API base URL provided. Please set the 'apiBase' option in config.json"
);
}
if (url.endsWith("/")) {
url = url.slice(0, -1);
}
if (!url.endsWith("/v1")) {
url += "/v1";
}
return url + "/completions";
}
}

}

export default SAPGenAIHub;
2 changes: 2 additions & 0 deletions core/llm/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import OpenAI from "./OpenAI";
import Replicate from "./Replicate";
import TextGenWebUI from "./TextGenWebUI";
import Together from "./Together";
import SAPGenAIHub from "./SAPGenAIHub";

function convertToLetter(num: number): string {
let result = "";
Expand Down Expand Up @@ -88,6 +89,7 @@ const LLMs = [
Gemini,
Mistral,
Bedrock,
SAPGenAIHub,
DeepInfra,
];

Expand Down

0 comments on commit dc80fda

Please sign in to comment.