In [None]:
import numpy as np
import torch

from tqdm.auto import tqdm
from pathlib import Path

import itertools

import os
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv(), verbose=True)

In [None]:
DEVICE = torch.device('cuda:1')

In [None]:
results_dir = Path('/data/huze/ray_results/algonauts2021')
print('results_dir', results_dir)

In [None]:
from src.utils.runs import load_run_df

In [None]:
run_df = load_run_df(results_dir)

# prepare data

In [60]:
from src.utils.ensemble import optimize_val_correlation
from src.config.config import combine_cfgs, get_cfg_defaults
from src.data.datamodule import MyDataModule
from pathlib import Path
import torch
import numpy as np

In [61]:
# prepare train and validation data
cfg = get_cfg_defaults()
cfg.DATASET.TRANSFORM = 'i3d_flow'
dm = MyDataModule(cfg)
dm.prepare_data()
dm.setup()

val_indices = dm.val_dataset.indices
fmris_cache_path = Path('/data/huze/.cache/trainval_fmris.pt')

if fmris_cache_path.exists():
    fmris = torch.load(fmris_cache_path)
else:
    fmris = [dm.dataset_train_val.__getitem__(i)[1]
             for i in tqdm(range(dm.dataset_train_val.__len__()))]
    fmris = np.stack(fmris, 0)
    fmris = torch.tensor(fmris)
    torch.save(fmris, fmris_cache_path)

val_fmris = fmris[val_indices]

In [62]:
def get_ensemble_prediction_from_tensor_list(predicions_list, roi_val_fmris, val_indices, opt_verbose=False, tol=1e-4):
    predictions = torch.stack(predicions_list, -1)
    ws = optimize_val_correlation(predictions[val_indices],
                                  roi_val_fmris,
                                  verbose=opt_verbose,
                                  device=DEVICE,
                                  tol=tol)
    new_predictions = predictions @ ws
    return new_predictions


def get_ensemble_prediction_from_df(roi_df, val_indices, roi_val_fmris, roi_voxel_indices, opt_verbose=False, tol=1e-4):
    predictions = np.stack([
        np.load(path.joinpath('prediction.npy'))
        for path in roi_df['path'].values
    ], -1)

    predictions = torch.tensor(predictions).float()
    ws = optimize_val_correlation(predictions[val_indices],
                                  roi_val_fmris,
                                  verbose=opt_verbose,
                                  device=DEVICE,
                                  tol=tol)
    new_predictions = predictions @ ws
    return new_predictions

# 3 level of hierarchical ensemble

In [63]:
from src.utils.rigistry import Registry
from src.utils.misc import my_query_df

ORDERED_HIERACHY_KEYS = ['MODEL.BACKBONE.NAME', 'MODEL.BACKBONE.LAYERS', 'MODEL.NECK.SPP_LEVELS']

HEFN_REGISTRY = Registry()

@HEFN_REGISTRY.register('H1')
def H1_ens_roi(run_df, roi, verbose=True, opt_verbose=False, he_keys=ORDERED_HIERACHY_KEYS):
    run_df = run_df.loc[run_df['DATASET.ROI'] == roi]
    roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_val_fmris = val_fmris[..., roi_voxel_indices]

    if verbose:
        print('Level 1...\t', roi, '\t')

    new_predictions = get_ensemble_prediction_from_df(run_df, val_indices,
                                                      roi_val_fmris, roi_voxel_indices, opt_verbose=opt_verbose)

    return new_predictions


@HEFN_REGISTRY.register('H2')
def H2_ens_roi(run_df, roi, verbose=True, opt_verbose=False, he_keys=ORDERED_HIERACHY_KEYS):
    run_df = run_df.loc[run_df['DATASET.ROI'] == roi]    
    roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_val_fmris = val_fmris[..., roi_voxel_indices]

    o_predictions_list = []
    for v1 in run_df[he_keys[0]].unique():
        _l1_df = my_query_df(run_df, equal_dict={he_keys[0]: v1})

        if verbose:
            print('Level 1...\t', roi, v1, '\t')
        new_predictions = get_ensemble_prediction_from_df(_l1_df, val_indices,
                                                          roi_val_fmris, roi_voxel_indices, opt_verbose=opt_verbose)
        o_predictions_list.append(new_predictions)

    if verbose:
        print('Level 2...\t', roi, '\t')
    new_predictions = get_ensemble_prediction_from_tensor_list(o_predictions_list, roi_val_fmris, val_indices,
                                                               opt_verbose=opt_verbose)
    return new_predictions


@HEFN_REGISTRY.register('H3')
def H3_ens_roi(run_df, roi, verbose=True, opt_verbose=False, he_keys=ORDERED_HIERACHY_KEYS):
    run_df = run_df.loc[run_df['DATASET.ROI'] == roi]    
    roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_val_fmris = val_fmris[..., roi_voxel_indices]

    oo_predictions_list = []
    for v2 in run_df[he_keys[0]].unique():
        _l2_df = my_query_df(run_df, equal_dict={he_keys[0]: v2})
        o_predictions_list = []
        for v1 in _l2_df[he_keys[1]].unique():
            _l1_df = my_query_df(_l2_df, equal_dict={he_keys[1]: v1})
            if verbose:
                print('Level 1...\t', roi, v2, v1, '\t')
            new_predictions = get_ensemble_prediction_from_df(_l1_df, val_indices,
                                                              roi_val_fmris, roi_voxel_indices, opt_verbose=opt_verbose)
            o_predictions_list.append(new_predictions)

        if verbose:
            print('Level 2...\t', roi, v2, '\t')
        new_predictions = get_ensemble_prediction_from_tensor_list(o_predictions_list, roi_val_fmris, val_indices,
                                                                   opt_verbose=opt_verbose)
        oo_predictions_list.append(new_predictions)
    if verbose:
        print('Level 3...\t', roi, '\t')
    new_predictions = get_ensemble_prediction_from_tensor_list(oo_predictions_list, roi_val_fmris, val_indices,
                                                               opt_verbose=opt_verbose)
    return new_predictions

