In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jsputils import paths, classes, nsdorg, plotting, encoding, nnutils, selectivity, eiganalysis, gpu_encoding
import os
from os.path import exists
import cortex
import numpy as np
import matplotlib.pyplot as plt
from fastprogress import progress_bar
import scipy.stats as stats
from scipy.spatial.distance import pdist, squareform
import copy
import time
import gc
from IPython.core.debugger import set_trace
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
from PIL import Image

import sys
sys.path.append('/home/jovyan/work/DropboxSandbox/GSN')
import gsn
from gsn.rsa_noise_ceiling import rsa_noise_ceiling

import torch
import torchvision
import torchlens as tl

from torchvision.transforms._presets import ImageClassification

from sklearn.linear_model import Lasso, LinearRegression

In [None]:
nsddir = paths.nsd()

savedir = f'{os.getcwd()}/analysis_outputs/3-Encoding/'

subjs = [f'subj0{s}' for s in range(1,9)]
roi_list = ['FFA-1','FFA-2','OFA','PPA','OPA','EBA','FBA-1','FBA-2','VWFA-1','VWFA-1','OWFA']

train_imageset = 'nonshared1000-3rep-batch0'
val_imageset = 'nonshared1000-3rep-batch1'
test_imageset = 'special515'

space = 'nativesurface'
beta_version = 'betas_fithrf_GLMdenoise_RR'
ncsnr_threshold = 0.3

model_name = 'alexnet-barlow-twins'
floc_imageset_name = 'vpnl-floc'

if 'alexnet-barlow-twins' in model_name:
    
    layer_list = ['conv3',
                  'groupnorm3',
                  'relu3',
                  'conv4',
                  'groupnorm4',
                  'relu4',
                  'conv5',
                  'groupnorm5',
                  'relu5',
                  'maxpool5',
                  'fc6',
                  'batchnorm6',
                  'relu6',
                  'fc7',
                  'batchnorm7',
                  'relu7',
                  'fc8',
                  'batchnorm8']
    
elif 'alexnet-vggface' in model_name:
    
    layer_list = ['conv3',
                  'relu3',
                  'conv4',
                  'relu4',
                  'conv5',
                  'relu5',
                  'maxpool5',
                  'fc6',
                  'relu6',
                  'fc7',
                  'relu7',
                  'fc8']

domain_list = ['faces','scenes','bodies','characters','objects']#,'bodies','objects','scenes','characters']

In [None]:
DNN = classes.DNNModel(model_name)
if 'vggface' not in model_name:
    floc = classes.ImageSet(floc_imageset_name, transforms = DNN.transforms)

    DNN.find_selective_units(floc_imageset_name, overwrite = False, verbose = False,
                           FDR_p = 0.05)

In [None]:
results = dict()

for roi in roi_list:

    results[roi] = dict()

    for subj in progress_bar(subjs):
        NSDsubj = classes.fMRISubject(subj, space, beta_version)
        ROI = classes.BrainRegion(NSDsubj, roi)

        ROI.load_betas()
        ROI.get_ncsnr_mask(threshold = ncsnr_threshold)
        ROI.load_encoding_data(train_imageset, val_imageset, test_imageset)
        
        encoder = classes.EncodingProcedure(ROI, DNN, 
                                     method = 'lasso', 
                                     positive = True,
                                     alphas = [0.1])  # 0.001 for untrained. 0.1 for trained
        
        if 'vggface' in model_name:
            domain_list = ['layer']
        elif 'random' in model_name:
            if 'FFA' in roi or 'OFA' in roi:
                domain_list = ['faces']
            elif 'OPA' in roi or 'PPA' in roi:
                domain_list = ['scenes']
            elif 'EBA' in roi or 'FBA' in roi:
                domain_list = ['bodies']
            elif 'VWFA' in roi or 'OWFA' in roi:
                domain_list = ['characters']
    
        encoder.encode_layers(savedir, 
                              layers = np.flip(layer_list), 
                              domains = domain_list, 
                              overwrite = False)

        results[roi][subj] = encoder.results_df

In [None]:
results_[subj]