In [56]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import sys
from IPython.display import display, HTML
from typing import List
from mprompt.modules.emb_diff_module import EmbDiffModule
import numpy as np
import matplotlib
import imodelsx.util
from copy import deepcopy
import re
import notebook_helper
import mprompt.viz
import scipy.special
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
from mprompt.methods.m4_evaluate import D5_Validator
import openai
from mprompt.modules.fmri_module import fMRIModule
from pprint import pprint
import joblib
import viz
from mprompt.config import RESULTS_DIR
import torch.cuda
import json
openai.api_key_path = os.path.expanduser('~/.OPENAI_KEY')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Option 1 - select rows corresponding to 2016 categories

In [57]:
def get_rows_huth():
    huth2016_categories = json.load(open('huth2016clusters.json', 'r'))
    r = pd.DataFrame.from_dict({'expl': huth2016_categories.keys(), 'top_ngrams_module_correct': huth2016_categories.values()})
    rows = r
    return rows

### Option 2 -- select rows from fitted voxels

In [58]:
def get_rows_voxels():
    r = (pd.read_pickle('../results/results_fmri.pkl')
        .sort_values(by=['top_score_synthetic'], ascending=False))
    r['id'] = "('" + r['top_explanation_init_strs'].str.replace(' ', '_').str.slice(stop=20) + "', '" + r['subject'] + "', " + r['module_num'].astype(str) + ")"

    # manually pick some voxels
    with pd.option_context('display.max_rows', None, 'display.max_colwidth', 200):
        display(r.sort_values(by=['top_score_synthetic'], ascending=False)[
            ['top_explanation_init_strs', 'subject', 'module_num', 'top_score_synthetic', 'frac_top_ngrams_module_correct', 'id', 'top_ngrams_module_correct']
        ].round(3).reset_index(drop=True).head(50))


    # expls = ['baseball','animals','water','movement','religion','time','technology']
    # interesting_expls = ['food', 'numbers', 'physical contact', 'time', 'laughter', 'age', 'clothing']
    # voxels = [('movement', 'UTS01',	7), ('numbers', 'UTS03', 55), ('time', 'UTS03', 19), ('relationships', 'UTS01', 21),
            #   ('sounds', 'UTS03', 35), ('emotion', 'UTS03', 23), ('food', 'UTS03', 46)]
    # voxels = [('numbers', 'UTS03', 55), ('time', 'UTS03', 19),
            #   ('sounds', 'UTS03', 35), ('emotion', 'UTS03', 23), ('food', 'UTS03', 46)]
    # voxels = [('movement', 'UTS01',	7),('relationships', 'UTS01', 21) ('passing of time	UTS02	4)]
    voxels = [('relationships', 'UTS02', 9), ('time', 'UTS02', 4), ('looking or staring', 'UTS03', 57), ('food and drinks', 'UTS01', 52), ('hands and arms', 'UTS01', 46)]

    # put all voxel data into rows DataFrame
    rows = []
    expls = []
    for vox in voxels:
        expl, subj, vox_num = vox
        try:
            rows.append(r[(r.subject == subj) & (r.module_num == vox_num)].iloc[0])
            expls.append(expl)
        except:
            print('skipping', vox)
    rows = pd.DataFrame(rows)
    rows['expl'] = expls

    with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.max_colwidth', 200):
        display(rows[['subject', 'module_num', 'expl', 'top_explanation_init_strs', 'top_ngrams_module_correct']])

    return rows

# Generate story

In [62]:
EXPT_NAME = 'huth2016clusters_mar21_i_time_traveled'
version = 'v4'
rows = get_rows_huth()

# EXPT_NAME = 'relationships_mar9'
# rows = get_rows_voxels()

expls = rows.expl.values
examples_list = rows.top_ngrams_module_correct
prompts = notebook_helper.get_prompts(expls, examples_list, version, n_examples=4)
for p in prompts:
    print(p)
PV = notebook_helper.get_prompt_templates(version)

Write the beginning paragraph of an interesting story told in first person. The story should place a heavy focus on temporal words. Make sure it contains several references to temporal words, such as "travel", "minute", "leave", "date".
Write the next paragraph of the story, but now make it emphasize abstract words. Make sure it contains several references to abstract words, such as "natural", "roots", "delicate", "exaggerated".
Write the next paragraph of the story, but now make it emphasize professional words. Make sure it contains several references to professional words, such as "meetings", "owner", "worker", "office".
Write the next paragraph of the story, but now make it emphasize visual words. Make sure it contains several references to visual words, such as "yellow", "fur", "silver", "badge".
Write the next paragraph of the story, but now make it emphasize violent words. Make sure it contains several references to violent words, such as "lethal", "instantly", "breath", "kill".


In [63]:
paragraphs = mprompt.llm.get_paragraphs(prompts, prefix_first=PV['prefix_first'], prefix_next=PV['prefix_next'])
rows['prompt'] = prompts
rows['paragraph'] = paragraphs
for para in paragraphs:
    pprint(para)

cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
("I've always been fascinated by the concept of time. It's a strange and "
 'elusive thing, always moving forward, never stopping or slowing down. As a '
 'child, I would spend hours staring at the clock on my bedroom wall, watching '
 'the seconds tick by. As I grew older, my fascination with time only '
 'deepened. So when I received an invitation to travel back in time for just '
 "one minute, I couldn't resist. The date was set, and as the moment "
 'approached for me to leave my present life behind and step into the past, my '
 'heart raced with excitement and anticipation.')
("As I stepped into the time machine, I couldn't help but feel a sense of "
 'natural curiosity and wonder. The roots of my fascination with time ran '
 'deep, and I knew that this delicate moment would be one that I would never '
 'forget. The exaggerated stories of time travel that I had heard as a child '
 'sudd

In [64]:
STORIES_DIR = join(RESULTS_DIR, 'stories')
os.makedirs(STORIES_DIR, exist_ok=True)
joblib.dump(rows, join(STORIES_DIR, f'{EXPT_NAME}_rows.pkl'))

['/home/chansingh/mprompt/results/stories/huth2016clusters_mar21_i_time_traveled_rows.pkl']