In [2]:
import numpy as np
import torch

from tqdm.auto import tqdm
from pathlib import Path

import itertools

import os
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv(), verbose=True)

True

# prepare data

In [3]:
from src.utils.ensemble import optimize_val_correlation
from src.config.config import combine_cfgs, get_cfg_defaults
from src.data.datamodule import MyDataModule

In [4]:
# prepare validation data
cfg = combine_cfgs('../src/config/experiments/algonauts2021_i3d_flow.yml')
dm = MyDataModule(cfg)
dm.prepare_data()
dm.setup()

val_indices = dm.val_dataset.indices

cache_path = Path('/data_smr/huze/.cache/val_fmris.pt')

if cache_path.exists():
    val_fmris = torch.load(cache_path)
else:
    val_fmris = [dm.dataset_train_val.__getitem__(i)[1] for i in tqdm(val_indices)]
    val_fmris = np.stack(val_fmris, 0)
    val_fmris = torch.tensor(val_fmris)
    torch.save(val_fmris, cache_path)

# prepare model prediction

In [8]:
# multi-layer `single-run` baseline
# load from notebook 001. (cross-notebook-ref)
path = Path("/data/huze/ray_results/algonauts2021/algonauts2021_i3d_rgb-multi_layer/run_single_tune_config_575a8_00066_66_DATASET.ROI=WB,MODEL.BACKBONE.LAYERS=_'x1', 'x2', 'x3', 'x4'_,MODEL.BACKBONE.LAYER_PATHWAYS=_2022-02-26_14-18-59/prediction.npy")

In [12]:
prediction = torch.tensor(np.load(path)).float()

# get final score for each roi

In [41]:
rois = ['WB', 'V1', 'V2', 'V3', 'V4', 'LOC', 'EBA', 'FFA', 'STS', 'PPA', 'REST', 'SMC1', 'SMC2', 'MC2', 'MC1', 'SC3',
        'SC4', 'LC1', 'LC2', 'LC3', 'LC4', 'LC5']

In [42]:
from src.utils.metrics import vectorized_correlation

In [43]:
scores = vectorized_correlation(prediction[val_indices], val_fmris)

roi_score_dict = {}
for roi in rois:
    voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_score = scores[voxel_indices].mean().item()
    roi_score_dict[roi] = roi_score

In [44]:
roi_score_dict

{'WB': 0.14779090881347656,
 'V1': 0.19442307949066162,
 'V2': 0.20365500450134277,
 'V3': 0.21832409501075745,
 'V4': 0.1992104947566986,
 'LOC': 0.3296410143375397,
 'EBA': 0.3657059073448181,
 'FFA': 0.28206267952919006,
 'STS': 0.19279032945632935,
 'PPA': 0.20913344621658325,
 'REST': 0.11902685463428497,
 'SMC1': 0.08702672272920609,
 'SMC2': 0.02951086312532425,
 'MC2': 0.24660779535770416,
 'MC1': 0.03918087109923363,
 'SC3': 0.25927734375,
 'SC4': 0.23413535952568054,
 'LC1': 0.06566636264324188,
 'LC2': 0.3013060390949249,
 'LC3': 0.22845622897148132,
 'LC4': 0.43582409620285034,
 'LC5': 0.28519466519355774}

# we deforest the backbone when reaching 1/2 score

In [52]:
from pprint import pprint
import json

In [53]:
magic_lst = [0.009, 0.0105, 0.012, 0.021, 0.027, 0.03, 0.036, 0.045, 0.06, 0.09, 0.105, 0.12, 0.135, 0.15, 0.18, 0.21]
def closest(magic_lst, K):
    return magic_lst[min(range(len(magic_lst)), key = lambda i: abs(magic_lst[i]-K))]

In [54]:
def get_and_print(roi_score_dict : dict, multiply_ratio : float):
    new_dict = {k:closest(magic_lst, v*multiply_ratio) for k,v in roi_score_dict.items()}
    print(json.dumps(new_dict, indent=2))

In [57]:
# i3d_rgb
# deforest at roughly 1/2 score
get_and_print(roi_score_dict, 0.5)

{
  "WB": 0.06,
  "V1": 0.09,
  "V2": 0.105,
  "V3": 0.105,
  "V4": 0.105,
  "LOC": 0.15,
  "EBA": 0.18,
  "FFA": 0.135,
  "STS": 0.09,
  "PPA": 0.105,
  "REST": 0.06,
  "SMC1": 0.045,
  "SMC2": 0.012,
  "MC2": 0.12,
  "MC1": 0.021,
  "SC3": 0.135,
  "SC4": 0.12,
  "LC1": 0.03,
  "LC2": 0.15,
  "LC3": 0.12,
  "LC4": 0.21,
  "LC5": 0.15
}


In [58]:
# i3d_flow
# the final score is roughly 80% of that of i3d_rgb
get_and_print(roi_score_dict, 0.5*0.81)

{
  "WB": 0.06,
  "V1": 0.09,
  "V2": 0.09,
  "V3": 0.09,
  "V4": 0.09,
  "LOC": 0.135,
  "EBA": 0.15,
  "FFA": 0.12,
  "STS": 0.09,
  "PPA": 0.09,
  "REST": 0.045,
  "SMC1": 0.036,
  "SMC2": 0.012,
  "MC2": 0.105,
  "MC1": 0.012,
  "SC3": 0.105,
  "SC4": 0.09,
  "LC1": 0.027,
  "LC2": 0.12,
  "LC3": 0.09,
  "LC4": 0.18,
  "LC5": 0.12
}


In [51]:
# save output to src/config/experiments/algonauts2021_i3d_rgb_defrost_score.json