In [81]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
sys.path.append('../experiments')
import seaborn as sns
import os
import pandas as pd
from copy import deepcopy
from matplotlib import pyplot as plt
from os.path import join
import numpy as np
import imodelsx.process_results
from neuro.features import qa_questions, feature_spaces
from neuro.data import story_names
from neuro.features.stim_utils import load_story_wordseqs, load_story_wordseqs_huge
import random
import json
import neuro.config
import joblib
from tqdm import tqdm
from collections import defaultdict
fit_encoding = __import__('02_fit_encoding')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [84]:
story_names_list = sorted(story_names.get_story_names(
    all=True))
print('loaded', len(story_names_list), 'stories')
wordseqs = load_story_wordseqs_huge(story_names_list)
wordseq_idxs = {}
ngrams_list_total = []
running_idx = 0
for story in story_names_list:
    ngrams_list = feature_spaces.get_ngrams_list_main(
        wordseqs[story], num_ngrams_context=10)
    ngrams_list_total.extend(ngrams_list)
    assert len(ngrams_list) == len(wordseqs[story].data)
    wordseq_idxs[story] = (running_idx, running_idx + len(ngrams_list))
    running_idx += len(ngrams_list)
print(f'{len(ngrams_list_total)=} ngrams')
joblib.dump(({'ngrams_list_total': ngrams_list_total, 'wordseq_idxs': wordseq_idxs}), os.path.join(
    neuro.config.root_dir, 'qa/cache_gpt/ngrams_metadata.joblib'))


questions = [
    'Does the sentence describe a personal reflection or thought?',
    'Does the sentence contain a proper noun?',
    'Does the sentence describe a physical action?',
    'Does the sentence describe a personal or social interaction that leads to a change or revelation?',
    'Does the sentence involve the mention of a specific object or item?',  # completed
    'Does the sentence involve a description of physical environment or setting?',
    'Does the sentence describe a relationship between people?',
    'Does the sentence mention a specific location?',
    'Is time mentioned in the input?',  # completed
    'Is the sentence abstract rather than concrete?',
    "Does the sentence express the narrator's opinion or judgment about an event or character?",
    'Is the input related to a specific industry or profession?',
    'Does the sentence include dialogue?',
    'Does the sentence describe a visual experience or scene?',
    'Does the input involve planning or organizing?',
    'Does the sentence involve spatial reasoning?',
    'Does the sentence involve an expression of personal values or beliefs?',
    'Does the sentence contain a negation?',
    'Does the sentence describe a sensory experience?',
    'Does the sentence include technical or specialized terminology?',
    'Does the input contain a number?',
    'Does the sentence contain a cultural reference?',
    'Does the text describe a mode of communication?',
    'Does the input include a comparison or metaphor?',
    'Does the sentence express a sense of belonging or connection to a place or community?',
    'Does the sentence describe a specific sensation or feeling?',
    'Does the text include a planning or decision-making process?',
    'Does the sentence include a personal anecdote or story?',
    'Does the sentence involve a discussion about personal or social values?',
    'Does the text describe a journey?',
    'Does the input contain a measurement?',
    'Does the sentence describe a physical sensation?',
    'Does the sentence include a direct speech quotation?',
    'Is the sentence reflective, involving self-analysis or introspection?',
    'Does the input describe a specific texture or sensation?',
]

loaded 103 stories
len(ngrams_list_total)=195190 ngrams


In [None]:
story_name = 'wheretheressmoke'

In [None]:
def get_gpt4_qa_embs(questions, story_name):
    answers_dict = {}
    for question in os.listdir(os.path.join(neuro.config.root_dir, 'qa/cache_gpt')):
        out_file = os.path.join(neuro.config.root_dir,
                                f'qa/cache_gpt/{question}')
        answers_dict[question] = joblib.load(out_file)
    out = pd.DataFrame(answers_dict, index=ngrams_list_total)

    embs = np.zeros((len(wordseqs[story_name].data), len(questions)))
    out_story = out.iloc[wordseq_idxs[story_name]
                         [0]: wordseq_idxs[story_name][1]]
    for q in out_story.columns:
        assert q in questions
        idx = questions.index(q)
        embs[:, idx] = out_story[q].values

    return embs

# Compare questions

In [None]:
answers_dict = {}
for question in tqdm(questions[:17]):
    out_file = f'/home/chansingh/mntv1/deep-fMRI/qa/cache_gpt/{question}.pkl'
    answers_dict[question] = joblib.load(out_file)
out = pd.DataFrame(answers_dict, index=ngrams_list_total)


def abbrev_question(q):
    for prefix in ['Does the sentence', 'Is the sentence', 'Does the input', 'Is the input']:
        q = q.replace(prefix, '...')
    return q

In [None]:
o = deepcopy(out)
o.columns = [abbrev_question(q) for q in o.columns]

# plt.figure(figsize=(10, 10))
corrs = o.corr()
# set diag to nan
vabs = np.nanmax(corrs)
# plot clustermap with cbar bottom right and mask diagonal
sns.clustermap(corrs, center=0, cmap='RdBu_r', vmin=-vabs,
               vmax=vabs, cbar_pos=(0.5, 0.3, 0.03, 0.2))

# add barplot on top of the clustermap
# plt.figure(figsize=(10, 10))
plt.xlabel('Mean correlation')
plt.show()

In [None]:
out