In [1]:
import numpy as np
import torch

from tqdm.auto import tqdm

import itertools

In [2]:
import os
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv(), verbose=True)

True

In [3]:
from pathlib import Path

results_dir = Path(os.getenv('RESULTS_DIR'))
print(results_dir)

/data/huze/ray_results/algonauts2021


In [4]:
finished_runs = [path.parent for path in results_dir.glob('**/prediction.npy')]
print(len(finished_runs))

197


In [5]:
from pprint import pprint
exapmle_files = list(path.name for path in finished_runs[0].iterdir())
pprint(exapmle_files)

['params.json',
 'events.out.tfevents.1645454949.yfwu-guslab.2544995.0',
 'progress.csv',
 'result.json',
 'hparams.yaml',
 'events.out.tfevents.1645454934.yfwu-guslab',
 'prediction.npy',
 'voxel_embedding.npy',
 'params.pkl']


In [6]:
import yaml
from yaml import CLoader

from src.config.config import flatten

run_meta_infos = []
for run_dir in finished_runs:
    hparams = yaml.load(run_dir.joinpath('hparams.yaml').open(), Loader=CLoader)
    run_meta_info = flatten(hparams)
    run_meta_info['path'] = run_dir
    run_meta_infos.append(run_meta_info)

import pandas as pd

run_df = pd.DataFrame(run_meta_infos)

# hierarchical ensemble

In [7]:
from src.utils.ensemble import optimize_val_correlation
from src.config.config import combine_cfgs, get_cfg_defaults
from src.data.datamodule import MyDataModule

In [8]:
# prepare validation data
cfg = combine_cfgs('../src/config/experiments/algonauts2021_i3d_flow.yml')
dm = MyDataModule(cfg)
dm.prepare_data()
dm.setup()

val_indices = dm.val_dataset.indices

cache_path = Path('/home/huze/.cache/val_fmris.pt')

if cache_path.exists():
    val_fmris = torch.load(cache_path)
else:
    val_fmris = [dm.dataset_train_val.__getitem__(i) for i in tqdm(val_indices)]
    val_fmris = np.stack(val_fmris, 0)
    val_fmris = torch.tensor(val_fmris)
    torch.save(val_fmris, cache_path)

In [9]:
# analysis grid search space
rois = run_df['DATASET.ROI'].unique()
backbones = run_df['MODEL.BACKBONE.NAME'].unique()
configs = list(itertools.product(*[run_df[k].unique() for k in ['MODEL.BACKBONE.LAYERS', 'MODEL.NECK.SPP_LEVELS']]))

In [None]:
# hierarchical ensemble
verbose = True
opt_verbose = False

roi_predictoin_dict = {}
roi_voxel_indices_dict = {}

for roi in rois:
    if roi == 'WB':
        # WB model is already in `level1` below
        continue
    
    # load flattened voxel masks
    roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_voxel_indices_dict[roi] = roi_voxel_indices
    roi_val_fmris = val_fmris[..., roi_voxel_indices]
    
    l3_predictions = []
    # level3 combine different models
    for backbone in backbones:
        l2_predictions = []
        # level2 combine all layers and pooling sche
        for layers, spp_levels in configs:
            # level1 roi_model + wb_model
            _l1_df = run_df.loc[
                (run_df['MODEL.BACKBONE.NAME'] == backbone) &
                (run_df['MODEL.BACKBONE.LAYERS'] == layers) &
                (run_df['MODEL.NECK.SPP_LEVELS'] == spp_levels)
            ]
            roi_l1_df = _l1_df.loc[_l1_df['DATASET.ROI'] == roi]
            wb_l1_df = _l1_df.loc[_l1_df['DATASET.ROI'] == 'WB']
            
            # 1 ROI model pair with 1 WB model, with the same hyperparameter
            # if not paired (grid run is not complete yet), will skip
            if not (len(roi_l1_df) == 1 and len(wb_l1_df) == 1):
                print('skipped...\t', roi, backbone, layers, spp_levels, '\t',
                      f'roi={len(roi_l1_df)}', f'wb={len(wb_l1_df)}')
                continue
            if verbose:
                print('Level 1...\t', roi, backbone, layers, spp_levels, '\t')
            l1_predictions = np.stack([
                np.load(roi_l1_df['path'].item().joinpath('prediction.npy')),
                np.load(wb_l1_df['path'].item().joinpath('prediction.npy'))[..., roi_voxel_indices],
            ], -1)
            l1_predictions = torch.tensor(l1_predictions).float()
            l1_ensemble_weight = optimize_val_correlation(l1_predictions[val_indices].clone(), 
                                                          roi_val_fmris.clone(), 
                                                          verbose=opt_verbose,
                                                         device=0)
            new_predictions = l1_predictions @ l1_ensemble_weight
            l2_predictions.append(new_predictions)
        
        if verbose:
            print('Level 2...\t', roi, backbone, '\t')
        l2_predictions = torch.stack(l2_predictions, -1)
        l2_ensemble_weight = optimize_val_correlation(l2_predictions[val_indices].clone(), 
                                                          roi_val_fmris.clone(), 
                                                          verbose=opt_verbose,
                                                         device=0)
        new_predictions = l2_predictions @ l2_ensemble_weight
        l3_predictions.append(new_predictions)
    
    if verbose:
        print('Level 3...\t', roi, '\t')
    l3_predictions = torch.stack(l3_predictions, -1)
    l3_ensemble_weight = optimize_val_correlation(l3_predictions[val_indices].clone(), 
                                                      roi_val_fmris.clone(), 
                                                      verbose=opt_verbose,
                                                     device=0)
    
    roi_prediction = l3_predictions @ l3_ensemble_weight
    
    roi_predictoin_dict[roi] = roi_prediction
    
            
            
            



