In [78]:
import os
import torch
import argparse

from brain_multimodal_vae.dataset.deeprecon import load_data, DeepReconDataset
from brain_multimodal_vae.models import DMVAE, MMVAE, MVAE
from brain_multimodal_vae.evaluation import get_prediction_dict, get_pattern_corr_result
import brain_multimodal_vae.utils as utils

device = "cuda" if torch.cuda.is_available() else "cpu"
g = torch.Generator()

In [79]:
all_subject_list = ["x01", "x02", "x03", "x04", "x05"]
n_voxels_dict = {'x01': 15316, 'x02': 14597, 'x03': 13135, 'x04': 13596, 'x05': 13149}
models = {"MVAE": MVAE, "MMVAE": MMVAE, "DMVAE": DMVAE}

parser = argparse.ArgumentParser()

# Path setting
parser.add_argument("--data_dir", type=str, default="/home/acg17270jl/projects/brain-multimodal-vae/data/deeprecon/")
parser.add_argument("--ckpt_dir", type=str, default="/home/acg17270jl/projects/brain-multimodal-vae/checkpoints/deeprecon/")
# Data setting
parser.add_argument("--subject_list", nargs="+", default=["x01", "x02", "x03", "x04", "x05"])
parser.add_argument("--n_train_repetitions", type=int, default=5)
parser.add_argument("--normalize", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--n_shared_labels", type=int, default=600)
parser.add_argument("--n_unique_labels", type=int, default=120)
parser.add_argument("--select_seed", type=int, default=42)
# Dataset setting
parser.add_argument("--train_group", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--test_group", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--set_mode", choices=["and", "or"], default="or")
parser.add_argument("--include_missing", action=argparse.BooleanOptionalAction, default=True)
# DataLoder setting
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch_size", type=int, default=128)
# Model setting
parser.add_argument('--model_name', choices=list(models.keys()), default="MVAE")
parser.add_argument("--z_dim", type=int, default=128)
parser.add_argument("--zp_dim", type=int, default=64)
parser.add_argument("--zs_dim", type=int, default=128)
parser.add_argument("--hidden_dim", type=int, default=4096)
# Training setting
parser.add_argument("--optimizer_name", choices=["adam"], default="adam")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=1e-1)
parser.add_argument("--n_epochs", type=int, default=15)
parser.add_argument("--eval", action=argparse.BooleanOptionalAction, default=False)

jupyter_args = """
    --subject_list x01 x02 \
    --n_train_repetitions 5 \
    --n_shared_labels 0 \
    --n_unique_labels 600 \
    --select_seed 42 \
    --no-train_group \
    --include_missing \
    --model_name MVAE \
    --z_dim 128 \
    --eval
"""

jupyter_args = jupyter_args.split()

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

params = vars(args)

In [80]:
_, _, test_brain_dict, test_label_dict = load_data(params["data_dir"], all_subject_list, params["n_train_repetitions"], params["normalize"])

test_brain_dict = utils.get_sub_dict(test_brain_dict, params["subject_list"])
test_label_dict = utils.get_sub_dict(test_label_dict, params["subject_list"])

In [81]:
test_ds = DeepReconDataset(test_brain_dict, test_label_dict, params["subject_list"], False, "and", False)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=params["batch_size"], shuffle=False, drop_last=False, pin_memory=True)

In [82]:
model = models[params["model_name"]](**params, n_voxels_dict=n_voxels_dict, device=device)
saved_dir = os.path.join(params["ckpt_dir"], f"{params['model_name']}", f"include_missing_{params['include_missing']}", f"train_group_{params['train_group']}")
model.load(os.path.join(saved_dir, f"{''.join(params['subject_list'])}_{params['n_train_repetitions']}_{params['n_shared_labels']}_{params['n_unique_labels']}_{params['select_seed']}.pt"))

In [83]:
result_brain_dict, result_recon_dict, result_label_dict, _ = get_prediction_dict(model, test_dl)
pattern_corr_result = get_pattern_corr_result(params["subject_list"], result_brain_dict, result_recon_dict, result_label_dict)

In [84]:
import pandas as pd
import numpy as np

result_path = "/home/acg17270jl/projects/brain-multimodal-vae/results/deeprecon/conversion/"
pattern_nc_file_name = "pattern_noise_ceiling.csv"
pattern_nc_df = pd.read_csv(os.path.join(result_path, pattern_nc_file_name))

pattern_corr_df = pd.DataFrame(pattern_corr_result)

In [85]:
pattern_nc_df['identifier'] = pattern_nc_df['subject'].values.astype(str) + pattern_nc_df['label'].values.astype(str)
pattern_corr_df['identifier'] = pattern_corr_df['subject_target'].values.astype(str) + pattern_corr_df['label'].values.astype(str)
pattern_corr_df = pattern_corr_df.merge(pattern_nc_df, on='identifier', suffixes=('', '_nc'))

In [86]:
pattern_corr_df['normalized correlation'] = np.clip(np.nan_to_num(pattern_corr_df['correlation'].values / pattern_corr_df['noise_ceiling'].values), a_min=-1, a_max=1)
pattern_corr_df['normalized correlation'].mean()

np.float64(0.14174371790143073)