# Summary

Experimenting with making a Prompt object. Goals:
- simplify the call to resolve template + arg(s)
- allow computed values? (e.g. accept arg x and then fill another field with x+3 or x.upper()}
- maybe define postprocessing/completion validation steps?

In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [228]:
from copy import deepcopy
import importlib
from inspect import Parameter, Signature
import keyword
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import string
import warnings

from jabberwocky.config import C
from jabberwocky.openai_utils import load_prompt, load_openai_api_key, GPT
from jabberwocky.utils import load_yaml
from htools import *

In [4]:
cd_root()

Current directory: /Users/hmamin/jabberwocky


In [84]:
def import_function(qualname):
    module_name, _, func_name = qualname.rpartition('.')
    module = importlib.import_module(module_name)
    return getattr(module, func_name)

WIP: prompt class that lets us specify postprocessing and validation steps in our prompt.yaml. Issues:

- using this to create multiple queries (i.e. equivalent to passing a list of strings to GPT.query which will then be passed to `_query_batch`) was a little clunky. Still figuring the best way to handle this.
- right now, resolving a prompt (passing in fields to fill in the template) always appends to self.cache, which GPT.query would then use to retrieve the fully resolved prompts. But it's not clear when we should clear the cache, when we should add to it (currently always, but what if we just want to test out a resolution for debugging or sanity checks?), etc.
- Should strip_output option still be in gpt.query? Or should that be specified in prompt.yaml when we want it?

TLDR: I think in the interest of wrapping up the alexa portion of the project in a timely manner, I should put this on hold. There's no immediate need for it and rewriting my openai wrapper library is honestly a whole project in itself (admittedly, one I've already done a lot on - all that gpt.query refactoring falls into that category IMO).

