diff --git a/docs/core_docs/docs/integrations/text_embedding/cohere.mdx b/docs/core_docs/docs/integrations/text_embedding/cohere.mdx index 4a11080a64c..38727886a31 100644 --- a/docs/core_docs/docs/integrations/text_embedding/cohere.mdx +++ b/docs/core_docs/docs/integrations/text_embedding/cohere.mdx @@ -6,11 +6,7 @@ The `CohereEmbeddings` class uses the Cohere API to generate embeddings for a gi npm install cohere-ai ``` -```typescript -import { CohereEmbeddings } from "langchain/embeddings/cohere"; +import CodeBlock from "@theme/CodeBlock"; +import BasicExample from "@examples/embeddings/cohere.ts"; -const embeddings = new CohereEmbeddings({ - apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY - batchSize: 48, // Default value if omitted is 48. Max value is 96 -}); -``` +{BasicExample} diff --git a/examples/src/embeddings/cohere.ts b/examples/src/embeddings/cohere.ts index 3499703a8e9..d99f4620938 100644 --- a/examples/src/embeddings/cohere.ts +++ b/examples/src/embeddings/cohere.ts @@ -1,9 +1,36 @@ import { CohereEmbeddings } from "langchain/embeddings/cohere"; -export const run = async () => { - const model = new CohereEmbeddings(); - const res = await model.embedQuery( - "What would be a good company name a company that makes colorful socks?" - ); - console.log({ res }); -}; +const cohere = new CohereEmbeddings({ + apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY + batchSize: 48, // Default value if omitted is 48. Max value is 96 + modelName: "embed-english-v3.0", // Default value if omitted is "small". + inputType: "classification", // Optional parameter unless using a v3 model. +}); + +const texts = [ + "I love Cohere!", + "I hate Cohere!", + "I feel neutral about Cohere.", +]; + +const embeddings = await cohere.embedDocuments(texts); +console.log(embeddings); +/** + * [ + [ + -0.007194519, -0.009376526, -0.10015869, -0.06750488, + -0.011001587, -0.034454346, -0.074523926, 0.03756714, + ... 943 more items + ], + [ + -0.009613037, -0.022705078, -0.07318115, -0.02255249, -0.019729614, + 0.009689331, -0.024749756, -0.06665039, -0.02128601, -0.010520935, + ... 943 more items + ], + [ + 0.009689331, -0.024749756, -0.06665039, -0.02128601, -0.010520935, + -0.007194519, -0.009376526, -0.10015869, -0.06750488, -0.02128601, + ... 943 more items + ] +] + */ diff --git a/examples/src/models/embeddings/cohere.ts b/examples/src/models/embeddings/cohere.ts deleted file mode 100644 index 8bb558d50be..00000000000 --- a/examples/src/models/embeddings/cohere.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { CohereEmbeddings } from "langchain/embeddings/cohere"; - -export const run = async () => { - /* Embed queries */ - const embeddings = new CohereEmbeddings(); - const res = await embeddings.embedQuery("Hello world"); - console.log(res); - /* Embed documents */ - const documentRes = await embeddings.embedDocuments([ - "Hello world", - "Bye bye", - ]); - console.log({ documentRes }); -}; diff --git a/langchain/package.json b/langchain/package.json index 0905ae50f29..171f6aa9442 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -902,7 +902,7 @@ "closevector-common": "0.1.0-alpha.1", "closevector-node": "0.1.0-alpha.10", "closevector-web": "0.1.0-alpha.15", - "cohere-ai": ">=6.0.0", + "cohere-ai": "^7.2.0", "convex": "^1.3.1", "d3-dsv": "^2.0.0", "dotenv": "^16.0.3", @@ -1019,7 +1019,7 @@ "closevector-common": "0.1.0-alpha.1", "closevector-node": "0.1.0-alpha.10", "closevector-web": "0.1.0-alpha.16", - "cohere-ai": ">=6.0.0", + "cohere-ai": "^7.2.0", "convex": "^1.3.1", "d3-dsv": "^2.0.0", "epub2": "^3.0.1", diff --git a/langchain/src/embeddings/cohere.ts b/langchain/src/embeddings/cohere.ts index 3c547e769f6..1382e158468 100644 --- a/langchain/src/embeddings/cohere.ts +++ b/langchain/src/embeddings/cohere.ts @@ -14,6 +14,18 @@ export interface CohereEmbeddingsParams extends EmbeddingsParams { * limited by the Cohere API to a maximum of 96. */ batchSize?: number; + + /** + * Specifies the type of input you're giving to the model. + * Not required for older versions of the embedding models (i.e. anything lower than v3), + * but is required for more recent versions (i.e. anything bigger than v2). + * + * * `search_document` - Use this when you encode documents for embeddings that you store in a vector database for search use-cases. + * * `search_query` - Use this when you query your vector DB to find relevant documents. + * * `classification` - Use this when you use the embeddings as an input to a text classifier. + * * `clustering` - Use this when you want to cluster the embeddings. + */ + inputType?: string; } /** @@ -27,9 +39,11 @@ export class CohereEmbeddings batchSize = 48; + inputType: string | undefined; + private apiKey: string; - private client: typeof import("cohere-ai"); + private client: import("cohere-ai").CohereClient; /** * Constructor for the CohereEmbeddings class. @@ -54,6 +68,7 @@ export class CohereEmbeddings this.modelName = fieldsWithDefaults?.modelName ?? this.modelName; this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize; + this.inputType = fieldsWithDefaults?.inputType; this.apiKey = apiKey; } @@ -71,6 +86,7 @@ export class CohereEmbeddings this.embeddingWithRetry({ model: this.modelName, texts: batch, + inputType: this.inputType, }) ); @@ -80,9 +96,9 @@ export class CohereEmbeddings for (let i = 0; i < batchResponses.length; i += 1) { const batch = batches[i]; - const { body: batchResponse } = batchResponses[i]; + const { embeddings: batchResponse } = batchResponses[i]; for (let j = 0; j < batch.length; j += 1) { - embeddings.push(batchResponse.embeddings[j]); + embeddings.push(batchResponse[j]); } } @@ -97,11 +113,11 @@ export class CohereEmbeddings async embedQuery(text: string): Promise { await this.maybeInitClient(); - const { body } = await this.embeddingWithRetry({ + const { embeddings } = await this.embeddingWithRetry({ model: this.modelName, texts: [text], }); - return body.embeddings[0]; + return embeddings[0]; } /** @@ -122,20 +138,21 @@ export class CohereEmbeddings */ private async maybeInitClient() { if (!this.client) { - const { cohere } = await CohereEmbeddings.imports(); + const { CohereClient } = await CohereEmbeddings.imports(); - this.client = cohere; - this.client.init(this.apiKey); + this.client = new CohereClient({ + token: this.apiKey, + }); } } /** @ignore */ static async imports(): Promise<{ - cohere: typeof import("cohere-ai"); + CohereClient: typeof import("cohere-ai").CohereClient; }> { try { - const { default: cohere } = await import("cohere-ai"); - return { cohere }; + const { CohereClient } = await import("cohere-ai"); + return { CohereClient }; } catch (e) { throw new Error( "Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`" diff --git a/langchain/src/llms/cohere.ts b/langchain/src/llms/cohere.ts index 32a1e1f746b..c33c71ad563 100644 --- a/langchain/src/llms/cohere.ts +++ b/langchain/src/llms/cohere.ts @@ -76,9 +76,11 @@ export class Cohere extends LLM implements CohereInput { prompt: string, options: this["ParsedCallOptions"] ): Promise { - const { cohere } = await Cohere.imports(); + const { CohereClient } = await Cohere.imports(); - cohere.init(this.apiKey); + const cohere = new CohereClient({ + token: this.apiKey, + }); // Hit the `generate` endpoint on the `large` model const generateResponse = await this.caller.callWithOptions( @@ -87,13 +89,13 @@ export class Cohere extends LLM implements CohereInput { { prompt, model: this.model, - max_tokens: this.maxTokens, + maxTokens: this.maxTokens, temperature: this.temperature, - end_sequences: options.stop, + endSequences: options.stop, } ); try { - return generateResponse.body.generations[0].text; + return generateResponse.generations[0].text; } catch { console.log(generateResponse); throw new Error("Could not parse response."); @@ -102,11 +104,11 @@ export class Cohere extends LLM implements CohereInput { /** @ignore */ static async imports(): Promise<{ - cohere: typeof import("cohere-ai"); + CohereClient: typeof import("cohere-ai").CohereClient; }> { try { - const { default: cohere } = await import("cohere-ai"); - return { cohere }; + const { CohereClient } = await import("cohere-ai"); + return { CohereClient }; } catch (e) { throw new Error( "Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`" diff --git a/yarn.lock b/yarn.lock index 396b4a9d8ab..3d395fadd09 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11580,6 +11580,13 @@ __metadata: languageName: node linkType: hard +"@types/qs@npm:6.9.8": + version: 6.9.8 + resolution: "@types/qs@npm:6.9.8" + checksum: c28e07d00d07970e5134c6eed184a0189b8a4649e28fdf36d9117fe671c067a44820890de6bdecef18217647a95e9c6aebdaaae69f5fe4b0bec9345db885f77e + languageName: node + linkType: hard + "@types/range-parser@npm:*": version: 1.2.4 resolution: "@types/range-parser@npm:1.2.4" @@ -11804,6 +11811,13 @@ __metadata: languageName: node linkType: hard +"@types/url-join@npm:4.0.1": + version: 4.0.1 + resolution: "@types/url-join@npm:4.0.1" + checksum: 29444b90e165b3970c8ad3fcef1e2de092e72b67e1f3aaff3016eea1697f98a01a28fb4f71f8044e3338a95bd2262327e9299be416efe5297018658760a331b4 + languageName: node + linkType: hard + "@types/uuid@npm:^9": version: 9.0.1 resolution: "@types/uuid@npm:9.0.1" @@ -13265,6 +13279,16 @@ __metadata: languageName: node linkType: hard +"axios@npm:0.27.2": + version: 0.27.2 + resolution: "axios@npm:0.27.2" + dependencies: + follow-redirects: ^1.14.9 + form-data: ^4.0.0 + checksum: 38cb7540465fe8c4102850c4368053c21683af85c5fdf0ea619f9628abbcb59415d1e22ebc8a6390d2bbc9b58a9806c874f139767389c862ec9b772235f06854 + languageName: node + linkType: hard + "axios@npm:^0.21.1": version: 0.21.4 resolution: "axios@npm:0.21.4" @@ -14680,10 +14704,17 @@ __metadata: languageName: node linkType: hard -"cohere-ai@npm:>=6.0.0": - version: 6.2.2 - resolution: "cohere-ai@npm:6.2.2" - checksum: 5a7ea2bb6f2a6b83de7ec90612a38f0b215bfc199b5fe749842c6e771a1b183681a3a0e351d75c22eb6d929ea45cf03b1cdf4ed3a907c401186e73c392749310 +"cohere-ai@npm:^7.2.0": + version: 7.2.0 + resolution: "cohere-ai@npm:7.2.0" + dependencies: + "@types/qs": 6.9.8 + "@types/url-join": 4.0.1 + axios: 0.27.2 + js-base64: 3.7.2 + qs: 6.11.2 + url-join: 4.0.1 + checksum: 4ea91185bacb343d0d7cb22d2c256075ea57c878f3b9398dbf2d5ccc4149a8cec1c8074373147ea8608469fba885439bca967d31dd068df0b3ddba12fd74bc7d languageName: node linkType: hard @@ -18334,7 +18365,7 @@ __metadata: languageName: node linkType: hard -"follow-redirects@npm:^1.0.0, follow-redirects@npm:^1.14.7": +"follow-redirects@npm:^1.0.0, follow-redirects@npm:^1.14.7, follow-redirects@npm:^1.14.9": version: 1.15.3 resolution: "follow-redirects@npm:1.15.3" peerDependenciesMeta: @@ -21706,6 +21737,13 @@ __metadata: languageName: node linkType: hard +"js-base64@npm:3.7.2": + version: 3.7.2 + resolution: "js-base64@npm:3.7.2" + checksum: 573f28e9a27c3df60096d4d3f551bcb4fcb6d49161cf83396e9bad9b76f94736a70bb70b8808fe834dff2a388f76604ba09d6e153bbf181646e407720139fa5b + languageName: node + linkType: hard + "js-sdsl@npm:^4.1.4": version: 4.3.0 resolution: "js-sdsl@npm:4.3.0" @@ -22229,7 +22267,7 @@ __metadata: closevector-common: 0.1.0-alpha.1 closevector-node: 0.1.0-alpha.10 closevector-web: 0.1.0-alpha.15 - cohere-ai: ">=6.0.0" + cohere-ai: ^7.2.0 convex: ^1.3.1 d3-dsv: ^2.0.0 decamelize: ^1.2.0 @@ -22362,7 +22400,7 @@ __metadata: closevector-common: 0.1.0-alpha.1 closevector-node: 0.1.0-alpha.10 closevector-web: 0.1.0-alpha.16 - cohere-ai: ">=6.0.0" + cohere-ai: ^7.2.0 convex: ^1.3.1 d3-dsv: ^2.0.0 epub2: ^3.0.1 @@ -26856,7 +26894,7 @@ __metadata: languageName: node linkType: hard -"qs@npm:^6.7.0": +"qs@npm:6.11.2, qs@npm:^6.7.0": version: 6.11.2 resolution: "qs@npm:6.11.2" dependencies: @@ -30856,6 +30894,13 @@ __metadata: languageName: node linkType: hard +"url-join@npm:4.0.1, url-join@npm:^4.0.1": + version: 4.0.1 + resolution: "url-join@npm:4.0.1" + checksum: f74e868bf25dbc8be6a8d7237d4c36bb5b6c62c72e594d5ab1347fe91d6af7ccd9eb5d621e30152e4da45c2e9a26bec21390e911ab54a62d4d82e76028374ee5 + languageName: node + linkType: hard + "url-join@npm:5.0.0": version: 5.0.0 resolution: "url-join@npm:5.0.0" @@ -30863,13 +30908,6 @@ __metadata: languageName: node linkType: hard -"url-join@npm:^4.0.1": - version: 4.0.1 - resolution: "url-join@npm:4.0.1" - checksum: f74e868bf25dbc8be6a8d7237d4c36bb5b6c62c72e594d5ab1347fe91d6af7ccd9eb5d621e30152e4da45c2e9a26bec21390e911ab54a62d4d82e76028374ee5 - languageName: node - linkType: hard - "url-loader@npm:^4.1.1": version: 4.1.1 resolution: "url-loader@npm:4.1.1"