In [64]:
def assemble_rois_to_full_brain(roi_sch_dict, roi_prediction_dict, shape):
    # combine rois to full brain
    sch_prediction_dict = {}
    for sch_name, sch_rois in roi_sch_dict.items():
        prediction = torch.zeros(shape)
        for roi in sch_rois:
            voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
            prediction[..., voxel_indices] = roi_prediction_dict[roi]
        sch_prediction_dict[sch_name] = prediction
    return sch_prediction_dict

In [65]:
run_df.keys()

Index(['DESCRIPTION', 'DATAMODULE.NUM_CV_SPLITS', 'DATAMODULE.I_CV_FOLD',
       'DATASET.NAME', 'DATASET.ROOT_DIR', 'DATASET.TRANSFORM',
       'DATASET.RESOLUTION', 'DATASET.FRAMES', 'DATASET.VOXEL_INDEX_DIR',
       'DATASET.ROI', 'MODEL.BACKBONE.NAME', 'MODEL.BACKBONE.PRETRAINED',
       'MODEL.BACKBONE.PRETRAINED_WEIGHT_DIR', 'MODEL.BACKBONE.DISABLE_BN',
       'MODEL.BACKBONE.LAYERS', 'MODEL.BACKBONE.LAYER_PATHWAYS',
       'MODEL.NECK.NECK_TYPE', 'MODEL.NECK.FIRST_CONV_SIZE',
       'MODEL.NECK.POOLING_MODE', 'MODEL.NECK.SPP_LEVELS',
       'MODEL.NECK.FC_ACTIVATION', 'MODEL.NECK.FC_HIDDEN_DIM',
       'MODEL.NECK.FC_NUM_LAYERS', 'MODEL.NECK.FC_BATCH_NORM',
       'MODEL.NECK.FC_DROPOUT', 'MODEL.NECK.LSTM.HIDDEN_SIZE',
       'MODEL.NECK.LSTM.NUM_LAYERS', 'MODEL.NECK.LSTM.BIDIRECTIONAL',
       'OPTIMIZER.NAME', 'OPTIMIZER.LR', 'OPTIMIZER.WEIGHT_DECAY',
       'SCHEDULER.NAME', 'TRAINER.GPUS', 'TRAINER.FP16', 'TRAINER.MAX_EPOCHS',
       'TRAINER.ACCUMULATE_GRAD_BATCHES', 'TRAIN

In [66]:
run_df = run_df[run_df['TRAINER.CALLBACKS.BACKBONE.DEFROST_SCORE'] < 1.]

# define ensemble sche gird space

In [67]:
import pandas as pd

In [68]:
# this dict defines how to combine rois to full brain
roi_sch_dict = {
    'WB': ['WB'],
    'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'],
    'S-htROI': [f'S-htROI{i+1}' for i in range(8)],
}

# ensemble configuration

In [69]:
results_dir

PosixPath('/data/huze/ray_results/algonauts2021')

In [70]:
save_dir = results_dir.joinpath('notebook200')
save_dir.mkdir(parents=False, exist_ok=True)

In [71]:
save_dir

PosixPath('/data/huze/ray_results/algonauts2021/notebook200')

In [72]:
rois = sum(list(roi_sch_dict.values()), [])

In [73]:
rois

['WB',
 'V1',
 'V2',
 'V3',
 'V4',
 'EBA',
 'LOC',
 'PPA',
 'FFA',
 'STS',
 'REST',
 'S-htROI1',
 'S-htROI2',
 'S-htROI3',
 'S-htROI4',
 'S-htROI5',
 'S-htROI6',
 'S-htROI7',
 'S-htROI8']

# heavy lift for each roi

In [77]:
roi_save_dir = save_dir.joinpath(Path('roi'))

In [78]:
roi_save_dir

PosixPath('/data/huze/ray_results/algonauts2021/notebook200/roi')

In [79]:
roi_save_dir.mkdir(parents=False, exist_ok=True)

# for main results

In [26]:
he_sch = 'H3'

In [27]:
# X model
for b in run_df['MODEL.BACKBONE.NAME'].unique():
    b_run_df = run_df[run_df['MODEL.BACKBONE.NAME'] == b]
    
    # Disk I/O is the bottle neck

    verbose = True
    opt_verbose = False
    skip_existing = True

    for roi in tqdm(rois[:]):
        # skip existing files
        path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
        if path.exists():
            if skip_existing:
                print('skipped existing ...', he_sch, roi)
                continue

        roi_prediction = HEFN_REGISTRY[he_sch](b_run_df, roi, verbose=verbose, opt_verbose=opt_verbose)

        torch.save(roi_prediction, path)

  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H3 V1
skipped existing ... H3 V2
skipped existing ... H3 V3
skipped existing ... H3 V4
skipped existing ... H3 EBA
skipped existing ... H3 LOC
skipped existing ... H3 PPA
skipped existing ... H3 FFA
skipped existing ... H3 STS
skipped existing ... H3 REST
skipped existing ... H3 S-htROI1
skipped existing ... H3 S-htROI2
skipped existing ... H3 S-htROI3
skipped existing ... H3 S-htROI4
skipped existing ... H3 S-htROI5
skipped existing ... H3 S-htROI6
skipped existing ... H3 S-htROI7
skipped existing ... H3 S-htROI8


In [83]:
# for supplementary
for roi in tqdm(rois[:]):
    b = '3d_swin'
    b_run_df = run_df[run_df['MODEL.BACKBONE.NAME'] == b]
    verbose = True
    opt_verbose = False
    skip_existing = True

    for he_sch in ['H3', 'H2']:
        # skip existing files
        path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
        if path.exists():
            if skip_existing:
                print('skipped existing ...', he_sch, roi)
                continue

        roi_prediction = HEFN_REGISTRY[he_sch](b_run_df, roi, verbose=verbose, opt_verbose=opt_verbose)

        torch.save(roi_prediction, path)

  0%|          | 0/19 [00:00<?, ?it/s]

skipped existing ... H3 WB
skipped existing ... H2 WB
skipped existing ... H3 V1
Level 1...	 V1 3d_swin 	
Level 2...	 V1 	
skipped existing ... H3 V2
Level 1...	 V2 3d_swin 	
Level 2...	 V2 	
skipped existing ... H3 V3
Level 1...	 V3 3d_swin 	
Level 2...	 V3 	
skipped existing ... H3 V4
Level 1...	 V4 3d_swin 	
Level 2...	 V4 	
skipped existing ... H3 EBA
Level 1...	 EBA 3d_swin 	
Level 2...	 EBA 	
skipped existing ... H3 LOC
Level 1...	 LOC 3d_swin 	
Level 2...	 LOC 	
skipped existing ... H3 PPA
Level 1...	 PPA 3d_swin 	
Level 2...	 PPA 	
skipped existing ... H3 FFA
Level 1...	 FFA 3d_swin 	
Level 2...	 FFA 	
skipped existing ... H3 STS
Level 1...	 STS 3d_swin 	
Level 2...	 STS 	
skipped existing ... H3 REST
Level 1...	 REST 3d_swin 	
Level 2...	 REST 	
skipped existing ... H3 S-htROI1
Level 1...	 S-htROI1 3d_swin 	
Level 2...	 S-htROI1 	
skipped existing ... H3 S-htROI2
Level 1...	 S-htROI2 3d_swin 	
Level 2...	 S-htROI2 	
skipped existing ... H3 S-htROI3
Level 1...	 S-htROI3 3d_swin

In [36]:
# all models
for roi in tqdm(rois[:]):
    prediction_list = []
    for b in run_df['MODEL.BACKBONE.NAME'].unique():
        path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
        prediction = torch.load(path)
        prediction_list.append(prediction)
    
    roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_val_fmris = val_fmris[..., roi_voxel_indices]

    roi_prediction = get_ensemble_prediction_from_tensor_list(prediction_list, roi_val_fmris, val_indices)
    
    b = 'all'
    path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
    torch.save(roi_prediction, path)

  0%|          | 0/19 [00:00<?, ?it/s]

In [40]:
# all - X models
for o_b in run_df['MODEL.BACKBONE.NAME'].unique():
    for roi in tqdm(rois[:]):
        prediction_list = []
        for b in run_df['MODEL.BACKBONE.NAME'].unique():
            if b == o_b: continue
            path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
            prediction = torch.load(path)
            prediction_list.append(prediction)

        roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
        roi_val_fmris = val_fmris[..., roi_voxel_indices]

        roi_prediction = get_ensemble_prediction_from_tensor_list(prediction_list, roi_val_fmris, val_indices)

        path = roi_save_dir.joinpath(Path(f'backbone=minus_{o_b},he_sch={he_sch},roi={roi}-prediction.pt'))
        torch.save(roi_prediction, path)

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

In [41]:
# 1 - 11 models
sorted_names = ['3d_swin', 'i3d_rgb', '2d_densnet_warp_3d',
       '2d_simclr_warp_3d', '2d_moby_swin_warp_3d',
       '2d_pyconvsegnet_warp_3d', '2d_seg_swin_warp_3d', 'i3d_flow',
       '2d_colorizer_warp_3d', '2d_bdcnvgg_warp_3d', 'audio_vggish']

for roi in tqdm(rois[:]):
    roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_val_fmris = val_fmris[..., roi_voxel_indices]
    
    prediction_list = []
    for b in sorted_names:
        path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
        prediction = torch.load(path)
        prediction_list.append(prediction)

        roi_prediction = get_ensemble_prediction_from_tensor_list(prediction_list, roi_val_fmris, val_indices)
    
        l = len(prediction_list)
        path = roi_save_dir.joinpath(Path(f'backbone={l},he_sch={he_sch},roi={roi}-prediction.pt'))
        torch.save(roi_prediction, path)

  0%|          | 0/19 [00:00<?, ?it/s]

In [51]:
from src.utils.metrics import vectorized_correlation

for roi in ['WB']:
    roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_val_fmris = val_fmris[..., roi_voxel_indices]
    
    for l in range(1, 12):
        path = roi_save_dir.joinpath(Path(f'backbone={l},he_sch={he_sch},roi={roi}-prediction.pt'))
        prediction = torch.load(path)
        
        score = vectorized_correlation(prediction[val_indices], roi_val_fmris)
        score /= noise_ceiling
        score = score.mean().item()
        print(l, score)

# 1 - 11 models (greedy)
sorted_names = ['3d_swin', 'i3d_rgb', '2d_densnet_warp_3d',
       '2d_simclr_warp_3d', '2d_moby_swin_warp_3d',
       '2d_pyconvsegnet_warp_3d', '2d_seg_swin_warp_3d', 'i3d_flow',
       '2d_colorizer_warp_3d', '2d_bdcnvgg_warp_3d', 'audio_vggish']

for roi in ['WB']:
    roi_voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
    roi_val_fmris = val_fmris[..., roi_voxel_indices]
    
    best_score = 0.
    
    prediction_list = []
    for b in sorted_names:
        path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
        prediction = torch.load(path)
        tmp_prediction_list = prediction_list + [prediction]

        roi_prediction = get_ensemble_prediction_from_tensor_list(tmp_prediction_list, roi_val_fmris, val_indices)
        
        score = vectorized_correlation(roi_prediction[val_indices], roi_val_fmris)
        score /= noise_ceiling
        score = score.mean().item()
        print(b, score)
        if score > best_score:
            beest_score = score
            prediction_list = tmp_prediction_list
        
        # l = len(prediction_list)
        # path = roi_save_dir.joinpath(Path(f'backbone={l},he_sch={he_sch},roi={roi}-prediction.pt'))
        # torch.save(roi_prediction, path)

# "assemble" ROIs to full brain and save to file


In [28]:
backbones = run_df['MODEL.BACKBONE.NAME'].unique().tolist() + ['all'] + [f'minus_{o_b}' for o_b in run_df['MODEL.BACKBONE.NAME'].unique()] + list(range(1, 12))

In [29]:
backbones

['2d_pyconvsegnet_warp_3d',
 'i3d_flow',
 '2d_seg_swin_warp_3d',
 'i3d_rgb',
 '2d_moby_swin_warp_3d',
 '2d_colorizer_warp_3d',
 'audio_vggish',
 '2d_densnet_warp_3d',
 '3d_swin',
 '2d_bdcnvgg_warp_3d',
 '2d_simclr_warp_3d',
 'all',
 'minus_2d_pyconvsegnet_warp_3d',
 'minus_i3d_flow',
 'minus_2d_seg_swin_warp_3d',
 'minus_i3d_rgb',
 'minus_2d_moby_swin_warp_3d',
 'minus_2d_colorizer_warp_3d',
 'minus_audio_vggish',
 'minus_2d_densnet_warp_3d',
 'minus_3d_swin',
 'minus_2d_bdcnvgg_warp_3d',
 'minus_2d_simclr_warp_3d',
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11]

In [30]:
roi_sch_dict

{'WB': ['WB'],
 'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'],
 'S-htROI': ['S-htROI1',
  'S-htROI2',
  'S-htROI3',
  'S-htROI4',
  'S-htROI5',
  'S-htROI6',
  'S-htROI7',
  'S-htROI8']}