In [514]:
class Prompt:
    """
    WIP. Was still figuring out how to incorporate this into GPT.query. This
    was what I had for a start but it would need more tweaking.
    
    ```
    if isinstance(prompt, Prompt):
        kwargs = prompt.kwargs(**kwargs)
        kwargs['prompt'] = prompt.get_cache()
    else:
        kwargs['prompt'] = prompt
    if listlike(kwargs['prompt']) and len(kwargs['prompt']) == 1:
        kwargs['prompt'] = kwargs['prompt'][0]
    ```
    """
    
    def __init__(self, template, name='', **kwargs):
        self.name = name
        self.defaults = kwargs.pop('defaults', {})
        if not isinstance(self.defaults, dict):
            raise TypeError('`defaults` must be a dict. If your prompt only '
                            'has one arg, name it "prompt".')
            
        self.postprocessors = kwargs.pop('postprocessors', [])
        self.validators = kwargs.pop('validators', [])
        self._kwargs = kwargs
        self.template, self.fields = self.parse_fields(template)
        self.signature = self.make_signature(self.fields, self.defaults)
        self.n_fields = len(self.fields)
        
        # TODO testing
        self.cache = []
            
    def make_signature(self, fields, defaults):
        if set(defaults) - set(fields):
            raise ValueError(f'You provided defaults for fields that do not '
                             f'exist. Defaults={defaults}, fields={fields}.')
        if len(fields) == 1:
            kind = Parameter.POSITIONAL_OR_KEYWORD
        else:
            kind = Parameter.KEYWORD_ONLY
        parameters = [Parameter(name=field, kind=kind,
                                default=defaults.get(field, Parameter.empty))
                      for field in fields]
        return Signature(parameters)
    
    def _resolve_fields(self, *args, **kwargs):
        """Combine user-specified args/kwargs with default field values and
        return dict of resolved key-value pairs.
        """
        error_base = f'Prompt {self.name} has {self.n_fields} fields: ' \
            f'{self.fields}.'
        if args and self.n_fields > 1:
            raise RuntimeError(
                'You must use keyword arguments when n_fields > 1 but you '
                f'specified 1 or more positional args: {args}. ' + error_base
            )
        try:
            sig = self.signature.bind(*args, **kwargs)
        except TypeError as e:
            raise RuntimeError(
                'You provided 1 or more unexpected kwargs: '
                f'{set(kwargs) - self.fields}. {error_base}'
            ) from e

        sig.apply_defaults()
        return sig.arguments
            
    def postprocess(self, text):
        for func in self.postprocessors:
            text = func(text)
        return text
    
    def validate(self, text):
        return all(func(text) for func in self.validators)
            
    @classmethod
    def load(cls, name, prompt_dir=C.root/'data/prompts', **kwargs):
        config = load_prompt(name, prompt_dir=prompt_dir, **kwargs)
        template = config.pop('prompt')
        config['validators'] = cls.functions_from_config(config, 
                                                         'validators')
        config['postprocessors'] = cls.functions_from_config(
            config, 'postprocessors'
        )
        return cls(template, name=name, **config)

    @staticmethod
    def functions_from_config(config, key):
        return [partial(import_function(row.pop('name')), **row) 
                for row in config.get(key, [])]
    
    def parse_fields(self, template):
        parser = string.Formatter()
        fields = [x[1] for x in parser.parse(template) if x[1] is not None]
        unique_fields = set(fields)
        default = 'prompt'
        if '' in fields:
            if len(unique_fields) > 1:
                raise RuntimeError(
                    'Found unnamed arg in prompt template. This is only '
                    'valid when a prompt expects <=1 unique fields but yours '
                    f'expects {len(unique_fields)}.'
                )
            template = template.replace('{}', '{' + default + '}')
            unique_fields.discard('')
            unique_fields.add(default)
        if len(fields) > 1 and all(field == default 
                                   for field in unique_fields):
            warnings.warn('Found multiple fields to fill but all have the '
                          'same name. Are you sure that\'s intentional?')
        return template, unique_fields
    
    def kwargs(self, extra_kwargs=None, 
#                return_prompt=False, # TODO
               **kwargs):
        base_kwargs = {**self._kwargs, **kwargs}
        for k, v in (extra_kwargs or {}).items():
            try:
                base_val = deepcopy(base_kwargs[k])
            except KeyError as e:
                raise KeyError(f'Extra kwarg {k} not present in default '
                               'kwargs.')
            
            if isinstance(base_val, Iterable):
                assert listlike(v), f'Extra kwargs for {k} should be ' \
                    f'list-like, not {type(v)}.'
                base_val = list(base_val)
                base_val.extend(v)
            elif isinstance(base_val, Mapping):
                assert isinstance(v, Mapping), f'Extra kwargs for {k} should'\
                    f' be dict-like, not {type(v)}.'
                base_val.update(v)
            else:
                raise TypeError(
                    'Extra_kwargs are only available for args that are lists '
                    f'or dicts, but you tried to use them with `{k}` which '
                    f'has type {type(base_val)}'
                )
            base_kwargs[k] = base_val
            
        # TODO
