# Summary

Experimenting with making a Prompt object. Goals:
- simplify the call to resolve template + arg(s)
- allow computed values? (e.g. accept arg x and then fill another field with x+3 or x.upper()}
- maybe define postprocessing/completion validation steps?

In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [228]:
from copy import deepcopy
import importlib
from inspect import Parameter, Signature
import keyword
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import string
import warnings

from jabberwocky.config import C
from jabberwocky.openai_utils import load_prompt, load_openai_api_key, GPT
from jabberwocky.utils import load_yaml
from htools import *

In [4]:
cd_root()

Current directory: /Users/hmamin/jabberwocky


In [84]:
def import_function(qualname):
    module_name, _, func_name = qualname.rpartition('.')
    module = importlib.import_module(module_name)
    return getattr(module, func_name)

In [435]:
# @delegate('template', iter_magics=True) # TODO rm?
class Prompt:
    
    def __init__(self, template, name='', **kwargs):
        self.name = name
        self.defaults = kwargs.pop('defaults', {})
        if not isinstance(self.defaults, dict):
            raise TypeError('`defaults` must be a dict. If your prompt only '
                            'has one arg, name it "prompt".')
            
        self.postprocessors = kwargs.pop('postprocessors', [])
        self.validators = kwargs.pop('validators', [])
        self._kwargs = kwargs
        self.template, self.fields = self.parse_fields(template)
        self.signature = self.make_signature(self.fields, self.defaults)
        self.n_fields = len(self.fields)
        self.fully_resolved = ''
            
    def make_signature(self, fields, defaults):
        if set(defaults) - set(fields):
            raise ValueError(f'You provided defaults for fields that do not '
                             f'exist. Defaults={defaults}, fields={fields}.')
        if len(fields) == 1:
            kind = Parameter.POSITIONAL_OR_KEYWORD
        else:
            kind = Parameter.KEYWORD_ONLY
        parameters = [Parameter(name=field, kind=kind,
                                default=defaults.get(field, Parameter.empty))
                      for field in fields]
        return Signature(parameters)
    
    def _resolve_kwargs(self, *args, **kwargs):
        error_base = f'Prompt {self.name} has {self.n_fields} fields: ' \
            f'{self.fields}.'
        if args and self.n_fields > 1:
            raise RuntimeError(
                'You must use keyword arguments when n_fields > 1 but you '
                f'specified 1 or more positional args: {args}. ' + error_base
            )
        try:
            sig = self.signature.bind(*args, **kwargs)
        except TypeError as e:
            raise RuntimeError(
                'You provided 1 or more unexpected kwargs: '
                f'{set(kwargs) - self.fields}. {error_base}'
            ) from e

        sig.apply_defaults()
        return sig.arguments
            
    def postprocess(self, text):
        for func in self.postprocessors:
            text = func(text)
        return text
    
    def validate(self, text):
        return all(func(text) for func in self.validators)
            
    @classmethod
    def load(cls, name, prompt_dir=C.root/'data/prompts', **kwargs):
        config = load_prompt(name, prompt_dir=prompt_dir, **kwargs)
        template = config.pop('prompt')
        config['validators'] = cls.functions_from_config(config, 
                                                         'validators')
        config['postprocessors'] = cls.functions_from_config(
            config, 'postprocessors'
        )
        return cls(template, name=name, **config)

    @staticmethod
    def functions_from_config(config, key):
        return [partial(import_function(row.pop('name')), **row) 
                for row in config.get(key, [])]
    
    def parse_fields(self, template):
        parser = string.Formatter()
        fields = [x[1] for x in parser.parse(template) if x[1] is not None]
        unique_fields = set(fields)
        default = 'prompt'
        if '' in fields:
            if len(unique_fields) > 1:
                raise RuntimeError(
                    'Found unnamed arg in prompt template. This is only '
                    'valid when a prompt expects <=1 unique fields but yours '
                    f'expects {len(unique_fields)}.'
                )
            template = template.replace('{}', '{' + default + '}')
            unique_fields.discard('')
            unique_fields.add(default)
        if len(fields) > 1 and all(field == default 
                                   for field in unique_fields):
            warnings.warn('Found multiple fields to fill but all have the '
                          'same name. Are you sure that\'s intentional?')
        return template, unique_fields
    
    def kwargs(self, extra_kwargs=None, **kwargs):
        base_kwargs = {**self._kwargs, **kwargs}
        for k, v in (extra_kwargs or {}).items():
            try:
                base_val = deepcopy(base_kwargs[k])
            except KeyError as e:
                raise KeyError(f'Extra kwarg {k} not present in default '
                               'kwargs.')
            
            if isinstance(base_val, Iterable):
                assert listlike(v), f'Extra kwargs for {k} should be ' \
                    f'list-like, not {type(v)}.'
                base_val = list(base_val)
                base_val.extend(v)
            elif isinstance(base_val, Mapping):
                assert isinstance(v, Mapping), f'Extra kwargs for {k} should'\
                    f' be dict-like, not {type(v)}.'
                base_val.update(v)
            else:
                raise TypeError(
                    'Extra_kwargs are only available for args that are lists '
                    f'or dicts, but you tried to use them with `{k}` which '
                    f'has type {type(base_val)}'
                )
            base_kwargs[k] = base_val
        return base_kwargs

