In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
import seaborn as sns
import os
import pandas as pd
from copy import deepcopy
from matplotlib import pyplot as plt
from os.path import join
import numpy as np
import imodelsx.process_results
import qa_questions
import random
import json
import joblib
from tqdm import tqdm
from collections import defaultdict
import feature_spaces
fit_encoding = __import__('01_fit_encoding')
import encoding_utils

### Get data

In [None]:
class A:
    use_test_setup = False
    subject = 'UTS03'
    num_stories = -1


args = A()
story_names_train, story_names_test = fit_encoding.get_story_names(args)
wordseqs = feature_spaces.get_story_wordseqs(story_names_train)

### Load the model to boost

In [None]:
results_dir = analyze_helper.best_results_dir
r = imodelsx.process_results.get_results_df(results_dir)
for k in ['save_dir', 'save_dir_unique']:
    r[k] = r[k].map(lambda x: x if x.startswith('/home')
                    else x.replace('/mntv1', '/home/chansingh/mntv1'))

args_top = r[
    (r.feature_space.str.contains('qa_embedder')) *
    (r.pc_components == 100) *
    # first boost
    # (r.ndelays == 4) *
    # (r.qa_questions_version == 'v2')
    # second boost
    (r.ndelays == 8) *
    (r.qa_questions_version == 'v3_boostexamples')

    # (r.qa_questions_version == 'v4')
    # (r.qa_questions_version == 'v5')
].sort_values(
    by='corrs_tune_pc_mean',
    ascending=False).iloc[0]
print(f'{args_top.feature_space=} {args_top.ndelays=}')
print(f'{args_top.corrs_test_mean=:.3f} {args_top.corrs_tune_pc_mean=:3f}')

model_params_to_save = joblib.load(
    join(args_top.save_dir_unique, 'model_params.pkl'))

### Boost based on errors or boost based on deviation from llama model
- If use_distill=True, Boost based on LLaMA model preds
- If use_distill=False, Boost based on voxel errors
  - Generate examples for boosted questions based model errors (v4, v5, v6)
  - note: v4 wasn't actually boosted because the model we used was basically random
  - v5 settings were:
    - args_top.feature_space='qa_embedder-10' args_top.ndelays=4
    - args_top.corrs_test_mean=0.126 args_top.corrs_tune_pc_mean=0.134110

In [None]:
use_distill = True
if use_distill:
    folder_id_distill = '68936a10a548e2b4ce895d14047ac49e7a56c3217e50365134f78f990036c5f7'
    results_dir = '/home/chansingh/mntv1/deep-fMRI/encoding/results_apr7'
    args_distill = pd.Series(joblib.load(
        join(results_dir, folder_id_distill, 'results.pkl')))
    print(args_distill[['feature_space', 'ndelays',
                        'corrs_test_mean', 'num_stories', 'subject']])
    model_params = joblib.load(
        join(results_dir, folder_id_distill, 'model_params.pkl'))

In [None]:
r = defaultdict(list)
for story_name in tqdm(story_names_train):
    # ngram for 3 trs preceding the current TR
    chunks = wordseqs[story_name].chunks()
    ngrams_list = feature_spaces._get_ngrams_list_from_chunks(
        chunks, num_trs=3)
    ngrams_list = np.array(ngrams_list[10:-5])

    stim_train_delayed, resp_target = fit_encoding.get_data(
        args_top, [story_name])

    preds_test = stim_train_delayed @ model_params_to_save['weights'] + \
        model_params_to_save['bias']

    # compare to distilled predictions instead of actual response
    if use_distill:
        stim_train_delayed_distill, _ = fit_encoding.get_data(
            args_distill, [story_name])
        resp_target = stim_train_delayed_distill @ model_params['weights'] + \
            model_params['bias']

    # calculate correlation at each timepoint
    corrs_time = np.array([np.corrcoef(resp_target[i, :], preds_test[i, :])[0, 1]
                           for i in range(resp_target.shape[0])])
    corrs_time[:10] = 100  # don't pick first 10 TRs
    # get worst 3 idxs
    corrs_worst_idxs = np.argsort(corrs_time)[:3]

    for i in range(3):
        r['story_name'].append(story_name)
        r['corrs'].append(corrs_time[corrs_worst_idxs[i]])
        r['ngram'].append(ngrams_list[corrs_worst_idxs[i]])
        r['tr'].append(corrs_worst_idxs[i])

joblib.dump(r, '../questions/ngrams_boost_v4_llama.pkl')

In [None]:
ngrams_boost_list = pd.DataFrame(joblib.load(
    '../questions/ngrams_boost_v4_llama.pkl'))
# remove any string that is a subset of another string
ngrams_boost_list_clean = []
for ngram in ngrams_boost_list['ngram']:
    if not any([ngram in x for x in ngrams_boost_list_clean]) and len(ngram.strip()) > 1:
        ngrams_boost_list_clean.append(ngram)
print('lens', len(ngrams_boost_list), len(ngrams_boost_list_clean))

print('\n'.join(
    ['- ' + x for x in ngrams_boost_list_clean[1::2]]))

In [None]:
questions_prev = json.load(open('../questions/v3_boostexamples.json'))
print('\n'.join(['- ' + x for x in questions_prev[1::2]]))

### Generate random examples for prompting new questions (v3)

In [None]:
seed = 43  # 42, 43
ngrams_examples = []
ngram_size = 10
num_examples_per_story = 1
random.seed(seed)
np.random.seed(seed)
for story_name in story_names_train:
    words_list = wordseqs[story_name].data
    ngrams_list = feature_spaces._get_ngrams_list_from_words_list(
        words_list, ngram_size=ngram_size)[ngram_size + 2:]
    ngrams_examples += np.random.choice(ngrams_list,
                                        num_examples_per_story).tolist()
print('\n'.join(['- ' + ngram for ngram in ngrams_examples]))