In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import sys
from IPython.display import display, HTML
from typing import List
from mprompt.modules.emb_diff_module import EmbDiffModule
import numpy as np
import matplotlib
import imodelsx.util
from copy import deepcopy
import re
import notebook_helper
import mprompt.viz
import scipy.special
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
from mprompt.methods.m4_evaluate import D5_Validator
import openai
from mprompt.modules.fmri_module import fMRIModule
from pprint import pprint
openai.api_key_path = os.path.expanduser('~/.OPENAI_KEY')


r = (pd.read_pickle('../results/results_fmri.pkl')
    .sort_values(by=['top_score_synthetic'], ascending=False))
r['id'] = "('" + r['top_explanation_init_strs'].str.replace(' ', '_').str.slice(stop=20) + "', '" + r['subject'] + "', " + r['module_num'].astype(str) + ")"

### Select voxels

In [None]:
# manually pick some voxels
with pd.option_context('display.max_rows', None, 'display.max_colwidth', 200):
    display(r.sort_values(by=['top_score_synthetic'], ascending=False)[
        ['top_explanation_init_strs', 'subject', 'module_num', 'top_score_synthetic', 'frac_top_ngrams_module_correct', 'id', 'top_ngrams_module_correct']
    ].round(3).reset_index(drop=True).head(50))


# expls = ['baseball','animals','water','movement','religion','time','technology']
# interesting_expls = ['food', 'numbers', 'physical contact', 'time', 'laughter', 'age', 'clothing']
# voxels = [('movement', 'UTS01',	7), ('numbers', 'UTS03', 55), ('time', 'UTS03', 19), ('relationships', 'UTS01', 21),
        #   ('sounds', 'UTS03', 35), ('emotion', 'UTS03', 23), ('food', 'UTS03', 46)]
# voxels = [('numbers', 'UTS03', 55), ('time', 'UTS03', 19),
        #   ('sounds', 'UTS03', 35), ('emotion', 'UTS03', 23), ('food', 'UTS03', 46)]
# voxels = [('movement', 'UTS01',	7),('relationships', 'UTS01', 21) ('passing of time	UTS02	4)]
voxels = [('relationships', 'UTS02', 9), ('time', 'UTS02', 4), ('looking or staring', 'UTS03', 57)]

In [7]:
# put all voxel data into rows DataFrame
rows = []
expls = []
for vox in voxels:
    expl, subj, vox_num = vox
    try:
        rows.append(r[(r.subject == subj) & (r.module_num == vox_num)].iloc[0])
        expls.append(expl)
    except:
        print('skipping', vox)
rows = pd.DataFrame(rows)
rows['expl'] = expls

with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.max_colwidth', 200):
    display(rows[['subject', 'module_num', 'expl', 'top_explanation_init_strs', 'top_ngrams_module_correct']])

Unnamed: 0,subject,module_num,expl,top_explanation_init_strs,top_ngrams_module_correct
80,UTS02,9,relationships,relationships and milestones in life,"[boyfriend of six, dating for months, boyfriend a year, married for fifteen, we got engaged, we were engaged, twenty he retired, a divorce twenty, virginity at twenty, am turning forty, daughter i..."
53,UTS02,4,time,passing of time,"[weeks became months, weekends became weeks, five years four, the moment passed, twenty minutes thirty, more time passed, replied age thirty, moment passed, later came the, days later diagnosed, m..."
41,UTS03,57,looking or staring,looking or staring in some way,"[eyed her suspiciously, at him incredulously, wink at, at me shyly, locks eyes with, staring at me, turned and saw, and mimed crying, incredulously like look, staring right at, point at a, leered ..."


# Generate story

In [22]:
PROMPTS = {
    'v0': {
        'prefix_first': 'Write the beginning paragraph of a story about',
        'prefix_next': 'Write the next paragraph of the story, but now make it about',
        'suffix': ' "{expl}". Make sure it contains several references to "{expl}".',
    },

    # first-person
    'v1': {
        'prefix_first': 'Write the beginning paragraph of a story told in first person. The story should be about',
        'prefix_next': 'Write the next paragraph of the story, but now make it about',
        'suffix': ' "{expl}". Make sure it contains several references to "{expl}".',
    },

    # add in ngrams
    'v2': {
        'prefix_first': 'Write the beginning paragraph of a story told in first person. The story should be about',
        'prefix_next': 'Write the next paragraph of the story, but now make it about',
        'suffix': ' "{expl}". Make sure it contains several references to "{expl}", such as {examples}.',
    }
}


