Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brace/bump cohere #3263

Merged
merged 6 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions docs/core_docs/docs/integrations/text_embedding/cohere.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ The `CohereEmbeddings` class uses the Cohere API to generate embeddings for a gi
npm install cohere-ai
```

```typescript
import { CohereEmbeddings } from "langchain/embeddings/cohere";
import CodeBlock from "@theme/CodeBlock";
import BasicExample from "@examples/embeddings/cohere.ts";

const embeddings = new CohereEmbeddings({
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
batchSize: 48, // Default value if omitted is 48. Max value is 96
});
```
<CodeBlock language="typescript">{BasicExample}</CodeBlock>
41 changes: 34 additions & 7 deletions examples/src/embeddings/cohere.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import { CohereEmbeddings } from "langchain/embeddings/cohere";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds code that accesses an environment variable via process.env to retrieve the API key for the CohereEmbeddings instance. Please review this change to ensure the environment variable is properly set and handled.


export const run = async () => {
const model = new CohereEmbeddings();
const res = await model.embedQuery(
"What would be a good company name a company that makes colorful socks?"
);
console.log({ res });
};
const cohere = new CohereEmbeddings({
apiKey: "YOUR-API-KEY", // In Node.js defaults to process.env.COHERE_API_KEY
batchSize: 48, // Default value if omitted is 48. Max value is 96
modelName: "embed-english-v3.0", // Default value if omitted is "small".
inputType: "classification", // Optional parameter unless using a v3 model.
});

const texts = [
"I love Cohere!",
"I hate Cohere!",
"I feel neutral about Cohere.",
];

const embeddings = await cohere.embedDocuments(texts);
console.log(embeddings);
/**
* [
[
-0.007194519, -0.009376526, -0.10015869, -0.06750488,
-0.011001587, -0.034454346, -0.074523926, 0.03756714,
... 943 more items
],
[
-0.009613037, -0.022705078, -0.07318115, -0.02255249, -0.019729614,
0.009689331, -0.024749756, -0.06665039, -0.02128601, -0.010520935,
... 943 more items
],
[
0.009689331, -0.024749756, -0.06665039, -0.02128601, -0.010520935,
-0.007194519, -0.009376526, -0.10015869, -0.06750488, -0.02128601,
... 943 more items
]
]
*/
14 changes: 0 additions & 14 deletions examples/src/models/embeddings/cohere.ts

This file was deleted.

4 changes: 2 additions & 2 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@
"closevector-common": "0.1.0-alpha.1",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on the PR! I've noticed a change in the dev dependency "cohere-ai" from ">=6.0.0" to "^7.2.0". This comment is to flag the change for maintainers to review.

"closevector-node": "0.1.0-alpha.10",
"closevector-web": "0.1.0-alpha.15",
"cohere-ai": ">=6.0.0",
"cohere-ai": "^7.2.0",
"convex": "^1.3.1",
"d3-dsv": "^2.0.0",
"dotenv": "^16.0.3",
Expand Down Expand Up @@ -1019,7 +1019,7 @@
"closevector-common": "0.1.0-alpha.1",
"closevector-node": "0.1.0-alpha.10",
"closevector-web": "0.1.0-alpha.16",
"cohere-ai": ">=6.0.0",
"cohere-ai": "^7.2.0",
"convex": "^1.3.1",
"d3-dsv": "^2.0.0",
"epub2": "^3.0.1",
Expand Down
39 changes: 28 additions & 11 deletions langchain/src/embeddings/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ export interface CohereEmbeddingsParams extends EmbeddingsParams {
* limited by the Cohere API to a maximum of 96.
*/
batchSize?: number;

/**
* Specifies the type of input you're giving to the model.
* Not required for older versions of the embedding models (i.e. anything lower than v3),
* but is required for more recent versions (i.e. anything bigger than v2).
*
* * `search_document` - Use this when you encode documents for embeddings that you store in a vector database for search use-cases.
* * `search_query` - Use this when you query your vector DB to find relevant documents.
* * `classification` - Use this when you use the embeddings as an input to a text classifier.
* * `clustering` - Use this when you want to cluster the embeddings.
*/
inputType?: string;
}

/**
Expand All @@ -27,9 +39,11 @@ export class CohereEmbeddings

batchSize = 48;

inputType: string | undefined;

private apiKey: string;

private client: typeof import("cohere-ai");
private client: import("cohere-ai").CohereClient;

/**
* Constructor for the CohereEmbeddings class.
Expand All @@ -54,6 +68,7 @@ export class CohereEmbeddings

this.modelName = fieldsWithDefaults?.modelName ?? this.modelName;
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize;
this.inputType = fieldsWithDefaults?.inputType;
this.apiKey = apiKey;
}

Expand All @@ -71,6 +86,7 @@ export class CohereEmbeddings
this.embeddingWithRetry({
model: this.modelName,
texts: batch,
inputType: this.inputType,
})
);

Expand All @@ -80,9 +96,9 @@ export class CohereEmbeddings

for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i];
const { body: batchResponse } = batchResponses[i];
const { embeddings: batchResponse } = batchResponses[i];
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse.embeddings[j]);
embeddings.push(batchResponse[j]);
}
}

Expand All @@ -97,11 +113,11 @@ export class CohereEmbeddings
async embedQuery(text: string): Promise<number[]> {
await this.maybeInitClient();

const { body } = await this.embeddingWithRetry({
const { embeddings } = await this.embeddingWithRetry({
model: this.modelName,
texts: [text],
});
return body.embeddings[0];
return embeddings[0];
}

/**
Expand All @@ -122,20 +138,21 @@ export class CohereEmbeddings
*/
private async maybeInitClient() {
if (!this.client) {
const { cohere } = await CohereEmbeddings.imports();
const { CohereClient } = await CohereEmbeddings.imports();

this.client = cohere;
this.client.init(this.apiKey);
this.client = new CohereClient({
token: this.apiKey,
});
}
}

/** @ignore */
static async imports(): Promise<{
cohere: typeof import("cohere-ai");
CohereClient: typeof import("cohere-ai").CohereClient;
}> {
try {
const { default: cohere } = await import("cohere-ai");
return { cohere };
const { CohereClient } = await import("cohere-ai");
return { CohereClient };
} catch (e) {
throw new Error(
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`"
Expand Down
18 changes: 10 additions & 8 deletions langchain/src/llms/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ export class Cohere extends LLM implements CohereInput {
prompt: string,
options: this["ParsedCallOptions"]
): Promise<string> {
const { cohere } = await Cohere.imports();
const { CohereClient } = await Cohere.imports();

cohere.init(this.apiKey);
const cohere = new CohereClient({
token: this.apiKey,
});

// Hit the `generate` endpoint on the `large` model
const generateResponse = await this.caller.callWithOptions(
Expand All @@ -87,13 +89,13 @@ export class Cohere extends LLM implements CohereInput {
{
prompt,
model: this.model,
max_tokens: this.maxTokens,
maxTokens: this.maxTokens,
temperature: this.temperature,
end_sequences: options.stop,
endSequences: options.stop,
}
);
try {
return generateResponse.body.generations[0].text;
return generateResponse.generations[0].text;
} catch {
console.log(generateResponse);
throw new Error("Could not parse response.");
Expand All @@ -102,11 +104,11 @@ export class Cohere extends LLM implements CohereInput {

/** @ignore */
static async imports(): Promise<{
cohere: typeof import("cohere-ai");
CohereClient: typeof import("cohere-ai").CohereClient;
}> {
try {
const { default: cohere } = await import("cohere-ai");
return { cohere };
const { CohereClient } = await import("cohere-ai");
return { CohereClient };
} catch (e) {
throw new Error(
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`"
Expand Down
68 changes: 53 additions & 15 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11580,6 +11580,13 @@ __metadata:
languageName: node
linkType: hard

"@types/qs@npm:6.9.8":
version: 6.9.8
resolution: "@types/qs@npm:6.9.8"
checksum: c28e07d00d07970e5134c6eed184a0189b8a4649e28fdf36d9117fe671c067a44820890de6bdecef18217647a95e9c6aebdaaae69f5fe4b0bec9345db885f77e
languageName: node
linkType: hard

"@types/range-parser@npm:*":
version: 1.2.4
resolution: "@types/range-parser@npm:1.2.4"
Expand Down Expand Up @@ -11804,6 +11811,13 @@ __metadata:
languageName: node
linkType: hard

"@types/url-join@npm:4.0.1":
version: 4.0.1
resolution: "@types/url-join@npm:4.0.1"
checksum: 29444b90e165b3970c8ad3fcef1e2de092e72b67e1f3aaff3016eea1697f98a01a28fb4f71f8044e3338a95bd2262327e9299be416efe5297018658760a331b4
languageName: node
linkType: hard

"@types/uuid@npm:^9":
version: 9.0.1
resolution: "@types/uuid@npm:9.0.1"
Expand Down Expand Up @@ -13265,6 +13279,16 @@ __metadata:
languageName: node
linkType: hard

"axios@npm:0.27.2":
version: 0.27.2
resolution: "axios@npm:0.27.2"
dependencies:
follow-redirects: ^1.14.9
form-data: ^4.0.0
checksum: 38cb7540465fe8c4102850c4368053c21683af85c5fdf0ea619f9628abbcb59415d1e22ebc8a6390d2bbc9b58a9806c874f139767389c862ec9b772235f06854
languageName: node
linkType: hard

"axios@npm:^0.21.1":
version: 0.21.4
resolution: "axios@npm:0.21.4"
Expand Down Expand Up @@ -14680,10 +14704,17 @@ __metadata:
languageName: node
linkType: hard

"cohere-ai@npm:>=6.0.0":
version: 6.2.2
resolution: "cohere-ai@npm:6.2.2"
checksum: 5a7ea2bb6f2a6b83de7ec90612a38f0b215bfc199b5fe749842c6e771a1b183681a3a0e351d75c22eb6d929ea45cf03b1cdf4ed3a907c401186e73c392749310
"cohere-ai@npm:^7.2.0":
version: 7.2.0
resolution: "cohere-ai@npm:7.2.0"
dependencies:
"@types/qs": 6.9.8
"@types/url-join": 4.0.1
axios: 0.27.2
js-base64: 3.7.2
qs: 6.11.2
url-join: 4.0.1
checksum: 4ea91185bacb343d0d7cb22d2c256075ea57c878f3b9398dbf2d5ccc4149a8cec1c8074373147ea8608469fba885439bca967d31dd068df0b3ddba12fd74bc7d
languageName: node
linkType: hard

Expand Down Expand Up @@ -18334,7 +18365,7 @@ __metadata:
languageName: node
linkType: hard

"follow-redirects@npm:^1.0.0, follow-redirects@npm:^1.14.7":
"follow-redirects@npm:^1.0.0, follow-redirects@npm:^1.14.7, follow-redirects@npm:^1.14.9":
version: 1.15.3
resolution: "follow-redirects@npm:1.15.3"
peerDependenciesMeta:
Expand Down Expand Up @@ -21706,6 +21737,13 @@ __metadata:
languageName: node
linkType: hard

"js-base64@npm:3.7.2":
version: 3.7.2
resolution: "js-base64@npm:3.7.2"
checksum: 573f28e9a27c3df60096d4d3f551bcb4fcb6d49161cf83396e9bad9b76f94736a70bb70b8808fe834dff2a388f76604ba09d6e153bbf181646e407720139fa5b
languageName: node
linkType: hard

"js-sdsl@npm:^4.1.4":
version: 4.3.0
resolution: "js-sdsl@npm:4.3.0"
Expand Down Expand Up @@ -22229,7 +22267,7 @@ __metadata:
closevector-common: 0.1.0-alpha.1
closevector-node: 0.1.0-alpha.10
closevector-web: 0.1.0-alpha.15
cohere-ai: ">=6.0.0"
cohere-ai: ^7.2.0
convex: ^1.3.1
d3-dsv: ^2.0.0
decamelize: ^1.2.0
Expand Down Expand Up @@ -22362,7 +22400,7 @@ __metadata:
closevector-common: 0.1.0-alpha.1
closevector-node: 0.1.0-alpha.10
closevector-web: 0.1.0-alpha.16
cohere-ai: ">=6.0.0"
cohere-ai: ^7.2.0
convex: ^1.3.1
d3-dsv: ^2.0.0
epub2: ^3.0.1
Expand Down Expand Up @@ -26856,7 +26894,7 @@ __metadata:
languageName: node
linkType: hard

"qs@npm:^6.7.0":
"qs@npm:6.11.2, qs@npm:^6.7.0":
version: 6.11.2
resolution: "qs@npm:6.11.2"
dependencies:
Expand Down Expand Up @@ -30856,20 +30894,20 @@ __metadata:
languageName: node
linkType: hard

"url-join@npm:4.0.1, url-join@npm:^4.0.1":
version: 4.0.1
resolution: "url-join@npm:4.0.1"
checksum: f74e868bf25dbc8be6a8d7237d4c36bb5b6c62c72e594d5ab1347fe91d6af7ccd9eb5d621e30152e4da45c2e9a26bec21390e911ab54a62d4d82e76028374ee5
languageName: node
linkType: hard

"url-join@npm:5.0.0":
version: 5.0.0
resolution: "url-join@npm:5.0.0"
checksum: 5921384a8ad4395b49ce4b50aa26efbc429cebe0bc8b3660ad693dd12fd859747b5369be0443e60e53a7850b2bc9d7d0687bcb94386662b40e743596bbf38101
languageName: node
linkType: hard

"url-join@npm:^4.0.1":
version: 4.0.1
resolution: "url-join@npm:4.0.1"
checksum: f74e868bf25dbc8be6a8d7237d4c36bb5b6c62c72e594d5ab1347fe91d6af7ccd9eb5d621e30152e4da45c2e9a26bec21390e911ab54a62d4d82e76028374ee5
languageName: node
linkType: hard

"url-loader@npm:^4.1.1":
version: 4.1.1
resolution: "url-loader@npm:4.1.1"
Expand Down