From 0d47d15a9f6ea9880404595fe86ac19300aec348 Mon Sep 17 00:00:00 2001 From: cxumol Date: Mon, 4 Dec 2023 19:25:05 -0800 Subject: [PATCH] add(feat): Text Embeddings by Cloudflare Workers AI (#14220) Add [Text Embeddings by Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/text-embeddings/). It's a new integration. Trying to align it with its langchain-js version counterpart [here](https://api.js.langchain.com/classes/embeddings_cloudflare_workersai.CloudflareWorkersAIEmbeddings.html). - Dependencies: N/A - Done `make format` `make lint` `make spell_check` `make integration_tests` and all my changes was passed --- .../text_embedding/cloudflare_workersai.ipynb | 125 ++++++++++++++++++ .../embeddings/cloudflare_workersai.py | 94 +++++++++++++ .../embeddings/test_cloudflare_workersai.py | 53 ++++++++ 3 files changed, 272 insertions(+) create mode 100644 docs/docs/integrations/text_embedding/cloudflare_workersai.ipynb create mode 100644 libs/langchain/langchain/embeddings/cloudflare_workersai.py create mode 100644 libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py diff --git a/docs/docs/integrations/text_embedding/cloudflare_workersai.ipynb b/docs/docs/integrations/text_embedding/cloudflare_workersai.ipynb new file mode 100644 index 00000000000..25c6ce96010 --- /dev/null +++ b/docs/docs/integrations/text_embedding/cloudflare_workersai.ipynb @@ -0,0 +1,125 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "59428e05", + "metadata": {}, + "source": [ + "# Text Embeddings on Cloudflare Workers AI\n", + "\n", + "[Cloudflare AI document](https://developers.cloudflare.com/workers-ai/models/text-embeddings/) listed all text embeddings models available.\n", + "\n", + "Both Cloudflare account ID and API token are required. Find how to obtain them from [this document](https://developers.cloudflare.com/workers-ai/get-started/rest-api/).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "92c5b61e", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings.cloudflare_workersai import CloudflareWorkersAIEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f60023b8", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "\n", + "my_account_id = getpass.getpass(\"Enter your Cloudflare account ID:\\n\\n\")\n", + "my_api_token = getpass.getpass(\"Enter your Cloudflare API token:\\n\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "062547b9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(384, [-0.033627357333898544, 0.03982774540781975, 0.03559349477291107])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings = CloudflareWorkersAIEmbeddings(\n", + " account_id=my_account_id,\n", + " api_token=my_api_token,\n", + " model_name=\"@cf/baai/bge-small-en-v1.5\",\n", + ")\n", + "# single string embeddings\n", + "query_result = embeddings.embed_query(\"test\")\n", + "len(query_result), query_result[:3]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e1dcc4bd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 384)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# string embeddings in batches\n", + "batch_query_result = embeddings.embed_documents([\"test1\", \"test2\", \"test3\"])\n", + "len(batch_query_result), len(batch_query_result[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52de8b88", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + }, + "vscode": { + "interpreter": { + "hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/embeddings/cloudflare_workersai.py b/libs/langchain/langchain/embeddings/cloudflare_workersai.py new file mode 100644 index 00000000000..8d7e4103401 --- /dev/null +++ b/libs/langchain/langchain/embeddings/cloudflare_workersai.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, List + +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra + +DEFAULT_MODEL_NAME = "@cf/baai/bge-base-en-v1.5" + + +class CloudflareWorkersAIEmbeddings(BaseModel, Embeddings): + """Cloudflare Workers AI embedding model. + + To use, you need to provide an API token and + account ID to access Cloudflare Workers AI. + + Example: + .. code-block:: python + + from langchain.embeddings import CloudflareWorkersAIEmbeddings + + account_id = "my_account_id" + api_token = "my_secret_api_token" + model_name = "@cf/baai/bge-small-en-v1.5" + + cf = CloudflareWorkersAIEmbeddings( + account_id=account_id, + api_token=api_token, + model_name=model_name + ) + """ + + api_base_url: str = "https://api.cloudflare.com/client/v4/accounts" + account_id: str + api_token: str + model_name: str = DEFAULT_MODEL_NAME + batch_size: int = 50 + strip_new_lines: bool = True + headers: Dict[str, str] = {"Authorization": "Bearer "} + + def __init__(self, **kwargs: Any): + """Initialize the Cloudflare Workers AI client.""" + super().__init__(**kwargs) + + self.headers = {"Authorization": f"Bearer {self.api_token}"} + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using Cloudflare Workers AI. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + if self.strip_new_lines: + texts = [text.replace("\n", " ") for text in texts] + + batches = [ + texts[i : i + self.batch_size] + for i in range(0, len(texts), self.batch_size) + ] + embeddings = [] + + for batch in batches: + response = requests.post( + f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}", + headers=self.headers, + json={"text": batch}, + ) + embeddings.extend(response.json()["result"]["data"]) + + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using Cloudflare Workers AI. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + text = text.replace("\n", " ") if self.strip_new_lines else text + response = requests.post( + f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}", + headers=self.headers, + json={"text": [text]}, + ) + return response.json()["result"]["data"][0] diff --git a/libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py b/libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py new file mode 100644 index 00000000000..24ac0313717 --- /dev/null +++ b/libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py @@ -0,0 +1,53 @@ +"""Test Cloudflare Workers AI embeddings.""" + +import responses + +from langchain.embeddings.cloudflare_workersai import CloudflareWorkersAIEmbeddings + + +@responses.activate +def test_cloudflare_workers_ai_embedding_documents() -> None: + """Test Cloudflare Workers AI embeddings.""" + documents = ["foo bar", "foo bar", "foo bar"] + + responses.add( + responses.POST, + "https://api.cloudflare.com/client/v4/accounts/123/ai/run/@cf/baai/bge-base-en-v1.5", + json={ + "result": { + "shape": [3, 768], + "data": [[0.0] * 768, [0.0] * 768, [0.0] * 768], + }, + "success": "true", + "errors": [], + "messages": [], + }, + ) + + embeddings = CloudflareWorkersAIEmbeddings(account_id="123", api_token="abc") + output = embeddings.embed_documents(documents) + + assert len(output) == 3 + assert len(output[0]) == 768 + + +@responses.activate +def test_cloudflare_workers_ai_embedding_query() -> None: + """Test Cloudflare Workers AI embeddings.""" + + responses.add( + responses.POST, + "https://api.cloudflare.com/client/v4/accounts/123/ai/run/@cf/baai/bge-base-en-v1.5", + json={ + "result": {"shape": [1, 768], "data": [[0.0] * 768]}, + "success": "true", + "errors": [], + "messages": [], + }, + ) + + document = "foo bar" + embeddings = CloudflareWorkersAIEmbeddings(account_id="123", api_token="abc") + output = embeddings.embed_query(document) + + assert len(output) == 768