In [1]:
import numpy as np
import torch
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import re
import os
from itertools import permutations
import pandas as pd
import argparse

from fmri_reconstruction_with_dmvae.deeprecon.datasets.load import load_data
from fmri_reconstruction_with_dmvae.deeprecon.datasets.align import get_train_data, get_test_data
from fmri_reconstruction_with_dmvae.deeprecon.datasets.combine import conbine_cross_occurrences
from fmri_reconstruction_with_dmvae.deeprecon.datasets.dataset import get_dataset


from fmri_reconstruction_with_dmvae.models.dmvae import DMVAE

device = "cuda" if torch.cuda.is_available() else "cpu"

### Configuration

In [None]:
parser = argparse.ArgumentParser()

parser.add_argument("--data_path", type=str, default="/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/data/deeprecon/")
parser.add_argument("--output_path", type=str, default="/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/results/deeprecon/conversion/")
# parser.add_argument("--ckpt_path", type=str, default="")

parser.add_argument("--subj_list", nargs="+", type=int, default=[1, 2])
parser.add_argument("--n_samples", type=int, default=2400)
parser.add_argument("--is_normalized", type=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--zp_dim", type=int, default=256)
parser.add_argument("--zs_dim", type=int, default=768)
parser.add_argument("--hidden_dim", type=int, default=1024)
parser.add_argument("--lr", type=int, default=5e-4)
parser.add_argument("--weight_decay", type=int,default=1e-1)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--all_subj_lambdas", nargs="+", type=int, default=[1, 1, 1, 1, 1, 1, 1, 1])
parser.add_argument("--seed", type=int, default=42)

In [None]:
args = parser.parse_args()

### Data

In [None]:
data_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/data/deeprecon/"
subj_list = [1, 3, 4, 5]   # [1, 2, 3, 4, 5] 

n_samples = 2400

# set needed repetitions
rep = 1 if n_samples < 1200 else int(n_samples / 1200)

In [3]:
### Load
train_brain_data_dict, test_brain_data_dict, all_subj_num_voxels = load_data(data_path, subj_list)

In [4]:
### Align
train_data, train_mean_dict, train_norm_dict = get_train_data(train_brain_data_dict, subj_list, rep, is_normalized=True)
# ### Combine
# train_data = conbine_cross_occurrences(train_data, subj_list)

test_data = get_test_data(test_brain_data_dict, subj_list, train_mean_dict, train_norm_dict, is_normalized=True)

In [5]:
### Dataset
train_dataset, test_dataset = get_dataset(train_data, test_data)

In [6]:
### DataLoader
batch_size = 128
g = torch.Generator()

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, generator=g)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)

### Model

In [7]:
zp_dim = 256
zs_dim = 768
hidden_dim = 1024
optimizer = optim.Adam
lr = 5e-4
weight_decay = 1e-1

model = DMVAE(subj_list, all_subj_num_voxels, zp_dim, zs_dim, hidden_dim, device=device, optimizer=optimizer, lr=lr, weight_decay=weight_decay)

### Train

In [8]:
def get_recon_dict_batch(input_dict, model):
    z_dict = {}
    recon_dict_batch = {}

    with torch.no_grad():
        for subj in subj_list:
            s = f"{int(subj):02d}"
            z_dict[f"zp{s}"] = model.dist_dict[f"q_zp{s}__x{s}"].sample(input_dict, return_all=False) 
            z_dict[f"zs__x{s}"] = model.dist_dict[f"q_zs__x{s}"].sample(input_dict, return_all=False)
        
        z_dict["zs__x"] = model.dist_dict["q_zs__x"].sample(input_dict, return_all=False)

        for subj_target in subj_list:
            s_t = f"{int(subj_target):02d}"

            recon_dict_batch[f"joint_recon_x{s_t}"] = model.dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x"]).cpu()
            
            for subj_source in subj_list:
                s_s = f"{int(subj_source):02d}"

                if s_t == s_s:
                    recon_dict_batch[f"self_recon_x{s_t}"] = model.dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x{s_s}"]).cpu()
                else:
                    recon_dict_batch[f"cross_recon_x{s_t}__x{s_s}"] = model.dist_dict[f"p_x{s_t}__zp{s_t}_zs"].sample_mean(z_dict[f"zp{s_t}"] | z_dict[f"zs__x{s_s}"]).cpu()

    return recon_dict_batch

In [9]:
def train(epoch):
    train_loss = 0
    total_samples = 0

    for data in tqdm(train_dl):
        input_dict = {}
        lambda_dict = {}
        
        for subj in subj_list:
            s = f"{int(subj):02d}"

            input_dict[f"x{s}"] = data[f"subj{s}"].to(torch.float32).to(device)
            lambda_dict[f"lambda_{s}"] = all_subj_lambdas[int(s)-1]

        loss = model.train(input_dict | lambda_dict)
        train_loss += loss * input_dict[f"x{s}"].size(0)
        total_samples += input_dict[f"x{s}"].size(0)

    train_loss = train_loss / total_samples
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

