# Summary

Experimenting with making a Prompt object. Goals:
- simplify the call to resolve template + arg(s)
- allow computed values? (e.g. accept arg x and then fill another field with x+3 or x.upper()}
- maybe define postprocessing/completion validation steps?

In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [74]:
import importlib
import keyword
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
import string
import warnings

from jabberwocky.config import C
from jabberwocky.openai_utils import load_prompt, load_openai_api_key, GPT
from jabberwocky.utils import load_yaml
from htools import *

In [4]:
cd_root()

Current directory: /Users/hmamin/jabberwocky


In [84]:
def import_function(qualname):
    module_name, _, func_name = qualname.rpartition('.')
    module = importlib.import_module(module_name)
    return getattr(module, func_name)

In [134]:
# @delegate('template', iter_magics=True) # TODO rm?
class Prompt:
    def __init__(self, template, name='', **kwargs):
        self.name = name
        self.defaults = kwargs.pop('defaults', {})
        self.postprocessors = kwargs.pop('postprocessors', [])
        self.validators = kwargs.pop('validators', [])
        self.kwargs = kwargs
        self.template, self.fields = self.parse_fields(template)
        self.n_fields = len(self.fields)
        if not isinstance(self.defaults, dict):
            raise TypeError('`defaults` must be a dict. If your prompt only '
                            'has one arg, name it "arg".')
            
    def postprocess(self, text):
        for func in self.postprocessors:
            text = func(text)
        return text
    
    def validate(self, text):
        return all(func(text) for func in self.validators)
            
    @classmethod
    def load(cls, name, prompt_dir=C.root/'data/prompts', **kwargs):
        config = load_prompt(name, prompt_dir=prompt_dir, **kwargs)
        template = config.pop('prompt')
        config['validators'] = cls.functions_from_config(config, 
                                                         'validators')
        config['postprocessors'] = cls.functions_from_config(
            config, 'postprocessors'
        )
        return cls(template, name=name, **config)

    @staticmethod
    def functions_from_config(config, key):
        return [partial(import_function(row.pop('name')), **row) 
                for row in config.get(key, [])]
    
    def parse_fields(self, template):
        parser = string.Formatter()
        fields = set(x[1] for x in parser.parse(template) if x[1] is not None)
        if '' in fields:
            fields.discard('')
            fields.add('arg')
            template = template.replace('{}', '{arg}')
        return template, fields

    def resolve(self, data=None):
        if data and not isinstance(data, (str, dict)):
            raise TypeError(f'data should have type str, dict, or NoneType. '
                            f' Got type {type(data)} instead.')
        if isinstance(data, str):
            data = {'arg': data}
        data = data or {}
        data = {**self.defaults, **data}
        if self.n_fields > len(data):
            raise ValueError(
                f'Your prompt takes args: {self.fields}. '
                'Between default values and arguments to format(), we were '
                f'only able to resolve args: {data}.'
            )
        if len(data) > self.n_fields:
            warnings.warn(
                'You provided more args than your template expects. '
                f'Expected to see args: {self.fields}. Got args: {data}.'
            )
        return self.template.format(**data)

    def __repr__(self):
        return f'{type(self).__name__}(\n\tname={repr(self.name)},'\
               f'\n\tprompt={repr(self.template)},'\
               f'\n\tdefaults={self.defaults}'\
               f'\n\t{repr(self.kwargs)}\n)'

    def __str__(self):
        return str(self.template)
    
    def __getitem__(self, key):
        return self.kwargs[key]

In [135]:
tmp = Prompt.load('tmp')
tmp

tmp: Max_tokens probably should change as a result of the prompt args but this is just for testing purposes. Should probably delete eventually.
-------------------------------------------------------------------------------



Prompt(
	name='tmp',
	prompt='Make a list of {n} names that sound {quality}:\n\n1.',
	defaults={'n': 3, 'quality': 'old'}
	{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}
)

In [136]:
tmp.postprocessors

[functools.partial(<function reattach_1 at 0x12f15c950>),
 functools.partial(<function indent at 0x12f15c6a8>, broken=False),
 functools.partial(<function swapcase at 0x12f15c8c8>)]

In [137]:
tmp.validators

[functools.partial(<function test_indented at 0x12f15c840>, mode='all')]

In [138]:
tmp.kwargs

{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. ']}

In [139]:
tmp.resolve()

'Make a list of 3 names that sound old:\n\n1.'

In [141]:
tmp.resolve({'n': 9})

'Make a list of 9 names that sound old:\n\n1.'

In [142]:
tmp.resolve({'quality': 'happy'})

'Make a list of 3 names that sound happy:\n\n1.'

In [99]:
with GPT('openai'):
    res = GPT.query(tmp.resolve({'n': 3, 'quality': 'spacy'}), 
                    **tmp.kwargs)
print(res)

Switching openai backend to "openai".
{'engine': 0, 'max_tokens': 15, 'stop': ['1.', '\n1.', '\n1. '], 'prompt': 'Make a list of 3 names that sound spacy:\n\n1.', 'meta': {'backend_name': 'openai', 'query_func': 'query_gpt3', 'datetime': 'Fri May 13 22:03:46 2022'}}
Switching  backend back to "openai".
(['invincibility\n2. invincibility\n3. invincibility'], [{'text': ' invincibility\n2. invincibility\n3. invincibility', 'index': 0, 'logprobs': None, 'finish_reason': 'length', 'prompt_index': 0}])


In [127]:
postprocessed = tmp.postprocess(res[0][0])
print(postprocessed)

	1. INVINCIBILITY
	2. INVINCIBILITY
	3. INVINCIBILITY


In [128]:
tmp.validate(postprocessed)

True

In [37]:
config = load_prompt('test_professional')
config

test_professional: This doesn't actually work very well but the settings are a good example of what is needed for my envisioned natural language tests.
-------------------------------------------------------------------------------



{'engine': 1,
 'temperature': 0.0,
 'max_tokens': 1,
 'presence_penalty': 2,
 'frequency_penalty': 2,
 'logprobs': 2,
 'stop': ['\n'],
 'logit_bias': {3763: 100, 645: 100},
 'validation': [{'name': 'tmp.is_yes_or_no', 'strip': True, 'lower': False},
  {'name': 'requests.get'}],
 'prompt': 'Does the following email maintain a professional tone? (Yes/No)\n\nEmail:\n{} \nAnswer:'}

In [45]:
template = '{age} is younger than {age+1}'

In [47]:
template.format(age=9)

KeyError: 'age+1'