In [31]:
from src.utils.submission import algonauts2021_submission_from_whole_brain_prediction

In [50]:
wb_shape = np.load(run_df.loc[run_df['DATASET.ROI'] == 'WB'].path.values[0].joinpath('prediction.npy')).shape

for b in tqdm(backbones):

    for roi_sch, rois in roi_sch_dict.items():

        prediction = torch.zeros(wb_shape)

        for roi in rois:
            voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
            path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
            prediction[..., voxel_indices] = torch.load(path)

        path = Path(os.path.join(save_dir, f'backbone={b},he_sch={he_sch},roi_sch={roi_sch}-prediction.pt'))
        torch.save(prediction, path)

        # algonauts2021_submission_from_whole_brain_prediction('../../src/config/dataset',
        #                                                 path.name.replace('-prediction.pt', ''),
        #                                                 prediction[1000:].numpy(),
        #                                                 output_dir='./submissions/full/',
        #                                                 mini_track=False)

  0%|          | 0/34 [00:00<?, ?it/s]

In [84]:
# for supplementary
wb_shape = np.load(run_df.loc[run_df['DATASET.ROI'] == 'WB'].path.values[0].joinpath('prediction.npy')).shape

b = '3d_swin'

for he_sch in ['H3', 'H2']:

    for roi_sch, rois in roi_sch_dict.items():

        prediction = torch.zeros(wb_shape)

        for roi in rois:
            voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
            path = roi_save_dir.joinpath(Path(f'backbone={b},he_sch={he_sch},roi={roi}-prediction.pt'))
            prediction[..., voxel_indices] = torch.load(path)

        path = Path(os.path.join(save_dir, f'backbone={b},he_sch={he_sch},roi_sch={roi_sch}-prediction.pt'))
        torch.save(prediction, path)