#         if return_prompt:
#             assert self.ever_resolved, \
#                 'Return_prompt=True is only valid if you\'ve previously ' \
#                 'resolved your prompt, i.e. prompt.resolve().'
#             base_kwargs['prompt'] = self.fully_resolved
        return base_kwargs

    def resolve_template(self, *args, **kwargs):
        """Fill template fields with user-specified args/kwargs and return 
        string containing fully resolved template.
        """
        resolved_fields = self._resolve_fields(*args, **kwargs)
        resolved = self.template.format(**resolved_fields)
        self.cache.append(resolved)
        return resolved
    
    def resolve_templates(self, rows=(None,)):
        for row in rows:
            if listlike(row):
                raise TypeError('Each row must be a dict or primitive.')
            if isinstance(row, Mapping):
                _ = self.resolve_template(**row)
            else:
                _ = self.resolve_template(row)
        return list(self.cache)
    
    def resolve(self, *args, **kwargs):
        """Resolve template and return the prompt instance itself. This
        retains access to query kwargs, postprocessors, and validators.
        
        For a sample prompt with one field ("name"), you can create a single
        resolution like this:
        prompt.resolve("casey")
        prompt.resolve(name="casey")
        
        Or multiple resolutions like this:
        prompt.resolve(["casey", "jamie"])
        prompt.resolve([{"name": "casey"}, {"name": "jamie"}])
        
        
        For a sample prompt with two fields ("name" and "age"), you can create
        a single resolution like this:
        prompt.resolve(name="casey", age=40)
        
        Or multiple resolutions like this:
        prompt.resolve([{"name": "casey", "age": 40}, 
                        {"name": "jamie", "age": 10}])
        """
        # TODO: what if user wants to cache a query 1 at a time but not use
        # immediately, sort of like in conv manager's query_later? Might want
        # clearing to be optional.
        self.cache.clear()
        if listlike(args[0]):
            self.resolve_templates(args[0])
        else:
            self.resolve_template(*args, **kwargs)
        return self
        
    def get_cache(self):
        res = list(self.cache)
        self.cache.clear()
        return res
    
    def __repr__(self):
        return f'{type(self).__name__}(\n\tname={repr(self.name)},'\
               f'\n\tprompt={repr(self.template)},'\
               f'\n\tdefaults={self.defaults},'\
               f'\n\t_kwargs={repr(self._kwargs)}\n)'

    def __str__(self):
        return str(self.template)
    
    def __getattr__(self, key):
        return self._kwargs[key]

Object loaded from /Users/hmamin/jabberwocky/data/misc/sample_response.pkl.
Object loaded from /Users/hmamin/jabberwocky/data/misc/sample_stream_response.pkl.
Object loaded from /Users/hmamin/jabberwocky/data/misc/gooseai_sample_responses.pkl.


In [503]:
Prompt('Who are {}? {}')



Prompt(
	name='',
	prompt='Who are {prompt}? {prompt}',
	defaults={},
	_kwargs={}
)

In [504]:
Prompt('Who are {}?')

Prompt(
	name='',
	prompt='Who are {prompt}?',
	defaults={},
	_kwargs={}
)

In [505]:
with assert_raises(Exception):
    Prompt('Who are {}?', defaults={'fake': 3})

As expected, got Exception(You provided defaults for fields that do not exist. Defaults={'fake': 3}, fields={'prompt'}.).


In [506]:
with assert_raises(Exception):
    Prompt('Who are {}? {name}')

As expected, got Exception(Found unnamed arg in prompt template. This is only valid when a prompt expects <=1 unique fields but yours expects 2.).


In [507]:
tmp = Prompt.load('tmp')
tmp

tmp: Max_tokens probably should change as a result of the prompt args but this is just for testing purposes. Should probably delete eventually.
-------------------------------------------------------------------------------



Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [508]:
tmp.resolve_templates([{'n': 4}, {'n': 9, 'quality': 'zany'}])

['Make a list of 4 names that sound old:\n\n1.',
 'Make a list of 9 names that sound zany:\n\n1.']

In [509]:
tmp.kwargs()

{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}

In [510]:
tmp.kwargs(stop=['3', '4'])

{'engine': 0, 'max_tokens': 15, 'stop': ['3', '4']}

In [511]:
tmp.kwargs(extra_kwargs={'stop': ['\n\n\n333']})

