In [17]:
from fmri_reconstruction_with_dmvae.mindeye2_nsd.datasets.load import load_all_subj_data, load_all_subj_voxels
from fmri_reconstruction_with_dmvae.mindeye2_nsd.datasets.align import align_subject_trials
from fmri_reconstruction_with_dmvae.mindeye2_nsd.datasets.split import split_aligned_data
from fmri_reconstruction_with_dmvae.mindeye2_nsd.datasets.dataset import get_dataset

import torch
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import re

device = "cuda" if torch.cuda.is_available() else "cpu"

data_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/data/mindeye2_nsd/"
subj_list = [1, 2, 5, 7] # [1, 2, 3, 4, 5, 6, 7, 8]

### Data

In [2]:
### Load
all_subj_shared1000_data = load_all_subj_data(data_path, subj_list, data_range="shared1000")
# all_subj_all_data = load_all_subj_data(data_path, subj_list, data_range="all", subj_num_ssessions_list=all_subj_num_ssessions_list)

all_subj_voxels, all_subj_num_voxels = load_all_subj_voxels(data_path, subj_list)

Loaded all subj data

Loaded all subj voxels



In [3]:
### Align
aligned_all_subj_shared1000_data = align_subject_trials(all_subj_shared1000_data, anchor_subject="subj01")
# aligned_all_subj_data = align_subject_trials(all_subj_all_data, anchor_subject="subj01")

In [4]:
### Split
train_data, test_data = split_aligned_data(aligned_all_subj_shared1000_data, subj_list, train_occurrence_max=2)
# train_data, test_data = split_aligned_data(aligned_all_subj_data, subj_list, train_occurrence_max=2)

In [5]:
### Dataset
train_dataset, test_dataset = get_dataset(train_data, test_data)

In [6]:
### DataLoader
batch_size = 128

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)

### Model

In [23]:
from fmri_reconstruction_with_dmvae.models.dmvae import DMVAE

In [24]:
zp_dim = 256
zs_dim = 768
hidden_dim = 1024

model = DMVAE(subj_list, all_subj_num_voxels, zp_dim, zs_dim, hidden_dim, optimizer=optim.Adam, lr=1e-3, device=device)

### Train

In [25]:
def get_recon_dict_batch(input_dict, model):
    z_dict = {}
    recon_dict_batch = {}

    with torch.no_grad():
        for subj in subj_list:
            s = f"{int(subj):02d}"
            z_dict[f"zp{s}"] = model.dist_dict[f"q_zp{s}__x{s}"].sample(input_dict, return_all=False) 
            z_dict[f"zs__x{s}"] = model.dist_dict[f"q_zs__x{s}"].sample(input_dict, return_all=False)
        
        z_dict["zs__x"] = model.dist_dict["q_zs__x"].sample(input_dict, return_all=False)

        for subj_target in subj_list:
            s_t = f"{int(subj_target):02d}"

            recon_dict_batch[f"joint_recon_x{s_t}"] = model.dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x"]).cpu()
            
            for subj_resource in subj_list:
                s_s = f"{int(subj_resource):02d}"

                if s_t == s_s:
                    recon_dict_batch[f"self_recon_x{s_t}"] = model.dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x{s_s}"]).cpu()
                else:
                    recon_dict_batch[f"cross_recon_x{s_t}__x{s_s}"] = model.dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x{s_s}"]).cpu()

    return recon_dict_batch

def calc_cosine_dict_batch(input_dict, recon_dict_batch):
    cosine_dict_batch = {}

    for key in recon_dict_batch:
        s_t = re.search(r'recon_x(\d{2})', key).group(1)
        cosine = F.cosine_similarity(input_dict[f"x{s_t}"].cpu(), recon_dict_batch[key])
        cosine_dict_batch[key] = cosine

    return cosine_dict_batch

def calc_pearson_dict_batch(input_dict, recon_dict_batch):
    pearson_dict_batch = {}

    for key in recon_dict_batch:
        s_t = re.search(r'recon_x(\d{2})', key).group(1)
        input_centered = input_dict[f"x{s_t}"] - input_dict[f"x{s_t}"].mean(dim=1, keepdim=True)
        recon_centered = recon_dict_batch[key] - recon_dict_batch[key].mean(dim=1, keepdim=True)
        pearson = F.cosine_similarity(input_centered.cpu(), recon_centered, dim=1)
        pearson_dict_batch[key] = pearson
    
    return pearson_dict_batch