# ROI intersection

In [20]:
def intersect1d_multi_arr(arr_list):
    intersect = arr_list[0]
    for arr in arr_list[1:]:
        intersect = np.intersect1d(intersect, arr)
    return intersect

In [87]:
from src.utils.rigistry import Registry
from src.utils.misc import my_query_df
from src.utils.metrics import vectorized_correlation

INTFN_REGISTRY = Registry()

@INTFN_REGISTRY.register('grouped ensemble')
def grouped_ensemble(prediction_list, val_indices, indiced_val_fmris, opt_verbose=False):
    predictions = torch.stack(prediction_list, -1)
    ws = optimize_val_correlation(predictions[val_indices],
                                  indiced_val_fmris,
                                  verbose=opt_verbose,
                                  device=DEVICE)
    new_predictions = predictions @ ws
    return new_predictions


@INTFN_REGISTRY.register('grouped swap')
def grouped_swap(prediction_list, val_indices, indiced_val_fmris):
    voxel_scores = np.asarray([
        vectorized_correlation(p[val_indices], indiced_val_fmris).numpy()
        for p in prediction_list
    ])
    croi_scores = voxel_scores.mean(1)

    new_predictions = prediction_list[croi_scores.argmax()]
    return new_predictions

@INTFN_REGISTRY.register('voxel-wise swap')
def voxel_wise_swap(prediction_list, val_indices, indiced_val_fmris):
    voxel_scores = np.asarray([
        vectorized_correlation(p[val_indices], indiced_val_fmris).numpy()
        for p in prediction_list
    ])
        
    vsargmax = voxel_scores.argmax(0)
    new_predictions = torch.stack([prediction_list[idx][:, i] for i, idx in enumerate(vsargmax)], -1)
    return new_predictions

In [88]:
def all_combinations(any_list):
    return itertools.chain.from_iterable(
        itertools.combinations(any_list, i + 1)
        for i in range(len(any_list)))

In [89]:
subset_roi_sch_dict_keys = list(all_combinations(roi_sch_dict.keys()))

In [90]:
subset_roi_sch_dict_keys

[('WB',),
 ('aROI',),
 ('S-htROI',),
 ('WB', 'aROI'),
 ('WB', 'S-htROI'),
 ('aROI', 'S-htROI'),
 ('WB', 'aROI', 'S-htROI')]

In [91]:
subset_roi_sch_dicts = []
for keys in subset_roi_sch_dict_keys:
    d = {k: roi_sch_dict[k] for k in keys}
    subset_roi_sch_dicts.append(d)

In [92]:
subset_roi_sch_dicts

