In [None]:
%load_ext autoreload
%autoreload 2
import sys
from ridge_utils.DataSequence import DataSequence
import pandas as pd
import os
from os.path import join
from collections import defaultdict
import numpy as np
import joblib
from tqdm import tqdm
from sasc.config import FMRI_DIR, STORIES_DIR, RESULTS_DIR
from neuro.config import brain_drive_resps_dir
from ridge_utils import DataSequence
import ridge_utils.DataSequence import DataSequence
import DataSequence

In [None]:
wordseqs_dir = join(RESULTS_DIR, 'processed', 'wordseqs')

wordseqs = [
    joblib.load(join(wordseqs_dir, f)) for f in os.listdir(wordseqs_dir)
]
# concatenate list of dicts
wordseqs_dict = {}
for d in wordseqs:
    for k, v in d.items():
        wordseqs_dict[k] = ridge_utils.DataSequence.DataSequence(
            data=v.data, split_inds=v.split_inds, data_times=v.data_times, tr_times=v.tr_times)

In [None]:
wordseqs_dict

In [None]:
RESPS_DICT = {
    '20230504': {
        "default/uts02_pilot_gpt4_mar28___ver=v4_noun___seed=1": "GenStory1_resps.npy",
        "default/uts02_pilot_gpt4_mar28___ver=v4_noun___seed=3": "GenStory2_resps.npy",
        "default/uts02_pilot_gpt4_mar28___ver=v4_noun___seed=4": "GenStory3_resps.npy",
        "default/uts02_pilot_gpt4_mar28___ver=v5_noun___seed=1": "GenStory4_resps.npy",
        "default/uts02_pilot_gpt4_mar28___ver=v5_noun___seed=2": "GenStory5_resps.npy",
        "default/uts02_pilot_gpt4_mar28___ver=v5_noun___seed=4": "GenStory6_resps.npy",
    },
    '20230702': {
        "interactions/uts02___jun14___seed=1": "GenStory7_resps.npy",
        "interactions/uts02___jun14___seed=4": "GenStory8_resps.npy",
        "polysemantic/uts02___jun14___seed=6": "GenStory9_resps.npy",
        "polysemantic/uts02___jun14___seed=1": "GenStory10_resps.npy",
    },
    '20231106': {
        'default/uts03___jun14___seed=5': 'GenStory12_resps.npy',
        'default/uts03___jun14___seed=1': 'GenStory13_resps.npy',
        'interactions/uts03___jun14___seed=5': 'GenStory14_resps.npy',
        'interactions/uts03___jun14___seed=6': 'GenStory15_resps.npy',
        'polysemantic/uts03___jun14___seed=3': 'GenStory16_resps.npy',
        'polysemantic/uts03___jun14___seed=7': 'GenStory17_resps.npy',
    },
    '20240509': {
        'default/uts01___may9___seed=5_top1': 'deeptune-story19.npy',
        'default/uts01___may9___seed=2_top2': 'deeptune-story20.npy',
        'interactions/uts01___may9___seed=3_top1': 'deeptune-story21.npy',
        'interactions/uts01___may9___seed=6_top2': 'deeptune-story22.npy',
    },

    '20240604': {
        'roi/uts02___roi_may31___seed=5_best1': 'GenStory23.npy',
        'roi/uts02___roi_may31___seed=2_best2': 'GenStory24.npy',
        'roi/uts02___roi_may31___seed=7_best3': 'GenStory25.npy',
        'roi/uts02___roi_may31___seed=6_best4': 'GenStory26.npy',

        'qa/uts02___qa_may31___seed=1': 'GenStory27_resps.npy',
        'qa/uts02___qa_may31___seed=2': 'GenStory28_resps.npy',
        'qa/uts02___qa_may31___seed=3': 'GenStory29_resps.npy',
    },
}


def build_wordseq(timings):
    TR_TIME = 2
    words = timings['word'].values
    word_lengths = timings['timing'].values
    end_times = timings['time_running'].values

    # Compute the average times of the words (halfway through each word)
    word_avgtimes = end_times - (word_lengths / 2.0)

    # Compute the TR times starting from -9
    tr_times = np.arange(-9, word_avgtimes[-1] + TR_TIME, TR_TIME)

    # Compute the split indices
    split_inds = [(word_avgtimes < (t + TR_TIME)).sum() for t in tr_times][:-1]

    ds = DataSequence(words, split_inds, word_avgtimes, tr_times)
    return ds

In [None]:
def get_num_from_string(s):
    return int(''.join(filter(str.isdigit, s)))


dset = defaultdict(list)
for session in tqdm(RESPS_DICT.keys()):
    for k, v in RESPS_DICT[session].items():
        dset['session'].append(session)
        dset['story_name'].append(k)
        dset['resp_file'].append(v.replace('_resps', ''))
        # all the time comes from loading the resp here
        resp = np.load(join(brain_drive_resps_dir,
                       session, dset['resp_file'][-1]))
        dset['trs_resp'].append(resp.shape[0])

        if session == '20240509':
            timings = pd.read_csv(
                join(STORIES_DIR, k, 'timings_processed_slowed.csv'))
            # display(timings.head())
        else:
            timings = pd.read_csv(
                join(STORIES_DIR, k, 'timings_processed.csv'))

        # get number from string
        v_string = get_num_from_string(v)
        wordseq = wordseqs_dict[f'GenStory{v_string}']

        # wordseq = build_wordseq(timings)
        if len(wordseq.tr_times) - dset['trs_resp'][-1] == 16:
            # trim off last tr from wordseq
            wordseq.tr_times = wordseq.tr_times[:-1]
        dset['trs_timings'].append(int(max(timings['time_running']) // 2))
        dset['trs_wordseq'].append(len(wordseq.tr_times))
        dset['trs_diff'].append(dset['trs_wordseq'][-1] - dset['trs_resp'][-1])
        dset['subject'].append(k.split('/')[1].split('_')[0].upper())
        dset['wordseq'].append(wordseq)

df = pd.DataFrame(dset)
df['story_id'] = 'Story' + df['resp_file'].str.replace(
    '.npy', '').replace('GenStory', '').replace('deeptune-story', '')


df['story_id'] = 'Story' + df['resp_file'].str.replace(
    '.npy', '').str.replace('GenStory', '').str.replace('deeptune-story', '')
df = df.set_index('story_id')

In [None]:
df

In [36]:
joblib.dump(df, join(brain_drive_resps_dir, 'metadata.pkl'))

['/home/chansingh/mntv1/deep-fMRI/brain_tune/story_data/metadata.pkl']

In [None]:
df[df.subject == 'UTS03'].story_id.values

In [None]:
df.loc[['GenStory5', 'GenStory7']]['wordseq'].to_dict()