# Summary

ConversationManager refactor attempt. Trying to change its interface so it can more effectively:
1. Support longer conversations via prompting with a subset of past responseses,
2. Support longer conversations via summarizing past conv, and
3. Still work with my GUI.

In [1]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
from itertools import zip_longest
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path

from jabberwocky.config import C
from jabberwocky.openai_utils import load_prompt, load_openai_api_key
from htools import *

In [3]:
cd_root()

Current directory: /Users/hmamin/jabberwocky


In [24]:
class ProtoConversationManager:
    """Similar to PromptManager but designed for ongoing conversations. This
    currently references just a single prompt: conversation.
    """

    img_exts = {'.jpg', '.jpeg', '.png'}

    def __init__(self, *names, verbose=True, data_dir='./data',
                 backup_image='data/misc/unknown_person.png', 
                 turn_window=4):
        assert 1 <= turn_window <= 20, 'turn_window should be in [1, 20].'
        
        # User window is adjusted because we'll be adding in the user's new
        # turn separately from accessing their historical turns.
        self.user_turn_window = int(np.ceil(turn_window / 2)) - 1
        self.gpt3_turn_window = turn_window - self.user_turn_window
        self.verbose = verbose
        
        # Set directories for data storage, logging, etc.
        self.backup_image = Path(backup_image)
        self.data_dir = Path(data_dir)
        self.persona_dir = self.data_dir/'conversation_personas'
        self.conversation_dir = self.data_dir/'conversations'
        self.log_dir = self.data_dir/'logs'
        self.log_path = Path(self.log_dir)/'conversation_query_kwargs.json'
        for dir_ in (self.persona_dir, self.conversation_dir, self.log_dir):
            os.makedirs(dir_, exist_ok=True)

        # These attributes will be updated when we load a persona and cleared
        # when we end a conversation. current_persona is the processed name
        # (i.e. lowercase w/ underscores).
        self.current_persona = ''
        self.current_summary = ''
        self.current_img_path = ''
        self.current_gender = ''
        self.full_conv = ''
        self.cached_query = ''
        self.user_turns = []
        self.gpt3_turns = []

        # Load prompt, default query kwargs, and existing personas.
        self._kwargs = load_prompt('conversation')
        self._base_prompt = self._kwargs.pop('prompt')

        # Populated by _load_personas().
        self.name2img_path = {}
        self.name2base = {}
        self.name2gender = {}
        self._load_personas(names)

    def _load_personas(self, names):
        names = names or [path.stem for path in self.persona_dir.iterdir()]
        for name in names:
            print(name)
            try:
                self.update_persona_dicts(self.process_name(name))
            except:
                print('exc')
                warnings.warn(f'Could not load files for {name}.')

    def start_conversation(self, name, download_if_necessary=False):
        if name not in self:
            if not download_if_necessary:
                raise KeyError(f'{name} persona not available. You can set '
                               'download_if_necessary=True if you wish to '
                               'construct a new persona.')
            _ = self.add_persona(name, return_data=True)
        self.end_conversation()

        processed_name = self.process_name(name)
        self.current_persona = processed_name
        self.full_conv = self.name2base[processed_name]
        self.current_img_path = self.name2img_path[processed_name]
        self.current_gender = self.name2gender[processed_name]
        # This one is not returned. Info would be a bit repetitive.
        self.current_summary = self._name2summary(processed_name)
        return (self.current_persona,
                self.full_conv,
                self.current_img_path,
                self.current_gender)

    def _name2summary(self, name):
        if '_' not in name: name = self.process_name(name)
        base = self.name2base[name]
        intro = sent_tokenize(base)[0]
        return base.replace(intro, '').strip()

    def end_conversation(self, fname=None):
        if fname: self.save_conversation(fname)
        self.full_conv = ''
        self.current_summary = ''
        self.current_persona = ''
        self.current_img_path = ''
        self.current_gender = ''
        self.cached_query = ''
        self.user_turns.clear()
        self.gpt3_turns.clear()

    def save_conversation(self, fname):
        if not self.full_conv:
            raise RuntimeError('No conversation to save.')
        save(self.full_conv, self.conversation_dir/fname)

    def add_persona(self, name, return_data=False):
        processed_name = self.process_name(name)
        dir_ = self.persona_dir/processed_name
        if dir_.exists():
            summary, img_path, gender = self.update_persona_dicts(
                processed_name, return_values=True
            )
        else:
            summary, _, img_path, gender = wiki_data(
                name, img_dir=self.persona_dir/processed_name, fname='profile'
            )
            save(summary, dir_/'summary.txt')
            save(gender, dir_/'gender.json')

            # Otherwise it's an empty string if we fail to download an image.
            if not img_path:
                img_path = dir_/f'profile{self.backup_image.suffix}'
                shutil.copy2(self.backup_image, img_path)
            self.update_persona_dicts(processed_name)
        if return_data: return summary, img_path, gender

    def update_persona_dicts(self, processed_name, return_values=False):
        dir_ = self.persona_dir/processed_name
        summary = load(dir_/'summary.txt')
        self.name2gender[processed_name] = load(dir_/'gender.json')
        self.name2img_path[processed_name] = [p for p in dir_.iterdir()
                                              if p.stem == 'profile'][0]
        self.name2base[processed_name] = self._base_prompt.format(
            name=self.process_name(processed_name, inverse=True),
            summary=summary
        )
        if return_values:
            return Results(summary=summary,
                           img_path=self.name2img_path[processed_name],
                           gender=self.name2gender[processed_name])

    def process_name(self, name, inverse=False):
        if inverse:
            return name.replace('_', ' ').title()
        return name.lower().replace(' ', '_').replace('.', '')

    def personas(self, pretty=True, sort=True):
        names = list(self.name2base)
        if pretty: names = [self.process_name(name, True) for name in names]
        if sort: names = sorted(names)
        return names

    def kwargs(self, name='', fully_resolved=True, return_prompt=False,
               extra_kwargs=None, **kwargs):
        # Name param should be pretty version, i.e. no underscores. Only
        # needed when return_prompt is True.
        if 'prompt' in kwargs:
            raise RuntimeError(
                'Arg "prompt" should not be in query kwargs. It will be '
                'constructed within this method and passing it in will '
                'override the new version.'
            )
        kwargs = {**self._kwargs, **kwargs}
        for k, v in (extra_kwargs or {}).items():
            v_cls = type(v)
            # Make a new object instead of just using get() or setdefault
            # since the latter two methods both mutate our default kwargs.
            curr_val = v_cls(kwargs.get(k, v_cls()))
            if isinstance(v, Iterable):
                curr_val.extend(v)
            elif isinstance(v, Mapping):
                curr_val.update(v)
            else:
                raise TypeError(f'Key {k} has unrecognized type {v_cls} in '
                                '`extra_kwargs`.')
            kwargs[k] = curr_val

        if fully_resolved: kwargs = dict(bound_args(query_gpt3, [], kwargs))
        if name and return_prompt:
            kwargs['prompt'] = self.name2base[self.process_name(name)]
        return kwargs

    def query_later(self, text):
        self.cached_query = text.strip()
        
    def query(self, text=None, debug=False, extra_kwargs=None, **kwargs):
        if not self.current_persona:
            raise RuntimeError('You must call the `start_conversation` '
                               'method before making a query.')
        
        # In the same spirit as our handling of kwargs here, passing in a text
        # arg will override a cached query if one exists.
        text = text or self.cached_query
        self.cached_query = ''
        kwargs = self.kwargs(fully_resolved=False, return_prompt=False,
                             extra_kwargs=extra_kwargs, **kwargs)
        prompt = self.format_prompt(user_text=text)
        if debug:
            print('prompt:\n' + prompt)
            print(spacer())
            print('kwargs:\n', kwargs)
            print(spacer())
            print('fully resolved kwargs:\n',
                  dict(bound_args(query_gpt3, [], kwargs)))
            return
        
        # Update this after format_prompt() call and debug check.
        self.user_turns.append(text.strip())
        save({'prompt': prompt, **kwargs}, self.log_path, verbose=False)
        prompt, resp = query_gpt3(prompt, **kwargs)
        self.gpt3_turns.append(text.strip())
        # GPT3 prefers prompts that don't end with spaces and query_gpt3()
        # strips output, but we want a space after the colon.
        self.full_conv = prompt + ' ' + resp
        return prompt, resp

    def format_prompt(self, user_text, exclude_trailing_name=False):
        if not self.full_conv:
            raise RuntimeError('Conversation history is empty. Have you '
                               'started a conversation?')
        user_turns = (self.user_turns[-self.user_turn_window:]
                      + [user_text.strip()])
        gpt3_turns = self.gpt3_turns[-self.gpt3_turn_window]
        user_turns = [f'Me: {turn}' for turn in user_turns]
        gpt3_turns = [f'{self.current_persona}: {turn}'
                      for turn in gpt3_turns]
        prompt = self.current_summary + \
            '\n\n'.join(flatten(zip_longest(user_turns, gpt3_turns)))
        # TODO: still need to check if this works as expected.
        if exclude_trailing_name: return prompt
        return f'{prompt}\n\n{self.process_name(self.current_persona, True)}:'

    @contextmanager
    def converse(self, name, fname='', download_if_necessary=False):
        try:
            _ = self.start_conversation(name, download_if_necessary)
            yield
        finally:
            self.end_conversation(fname=fname)

    @staticmethod
    def format_conversation(text, gpt_color='black'):
        def _format(line, color='black'):
            if not line: return line
            name, _, line = line.partition(':')
            # Bold's stop character also resets color so we need to color the
            # chunks separately.
            return colored(bold(name + ':'), color) + colored(line, color)

        if listlike(text): text = ' '.join(text)
        summary, *lines = text.splitlines()
        name = [name for name, n in
                Counter(line.split(':')[0]
                        for line in lines if ':' in line).most_common(2)
                if name != 'Me'][0]
        formatted_lines = [bold(summary)]
        prev_is_me = True
        for line in lines:
            if line.startswith(name + ':'):
                line = _format(line, gpt_color)
                prev_is_me = False
            elif line.startswith('Me: ') or prev_is_me:
                line = _format(line)
                prev_is_me = True
            formatted_lines.append(line)
        return '\n'.join(formatted_lines)

    def __contains__(self, name):
        return self.process_name(name) in self.name2base

    def __len__(self):
        return len(self.name2base)

In [22]:
a = ['one', 'two', 'three']
b = ['a', 'b']
tmp = '\n'.join(filter(None, flatten(zip_longest(a, b))))
print(tmp)

one
a
two
b
three


In [23]:
'\n'.join(filter(None, flatten(zip_longest(a, b + ['c']))))
print(tmp)

one
a
two
b
three
