Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llm match chain #8

Merged
merged 1 commit into from Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Expand Up @@ -4,11 +4,11 @@ format:
black .
isort .

lint:
lint:
mypy .
black . --check
isort . --check
flake8 .
mypy .

tests:
pytest tests/unit_tests
Expand Down
4 changes: 4 additions & 0 deletions langchain/chains/llm_math/__init__.py
@@ -0,0 +1,4 @@
"""Chain that interprets a prompt and executes python code to do math.

Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py
"""
57 changes: 57 additions & 0 deletions langchain/chains/llm_math/base.py
@@ -0,0 +1,57 @@
"""Chain that interprets a prompt and executes python code to do math."""
from typing import Dict, List

from pydantic import BaseModel, Extra

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.chains.python import PythonChain
from langchain.llms.base import LLM


class LLMMathChain(Chain, BaseModel):
"""Chain that interprets a prompt and executes python code to do math."""

llm: LLM
verbose: bool = False
input_key: str = "question"
output_key: str = "answer"

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

@property
def input_keys(self) -> List[str]:
"""Expect input key."""
return [self.input_key]

@property
def output_keys(self) -> List[str]:
"""Expect output key."""
return [self.output_key]

def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
python_executor = PythonChain()
question = inputs[self.input_key]
t = llm_executor.predict(question=question, stop=["```output"]).strip()
if t.startswith("```python"):
code = t[9:-4]
if self.verbose:
print("[DEBUG] evaluating code")
print(code)
output = python_executor.run(code)
answer = "Answer: " + output
elif t.startswith("Answer:"):
answer = t
else:
raise ValueError(f"unknown format from LLM: {t}")
return {self.output_key: answer}

def run(self, question: str) -> str:
"""More user-friendly interface for interfacing with LLM math."""
return self({self.input_key: question})[self.output_key]
38 changes: 38 additions & 0 deletions langchain/chains/llm_math/prompt.py
@@ -0,0 +1,38 @@
# flake8: noqa
from langchain.prompt import Prompt

_PROMPT_TEMPLATE = """You are GPT-3, and you can't do math.

You can do basic math, and your memorization abilities are impressive, but you can't do any complex calculations that a human could not do in their head. You also have an annoying tendency to just make up highly specific, but wrong, answers.

So we hooked you up to a Python 3 kernel, and now you can execute code. If anyone gives you a hard math problem, just use this format and we鈥檒l take care of the rest:

Question: ${{Question with hard calculation.}}
```python
${{Code that prints what you need to know}}
```
```output
${{Output of your code}}
```
Answer: ${{Answer}}

Otherwise, use this simpler format:

Question: ${{Question without hard calculation}}
Answer: ${{Answer}}

Begin.

Question: What is 37593 * 67?

```python
print(37593 * 67)
```
```output
2518731
```
Answer: 2518731

Question: {question}"""

PROMPT = Prompt(input_variables=["question"], template=_PROMPT_TEMPLATE)
4 changes: 2 additions & 2 deletions langchain/chains/python.py
Expand Up @@ -19,12 +19,12 @@ class PythonChain(Chain, BaseModel):

@property
def input_keys(self) -> List[str]:
"""Expect input in `code` key."""
"""Expect input key."""
return [self.input_key]

@property
def output_keys(self) -> List[str]:
"""Return output in `output` key."""
"""Return output key."""
return [self.output_key]

def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
Expand Down
40 changes: 40 additions & 0 deletions tests/unit_tests/chains/test_llm_math.py
@@ -0,0 +1,40 @@
"""Test LLM Math functionality."""

import pytest

from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_math.prompt import _PROMPT_TEMPLATE
from tests.unit_tests.llms.fake_llm import FakeLLM


@pytest.fixture
def fake_llm_math_chain() -> LLMMathChain:
"""Fake LLM Math chain for testing."""
complex_question = _PROMPT_TEMPLATE.format(question="What is the square root of 2?")
queries = {
_PROMPT_TEMPLATE.format(question="What is 1 plus 1?"): "Answer: 2",
complex_question: "```python\nprint(2**.5)\n```",
_PROMPT_TEMPLATE.format(question="foo"): "foo",
}
fake_llm = FakeLLM(queries=queries)
return LLMMathChain(llm=fake_llm, input_key="q", output_key="a")


def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None:
"""Test simple question that should not need python."""
question = "What is 1 plus 1?"
output = fake_llm_math_chain.run(question)
assert output == "Answer: 2"


def test_complex_question(fake_llm_math_chain: LLMMathChain) -> None:
"""Test complex question that should need python."""
question = "What is the square root of 2?"
output = fake_llm_math_chain.run(question)
assert output == f"Answer: {2**.5}\n"


def test_error(fake_llm_math_chain: LLMMathChain) -> None:
"""Test question that raises error."""
with pytest.raises(ValueError):
fake_llm_math_chain.run("foo")
8 changes: 7 additions & 1 deletion tests/unit_tests/llms/fake_llm.py
@@ -1,14 +1,20 @@
"""Fake LLM wrapper for testing purposes."""
from typing import List, Optional
from typing import List, Mapping, Optional

from langchain.llms.base import LLM


class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""

def __init__(self, queries: Optional[Mapping] = None):
"""Initialize with optional lookup of queries."""
self._queries = queries

def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Return `foo` if no stop words, otherwise `bar`."""
if self._queries is not None:
return self._queries[prompt]
if stop is None:
return "foo"
else:
Expand Down