diff --git a/src/langchain_contrib/.dockerignore b/src/langchain_contrib/.dockerignore new file mode 100644 index 0000000000..6788f45700 --- /dev/null +++ b/src/langchain_contrib/.dockerignore @@ -0,0 +1,6 @@ +.venv +.github +.git +.mypy_cache +.pytest_cache +Dockerfile \ No newline at end of file diff --git a/src/langchain_contrib/.flake8 b/src/langchain_contrib/.flake8 new file mode 100644 index 0000000000..d3ac343b3b --- /dev/null +++ b/src/langchain_contrib/.flake8 @@ -0,0 +1,12 @@ +[flake8] +exclude = + venv + .venv + __pycache__ + notebooks +# Recommend matching the black line length (default 88), +# rather than using the flake8 default of 79: +max-line-length = 88 +extend-ignore = + # See https://github.com/PyCQA/pycodestyle/issues/373 + E203, diff --git a/src/langchain_contrib/.gitattributes b/src/langchain_contrib/.gitattributes new file mode 100644 index 0000000000..5dc46e6b38 --- /dev/null +++ b/src/langchain_contrib/.gitattributes @@ -0,0 +1,3 @@ +* text=auto eol=lf +*.{cmd,[cC][mM][dD]} text eol=crlf +*.{bat,[bB][aA][tT]} text eol=crlf \ No newline at end of file diff --git a/src/langchain_contrib/.gitignore b/src/langchain_contrib/.gitignore new file mode 100644 index 0000000000..260a6d0c7c --- /dev/null +++ b/src/langchain_contrib/.gitignore @@ -0,0 +1,169 @@ +.vs/ +.vscode/ +.idea/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +.venvs +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# macOS display setting files +.DS_Store + +# Wandb directory +wandb/ + +# asdf tool versions +.tool-versions +/.ruff_cache/ + +*.pkl +*.bin + +# integration test artifacts +data_map* +\[('_type', 'fake'), ('stop', None)] + +# Replit files +*replit* + +node_modules +docs/.yarn/ +docs/node_modules/ +docs/.docusaurus/ +docs/.cache-loader/ +docs/_dist +docs/api_reference/_build +docs/docs_skeleton/build +docs/docs_skeleton/node_modules +docs/docs_skeleton/yarn.lock + +sftp-config.json diff --git a/src/langchain_contrib/.gitmodules b/src/langchain_contrib/.gitmodules new file mode 100644 index 0000000000..855d367568 --- /dev/null +++ b/src/langchain_contrib/.gitmodules @@ -0,0 +1,4 @@ +[submodule "docs/_docs_skeleton"] + path = docs/_docs_skeleton + url = https://github.com/langchain-ai/langchain-shared-docs + branch = main diff --git a/src/langchain_contrib/.readthedocs.yaml b/src/langchain_contrib/.readthedocs.yaml new file mode 100644 index 0000000000..15bface37f --- /dev/null +++ b/src/langchain_contrib/.readthedocs.yaml @@ -0,0 +1,29 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.11" + jobs: + pre_build: + - python docs/api_reference/create_api_rst.py + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/api_reference/conf.py + +# If using Sphinx, optionally build your docs in additional formats such as PDF +# formats: +# - pdf + +# Optionally declare the Python requirements required to build your docs +python: + install: + - requirements: docs/requirements.txt + - method: pip + path: . diff --git a/src/langchain_contrib/Dockerfile b/src/langchain_contrib/Dockerfile new file mode 100644 index 0000000000..b950527bb0 --- /dev/null +++ b/src/langchain_contrib/Dockerfile @@ -0,0 +1,48 @@ +# This is a Dockerfile for running unit tests + +ARG POETRY_HOME=/opt/poetry + +# Use the Python base image +FROM python:3.11.2-bullseye AS builder + +# Define the version of Poetry to install (default is 1.4.2) +ARG POETRY_VERSION=1.4.2 + +# Define the directory to install Poetry to (default is /opt/poetry) +ARG POETRY_HOME + +# Create a Python virtual environment for Poetry and install it +RUN python3 -m venv ${POETRY_HOME} && \ + $POETRY_HOME/bin/pip install --upgrade pip && \ + $POETRY_HOME/bin/pip install poetry==${POETRY_VERSION} + +# Test if Poetry is installed in the expected path +RUN echo "Poetry version:" && $POETRY_HOME/bin/poetry --version + +# Set the working directory for the app +WORKDIR /app + +# Use a multi-stage build to install dependencies +FROM builder AS dependencies + +ARG POETRY_HOME + +# Copy only the dependency files for installation +COPY pyproject.toml poetry.lock poetry.toml ./ + +# Install the Poetry dependencies (this layer will be cached as long as the dependencies don't change) +RUN $POETRY_HOME/bin/poetry install --no-interaction --no-ansi --with test + +# Use a multi-stage build to run tests +FROM dependencies AS tests + +# Copy the rest of the app source code (this layer will be invalidated and rebuilt whenever the source code changes) +COPY . . + +RUN /opt/poetry/bin/poetry install --no-interaction --no-ansi --with test + +# Set the entrypoint to run tests using Poetry +ENTRYPOINT ["/opt/poetry/bin/poetry", "run", "pytest"] + +# Set the default command to run all unit tests +CMD ["tests/unit_tests"] diff --git a/src/langchain_contrib/Makefile b/src/langchain_contrib/Makefile new file mode 100644 index 0000000000..1786c77771 --- /dev/null +++ b/src/langchain_contrib/Makefile @@ -0,0 +1,73 @@ +.PHONY: all clean format lint test tests test_watch integration_tests docker_tests help extended_tests + +all: help + +coverage: + poetry run pytest --cov \ + --cov-config=.coveragerc \ + --cov-report xml \ + --cov-report term-missing:skip-covered + +clean: docs_clean + +docs_compile: + poetry run nbdoc_build --srcdir $(srcdir) + +docs_build: + cd docs && poetry run make html + +docs_clean: + cd docs && poetry run make clean + +docs_linkcheck: + poetry run linkchecker docs/_build/html/index.html + +format: + poetry run black . + poetry run ruff --select I --fix . + +PYTHON_FILES=. +lint: PYTHON_FILES=. +lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | grep -E '\.py$$') + +lint lint_diff: + poetry run mypy $(PYTHON_FILES) + poetry run black $(PYTHON_FILES) --check + poetry run ruff . + +TEST_FILE ?= tests/unit_tests/ + +test: + poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE) + +tests: + poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE) + +extended_tests: + poetry run pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests + +test_watch: + poetry run ptw --now . -- tests/unit_tests + +integration_tests: + poetry run pytest tests/integration_tests + +docker_tests: + docker build -t my-langchain-image:test . + docker run --rm my-langchain-image:test + +help: + @echo '----' + @echo 'coverage - run unit tests and generate coverage report' + @echo 'docs_build - build the documentation' + @echo 'docs_clean - clean the documentation build artifacts' + @echo 'docs_linkcheck - run linkchecker on the documentation' + @echo 'format - run code formatters' + @echo 'lint - run linters' + @echo 'test - run unit tests' + @echo 'tests - run unit tests' + @echo 'test TEST_FILE= - run all tests in file' + @echo 'extended_tests - run only extended unit tests' + @echo 'test_watch - run unit tests in watch mode' + @echo 'integration_tests - run integration tests' + @echo 'docker_tests - run unit tests in docker' diff --git a/src/langchain_contrib/README.md b/src/langchain_contrib/README.md new file mode 100644 index 0000000000..6a5144db29 --- /dev/null +++ b/src/langchain_contrib/README.md @@ -0,0 +1,20 @@ +## Repository for langchain's extra modules + +This repository is intended for the development of so-called "extra" modules, +contributed functionality. New modules quite often do not have stable API, +and they are not well-tested. Thus, they shouldn't be released as a part of the +official langchain distribution, since the library maintains binary compatibility, +and tries to provide decent performance and stability. + +So, all the new modules should be developed separately, and published in the +`langchain_contrib` repository at first. Later, when the module matures and gains +popularity, it will create a pr for langchain. + + +### Update the repository documentation + +In order to keep a clean overview containing all contributed modules, the following files need to be created/adapted: + +1. Update the README.md file under the modules folder. Here, you add your model with a single-line description. + +2. Add a README.md inside your own module folder. This README explains which functionality (separate functions) is available, links to the corresponding samples, and explains in somewhat more detail what the module is expected to do. If any extra requirements are needed to build the module without problems, add them here also. diff --git a/src/langchain_contrib/dev.Dockerfile b/src/langchain_contrib/dev.Dockerfile new file mode 100644 index 0000000000..4383ecc895 --- /dev/null +++ b/src/langchain_contrib/dev.Dockerfile @@ -0,0 +1,41 @@ +# This is a Dockerfile for the Development Container + +# Use the Python base image +ARG VARIANT="3.11-bullseye" +FROM mcr.microsoft.com/devcontainers/python:0-${VARIANT} AS langchain-dev-base + +USER vscode + +# Define the version of Poetry to install (default is 1.4.2) +# Define the directory of python virtual environment +ARG PYTHON_VIRTUALENV_HOME=/home/vscode/langchain-py-env \ + POETRY_VERSION=1.3.2 + +ENV POETRY_VIRTUALENVS_IN_PROJECT=false \ + POETRY_NO_INTERACTION=true + +# Create a Python virtual environment for Poetry and install it +RUN python3 -m venv ${PYTHON_VIRTUALENV_HOME} && \ + $PYTHON_VIRTUALENV_HOME/bin/pip install --upgrade pip && \ + $PYTHON_VIRTUALENV_HOME/bin/pip install poetry==${POETRY_VERSION} + +ENV PATH="$PYTHON_VIRTUALENV_HOME/bin:$PATH" \ + VIRTUAL_ENV=$PYTHON_VIRTUALENV_HOME + +# Setup for bash +RUN poetry completions bash >> /home/vscode/.bash_completion && \ + echo "export PATH=$PYTHON_VIRTUALENV_HOME/bin:$PATH" >> ~/.bashrc + +# Set the working directory for the app +WORKDIR /workspaces/langchain + +# Use a multi-stage build to install dependencies +FROM langchain-dev-base AS langchain-dev-dependencies + +ARG PYTHON_VIRTUALENV_HOME + +# Copy only the dependency files for installation +COPY pyproject.toml poetry.toml ./ + +# Install the Poetry dependencies (this layer will be cached as long as the dependencies don't change) +RUN poetry install --no-interaction --no-ansi --with dev,test,docs \ No newline at end of file diff --git a/src/langchain_contrib/langchain_contrib/__init__.py b/src/langchain_contrib/langchain_contrib/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/langchain_contrib/langchain_contrib/chains/__init__.py b/src/langchain_contrib/langchain_contrib/chains/__init__.py new file mode 100644 index 0000000000..59e4fcb506 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chains/__init__.py @@ -0,0 +1,5 @@ +from langchain_contrib.chains.combine_documents.stuff import StuffDocumentsChain + +__all__ = [ + "StuffDocumentsChain", +] \ No newline at end of file diff --git a/src/langchain_contrib/langchain_contrib/chains/combine_documents/__init__.py b/src/langchain_contrib/langchain_contrib/chains/combine_documents/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/langchain_contrib/langchain_contrib/chains/combine_documents/stuff.py b/src/langchain_contrib/langchain_contrib/chains/combine_documents/stuff.py new file mode 100644 index 0000000000..f3bbfbbf00 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chains/combine_documents/stuff.py @@ -0,0 +1,52 @@ +from typing import Any, Dict, List, Optional, Tuple +from langchain.chains.combine_documents.stuff import StuffDocumentsChain as StuffDocumentsChainOld +from langchain.callbacks.manager import Callbacks +from langchain.docstore.document import Document + + +class StuffDocumentsChain(StuffDocumentsChainOld): + + token_max: int = -1 + + def combine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: + """Stuff all documents into one prompt and pass to LLM. + + Args: + docs: List of documents to join together into one variable + callbacks: Optional callbacks to pass along + **kwargs: additional parameters to use to get inputs to LLMChain. + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ + inputs = self._get_inputs(docs, **kwargs) + # print('inputs:', len(inputs['context'])) + # print('prompt_length:', self.prompt_length(docs, **kwargs)) + if self.token_max > 0: + inputs[self.document_variable_name] = inputs[self.document_variable_name][:self.token_max] + # Call predict on the LLM. + return self.llm_chain.predict(callbacks=callbacks, **inputs), {} + + async def acombine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: + """Stuff all documents into one prompt and pass to LLM. + + Args: + docs: List of documents to join together into one variable + callbacks: Optional callbacks to pass along + **kwargs: additional parameters to use to get inputs to LLMChain. + + Returns: + The first element returned is the single string output. The second + element returned is a dictionary of other keys to return. + """ + inputs = self._get_inputs(docs, **kwargs) + if self.token_max > 0: + inputs[self.document_variable_name] = inputs[self.document_variable_name][:self.token_max] + # Call predict on the LLM. + return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {} + diff --git a/src/langchain_contrib/langchain_contrib/chains/question_answering/__init__.py b/src/langchain_contrib/langchain_contrib/chains/question_answering/__init__.py new file mode 100644 index 0000000000..c63f654efc --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chains/question_answering/__init__.py @@ -0,0 +1,250 @@ +"""Load question answering chains.""" +from typing import Any, Mapping, Optional, Protocol + +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks +from langchain.chains import ReduceDocumentsChain +from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain +from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain +from langchain.chains.combine_documents.refine import RefineDocumentsChain +# from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain_contrib.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.llm import LLMChain +from langchain.chains.question_answering import ( + map_reduce_prompt, + refine_prompts, + stuff_prompt, +) +from langchain.chains.question_answering.map_rerank_prompt import ( + PROMPT as MAP_RERANK_PROMPT, +) +from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.prompt_template import BasePromptTemplate + + +class LoadingCallable(Protocol): + """Interface for loading the combine documents chain.""" + + def __call__( + self, llm: BaseLanguageModel, **kwargs: Any + ) -> BaseCombineDocumentsChain: + """Callable to load the combine documents chain.""" + + +def _load_map_rerank_chain( + llm: BaseLanguageModel, + prompt: BasePromptTemplate = MAP_RERANK_PROMPT, + verbose: bool = False, + document_variable_name: str = "context", + rank_key: str = "score", + answer_key: str = "answer", + callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, + **kwargs: Any, +) -> MapRerankDocumentsChain: + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + ) + return MapRerankDocumentsChain( + llm_chain=llm_chain, + rank_key=rank_key, + answer_key=answer_key, + document_variable_name=document_variable_name, + verbose=verbose, + callback_manager=callback_manager, + **kwargs, + ) + + +def _load_stuff_chain( + llm: BaseLanguageModel, + prompt: Optional[BasePromptTemplate] = None, + document_variable_name: str = "context", + verbose: Optional[bool] = None, + callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, + **kwargs: Any, +) -> StuffDocumentsChain: + _prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm) + llm_chain = LLMChain( + llm=llm, + prompt=_prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + ) + # TODO: document prompt + return StuffDocumentsChain( + llm_chain=llm_chain, + document_variable_name=document_variable_name, + verbose=verbose, + callback_manager=callback_manager, + **kwargs, + ) + + +def _load_map_reduce_chain( + llm: BaseLanguageModel, + question_prompt: Optional[BasePromptTemplate] = None, + combine_prompt: Optional[BasePromptTemplate] = None, + combine_document_variable_name: str = "summaries", + map_reduce_document_variable_name: str = "context", + collapse_prompt: Optional[BasePromptTemplate] = None, + reduce_llm: Optional[BaseLanguageModel] = None, + collapse_llm: Optional[BaseLanguageModel] = None, + verbose: Optional[bool] = None, + callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, + token_max: int = 3000, + **kwargs: Any, +) -> MapReduceDocumentsChain: + _question_prompt = ( + question_prompt or map_reduce_prompt.QUESTION_PROMPT_SELECTOR.get_prompt(llm) + ) + _combine_prompt = ( + combine_prompt or map_reduce_prompt.COMBINE_PROMPT_SELECTOR.get_prompt(llm) + ) + map_chain = LLMChain( + llm=llm, + prompt=_question_prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + ) + _reduce_llm = reduce_llm or llm + reduce_chain = LLMChain( + llm=_reduce_llm, + prompt=_combine_prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + ) + # TODO: document prompt + combine_documents_chain = StuffDocumentsChain( + llm_chain=reduce_chain, + document_variable_name=combine_document_variable_name, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + ) + if collapse_prompt is None: + collapse_chain = None + if collapse_llm is not None: + raise ValueError( + "collapse_llm provided, but collapse_prompt was not: please " + "provide one or stop providing collapse_llm." + ) + else: + _collapse_llm = collapse_llm or llm + collapse_chain = StuffDocumentsChain( + llm_chain=LLMChain( + llm=_collapse_llm, + prompt=collapse_prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + ), + document_variable_name=combine_document_variable_name, + verbose=verbose, + callback_manager=callback_manager, + ) + reduce_documents_chain = ReduceDocumentsChain( + combine_documents_chain=combine_documents_chain, + collapse_documents_chain=collapse_chain, + token_max=token_max, + verbose=verbose, + ) + return MapReduceDocumentsChain( + llm_chain=map_chain, + document_variable_name=map_reduce_document_variable_name, + reduce_documents_chain=reduce_documents_chain, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + **kwargs, + ) + + +def _load_refine_chain( + llm: BaseLanguageModel, + question_prompt: Optional[BasePromptTemplate] = None, + refine_prompt: Optional[BasePromptTemplate] = None, + document_variable_name: str = "context_str", + initial_response_name: str = "existing_answer", + refine_llm: Optional[BaseLanguageModel] = None, + verbose: Optional[bool] = None, + callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, + **kwargs: Any, +) -> RefineDocumentsChain: + _question_prompt = ( + question_prompt or refine_prompts.QUESTION_PROMPT_SELECTOR.get_prompt(llm) + ) + _refine_prompt = refine_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt( + llm + ) + initial_chain = LLMChain( + llm=llm, + prompt=_question_prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + ) + _refine_llm = refine_llm or llm + refine_chain = LLMChain( + llm=_refine_llm, + prompt=_refine_prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, + ) + return RefineDocumentsChain( + initial_llm_chain=initial_chain, + refine_llm_chain=refine_chain, + document_variable_name=document_variable_name, + initial_response_name=initial_response_name, + verbose=verbose, + callback_manager=callback_manager, + **kwargs, + ) + + +def load_qa_chain( + llm: BaseLanguageModel, + chain_type: str = "stuff", + verbose: Optional[bool] = None, + callback_manager: Optional[BaseCallbackManager] = None, + **kwargs: Any, +) -> BaseCombineDocumentsChain: + """Load question answering chain. + + Args: + llm: Language Model to use in the chain. + chain_type: Type of document combining chain to use. Should be one of "stuff", + "map_reduce", "map_rerank", and "refine". + verbose: Whether chains should be run in verbose mode or not. Note that this + applies to all chains that make up the final chain. + callback_manager: Callback manager to use for the chain. + + Returns: + A chain to use for question answering. + """ + loader_mapping: Mapping[str, LoadingCallable] = { + "stuff": _load_stuff_chain, + "map_reduce": _load_map_reduce_chain, + "refine": _load_refine_chain, + "map_rerank": _load_map_rerank_chain, + } + if chain_type not in loader_mapping: + raise ValueError( + f"Got unsupported chain type: {chain_type}. " + f"Should be one of {loader_mapping.keys()}" + ) + return loader_mapping[chain_type]( + llm, verbose=verbose, callback_manager=callback_manager, **kwargs + ) diff --git a/src/langchain_contrib/langchain_contrib/chains/retrieval_qa/__init__.py b/src/langchain_contrib/langchain_contrib/chains/retrieval_qa/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/langchain_contrib/langchain_contrib/chains/retrieval_qa/base.py b/src/langchain_contrib/langchain_contrib/chains/retrieval_qa/base.py new file mode 100644 index 0000000000..5bc86ee380 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chains/retrieval_qa/base.py @@ -0,0 +1,100 @@ +from typing import Any, Dict, List, Optional +from langchain.chains.retrieval_qa.base import BaseRetrievalQA +from langchain.schema import BaseRetriever, Document +from langchain.pydantic_v1 import Extra, Field, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + Callbacks, +) + + +class MultiRetrievalQA(BaseRetrievalQA): + """Chain for question-answering against an index. + + Example: + .. code-block:: python + + from langchain.llms import OpenAI + from langchain.chains import RetrievalQA + from langchain.faiss import FAISS + from langchain.vectorstores.base import VectorStoreRetriever + retriever = VectorStoreRetriever(vectorstore=FAISS(...)) + retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever) + + """ + + vector_retriever: BaseRetriever = Field(exclude=True) + keyword_retriever: BaseRetriever = Field(exclude=True) + combine_strategy: str = "keyword_front" # "keyword_front, vector_front, mix" + + def _get_docs( + self, + question: str, + *, + run_manager: CallbackManagerForChainRun, + ) -> List[Document]: + """Get docs.""" + vector_docs = self.vector_retriever.get_relevant_documents( + question, callbacks=run_manager.get_child() + ) + keyword_docs = self.keyword_retriever.get_relevant_documents( + question, callbacks=run_manager.get_child() + ) + if self.combine_strategy == "keyword_front": + return keyword_docs + vector_docs + elif self.combine_strategy == "vector_front": + return vector_docs + keyword_docs + elif self.combine_strategy == "mix": + combine_docs = [] + min_len = min(len(keyword_docs), len(vector_docs)) + for i in range(min_len): + combine_docs.append(keyword_docs[i]) + combine_docs.append(vector_docs[i]) + combine_docs.extend(keyword_docs[min_len:]) + combine_docs.extend(vector_docs[min_len:]) + return combine_docs + else: + raise ValueError( + f"Expected combine_strategy to be one of " + f"(keyword_front, vector_front, mix)," + f"instead found {self.combine_strategy}" + ) + + async def _aget_docs( + self, + question: str, + *, + run_manager: AsyncCallbackManagerForChainRun, + ) -> List[Document]: + """Get docs.""" + vector_docs = await self.vector_retriever.get_relevant_documents( + question, callbacks=run_manager.get_child() + ) + keyword_docs = await self.keyword_retriever.get_relevant_documents( + question, callbacks=run_manager.get_child() + ) + if self.combine_strategy == "keyword_front": + return keyword_docs + vector_docs + elif self.combine_strategy == "vector_front": + return vector_docs + keyword_docs + elif self.combine_strategy == "mix": + combine_docs = [] + min_len = min(len(keyword_docs), len(vector_docs)) + for i in range(min_len): + combine_docs.append(keyword_docs[i]) + combine_docs.append(vector_docs[i]) + combine_docs.extend(keyword_docs[min_len:]) + combine_docs.extend(vector_docs[min_len:]) + return combine_docs + else: + raise ValueError( + f"Expected combine_strategy to be one of " + f"(keyword_front, vector_front, mix)," + f"instead found {self.combine_strategy}" + ) + + @property + def _chain_type(self) -> str: + """Return the chain type.""" + return "multi_retrieval_qa" diff --git a/src/langchain_contrib/langchain_contrib/chat_models/__init__.py b/src/langchain_contrib/langchain_contrib/chat_models/__init__.py new file mode 100644 index 0000000000..5f1fac9140 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/__init__.py @@ -0,0 +1,11 @@ +from .proxy_llm import ProxyChatLLM +from .minimax import ChatMinimaxAI +from .wenxin import ChatWenxin +from .zhipuai import ChatZhipuAI +from .xunfeiai import ChatXunfeiAI +from .host_llm import Llama2Chat, ChatGLM2Host, BaichuanChat, QwenChat + +__all__ = [ + 'ProxyChatLLM', 'ChatMinimaxAI', 'ChatWenxin', 'ChatZhipuAI', + 'ChatXunfeiAI', 'Llama2Chat', 'ChatGLM2Host', 'BaichuanChat','QwenChat' +] diff --git a/src/langchain_contrib/langchain_contrib/chat_models/host_llm.py b/src/langchain_contrib/langchain_contrib/chat_models/host_llm.py new file mode 100644 index 0000000000..9831ffba14 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/host_llm.py @@ -0,0 +1,437 @@ +"""proxy llm chat wrapper.""" +from __future__ import annotations + +import requests +import logging +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from pydantic import Field, root_validator +from requests.exceptions import HTTPError +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain.utils import get_from_dict_or_env + +if TYPE_CHECKING: + import tiktoken + +from .interface import MinimaxChatCompletion +from .interface.types import ChatInput + +logger = logging.getLogger(__name__) + + +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + +def _create_retry_decorator(llm: BaseHostChatLLM) -> Callable[[Any], Any]: + + min_seconds = 1 + max_seconds = 20 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(Exception) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + content = _dict["content"] or "" # OpenAI returns None for tool invocations + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +class BaseHostChatLLM(BaseChatModel): + """Wrapper around base host Chat large language models. + """ + + client: Optional[Any] #: :meta private: + + """Model name to use.""" + model_name: str = Field("", alias="model") + + temperature: float = 0.9 + top_p: float = 0.95 + do_sample: bool = False + + """Number of chat completions to generate for each prompt.""" + max_tokens: int = 4096 + + """What sampling temperature to use.""" + model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + host_base_url: Optional[str] = None + + headers: Optional[Dict[str, str]] = Field(default_factory=dict) + + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: Optional[int] = 6 + """Maximum number of retries to make when generating.""" + streaming: Optional[bool] = False + """Whether to stream the results or not.""" + n: Optional[int] = 1 + + """Maximum number of tokens to generate.""" + tiktoken_model_name: Optional[str] = None + """The model name to pass to tiktoken when using this class. + Tiktoken is used to count the number of tokens in documents to constrain + them to be under a certain limit. By default, when set to None, this will + be the same as the embedding model name. However, there are some cases + where you may want to use this Embedding class with a model name not + supported by tiktoken. This can include when using Azure embeddings or + when using one of the many model providers that expose an OpenAI-like + API but with different models. In those cases, in order to avoid erroring + when tiktoken is called, you can specify a model name to use here.""" + + verbose: Optional[bool] = False + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["host_base_url"] = get_from_dict_or_env( + values, "host_base_url", "HostBaseUrl" + ) + try: + values["client"] = requests.post + except AttributeError: + raise ValueError( + "Try upgrading it with `pip install --upgrade requests`." + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling ChatMinimaxAI API.""" + return { + "model": self.model_name, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + "do_sample": self.do_sample, + **self.model_kwargs, + } + + def completion_with_retry(self, **kwargs: Any) -> Any: + retry_decorator = _create_retry_decorator(self) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + messages = kwargs.get('messages') + temperature = kwargs.get('temperature') + top_p = kwargs.get('top_p') + max_tokens = kwargs.get('max_tokens') + do_sample = kwargs.get('do_sample') + params = { + 'messages': messages, + 'model': self.model_name, + 'top_p': top_p, + 'temperature': temperature, + "max_tokens": max_tokens, + 'do_sample': do_sample} + + if self.verbose: + print('payload', params) + + url = f"{self.host_base_url}/{self.model_name}/infer" + resp = self.client(url=url, json=params).json() + if resp['status_code'] != 200: + raise ValueError( + f"API returned an error: {resp['status_message']}" + ) + resp["usage"] = {} + return resp + return _completion_with_retry(**kwargs) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + if token_usage is None: continue + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage, "model_name": self.model_name} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + response = self.completion_with_retry(messages=message_dicts, **params) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return self._generate(messages, stop, run_manager, kwargs) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = dict(self._client_params) + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + message_dicts = [_convert_message_to_dict(m) for m in messages] + + return message_dicts, params + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res['message']) + gen = ChatGeneration(message=message) + generations.append(gen) + + llm_output = { + "token_usage": response['usage'], + "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _client_params(self) -> Mapping[str, Any]: + """Get the parameters used for the client.""" + minimaxai_creds: Dict[str, Any] = { + "model": self.model_name, + } + return {**minimaxai_creds, **self._default_params} + + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **super()._get_invocation_params(stop=stop, **kwargs), + **self._default_params, + "model": self.model_name, + "function": kwargs.get("functions"), + } + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "host_chat_llm" + + def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: + tiktoken_ = _import_tiktoken() + if self.tiktoken_model_name is not None: + model = self.tiktoken_model_name + else: + model = self.model_name + # model chatglm-std, chatglm-lite + # Returns the number of tokens used by a list of messages. + try: + encoding = tiktoken_.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + return model, encoding + + def get_token_ids(self, text: str) -> List[int]: + """Get the tokens present in the text with tiktoken package.""" + # tiktoken NOT supported for Python 3.7 or below + if sys.version_info[1] <= 7: + return super().get_token_ids(text) + _, encoding_model = self._get_encoding_model() + return encoding_model.encode(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Calculate num tokens for chatglm with tiktoken package. + + todo: read chatglm document + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + if sys.version_info[1] <= 7: + return super().get_num_tokens_from_messages(messages) + model, encoding = self._get_encoding_model() + if model.startswith("chatglm"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + messages_dict = [_convert_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + # every reply is primed with assistant + num_tokens += 3 + return num_tokens + + +class ChatGLM2Host(BaseHostChatLLM): + # chatglm2-12b, chatglm2-6b + model_name: str = Field("chatglm2-6b", alias="model") + + temperature: float = 0.95 + top_p: float = 0.7 + max_tokens: int = 4096 + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "chatglm2" + + +class BaichuanChat(BaseHostChatLLM): + # Baichuan-7B-Chat, Baichuan-13B-Chat + model_name: str = Field("Baichuan-13B-Chat", alias="model") + + temperature: float = 0.3 + top_p: float = 0.85 + max_tokens: int = 8192 + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "baichang_chat" + + +class QwenChat(BaseHostChatLLM): + # Qwen-7B-Chat + model_name: str = Field("Qwen-7B-Chat", alias="model") + + temperature: float = 0 + top_p: float = 0.5 + max_tokens: int = 8192 + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "qwen_chat" + +class Llama2Chat(BaseHostChatLLM): + # Llama-2-7b-chat-hf, Llama-2-13b-chat-hf, Llama-2-70b-chat-hf + model_name: str = Field("Llama-2-7b-chat-hf", alias="model") + + temperature: float = 0.9 + top_p: float = 0.6 + max_tokens: int = 8192 + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "llama2_chat" + diff --git a/src/langchain_contrib/langchain_contrib/chat_models/interface/__init__.py b/src/langchain_contrib/langchain_contrib/chat_models/interface/__init__.py new file mode 100644 index 0000000000..5b7be49c31 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/interface/__init__.py @@ -0,0 +1,6 @@ + +from .minimax import ChatCompletion as MinimaxChatCompletion +from .openai import ChatCompletion as OpenaiChatCompletion +from .wenxin import ChatCompletion as WenxinChatCompletion +from .xunfei import ChatCompletion as XunfeiChatCompletion +from .zhipuai import ChatCompletion as ZhipuaiChatCompletion diff --git a/src/langchain_contrib/langchain_contrib/chat_models/interface/minimax.py b/src/langchain_contrib/langchain_contrib/chat_models/interface/minimax.py new file mode 100644 index 0000000000..8bb44d63e1 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/interface/minimax.py @@ -0,0 +1,118 @@ +import requests +import json + +from .types import (ChatInput, ChatOutput, Message, Choice, Usage) +from .utils import get_ts + + +class ChatCompletion(object): + def __init__(self, group_id, api_key, **kwargs): + ep_url = "https://api.minimax.chat/v1/text/chatcompletion" + self.endpoint = f"{ep_url}?GroupId={group_id}" + self.headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + def parseChunkDelta(self, chunk) : + decoded_data = chunk.decode('utf-8') + parsed_data = json.loads(decoded_data[6:]) + delta_content = parsed_data['choices'][0]['delta'] + return delta_content + + def __call__(self, inp: ChatInput, verbose=False): + messages = inp.messages + model = inp.model + top_p = 0.95 if inp.top_p is None else inp.top_p + temperature = 0.9 if inp.temperature is None else inp.temperature + stream = False if inp.stream is None else inp.stream + max_tokens = 1024 if inp.max_tokens is None else inp.max_tokens + if abs(temperature) <= 1e-6: + temperature = 1e-6 + + chat_messages = messages + system_prompt = ( + 'MM智能助理是一款由MinMax自研的,没有调用其他产品接口的大型语言' + '模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。\n----\n') + + if messages[0].role == 'system': + system_prompt = messages[0].content + chat_messages = messages[1:] + + new_messages = [] + for m in chat_messages: + role = 'USER' + if m.role == 'system' or m.role == 'assistant': + role = 'BOT' + + new_messages.append({'sender_type': role, 'text': m.content}) + + # role_meta is given, prompt must is not empty + system_info = {} + if system_prompt: + system_info = { + "prompt": system_prompt, + "role_meta": { + "user_name": "用户", + "bot_name": "MM智能助理" + } + } + + payload = { + "model": model, + "stream": stream, + "use_standard_sse": True, + "messages": new_messages, + "temperature": temperature, + "top_p": top_p, + "tokens_to_generate": max_tokens + } + payload.update(system_info) + + if verbose: + print('payload', payload) + + response = requests.post(self.endpoint, + headers=self.headers, json=payload) + + req_type = 'chat.completion' + status_message = 'success' + status_code = response.status_code + created = get_ts() + choices = [] + usage = None + if status_code == 200: + try: + info = json.loads(response.text) + if info['base_resp']['status_code'] == 0: + created = info['created'] + reply = info['reply'] + choices = [] + for s in info['choices']: + index = s['index'] + finish_reason = s['finish_reason'] + msg = Message(role='assistant', content=s['text']) + cho = Choice(index=index, message=msg, + finish_reason=finish_reason) + choices.append(cho) + total_tokens = info['usage']['total_tokens'] + usage = Usage(total_tokens=total_tokens) + else: + status_code = info['base_resp']['status_code'] + status_message = info['base_resp']['status_msg'] + + except Exception as e: + status_code = 401 + status_message = str(e) + else: + status_code = 400 + status_message = "requests error" + + if status_code != 200: + raise Exception(status_message) + + return ChatOutput( + status_code=status_code, + status_message=status_message, + model=model, object=req_type, created=created, + choices=choices, usage=usage) diff --git a/src/langchain_contrib/langchain_contrib/chat_models/interface/openai.py b/src/langchain_contrib/langchain_contrib/chat_models/interface/openai.py new file mode 100644 index 0000000000..c689cf752b --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/interface/openai.py @@ -0,0 +1,67 @@ +import openai +import json + +from .types import (ChatInput, ChatOutput, Message, Choice, Usage) +from .utils import get_ts + + +class ChatCompletion(object): + def __init__(self, api_key, proxy=None, **kwargs): + openai.api_key = api_key + openai.proxy = proxy + + def __call__(self, inp: ChatInput, verbose=False): + messages = inp.messages + model = inp.model + top_p = 0.7 if inp.top_p is None else inp.top_p + temperature = 0.97 if inp.temperature is None else inp.temperature + stream = False if inp.stream is None else inp.stream + max_tokens = 1024 if inp.max_tokens is None else inp.max_tokens + stop = None + if inp.stop is not None: + stop = inp.stop.split('||') + + new_messages = [m.dict() for m in messages] + created = get_ts() + payload = { + 'model': model, + "messages": new_messages, + "temperature": temperature, + "top_p": top_p, + "stop": stop, + "max_tokens": max_tokens, + } + if inp.functions: + payload.update({'functions': inp.functions}) + + if verbose: + print('payload', payload) + + req_type = 'chat.completion' + status_message = 'success' + choices = [] + usage = None + try: + resp = openai.ChatCompletion.create(**payload) + status_code = 200 + choices = [] + for choice in resp['choices']: + cho = Choice(**choice) + choices.append(cho) + usage = Usage(**resp['usage']) + + except Exception as e: + status_code = 400 + status_message = str(e) + + if status_code != 200: + raise Exception(status_message) + + return ChatOutput( + status_code=status_code, + status_message=status_message, + model=model, + object=req_type, + created=created, + choices=choices, + usage=usage) diff --git a/src/langchain_contrib/langchain_contrib/chat_models/interface/types.py b/src/langchain_contrib/langchain_contrib/chat_models/interface/types.py new file mode 100644 index 0000000000..da923f329f --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/interface/types.py @@ -0,0 +1,54 @@ +from typing import Union +from pydantic import BaseModel + + +class Message(BaseModel): + role: str + content: str + +class Function(BaseModel): + name: str + description: str + parameters: dict + +class ChatInput(BaseModel): + model: str + messages: list[Message] = [] + top_p: float = None + temperature: float = None + n: int = 1 + stream: bool = False + stop: str = None + max_tokens: int = 256 + functions: list[Function] = [] + function_call: str = None + +class Choice(BaseModel): + index: int + message: Message = None + finish_reason: str = "stop" + +class Usage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int + +class ChatOutput(BaseModel): + status_code: int + status_message: str = 'success' + id: str = None + object: str = None + model:str = None + created: int = None + choices: list[Choice] = [] + usage: Usage = None + +class CompletionsInput(BaseModel): + model: str + prompt: str + top_p: float = None + temperature: float = None + n: int = 1 + stream: bool = True + stop: str = None + max_tokens: int = 256 diff --git a/src/langchain_contrib/langchain_contrib/chat_models/interface/utils.py b/src/langchain_contrib/langchain_contrib/chat_models/interface/utils.py new file mode 100644 index 0000000000..b12635ca52 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/interface/utils.py @@ -0,0 +1,4 @@ +import time + +def get_ts(): + return round(time.time() * 1000) diff --git a/src/langchain_contrib/langchain_contrib/chat_models/interface/wenxin.py b/src/langchain_contrib/langchain_contrib/chat_models/interface/wenxin.py new file mode 100644 index 0000000000..946a0f3e86 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/interface/wenxin.py @@ -0,0 +1,109 @@ +import requests +import json + +from .types import (ChatInput, ChatOutput, Message, Choice, Usage) +from .utils import get_ts + + +def get_access_token(api_key, sec_key): + url = (f"https://aip.baidubce.com/oauth/2.0/token?" + f"grant_type=client_credentials" + f"&client_id={api_key}&client_secret={sec_key}") + + payload = json.dumps("") + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json' + } + + response = requests.request("POST", url, headers=headers, data=payload) + return response.json().get("access_token") + + +class ChatCompletion(object): + def __init__(self, api_key, sec_key, **kwargs): + self.api_key = api_key + self.sec_key = sec_key + self.ep_url = ("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/" + "wenxinworkshop/chat/completions") + self.ep_url_turbo = ("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/" + "wenxinworkshop/chat/eb-instant") + + # token = get_access_token(api_key, sec_key) + # self.endpoint = f"{self.ep_url}?access_token={token}" + self.headers = { + 'Content-Type': 'application/json' + } + + def __call__(self, inp: ChatInput, verbose=False): + messages = inp.messages + model = inp.model + top_p = 0.8 if inp.top_p is None else inp.top_p + temperature = 0.95 if inp.temperature is None else inp.temperature + stream = False if inp.stream is None else inp.stream + max_tokens = 1024 if inp.max_tokens is None else inp.max_tokens + + system_content = '' + new_messages = [] + for m in messages: + role = m.role + if role == 'system': + system_content = m.content + continue + new_messages.append({'role': role, 'content': m.content}) + + if system_content: + new_messages[-1]['content'] = system_content + '\n' + new_messages[-1]['content'] + + payload = { + "stream": stream, + "messages": new_messages, + "temperature": temperature, + "top_p": top_p + } + + if verbose: + print('payload', payload) + + token = get_access_token(self.api_key, self.sec_key) + endpoint = f"{self.ep_url}?access_token={token}" + if model == "ernie-bot-turbo": + endpoint = f"{self.ep_url_turbo}?access_token={token}" + + response = requests.post(endpoint, + headers=self.headers, json=payload) + + req_type = 'chat.completion' + status_message = 'success' + status_code = response.status_code + created = get_ts() + choices = [] + usage = None + if status_code == 200: + try: + info = json.loads(response.text) + status_code = info.get('error_code', 200) + status_message = info.get('error_msg', status_message) + if status_code == 200: + created = info['created'] + result = info['result'] + finish_reason = 'default' + msg = Message(role='assistant', content=result) + choices = [Choice(index=0, message=msg, + finish_reason=finish_reason)] + usage = Usage(**info['usage']) + except Exception as e: + status_code = 401 + status_message = str(e) + else: + status_code = 400 + status_message = "requests error" + + if status_code != 200: + raise Exception(status_message) + + return ChatOutput( + status_code=status_code, + status_message=status_message, + model=model, object=req_type, created=created, + choices=choices, usage=usage) diff --git a/src/langchain_contrib/langchain_contrib/chat_models/interface/xunfei.py b/src/langchain_contrib/langchain_contrib/chat_models/interface/xunfei.py new file mode 100644 index 0000000000..cfe1ec7aa3 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/interface/xunfei.py @@ -0,0 +1,227 @@ +import _thread as thread +import threading +import base64 +import datetime +import hashlib +import hmac +import json +from urllib.parse import urlparse +import ssl +from datetime import datetime +from time import mktime +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time + +import websocket +from websocket import create_connection + +from .types import (ChatInput, ChatOutput, Message, Choice, Usage) +from .utils import get_ts + + +class Ws_Param(object): + # 初始化 + def __init__(self, APPID, APIKey, APISecret, gpt_url): + self.APPID = APPID + self.APIKey = APIKey + self.APISecret = APISecret + self.host = urlparse(gpt_url).netloc + self.path = urlparse(gpt_url).path + self.gpt_url = gpt_url + + # 生成url + def create_url(self): + # 生成RFC1123格式的时间戳 + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + # 拼接字符串 + signature_origin = "host: " + self.host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + self.path + " HTTP/1.1" + + # 进行hmac-sha256进行加密 + signature_sha = hmac.new( + self.APISecret.encode('utf-8'), + signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + + signature_sha_base64 = base64.b64encode( + signature_sha).decode(encoding='utf-8') + + authorization_origin = ( + f'api_key="{self.APIKey}", ' + f'algorithm="hmac-sha256", headers="host date request-line",' + f' signature="{signature_sha_base64}"') + + authorization = base64.b64encode( + authorization_origin.encode('utf-8')).decode(encoding='utf-8') + + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": self.host + } + # 拼接鉴权参数,生成url + url = self.gpt_url + '?' + urlencode(v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释, + # 比对相同参数时生成的url与自己代码生成的url是否一致 + return url + + +# 收到websocket错误的处理 +def on_error(ws, error): + print("### error:", error) + + +# 收到websocket关闭的处理 +def on_close(ws): + print("### closed ###") + + +# 收到websocket连接建立的处理 +def on_open(ws): + thread.start_new_thread(run, (ws,)) + + +def run(ws, *args): + data = json.dumps(gen_params(appid=ws.appid, question=ws.question)) + ws.send(data) + + +# 收到websocket消息的处理 +def on_message(ws, message): + print(message) + data = json.loads(message) + code = data['header']['code'] + if code != 0: + print(f'请求错误: {code}, {data}') + ws.close() + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + print(content, end='') + if status == 2: + ws.close() + + +def gen_params(appid, question): + data = { + "header": { + "app_id": appid, + "uid": "1234" + }, + "parameter": { + "chat": { + "domain": "general", + "random_threshold": 0.5, + "max_tokens": 2048, + "auditing": "default" + } + }, + "payload": { + "message": { + "text": [ + {"role": "user", "content": question} + ] + } + } + } + return data + + +class ChatCompletion(object): + def __init__(self, appid, api_key, api_secret, **kwargs): + gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat" + self.wsParam = Ws_Param(appid, api_key, api_secret, gpt_url) + websocket.enableTrace(False) + # wsUrl = wsParam.create_url() + + # todo: modify to the ws pool + # self.mutex = threading.Lock() + # self.ws = websocket.WebSocket() + # self.ws.connect(wsUrl) + + self.header = {'app_id': appid, "uid": "elem"} + + def __call__(self, inp: ChatInput, verbose=False): + messages = inp.messages + model = inp.model + top_p = 0.7 if inp.top_p is None else inp.top_p + temperature = 0.5 if inp.temperature is None else inp.temperature + stream = False if inp.stream is None else inp.stream + max_tokens = 1024 if inp.max_tokens is None else inp.max_tokens + stop = None + if inp.stop is not None: + stop = inp.stop.split('||') + + new_messages = [] + for m in messages: + role = m.role + if role == 'system': + role = 'user' + new_messages.append({'role': role, 'content': m.content}) + + created = get_ts() + payload = { + 'header': self.header, + "payload": {"message": {"text": new_messages}}, + "parameter": { + "chat": { + "domain": "general", + "temperature": temperature, + "max_tokens": max_tokens, + "auditing": "default" + } + } + } + + if verbose: + print('payload', payload) + + req_type = 'chat.completion' + status_code = 200 + status_message = 'success' + choices = [] + usage = None + texts = [] + ws = None + try: + # self.mutex.acquire() + wsUrl = self.wsParam.create_url() + ws = create_connection(wsUrl) + ws.send(json.dumps(payload)) + while True: + raw_data = ws.recv() + if not raw_data: break + resp = json.loads(raw_data) + if resp['header']['code'] == 0: + texts.append( + resp['payload']['choices']['text'][0]['content']) + if resp['header']['code'] == 0 and resp['header']['status'] == 2: + usage_dict = resp['payload']['usage']['text'] + usage_dict.pop('question_tokens') + usage = Usage(**usage_dict) + except Exception as e: + status_code = 401 + status_message = str(e) + finally: + if ws: ws.close() + + if texts: + finish_reason = 'default' + msg = Message(role='assistant', content=''.join(texts)) + cho = Choice(index=0, message=msg, + finish_reason=finish_reason) + choices.append(cho) + + if status_code != 200: + raise Exception(status_message) + + return ChatOutput( + status_code=status_code, + status_message=status_message, + model=model, object=req_type, created=created, + choices=choices, usage=usage) diff --git a/src/langchain_contrib/langchain_contrib/chat_models/interface/zhipuai.py b/src/langchain_contrib/langchain_contrib/chat_models/interface/zhipuai.py new file mode 100644 index 0000000000..ee856b96fd --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/interface/zhipuai.py @@ -0,0 +1,77 @@ +import zhipuai +import json + +from .types import (ChatInput, ChatOutput, Message, Choice, Usage) +from .utils import get_ts + + +class ChatCompletion(object): + def __init__(self, api_key, **kwargs): + zhipuai.api_key = api_key + + def __call__(self, inp: ChatInput, verbose=False): + messages = inp.messages + model = inp.model + top_p = 0.7 if inp.top_p is None else inp.top_p + temperature = 0.95 if inp.temperature is None else inp.temperature + stream = False if inp.stream is None else inp.stream + max_tokens = 1024 if inp.max_tokens is None else inp.max_tokens + + + new_messages = [] + system_content= '' + for m in messages: + content = m.content + role = m.role + if role == 'system': + system_content += content + continue + new_messages.append({'role': role, 'content': content}) + + if system_content: + new_messages[-1]['content'] = ( + system_content + new_messages[-1]['content']) + + created = get_ts() + payload = { + 'model': model, + "prompt": new_messages, + "temperature": temperature, + "top_p": top_p, + "request_id": str(created), + "incremental": False + } + + if verbose: + print('payload', payload) + + req_type = 'chat.completion' + status_message = 'success' + choices = [] + usage = None + try: + resp = zhipuai.model_api.invoke(**payload) + status_code = resp['code'] + status_message = resp['msg'] + if status_code == 200: + choices = [] + for index, choice in enumerate(resp['data']['choices']): + finish_reason = 'default' + msg = Message(**choice) + cho = Choice(index=index, message=msg, + finish_reason=finish_reason) + choices.append(cho) + usage = Usage(**resp['data']['usage']) + + except Exception as e: + status_code = 400 + status_message = str(e) + + if status_code != 200: + raise Exception(status_message) + + return ChatOutput( + status_code=status_code, + status_message=status_message, + model=model, object=req_type, created=created, + choices=choices, usage=usage) diff --git a/src/langchain_contrib/langchain_contrib/chat_models/minimax.py b/src/langchain_contrib/langchain_contrib/chat_models/minimax.py new file mode 100644 index 0000000000..3985332898 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/minimax.py @@ -0,0 +1,380 @@ +"""proxy llm chat wrapper.""" +from __future__ import annotations + +import requests +import logging +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from pydantic import Field, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain.utils import get_from_dict_or_env + +if TYPE_CHECKING: + import tiktoken + +from .interface import MinimaxChatCompletion +from .interface.types import ChatInput + +logger = logging.getLogger(__name__) + + +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + +def _create_retry_decorator(llm: ChatMinimaxAI) -> Callable[[Any], Any]: + + min_seconds = 1 + max_seconds = 20 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(Exception) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + content = _dict["content"] or "" # OpenAI returns None for tool invocations + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +class ChatMinimaxAI(BaseChatModel): + """Wrapper around proxy Chat large language models. + + To use, the environment variable ``ELEMAI_API_KEY`` set with your API key. + + Example: + .. code-block:: python + + from langchain_contrib.chat_models import ChatMinimaxAI + chat_miniamaxai = ChatMinimaxAI(model_name="abab5.5-chat") + """ + + client: Optional[Any] #: :meta private: + + """Model name to use.""" + model_name: str = Field("abab5.5-chat", alias="model") + + temperature: float = 0.9 + top_p: float = 0.95 + """What sampling temperature to use.""" + model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + minimaxai_api_key: Optional[str] = None + minimaxai_group_id: Optional[str] = None + + headers: Optional[Dict[str, str]] = Field(default_factory=dict) + + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: Optional[int] = 6 + """Maximum number of retries to make when generating.""" + streaming: Optional[bool] = False + """Whether to stream the results or not.""" + n: Optional[int] = 1 + """Number of chat completions to generate for each prompt.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + tiktoken_model_name: Optional[str] = None + """The model name to pass to tiktoken when using this class. + Tiktoken is used to count the number of tokens in documents to constrain + them to be under a certain limit. By default, when set to None, this will + be the same as the embedding model name. However, there are some cases + where you may want to use this Embedding class with a model name not + supported by tiktoken. This can include when using Azure embeddings or + when using one of the many model providers that expose an OpenAI-like + API but with different models. In those cases, in order to avoid erroring + when tiktoken is called, you can specify a model name to use here.""" + + verbose: Optional[bool] = False + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["minimaxai_api_key"] = get_from_dict_or_env( + values, "minimaxai_api_key", "MINIMAXAI_API_KEY" + ) + + values["minimaxai_group_id"] = get_from_dict_or_env( + values, "minimaxai_group_id", "MINIMAXAI_GROUP_ID" + ) + + api_key = values["minimaxai_api_key"] + group_id = values["minimaxai_group_id"] + try: + values["client"] = MinimaxChatCompletion(group_id, api_key) + except AttributeError: + raise ValueError( + "Try upgrading it with `pip install --upgrade requests`." + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling ChatMinimaxAI API.""" + return { + "model": self.model_name, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + **self.model_kwargs, + } + + def completion_with_retry(self, **kwargs: Any) -> Any: + retry_decorator = _create_retry_decorator(self) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + messages = kwargs.get('messages') + temperature = kwargs.get('temperature') + top_p = kwargs.get('top_p') + max_tokens = kwargs.get('max_tokens') + params = { + 'messages': messages, + 'model': self.model_name, + 'top_p': top_p, + 'temperature': temperature, + "max_tokens": max_tokens} + return self.client(ChatInput.parse_obj(params), self.verbose).dict() + + return _completion_with_retry(**kwargs) + + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + if token_usage is None: continue + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage, "model_name": self.model_name} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + response = self.completion_with_retry(messages=message_dicts, **params) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return self._generate(messages, stop, run_manager, kwargs) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = dict(self._client_params) + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + message_dicts = [_convert_message_to_dict(m) for m in messages] + + return message_dicts, params + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res['message']) + gen = ChatGeneration(message=message) + generations.append(gen) + + llm_output = { + "token_usage": response['usage'], + "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _client_params(self) -> Mapping[str, Any]: + """Get the parameters used for the client.""" + minimaxai_creds: Dict[str, Any] = { + "model": self.model_name, + } + return {**minimaxai_creds, **self._default_params} + + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **super()._get_invocation_params(stop=stop, **kwargs), + **self._default_params, + "model": self.model_name, + "function": kwargs.get("functions"), + } + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "minimaxai_chat" + + def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: + tiktoken_ = _import_tiktoken() + if self.tiktoken_model_name is not None: + model = self.tiktoken_model_name + else: + model = self.model_name + # model chatglm-std, chatglm-lite + # Returns the number of tokens used by a list of messages. + try: + encoding = tiktoken_.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + return model, encoding + + def get_token_ids(self, text: str) -> List[int]: + """Get the tokens present in the text with tiktoken package.""" + # tiktoken NOT supported for Python 3.7 or below + if sys.version_info[1] <= 7: + return super().get_token_ids(text) + _, encoding_model = self._get_encoding_model() + return encoding_model.encode(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Calculate num tokens for chatglm with tiktoken package. + + todo: read chatglm document + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + if sys.version_info[1] <= 7: + return super().get_num_tokens_from_messages(messages) + model, encoding = self._get_encoding_model() + if model.startswith("chatglm"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + messages_dict = [_convert_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + # every reply is primed with assistant + num_tokens += 3 + return num_tokens diff --git a/src/langchain_contrib/langchain_contrib/chat_models/proxy_llm.py b/src/langchain_contrib/langchain_contrib/chat_models/proxy_llm.py new file mode 100644 index 0000000000..22a78465c0 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/proxy_llm.py @@ -0,0 +1,381 @@ +"""proxy llm chat wrapper.""" +from __future__ import annotations + +import requests +import logging +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from pydantic import Field, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain.utils import get_from_dict_or_env + +if TYPE_CHECKING: + import tiktoken + +logger = logging.getLogger(__name__) + + +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + +def _create_retry_decorator(llm: ProxyChatLLM) -> Callable[[Any], Any]: + + min_seconds = 1 + max_seconds = 20 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(Exception) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + content = _dict["content"] or "" # OpenAI returns None for tool invocations + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +class ProxyChatLLM(BaseChatModel): + """Wrapper around proxy Chat large language models. + + To use, the environment variable ``ELEMAI_API_KEY`` set with your API key. + + Example: + .. code-block:: python + + from langchain_contrib.chat_models import ProxyChatLLM + proxy_chat_llm = ProxyChatLLM(model_name="chatglm_std") + """ + + client: Optional[Any] #: :meta private: + + """Model name to use.""" + model_name: str = Field("chatglm_std", alias="model") + + temperature: float = 0.7 + top_p: float = 0.9 + """What sampling temperature to use.""" + model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + elemai_api_key: Optional[str] = None + elemai_base_url: Optional[str] = None + + headers: Optional[Dict[str, str]] = Field(default_factory=dict) + + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: Optional[int] = 6 + """Maximum number of retries to make when generating.""" + streaming: Optional[bool] = False + """Whether to stream the results or not.""" + n: Optional[int] = 1 + """Number of chat completions to generate for each prompt.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + tiktoken_model_name: Optional[str] = None + """The model name to pass to tiktoken when using this class. + Tiktoken is used to count the number of tokens in documents to constrain + them to be under a certain limit. By default, when set to None, this will + be the same as the embedding model name. However, there are some cases + where you may want to use this Embedding class with a model name not + supported by tiktoken. This can include when using Azure embeddings or + when using one of the many model providers that expose an OpenAI-like + API but with different models. In those cases, in order to avoid erroring + when tiktoken is called, you can specify a model name to use here.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["elemai_api_key"] = get_from_dict_or_env( + values, "elemai_api_key", "ELEMAI_API_KEY" + ) + + values["elemai_base_url"] = get_from_dict_or_env( + values, "elemai_base_url", "ELEMAI_BASE_URL" + ) + + elemai_api_key = values["elemai_api_key"] + values["headers"] = { + "Authorization": f"Bearer {elemai_api_key}", + "Content-Type": "application/json" + } + + try: + values["client"] = requests.post + except AttributeError: + raise ValueError( + "Try upgrading it with `pip install --upgrade requests`." + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling ProxyChatLLM API.""" + return { + "model": self.model_name, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + **self.model_kwargs, + } + + def completion_with_retry(self, **kwargs: Any) -> Any: + retry_decorator = _create_retry_decorator(self) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + messages = kwargs.get('messages') + temperature = kwargs.get('temperature') + top_p = kwargs.get('top_p') + max_tokens = kwargs.get('max_tokens') + params = { + 'messages': messages, + 'model': self.model_name, + 'top_p': top_p, + 'temperature': temperature, + "max_tokens": max_tokens} + return self.client( + self.elemai_base_url, headers=self.headers, json=params).json() + + return _completion_with_retry(**kwargs) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage, "model_name": self.model_name} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + + response = self.completion_with_retry(messages=message_dicts, **params) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return self._generate(messages, stop, run_manager, kwargs) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = dict(self._client_params) + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + message_dicts = [_convert_message_to_dict(m) for m in messages] + + return message_dicts, params + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res['message']) + gen = ChatGeneration(message=message) + generations.append(gen) + + llm_output = { + "token_usage": response['usage'], + "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _client_params(self) -> Mapping[str, Any]: + """Get the parameters used for the elemai client.""" + elemai_creds: Dict[str, Any] = { + "api_key": self.elemai_api_key, + "base_url": self.elemai_base_url, + "model": self.model_name, + } + return {**elemai_creds, **self._default_params} + + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **super()._get_invocation_params(stop=stop, **kwargs), + **self._default_params, + "model": self.model_name, + "function": kwargs.get("functions"), + } + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "proxy-chat" + + def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: + tiktoken_ = _import_tiktoken() + if self.tiktoken_model_name is not None: + model = self.tiktoken_model_name + else: + model = self.model_name + # model chatglm-std, chatglm-lite + # Returns the number of tokens used by a list of messages. + try: + encoding = tiktoken_.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + return model, encoding + + def get_token_ids(self, text: str) -> List[int]: + """Get the tokens present in the text with tiktoken package.""" + # tiktoken NOT supported for Python 3.7 or below + if sys.version_info[1] <= 7: + return super().get_token_ids(text) + _, encoding_model = self._get_encoding_model() + return encoding_model.encode(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Calculate num tokens for chatglm with tiktoken package. + + todo: read chatglm document + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + if sys.version_info[1] <= 7: + return super().get_num_tokens_from_messages(messages) + model, encoding = self._get_encoding_model() + if model.startswith("chatglm"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + messages_dict = [_convert_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + # every reply is primed with assistant + num_tokens += 3 + return num_tokens diff --git a/src/langchain_contrib/langchain_contrib/chat_models/wenxin.py b/src/langchain_contrib/langchain_contrib/chat_models/wenxin.py new file mode 100644 index 0000000000..fac1e15644 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/wenxin.py @@ -0,0 +1,377 @@ +"""proxy llm chat wrapper.""" +from __future__ import annotations + +import requests +import logging +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from pydantic import Field, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain.utils import get_from_dict_or_env + +if TYPE_CHECKING: + import tiktoken + +from .interface import WenxinChatCompletion +from .interface.types import ChatInput + +logger = logging.getLogger(__name__) + + +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + +def _create_retry_decorator(llm: ChatWenxin) -> Callable[[Any], Any]: + + min_seconds = 1 + max_seconds = 20 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(Exception) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + content = _dict["content"] or "" # OpenAI returns None for tool invocations + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +class ChatWenxin(BaseChatModel): + """Wrapper around proxy Chat large language models. + + To use, the environment variable ``ELEMAI_API_KEY`` set with your API key. + + Example: + .. code-block:: python + + from langchain_contrib.chat_models import ChatWenxin + chat_miniamaxai = ChatWenxin(model_name="ernie-bot") + """ + + client: Optional[Any] #: :meta private: + + """Model name to use.""" + model_name: str = Field("ernie-bot", alias="model") + + temperature: float = 0.95 + top_p: float = 0.8 + """What sampling temperature to use.""" + model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + wenxin_api_key: Optional[str] = None + wenxin_secret_key: Optional[str] = None + + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: Optional[int] = 6 + """Maximum number of retries to make when generating.""" + streaming: Optional[bool] = False + """Whether to stream the results or not.""" + n: Optional[int] = 1 + """Number of chat completions to generate for each prompt.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + tiktoken_model_name: Optional[str] = None + """The model name to pass to tiktoken when using this class. + Tiktoken is used to count the number of tokens in documents to constrain + them to be under a certain limit. By default, when set to None, this will + be the same as the embedding model name. However, there are some cases + where you may want to use this Embedding class with a model name not + supported by tiktoken. This can include when using Azure embeddings or + when using one of the many model providers that expose an OpenAI-like + API but with different models. In those cases, in order to avoid erroring + when tiktoken is called, you can specify a model name to use here.""" + verbose: Optional[bool] = False + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["wenxin_api_key"] = get_from_dict_or_env( + values, "wenxin_api_key", "WENXIN_API_KEY" + ) + + values["wenxin_secret_key"] = get_from_dict_or_env( + values, "wenxin_secret_key", "WENXIN_SECRET_KEY" + ) + + api_key = values["wenxin_api_key"] + secret_key = values["wenxin_secret_key"] + try: + values["client"] = WenxinChatCompletion(api_key, secret_key) + except AttributeError: + raise ValueError( + "Try upgrading it with `pip install --upgrade requests`." + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling ChatWenxin API.""" + return { + "model": self.model_name, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + **self.model_kwargs, + } + + def completion_with_retry(self, **kwargs: Any) -> Any: + retry_decorator = _create_retry_decorator(self) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + messages = kwargs.get('messages') + temperature = kwargs.get('temperature') + top_p = kwargs.get('top_p') + max_tokens = kwargs.get('max_tokens') + params = { + 'messages': messages, + 'model': self.model_name, + 'top_p': top_p, + 'temperature': temperature, + "max_tokens": max_tokens} + return self.client(ChatInput.parse_obj(params), self.verbose).dict() + + return _completion_with_retry(**kwargs) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage, "model_name": self.model_name} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + + response = self.completion_with_retry(messages=message_dicts, **params) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return self._generate(messages, stop, run_manager, kwargs) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = dict(self._client_params) + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + message_dicts = [_convert_message_to_dict(m) for m in messages] + + return message_dicts, params + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res['message']) + gen = ChatGeneration(message=message) + generations.append(gen) + + llm_output = { + "token_usage": response['usage'], + "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _client_params(self) -> Mapping[str, Any]: + """Get the parameters used for the client.""" + minimaxai_creds: Dict[str, Any] = { + "model": self.model_name, + } + return {**minimaxai_creds, **self._default_params} + + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **super()._get_invocation_params(stop=stop, **kwargs), + **self._default_params, + "model": self.model_name, + "function": kwargs.get("functions"), + } + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "ernie-bot-chat" + + def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: + tiktoken_ = _import_tiktoken() + if self.tiktoken_model_name is not None: + model = self.tiktoken_model_name + else: + model = self.model_name + # model chatglm-std, chatglm-lite + # Returns the number of tokens used by a list of messages. + try: + encoding = tiktoken_.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + return model, encoding + + def get_token_ids(self, text: str) -> List[int]: + """Get the tokens present in the text with tiktoken package.""" + # tiktoken NOT supported for Python 3.7 or below + if sys.version_info[1] <= 7: + return super().get_token_ids(text) + _, encoding_model = self._get_encoding_model() + return encoding_model.encode(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Calculate num tokens for chatglm with tiktoken package. + + todo: read chatglm document + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + if sys.version_info[1] <= 7: + return super().get_num_tokens_from_messages(messages) + model, encoding = self._get_encoding_model() + if model.startswith("chatglm"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + messages_dict = [_convert_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + # every reply is primed with assistant + num_tokens += 3 + return num_tokens + diff --git a/src/langchain_contrib/langchain_contrib/chat_models/xunfeiai.py b/src/langchain_contrib/langchain_contrib/chat_models/xunfeiai.py new file mode 100644 index 0000000000..3faaf0d2db --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/xunfeiai.py @@ -0,0 +1,385 @@ +"""proxy llm chat wrapper.""" +from __future__ import annotations + +import requests +import logging +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from pydantic import Field, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain.utils import get_from_dict_or_env + +if TYPE_CHECKING: + import tiktoken + +from .interface import XunfeiChatCompletion +from .interface.types import ChatInput + +logger = logging.getLogger(__name__) + + +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + +def _create_retry_decorator(llm: ChatXunfeiAI) -> Callable[[Any], Any]: + + min_seconds = 1 + max_seconds = 20 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(Exception) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + content = _dict["content"] or "" # OpenAI returns None for tool invocations + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +class ChatXunfeiAI(BaseChatModel): + """Wrapper around proxy Chat large language models. + + To use, the environment variable ``ELEMAI_API_KEY`` set with your API key. + + Example: + .. code-block:: python + + from langchain_contrib.chat_models import ChatXunfeiAI + chat_miniamaxai = ChatXunfeiAI(model_name="spark") + """ + + client: Optional[Any] #: :meta private: + + """Model name to use.""" + model_name: str = Field("spark", alias="model") + + temperature: float = 0.5 + top_p: float = 0.7 + """What sampling temperature to use.""" + model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + xunfeiai_appid: Optional[str] = None + xunfeiai_api_key: Optional[str] = None + xunfeiai_api_secret: Optional[str] = None + + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: Optional[int] = 6 + """Maximum number of retries to make when generating.""" + streaming: Optional[bool] = False + """Whether to stream the results or not.""" + n: Optional[int] = 1 + """Number of chat completions to generate for each prompt.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + tiktoken_model_name: Optional[str] = None + """The model name to pass to tiktoken when using this class. + Tiktoken is used to count the number of tokens in documents to constrain + them to be under a certain limit. By default, when set to None, this will + be the same as the embedding model name. However, there are some cases + where you may want to use this Embedding class with a model name not + supported by tiktoken. This can include when using Azure embeddings or + when using one of the many model providers that expose an OpenAI-like + API but with different models. In those cases, in order to avoid erroring + when tiktoken is called, you can specify a model name to use here.""" + verbose: Optional[bool] = False + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["xunfeiai_appid"] = get_from_dict_or_env( + values, "xunfeiai_appid", "XUNFEIAI_APPID" + ) + + values["xunfeiai_api_key"] = get_from_dict_or_env( + values, "xunfeiai_api_key", "XUNFEIAI_API_KEY" + ) + + values["xunfeiai_api_secret"] = get_from_dict_or_env( + values, "xunfeiai_api_secret", "XUNFEIAI_API_SECRET" + ) + + appid = values["xunfeiai_appid"] + api_key = values["xunfeiai_api_key"] + api_secret = values["xunfeiai_api_secret"] + + try: + values["client"] = XunfeiChatCompletion(appid, api_key, api_secret) + except AttributeError: + raise ValueError( + "Try upgrading it with `pip install --upgrade requests`." + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling ChatXunfeiAI API.""" + return { + "model": self.model_name, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + **self.model_kwargs, + } + + + def completion_with_retry(self, **kwargs: Any) -> Any: + retry_decorator = _create_retry_decorator(self) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + messages = kwargs.get('messages') + temperature = kwargs.get('temperature') + top_p = kwargs.get('top_p') + max_tokens = kwargs.get('max_tokens') + params = { + 'messages': messages, + 'model': self.model_name, + 'top_p': top_p, + 'temperature': temperature, + "max_tokens": max_tokens} + return self.client(ChatInput.parse_obj(params), self.verbose).dict() + + return _completion_with_retry(**kwargs) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage, "model_name": self.model_name} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + + response = self.completion_with_retry(messages=message_dicts, **params) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return self._generate(messages, stop, run_manager, kwargs) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = dict(self._client_params) + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + message_dicts = [_convert_message_to_dict(m) for m in messages] + + return message_dicts, params + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res['message']) + gen = ChatGeneration(message=message) + generations.append(gen) + + llm_output = { + "token_usage": response['usage'], + "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _client_params(self) -> Mapping[str, Any]: + """Get the parameters used for the client.""" + minimaxai_creds: Dict[str, Any] = { + "model": self.model_name, + } + return {**minimaxai_creds, **self._default_params} + + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **super()._get_invocation_params(stop=stop, **kwargs), + **self._default_params, + "model": self.model_name, + "function": kwargs.get("functions"), + } + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "xunfeiai_chat" + + def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: + tiktoken_ = _import_tiktoken() + if self.tiktoken_model_name is not None: + model = self.tiktoken_model_name + else: + model = self.model_name + # model chatglm-std, chatglm-lite + # Returns the number of tokens used by a list of messages. + try: + encoding = tiktoken_.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + return model, encoding + + def get_token_ids(self, text: str) -> List[int]: + """Get the tokens present in the text with tiktoken package.""" + # tiktoken NOT supported for Python 3.7 or below + if sys.version_info[1] <= 7: + return super().get_token_ids(text) + _, encoding_model = self._get_encoding_model() + return encoding_model.encode(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Calculate num tokens for chatglm with tiktoken package. + + todo: read chatglm document + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + if sys.version_info[1] <= 7: + return super().get_num_tokens_from_messages(messages) + model, encoding = self._get_encoding_model() + if model.startswith("spark"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + messages_dict = [_convert_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + # every reply is primed with assistant + num_tokens += 3 + return num_tokens + diff --git a/src/langchain_contrib/langchain_contrib/chat_models/zhipuai.py b/src/langchain_contrib/langchain_contrib/chat_models/zhipuai.py new file mode 100644 index 0000000000..c8012784e9 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/chat_models/zhipuai.py @@ -0,0 +1,401 @@ +"""ZhipuAI chat wrapper.""" +from __future__ import annotations + +import logging +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from pydantic import Field, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain.utils import get_from_dict_or_env + +if TYPE_CHECKING: + import tiktoken + +logger = logging.getLogger(__name__) + + +def _import_tiktoken() -> Any: + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_token_ids. " + "Please install it with `pip install tiktoken`." + ) + return tiktoken + + +def _create_retry_decorator(llm: ChatZhipuAI) -> Callable[[Any], Any]: + + min_seconds = 1 + max_seconds = 20 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(Exception) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + content = _dict["content"] or "" # OpenAI returns None for tool invocations + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + else: + additional_kwargs = {} + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict["content"]) + elif role == "function": + return FunctionMessage(content=_dict["content"], name=_dict["name"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + # raise ValueError(f"not support system role {message}") + + elif isinstance(message, FunctionMessage): + raise ValueError(f"not support funciton {message}") + else: + raise ValueError(f"Got unknown type {message}") + + # if "name" in message.additional_kwargs: + # message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +def _convert_message_to_dict2(message: BaseMessage) -> List[dict]: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + raise ValueError(f"not support system role {message}") + + elif isinstance(message, FunctionMessage): + raise ValueError(f"not support funciton {message}") + else: + raise ValueError(f"Got unknown type {message}") + + return [message_dict] + + +class ChatZhipuAI(BaseChatModel): + """Wrapper around ZhipuAI Chat large language models. + + To use, you should have the ``zhipuai`` python package installed, and the + environment variable ``ZHIPU_API_KEY`` set with your API key. + + Example: + .. code-block:: python + + from lib.zhipuai import ChatZhipuAI + chat_zhipu = ChatZhipu(model_name="chatglm-std") + """ + + client: Optional[Any] #: :meta private: + model_name: str = Field(default="chatglm_std", alias="model") + + """Model name to use.""" + temperature: float = 0.95 + top_p: float = 0.7 + """What sampling temperature to use.""" + model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + zhipuai_api_key: Optional[str] = None + + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: Optional[int] = 6 + """Maximum number of retries to make when generating.""" + streaming: Optional[bool] = False + """Whether to stream the results or not.""" + n: Optional[int] = 1 + """Number of chat completions to generate for each prompt.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + tiktoken_model_name: Optional[str] = None + """The model name to pass to tiktoken when using this class. + Tiktoken is used to count the number of tokens in documents to constrain + them to be under a certain limit. By default, when set to None, this will + be the same as the embedding model name. However, there are some cases + where you may want to use this Embedding class with a model name not + supported by tiktoken. This can include when using Azure embeddings or + when using one of the many model providers that expose an OpenAI-like + API but with different models. In those cases, in order to avoid erroring + when tiktoken is called, you can specify a model name to use here.""" + verbose: Optional[bool] = False + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["zhipuai_api_key"] = get_from_dict_or_env( + values, "zhipuai_api_key", "ZHIPUAI_API_KEY" + ) + try: + import zhipuai + zhipuai.api_key = values["zhipuai_api_key"] + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install zhipuai`." + ) + try: + values["client"] = zhipuai.model_api.invoke + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the zhipuai package. Try upgrading it " + "with `pip install --upgrade zhipuai`." + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling ZhipuAI API.""" + return { + "model": self.model_name, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + **self.model_kwargs, + } + + def completion_with_retry(self, **kwargs: Any) -> Any: + retry_decorator = _create_retry_decorator(self) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + messages = kwargs.get('messages') + temperature = kwargs.get('temperature') + top_p = kwargs.get('top_p') + params = {'prompt': messages, 'model': self.model_name, + 'top_p': top_p, 'temperature': temperature, + "incremental": False} + return self.client(**params) + + return _completion_with_retry(**kwargs) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage, "model_name": self.model_name} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + + response = self.completion_with_retry(messages=message_dicts, **params) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + return self._generate(messages, stop, run_manager, kwargs) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = dict(self._client_params) + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + system_content = '' + message_dicts = [] + for m in messages: + if m.type == 'system': + system_content += m.content + continue + message_dicts.extend(_convert_message_to_dict2(m)) + + if system_content: + message_dicts[-1]['content'] = system_content + message_dicts[-1]['content'] + + return message_dicts, params + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + # print('response', response) + def _norm_text(text): + if text[0] == '"' and text[-1] == '"': + out = eval(text) + else: + out = text + return out + + for res in response['data']["choices"]: + res['content'] = _norm_text(res['content']) + message = _convert_dict_to_message(res) + gen = ChatGeneration(message=message) + generations.append(gen) + + llm_output = {"token_usage": response['data']["usage"], "model_name": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _client_params(self) -> Mapping[str, Any]: + """Get the parameters used for the openai client.""" + zhipu_creds: Dict[str, Any] = { + "api_key": self.zhipuai_api_key, + "model": self.model_name, + } + return {**zhipu_creds, **self._default_params} + + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **super()._get_invocation_params(stop=stop, **kwargs), + **self._default_params, + "model": self.model_name, + "function": kwargs.get("functions"), + } + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "zhipu-chat" + + def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: + tiktoken_ = _import_tiktoken() + if self.tiktoken_model_name is not None: + model = self.tiktoken_model_name + else: + model = self.model_name + # model chatglm-std, chatglm-lite + # Returns the number of tokens used by a list of messages. + try: + encoding = tiktoken_.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + return model, encoding + + def get_token_ids(self, text: str) -> List[int]: + """Get the tokens present in the text with tiktoken package.""" + # tiktoken NOT supported for Python 3.7 or below + if sys.version_info[1] <= 7: + return super().get_token_ids(text) + _, encoding_model = self._get_encoding_model() + return encoding_model.encode(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Calculate num tokens for chatglm with tiktoken package. + + todo: read chatglm document + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + if sys.version_info[1] <= 7: + return super().get_num_tokens_from_messages(messages) + model, encoding = self._get_encoding_model() + if model.startswith("chatglm"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + messages_dict = [_convert_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + # every reply is primed with assistant + num_tokens += 3 + return num_tokens diff --git a/src/langchain_contrib/langchain_contrib/document_loaders/__init__.py b/src/langchain_contrib/langchain_contrib/document_loaders/__init__.py new file mode 100644 index 0000000000..3fbc88873f --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/document_loaders/__init__.py @@ -0,0 +1,3 @@ +from .elem_pdf import PDFWithSemanticLoader + +__all__ = ['PDFWithSemanticLoader'] \ No newline at end of file diff --git a/src/langchain_contrib/langchain_contrib/document_loaders/elem_html.py b/src/langchain_contrib/langchain_contrib/document_loaders/elem_html.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/langchain_contrib/langchain_contrib/document_loaders/elem_image.py b/src/langchain_contrib/langchain_contrib/document_loaders/elem_image.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/langchain_contrib/langchain_contrib/document_loaders/elem_pdf.py b/src/langchain_contrib/langchain_contrib/document_loaders/elem_pdf.py new file mode 100644 index 0000000000..9e0829888f --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/document_loaders/elem_pdf.py @@ -0,0 +1,664 @@ +"""Loads PDF with semantic splilter.""" +import json +import logging +import os +import tempfile +import time +from abc import ABC +import io +from pathlib import Path +from typing import Any, Iterator, List, Mapping, Optional, Union +from urllib.parse import urlparse +import re +import requests +from collections import Counter +from copy import deepcopy + +import numpy as np +from shapely import box as Rect +from shapely import Polygon +import pypdfium2 +import fitz + +from langchain.docstore.document import Document +from langchain.document_loaders.pdf import BasePDFLoader +from langchain.document_loaders.blob_loaders import Blob +from langchain_contrib.document_loaders.parsers import LayoutParser + + +RE_MULTISPACE_INCLUDING_NEWLINES = re.compile(pattern=r"\s+", flags=re.DOTALL) + + +def merge_rects(bboxes): + x0 = np.min(bboxes[:, 0]) + y0 = np.min(bboxes[:, 1]) + x1 = np.max(bboxes[:, 2]) + y1 = np.max(bboxes[:, 3]) + return [x0, y0, x1, y1] + + +def norm_rect(bbox): + x0 = np.min([bbox[0], bbox[2]]) + x1 = np.max([bbox[0], bbox[2]]) + y0 = np.min([bbox[1], bbox[3]]) + y1 = np.max([bbox[1], bbox[3]]) + return np.asarray([x0, y0, x1, y1]) + + +def find_max_continuous_seq(arr): + n = len(arr) + max_info = (0, 1) + for i in range(n): + m = 1 + for j in range(i + 1, n): + if arr[j] - arr[j - 1] == 1: + m += 1 + else: + break + + if m > max_info[1]: + max_info = (i, m) + + max_info = (max_info[0] + arr[0], max_info[1]) + return max_info + + +def order_by_tbyx(block_info, th=10): + """ + block_info: [(b0, b1, b2, b3, text, x, y)+] + th: threshold of the position threshold + """ + # sort using y1 first and then x1 + res = sorted(block_info, key=lambda b: (b[1], b[0])) + for i in range(len(res) - 1): + for j in range(i, 0, -1): + # restore the order using the + if (abs(res[j + 1][1] - res[j][1]) < th and + (res[j + 1][0] < res[j][0])): + tmp = deepcopy(res[j]) + res[j] = deepcopy(res[j + 1]) + res[j + 1] = deepcopy(tmp) + else: + break + return res + + +def join_lines(texts, is_table=False): + if is_table: + return '\n'.join(texts) + + flags = [] + PUNC_SET = set(['.', ',', ';', '?', '!']) + for text in texts: + flags.append(np.all([t.isalnum() for t in text.rsplit(' ', 5)])) + + if np.all(flags): + t0 = texts[0] + for t in texts[1:]: + if t0[-1] == '-': + t0 = t0[:-1] + t + elif t0[-1].isalnum() and t[0].isalnum(): + t0 += ' ' + t + elif t0[-1] in PUNC_SET or t[0] in PUNC_SET: + t0 += ' ' + t + else: + t0 += t + return t0 + else: + return ''.join(texts) + + +class Segment: + def __init__(self, seg): + self.whole = seg + self.segs = [] + + @staticmethod + def is_align(seg0, seg1, delta=5, mode=0): + # mode=0 edge align + # mode=1, edge align or center align + res = Segment.contain(seg0, seg1) + if not res: + return False + else: + if mode == 1: + r1 = seg1[0] - seg0[0] <= delta or seg0[1] - seg1[1] <= delta + c0 = (seg0[0] + seg0[1]) / 2 + c1 = (seg1[0] + seg1[1]) / 2 + r2 = abs(c1 - c0) <= delta + return r1 or r2 + else: + return seg1[0] - seg0[0] <= delta or seg0[1] - seg1[1] <= delta + + @staticmethod + def contain(seg0, seg1): + return seg0[0] <= seg1[0] and seg0[1] >= seg1[0] + + @staticmethod + def overlap(seg0, seg1): + max_x0 = max(seg0[0], seg1[0]) + min_x1 = min(seg0[1], seg1[1]) + return max_x0 < min_x1 + + def _merge(self, segs): + x0s = [s[0] for s in segs] + x1s = [s[1] for s in segs] + return (np.min(x0s), np.max(x1s)) + + def add(self, seg): + if not self.segs: + self.segs.append(seg) + else: + overlaps = [] + non_overlaps = [] + for seg0 in self.segs: + if Segment.overlap(seg0, seg): + overlaps.append(seg0) + else: + non_overlaps.append(seg0) + + if not overlaps: + self.segs.append(seg) + else: + overlaps.append(seg) + new_seg = self._merge(overlaps) + non_overlaps.append(new_seg) + self.segs = non_overlaps + + def get_free_segment(self, incr_margin=True, margin_threshold=10): + sorted_segs = sorted(self.segs, key=lambda x: x[0]) + n = len(sorted_segs) + free_segs = [] + if incr_margin: + if n > 0: + seg_1st = sorted_segs[0] + if (seg_1st[0] - self.whole[0]) > margin_threshold: + free_segs.append((self.whole[0], seg_1st[0])) + + seg_last = sorted_segs[-1] + if (self.whole[1] - seg_last[1]) > margin_threshold: + free_segs.append((seg_last[1], self.whole[1])) + + for i in range(n - 1): + x0 = sorted_segs[i][1] + x1 = sorted_segs[i + 1][0] + free_segs.append((x0, x1)) + + return free_segs + + +class PDFWithSemanticLoader(BasePDFLoader): + """Loads a PDF with pypdf and chunks at character level. + + Loader also stores page numbers in metadata. + """ + + def __init__( + self, + file_path: str, + password: Optional[Union[str, bytes]] = None, + layout_api_key: str = None, + layout_api_url: str = None, + is_join_table: bool = True, + with_columns: bool = False, + support_rotate: bool = False, + text_elem_sep: str = '\n', + start: int = 0, + n: int = None, + html_output_file: str = None, + verbose: bool = False + ) -> None: + """Initialize with a file path.""" + self.layout_parser = LayoutParser( + api_key=layout_api_key, api_base_url=layout_api_url) + self.with_columns = with_columns + self.is_join_table = is_join_table + self.support_rotate = support_rotate + self.start = start + self.n = n + self.html_output_file = html_output_file + self.verbose = verbose + self.text_elem_sep = text_elem_sep + super().__init__(file_path) + + def _get_image_blobs(self, fitz_doc, pdf_reader, n=None, start=0): + blobs = [] + pages = [] + if not n: + n = fitz_doc.page_count + for pg in range(start, start + n): + bytes_img = None + page = fitz_doc.load_page(pg) + pages.append(page) + mat = fitz.Matrix(1, 1) + try: + pm = page.get_pixmap(matrix=mat, alpha=False) + bytes_img = pm.getPNGData() + except Exception: + # some pdf input cannot get render image from fitz + page = pdf_reader.get_page(pg) + pil_image = page.render().to_pil() + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format='PNG') + bytes_img = img_byte_arr.getvalue() + + blobs.append(Blob(data=bytes_img)) + return blobs, pages + + def _allocate_semantic(self, page, layout): + class_name = [ + '印章', '图片', '标题', '段落', '表格', '页眉', '页码', '页脚' + ] + effective_class_inds = [3, 4, 5, 999] + non_conti_class_ids = [6, 7, 8] + TEXT_ID = 4 + TABLE_ID = 5 + + textpage = page.get_textpage() + blocks = textpage.extractBLOCKS() + + if self.support_rotate: + rotation_matrix = np.asarray(page.rotation_matrix).reshape((3, 2)) + c1 = (rotation_matrix[0, 0] - 1) <= 1e-6 + c2 = (rotation_matrix[1, 1] - 1) <= 1e-6 + is_rotated = c1 and c2 + # print('c1/c2', c1, c2) + if is_rotated: + new_blocks = [] + for b in blocks: + bbox = np.asarray([b[0], b[1], b[2], b[3]]) + aug_bbox = bbox.reshape((-1, 2)) + padding = np.ones((len(aug_bbox), 1)) + aug_bbox = np.hstack([aug_bbox, padding]) + bb = np.dot(aug_bbox, rotation_matrix).reshape(-1) + bb = norm_rect(bb) + info = (bb[0], bb[1], bb[2], bb[3], b[4], b[5], b[6]) + new_blocks.append(info) + + blocks = new_blocks + + if not self.with_columns: + blocks = order_by_tbyx(blocks) + + # print('---ori blocks---') + # for b in blocks: + # print(b) + + + IMG_BLOCK_TYPE = 1 + text_ploys = [] + text_rects = [] + texts = [] + for b in blocks: + if b[-1] != IMG_BLOCK_TYPE: + text = re.sub( + RE_MULTISPACE_INCLUDING_NEWLINES, ' ', + b[4] or "").strip() + if text: + texts.append(text) + text_ploys.append(Rect(b[0], b[1], b[2], b[3])) + text_rects.append([b[0], b[1], b[2], b[3]]) + text_rects = np.asarray(text_rects) + texts = np.asarray(texts) + + semantic_polys = [] + semantic_labels = [] + + layout_info = json.loads(layout.page_content) + for info in layout_info: + bbs = info['bbox'] + coords = ((bbs[0], bbs[1]), (bbs[2], bbs[3]), + (bbs[4], bbs[5]), (bbs[6], bbs[7])) + semantic_polys.append(Polygon(coords)) + semantic_labels.append(info['category_id']) + + # caculate containing overlap + sem_cnt = len(semantic_polys) + texts_cnt = len(text_ploys) + contain_matrix = np.zeros((sem_cnt, texts_cnt)) + for i in range(sem_cnt): + for j in range(texts_cnt): + inter = semantic_polys[i].intersection(text_ploys[j]).area + contain_matrix[i, j] = inter * 1.0 / text_ploys[j].area + + # print('----------------containing matrix--------') + # for r in contain_matrix.tolist(): + # print([round(r_, 2) for r_ in r]) + + # print('---text---') + # for t in texts: + # print(t) + + # merge continuous text block by the containing matrix + CONTRAIN_THRESHOLD = 0.70 + contain_info = [] + for i in range(sem_cnt): + ind = np.argwhere(contain_matrix[i, :] > CONTRAIN_THRESHOLD)[:, 0] + if len(ind) == 0: continue + label = semantic_labels[i] + if label in non_conti_class_ids: + n = len(ind) + contain_info.append((None, None, n, label, ind)) + else: + start, n = find_max_continuous_seq(ind) + if n >= 1: + contain_info.append((start, start + n, n, label, None)) + + contain_info = sorted(contain_info, key=lambda x: x[2], reverse=True) + mask = np.zeros(texts_cnt) + new_block_info = [] + for info in contain_info: + start, end, n, label, ind = info + if label in non_conti_class_ids and np.all(mask[ind] == 0): + rect = merge_rects(text_rects[ind]) + ori_orders = [blocks[i][-2] for i in ind] + ts = texts[ind] + rs = text_rects[ind] + ord_ind = np.min(ori_orders) + mask[ind] = 1 + new_block_info.append( + (rect[0], rect[1], rect[2], rect[3], ts, rs, ord_ind)) + + elif np.all(mask[start:end] == 0): + rect = merge_rects(text_rects[start:end]) + ori_orders = [blocks[i][-2] for i in range(start, end)] + arg_ind = np.argsort(ori_orders) + # print('ori_orders', ori_orders, arg_ind) + ord_ind = np.min(ori_orders) + + ts = texts[start:end] + rs = text_rects[start:end] + if label == TABLE_ID: + ts = ts[arg_ind] + rs = rs[arg_ind] + + mask[start:end] = 1 + new_block_info.append( + (rect[0], rect[1], rect[2], rect[3], ts, rs, ord_ind)) + + for i in range(texts_cnt): + if mask[i] == 0: + b = blocks[i] + r = np.asarray([b[0], b[1], b[2], b[3]]) + ord_ind = b[-2] + new_block_info.append( + (b[0], b[1], b[2], b[3], [texts[i]], [r], ord_ind)) + + if self.with_columns: + new_blocks = sorted(new_block_info, key=lambda x: x[-1]) + else: + new_blocks = order_by_tbyx(new_block_info) + + # print('\n\n---new blocks---') + # for idx, b in enumerate(new_blocks): + # print(idx, b) + + text_ploys = [] + texts = [] + for b in new_blocks: + texts.append(b[4]) + text_ploys.append(Rect(b[0], b[1], b[2], b[3])) + + + # caculate overlap + sem_cnt = len(semantic_polys) + texts_cnt = len(text_ploys) + overlap_matrix = np.zeros((sem_cnt, texts_cnt)) + for i in range(sem_cnt): + for j in range(texts_cnt): + inter = semantic_polys[i].intersection(text_ploys[j]).area + union = semantic_polys[i].union(text_ploys[j]).area + overlap_matrix[i, j] = (inter * 1.0) / union + + # print('---overlap_matrix---') + # for r in overlap_matrix: + # print([round(r_, 3) for r_ in r]) + + # allocate label + OVERLAP_THRESHOLD = 0.2 + texts_labels = [] + DEF_SEM_LABEL = 999 + for j in range(texts_cnt): + ind = np.argwhere(overlap_matrix[:, j] > OVERLAP_THRESHOLD)[:, 0] + if len(ind) == 0: + sem_label = DEF_SEM_LABEL + else: + c = Counter([semantic_labels[i] for i in ind]) + items = c.most_common() + sem_label = items[0][0] + if len(items) > 1 and TEXT_ID in dict(items): + sem_label = TEXT_ID + + texts_labels.append(sem_label) + + # print(texts_labels) + # filter the unused element + filtered_blocks = [] + for label, b in zip(texts_labels, new_blocks): + if label in effective_class_inds: + text = join_lines(b[4], label == TABLE_ID) + filtered_blocks.append( + (b[0], b[1], b[2], b[3], text, b[5], label)) + + # print('---filtered_blocks---') + # for b in filtered_blocks: + # print(b) + + return filtered_blocks + + def _divide_blocks_into_groups(self, blocks): + # support only pure two columns layout, each has same width + rects = np.asarray([[b[0], b[1], b[2], b[3]] for b in blocks]) + min_x0 = np.min(rects[:, 0]) + max_x1 = np.max(rects[:, 2]) + root_seg = (min_x0, max_x1) + root_pc = (min_x0 + max_x1) / 2 + root_offset = 20 + center_seg = (root_pc - root_offset, root_pc + root_offset) + + segment = Segment(root_seg) + for r in rects: + segment.add((r[0], r[2])) + + COLUMN_THRESHOLD = 0.90 + CENTER_GAP_THRESHOLD = 0.90 + free_segs = segment.get_free_segment() + columns = [] + if len(free_segs) == 1 and len(segment.segs) == 2: + free_seg = free_segs[0] + seg0 = segment.segs[0] + seg1 = segment.segs[1] + cover = seg0[1] - seg0[0] + seg1[1] - seg1[0] + c0 = cover / (root_seg[1] - root_seg[0]) + c1 = Segment.contain(center_seg, free_seg) + if c0 > COLUMN_THRESHOLD and c1: + # two columns + columns.extend([seg0, seg1]) + + groups = [blocks] + if columns: + groups = [[] for _ in columns] + for b, r in zip(blocks, rects): + column_ind = 0 + cand_seg = (r[0], r[2]) + for i, seg in enumerate(columns): + if Segment.contain(seg, cand_seg): + column_ind = i + break + groups[i].append(b) + + return groups + + def _allocate_continuous(self, groups): + g_bound = [] + groups = [g for g in groups if g] + for blocks in groups: + arr =[[b[0], b[1], b[2], b[3]] for b in blocks] + bboxes = np.asarray(arr) + g_bound.append(np.asarray(merge_rects(bboxes))) + + LINE_FULL_THRESHOLD = 0.80 + START_THRESHOLD = 0.8 + SIMI_HEIGHT_THRESHOLD = 0.3 + SIMI_WIDTH_THRESHOLD = 0.3 + + TEXT_ID = 4 + TABLE_ID = 5 + + def _get_elem(blocks, is_first=True): + if not blocks: + return (None, None, None, None, None) + if is_first: + b1 = blocks[0] + b1_label = b1[-1] + r1 = b1[5][0] + r1_w = r1[2] - r1[0] + r1_h = r1[3] - r1[1] + return (b1, b1_label, r1, r1_w, r1_h) + else: + b0 = blocks[-1] + b0_label = b0[-1] + r0 = b0[5][-1] + r0_w = r0[2] - r0[0] + r0_h = r0[3] - r0[1] + return (b0, b0_label, r0, r0_w, r0_h) + + b0, b0_label, r0, r0_w, r0_h = _get_elem(groups[0], False) + g0 = g_bound[0] + + for i in range(1, len(groups)): + b1, b1_label, r1, r1_w, r1_h = _get_elem(groups[i], True) + g1 = g_bound[i] + + # print('\n_allocate_continuous:') + # print(b0, b0_label, b1, b1_label) + + if b0_label and b0_label == b1_label and b0_label == TEXT_ID: + c0 = r0_w / (g0[2] - g0[0]) + c1 = (r1[0] - g1[0]) / r1_h + c2 = np.abs(r0_h - r1_h) / r1_h + + # print('\n\n---conti texts---') + # print(b0_label, c0, c1, c2, + # b0, b0_label, r0, r0_w, r0_h, + # b1, b1_label, r1, r1_w, r1_h) + + if (c0 > LINE_FULL_THRESHOLD and c1 < START_THRESHOLD and + c2 < SIMI_HEIGHT_THRESHOLD): + new_text = join_lines([b0[4], b1[4]]) + new_block = ( + b0[0], b0[1], b0[2], b0[3], new_text, b0[5], b0[6]) + groups[i - 1][-1] = new_block + groups[i].pop(0) + + elif (self.is_join_table and b0_label and + b0_label == b1_label and b0_label == TABLE_ID): + c0 = (r1_w - r0_w) / r1_h + if c0 < SIMI_WIDTH_THRESHOLD: + new_text = join_lines([b0[4], b1[4]], True) + new_block = ( + b0[0], b0[1], b0[2], b0[3], new_text, b0[5], b0[6]) + groups[i - 1][-1] = new_block + groups[i].pop(0) + + b0, b0_label, r0, r0_w, r0_h = _get_elem(groups[i], False) + + return groups + + + def save_to_html(self, groups, output_file): + styles = [ + 'style="background-color: #EBEBEB;"', + 'style="background-color: #ABBAEA;"' + ] + idx = 0 + table_style = 'style="border:1px solid black;"' + + with open(output_file, 'w') as fout: + for blocks in groups: + for b in blocks: + if b[-1] == 3: + text = f'

