In [1]:
import torch

from brain_multimodal_vae.dataset.deeprecon import load_data, DeepReconDataset
from brain_multimodal_vae.utils import collate_batch_dict_list

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
data_dir = "/home/acg17270jl/projects/brain-multimodal-vae/data/deeprecon/"
all_subject_list = ["x01", "x02", "x03", "x04", "x05"]

In [3]:
_, _, test_brain_dict, test_label_dict = load_data(data_dir, all_subject_list, n_train_repetitions=5, normalize=True)
test_ds = DeepReconDataset(test_brain_dict, test_label_dict, all_subject_list, False, "and", False)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=128, shuffle=False, drop_last=False, pin_memory=True)

In [4]:
result_brain_batch_dict_list = []
result_label_batch_dict_list = []

for brain_batch_dict, label_batch_dict, _ in test_dl:
    for s in test_dl.dataset.subject_list:
        brain_batch_dict[f"{s}"] = brain_batch_dict[f"{s}"].to(torch.float32).to(device)
        label_batch_dict[f"{s}"] = label_batch_dict[f"{s}"].to(device)

    result_brain_batch_dict_list.append(brain_batch_dict)
    result_label_batch_dict_list.append(label_batch_dict)

result_brain_dict = collate_batch_dict_list(result_brain_batch_dict_list)
result_label_dict = collate_batch_dict_list(result_label_batch_dict_list)

In [5]:
from brain_multimodal_vae.dataset.deeprecon import get_label_set

def get_pattern_corr_nc_result(subject_list, result_brain_dict, result_label_dict):
    shared_label_set = get_label_set(result_label_dict, subject_list, set_mode="and")
    shared_label = torch.tensor(list(shared_label_set))

    pattern_corr_nc_result = []
    for s in subject_list:
        brain = result_brain_dict[f"{s}"]
        label = result_label_dict[f"{s}"]

        pattern_corr_nc_list = calc_pattern_corr_nc(brain, label, shared_label)

        for i, corr_nc in enumerate(pattern_corr_nc_list):
            pattern_corr_nc_result.append({
                "subject": s,
                "noise_ceiling": corr_nc.item(),
                "label" : i + 1
            })
    
    return pattern_corr_nc_result

In [6]:
def calc_pattern_corr_nc(brain, label, shared_label):
    pattern_corr_nc_list = []
    for l in shared_label:
        pattern = brain[(label == l), :]

        corrs = torch.corrcoef(pattern)
        corr_nc = (corrs.sum() - corrs.shape[0]) / (corrs.shape[0] * (corrs.shape[0] - 1))

        pattern_corr_nc_list.append(corr_nc)

    return pattern_corr_nc_list

In [7]:
pattern_corr_nc_result = get_pattern_corr_nc_result(all_subject_list, result_brain_dict, result_label_dict)

In [9]:
import pandas as pd
from brain_multimodal_vae.utils import save_df

output_dir = "/home/acg17270jl/projects/brain-multimodal-vae/results/deeprecon/conversion/"

save_df(pd.DataFrame(pattern_corr_nc_result), output_dir, "pattern_noise_ceiling.csv")