def test(epoch):
    test_loss = 0
    total_samples = 0

    all_input_dict_list = []
    all_recon_dict_list = []
    all_image_index_list = []

    with torch.no_grad():
        for data in tqdm(test_dl):
            input_dict = {}
            lambda_dict = {}
            
            for subj in subj_list:
                s = f"{int(subj):02d}"
                input_dict[f"x{s}"] = data[f"subj{s}"].to(torch.float32).to(device)
                lambda_dict[f"lambda_{s}"] = all_subj_lambdas[int(s)-1]

            loss = model.test(input_dict | lambda_dict)
            test_loss += loss * input_dict[f"x{s}"].size(0)
            total_samples += input_dict[f"x{s}"].size(0)

            if epoch == epochs:
                all_input_dict_list.append(input_dict)

                recon_dict_batch = get_recon_dict_batch(input_dict, model)
                all_recon_dict_list.append(recon_dict_batch)

                all_image_index_list.append(data["image_index"])

    test_loss = test_loss / total_samples
    print('Epoch: {} Test loss: {:.4f}'.format(epoch, test_loss))
    return test_loss, all_input_dict_list, all_recon_dict_list, all_image_index_list

In [None]:
epochs = 20
all_subj_lambdas = [1, 1, 1, 1, 1, 1, 1, 1]

g.manual_seed(42)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss, all_input_dict_list, all_recon_dict_list, all_image_index_list = test(epoch)

# model.save(ckpt_path)

100%|██████████| 18/18 [00:07<00:00,  2.53it/s]


Epoch: 1 Train loss: 404740.4688


100%|██████████| 10/10 [00:01<00:00,  8.67it/s]


Epoch: 1 Test loss: 396098.5312


100%|██████████| 18/18 [00:03<00:00,  4.63it/s]


Epoch: 2 Train loss: 394544.2812


100%|██████████| 10/10 [00:00<00:00, 10.11it/s]


Epoch: 2 Test loss: 391311.4062


100%|██████████| 18/18 [00:04<00:00,  4.32it/s]


Epoch: 3 Train loss: 389816.4375


100%|██████████| 10/10 [00:01<00:00,  9.65it/s]


Epoch: 3 Test loss: 389000.2812


100%|██████████| 18/18 [00:03<00:00,  4.60it/s]


Epoch: 4 Train loss: 385818.2188


100%|██████████| 10/10 [00:01<00:00,  9.72it/s]


Epoch: 4 Test loss: 386405.2500


100%|██████████| 18/18 [00:03<00:00,  4.50it/s]


Epoch: 5 Train loss: 381745.8125


100%|██████████| 10/10 [00:00<00:00, 10.54it/s]


Epoch: 5 Test loss: 383988.9375


100%|██████████| 18/18 [00:03<00:00,  4.62it/s]


Epoch: 6 Train loss: 378253.2812


100%|██████████| 10/10 [00:00<00:00, 10.78it/s]


Epoch: 6 Test loss: 382253.0625


100%|██████████| 18/18 [00:04<00:00,  4.49it/s]


Epoch: 7 Train loss: 375026.2812


100%|██████████| 10/10 [00:00<00:00, 10.44it/s]


Epoch: 7 Test loss: 380793.4688


100%|██████████| 18/18 [00:03<00:00,  4.56it/s]


Epoch: 8 Train loss: 372489.2500


100%|██████████| 10/10 [00:01<00:00,  9.18it/s]


Epoch: 8 Test loss: 379645.1562


100%|██████████| 18/18 [00:04<00:00,  4.49it/s]


Epoch: 9 Train loss: 370043.1875


100%|██████████| 10/10 [00:01<00:00,  8.42it/s]


Epoch: 9 Test loss: 378581.7812


100%|██████████| 18/18 [00:03<00:00,  4.53it/s]


Epoch: 10 Train loss: 367809.2500


100%|██████████| 10/10 [00:01<00:00,  8.20it/s]


Epoch: 10 Test loss: 377924.6562


100%|██████████| 18/18 [00:03<00:00,  4.73it/s]


Epoch: 11 Train loss: 365822.4062


100%|██████████| 10/10 [00:01<00:00,  8.94it/s]


Epoch: 11 Test loss: 377343.3750


100%|██████████| 18/18 [00:03<00:00,  4.59it/s]


Epoch: 12 Train loss: 363988.3125


100%|██████████| 10/10 [00:00<00:00, 10.54it/s]


Epoch: 12 Test loss: 376697.6562


100%|██████████| 18/18 [00:03<00:00,  4.67it/s]


Epoch: 13 Train loss: 362094.9688


100%|██████████| 10/10 [00:00<00:00, 10.80it/s]