{b[4]}

' + elif b[-1] == 4: + text = f'

{b[4]}

' + idx += 1 + elif b[-1] == 5: + rows = b[4].split('\n') + content = [] + for r in rows: + content.append( + f'{r}') + elem_text = '\n'.join(content) + text = f'{elem_text}
' + else: + text = f'

{b[4]}

' + idx += 1 + + fout.write(text + '\n') + + def _save_to_document(self, groups): + TITLE_ID = 3 + TEXT_ID = 4 + TABLE_ID = 5 + content_page = [] + is_first_elem = True + for blocks in groups: + for b in blocks: + if is_first_elem: + content_page.append(b[4]) + is_first_elem = False + else: + label, text = b[-1], b[4] + if label == TITLE_ID: + content_page.append('\n\n' + text) + else: + content_page.append(self.text_elem_sep + text) + + return ''.join(content_page) + + + def load(self) -> List[Document]: + """Load given path as pages.""" + blob = Blob.from_path(self.file_path) + start = self.start + groups = [] + with blob.as_bytes_io() as file_path: + fitz_doc = fitz.open(file_path) + pdf_doc = pypdfium2.PdfDocument(file_path, autoclose=True) + max_page = fitz_doc.page_count - start + n = self.n if self.n else max_page + n = min(n, max_page) + + tic = time.time() + if self.verbose: + print(f'{n} pages need be processed...') + + for idx in range(start, start + n): + blobs, pages = self._get_image_blobs(fitz_doc, pdf_doc, 1, idx) + layout = self.layout_parser.parse(blobs[0])[0] + blocks = self._allocate_semantic(pages[0], layout) + if not blocks: continue + + if self.with_columns: + sub_groups = self._divide_blocks_into_groups(blocks) + groups.extend(sub_groups) + else: + groups.append(blocks) + + if self.verbose: + count = idx - start + 1 + if count % 50 == 0: + elapse = round(time.time() - tic, 2) + tic = time.time() + print(f'process {count} pages used {elapse}sec...') + + groups = self._allocate_continuous(groups) + + if self.html_output_file: + self.save_to_html(groups, self.html_output_file) + return [] + + page_content = self._save_to_document(groups) + meta = {'source': os.path.basename(self.file_path)} + doc = Document(page_content=page_content, metadata=meta) + return [doc] diff --git a/src/langchain_contrib/langchain_contrib/document_loaders/parsers/__init__.py b/src/langchain_contrib/langchain_contrib/document_loaders/parsers/__init__.py new file mode 100644 index 0000000000..2b238f469a --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/document_loaders/parsers/__init__.py @@ -0,0 +1,5 @@ +from .image import LayoutParser + +__all__ = [ + "LayoutParser", +] diff --git a/src/langchain_contrib/langchain_contrib/document_loaders/parsers/image.py b/src/langchain_contrib/langchain_contrib/document_loaders/parsers/image.py new file mode 100644 index 0000000000..1f6369ff49 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/document_loaders/parsers/image.py @@ -0,0 +1,31 @@ +import time +from typing import Iterator, Optional, List +import json + +from langchain.document_loaders.blob_loaders import Blob +from langchain.schema import Document + +import requests +import base64 + + +class LayoutParser(object): + """Parse image layout structure. + """ + + def __init__(self, + api_key: Optional[str] = None, + api_base_url: Optional[str] = None): + self.api_key = api_key + self.api_base_url = "http://192.168.106.20:14569/predict" + self.class_name = [ + '印章', '图片', '标题', '段落', '表格', '页眉', '页码', '页脚' + ] + + def parse(self, blob: Blob) -> List[Document]: + b64_data = base64.b64encode(blob.as_bytes()).decode() + data = {'img': b64_data} + resp = requests.post("http://192.168.106.20:14569/predict", data=data) + content = resp.json() + doc = Document(page_content=json.dumps(content), metadata={}) + return [doc] \ No newline at end of file diff --git a/src/langchain_contrib/langchain_contrib/document_loaders/parsers/test_image.py b/src/langchain_contrib/langchain_contrib/document_loaders/parsers/test_image.py new file mode 100644 index 0000000000..b99924e9a4 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/document_loaders/parsers/test_image.py @@ -0,0 +1,290 @@ +import json +import logging +import os +import tempfile +import time +from abc import ABC +import io +from pathlib import Path +from typing import Any, Iterator, List, Mapping, Optional, Union +from urllib.parse import urlparse +import requests +import random +from copy import deepcopy + +import pypdfium2 +import fitz +import cv2 +import numpy as np + +from langchain.document_loaders.blob_loaders import Blob + +from image import LayoutParser + + +def norm_rect(bbox): + x0 = np.min([bbox[0], bbox[2]]) + x1 = np.max([bbox[0], bbox[2]]) + y0 = np.min([bbox[1], bbox[3]]) + y1 = np.max([bbox[1], bbox[3]]) + return np.asarray([x0, y0, x1, y1]) + + +def merge_rects(bboxes): + x0 = np.min(bboxes[:, 0]) + y0 = np.min(bboxes[:, 1]) + x1 = np.max(bboxes[:, 2]) + y1 = np.max(bboxes[:, 3]) + return [x0, y0, x1, y1] + + +def get_image_blobs(pages, pdf_reader, n, start=0): + blobs = [] + for pg in range(start, start + n): + bytes_img = None + page = pages.load_page(pg) + mat = fitz.Matrix(1, 1) + try: + pm = page.get_pixmap(matrix=mat, alpha=False) + bytes_img = pm.getPNGData() + except Exception: + # some pdf input cannot get render image from fitz + page = pdf_reader.get_page(pg) + pil_image = page.render().to_pil() + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format='PNG') + bytes_img = img_byte_arr.getvalue() + + blobs.append(Blob(data=bytes_img)) + return blobs + + +def test(): + file_path = './data/达梦数据库招股说明书_test_v1.pdf' + blob = Blob.from_path(file_path) + pages = None + image_blobs = [] + with blob.as_bytes_io() as file_path: + pages = fitz.open(file_path) + pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) + image_blobs = get_image_blobs(pages, pdf_reader) + + assert len(image_blobs) == pages.page_count + layout = LayoutParser() + res = layout.parse(image_blobs[0]) + + +def draw_polygon(image, bbox, text=None, color=(255, 0, 0), thickness=1): + bbox = bbox.astype(np.int32) + is_rect = bbox.shape[0] == 4 + if is_rect: + start_point = (bbox[0], bbox[1]) + end_point = (bbox[2], bbox[3]) + image = cv2.rectangle(image, start_point, end_point, color, thickness) + else: + polys = [bbox.astype(np.int32).reshape((-1, 1, 2))] + cv2.polylines(image, polys, True, color=color, thickness=thickness) + start_point = (polys[0][0, 0, 0], polys[0][0, 0, 1]) + + if text: + fontFace = cv2.FONT_HERSHEY_SIMPLEX + fontScale = 0.5 + color = (0, 0, 255) + image = cv2.putText(image, text, start_point, fontFace, fontScale, + color, 1) + + return image + + +def test_vis(): + # file_path = './data/达梦数据库招股说明书_test_v1.pdf' + file_path = './data/pdf_input/《中国药典》2020年版 一部.pdf' + output_prefix = 'zhongguoyaodian_2020_v1' + start, end, n = 70, 80, 10 + blob = Blob.from_path(file_path) + pages = None + image_blobs = [] + with blob.as_bytes_io() as file_path: + pages = fitz.open(file_path) + pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) + image_blobs = get_image_blobs(pages, pdf_reader, n, start) + + assert len(image_blobs) == n + + for i, blob in enumerate(image_blobs): + idx = i + start + # blob = image_blobs[2] + layout = LayoutParser() + out = layout.parse(blob) + res = json.loads(out[0].page_content) + bboxes = [] + labels = [] + for r in res: + bboxes.append(r['bbox']) + labels.append(str(r['category_id'])) + + bboxes = np.asarray(bboxes) + + bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8) + image = cv2.imdecode(bytes_arr, flags=1) + for bbox, text in zip(bboxes, labels): + image = draw_polygon(image, bbox, text) + + outf = f'./data/{output_prefix}_layout_p{idx+1}_vis.png' + cv2.imwrite(outf, image) + + +def order_by_tbyx(block_info, th=10): + """ + block_info: [(b0, b1, b2, b3, text, x, y)+] + th: threshold of the position threshold + """ + # sort using y1 first and then x1 + res = sorted(block_info, key=lambda b: (b[1], b[0])) + for i in range(len(res) - 1): + for j in range(i, 0, -1): + # restore the order using the + if (abs(res[j + 1][1] - res[j][1]) < th and + (res[j + 1][0] < res[j][0])): + tmp = deepcopy(res[j]) + res[j] = deepcopy(res[j + 1]) + res[j + 1] = deepcopy(tmp) + else: + break + return res + + +def test_vis2(): + # file_path = './data/达梦数据库招股说明书_test_v1.pdf' + file_path = './data/pdf_input/达梦数据库招股说明书.pdf' + output_prefix = 'dameng_pageblock' + + start = 0 + end = 10 + n = end - start + blob = Blob.from_path(file_path) + pages = None + image_blobs = [] + with blob.as_bytes_io() as file_path: + pages = fitz.open(file_path) + pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) + image_blobs = get_image_blobs(pages, pdf_reader, n, start) + + assert len(image_blobs) == pages.page_count + + for i, blob in enumerate(image_blobs): + idx = i + start + page = pages.load_page(idx) + + rect = page.rect + print('rect', rect) + o = 10 + b0 = np.asarray([rect.x0 + o, rect.y0 + o, + rect.x1 - o, rect.y1 - o]) + + bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8) + image = cv2.imdecode(bytes_arr, flags=1) + + image = draw_polygon(image, b0, '0.0') + + textpage = page.get_textpage() + blocks = textpage.extractBLOCKS() + IMG_BLOCK_TYPE = 1 + + # blocks = order_by_tbyx(blocks) + bboxes = [] + for off, b in enumerate(blocks): + label = 'text' if b[-1] != IMG_BLOCK_TYPE else 'image' + label = f'{label}-{off}' + print('block', b, label) + bbox = np.asarray([b[0], b[1], b[2], b[3]]) + bboxes.append(bbox) + + image = draw_polygon(image, bbox, label) + + if bboxes: + b1 = merge_rects(np.asarray(bboxes)) + b1 = np.asarray(b1) + image = draw_polygon(image, b1, '0.1') + + outf = f'./data/{output_prefix}_p{idx}_vis.png' + cv2.imwrite(outf, image) + + + +def test_vis3(): + file_path = './data/pdf_input/《中国药典》2020年版 一部.pdf' + + start = 50 + end = 60 + n = end - start + output_prefix = 'zhongguoyaodian_2020_v1' + + blob = Blob.from_path(file_path) + pages = None + image_blobs = [] + with blob.as_bytes_io() as file_path: + pages = fitz.open(file_path) + pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) + image_blobs = get_image_blobs(pages, pdf_reader, n, start=50) + + assert len(image_blobs) == n + + for i, blob in enumerate(image_blobs): + idx = i + start + page = pages.load_page(idx) + + rect = page.rect + print('rect', rect) + o = 10 + b0 = np.asarray([rect.x0 + o, rect.y0 + o, + rect.x1 - o, rect.y1 - o]) + + bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8) + image = cv2.imdecode(bytes_arr, flags=1) + + image = draw_polygon(image, b0, '0.0') + + rotation_matrix = np.asarray(page.rotation_matrix).reshape((3, 2)) + c1 = (rotation_matrix[0, 0] - 1) <= 1e-6 + c2 = (rotation_matrix[1, 1] - 1) <= 1e-6 + is_rotated = c1 and c2 + + textpage = page.get_textpage() + blocks = textpage.extractBLOCKS() + IMG_BLOCK_TYPE = 1 + + # blocks = order_by_tbyx(blocks) + bboxes = [] + for off, b in enumerate(blocks): + label = 'text' if b[-1] != IMG_BLOCK_TYPE else 'image' + label = f'{label}-{off}' + print('block', b, label) + bbox = np.asarray([b[0], b[1], b[2], b[3]]) + + aug_bbox = bbox.reshape((-1, 2)) + padding = np.ones((len(aug_bbox), 1)) + aug_bbox = np.hstack([aug_bbox, padding]) + new_bbox = np.dot(aug_bbox, rotation_matrix).reshape(-1) + + new_bbox = norm_rect(new_bbox) + + print('new_bboxes', new_bbox) + bboxes.append(new_bbox) + + image = draw_polygon(image, new_bbox, label) + + print(bboxes) + if bboxes: + b1 = merge_rects(np.asarray(bboxes)) + b1 = np.asarray(b1) + image = draw_polygon(image, b1, '0.1') + + outf = f'./data/{output_prefix}_p{idx}_vis.png' + cv2.imwrite(outf, image) + + +# test_vis3() +# test_vis2() +test_vis() +# test() \ No newline at end of file diff --git a/src/langchain_contrib/langchain_contrib/embeddings/__init__.py b/src/langchain_contrib/langchain_contrib/embeddings/__init__.py new file mode 100644 index 0000000000..ca36abf5e9 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/embeddings/__init__.py @@ -0,0 +1,7 @@ +from .wenxin import WenxinEmbeddings +from .host_embedding import ME5Embedding, BGEZhEmbedding, GTEEmbedding, HostEmbeddings + +__all__ = [ + 'WenxinEmbeddings', 'ME5Embedding', 'BGEZhEmbedding', 'GTEEmbedding', + 'HostEmbeddings' +] \ No newline at end of file diff --git a/src/langchain_contrib/langchain_contrib/embeddings/host_embedding.py b/src/langchain_contrib/langchain_contrib/embeddings/host_embedding.py new file mode 100644 index 0000000000..d004c0bee7 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/embeddings/host_embedding.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import logging +import warnings +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import requests +from requests.exceptions import HTTPError +from pydantic import BaseModel, Extra, Field, root_validator + +from requests.exceptions import HTTPError + +from tenacity import ( + AsyncRetrying, + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator(embeddings: HostEmbeddings) -> Callable[[Any], Any]: + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(Exception) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def embed_with_retry(embeddings: HostEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the embedding call.""" + retry_decorator = _create_retry_decorator(embeddings) + + @retry_decorator + def _embed_with_retry(**kwargs: Any) -> Any: + return embeddings.embed(**kwargs) + return _embed_with_retry(**kwargs) + + +class HostEmbeddings(BaseModel, Embeddings): + """host embedding models. + """ + + client: Optional[Any] #: :meta private: + + """Model name to use.""" + model: str = "embedding-host" + host_base_url: str = None + + deployment: Optional[str] = 'default' + + embedding_ctx_length: Optional[int] = 6144 + """The maximum number of tokens to embed at once.""" + + """Maximum number of texts to embed in each batch""" + max_retries: Optional[int] = 6 + """Maximum number of retries to make when generating.""" + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout in seconds for the OpenAPI request.""" + + model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + + verbose: Optional[bool] = False + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["host_base_url"] = get_from_dict_or_env( + values, "host_base_url", "HostBaseUrl" + ) + + try: + values["client"] = requests.post + except AttributeError: + raise ValueError( + "Try upgrading it with `pip install --upgrade requests`." + ) + return values + + @property + def _invocation_params(self) -> Dict: + api_args = { + "model": self.model, + "request_timeout": self.request_timeout, + **self.model_kwargs, + } + return api_args + + def embed(self, texts: List[str], **kwargs) -> List[List[float]]: + emb_type = kwargs.get('type', 'raw') + inp = {'texts': texts, 'model': self.model, 'type': emb_type} + if self.verbose: + print('payload', inp) + + url = f"{self.host_base_url}/{self.model}/infer" + outp = self.client(url=url, json=inp).json() + if outp['status_code'] != 200: + raise ValueError( + f"API returned an error: {outp['status_message']}" + ) + return outp['embeddings'] + + def embed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + embeddings = embed_with_retry(self, texts=texts, type='doc') + return embeddings + + def embed_query(self, text: str) -> List[float]: + embeddings = embed_with_retry(self, texts=[text], type='query') + return embeddings[0] + + +class ME5Embedding(HostEmbeddings): + model: str = "multi-e5" + embedding_ctx_length: int = 512 + + +class BGEZhEmbedding(HostEmbeddings): + model: str = "bge-zh" + embedding_ctx_length: int = 512 + + +class GTEEmbedding(HostEmbeddings): + model: str = "gte" + embedding_ctx_length: int = 512 diff --git a/src/langchain_contrib/langchain_contrib/embeddings/interface/__init__.py b/src/langchain_contrib/langchain_contrib/embeddings/interface/__init__.py new file mode 100644 index 0000000000..e3ed81dabd --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/embeddings/interface/__init__.py @@ -0,0 +1,3 @@ +from .wenxin import EmbeddingClient as WenxinEmbeddingClient + +__all__ = ['WenxinEmbeddingClient'] diff --git a/src/langchain_contrib/langchain_contrib/embeddings/interface/types.py b/src/langchain_contrib/langchain_contrib/embeddings/interface/types.py new file mode 100644 index 0000000000..80829e8718 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/embeddings/interface/types.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel +from typing import List, Union, Any, Dict + + +class EmbeddingInput(BaseModel): + model: str + input: Union[str, List[str]] + + +class Embedding(BaseModel): + object: str = 'embedding' + embedding: List[float] + index: int + + +class EmbeddingOutput(BaseModel): + status_code: int + status_message: str = 'success' + object: str = None + data: List[Embedding] = [] + model: str = None + usage: Dict[str, Any] = None diff --git a/src/langchain_contrib/langchain_contrib/embeddings/interface/wenxin.py b/src/langchain_contrib/langchain_contrib/embeddings/interface/wenxin.py new file mode 100644 index 0000000000..2e896b3962 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/embeddings/interface/wenxin.py @@ -0,0 +1,90 @@ +import requests +from requests.exceptions import HTTPError +import json +import numpy as np + +def get_access_token(api_key, sec_key): + url = (f"https://aip.baidubce.com/oauth/2.0/token?" + f"grant_type=client_credentials" + f"&client_id={api_key}&client_secret={sec_key}") + + payload = json.dumps("") + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json' + } + + response = requests.request("POST", url, headers=headers, data=payload) + return response.json().get("access_token") + + +class EmbeddingClient(object): + def __init__(self, api_key, sec_key, **kwargs): + self.api_key = api_key + self.sec_key = sec_key + self.ep_url = ( + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/" + "wenxinworkshop/embeddings") + self.headers = { + 'Content-Type': 'application/json' + } + self.max_text_tokens = 384 + self.max_text_num = 16 + self.drop_exceed_token = kwargs.get('drop_exceed_token', True) + + def create(self, model, input, verbose=False, **kwargs): + texts = input + if isinstance(texts, str): + texts = [texts] + + if self.drop_exceed_token: + texts = [t[:self.max_text_tokens] for t in texts] + + cond = np.all([len(text) <= self.max_text_tokens for text in texts]) + if not cond: + raise HTTPError('text exceed max token size 384') + + token = get_access_token(self.api_key, self.sec_key) + endpoint = f"{self.ep_url}/{model}?access_token={token}" + + def _call(sub_texts): + payload = json.dumps({ + "input": sub_texts + }) + response = requests.post(endpoint, headers=self.headers, data=payload) + status_message = 'success' + status_code = response.status_code + usage = {'prompt_tokens': 0, 'total_tokens': 0} + data = [] + if status_code == 200: + try: + info = json.loads(response.text) + status_code = info.get('error_code', 200) + status_message = info.get('error_msg', status_message) + if status_code == 200: + data = info['data'] + usage = info['usage'] + else: + raise HTTPError(status_message) + except Exception as e: + raise HTTPError(str(e)) + else: + raise HTTPError('requests error') + return data, usage + + data = [] + usage = {'prompt_tokens': 0, 'total_tokens': 0} + + for i in range(0, len(texts), self.max_text_num): + sub_texts = texts[i: (i + self.max_text_num)] + sub_data, sub_usage = _call(sub_texts) + data.extend(sub_data) + usage['prompt_tokens'] += sub_usage['prompt_tokens'] + usage['total_tokens'] += sub_usage['total_tokens'] + + outp = dict( + status_code=200, + model=model, + data=data, + usage=usage) + return outp diff --git a/src/langchain_contrib/langchain_contrib/embeddings/wenxin.py b/src/langchain_contrib/langchain_contrib/embeddings/wenxin.py new file mode 100644 index 0000000000..bc0e305f3f --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/embeddings/wenxin.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import logging +import warnings +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import numpy as np +from pydantic import BaseModel, Extra, Field, root_validator +from requests.exceptions import HTTPError +from tenacity import ( + AsyncRetrying, + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator(embeddings: WenxinEmbeddings) -> Callable[[Any], Any]: + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=(retry_if_exception_type(HTTPError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def embed_with_retry(embeddings: WenxinEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the embedding call.""" + retry_decorator = _create_retry_decorator(embeddings) + + @retry_decorator + def _embed_with_retry(**kwargs: Any) -> Any: + return embeddings.embed(**kwargs) + return _embed_with_retry(**kwargs) + + +class WenxinEmbeddings(BaseModel, Embeddings): + """Wenxin embedding models. + + To use, the environment variable ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` + set with your API key or pass it as a named parameter to the constructor. + + Example: + .. code-block:: python + from langchain_contrib.embeddings import WenxinEmbeddings + wenxin_embeddings = WenxinEmbeddings( + wenxin_api_key="my-api-key", + wenxin_secret_key='xxx') + + """ + + client: Optional[Any] #: :meta private: + model: str = "embedding-v1" + + deployment: Optional[str] = 'default' + wenxin_api_key: Optional[str] = None + wenxin_secret_key: Optional[str] = None + + embedding_ctx_length: Optional[int] = 6144 + """The maximum number of tokens to embed at once.""" + + """Maximum number of texts to embed in each batch""" + max_retries: Optional[int] = 6 + """Maximum number of retries to make when generating.""" + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout in seconds for the OpenAPI request.""" + + model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["wenxin_api_key"] = get_from_dict_or_env( + values, "wenxin_api_key", "WENXIN_API_KEY" + ) + values["wenxin_secret_key"] = get_from_dict_or_env( + values, + "wenxin_secret_key", + "WENXIN_SECRET_KEY", + ) + + api_key = values["wenxin_api_key"] + sec_key = values["wenxin_secret_key"] + try: + from .interface import WenxinEmbeddingClient + values["client"] = WenxinEmbeddingClient( + api_key=api_key, sec_key=sec_key) + except AttributeError: + raise ValueError( + "Try upgrading it with `pip install --upgrade requests`." + ) + return values + + @property + def _invocation_params(self) -> Dict: + wenxin_args = { + "model": self.model, + "request_timeout": self.request_timeout, + **self.model_kwargs, + } + + return wenxin_args + + + def embed(self, texts: List[str]) -> List[List[float]]: + inp = {'input': texts, 'model': self.model} + outp = self.client.create(**inp) + if outp['status_code'] != 200: + raise ValueError( + f"Wenxin API returned an error: {outp['status_message']}" + ) + return [e['embedding'] for e in outp['data']] + + def embed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + embeddings = embed_with_retry(self, texts=texts) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Call out to OpenAI's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + + embeddings = embed_with_retry(self, texts=[text]) + return embeddings[0] + diff --git a/src/langchain_contrib/langchain_contrib/vectorstores/__init__.py b/src/langchain_contrib/langchain_contrib/vectorstores/__init__.py new file mode 100644 index 0000000000..8d415501ed --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/vectorstores/__init__.py @@ -0,0 +1,3 @@ +from .elastic_keywords_search import ElasticKeywordsSearch + +__all__=['ElasticKeywordsSearch'] diff --git a/src/langchain_contrib/langchain_contrib/vectorstores/elastic_keywords_search.py b/src/langchain_contrib/langchain_contrib/vectorstores/elastic_keywords_search.py new file mode 100644 index 0000000000..30af51a2e5 --- /dev/null +++ b/src/langchain_contrib/langchain_contrib/vectorstores/elastic_keywords_search.py @@ -0,0 +1,292 @@ +"""Wrapper around Elasticsearch vector database.""" +from __future__ import annotations + +import uuid +import jieba.analyse +from abc import ABC +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from langchain.docstore.document import Document +from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.base import VectorStore + +if TYPE_CHECKING: + from elasticsearch import Elasticsearch + + +def _default_text_mapping() -> Dict: + return { + "properties": { + "text": {"type": "text"} + } + } + + +# ElasticKeywordsSearch is a concrete implementation of the abstract base class +# VectorStore, which defines a common interface for all vector database +# implementations. By inheriting from the ABC class, ElasticKeywordsSearch can be +# defined as an abstract base class itself, allowing the creation of subclasses with +# their own specific implementations. If you plan to subclass ElasticKeywordsSearch, +# you can inherit from it and define your own implementation of the necessary methods +# and attributes. +class ElasticKeywordsSearch(VectorStore, ABC): + """Wrapper around Elasticsearch as a vector database. + + To connect to an Elasticsearch instance that does not require + login credentials, pass the Elasticsearch URL and index name along with the + + Example: + .. code-block:: python + + from langchain import ElasticKeywordsSearch + + elastic_vector_search = ElasticKeywordsSearch( + elasticsearch_url="http://localhost:9200", + index_name="test_index", + ) + + + To connect to an Elasticsearch instance that requires login credentials, + including Elastic Cloud, use the Elasticsearch URL format + https://username:password@es_host:9243. For example, to connect to Elastic + Cloud, create the Elasticsearch URL with the required authentication details and + pass it to the ElasticKeywordsSearch constructor as the named parameter + elasticsearch_url. + + You can obtain your Elastic Cloud URL and login credentials by logging in to the + Elastic Cloud console at https://cloud.elastic.co, selecting your deployment, and + navigating to the "Deployments" page. + + To obtain your Elastic Cloud password for the default "elastic" user: + + 1. Log in to the Elastic Cloud console at https://cloud.elastic.co + 2. Go to "Security" > "Users" + 3. Locate the "elastic" user and click "Edit" + 4. Click "Reset password" + 5. Follow the prompts to reset the password + + The format for Elastic Cloud URLs is + https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243. + + Example: + .. code-block:: python + + from langchain import ElasticKeywordsSearch + elastic_host = "cluster_id.region_id.gcp.cloud.es.io" + elasticsearch_url = f"https://username:password@{elastic_host}:9243" + elastic_keywords_search = ElasticKeywordsSearch( + elasticsearch_url=elasticsearch_url, + index_name="test_index" + ) + + Args: + elasticsearch_url (str): The URL for the Elasticsearch instance. + index_name (str): The name of the Elasticsearch index for the keywords. + + Raises: + ValueError: If the elasticsearch python package is not installed. + """ + + def __init__( + self, + elasticsearch_url: str, + index_name: str, + *, + ssl_verify: Optional[Dict[str, Any]] = None, + ): + """Initialize with necessary components.""" + try: + import elasticsearch + except ImportError: + raise ImportError( + "Could not import elasticsearch python package. " + "Please install it with `pip install elasticsearch`." + ) + self.index_name = index_name + _ssl_verify = ssl_verify or {} + try: + self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify) + except ValueError as e: + raise ValueError( + f"Your elasticsearch client string is mis-formatted. Got error: {e} " + ) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + refresh_indices: bool = True, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the keywords and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of unique IDs. + refresh_indices: bool to refresh ElasticSearch indices + + Returns: + List of ids from adding the texts into the vectorstore. + """ + try: + from elasticsearch.exceptions import NotFoundError + from elasticsearch.helpers import bulk + except ImportError: + raise ImportError( + "Could not import elasticsearch python package. " + "Please install it with `pip install elasticsearch`." + ) + requests = [] + ids = ids or [str(uuid.uuid4()) for _ in texts] + mapping = _default_text_mapping() + + # check to see if the index already exists + try: + self.client.indices.get(index=self.index_name) + except NotFoundError: + # TODO would be nice to create index before embedding, + # just to save expensive steps for last + self.create_index(self.client, self.index_name, mapping) + + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + request = { + "_op_type": "index", + "_index": self.index_name, + "text": text, + "metadata": metadata, + "_id": ids[i], + } + requests.append(request) + bulk(self.client, requests) + + if refresh_indices: + self.client.indices.refresh(index=self.index_name) + return ids + + def similarity_search( + self, + query: str, + k: int = 4, + query_strategy: str = 'match_phrase', + must_or_should: str = 'should', + **kwargs: Any + ) -> List[Document]: + assert must_or_should in ['must', 'should'], 'only support must and should.' + keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False) + match_query = {'bool': {must_or_should: []}} + for key in keywords: + match_query['bool'][must_or_should].append({query_strategy: {'text': key}}) + docs_and_scores = self.similarity_search_with_score(match_query, k) + documents = [d[0] for d in docs_and_scores] + return documents + + def similarity_search_with_score( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Tuple[Document, float]]: + response = self.client_search( + self.client, self.index_name, query, size=k + ) + hits = [hit for hit in response["hits"]["hits"]] + docs_and_scores = [ + ( + Document( + page_content=hit["_source"]["text"], + metadata=hit["_source"]["metadata"], + ), + hit["_score"], + ) + for hit in hits + ] + return docs_and_scores + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + index_name: Optional[str] = None, + refresh_indices: bool = True, + **kwargs: Any, + ) -> ElasticKeywordsSearch: + """Construct ElasticKeywordsSearch wrapper from raw documents. + + This is a user-friendly interface that: + 1. Embeds documents. + 2. Creates a new index for the embeddings in the Elasticsearch instance. + 3. Adds the documents to the newly created Elasticsearch index. + + This is intended to be a quick way to get started. + + Example: + .. code-block:: python + + from langchain import ElasticKeywordsSearch + from langchain.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + elastic_vector_search = ElasticKeywordsSearch.from_texts( + texts, + embeddings, + elasticsearch_url="http://localhost:9200" + ) + """ + elasticsearch_url = get_from_dict_or_env( + kwargs, "elasticsearch_url", "ELASTICSEARCH_URL" + ) + if "elasticsearch_url" in kwargs: + del kwargs["elasticsearch_url"] + index_name = index_name or uuid.uuid4().hex + vectorsearch = cls(elasticsearch_url, index_name, **kwargs) + vectorsearch.add_texts( + texts, metadatas=metadatas, ids=ids, refresh_indices=refresh_indices + ) + return vectorsearch + + def create_index(self, client: Any, index_name: str, mapping: Dict) -> None: + version_num = client.info()["version"]["number"][0] + version_num = int(version_num) + if version_num >= 8: + client.indices.create(index=index_name, mappings=mapping) + else: + client.indices.create(index=index_name, body={"mappings": mapping}) + + def client_search( + self, client: Any, index_name: str, script_query: Dict, size: int + ) -> Any: + version_num = client.info()["version"]["number"][0] + version_num = int(version_num) + if version_num >= 8: + response = client.search(index=index_name, query=script_query, size=size) + else: + response = client.search( + index=index_name, body={"query": script_query, "size": size} + ) + return response + + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: + """Delete by vector IDs. + + Args: + ids: List of ids to delete. + """ + + if ids is None: + raise ValueError("No ids provided to delete.") + + # TODO: Check if this can be done in bulk + for id in ids: + self.client.delete(index=self.index_name, id=id) diff --git a/src/langchain_contrib/requirements.txt b/src/langchain_contrib/requirements.txt new file mode 100644 index 0000000000..e7a9b76399 --- /dev/null +++ b/src/langchain_contrib/requirements.txt @@ -0,0 +1,5 @@ +langchain +openai +zhipuai +websocket-client +elasticsearch \ No newline at end of file diff --git a/src/langchain_contrib/setup.py b/src/langchain_contrib/setup.py new file mode 100644 index 0000000000..01a9c2a465 --- /dev/null +++ b/src/langchain_contrib/setup.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020 langchain_contrib Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os + +import setuptools + + +def read_requirements_file(filepath): + with open(filepath) as fin: + requirements = fin.read() + return requirements + + +extras = {} +REQUIRED_PACKAGES = read_requirements_file("requirements.txt") + + +def read(*names, **kwargs): + with io.open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp: + return fp.read() + + +def get_package_data_files(package, data, package_dir=None): + """ + Helps to list all specified files in package including files in directories + since `package_data` ignores directories. + """ + if package_dir is None: + package_dir = os.path.join(*package.split(".")) + all_files = [] + for f in data: + path = os.path.join(package_dir, f) + if os.path.isfile(path): + all_files.append(f) + continue + for root, _dirs, files in os.walk(path, followlinks=True): + root = os.path.relpath(root, package_dir) + for file in files: + file = os.path.join(root, file) + if file not in all_files: + all_files.append(file) + return all_files + + +setuptools.setup( + name="langchain_contrib", + version="0.0.1", + author="DataElem", + author_email="contact@dataelem.com", + description="langchain's extra modules", + long_description=read("README.md"), + long_description_content_type="text/markdown", + url="https://github.com/dataelement/langchain_contrib", + packages=setuptools.find_packages( + exclude=("examples*", "tests*", "applications*", "model_zoo*"), + ), + package_data={ + }, + setup_requires=[], + install_requires=REQUIRED_PACKAGES, + entry_points={ + }, + extras_require=extras, + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + license="Apache 2.0", +)