[{'WB': ['WB']},
 {'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST']},
 {'S-htROI': ['S-htROI1',
   'S-htROI2',
   'S-htROI3',
   'S-htROI4',
   'S-htROI5',
   'S-htROI6',
   'S-htROI7',
   'S-htROI8']},
 {'WB': ['WB'],
  'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST']},
 {'WB': ['WB'],
  'S-htROI': ['S-htROI1',
   'S-htROI2',
   'S-htROI3',
   'S-htROI4',
   'S-htROI5',
   'S-htROI6',
   'S-htROI7',
   'S-htROI8']},
 {'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'],
  'S-htROI': ['S-htROI1',
   'S-htROI2',
   'S-htROI3',
   'S-htROI4',
   'S-htROI5',
   'S-htROI6',
   'S-htROI7',
   'S-htROI8']},
 {'WB': ['WB'],
  'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'],
  'S-htROI': ['S-htROI1',
   'S-htROI2',
   'S-htROI3',
   'S-htROI4',
   'S-htROI5',
   'S-htROI6',
   'S-htROI7',
   'S-htROI8']}]

In [39]:
for backbone in backbones:
    for subset_roi_sch_dict in tqdm(subset_roi_sch_dicts,
                                   desc=f'backbone={backbone}'):
        subset_roi_sch_text = '+'.join(subset_roi_sch_dict.keys())
        intersection_sch = 'grouped ensemble'

        wb_shape = None
        new_prediction = None
        
        prediction_list = [torch.load(os.path.join(save_dir,
                         f'backbone={backbone},he_sch={he_sch},roi_sch={roi_sch}-prediction.pt'))
                   for roi_sch in subset_roi_sch_dict.keys()]
        
        if len(prediction_list) > 1:
        
            for rois in list(itertools.product(*subset_roi_sch_dict.values())):
                voxel_indices_list = [torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt')) for roi in rois]
                intersect = intersect1d_multi_arr(voxel_indices_list)

                if wb_shape is None: wb_shape = prediction_list[0].shape
                if new_prediction is None: new_prediction = torch.zeros(wb_shape)

                intersect_prediction_list = [p[:, intersect] for p in prediction_list]

                if len(intersect) > 0:
                    new_prediction[:, intersect] = INTFN_REGISTRY[intersection_sch](intersect_prediction_list, val_indices, val_fmris[:, intersect])
        
        else:
            
            new_prediction = prediction_list[0]

        torch.save(new_prediction, os.path.join(save_dir,
                                     f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                     f'-prediction.pt'))

backbone=i3d_rgb:   0%|          | 0/7 [00:00<?, ?it/s]

FileNotFoundError: [Errno 2] No such file or directory: '/data/huze/ray_results/algonauts2021/notebook200/backbone=i3d_rgb,he_sch=H3,roi_sch=i3d_rgb-htROI-prediction.pt'

In [164]:
save_dir

PosixPath('/data/huze/ray_results/algonauts2021/notebook200')

In [93]:
# for supplementary
for he_sch in ['H3', 'H2']:
    for backbone in ['3d_swin']:
        for subset_roi_sch_dict in tqdm(subset_roi_sch_dicts,
                                       desc=f'backbone={backbone}'):
            subset_roi_sch_text = '+'.join(subset_roi_sch_dict.keys())
            for intersection_sch in INTFN_REGISTRY.keys():

                wb_shape = None
                new_prediction = None

                prediction_list = [torch.load(os.path.join(save_dir,
                                 f'backbone={backbone},he_sch={he_sch},roi_sch={roi_sch}-prediction.pt'))
                           for roi_sch in subset_roi_sch_dict.keys()]

                if len(prediction_list) > 1:

                    for rois in list(itertools.product(*subset_roi_sch_dict.values())):
                        voxel_indices_list = [torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt')) for roi in rois]
                        intersect = intersect1d_multi_arr(voxel_indices_list)

                        if wb_shape is None: wb_shape = prediction_list[0].shape
                        if new_prediction is None: new_prediction = torch.zeros(wb_shape)

                        intersect_prediction_list = [p[:, intersect] for p in prediction_list]

                        if len(intersect) > 0:
                            new_prediction[:, intersect] = INTFN_REGISTRY[intersection_sch](intersect_prediction_list, val_indices, val_fmris[:, intersect])

                else:

                    new_prediction = prediction_list[0]

                torch.save(new_prediction, os.path.join(save_dir,
                                             f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                             f'-prediction.pt'))

backbone=3d_swin:   0%|          | 0/7 [00:00<?, ?it/s]

backbone=3d_swin:   0%|          | 0/7 [00:00<?, ?it/s]

# save WB+aROI+htROI weights

In [42]:
def grouped_ensemble_weights(prediction_list, val_indices, indiced_val_fmris, opt_verbose=False):
    predictions = torch.stack(prediction_list, -1)
    ws = optimize_val_correlation(predictions[val_indices],
                                  indiced_val_fmris,
                                  verbose=opt_verbose,
                                  device=DEVICE)
    return ws

In [27]:
prediction_list = [torch.load(os.path.join(save_dir,
                 f'backbone=1,he_sch=H3,roi_sch={roi_sch}-prediction.pt'))
           for roi_sch in ('WB', 'aROI', 'S-htROI')]

In [40]:
weights = torch.zeros(val_fmris.shape[1], 3)

In [47]:
for rois in tqdm(list(itertools.product(*subset_roi_sch_dicts[-1].values()))):
    voxel_indices_list = [torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt')) for roi in rois]
    intersect = intersect1d_multi_arr(voxel_indices_list)
    
    intersect_prediction_list = [p[:, intersect] for p in prediction_list]
    
    if len(intersect) > 0:
        weights[intersect, :] = grouped_ensemble_weights(intersect_prediction_list, val_indices, val_fmris[:, intersect])


  0%|          | 0/80 [00:00<?, ?it/s]

In [48]:
torch.save(weights, 'tmp/notebook200/WB+aROI+htROI weights.pt')

# save to score

In [86]:
from src.utils.metrics import vectorized_correlation

In [95]:
noise_ceiling = np.load('config/noise_ceiling.npy')

In [96]:
save_dir

PosixPath('/data/huze/ray_results/algonauts2021/notebook200')

In [97]:
score_save_dir = save_dir.joinpath(Path('score'))

In [98]:
score_save_dir.mkdir(exist_ok=True)

In [46]:
for backbone in tqdm(backbones):
    for subset_roi_sch_dict in subset_roi_sch_dicts:
        subset_roi_sch_text = '+'.join(subset_roi_sch_dict.keys())
        intersection_sch = 'grouped ensemble'
        

        prediction = torch.load(os.path.join(save_dir,
                                     f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                     f'-prediction.pt'))
        scores = vectorized_correlation(prediction[val_indices], val_fmris) / noise_ceiling
        
        
        torch.save(scores, os.path.join(score_save_dir,
                                     f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                     f'-score.pt'))

  0%|          | 0/34 [00:00<?, ?it/s]

In [99]:
# for supplementary
for he_sch in ['H3', 'H2']:
    for backbone in ['3d_swin']:
        for subset_roi_sch_dict in tqdm(subset_roi_sch_dicts,
                                       desc=f'backbone={backbone}'):
            subset_roi_sch_text = '+'.join(subset_roi_sch_dict.keys())
            for intersection_sch in INTFN_REGISTRY.keys():

                prediction = torch.load(os.path.join(save_dir,
                                             f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                             f'-prediction.pt'))
                scores = vectorized_correlation(prediction[val_indices], val_fmris) / noise_ceiling


                torch.save(scores, os.path.join(score_save_dir,
                                             f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                             f'-score.pt'))

backbone=3d_swin:   0%|          | 0/7 [00:00<?, ?it/s]

backbone=3d_swin:   0%|          | 0/7 [00:00<?, ?it/s]

In [100]:
ls = []
for subset_roi_sch_dict in subset_roi_sch_dicts:
    subset_roi_sch_text = '+'.join(subset_roi_sch_dict.keys())
    ls.append(subset_roi_sch_text)
print(ls)

['WB', 'aROI', 'S-htROI', 'WB+aROI', 'WB+S-htROI', 'aROI+S-htROI', 'WB+aROI+S-htROI']


In [101]:
INTFN_REGISTRY.keys()

dict_keys(['grouped ensemble', 'grouped swap', 'voxel-wise swap'])

# match ROI ablation

In [119]:
# this dict defines how to combine rois to full brain
roi_sch_dict = {
    'i3d_rgb-htROI': ['i3d_rgb_htROI5',
 'i3d_rgb_htROI2',
 'i3d_rgb_htROI6',
 'i3d_rgb_htROI4',
 'i3d_rgb_htROI3',
 'i3d_rgb_htROI1'],
    '3d_swin-htROI': ['3d_swin_htROI6',
 '3d_swin_htROI3',
 '3d_swin_htROI2',
 '3d_swin_htROI4',
 '3d_swin_htROI1',
 '3d_swin_htROI5'],
}

In [120]:
he_sch = 'H3'

In [126]:
# X model
for b in roi_sch_dict.keys():
    backbone = b.replace('-htROI', '')
    b_run_df = run_df[run_df['MODEL.BACKBONE.NAME'] == backbone]
    
    # Disk I/O is the bottle neck

    verbose = True
    opt_verbose = False
    skip_existing = True
    
    rois = roi_sch_dict[b]
    for roi in tqdm(rois[:]):
        # skip existing files
        path = roi_save_dir.joinpath(Path(f'backbone={backbone},he_sch={he_sch},roi={roi}-prediction.pt'))
        if path.exists():
            if skip_existing:
                print('skipped existing ...', he_sch, roi)
                continue

        roi_prediction = HEFN_REGISTRY[he_sch](b_run_df, roi, verbose=verbose, opt_verbose=opt_verbose)

        torch.save(roi_prediction, path)

  0%|          | 0/6 [00:00<?, ?it/s]

skipped existing ... H3 i3d_rgb_htROI5
skipped existing ... H3 i3d_rgb_htROI2
skipped existing ... H3 i3d_rgb_htROI6
skipped existing ... H3 i3d_rgb_htROI4
skipped existing ... H3 i3d_rgb_htROI3
skipped existing ... H3 i3d_rgb_htROI1


  0%|          | 0/6 [00:00<?, ?it/s]

skipped existing ... H3 3d_swin_htROI6
skipped existing ... H3 3d_swin_htROI3
skipped existing ... H3 3d_swin_htROI2
skipped existing ... H3 3d_swin_htROI4
skipped existing ... H3 3d_swin_htROI1
skipped existing ... H3 3d_swin_htROI5


In [127]:
backbones = list(roi_sch_dict.keys())

In [128]:
backbones

['i3d_rgb-htROI', '3d_swin-htROI']

In [129]:
roi_sch_dict

{'i3d_rgb-htROI': ['i3d_rgb_htROI5',
  'i3d_rgb_htROI2',
  'i3d_rgb_htROI6',
  'i3d_rgb_htROI4',
  'i3d_rgb_htROI3',
  'i3d_rgb_htROI1'],
 '3d_swin-htROI': ['3d_swin_htROI6',
  '3d_swin_htROI3',
  '3d_swin_htROI2',
  '3d_swin_htROI4',
  '3d_swin_htROI1',
  '3d_swin_htROI5']}

In [130]:
from src.utils.submission import algonauts2021_submission_from_whole_brain_prediction

In [132]:
wb_shape = np.load(run_df.loc[run_df['DATASET.ROI'] == 'WB'].path.values[0].joinpath('prediction.npy')).shape

for b in tqdm(backbones):
    backbone = b.replace('-htROI', '')

    for roi_sch, rois in roi_sch_dict.items():

        prediction = torch.zeros(wb_shape)
        
        rois = roi_sch_dict[b]
        
        for roi in rois:
            voxel_indices = torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt'))
            path = roi_save_dir.joinpath(Path(f'backbone={backbone},he_sch={he_sch},roi={roi}-prediction.pt'))
            prediction[..., voxel_indices] = torch.load(path)

        path = Path(os.path.join(save_dir, f'backbone={backbone},he_sch={he_sch},roi_sch={roi_sch}-prediction.pt'))
        torch.save(prediction, path)

        # algonauts2021_submission_from_whole_brain_prediction('../../src/config/dataset',
        #                                                 path.name.replace('-prediction.pt', ''),
        #                                                 prediction[1000:].numpy(),
        #                                                 output_dir='./submissions/full/',
        #                                                 mini_track=False)

  0%|          | 0/2 [00:00<?, ?it/s]

In [134]:
subset_roi_sch_dict_keys = list(all_combinations(['WB', 'aROI', 'htROI']))

In [135]:
subset_roi_sch_dict_keys

[('WB',),
 ('aROI',),
 ('htROI',),
 ('WB', 'aROI'),
 ('WB', 'htROI'),
 ('aROI', 'htROI'),
 ('WB', 'aROI', 'htROI')]

In [136]:
roi_sch_dict

{'i3d_rgb-htROI': ['i3d_rgb_htROI5',
  'i3d_rgb_htROI2',
  'i3d_rgb_htROI6',
  'i3d_rgb_htROI4',
  'i3d_rgb_htROI3',
  'i3d_rgb_htROI1'],
 '3d_swin-htROI': ['3d_swin_htROI6',
  '3d_swin_htROI3',
  '3d_swin_htROI2',
  '3d_swin_htROI4',
  '3d_swin_htROI1',
  '3d_swin_htROI5']}

In [137]:
roi_sch_dict_full = roi_sch_dict
roi_sch_dict_full.update({
    'WB': ['WB'],
    'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'],
})

In [141]:
for backbone in backbones:
    subset_roi_sch_dicts = []
    for keys in subset_roi_sch_dict_keys:
        d = {k if k != 'htROI' else backbone: roi_sch_dict_full[k if k != 'htROI' else backbone] for k in keys}
        subset_roi_sch_dicts.append(d)
    print(subset_roi_sch_dicts)

[{'WB': ['WB']}, {'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST']}, {'i3d_rgb-htROI': ['i3d_rgb_htROI5', 'i3d_rgb_htROI2', 'i3d_rgb_htROI6', 'i3d_rgb_htROI4', 'i3d_rgb_htROI3', 'i3d_rgb_htROI1']}, {'WB': ['WB'], 'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST']}, {'WB': ['WB'], 'i3d_rgb-htROI': ['i3d_rgb_htROI5', 'i3d_rgb_htROI2', 'i3d_rgb_htROI6', 'i3d_rgb_htROI4', 'i3d_rgb_htROI3', 'i3d_rgb_htROI1']}, {'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'], 'i3d_rgb-htROI': ['i3d_rgb_htROI5', 'i3d_rgb_htROI2', 'i3d_rgb_htROI6', 'i3d_rgb_htROI4', 'i3d_rgb_htROI3', 'i3d_rgb_htROI1']}, {'WB': ['WB'], 'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'], 'i3d_rgb-htROI': ['i3d_rgb_htROI5', 'i3d_rgb_htROI2', 'i3d_rgb_htROI6', 'i3d_rgb_htROI4', 'i3d_rgb_htROI3', 'i3d_rgb_htROI1']}]
[{'WB': ['WB']}, {'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST']}, {'3d_swin-htROI':

In [142]:
subset_roi_sch_dicts

[{'WB': ['WB']},
 {'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST']},
 {'3d_swin-htROI': ['3d_swin_htROI6',
   '3d_swin_htROI3',
   '3d_swin_htROI2',
   '3d_swin_htROI4',
   '3d_swin_htROI1',
   '3d_swin_htROI5']},
 {'WB': ['WB'],
  'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST']},
 {'WB': ['WB'],
  '3d_swin-htROI': ['3d_swin_htROI6',
   '3d_swin_htROI3',
   '3d_swin_htROI2',
   '3d_swin_htROI4',
   '3d_swin_htROI1',
   '3d_swin_htROI5']},
 {'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'],
  '3d_swin-htROI': ['3d_swin_htROI6',
   '3d_swin_htROI3',
   '3d_swin_htROI2',
   '3d_swin_htROI4',
   '3d_swin_htROI1',
   '3d_swin_htROI5']},
 {'WB': ['WB'],
  'aROI': ['V1', 'V2', 'V3', 'V4', 'EBA', 'LOC', 'PPA', 'FFA', 'STS', 'REST'],
  '3d_swin-htROI': ['3d_swin_htROI6',
   '3d_swin_htROI3',
   '3d_swin_htROI2',
   '3d_swin_htROI4',
   '3d_swin_htROI1',
   '3d_swin_htROI5']}]

