diff --git a/docs/core_docs/docs/integrations/text_embedding/cohere.mdx b/docs/core_docs/docs/integrations/text_embedding/cohere.mdx
index 4a11080a64c..38727886a31 100644
--- a/docs/core_docs/docs/integrations/text_embedding/cohere.mdx
+++ b/docs/core_docs/docs/integrations/text_embedding/cohere.mdx
@@ -6,11 +6,7 @@ The `CohereEmbeddings` class uses the Cohere API to generate embeddings for a gi
npm install cohere-ai
```
-```typescript
-import { CohereEmbeddings } from "langchain/embeddings/cohere";
+import CodeBlock from "@theme/CodeBlock";
+import BasicExample from "@examples/embeddings/cohere.ts";
-const embeddings = new CohereEmbeddings({
- apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
- batchSize: 48, // Default value if omitted is 48. Max value is 96
-});
-```
+{BasicExample}
diff --git a/examples/src/embeddings/cohere.ts b/examples/src/embeddings/cohere.ts
index 3499703a8e9..d99f4620938 100644
--- a/examples/src/embeddings/cohere.ts
+++ b/examples/src/embeddings/cohere.ts
@@ -1,9 +1,36 @@
import { CohereEmbeddings } from "langchain/embeddings/cohere";
-export const run = async () => {
- const model = new CohereEmbeddings();
- const res = await model.embedQuery(
- "What would be a good company name a company that makes colorful socks?"
- );
- console.log({ res });
-};
+const cohere = new CohereEmbeddings({
+ apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
+ batchSize: 48, // Default value if omitted is 48. Max value is 96
+ modelName: "embed-english-v3.0", // Default value if omitted is "small".
+ inputType: "classification", // Optional parameter unless using a v3 model.
+});
+
+const texts = [
+ "I love Cohere!",
+ "I hate Cohere!",
+ "I feel neutral about Cohere.",
+];
+
+const embeddings = await cohere.embedDocuments(texts);
+console.log(embeddings);
+/**
+ * [
+ [
+ -0.007194519, -0.009376526, -0.10015869, -0.06750488,
+ -0.011001587, -0.034454346, -0.074523926, 0.03756714,
+ ... 943 more items
+ ],
+ [
+ -0.009613037, -0.022705078, -0.07318115, -0.02255249, -0.019729614,
+ 0.009689331, -0.024749756, -0.06665039, -0.02128601, -0.010520935,
+ ... 943 more items
+ ],
+ [
+ 0.009689331, -0.024749756, -0.06665039, -0.02128601, -0.010520935,
+ -0.007194519, -0.009376526, -0.10015869, -0.06750488, -0.02128601,
+ ... 943 more items
+ ]
+]
+ */
diff --git a/examples/src/models/embeddings/cohere.ts b/examples/src/models/embeddings/cohere.ts
deleted file mode 100644
index 8bb558d50be..00000000000
--- a/examples/src/models/embeddings/cohere.ts
+++ /dev/null
@@ -1,14 +0,0 @@
-import { CohereEmbeddings } from "langchain/embeddings/cohere";
-
-export const run = async () => {
- /* Embed queries */
- const embeddings = new CohereEmbeddings();
- const res = await embeddings.embedQuery("Hello world");
- console.log(res);
- /* Embed documents */
- const documentRes = await embeddings.embedDocuments([
- "Hello world",
- "Bye bye",
- ]);
- console.log({ documentRes });
-};
diff --git a/langchain/package.json b/langchain/package.json
index 0905ae50f29..171f6aa9442 100644
--- a/langchain/package.json
+++ b/langchain/package.json
@@ -902,7 +902,7 @@
"closevector-common": "0.1.0-alpha.1",
"closevector-node": "0.1.0-alpha.10",
"closevector-web": "0.1.0-alpha.15",
- "cohere-ai": ">=6.0.0",
+ "cohere-ai": "^7.2.0",
"convex": "^1.3.1",
"d3-dsv": "^2.0.0",
"dotenv": "^16.0.3",
@@ -1019,7 +1019,7 @@
"closevector-common": "0.1.0-alpha.1",
"closevector-node": "0.1.0-alpha.10",
"closevector-web": "0.1.0-alpha.16",
- "cohere-ai": ">=6.0.0",
+ "cohere-ai": "^7.2.0",
"convex": "^1.3.1",
"d3-dsv": "^2.0.0",
"epub2": "^3.0.1",
diff --git a/langchain/src/embeddings/cohere.ts b/langchain/src/embeddings/cohere.ts
index 3c547e769f6..1382e158468 100644
--- a/langchain/src/embeddings/cohere.ts
+++ b/langchain/src/embeddings/cohere.ts
@@ -14,6 +14,18 @@ export interface CohereEmbeddingsParams extends EmbeddingsParams {
* limited by the Cohere API to a maximum of 96.
*/
batchSize?: number;
+
+ /**
+ * Specifies the type of input you're giving to the model.
+ * Not required for older versions of the embedding models (i.e. anything lower than v3),
+ * but is required for more recent versions (i.e. anything bigger than v2).
+ *
+ * * `search_document` - Use this when you encode documents for embeddings that you store in a vector database for search use-cases.
+ * * `search_query` - Use this when you query your vector DB to find relevant documents.
+ * * `classification` - Use this when you use the embeddings as an input to a text classifier.
+ * * `clustering` - Use this when you want to cluster the embeddings.
+ */
+ inputType?: string;
}
/**
@@ -27,9 +39,11 @@ export class CohereEmbeddings
batchSize = 48;
+ inputType: string | undefined;
+
private apiKey: string;
- private client: typeof import("cohere-ai");
+ private client: import("cohere-ai").CohereClient;
/**
* Constructor for the CohereEmbeddings class.
@@ -54,6 +68,7 @@ export class CohereEmbeddings
this.modelName = fieldsWithDefaults?.modelName ?? this.modelName;
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
+ this.inputType = fieldsWithDefaults?.inputType;
this.apiKey = apiKey;
}
@@ -71,6 +86,7 @@ export class CohereEmbeddings
this.embeddingWithRetry({
model: this.modelName,
texts: batch,
+ inputType: this.inputType,
})
);
@@ -80,9 +96,9 @@ export class CohereEmbeddings
for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
- const { body: batchResponse } = batchResponses[i];
+ const { embeddings: batchResponse } = batchResponses[i];
for (let j = 0; j < batch.length; j += 1) {
- embeddings.push(batchResponse.embeddings[j]);
+ embeddings.push(batchResponse[j]);
}
}
@@ -97,11 +113,11 @@ export class CohereEmbeddings
async embedQuery(text: string): Promise {
await this.maybeInitClient();
- const { body } = await this.embeddingWithRetry({
+ const { embeddings } = await this.embeddingWithRetry({
model: this.modelName,
texts: [text],
});
- return body.embeddings[0];
+ return embeddings[0];
}
/**
@@ -122,20 +138,21 @@ export class CohereEmbeddings
*/
private async maybeInitClient() {
if (!this.client) {
- const { cohere } = await CohereEmbeddings.imports();
+ const { CohereClient } = await CohereEmbeddings.imports();
- this.client = cohere;
- this.client.init(this.apiKey);
+ this.client = new CohereClient({
+ token: this.apiKey,
+ });
}
}
/** @ignore */
static async imports(): Promise<{
- cohere: typeof import("cohere-ai");
+ CohereClient: typeof import("cohere-ai").CohereClient;
}> {
try {
- const { default: cohere } = await import("cohere-ai");
- return { cohere };
+ const { CohereClient } = await import("cohere-ai");
+ return { CohereClient };
} catch (e) {
throw new Error(
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`"
diff --git a/langchain/src/llms/cohere.ts b/langchain/src/llms/cohere.ts
index 32a1e1f746b..c33c71ad563 100644
--- a/langchain/src/llms/cohere.ts
+++ b/langchain/src/llms/cohere.ts
@@ -76,9 +76,11 @@ export class Cohere extends LLM implements CohereInput {
prompt: string,
options: this["ParsedCallOptions"]
): Promise {
- const { cohere } = await Cohere.imports();
+ const { CohereClient } = await Cohere.imports();
- cohere.init(this.apiKey);
+ const cohere = new CohereClient({
+ token: this.apiKey,
+ });
// Hit the `generate` endpoint on the `large` model
const generateResponse = await this.caller.callWithOptions(
@@ -87,13 +89,13 @@ export class Cohere extends LLM implements CohereInput {
{
prompt,
model: this.model,
- max_tokens: this.maxTokens,
+ maxTokens: this.maxTokens,
temperature: this.temperature,
- end_sequences: options.stop,
+ endSequences: options.stop,
}
);
try {
- return generateResponse.body.generations[0].text;
+ return generateResponse.generations[0].text;
} catch {
console.log(generateResponse);
throw new Error("Could not parse response.");
@@ -102,11 +104,11 @@ export class Cohere extends LLM implements CohereInput {
/** @ignore */
static async imports(): Promise<{
- cohere: typeof import("cohere-ai");
+ CohereClient: typeof import("cohere-ai").CohereClient;
}> {
try {
- const { default: cohere } = await import("cohere-ai");
- return { cohere };
+ const { CohereClient } = await import("cohere-ai");
+ return { CohereClient };
} catch (e) {
throw new Error(
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`"
diff --git a/yarn.lock b/yarn.lock
index 396b4a9d8ab..3d395fadd09 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -11580,6 +11580,13 @@ __metadata:
languageName: node
linkType: hard
+"@types/qs@npm:6.9.8":
+ version: 6.9.8
+ resolution: "@types/qs@npm:6.9.8"
+ checksum: c28e07d00d07970e5134c6eed184a0189b8a4649e28fdf36d9117fe671c067a44820890de6bdecef18217647a95e9c6aebdaaae69f5fe4b0bec9345db885f77e
+ languageName: node
+ linkType: hard
+
"@types/range-parser@npm:*":
version: 1.2.4
resolution: "@types/range-parser@npm:1.2.4"
@@ -11804,6 +11811,13 @@ __metadata:
languageName: node
linkType: hard
+"@types/url-join@npm:4.0.1":
+ version: 4.0.1
+ resolution: "@types/url-join@npm:4.0.1"
+ checksum: 29444b90e165b3970c8ad3fcef1e2de092e72b67e1f3aaff3016eea1697f98a01a28fb4f71f8044e3338a95bd2262327e9299be416efe5297018658760a331b4
+ languageName: node
+ linkType: hard
+
"@types/uuid@npm:^9":
version: 9.0.1
resolution: "@types/uuid@npm:9.0.1"
@@ -13265,6 +13279,16 @@ __metadata:
languageName: node
linkType: hard
+"axios@npm:0.27.2":
+ version: 0.27.2
+ resolution: "axios@npm:0.27.2"
+ dependencies:
+ follow-redirects: ^1.14.9
+ form-data: ^4.0.0
+ checksum: 38cb7540465fe8c4102850c4368053c21683af85c5fdf0ea619f9628abbcb59415d1e22ebc8a6390d2bbc9b58a9806c874f139767389c862ec9b772235f06854
+ languageName: node
+ linkType: hard
+
"axios@npm:^0.21.1":
version: 0.21.4
resolution: "axios@npm:0.21.4"
@@ -14680,10 +14704,17 @@ __metadata:
languageName: node
linkType: hard
-"cohere-ai@npm:>=6.0.0":
- version: 6.2.2
- resolution: "cohere-ai@npm:6.2.2"
- checksum: 5a7ea2bb6f2a6b83de7ec90612a38f0b215bfc199b5fe749842c6e771a1b183681a3a0e351d75c22eb6d929ea45cf03b1cdf4ed3a907c401186e73c392749310
+"cohere-ai@npm:^7.2.0":
+ version: 7.2.0
+ resolution: "cohere-ai@npm:7.2.0"
+ dependencies:
+ "@types/qs": 6.9.8
+ "@types/url-join": 4.0.1
+ axios: 0.27.2
+ js-base64: 3.7.2
+ qs: 6.11.2
+ url-join: 4.0.1
+ checksum: 4ea91185bacb343d0d7cb22d2c256075ea57c878f3b9398dbf2d5ccc4149a8cec1c8074373147ea8608469fba885439bca967d31dd068df0b3ddba12fd74bc7d
languageName: node
linkType: hard
@@ -18334,7 +18365,7 @@ __metadata:
languageName: node
linkType: hard
-"follow-redirects@npm:^1.0.0, follow-redirects@npm:^1.14.7":
+"follow-redirects@npm:^1.0.0, follow-redirects@npm:^1.14.7, follow-redirects@npm:^1.14.9":
version: 1.15.3
resolution: "follow-redirects@npm:1.15.3"
peerDependenciesMeta:
@@ -21706,6 +21737,13 @@ __metadata:
languageName: node
linkType: hard
+"js-base64@npm:3.7.2":
+ version: 3.7.2
+ resolution: "js-base64@npm:3.7.2"
+ checksum: 573f28e9a27c3df60096d4d3f551bcb4fcb6d49161cf83396e9bad9b76f94736a70bb70b8808fe834dff2a388f76604ba09d6e153bbf181646e407720139fa5b
+ languageName: node
+ linkType: hard
+
"js-sdsl@npm:^4.1.4":
version: 4.3.0
resolution: "js-sdsl@npm:4.3.0"
@@ -22229,7 +22267,7 @@ __metadata:
closevector-common: 0.1.0-alpha.1
closevector-node: 0.1.0-alpha.10
closevector-web: 0.1.0-alpha.15
- cohere-ai: ">=6.0.0"
+ cohere-ai: ^7.2.0
convex: ^1.3.1
d3-dsv: ^2.0.0
decamelize: ^1.2.0
@@ -22362,7 +22400,7 @@ __metadata:
closevector-common: 0.1.0-alpha.1
closevector-node: 0.1.0-alpha.10
closevector-web: 0.1.0-alpha.16
- cohere-ai: ">=6.0.0"
+ cohere-ai: ^7.2.0
convex: ^1.3.1
d3-dsv: ^2.0.0
epub2: ^3.0.1
@@ -26856,7 +26894,7 @@ __metadata:
languageName: node
linkType: hard
-"qs@npm:^6.7.0":
+"qs@npm:6.11.2, qs@npm:^6.7.0":
version: 6.11.2
resolution: "qs@npm:6.11.2"
dependencies:
@@ -30856,6 +30894,13 @@ __metadata:
languageName: node
linkType: hard
+"url-join@npm:4.0.1, url-join@npm:^4.0.1":
+ version: 4.0.1
+ resolution: "url-join@npm:4.0.1"
+ checksum: f74e868bf25dbc8be6a8d7237d4c36bb5b6c62c72e594d5ab1347fe91d6af7ccd9eb5d621e30152e4da45c2e9a26bec21390e911ab54a62d4d82e76028374ee5
+ languageName: node
+ linkType: hard
+
"url-join@npm:5.0.0":
version: 5.0.0
resolution: "url-join@npm:5.0.0"
@@ -30863,13 +30908,6 @@ __metadata:
languageName: node
linkType: hard
-"url-join@npm:^4.0.1":
- version: 4.0.1
- resolution: "url-join@npm:4.0.1"
- checksum: f74e868bf25dbc8be6a8d7237d4c36bb5b6c62c72e594d5ab1347fe91d6af7ccd9eb5d621e30152e4da45c2e9a26bec21390e911ab54a62d4d82e76028374ee5
- languageName: node
- linkType: hard
-
"url-loader@npm:^4.1.1":
version: 4.1.1
resolution: "url-loader@npm:4.1.1"