Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
boazwasserman committed Jun 8, 2023
1 parent 511c12d commit 05aefc6
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
46 changes: 46 additions & 0 deletions langchain/chains/pal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import warnings
import ast
from typing import Any, Dict, List, Optional

from pydantic import Extra, root_validator
Expand All @@ -18,6 +19,7 @@
from langchain.prompts.base import BasePromptTemplate
from langchain.utilities import PythonREPL

DEFAULT_CODE_VALIDATIONS = {'solution_function': 'solution', 'allow_imports': False, 'allow_non_solution_root_scope_expressions': False, 'allow_non_math_operations': True}

class PALChain(Chain):
"""Implements Program-Aided Language Models."""
Expand All @@ -33,6 +35,7 @@ class PALChain(Chain):
python_locals: Optional[Dict[str, Any]] = None
output_key: str = "result" #: :meta private:
return_intermediate_steps: bool = False
code_validations: Optional[Dict[str, Any]] = DEFAULT_CODE_VALIDATIONS

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -82,21 +85,64 @@ def _call(
stop=[self.stop], callbacks=_run_manager.get_child(), **inputs
)
_run_manager.on_text(code, color="green", end="\n", verbose=self.verbose)
PALChain.validate_code(code, self.code_validations)
repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals)
res = repl.run(code + f"\n{self.get_answer_expr}")
output = {self.output_key: res.strip()}
if self.return_intermediate_steps:
output["intermediate_steps"] = code
return output

@classmethod
def validate_code(cls, code, code_validations: Dict[str, Any]):
try:
code_tree = ast.parse(code)
except (SyntaxError, UnicodeDecodeError):
raise ValueError(f"Generated code is not valid python code: {code}")
except TypeError:
raise ValueError(f"Generated code is expected to be a string, instead found {type(code)}")
except OverflowError:
raise ValueError(f"Generated code too long / complex to be parsed by ast: {code}")

top_level_nodes = list(ast.iter_child_nodes(code_tree))
if code_validations.get('allow_non_solution_root_scope_expressions') is False and len(top_level_nodes) > 1:
raise ValueError(f"Generated code has more than 1 root scope expressions: {code}")

solution_func_name = code_validations.get('solution_function')
if not isinstance(solution_func_name, str):
raise ValueError(f"solution_function code validation parameter should be str, instead found {type(solution_func_name)}")
found_solution_func = False
has_imports = False
for node in top_level_nodes:
if isinstance(node, ast.FunctionDef) and node.name == solution_func_name:
found_solution_func = True
if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
has_imports = True

if not found_solution_func:
raise ValueError(f"Generated code is missing the solution function: {code}")

if code_validations.get('allow_imports') is False and has_imports:
raise ValueError(f"Generated code has disallowed imports: {code}")

if code_validations.get('allow_non_math_operations') is False:
for node in ast.walk(code_tree):
#if type(node) not in (ast.Assign, ast.FunctionDef, ast., ast.Add)
pass

@classmethod
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
"""Load PAL from math prompt."""
llm_chain = LLMChain(llm=llm, prompt=MATH_PROMPT)
code_validations = DEFAULT_CODE_VALIDATIONS
disallow_non_math_operations = kwargs.get('disallow_non_math_operations')
if disallow_non_math_operations is True:
code_validations.update({'allow_non_math_operations': False})
return cls(
llm_chain=llm_chain,
stop="\n\n",
get_answer_expr="print(solution())",
code_validations = code_validations,
**kwargs,
)

Expand Down
50 changes: 50 additions & 0 deletions tests/unit_tests/chains/test_pal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Test LLM PAL functionality."""
import sys
sys.path.append(r"C:\projects\langchain")

import pytest

from langchain.chains.pal.base import PALChain
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.schema import OutputParserException
from tests.unit_tests.llms.fake_llm import FakeLLM

_SAMPLE_CODE = """
def solution():
\"\"\"Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\"\"\"
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
"""


_SAMPLE_CODE_2_LINES = """
Unrelated text
```bash
echo hello
echo world
```
Unrelated text
"""

FULL_CODE_VALIDATIONS = {'solution_function': 'solution', 'allow_imports': False, 'allow_non_solution_root_scope_expressions': False, 'allow_non_math_operations': False}

def test_simple_question() -> None:
"""Test simple question."""
question = "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?"
prompt = MATH_PROMPT.format(question=question)
queries = {prompt: _SAMPLE_CODE}
fake_llm = FakeLLM(queries=queries)
fake_pal_chain = PALChain.from_math_prompt(fake_llm)
output = fake_pal_chain.run(question)
assert output == "8"


def test_get_code() -> None:
"""Test the validator."""
PALChain.validate_code(_SAMPLE_CODE, FULL_CODE_VALIDATIONS)

0 comments on commit 05aefc6

Please sign in to comment.