In [149]:
for backbone in backbones:
    
    backbone = backbone.replace('-htROI', '')
    
    subset_roi_sch_dicts = []
    for keys in subset_roi_sch_dict_keys:
        d = {k if k != 'htROI' else backbone+'-htROI': roi_sch_dict_full[k if k != 'htROI' else backbone+'-htROI'] for k in keys}
        subset_roi_sch_dicts.append(d)

    for subset_roi_sch_dict in tqdm(subset_roi_sch_dicts,
                                   desc=f'backbone={backbone}'):
        subset_roi_sch_text = '+'.join(subset_roi_sch_dict.keys())
        intersection_sch = 'grouped ensemble'

        wb_shape = None
        new_prediction = None
        
        prediction_list = [torch.load(os.path.join(save_dir,
                         f'backbone={backbone},he_sch={he_sch},roi_sch={roi_sch}-prediction.pt'))
                   for roi_sch in subset_roi_sch_dict.keys()]
        
        if len(prediction_list) > 1:
        
            for rois in list(itertools.product(*subset_roi_sch_dict.values())):
                voxel_indices_list = [torch.load(os.path.join(cfg.DATASET.VOXEL_INDEX_DIR, f'{roi}.pt')) for roi in rois]
                intersect = intersect1d_multi_arr(voxel_indices_list)

                if wb_shape is None: wb_shape = prediction_list[0].shape
                if new_prediction is None: new_prediction = torch.zeros(wb_shape)

                intersect_prediction_list = [p[:, intersect] for p in prediction_list]

                if len(intersect) > 0:
                    new_prediction[:, intersect] = INTFN_REGISTRY[intersection_sch](intersect_prediction_list, val_indices, val_fmris[:, intersect])
        
        else:
            
            new_prediction = prediction_list[0]
        
        path = Path(os.path.join(save_dir,
                                     f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                     f'-prediction.pt'))
        
        torch.save(new_prediction, path)
        
        # algonauts2021_submission_from_whole_brain_prediction('../../src/config/dataset',
        #                                                 path.name.replace('-prediction.pt', ''),
        #                                                 new_prediction[1000:].numpy(),
        #                                                 output_dir='./submissions/ablation/',
        #                                                 mini_track=False)
        scores = vectorized_correlation(new_prediction[val_indices], val_fmris) / noise_ceiling
        score = scores.mean().item()
        print(backbone, subset_roi_sch_text, score)