#     def resolve(self, *args, **kwargs):
#         resolved_kwargs = self._resolve_kwargs(*args, **kwargs)
#         resolved = self.template.format(**resolved_kwargs)
#         self.last_resolved = resolved
#         return resolved

    def resolve_template(self, *args, **kwargs):
        resolved_kwargs = self._resolve_kwargs(*args, **kwargs)
        resolved = self.template.format(**resolved_kwargs)
        self.fully_resolved = resolved
        return resolved
    
    def resolve(self, *args, **kwargs):
        _ = self.resolve_template(*args, **kwargs)
        return self
        
    def __repr__(self):
        return f'{type(self).__name__}(\n\tname={repr(self.name)},'\
               f'\n\tprompt={repr(self.template)},'\
               f'\n\tdefaults={self.defaults},'\
               f'\n\t_kwargs={repr(self._kwargs)}\n)'

    def __str__(self):
        return str(self.template)
    
    def __getattr__(self, key):
        return self._kwargs[key]

In [436]:
class ResolvedPrompt(Prompt):
    
    def __init__(self, prompt, query_kwargs):
        self.prompt = prompt
        self.query_kwargs = query_kwargs
        
    def postprocess(self, text):
        return self.prompt.postprocess(text)
    
    def validate(self, text):
        return self.prompt.validate(text)
    

In [437]:
Prompt('Who are {}? {}')



Prompt(
	name='',
	prompt='Who are {prompt}? {prompt}',
	defaults={},
	_kwargs={}
)

In [438]:
Prompt('Who are {}?')

Prompt(
	name='',
	prompt='Who are {prompt}?',
	defaults={},
	_kwargs={}
)

In [439]:
with assert_raises(Exception):
    Prompt('Who are {}?', defaults={'fake': 3})

As expected, got Exception(You provided defaults for fields that do not exist. Defaults={'fake': 3}, fields={'prompt'}.).


In [440]:
with assert_raises(Exception):
    Prompt('Who are {}? {name}')

As expected, got Exception(Found unnamed arg in prompt template. This is only valid when a prompt expects <=1 unique fields but yours expects 2.).


In [441]:
tmp = Prompt.load('tmp')
tmp

tmp: Max_tokens probably should change as a result of the prompt args but this is just for testing purposes. Should probably delete eventually.
-------------------------------------------------------------------------------



Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [442]:
tmp.kwargs()

{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}

In [443]:
tmp.kwargs(stop=['3', '4'])

{'engine': 0, 'max_tokens': 15, 'stop': ['3', '4']}

In [444]:
tmp.kwargs(extra_kwargs={'stop': ['\n\n\n333']})

