In [1]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import joblib
import imodelsx.process_results
import sys
import numpy as np
sys.path.append('../experiments/')
results_dir = '/home/chansingh/mntv1/deep-fMRI/encoding/single_question_may6'
experiment_filename = '../experiments/04_fit_single_question.py'

In [2]:
grids = joblib.load(os.path.join('../data/', "story_data", "grids_huge.jbl"))

In [5]:
trfiles = joblib.load(os.path.join(
    '../data/', "story_data", "trfiles_huge.jbl"))

In [None]:
len(grids.keys())

In [None]:
import os
import joblib
import torch
from torch.utils.data import Dataset
import h5py
from ridge_utils.dsutils import make_word_ds, make_phoneme_ds


class HuthLabDataset(Dataset):
    def __init__(self, data_dir: str, subject: str, num_trs: int, trim_start: int = 5, trim_end: int = 10):
        self.subject = subject
        self.stories = joblib.load(os.path.join(
            data_dir, subject, "storylist.jbl"))
        self.grids = joblib.load(os.path.join(
            data_dir, "story_data", "grids.jbl"))
        self.trfiles = joblib.load(os.path.join(
            data_dir, "story_data", "trfiles.jbl"))
        self.wordseqs = make_word_ds(self.grids, self.trfiles)
        self.trim_start = trim_start
        self.trim_end = trim_end
        self.lookback = num_trs  # num_trs = num_seconds / 2
        self.resp_dict = {}
        self.chunk_dict = {}
        for story in self.stories:
            hf5_path = os.path.join(data_dir, subject, story + ".hf5")
            self.resp_dict[story] = h5py.File(hf5_path, 'r')
            self.chunk_dict[story] = self.wordseqs[story].chunks()[
                self.trim_start:-self.trim_end]
            # Confirm trimming dimensions match
            num_trs_stim = len(
                self.wordseqs[story].tr_times[self.trim_start:-self.trim_end])
            num_trs_resp = self.resp_dict[story]['data'].shape[0]
            assert num_trs_stim == num_trs_resp

    def __len__(self):
        return len(self.stories)

    def __getitem__(self, story: str, idx: int, delays: int):
        assert delays >= 0
        if delays == 0:
            return (self.chunk_dict[story][idx], self.resp_dict[story]['data'][idx])
        else:
            acc_out = []
            for i in range(delays+1):
                if idx-delays+i < 0:
                    acc_out.append(np.array([], dtype='<U13'))
                else:
                    acc_out.append(self.chunk_dict[story][idx-delays+i])
            return (acc_out, self.resp_dict[story]['data'][idx])