def update_metrics(metrics_dict, input_dict, recon_dict_batch, n_batches):
      cosine_dict = calc_cosine_dict_batch(input_dict, recon_dict_batch)
      pearson_dict = calc_pearson_dict_batch(input_dict, recon_dict_batch)
      
      for key in recon_dict_batch:
          entry = metrics_dict.setdefault(key, {"cosine_mean": 0.0, "pearson_mean": 0.0})
          entry["cosine_mean"] += cosine_dict[key].mean() / n_batches
          entry["pearson_mean"] += pearson_dict[key].mean() / n_batches

      return metrics_dict

In [26]:
def train(epoch):
    train_loss = 0
    n_batches = len(train_dl)
    train_metrics_dict = {}

    for data in tqdm(train_dl):
        input_dict = {}
        lambda_dict = {}
        
        for subj in subj_list:
            s = f"{int(subj):02d}"

            input_dict[f"x{s}"] = all_subj_voxels[f"subj{s}"][data[1][f"subj{s}"]].to(device)
            lambda_dict[f"lambda_{s}"] = all_subj_lambdas[int(s)-1]

        loss = model.train(input_dict | lambda_dict)
        train_loss += loss

        if epoch == epochs:
            recon_dict_batch = get_recon_dict_batch(input_dict, model)
            train_metrics_dict = update_metrics(train_metrics_dict, input_dict, recon_dict_batch, n_batches)

    train_loss = train_loss / n_batches
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss, train_metrics_dict

def test(epoch):
    test_loss = 0
    n_batches = len(test_dl)
    test_metrics_dict = {}

    with torch.no_grad():
        for data in tqdm(test_dl):
            input_dict = {}
            lambda_dict = {}
            
            for subj in subj_list:
                s = f"{int(subj):02d}"
                input_dict[f"x{s}"] = all_subj_voxels[f"subj{s}"][data[1][f"subj{s}"]].to(device)
                lambda_dict[f"lambda_{s}"] = all_subj_lambdas[int(s)-1]

            loss = model.test(input_dict | lambda_dict)
            test_loss += loss

            if epoch == epochs:
                recon_dict_batch = get_recon_dict_batch(input_dict, model)
                test_metrics_dict = update_metrics(test_metrics_dict, input_dict, recon_dict_batch, n_batches)

    test_loss = test_loss / n_batches
    print('Epoch: {} Test loss: {:.4f}'.format(epoch, test_loss))
    return test_loss, test_metrics_dict

In [27]:
epochs = 20
all_subj_lambdas = [1, 1, 1, 1, 1, 1, 1, 1]

for epoch in range(1, epochs + 1):
    train_loss, train_metrics_dict = train(epoch)
    test_loss, test_metrics_dict = test(epoch)

100%|██████████| 15/15 [00:03<00:00,  4.48it/s]


Epoch: 1 Train loss: 418959.4062


100%|██████████| 7/7 [00:01<00:00,  6.50it/s]


Epoch: 1 Test loss: 420662.0625


100%|██████████| 15/15 [00:03<00:00,  4.89it/s]


Epoch: 2 Train loss: 410366.1562


100%|██████████| 7/7 [00:00<00:00,  8.03it/s]


Epoch: 2 Test loss: 407685.4375


100%|██████████| 15/15 [00:03<00:00,  4.45it/s]


Epoch: 3 Train loss: 401057.0938


100%|██████████| 7/7 [00:00<00:00, 10.64it/s]


Epoch: 3 Test loss: 400462.4688


100%|██████████| 15/15 [00:03<00:00,  4.26it/s]


Epoch: 4 Train loss: 394952.9688


100%|██████████| 7/7 [00:00<00:00, 10.51it/s]


Epoch: 4 Test loss: 396740.4062


100%|██████████| 15/15 [00:03<00:00,  3.98it/s]


Epoch: 5 Train loss: 391423.0312


100%|██████████| 7/7 [00:00<00:00,  9.81it/s]


Epoch: 5 Test loss: 393323.7812


100%|██████████| 15/15 [00:03<00:00,  4.24it/s]


Epoch: 6 Train loss: 387764.3438


100%|██████████| 7/7 [00:00<00:00,  7.16it/s]


Epoch: 6 Test loss: 390680.5938


100%|██████████| 15/15 [00:03<00:00,  4.44it/s]


Epoch: 7 Train loss: 384258.0312


100%|██████████| 7/7 [00:01<00:00,  6.99it/s]


Epoch: 7 Test loss: 388546.1250


100%|██████████| 15/15 [00:03<00:00,  4.33it/s]


Epoch: 8 Train loss: 381155.6875