{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ', '\n\n\n333']}

In [445]:
with assert_raises(Exception):
    tmp.kwargs(extra_kwargs={'stop': {'wrong_type': '\n\n\n333'}})

As expected, got Exception(Extra kwargs for stop should be list-like, not <class 'dict'>.).


In [446]:
with assert_raises(KeyError):
    tmp.kwargs(temperature=.99, extra_kwargs={'logprobs': True})

As expected, got KeyError('Extra kwarg logprobs not present in default kwargs.').


In [447]:
tmp.kwargs(temperature=.99, logit_bias={200: 1})

{'engine': 0,
 'max_tokens': 15,
 'stop': ['1.', '\n1.', '\n1. '],
 'temperature': 0.99,
 'logit_bias': {200: 1}}

In [448]:
tmp.kwargs(temperature=.3)

{'engine': 0,
 'max_tokens': 15,
 'stop': ['1.', '\n1.', '\n1. '],
 'temperature': 0.3}

In [449]:
with assert_raises(RuntimeError):
    tmp.resolve(4)

As expected, got RuntimeError(You must use keyword arguments when n_fields > 1 but you specified 1 or more positional args: (4,). Prompt tmp has 2 fields: {'quality', 'n'}.).


In [450]:
with assert_raises(RuntimeError):
    tmp.resolve(z=4)

As expected, got RuntimeError(You provided 1 or more unexpected kwargs: {'z'}. Prompt tmp has 2 fields: {'quality', 'n'}.).


In [451]:
tmp.postprocessors

[functools.partial(<function reattach_1 at 0x12f191bf8>),
 functools.partial(<function indent at 0x12f191ea0>, broken=False),
 functools.partial(<function swapcase at 0x12f191c80>)]

In [452]:
tmp.validators

[functools.partial(<function test_indented at 0x12f191048>, mode='all')]

In [453]:
tmp.kwargs()

{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}

In [454]:
tmp.resolve()

Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [455]:
tmp.resolve(n=9)

Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [457]:
tmp.resolve(quality='happy')

Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [458]:
tmp.resolve(n=2, quality='happy')

Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [396]:
with assert_raises(Exception):
    tmp.resolve(n=2, quality='happy', dense=True)

As expected, got Exception(You provided 1 or more unexpected kwargs: {'dense'}. Prompt tmp has 2 fields: {'quality', 'n'}.).


In [406]:
with GPT('banana'):
    res = GPT.query(tmp.resolve(n=5, quality='spacy'), 
                    **tmp.kwargs(), 
                    strip_output=False)

Switching openai backend to "banana".
{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. '], 'prompt': 'Make a list of 5 names that sound spacy:\n\n1.', 'meta': {'backend_name': 'banana', 'query_func': 'query_gpt_banana', 'datetime': 'Sun May 15 13:23:03 2022'}}
Switching  backend back to "openai".


In [405]:
print(tmp.last_resolved + res[0][0])

Make a list of 5 names that sound spacy:

1.Brine

2. Clay

3. Kaya


In [407]:
print(tmp.last_resolved + res[0][0])

Make a list of 5 names that sound spacy:

1. Ayn Rand

2. JK Rowling

3. Stephen


In [408]:
bound_args(GPT.query, [], {})

OrderedDict([('strip_output', True),
             ('log', True),
             ('optimize_cost', False),
             ('subwords', True),
             ('drop_fragment', False),
             ('engine', 0),
             ('temperature', 0.7),
             ('top_p', 1.0),
             ('frequency_penalty', 0.0),
             ('presence_penalty', 0.0),
             ('max_tokens', 50),
             ('logprobs', None),
             ('n', 1),
             ('stream', False),
             ('logit_bias', None)])

In [371]:
postprocessed = tmp.postprocess(res[0][0])
print(postprocessed)

	1. wHAT DO YOU DO FOR A LIVING?
	
	2.


In [219]:
tmp.validate(postprocessed)

True

In [37]:
config = load_prompt('test_professional')
config

test_professional: This doesn't actually work very well but the settings are a good example of what is needed for my envisioned natural language tests.
-------------------------------------------------------------------------------



{'engine': 1,
 'temperature': 0.0,
 'max_tokens': 1,
 'presence_penalty': 2,
 'frequency_penalty': 2,
 'logprobs': 2,
 'stop': ['\n'],
 'logit_bias': {3763: 100, 645: 100},
 'validation': [{'name': 'tmp.is_yes_or_no', 'strip': True, 'lower': False},
  {'name': 'requests.get'}],
 'prompt': 'Does the following email maintain a professional tone? (Yes/No)\n\nEmail:\n{} \nAnswer:'}

In [45]:
template = '{age} is younger than {age+1}'

In [47]:
template.format(age=9)

KeyError: 'age+1'