In [15]:
from pathlib import Path

import numpy as np
import torch

from tqdm.auto import tqdm
from tqdm.notebook import tqdm_notebook, trange

import itertools

import os
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv(), verbose=True)

True

# prepare data

In [16]:
from src.config.config import combine_cfgs, get_cfg_defaults
from src.data.datamodule import MyDataModule

In [17]:
# prepare validation data
cfg = combine_cfgs('../src/config/experiments/algonauts2021_i3d_flow.yml')
dm = MyDataModule(cfg)
dm.prepare_data()
dm.setup()

val_indices = dm.val_dataset.indices

cache_path = Path('/data_smr/huze/.cache/val_fmris.pt')

if cache_path.exists():
    val_fmris = torch.load(cache_path)
else:
    val_fmris = [dm.dataset_train_val.__getitem__(i)[1] for i in tqdm(val_indices)]
    val_fmris = np.stack(val_fmris, 0)
    val_fmris = torch.tensor(val_fmris)
    torch.save(val_fmris, cache_path)

In [18]:
voxel_masks = np.load(dm.dataset_train_val.root_dir.joinpath(Path('processed/voxel_masks.npy')))

In [19]:
from src.utils.visulization import save_as_nii

reliable_voxel_mask_mean_sub = voxel_masks.sum(0)
reliable_voxel_mask_mean_sub[reliable_voxel_mask_mean_sub > 0] = 1
reliable_voxel_mask_mean_sub = reliable_voxel_mask_mean_sub.astype(np.int8)
reliable_voxel_mask_path = Path('./tmp/reliable_voxel_mask.nii')
save_as_nii('./tmp/example.nii', reliable_voxel_mask_path, reliable_voxel_mask_mean_sub)

In [20]:
predictions_dict = {
    # from previous project, notebook 'plot mc fi score, maybe final'
    'Baseline,VGG16 2D + LSTM,layer=(x1,x2,x3,x4,x5),pathway=cascade,roi=WB': Path(
        '/data_smr/huze/projects/my_algonauts/predictions/bc4bd42bc70d409daebe27f2c0fae255/WB.pt'),
    # from notebook 1.
    'Baseline,Inflated 3D Resnet,RGB stream,layer=(x1,x2,x3,x4),pathway=topdown,roi=WB': Path(
        "/data/huze/ray_results/algonauts2021/algonauts2021_i3d_rgb-multi_layer/run_single_tune_config_575a8_00066_66_DATASET.ROI=WB,MODEL.BACKBONE.LAYERS=_'x1', 'x2', 'x3', 'x4'_,MODEL.BACKBONE.LAYER_PATHWAYS=_2022-02-26_14-18-59/prediction.npy"),
    # from notebook 999.
    'Best,Inflated3D,RGB steram + FLOW stream,HierarchicalEnsemble,ROIxkROI': Path('/data/huze/ray_results/algonauts2021/ensemble_outputs/he_sch=H3,model_sch=single_layer&i3d_rgb+i3d_flow,roi_sch=ROIxSMC,cross_roi_sch=croi_ensemble-prediction.pt'),
}

In [21]:
predictions_dict = {k: path for k, path in predictions_dict.items() if path.exists()}
assert len(predictions_dict) > 1

In [23]:
from matplotlib import pyplot as plt
from nilearn import plotting
from src.utils.visulization import get_nii, nice_plot
from src.utils.metrics import vectorized_correlation

surf_mesh = 'fsaverage' # more veterx but 20x slower
# surf_mesh = 'fsaverage5'

for name, path in predictions_dict.items():
    if path.name.endswith('.pt'):
        prediction = torch.load(path).float()
    else:
        prediction = torch.tensor(np.load(path)).float()

    voxel_score = vectorized_correlation(prediction[val_indices], val_fmris)
    voxel_score = voxel_score.numpy()
    mean_score = voxel_score.mean()

    tmp_nii_path = Path('tmp/score.nii')
    example_nii_path = Path('tmp/example.nii')

    get_nii(voxel_masks, voxel_score, example_nii_path, tmp_nii_path)

    fig, axes = plotting.plot_img_on_surf(str(tmp_nii_path),
                                          # threshold=0.025,
                                          bg_on_data=True,
                                          surf_mesh=surf_mesh,
                                          mask_img=str(reliable_voxel_mask_path),
                                          colorbar=True,
                                          alpha=1.,
                                          darkness=.5,
                                          vmax=0.5,
                                          views=['lateral', 'medial', 'dorsal', 'ventral', 'anterior', 'posterior'],
                                          hemispheres=['left', 'right'],
                                          inflate=True,
                                          cmap='PuOr',
                                          )
    # plt.show()
    img = nice_plot(fig, axes, dpi=360, tmp_dir='tmp/')
    plt.close()

    save_path = os.path.join('figures', f'score={mean_score:.3f},{name}.jpg')
    img.save(save_path, quality=95)

  texture = np.nanmean(all_samples, axis=2)
  texture = np.nanmean(all_samples, axis=2)


axes:   0%|          | 0/12 [00:00<?, ?it/s]

  texture = np.nanmean(all_samples, axis=2)
  texture = np.nanmean(all_samples, axis=2)


axes:   0%|          | 0/12 [00:00<?, ?it/s]