version = 'v2'
PV = PROMPTS[version]

# get a list of prompts
expls = rows.expl.values
examples = rows.top_ngrams_module_correct.apply(lambda l: ', '.join([f'"{x}"' for x in l[:3]])).values

def get_prompts_basic(expls):
    prompt_init = PV['prefix_first'] + PV['suffix']
    prompt_continue = PV['prefix_next'] + PV['suffix']
    if version in ['v0', 'v1']:
        prompts = [prompt_init.format(expl=expls[0])] + [prompt_continue.format(expl=expl) for expl in expls[1:]]
    elif version in ['v2']:
        prompts = [prompt_init.format(expl=expls[0], examples=examples[0])] + \
            [prompt_continue.format(expl=expl, examples=examples) for (expl, examples) in zip(expls[1:], examples[1:])] 
    return prompts
prompts = get_prompts_basic(expls)
for p in prompts:
    print(p)

Write the beginning paragraph of a story told in first person. The story should be about "relationships". Make sure it contains several references to "relationships", such as "boyfriend of six", "dating for months", "boyfriend a year".
Write the next paragraph of the story, but now make it about "time". Make sure it contains several references to "time", such as "weeks became months", "weekends became weeks", "five years four".
Write the next paragraph of the story, but now make it about "looking or staring". Make sure it contains several references to "looking or staring", such as "eyed her suspiciously", "at him incredulously", "wink at".


In [28]:
# generate the paragraphs (these prefixes need get replace when the story is getting long)
paragraphs = mprompt.llm.get_paragraphs(prompts, prefix_first=PV['prefix_first'], prefix_next=PV['prefix_next'])
for para in paragraphs:
    pprint(para)

cached!
cached!
cached!
("I've always been fascinated by relationships. The way two people can come "
 "together and form a bond that lasts a lifetime is truly remarkable. I've had "
 "my fair share of relationships, some good and some bad. Currently, I'm "
 "dating my boyfriend of six months and things are going great. We've been "
 "through a lot together in such a short amount of time, but we've managed to "
 'come out stronger on the other side. Before him, I was in a relationship for '
 'a year that ended in heartbreak. But I learned so much from that experience '
 "and it's helped me appreciate the relationship I have now even more.")
('Time is a funny thing when it comes to relationships. It can feel like weeks '
 "become months and weekends become weeks when you're with the right person. "
 "But when you're in a toxic relationship, time can feel like it's standing "
 "still. I've been in both situations and I can say with certainty that time "
 'is a crucial factor in any relat

# Test synthetic data contains concept

In [None]:
val = D5_Validator()

In [None]:
# visualize single story
s = mprompt.viz.visualize_story_html(val, expls, paragraphs, prompts, fname='../results/story_running.html')
display(HTML(s))

In [None]:
# compute scores heatmap
scores = notebook_helper.compute_expl_data_match_heatmap(val, expls, paragraphs)

In [None]:
s = scores
# s = scipy.special.softmax(scores, axis=1)
# s = (s - s.min()) / (s.max() - s.min())
mprompt.viz.heatmap(scores, expls)

# Test modules on the generated stories

In [None]:
scores, scores_max, all_scores, all_ngrams = notebook_helper.compute_expl_module_match_heatmap(expls, paragraphs, voxels)

In [None]:
s = scores
s = scipy.special.softmax(scores, axis=1)
# s = (s - s.min()) / (s.max() - s.min())
mprompt.viz.heatmap(scores, expls, xlab='Explanation of voxel used for evaluation', clab='Mean voxel response')

### Module responses for single story

In [None]:
pd.set_option('display.max_rows', 120)
for i in range(1): #len(voxels)):
    row = rows.iloc[i]
    display(row[['subject', 'module_num', 'top_explanation_init_strs', 'explanation_init_ngrams', 'top_ngrams_module_correct']])
    mod = fMRIModule(voxel_num_best=row.module_num, subject=row.subject)
    # show all rows
    # display(
    #     pd.DataFrame.from_dict({
    #         'score': all_scores[i][i],
    #         'ngram': all_ngrams[i][i],
    #     }).sort_values('score', ascending=False).head(10)
    # )

In [None]:
x = row['explanation_init_ngrams']
p = mod(x)


In [None]:
out = mod(x, return_all=True)
scores = row['explanation_init_outputs']

In [None]:
p = out[:, row.module_num]