In [1]:
import numpy as np
import scipy.stats as stats
import nibabel.freesurfer.mghformat as mgh
import h5py

import sys

sys.path.append("..")

from paths import *
from constants import *

sys.path.append(CODE_PATH)

from utils.general_utils import make_iterable

In [2]:
np.set_printoptions(formatter={'float_kind':'{:f}'.format})

In [3]:
model_layer_strings = (
    ["blocks.0.multipathway_blocks.0", "blocks.0.multipathway_blocks.1"]
    + [f"blocks.1.multipathway_blocks.0.res_blocks.{i}" for i in range(3)]  # slow
    + [f"blocks.1.multipathway_blocks.1.res_blocks.{i}" for i in range(3)]  # fast
    + [f"blocks.2.multipathway_blocks.0.res_blocks.{i}" for i in range(4)]  # slow
    + [f"blocks.2.multipathway_blocks.1.res_blocks.{i}" for i in range(4)]  # fast
    + [f"blocks.3.multipathway_blocks.0.res_blocks.{i}" for i in range(6)]  # slow
    + [f"blocks.3.multipathway_blocks.1.res_blocks.{i}" for i in range(6)]  # fast
    + [f"blocks.4.multipathway_blocks.0.res_blocks.{i}" for i in range(3)]  # slow
    + [f"blocks.4.multipathway_blocks.1.res_blocks.{i}" for i in range(3)]  # fast
    + ["blocks.5", "blocks.6.proj"]
)

In [4]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import colors as mcolors

In [5]:
subjid = ["01"]
roi = "streams_shrink10"
hemi = "lh"
mapping_func = "PLS"
CV = 0
subsample = 2
ROI_NAMES = ['Early',
             'Midventral',
             'Midlateral',
             'Midparietal',
             'Ventral',
             'Lateral',
             'Parietal']
model_name = "slowfast_full"
reduce_temporal_dims = 1
pretrained = 1

In [6]:
#zip layers and save
for sidx, sid in enumerate(subjid):
                                    
    rsquared_array = {}
    layer_keys = [
            item
            for sublist in [
                [item] if type(item) is not list else item for item in model_layer_strings
            ]
            for item in sublist
        ]
    for lidx, layer in enumerate(layer_keys):
        
        reduce_temporal_dims = 1
        if hemi == "rh":
            if lidx < 2:
                reduce_temporal_dims = 0
        
        # get model fits
        load_path = (RESULTS_PATH
                    + "fits_by_layer/subj"
                    + sid
                    + "_"
                    + hemi
                    + "_"
                    + roi
                    + "_"
                    + model_name
                    + (
                        str(reduce_temporal_dims)
                        if model_name == "slowfast" or model_name == "slowfast_full"
                        else ""
                    )
                    + "_"
                    + layer
                    + "_"
                    + mapping_func
                    + "_subsample_"
                    + str(subsample)
                    + "_"
                    + str(CV)
                    + "CV_"
                    + str(pretrained)
                    + "pretraining_fits.hdf5")
        with h5py.File(load_path, "r") as f:
            keys = f.keys()
            for k in keys:
                rsquared_array[k] = f[k][:]
                
    # save to local data folder
    h5f = h5py.File(
        RESULTS_PATH
        + "fits/by_layer_subj"
        + sid
        + "_"
        + hemi
        + "_"
        + roi
        + "_"
        + model_name
        + (
            str(reduce_temporal_dims)
            if model_name == "slowfast" or model_name == "slowfast_full"
            else ""
        )
        + "_"
        + mapping_func
        + "_subsample_"
        + str(subsample)
        + "_"
        + str(CV)
        + "CV_"
        + str(pretrained)
        + "pretraining_fits.hdf5",
        "w",
    )

    for k, v in rsquared_array.items():
        print(str(k))
        h5f.create_dataset(str(k), data=v)
    h5f.close()
    del rsquared_array

blocks.0.multipathway_blocks.0
blocks.0.multipathway_blocks.1
blocks.1.multipathway_blocks.0.res_blocks.0
blocks.1.multipathway_blocks.0.res_blocks.1
blocks.1.multipathway_blocks.0.res_blocks.2
blocks.1.multipathway_blocks.1.res_blocks.0
blocks.1.multipathway_blocks.1.res_blocks.1
blocks.1.multipathway_blocks.1.res_blocks.2
blocks.2.multipathway_blocks.0.res_blocks.0
blocks.2.multipathway_blocks.0.res_blocks.1
blocks.2.multipathway_blocks.0.res_blocks.2
blocks.2.multipathway_blocks.0.res_blocks.3
blocks.2.multipathway_blocks.1.res_blocks.0
blocks.2.multipathway_blocks.1.res_blocks.1
blocks.2.multipathway_blocks.1.res_blocks.2
blocks.2.multipathway_blocks.1.res_blocks.3
blocks.3.multipathway_blocks.0.res_blocks.0
blocks.3.multipathway_blocks.0.res_blocks.1
blocks.3.multipathway_blocks.0.res_blocks.2
blocks.3.multipathway_blocks.0.res_blocks.3
blocks.3.multipathway_blocks.0.res_blocks.4
blocks.3.multipathway_blocks.0.res_blocks.5
blocks.3.multipathway_blocks.1.res_blocks.0
blocks.3.multi