Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: Added cohere-ai 7.0.0 support in package.json #1460

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
},
"peerDependencies": {
"@google/generative-ai": "^0.1.1",
"cohere-ai": "^5.0.0 || ^6.0.0",
"cohere-ai": "^5.0.0 || ^6.0.0 || ^7.0.0",
"openai": "^3.0.0 || ^4.0.0"
},
"peerDependenciesMeta": {
Expand Down
157 changes: 109 additions & 48 deletions clients/js/src/embeddings/CohereEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -1,61 +1,122 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

let CohereAiApi: any;

export class CohereEmbeddingFunction implements IEmbeddingFunction {
private api_key: string;
private model: string;
private cohereAiApi?: any;

constructor({ cohere_api_key, model }: { cohere_api_key: string, model?: string }) {
// we used to construct the client here, but we need to async import the types
// for the openai npm package, and the constructor can not be async
this.api_key = cohere_api_key;
this.model = model || "large";
}
interface CohereAIAPI {
createEmbedding: (params: {
model: string;
input: string[];
}) => Promise<number[][]>;
}

private async loadClient() {
if(this.cohereAiApi) return;
try {
// eslint-disable-next-line global-require,import/no-extraneous-dependencies
const { cohere } = await CohereEmbeddingFunction.import();
CohereAiApi = cohere;
CohereAiApi.init(this.api_key);
} catch (_a) {
// @ts-ignore
if (_a.code === 'MODULE_NOT_FOUND') {
throw new Error("Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`");
}
throw _a; // Re-throw other errors
}
this.cohereAiApi = CohereAiApi;
}
class CohereAISDK56 implements CohereAIAPI {
private cohereClient: any;
private apiKey: string;

public async generate(texts: string[]) {
constructor(configuration: { apiKey: string }) {
this.apiKey = configuration.apiKey;
}

await this.loadClient();
private async loadClient() {
if (this.cohereClient) return;
//@ts-ignore
const { default: cohere } = await import("cohere-ai");
// @ts-ignore
cohere.init(this.apiKey);
this.cohereClient = cohere;
}

const response = await this.cohereAiApi.embed({
texts: texts,
model: this.model,
});
public async createEmbedding(params: {
model: string;
input: string[];
}): Promise<number[][]> {
await this.loadClient();
return await this.cohereClient
.embed({
texts: params.input,
model: params.model,
})
.then((response: any) => {
return response.body.embeddings;
}
});
}
}

class CohereAISDK7 implements CohereAIAPI {
private cohereClient: any;
private apiKey: string;

constructor(configuration: { apiKey: string }) {
this.apiKey = configuration.apiKey;
}

private async loadClient() {
if (this.cohereClient) return;
//@ts-ignore
const cohere = await import("cohere-ai").then((cohere) => {
return cohere;
});
// @ts-ignore
this.cohereClient = new cohere.CohereClient({
token: this.apiKey,
});
}

public async createEmbedding(params: {
model: string;
input: string[];
}): Promise<number[][]> {
await this.loadClient();
return await this.cohereClient
.embed({ texts: params.input, model: params.model })
.then((response: any) => {
return response.embeddings;
});
}
}

export class CohereEmbeddingFunction implements IEmbeddingFunction {
private cohereAiApi?: CohereAIAPI;
private model: string;
private apiKey: string;
constructor({
cohere_api_key,
model,
}: {
cohere_api_key: string;
model?: string;
}) {
this.model = model || "large";
this.apiKey = cohere_api_key;
}

/** @ignore */
static async import(): Promise<{
private async initCohereClient() {
if (this.cohereAiApi) return;
try {
// @ts-ignore
this.cohereAiApi = await import("cohere-ai").then((cohere) => {
// @ts-ignore
cohere: typeof import("cohere-ai");
}> {
try {
// @ts-ignore
const { default: cohere } = await import("cohere-ai");
return { cohere };
} catch (e) {
throw new Error(
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`"
);
if (cohere.CohereClient) {
return new CohereAISDK7({ apiKey: this.apiKey });
} else {
return new CohereAISDK56({ apiKey: this.apiKey });
}
});
} catch (e) {
// @ts-ignore
if (e.code === "MODULE_NOT_FOUND") {
throw new Error(
"Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`"
);
}
throw e;
}
}

public async generate(texts: string[]): Promise<number[][]> {
await this.initCohereClient();
// @ts-ignore
return await this.cohereAiApi.createEmbedding({
model: this.model,
input: texts,
});
}
}