100%|██████████| 7/7 [00:00<00:00,  9.02it/s]


Epoch: 8 Test loss: 386986.2188


100%|██████████| 15/15 [00:03<00:00,  4.67it/s]


Epoch: 9 Train loss: 378390.8125


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]


Epoch: 9 Test loss: 385660.5938


100%|██████████| 15/15 [00:03<00:00,  4.57it/s]


Epoch: 10 Train loss: 375512.2500


100%|██████████| 7/7 [00:00<00:00,  8.17it/s]


Epoch: 10 Test loss: 384165.8438


100%|██████████| 15/15 [00:03<00:00,  4.76it/s]


Epoch: 11 Train loss: 372899.8125


100%|██████████| 7/7 [00:00<00:00,  8.29it/s]


Epoch: 11 Test loss: 383209.2812


100%|██████████| 15/15 [00:03<00:00,  4.56it/s]


Epoch: 12 Train loss: 370380.8750


100%|██████████| 7/7 [00:00<00:00,  7.18it/s]


Epoch: 12 Test loss: 382151.5625


100%|██████████| 15/15 [00:03<00:00,  4.36it/s]


Epoch: 13 Train loss: 368124.2500


100%|██████████| 7/7 [00:00<00:00,  7.50it/s]


Epoch: 13 Test loss: 381471.0312


100%|██████████| 15/15 [00:03<00:00,  4.43it/s]


Epoch: 14 Train loss: 365762.1562


100%|██████████| 7/7 [00:00<00:00,  8.15it/s]


Epoch: 14 Test loss: 380803.9375


100%|██████████| 15/15 [00:03<00:00,  4.78it/s]


Epoch: 15 Train loss: 363588.9375


100%|██████████| 7/7 [00:00<00:00,  7.68it/s]


Epoch: 15 Test loss: 380278.1562


100%|██████████| 15/15 [00:03<00:00,  4.39it/s]


Epoch: 16 Train loss: 361628.6562


100%|██████████| 7/7 [00:00<00:00,  9.37it/s]


Epoch: 16 Test loss: 379995.5625


100%|██████████| 15/15 [00:03<00:00,  4.69it/s]


Epoch: 17 Train loss: 359618.0312


100%|██████████| 7/7 [00:00<00:00,  7.97it/s]


Epoch: 17 Test loss: 380415.4062


100%|██████████| 15/15 [00:03<00:00,  4.38it/s]


Epoch: 18 Train loss: 357999.6250


100%|██████████| 7/7 [00:00<00:00,  7.67it/s]


Epoch: 18 Test loss: 380400.1562


100%|██████████| 15/15 [00:03<00:00,  4.25it/s]


Epoch: 19 Train loss: 356734.5625


100%|██████████| 7/7 [00:00<00:00,  7.21it/s]


Epoch: 19 Test loss: 381882.0938


100%|██████████| 15/15 [00:04<00:00,  3.18it/s]


Epoch: 20 Train loss: 355454.5625


100%|██████████| 7/7 [00:01<00:00,  4.62it/s]

Epoch: 20 Test loss: 382027.5938





In [28]:
test_metrics_dict

{'joint_recon_x01': {'cosine_mean': tensor(0.5460),
  'pearson_mean': tensor(0.5091)},
 'self_recon_x01': {'cosine_mean': tensor(0.5381),
  'pearson_mean': tensor(0.5010)},
 'cross_recon_x01__x02': {'cosine_mean': tensor(0.5396),
  'pearson_mean': tensor(0.5028)},
 'cross_recon_x01__x05': {'cosine_mean': tensor(0.5426),
  'pearson_mean': tensor(0.5057)},
 'cross_recon_x01__x07': {'cosine_mean': tensor(0.5422),
  'pearson_mean': tensor(0.5051)},
 'joint_recon_x02': {'cosine_mean': tensor(0.5621),
  'pearson_mean': tensor(0.5295)},
 'cross_recon_x02__x01': {'cosine_mean': tensor(0.5547),
  'pearson_mean': tensor(0.5216)},
 'self_recon_x02': {'cosine_mean': tensor(0.5571),
  'pearson_mean': tensor(0.5242)},
 'cross_recon_x02__x05': {'cosine_mean': tensor(0.5576),
  'pearson_mean': tensor(0.5249)},
 'cross_recon_x02__x07': {'cosine_mean': tensor(0.5567),
  'pearson_mean': tensor(0.5238)},
 'joint_recon_x05': {'cosine_mean': tensor(0.6178),
  'pearson_mean': tensor(0.5807)},
 'cross_recon_x