Skip to content

Commit

Permalink
Merge pull request #4 from hwchase17/harrison/add_llms
Browse files Browse the repository at this point in the history
add llm objects
  • Loading branch information
hwchase17 committed Oct 17, 2022
2 parents 97ba020 + f1d60b9 commit 4cc39aa
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 2 deletions.
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: format lint tests
.PHONY: format lint tests integration_tests

format:
black .
Expand All @@ -11,4 +11,7 @@ lint:
mypy .

tests:
pytest tests
pytest tests/unit_tests

integration_tests:
pytest tests/integration_tests
1 change: 1 addition & 0 deletions langchain/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Wrappers on top of large language models."""
11 changes: 11 additions & 0 deletions langchain/llms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Base interface for large language models to expose."""
from abc import ABC, abstractmethod
from typing import List, Optional


class LLM(ABC):
"""LLM wrapper should take in a prompt and return a string."""

@abstractmethod
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
72 changes: 72 additions & 0 deletions langchain/llms/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Wrapper around Cohere APIs."""
import os
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Extra, root_validator

from langchain.llms.base import LLM


def remove_stop_tokens(text: str, stop: List[str]) -> str:
"""Remove stop tokens, should they occur at end."""
for s in stop:
if text.endswith(s):
return text[: -len(s)]
return text


class Cohere(BaseModel, LLM):
"""Wrapper around Cohere large language models."""

client: Any
model: str = "gptd-instruct-tft"
max_tokens: int = 256
temperature: float = 0.6
k: int = 0
p: int = 1
frequency_penalty: int = 0
presence_penalty: int = 0

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Validate that api key python package exists in environment."""
if "COHERE_API_KEY" not in os.environ:
raise ValueError(
"Did not find Cohere API key, please add an environment variable"
" `COHERE_API_KEY` which contains it."
)
try:
import cohere

values["client"] = cohere.Client(os.environ["COHERE_API_KEY"])
except ImportError:
raise ValueError(
"Could not import cohere python package. "
"Please it install it with `pip install cohere`."
)
return values

def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Cohere's generate endpoint."""
response = self.client.generate(
model=self.model,
prompt=prompt,
max_tokens=self.max_tokens,
temperature=self.temperature,
k=self.k,
p=self.p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop_sequences=stop,
)
text = response.generations[0].text
# If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop is not None:
text = remove_stop_tokens(text, stop)
return text
65 changes: 65 additions & 0 deletions langchain/llms/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Wrapper around OpenAI APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional

from pydantic import BaseModel, Extra, root_validator

from langchain.llms.base import LLM


class OpenAI(BaseModel, LLM):
"""Wrapper around OpenAI large language models."""

client: Any
model_name: str = "text-davinci-002"
temperature: float = 0.7
max_tokens: int = 256
top_p: int = 1
frequency_penalty: int = 0
presence_penalty: int = 0
n: int = 1
best_of: int = 1

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key python package exists in environment."""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"Did not find OpenAI API key, please add an environment variable"
" `OPENAI_API_KEY` which contains it."
)
try:
import openai

values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please it install it with `pip install openai`."
)
return values

@property
def default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"n": self.n,
"best_of": self.best_of,
}

def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to OpenAI's create endpoint."""
response = self.client.create(
model=self.model_name, prompt=prompt, stop=stop, **self.default_params
)
return response["choices"][0]["text"]
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
-e .
pytest
pytest-dotenv
black
isort
mypy
flake8
flake8-docstrings
cohere
openai
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All tests for this package."""
1 change: 1 addition & 0 deletions tests/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All integration tests (tests that call out to an external API)."""
1 change: 1 addition & 0 deletions tests/integration_tests/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All integration tests for LLM objects."""
10 changes: 10 additions & 0 deletions tests/integration_tests/llms/test_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Test Cohere API wrapper."""

from langchain.llms.cohere import Cohere


def test_cohere_call() -> None:
"""Test valid call to cohere."""
llm = Cohere(max_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
10 changes: 10 additions & 0 deletions tests/integration_tests/llms/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Test OpenAI API wrapper."""

from langchain.llms.openai import OpenAI


def test_cohere_call() -> None:
"""Test valid call to cohere."""
llm = OpenAI(max_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
1 change: 1 addition & 0 deletions tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All unit tests (lightweight tests)."""
1 change: 1 addition & 0 deletions tests/unit_tests/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""All unit tests for LLM objects."""
17 changes: 17 additions & 0 deletions tests/unit_tests/llms/test_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Test helper functions for Cohere API."""

from langchain.llms.cohere import remove_stop_tokens


def test_remove_stop_tokens() -> None:
"""Test removing stop tokens when they occur."""
text = "foo bar baz"
output = remove_stop_tokens(text, ["moo", "baz"])
assert output == "foo bar "


def test_remove_stop_tokens_none() -> None:
"""Test removing stop tokens when they do not occur."""
text = "foo bar baz"
output = remove_stop_tokens(text, ["moo"])
assert output == "foo bar baz"

0 comments on commit 4cc39aa

Please sign in to comment.