{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ', '\n\n\n333']}

In [472]:
with assert_raises(Exception):
    tmp.kwargs(extra_kwargs={'stop': ['\n\n\n333']}, return_prompt=True)

As expected, got Exception(Return_prompt=True is only valid if you've previously resolved your prompt, i.e. prompt.resolve().).


In [473]:
with assert_raises(Exception):
    tmp.kwargs(extra_kwargs={'stop': {'wrong_type': '\n\n\n333'}})

As expected, got Exception(Extra kwargs for stop should be list-like, not <class 'dict'>.).


In [474]:
with assert_raises(KeyError):
    tmp.kwargs(temperature=.99, extra_kwargs={'logprobs': True})

As expected, got KeyError('Extra kwarg logprobs not present in default kwargs.').


In [475]:
tmp.kwargs(temperature=.99, logit_bias={200: 1})

{'engine': 0,
 'max_tokens': 15,
 'stop': ['1.', '\n1.', '\n1. '],
 'temperature': 0.99,
 'logit_bias': {200: 1}}

In [476]:
tmp.kwargs(temperature=.3)

{'engine': 0,
 'max_tokens': 15,
 'stop': ['1.', '\n1.', '\n1. '],
 'temperature': 0.3}

In [477]:
with assert_raises(RuntimeError):
    tmp.resolve(4)

As expected, got RuntimeError(You must use keyword arguments when n_fields > 1 but you specified 1 or more positional args: (4,). Prompt tmp has 2 fields: {'quality', 'n'}.).


In [478]:
with assert_raises(RuntimeError):
    tmp.resolve(z=4)

As expected, got RuntimeError(You provided 1 or more unexpected kwargs: {'z'}. Prompt tmp has 2 fields: {'quality', 'n'}.).


In [480]:
tmp.kwargs()

{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}

In [481]:
tmp.resolve()

Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [482]:
tmp.kwargs(return_prompt=True)

{'engine': 0,
 'max_tokens': 15,
 'stop': ['1.', '\n1.', '\n1. '],
 'prompt': 'Make a list of 3 names that sound old:\n\n1.'}

In [455]:
tmp.resolve(n=9)

Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [457]:
tmp.resolve(quality='happy')

Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [458]:
tmp.resolve(n=2, quality='happy')

Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'},
	_kwargs={'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [396]:
with assert_raises(Exception):
    tmp.resolve(n=2, quality='happy', dense=True)

As expected, got Exception(You provided 1 or more unexpected kwargs: {'dense'}. Prompt tmp has 2 fields: {'quality', 'n'}.).


In [479]:
tmp.postprocessors

[functools.partial(<function reattach_1 at 0x12f191bf8>),
 functools.partial(<function indent at 0x12f191ea0>, broken=False),
 functools.partial(<function swapcase at 0x12f191c80>)]

In [452]:
tmp.validators

[functools.partial(<function test_indented at 0x12f191048>, mode='all')]

In [406]:
with GPT('banana'):
    res = GPT.query(tmp.resolve(n=5, quality='spacy'), 
                    **tmp.kwargs(), 
                    strip_output=False)

Switching openai backend to "banana".
{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. '], 'prompt': 'Make a list of 5 names that sound spacy:\n\n1.', 'meta': {'backend_name': 'banana', 'query_func': 'query_gpt_banana', 'datetime': 'Sun May 15 13:23:03 2022'}}
Switching  backend back to "openai".


In [405]:
print(tmp.last_resolved + res[0][0])

Make a list of 5 names that sound spacy:

1.Brine

2. Clay

3. Kaya


In [407]:
print(tmp.last_resolved + res[0][0])

Make a list of 5 names that sound spacy:

1. Ayn Rand

2. JK Rowling

3. Stephen


In [408]:
bound_args(GPT.query, [], {})

OrderedDict([('strip_output', True),
             ('log', True),
             ('optimize_cost', False),
             ('subwords', True),
             ('drop_fragment', False),
             ('engine', 0),
             ('temperature', 0.7),
             ('top_p', 1.0),
             ('frequency_penalty', 0.0),
             ('presence_penalty', 0.0),
             ('max_tokens', 50),
             ('logprobs', None),
             ('n', 1),
             ('stream', False),
             ('logit_bias', None)])

In [371]:
postprocessed = tmp.postprocess(res[0][0])
print(postprocessed)

	1. wHAT DO YOU DO FOR A LIVING?
	
	2.


In [219]:
tmp.validate(postprocessed)

True

In [37]:
config = load_prompt('test_professional')
config

test_professional: This doesn't actually work very well but the settings are a good example of what is needed for my envisioned natural language tests.
-------------------------------------------------------------------------------



{'engine': 1,
 'temperature': 0.0,
 'max_tokens': 1,
 'presence_penalty': 2,
 'frequency_penalty': 2,
 'logprobs': 2,
 'stop': ['\n'],
 'logit_bias': {3763: 100, 645: 100},
 'validation': [{'name': 'tmp.is_yes_or_no', 'strip': True, 'lower': False},
  {'name': 'requests.get'}],
 'prompt': 'Does the following email maintain a professional tone? (Yes/No)\n\nEmail:\n{} \nAnswer:'}

In [45]:
template = '{age} is younger than {age+1}'

In [47]:
template.format(age=9)

KeyError: 'age+1'

## V2

Trying to allow us to specify defaults directly inside the curly braces rather than having a separate defaults section in the yaml file.

In [519]:
"""WIP: revised attempt at allowing a string to provide default values to fallback on
when str.format() is called. For gpt prompts I think a real DSL would be useful thought -
e.g. something like
{name} - a regular str provided by user
{age=3} - optional arg, can be user provided or not
{bio=<jabberwocky.external_data.wiki_data(name, age)>} - syntax is a wip but here the slot is filled with a python function output.
{examples=[how_to(task)]} - syntax also wip but slot is filled with another gpt completion
"""

'WIP: revised attempt at allowing a string to provide default values to fallback on\nwhen str.format() is called. For gpt prompts I think a real DSL would be useful thought -\ne.g. something like\n{name} - a regular str provided by user\n{age=3} - optional arg, can be user provided or not\n{bio=<jabberwocky.external_data.wiki_data(name, age)>} - syntax is a wip but here the slot is filled with a python function output.\n{examples=[how_to(task)]} - syntax also wip but slot is filled with another gpt completion\n'

In [521]:
import inspect

In [523]:
class Prompt:
    def __init__(self, template, default_key='arg'):
        self.raw_template = template
        self.template, self.sig = self._parse(template, 
                                              default_key=default_key)
        
    def format(self, *args, **kwargs):
        resolved_kwargs = self._resolve_args(*args, **kwargs)
        return self.template.format(**resolved_kwargs)
        
    def _resolve_args(self, *args, **kwargs):
        sig = self.sig.bind(*args, **kwargs)
        sig.apply_defaults()
        return sig.arguments
        
    def _parse(self, fmt, default_key='arg'):
        fields = {}
        new_fmt = fmt
        used_unnamed = False
        used_default = False
        error_msg = 'You have both an unnamed field and  ' \
                    f'a field with the default name {default_key}. ' \
                    'Either provide a name for the unnamed field or ' \
                    'rename the default one to avoid a collision.'
        for _, field, *_ in F.parse(fmt):
            if field is None: continue
            parts = field.split('=')
            key, val, *_ = parts + [inspect.Parameter.empty]
            if val != inspect.Parameter.empty:
                new_fmt = new_fmt.replace('='.join(parts), key)
            if key == default_key:
                if used_unnamed:
                    raise ValueError(error_msg)
                used_default = True
            if not key:
                if used_default:
                    raise ValueError(error_msg)
                else:
                    used_unnamed = True
                    key = default_key
                    new_fmt = new_fmt.replace('{}', '{' + key + '}')
            if key in fields:
                assert val == fields[key], \
                    'Each field can only have 1 default value. '\
                    f'Field "{key}" has"{val}" and "{field[key]}".'           
            fields[key] = val
        if len(fields) <= 1:
            kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
        else:
            kind = inspect.Parameter.KEYWORD_ONLY        
        return new_fmt, inspect.Signature(inspect.Parameter(name=k, kind=kind, default=v)
                                          for k, v in fields.items())
                                      
    def __str__(self):
        return self.raw_template

## V3

Experimenting with jinja to see if this could be a good way to do this.

In [582]:
from jinja2 import Template, StrictUndefined, meta, Environment

In [550]:
fmt = """
This is a conversation with {{ name }}. {{ name }} is {{ age }} years old.
In 2000 {{ pronoun }} was {{ age - 22 }} years old.
""".strip()

temp = Template(fmt, undefined=StrictUndefined)
temp

<Template memory:13ab5e198>

In [539]:
with assert_raises(Exception):
    temp.render()

As expected, got Exception('age' is undefined).


In [540]:
temp.render(name='Julia', age=23, pronoun='she')

'This is a conversation with Julia. Julia is 23 years old.\nIn 2000 she was 1 years old.'

In [570]:
def generate_synonyms(n, word):
    # Mock/placeholder for something like
    # prompt_manager.query(task='generate_synonyms', 
    #                      prompt={'n': n, 'word': word})
    return '\n'.join('- ' + word[i:] for i in range(1, n + 1))

In [571]:
print(generate_synonyms(4, 'fluid'))

- luid
- uid
- id
- d


In [644]:
# Seems we need to provide default value every time an arg appears rather than
# defining it once - not ideal.
fmt = """
This is a list of {{ n | default(5) }} synonyms for the word "{{ word }}":

{{ generate_synonyms(n, word) }}

Reorder the list by similarity to the original word, with the most similar 
word coming first. {% if n > top_n|default(3) %} Keep only the {{ top_n|default(3) }} most similar synonyms. {% endif %}
""".strip()

temp = Template(fmt, undefined=StrictUndefined)
temp

<Template memory:13ab71f28>

In [645]:
temp.render()

UndefinedError: 'generate_synonyms' is undefined

In [647]:
res = temp.render(n=5, word='scrupulous',
                  generate_synonyms=generate_synonyms)
print(res)

This is a list of 5 synonyms for the word "scrupulous":

- crupulous
- rupulous
- upulous
- pulous
- ulous

Reorder the list by similarity to the original word, with the most similar 
word coming first.  Keep only the 3 most similar synonyms. 


In [648]:
# Notice no warning to keep only the top n because n < top_n.
res = temp.render(n=2, word='scrupulous', generate_synonyms=generate_synonyms)
print(res)

This is a list of 2 synonyms for the word "scrupulous":

- crupulous
- rupulous

Reorder the list by similarity to the original word, with the most similar 
word coming first. 


In [587]:
def jinja_template_fields(fmt):
    ast = Environment().parse(fmt)
    return list(meta.find_undeclared_variables(ast))

In [588]:
jinja_template_fields(fmt)

['word', 'generate_synonyms', 'n']

**Takeaways**

Not bad but default values work poorly and error messages could be better. Maybe resolved with some custom subclasses.

## V4

Mako - another templating engine that seems to natively support more logic than jinja.

In [589]:
from mako.template import Template as MTemplate

In [685]:
mfmt = """
<%page args="n, word, top_n=3"/>
This is a list of ${n} synonyms for the word "${word}":

${generate_synonyms(n, word)}

Reorder the list by similarity to the original word, with the most similar 
word coming first. 
% if n > top_n:
Keep only the ${top_n} most similar synonyms.
% endif
""".strip()

In [686]:
mtemp = MTemplate(mfmt)

In [687]:
# Error message here is much more vague than jinja's.
res = mtemp.render()

TypeError: render_body() missing 2 required positional arguments: 'n' and 'word'

In [689]:
res = mtemp.render(word='alacrity', n=4,
                   generate_synonyms=generate_synonyms)
print(res)


This is a list of 4 synonyms for the word "alacrity":

- lacrity
- acrity
- crity
- rity

Reorder the list by similarity to the original word, with the most similar 
word coming first. 
Keep only the 3 most similar synonyms.



**Takeaways**

Having more logic support is cool but it's hard to get spacing right, which can be important for prompts, and error messages are much less informative than I'd like.