# "emsemble" ROIs to full brain

In [15]:
combine_dict = {
    'ROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'],
    'LC': ['LC1', 'LC2', 'LC3', 'LC4', 'LC5'],
    'MC': ['MC1', 'MC2', 'LC2', 'LC3', 'LC4', 'LC5'],
    'SMC': ['SMC1', 'SMC2', 'MC2', 'LC2', 'LC3', 'LC4', 'LC5'],
    'SC': ['SMC1', 'SMC2', 'SC3', 'SC4', 'LC2', 'LC3', 'LC4', 'LC5'],
}

In [22]:
shape = np.load(run_df.loc[run_df['DATASET.ROI'] == 'WB'].path.values[0].joinpath('prediction.npy')).shape
print(shape)

(1102, 161326)


In [28]:
# combine rois to full brain
sch_prediction_dict = {}
for sch_name, sch_rois in combine_dict.items():
    prediction = torch.zeros(shape)
    for roi in sch_rois:
        voxel_indices = roi_voxel_indices_dict[roi]
        prediction[..., voxel_indices] = roi_predictoin_dict[roi]
    sch_prediction_dict[sch_name] = prediction
    

In [29]:
sch_prediction_dict

{'ROI': tensor([[-0.0479,  0.0272,  0.0302,  ..., -0.0647, -0.1650, -0.0747],
         [ 0.0115, -0.0449, -0.0573,  ..., -0.1238, -0.1307,  0.0253],
         [-0.0059,  0.0199, -0.0389,  ...,  0.0397,  0.1095,  0.0897],
         ...,
         [ 0.0478,  0.0042, -0.0398,  ...,  0.0993,  0.2024,  0.1934],
         [-0.0636,  0.0524,  0.0418,  ..., -0.0578, -0.1556, -0.1591],
         [-0.0236,  0.0208,  0.0106,  ...,  0.0794,  0.1082,  0.0487]]),
 'LC': tensor([[-0.0454,  0.0734,  0.0509,  ..., -0.0268, -0.1024, -0.0553],
         [ 0.0129, -0.0925, -0.1032,  ..., -0.0210, -0.0207,  0.0889],
         [ 0.0127,  0.0100, -0.0492,  ...,  0.0776,  0.1567,  0.1396],
         ...,
         [ 0.0224, -0.0102, -0.0404,  ...,  0.1298,  0.2577,  0.2563],
         [-0.0426,  0.0466,  0.0378,  ..., -0.0542, -0.1925, -0.2289],
         [-0.0333, -0.0183, -0.0183,  ...,  0.0332,  0.0778,  0.0940]]),
 'MC': tensor([[-0.0412,  0.0834,  0.0535,  ..., -0.0268, -0.1024, -0.0553],
         [ 0.0378, -0.0595