backbone=i3d_rgb:   0%|          | 0/7 [00:00<?, ?it/s]

i3d_rgb WB 0.45329281106712116
i3d_rgb aROI 0.45329281106712116
i3d_rgb i3d_rgb-htROI 0.45329281106712116
i3d_rgb WB+aROI 0.45329281106712116
i3d_rgb WB+i3d_rgb-htROI 0.45329281106712116
i3d_rgb aROI+i3d_rgb-htROI 0.45329281106712116
i3d_rgb WB+aROI+i3d_rgb-htROI 0.45329281106712116


backbone=3d_swin:   0%|          | 0/7 [00:00<?, ?it/s]

3d_swin WB 0.45329281106712116
3d_swin aROI 0.45329281106712116
3d_swin 3d_swin-htROI 0.45329281106712116
3d_swin WB+aROI 0.45329281106712116
3d_swin WB+3d_swin-htROI 0.45329281106712116
3d_swin aROI+3d_swin-htROI 0.45329281106712116
3d_swin WB+aROI+3d_swin-htROI 0.45329281106712116


In [161]:
for backbone in backbones:
    
    backbone = backbone.replace('-htROI', '')
    
    subset_roi_sch_dicts = []
    for keys in subset_roi_sch_dict_keys:
        d = {k if k != 'htROI' else backbone+'-htROI': roi_sch_dict_full[k if k != 'htROI' else backbone+'-htROI'] for k in keys}
        subset_roi_sch_dicts.append(d)

    for subset_roi_sch_dict in tqdm(subset_roi_sch_dicts,
                                   desc=f'backbone={backbone}'):
        subset_roi_sch_text = '+'.join(subset_roi_sch_dict.keys())
        intersection_sch = 'grouped ensemble'
        
        path = Path(os.path.join(save_dir,
                                     f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                     f'-prediction.pt'))
        
        new_prediction = torch.load(path)
        
        # algonauts2021_submission_from_whole_brain_prediction('../../src/config/dataset',
        #                                                 path.name.replace('-prediction.pt', ''),
        #                                                 new_prediction[1000:].numpy(),
        #                                                 output_dir='./submissions/ablation/',
        #                                                 mini_track=False)
        scores = vectorized_correlation(new_prediction[val_indices], val_fmris) / noise_ceiling
        sem = scores.std() / np.sqrt(len(scores))
        sem = sem.item()
        score = scores.mean().item()
        print(backbone, subset_roi_sch_text, '\t', score, '\t', sem)