Epoch: 13 Test loss: 376228.7812


100%|██████████| 18/18 [00:03<00:00,  4.68it/s]


Epoch: 14 Train loss: 360523.9375


100%|██████████| 10/10 [00:01<00:00,  9.63it/s]


Epoch: 14 Test loss: 375783.5312


100%|██████████| 18/18 [00:04<00:00,  4.33it/s]


Epoch: 15 Train loss: 358984.8750


100%|██████████| 10/10 [00:00<00:00, 10.84it/s]


Epoch: 15 Test loss: 375454.3125


100%|██████████| 18/18 [00:03<00:00,  4.64it/s]


Epoch: 16 Train loss: 357414.8125


100%|██████████| 10/10 [00:00<00:00, 10.72it/s]


Epoch: 16 Test loss: 375359.9375


100%|██████████| 18/18 [00:03<00:00,  4.71it/s]


Epoch: 17 Train loss: 355911.8125


100%|██████████| 10/10 [00:00<00:00, 11.02it/s]


Epoch: 17 Test loss: 375104.1250


100%|██████████| 18/18 [00:03<00:00,  4.60it/s]


Epoch: 18 Train loss: 354548.7188


100%|██████████| 10/10 [00:00<00:00, 10.73it/s]


Epoch: 18 Test loss: 374900.0625


100%|██████████| 18/18 [00:03<00:00,  4.59it/s]


Epoch: 19 Train loss: 353243.5312


100%|██████████| 10/10 [00:01<00:00,  8.99it/s]


Epoch: 19 Test loss: 374787.6875


100%|██████████| 18/18 [00:03<00:00,  4.72it/s]


Epoch: 20 Train loss: 351941.6250


100%|██████████| 10/10 [00:01<00:00,  5.21it/s]

Epoch: 20 Test loss: 374651.5000





### Evaluation

In [11]:
def collate_result_dict_list(all_result_dict_list):
    collated_all_result_dict = {
        key: torch.cat(
            [chunk[key].to(dtype=torch.float32) for chunk in all_result_dict_list],
            dim=0
        ).cpu().numpy()
        for key in all_result_dict_list[0].keys()
    }
    
    return collated_all_result_dict

def calculate_pattern_correlation(target_brain, recon_brain, label, rep=24):
    sort_idx = np.argsort(label.flatten())
    target_brain = target_brain[sort_idx]
    recon_brain = recon_brain[sort_idx]
    label = label[sort_idx]
    unique_label = np.unique(label)

    pattern_corr = []
    for image_idx in unique_label:
        target_pattern = target_brain[(label == image_idx).flatten(), :]
        recon_pattern = recon_brain[(label == image_idx).flatten(), :]
        
        corrs = np.corrcoef(target_pattern, recon_pattern)[:rep, rep:]
        corr = np.mean(corrs) 

        pattern_corr.append(corr)

    return pattern_corr

def calculate_profile_correlation(target_brain, recon_brain, label, rep=24):
    sort_idx = np.argsort(label.flatten())
    target_brain = target_brain[sort_idx]
    recon_brain = recon_brain[sort_idx]

    profile_corr = []
    for voxel_idx in range(target_brain.shape[1]):
        target_profile = target_brain[:, voxel_idx].reshape(rep, -1, order="F")
        recon_profile = recon_brain[:, voxel_idx].reshape(rep, -1, order="F")
        
        corrs = np.corrcoef(target_profile, recon_profile)[:rep, rep:]
        corr = np.mean(corrs) 

        profile_corr.append(corr)

    return profile_corr

def save_result(result_data, output_dir, output_filename):
    """
    Save the results to a CSV file.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    df = pd.DataFrame(result_data)
    df.to_csv(os.path.join(output_dir, output_filename), index=None)

In [None]:
input_brain_dict = collate_result_dict_list(all_input_dict_list)
recon_brain_dict = collate_result_dict_list(all_recon_dict_list)
label = torch.cat(all_image_index_list, dim=0).cpu().numpy()

output_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/results/deeprecon/conversion/"

pattern_corr_result = []
profile_corr_result = []

for subj_target, subj_source in permutations(subj_list, 2):
    s_t = f"{int(subj_target):02d}"
    s_s = f"{int(subj_source):02d}"

    target_brain = input_brain_dict[f"x{s_t}"]
    recon_brain = recon_brain_dict[f"cross_recon_x{s_t}__x{s_s}"]

    pattern_corr = calculate_pattern_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(pattern_corr):
        pattern_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Image_idx': i+1}
        )

    profile_corr = calculate_profile_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(profile_corr):
        profile_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Voxel_idx': i}
        )

save_result(pattern_corr_result, output_path, f"pattern_correlation_dmvae_subj{''.join(map(str, subj_list))}.csv")
save_result(profile_corr_result, output_path, f"profile_correlation_dmvae_subj{''.join(map(str, subj_list))}.csv")