Skip to content

Commit

Permalink
Brace/bump cohere (#3263)
Browse files Browse the repository at this point in the history
* bump cohere to support v3 embeddings

* chore: lint files

* nit

* Remove unused example file, add cohere example

* wut

* chore: lint files
  • Loading branch information
bracesproul authored Nov 14, 2023
1 parent 3983e5c commit 95f8563
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 64 deletions.
10 changes: 3 additions & 7 deletions docs/core_docs/docs/integrations/text_embedding/cohere.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ The `CohereEmbeddings` class uses the Cohere API to generate embeddings for a gi
npm install cohere-ai
```

```typescript
import { CohereEmbeddings } from "langchain/embeddings/cohere";
import CodeBlock from "@theme/CodeBlock";
import BasicExample from "@examples/embeddings/cohere.ts";

const embeddings = new CohereEmbeddings({
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
batchSize: 48, // Default value if omitted is 48. Max value is 96
});
```
<CodeBlock language="typescript">{BasicExample}</CodeBlock>
41 changes: 34 additions & 7 deletions examples/src/embeddings/cohere.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import { CohereEmbeddings } from "langchain/embeddings/cohere";

export const run = async () => {
const model = new CohereEmbeddings();
const res = await model.embedQuery(
"What would be a good company name a company that makes colorful socks?"
);
console.log({ res });
};
const cohere = new CohereEmbeddings({
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
batchSize: 48, // Default value if omitted is 48. Max value is 96
modelName: "embed-english-v3.0", // Default value if omitted is "small".
inputType: "classification", // Optional parameter unless using a v3 model.
});

const texts = [
"I love Cohere!",
"I hate Cohere!",
"I feel neutral about Cohere.",
];

const embeddings = await cohere.embedDocuments(texts);
console.log(embeddings);
/**
* [
[
-0.007194519, -0.009376526, -0.10015869, -0.06750488,
-0.011001587, -0.034454346, -0.074523926, 0.03756714,
... 943 more items
],
[
-0.009613037, -0.022705078, -0.07318115, -0.02255249, -0.019729614,
0.009689331, -0.024749756, -0.06665039, -0.02128601, -0.010520935,
... 943 more items
],
[
0.009689331, -0.024749756, -0.06665039, -0.02128601, -0.010520935,
-0.007194519, -0.009376526, -0.10015869, -0.06750488, -0.02128601,
... 943 more items
]
]
*/
14 changes: 0 additions & 14 deletions examples/src/models/embeddings/cohere.ts

This file was deleted.

4 changes: 2 additions & 2 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@
"closevector-common": "0.1.0-alpha.1",
"closevector-node": "0.1.0-alpha.10",
"closevector-web": "0.1.0-alpha.15",
"cohere-ai": ">=6.0.0",
"cohere-ai": "^7.2.0",
"convex": "^1.3.1",
"d3-dsv": "^2.0.0",
"dotenv": "^16.0.3",
Expand Down Expand Up @@ -1019,7 +1019,7 @@
"closevector-common": "0.1.0-alpha.1",
"closevector-node": "0.1.0-alpha.10",
"closevector-web": "0.1.0-alpha.16",
"cohere-ai": ">=6.0.0",
"cohere-ai": "^7.2.0",
"convex": "^1.3.1",
"d3-dsv": "^2.0.0",
"epub2": "^3.0.1",
Expand Down
39 changes: 28 additions & 11 deletions langchain/src/embeddings/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ export interface CohereEmbeddingsParams extends EmbeddingsParams {
* limited by the Cohere API to a maximum of 96.
*/
batchSize?: number;

/**
* Specifies the type of input you're giving to the model.
* Not required for older versions of the embedding models (i.e. anything lower than v3),
* but is required for more recent versions (i.e. anything bigger than v2).
*
* * `search_document` - Use this when you encode documents for embeddings that you store in a vector database for search use-cases.
* * `search_query` - Use this when you query your vector DB to find relevant documents.
* * `classification` - Use this when you use the embeddings as an input to a text classifier.
* * `clustering` - Use this when you want to cluster the embeddings.
*/
inputType?: string;
}

/**
Expand All @@ -27,9 +39,11 @@ export class CohereEmbeddings

batchSize = 48;

inputType: string | undefined;

private apiKey: string;

private client: typeof import("cohere-ai");
private client: import("cohere-ai").CohereClient;

/**
* Constructor for the CohereEmbeddings class.
Expand All @@ -54,6 +68,7 @@ export class CohereEmbeddings

this.modelName = fieldsWithDefaults?.modelName ?? this.modelName;
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
this.inputType = fieldsWithDefaults?.inputType;
this.apiKey = apiKey;
}

Expand All @@ -71,6 +86,7 @@ export class CohereEmbeddings
this.embeddingWithRetry({
model: this.modelName,
texts: batch,
inputType: this.inputType,
})
);

Expand All @@ -80,9 +96,9 @@ export class CohereEmbeddings

for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const { body: batchResponse } = batchResponses[i];
const { embeddings: batchResponse } = batchResponses[i];
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse.embeddings[j]);
embeddings.push(batchResponse[j]);
}
}

Expand All @@ -97,11 +113,11 @@ export class CohereEmbeddings
async embedQuery(text: string): Promise<number[]> {
await this.maybeInitClient();

const { body } = await this.embeddingWithRetry({
const { embeddings } = await this.embeddingWithRetry({
model: this.modelName,
texts: [text],
});
return body.embeddings[0];
return embeddings[0];
}

/**
Expand All @@ -122,20 +138,21 @@ export class CohereEmbeddings
*/
private async maybeInitClient() {
if (!this.client) {
const { cohere } = await CohereEmbeddings.imports();
const { CohereClient } = await CohereEmbeddings.imports();

this.client = cohere;
this.client.init(this.apiKey);
this.client = new CohereClient({
token: this.apiKey,
});
}
}

/** @ignore */
static async imports(): Promise<{
cohere: typeof import("cohere-ai");
CohereClient: typeof import("cohere-ai").CohereClient;
}> {
try {
const { default: cohere } = await import("cohere-ai");
return { cohere };
const { CohereClient } = await import("cohere-ai");
return { CohereClient };
} catch (e) {
throw new Error(
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`"
Expand Down
18 changes: 10 additions & 8 deletions langchain/src/llms/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ export class Cohere extends LLM implements CohereInput {
prompt: string,
options: this["ParsedCallOptions"]
): Promise<string> {
const { cohere } = await Cohere.imports();
const { CohereClient } = await Cohere.imports();

cohere.init(this.apiKey);
const cohere = new CohereClient({
token: this.apiKey,
});

// Hit the `generate` endpoint on the `large` model
const generateResponse = await this.caller.callWithOptions(
Expand All @@ -87,13 +89,13 @@ export class Cohere extends LLM implements CohereInput {
{
prompt,
model: this.model,
max_tokens: this.maxTokens,
maxTokens: this.maxTokens,
temperature: this.temperature,
end_sequences: options.stop,
endSequences: options.stop,
}
);
try {
return generateResponse.body.generations[0].text;
return generateResponse.generations[0].text;
} catch {
console.log(generateResponse);
throw new Error("Could not parse response.");
Expand All @@ -102,11 +104,11 @@ export class Cohere extends LLM implements CohereInput {

/** @ignore */
static async imports(): Promise<{
cohere: typeof import("cohere-ai");
CohereClient: typeof import("cohere-ai").CohereClient;
}> {
try {
const { default: cohere } = await import("cohere-ai");
return { cohere };
const { CohereClient } = await import("cohere-ai");
return { CohereClient };
} catch (e) {
throw new Error(
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`"
Expand Down
68 changes: 53 additions & 15 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11580,6 +11580,13 @@ __metadata:
languageName: node
linkType: hard

"@types/qs@npm:6.9.8":
version: 6.9.8
resolution: "@types/qs@npm:6.9.8"
checksum: c28e07d00d07970e5134c6eed184a0189b8a4649e28fdf36d9117fe671c067a44820890de6bdecef18217647a95e9c6aebdaaae69f5fe4b0bec9345db885f77e
languageName: node
linkType: hard

"@types/range-parser@npm:*":
version: 1.2.4
resolution: "@types/range-parser@npm:1.2.4"
Expand Down Expand Up @@ -11804,6 +11811,13 @@ __metadata:
languageName: node
linkType: hard

"@types/url-join@npm:4.0.1":
version: 4.0.1
resolution: "@types/url-join@npm:4.0.1"
checksum: 29444b90e165b3970c8ad3fcef1e2de092e72b67e1f3aaff3016eea1697f98a01a28fb4f71f8044e3338a95bd2262327e9299be416efe5297018658760a331b4
languageName: node
linkType: hard

"@types/uuid@npm:^9":
version: 9.0.1
resolution: "@types/uuid@npm:9.0.1"
Expand Down Expand Up @@ -13265,6 +13279,16 @@ __metadata:
languageName: node
linkType: hard

"axios@npm:0.27.2":
version: 0.27.2
resolution: "axios@npm:0.27.2"
dependencies:
follow-redirects: ^1.14.9
form-data: ^4.0.0
checksum: 38cb7540465fe8c4102850c4368053c21683af85c5fdf0ea619f9628abbcb59415d1e22ebc8a6390d2bbc9b58a9806c874f139767389c862ec9b772235f06854
languageName: node
linkType: hard

"axios@npm:^0.21.1":
version: 0.21.4
resolution: "axios@npm:0.21.4"
Expand Down Expand Up @@ -14680,10 +14704,17 @@ __metadata:
languageName: node
linkType: hard

"cohere-ai@npm:>=6.0.0":
version: 6.2.2
resolution: "cohere-ai@npm:6.2.2"
checksum: 5a7ea2bb6f2a6b83de7ec90612a38f0b215bfc199b5fe749842c6e771a1b183681a3a0e351d75c22eb6d929ea45cf03b1cdf4ed3a907c401186e73c392749310
"cohere-ai@npm:^7.2.0":
version: 7.2.0
resolution: "cohere-ai@npm:7.2.0"
dependencies:
"@types/qs": 6.9.8
"@types/url-join": 4.0.1
axios: 0.27.2
js-base64: 3.7.2
qs: 6.11.2
url-join: 4.0.1
checksum: 4ea91185bacb343d0d7cb22d2c256075ea57c878f3b9398dbf2d5ccc4149a8cec1c8074373147ea8608469fba885439bca967d31dd068df0b3ddba12fd74bc7d
languageName: node
linkType: hard

Expand Down Expand Up @@ -18334,7 +18365,7 @@ __metadata:
languageName: node
linkType: hard

"follow-redirects@npm:^1.0.0, follow-redirects@npm:^1.14.7":
"follow-redirects@npm:^1.0.0, follow-redirects@npm:^1.14.7, follow-redirects@npm:^1.14.9":
version: 1.15.3
resolution: "follow-redirects@npm:1.15.3"
peerDependenciesMeta:
Expand Down Expand Up @@ -21706,6 +21737,13 @@ __metadata:
languageName: node
linkType: hard

"js-base64@npm:3.7.2":
version: 3.7.2
resolution: "js-base64@npm:3.7.2"
checksum: 573f28e9a27c3df60096d4d3f551bcb4fcb6d49161cf83396e9bad9b76f94736a70bb70b8808fe834dff2a388f76604ba09d6e153bbf181646e407720139fa5b
languageName: node
linkType: hard

"js-sdsl@npm:^4.1.4":
version: 4.3.0
resolution: "js-sdsl@npm:4.3.0"
Expand Down Expand Up @@ -22229,7 +22267,7 @@ __metadata:
closevector-common: 0.1.0-alpha.1
closevector-node: 0.1.0-alpha.10
closevector-web: 0.1.0-alpha.15
cohere-ai: ">=6.0.0"
cohere-ai: ^7.2.0
convex: ^1.3.1
d3-dsv: ^2.0.0
decamelize: ^1.2.0
Expand Down Expand Up @@ -22362,7 +22400,7 @@ __metadata:
closevector-common: 0.1.0-alpha.1
closevector-node: 0.1.0-alpha.10
closevector-web: 0.1.0-alpha.16
cohere-ai: ">=6.0.0"
cohere-ai: ^7.2.0
convex: ^1.3.1
d3-dsv: ^2.0.0
epub2: ^3.0.1
Expand Down Expand Up @@ -26856,7 +26894,7 @@ __metadata:
languageName: node
linkType: hard

"qs@npm:^6.7.0":
"qs@npm:6.11.2, qs@npm:^6.7.0":
version: 6.11.2
resolution: "qs@npm:6.11.2"
dependencies:
Expand Down Expand Up @@ -30856,20 +30894,20 @@ __metadata:
languageName: node
linkType: hard

"url-join@npm:4.0.1, url-join@npm:^4.0.1":
version: 4.0.1
resolution: "url-join@npm:4.0.1"
checksum: f74e868bf25dbc8be6a8d7237d4c36bb5b6c62c72e594d5ab1347fe91d6af7ccd9eb5d621e30152e4da45c2e9a26bec21390e911ab54a62d4d82e76028374ee5
languageName: node
linkType: hard

"url-join@npm:5.0.0":
version: 5.0.0
resolution: "url-join@npm:5.0.0"
checksum: 5921384a8ad4395b49ce4b50aa26efbc429cebe0bc8b3660ad693dd12fd859747b5369be0443e60e53a7850b2bc9d7d0687bcb94386662b40e743596bbf38101
languageName: node
linkType: hard

"url-join@npm:^4.0.1":
version: 4.0.1
resolution: "url-join@npm:4.0.1"
checksum: f74e868bf25dbc8be6a8d7237d4c36bb5b6c62c72e594d5ab1347fe91d6af7ccd9eb5d621e30152e4da45c2e9a26bec21390e911ab54a62d4d82e76028374ee5
languageName: node
linkType: hard

"url-loader@npm:^4.1.1":
version: 4.1.1
resolution: "url-loader@npm:4.1.1"
Expand Down

1 comment on commit 95f8563

@vercel
Copy link

@vercel vercel bot commented on 95f8563 Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

langchainjs-docs – ./docs/core_docs/

langchainjs-docs-langchain.vercel.app
langchainjs-docs-git-main-langchain.vercel.app
langchainjs-docs-ruddy.vercel.app
js.langchain.com

Please sign in to comment.