backbone=i3d_rgb:   0%|          | 0/7 [00:00<?, ?it/s]

i3d_rgb WB 	 0.42381963992639043 	 0.0012802009652976766
i3d_rgb aROI 	 0.44165807653587275 	 0.001294845655740069
i3d_rgb i3d_rgb-htROI 	 0.4587858653969395 	 0.0012912800939623194
i3d_rgb WB+aROI 	 0.44480154283773893 	 0.0012980887036060397
i3d_rgb WB+i3d_rgb-htROI 	 0.4672036128227344 	 0.001295391605606082
i3d_rgb aROI+i3d_rgb-htROI 	 0.47562842784751 	 0.001301372850536858
i3d_rgb WB+aROI+i3d_rgb-htROI 	 0.4759140821714593 	 0.00130205356500629


backbone=3d_swin:   0%|          | 0/7 [00:00<?, ?it/s]

3d_swin WB 	 0.4258591710305511 	 0.001284256608850363
3d_swin aROI 	 0.43825530491334425 	 0.001299097228986688
3d_swin 3d_swin-htROI 	 0.4500680052811477 	 0.0012835251018083217
3d_swin WB+aROI 	 0.4418752167942003 	 0.0012999066120577252
3d_swin WB+3d_swin-htROI 	 0.4575968953670858 	 0.0012909982856360834
3d_swin aROI+3d_swin-htROI 	 0.4624340316515358 	 0.0012973117182602093
3d_swin WB+aROI+3d_swin-htROI 	 0.46343147739444684 	 0.0012975150455865516


In [162]:
for backbone in backbones:
    backbone = backbone.replace('-htROI', '')
    
    he_sch = 'H3'
    subset_roi_sch_text = 'WB+aROI+S-htROI'
    intersection_sch = 'grouped ensemble'

        
    path = Path(os.path.join(save_dir,
                                 f'backbone={backbone},he_sch={he_sch},intersection_sch={intersection_sch},subset_roi_sch={subset_roi_sch_text}'
                                 f'-prediction.pt'))
        
    new_prediction = torch.load(path)
        
    # algonauts2021_submission_from_whole_brain_prediction('../../src/config/dataset',
    #                                                 path.name.replace('-prediction.pt', ''),
    #                                                 new_prediction[1000:].numpy(),
    #                                                 output_dir='./submissions/ablation/',
    #                                                 mini_track=False)
    scores = vectorized_correlation(new_prediction[val_indices], val_fmris) / noise_ceiling
    sem = scores.std() / np.sqrt(len(scores))
    sem = sem.item()
    score = scores.mean().item()
    print(backbone, subset_roi_sch_text, '\t', score, '\t', sem)

i3d_rgb WB+aROI+S-htROI 	 0.4776067037549395 	 0.0013163660471815686
3d_swin WB+aROI+S-htROI 	 0.4685314330434128 	 0.0013100570368334515
