In [1]:
import os
import json
import argparse
import numpy as np
import config, consts, paths
from decoding.GPT import GPT
from decoding.utils_stim import get_stim
from decoding.utils_resp import get_resp
from decoding.StimulusModel import LMFeatures
from encoding.ridge import ridge, bootstrap_ridge
from utils import flatten_list, save_data

  _C._set_default_tensor_type(t)


In [7]:
subject='UTS03'

In [8]:
# load gpt
stories = flatten_list(consts.STORIES)
with open(os.path.join(config.DATA_LM_DIR, "perceived", "vocab.json"), "r") as f:
    gpt_vocab = json.load(f)
gpt = GPT(path = os.path.join(config.DATA_LM_DIR, "perceived", "model"), vocab = gpt_vocab)
features = LMFeatures(model = gpt, layer = config.GPT_LAYER, context_words = config.GPT_WORDS)


In [9]:
# estimate noise model
num_voxels = consts.NUM_VOXELS[subject]
rstim, tr_stats, word_stats = get_stim(stories, "story", features)
splits = np.array_split(range(num_voxels), 2)
weights = np.zeros([rstim.shape[1], num_voxels])
alphas = np.zeros(num_voxels)
bscorrs = np.zeros([len(config.ALPHAS), num_voxels, config.NBOOTS])
for split in splits:
    rresp = get_resp(subject, stories, "story", voxels = split, stack = True)
    weights[:, split], alphas[split], bscorrs[:, split, :] = bootstrap_ridge(rstim, rresp, alphas = config.ALPHAS, 
            nboots = config.NBOOTS, chunklen = config.CHUNKLEN, use_corr = False, seed = 42)        
    del rresp
bscorrs = bscorrs.mean(2).max(0)
voxels = np.sort(np.argsort(bscorrs)[-config.VOXELS:])

In [10]:
# estimate noise model
stim_dict = {story : get_stim([story], "story", features, tr_stats = tr_stats) for story in stories}
resp_dict = get_resp(subject, stories, "story", voxels = voxels, stack = False)
noise_model = np.zeros([len(voxels), len(voxels)])
for hstory in stories:
    tstim, hstim = np.vstack([stim_dict[tstory] for tstory in stories if tstory != hstory]), stim_dict[hstory]
    tresp, hresp = np.vstack([resp_dict[tstory] for tstory in stories if tstory != hstory]), resp_dict[hstory]
    bs_weights = ridge(tstim, tresp, alphas[voxels])
    resids = hresp - hstim.dot(bs_weights)
    bs_noise_model = resids.T.dot(resids)
    noise_model += bs_noise_model / np.diag(bs_noise_model).mean() / len(stories)
del stim_dict, resp_dict

In [11]:
save_location=os.path.join(paths.EM % subject)
em = {}
em["bscorrs"] = bscorrs
em["voxels"] = voxels
em["tr_stats"] = tr_stats
em["word_stats"] = word_stats
em["stories"] = stories
em["weights"] = weights[:, voxels]
em["noise_model"] = noise_model
save_data(save_location, em)