In [1]:
import numpy as np
import torch
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import re
import os
from itertools import permutations
import pandas as pd

from fmri_reconstruction_with_dmvae.deeprecon.datasets.load import load_data
from fmri_reconstruction_with_dmvae.deeprecon.datasets.align import get_train_data, get_test_data
from fmri_reconstruction_with_dmvae.deeprecon.datasets.dataset import get_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

### Data

In [2]:
data_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/data/deeprecon/"
subj_list = [1, 2, 3, 4, 5]  # [1, 2, 3, 4, 5]

num_samples = 2400

# set needed repetitions
rep = 1 if num_samples < 1200 else int(num_samples / 1200)

In [3]:
### Load
train_brain_data_dict, test_brain_data_dict, all_subj_num_voxels = load_data(data_path, subj_list)

In [4]:
### Align
train_data, train_mean_dict, train_norm_dict = get_train_data(train_brain_data_dict, subj_list, rep, is_normalized=True)
test_data = get_test_data(test_brain_data_dict, subj_list, train_mean_dict, train_norm_dict, is_normalized=True)

In [5]:
### Dataset
train_dataset, test_dataset = get_dataset(train_data, test_data)

In [6]:
### DataLoader
batch_size = 128
g = torch.Generator()

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, generator=g)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)

### Model & Train & Evaluation

#### Function

In [None]:
def train(epoch):
    train_loss = 0
    total_samples = 0

    for data in tqdm(train_dl):
        optimizer.zero_grad()

        source_batch = data[f"subj{s_s}"].to(torch.float32).to(device)
        target_batch = data[f"subj{s_t}"].to(torch.float32).to(device)
        recon_batch = model(source_batch)
        
        loss = loss_fn(recon_batch, target_batch)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * recon_batch.size(0)
        total_samples += recon_batch.size(0)

    train_loss = train_loss / total_samples

    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

def test(epoch):
    model.eval()
    test_loss = 0
    total_samples = 0

    target_brain = []
    recon_brain = []
    label = []

    with torch.no_grad():
        for data in tqdm(test_dl):
            source_batch = data[f"subj{s_s}"].to(torch.float32).to(device)
            target_batch = data[f"subj{s_t}"].to(torch.float32).to(device)
            recon_batch = model(source_batch)

            loss = loss_fn(recon_batch, target_batch)

            test_loss += loss.item() * recon_batch.size(0)
            total_samples += recon_batch.size(0)

            if epoch == epochs:
                target_brain.append(target_batch.cpu())
                recon_brain.append(recon_batch.cpu())
                label.append(data["image_index"].cpu())
            
    test_loss = test_loss / total_samples

    print('Epoch: {} Test loss: {:.4f}'.format(epoch, test_loss))
    return test_loss, target_brain, recon_brain, label

In [None]:
def calculate_pattern_correlation(target_brain, recon_brain, label, rep=24):
    sort_idx = np.argsort(label.flatten())
    target_brain = target_brain[sort_idx]
    recon_brain = recon_brain[sort_idx]
    label = label[sort_idx]
    unique_label = np.unique(label)

    pattern_corr = []
    for image_idx in unique_label:
        target_pattern = target_brain[(label == image_idx).flatten(), :]
        recon_pattern = recon_brain[(label == image_idx).flatten(), :]
        
        corrs = np.corrcoef(target_pattern, recon_pattern)[:rep, rep:]
        corr = np.mean(corrs) 

        pattern_corr.append(corr)

    return pattern_corr

def calculate_profile_correlation(target_brain, recon_brain, label, rep=24):
    sort_idx = np.argsort(label.flatten())
    target_brain = target_brain[sort_idx]
    recon_brain = recon_brain[sort_idx]

    profile_corr = []
    for voxel_idx in range(target_brain.shape[1]):
        target_profile = target_brain[:, voxel_idx].reshape(rep, -1, order="F")
        recon_profile = recon_brain[:, voxel_idx].reshape(rep, -1, order="F")
        
        corrs = np.corrcoef(target_profile, recon_profile)[:rep, rep:]
        corr = np.mean(corrs) 

        profile_corr.append(corr)

    return profile_corr

def save_result(result_data, output_dir, output_filename):
    """
    Save the results to a CSV file.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    df = pd.DataFrame(result_data)
    df.to_csv(os.path.join(output_dir, output_filename), index=None)

#### Converter1

In [26]:
from fmri_reconstruction_with_dmvae.models.converter import Converter1

lr = 1e-4
weight_decay = 1e-1

epochs = 7

conversion_output_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/results/deeprecon/conversion/"

pattern_corr_result = []
profile_corr_result = []

for subj_target, subj_source in permutations(subj_list, 2):
    s_t = f"{int(subj_target):02d}"
    s_s = f"{int(subj_source):02d}"

    model = Converter1(all_subj_num_voxels[f"subj{s_s}"], all_subj_num_voxels[f"subj{s_t}"]).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) # L2
    loss_fn = torch.nn.MSELoss()

    g.manual_seed(42)
    for epoch in range(1, epochs + 1):
        train_loss = train(epoch)
        test_loss, target_brain, recon_brain, label = test(epoch)

    target_brain = torch.cat(target_brain, dim=0).cpu().numpy()
    recon_brain = torch.cat(recon_brain, dim=0).cpu().numpy()
    label = torch.cat(label, dim=0).cpu().numpy()

    pattern_corr = calculate_pattern_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(pattern_corr):
        pattern_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Image_idx': i+1}
        )

    profile_corr = calculate_profile_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(profile_corr):
        profile_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Voxel_idx': i}
        )

save_result(pattern_corr_result, conversion_output_path, "pattern_correlation_converter1.csv")
save_result(profile_corr_result, conversion_output_path, "profile_correlation_converter1.csv")

100%|██████████| 18/18 [00:00<00:00, 30.61it/s]


Epoch: 1 Train loss: 1.2342


100%|██████████| 10/10 [00:00<00:00, 56.13it/s]


Epoch: 1 Test loss: 1.2247


100%|██████████| 18/18 [00:00<00:00, 27.92it/s]


Epoch: 2 Train loss: 1.1007


100%|██████████| 10/10 [00:00<00:00, 53.90it/s]


Epoch: 2 Test loss: 1.0849


100%|██████████| 18/18 [00:00<00:00, 39.78it/s]


Epoch: 3 Train loss: 1.0349


100%|██████████| 10/10 [00:00<00:00, 25.05it/s]


Epoch: 3 Test loss: 1.0225


100%|██████████| 18/18 [00:00<00:00, 34.81it/s]


Epoch: 4 Train loss: 1.0098


100%|██████████| 10/10 [00:00<00:00, 36.34it/s]


Epoch: 4 Test loss: 0.9970


100%|██████████| 18/18 [00:00<00:00, 33.64it/s]


Epoch: 5 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 56.95it/s]


Epoch: 5 Test loss: 0.9878


100%|██████████| 18/18 [00:00<00:00, 32.51it/s]


Epoch: 6 Train loss: 0.9985


100%|██████████| 10/10 [00:00<00:00, 52.90it/s]


Epoch: 6 Test loss: 0.9846


100%|██████████| 18/18 [00:00<00:00, 30.21it/s]


Epoch: 7 Train loss: 0.9868


100%|██████████| 10/10 [00:00<00:00, 28.46it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 25.96it/s]


Epoch: 1 Train loss: 1.2370


100%|██████████| 10/10 [00:00<00:00, 43.10it/s]


Epoch: 1 Test loss: 1.1168


100%|██████████| 18/18 [00:00<00:00, 29.44it/s]


Epoch: 2 Train loss: 1.1072


100%|██████████| 10/10 [00:00<00:00, 46.11it/s]


Epoch: 2 Test loss: 1.0437


100%|██████████| 18/18 [00:00<00:00, 35.89it/s]


Epoch: 3 Train loss: 1.0406


100%|██████████| 10/10 [00:00<00:00, 59.75it/s]


Epoch: 3 Test loss: 1.0087


100%|██████████| 18/18 [00:00<00:00, 39.70it/s]


Epoch: 4 Train loss: 1.0137


100%|██████████| 10/10 [00:00<00:00, 51.40it/s]


Epoch: 4 Test loss: 0.9931


100%|██████████| 18/18 [00:00<00:00, 30.25it/s]


Epoch: 5 Train loss: 1.0008


100%|██████████| 10/10 [00:00<00:00, 54.12it/s]


Epoch: 5 Test loss: 0.9868


100%|██████████| 18/18 [00:00<00:00, 35.88it/s]


Epoch: 6 Train loss: 0.9997


100%|██████████| 10/10 [00:00<00:00, 56.57it/s]


Epoch: 6 Test loss: 0.9845


100%|██████████| 18/18 [00:00<00:00, 40.91it/s]


Epoch: 7 Train loss: 0.9874


100%|██████████| 10/10 [00:00<00:00, 48.56it/s]


Epoch: 7 Test loss: 0.9839


100%|██████████| 18/18 [00:00<00:00, 26.59it/s]


Epoch: 1 Train loss: 1.2371


100%|██████████| 10/10 [00:00<00:00, 55.71it/s]


Epoch: 1 Test loss: 1.1264


100%|██████████| 18/18 [00:00<00:00, 33.83it/s]


Epoch: 2 Train loss: 1.1059


100%|██████████| 10/10 [00:00<00:00, 48.36it/s]


Epoch: 2 Test loss: 1.0458


100%|██████████| 18/18 [00:00<00:00, 35.70it/s]


Epoch: 3 Train loss: 1.0390


100%|██████████| 10/10 [00:00<00:00, 59.16it/s]


Epoch: 3 Test loss: 1.0083


100%|██████████| 18/18 [00:00<00:00, 35.15it/s]


Epoch: 4 Train loss: 1.0124


100%|██████████| 10/10 [00:00<00:00, 59.00it/s]


Epoch: 4 Test loss: 0.9925


100%|██████████| 18/18 [00:00<00:00, 41.29it/s]


Epoch: 5 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 35.82it/s]


Epoch: 5 Test loss: 0.9861


100%|██████████| 18/18 [00:00<00:00, 30.60it/s]


Epoch: 6 Train loss: 0.9993


100%|██████████| 10/10 [00:00<00:00, 58.69it/s]


Epoch: 6 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 34.03it/s]


Epoch: 7 Train loss: 0.9870


100%|██████████| 10/10 [00:00<00:00, 45.53it/s]


Epoch: 7 Test loss: 0.9829


100%|██████████| 18/18 [00:00<00:00, 34.18it/s]


Epoch: 1 Train loss: 1.2399


100%|██████████| 10/10 [00:00<00:00, 56.31it/s]


Epoch: 1 Test loss: 1.1893


100%|██████████| 18/18 [00:00<00:00, 41.00it/s]


Epoch: 2 Train loss: 1.1101


100%|██████████| 10/10 [00:00<00:00, 45.21it/s]


Epoch: 2 Test loss: 1.0757


100%|██████████| 18/18 [00:00<00:00, 40.12it/s]


Epoch: 3 Train loss: 1.0419


100%|██████████| 10/10 [00:00<00:00, 46.36it/s]


Epoch: 3 Test loss: 1.0218


100%|██████████| 18/18 [00:00<00:00, 41.44it/s]


Epoch: 4 Train loss: 1.0150


100%|██████████| 10/10 [00:00<00:00, 51.75it/s]


Epoch: 4 Test loss: 0.9983


100%|██████████| 18/18 [00:00<00:00, 40.61it/s]


Epoch: 5 Train loss: 1.0021


100%|██████████| 10/10 [00:00<00:00, 56.04it/s]


Epoch: 5 Test loss: 0.9891


100%|██████████| 18/18 [00:00<00:00, 35.60it/s]


Epoch: 6 Train loss: 1.0012


100%|██████████| 10/10 [00:00<00:00, 47.87it/s]


Epoch: 6 Test loss: 0.9860


100%|██████████| 18/18 [00:00<00:00, 34.14it/s]


Epoch: 7 Train loss: 0.9891


100%|██████████| 10/10 [00:00<00:00, 47.51it/s]


Epoch: 7 Test loss: 0.9851


100%|██████████| 18/18 [00:00<00:00, 27.19it/s]


Epoch: 1 Train loss: 1.2207


100%|██████████| 10/10 [00:00<00:00, 54.31it/s]


Epoch: 1 Test loss: 1.7321


100%|██████████| 18/18 [00:00<00:00, 37.27it/s]


Epoch: 2 Train loss: 1.0834


100%|██████████| 10/10 [00:00<00:00, 55.74it/s]


Epoch: 2 Test loss: 1.6542


100%|██████████| 18/18 [00:00<00:00, 29.53it/s]


Epoch: 3 Train loss: 1.0303


100%|██████████| 10/10 [00:00<00:00, 59.39it/s]


Epoch: 3 Test loss: 1.6216


100%|██████████| 18/18 [00:00<00:00, 36.80it/s]


Epoch: 4 Train loss: 1.0062


100%|██████████| 10/10 [00:00<00:00, 59.00it/s]


Epoch: 4 Test loss: 1.6089


100%|██████████| 18/18 [00:00<00:00, 38.62it/s]


Epoch: 5 Train loss: 0.9991


100%|██████████| 10/10 [00:00<00:00, 59.15it/s]


Epoch: 5 Test loss: 1.6047


100%|██████████| 18/18 [00:00<00:00, 40.45it/s]


Epoch: 6 Train loss: 0.9938


100%|██████████| 10/10 [00:00<00:00, 57.67it/s]


Epoch: 6 Test loss: 1.6031


100%|██████████| 18/18 [00:00<00:00, 36.75it/s]


Epoch: 7 Train loss: 0.9930


100%|██████████| 10/10 [00:00<00:00, 49.17it/s]


Epoch: 7 Test loss: 1.6032


100%|██████████| 18/18 [00:00<00:00, 30.12it/s]


Epoch: 1 Train loss: 1.2399


100%|██████████| 10/10 [00:00<00:00, 47.48it/s]


Epoch: 1 Test loss: 1.7353


100%|██████████| 18/18 [00:00<00:00, 34.01it/s]


Epoch: 2 Train loss: 1.1048


100%|██████████| 10/10 [00:00<00:00, 59.36it/s]


Epoch: 2 Test loss: 1.6620


100%|██████████| 18/18 [00:00<00:00, 34.22it/s]


Epoch: 3 Train loss: 1.0430


100%|██████████| 10/10 [00:00<00:00, 58.62it/s]


Epoch: 3 Test loss: 1.6265


100%|██████████| 18/18 [00:00<00:00, 34.63it/s]


Epoch: 4 Train loss: 1.0117


100%|██████████| 10/10 [00:00<00:00, 59.66it/s]


Epoch: 4 Test loss: 1.6108


100%|██████████| 18/18 [00:00<00:00, 41.86it/s]


Epoch: 5 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 41.85it/s]


Epoch: 5 Test loss: 1.6045


100%|██████████| 18/18 [00:00<00:00, 34.45it/s]


Epoch: 6 Train loss: 0.9925


100%|██████████| 10/10 [00:00<00:00, 52.69it/s]


Epoch: 6 Test loss: 1.6021


100%|██████████| 18/18 [00:00<00:00, 29.66it/s]


Epoch: 7 Train loss: 0.9907


100%|██████████| 10/10 [00:00<00:00, 49.43it/s]


Epoch: 7 Test loss: 1.6012


100%|██████████| 18/18 [00:00<00:00, 38.46it/s]


Epoch: 1 Train loss: 1.2387


100%|██████████| 10/10 [00:00<00:00, 55.77it/s]


Epoch: 1 Test loss: 1.7455


100%|██████████| 18/18 [00:00<00:00, 40.66it/s]


Epoch: 2 Train loss: 1.1023


100%|██████████| 10/10 [00:00<00:00, 51.92it/s]


Epoch: 2 Test loss: 1.6647


100%|██████████| 18/18 [00:00<00:00, 41.53it/s]


Epoch: 3 Train loss: 1.0403


100%|██████████| 10/10 [00:00<00:00, 59.08it/s]


Epoch: 3 Test loss: 1.6271


100%|██████████| 18/18 [00:00<00:00, 33.85it/s]


Epoch: 4 Train loss: 1.0096


100%|██████████| 10/10 [00:00<00:00, 41.89it/s]


Epoch: 4 Test loss: 1.6109


100%|██████████| 18/18 [00:00<00:00, 26.63it/s]


Epoch: 5 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 43.98it/s]


Epoch: 5 Test loss: 1.6046


100%|██████████| 18/18 [00:00<00:00, 41.63it/s]


Epoch: 6 Train loss: 0.9911


100%|██████████| 10/10 [00:00<00:00, 51.03it/s]


Epoch: 6 Test loss: 1.6024


100%|██████████| 18/18 [00:00<00:00, 40.97it/s]


Epoch: 7 Train loss: 0.9893


100%|██████████| 10/10 [00:00<00:00, 39.45it/s]


Epoch: 7 Test loss: 1.6017


100%|██████████| 18/18 [00:00<00:00, 37.33it/s]


Epoch: 1 Train loss: 1.2424


100%|██████████| 10/10 [00:00<00:00, 44.45it/s]


Epoch: 1 Test loss: 1.8083


100%|██████████| 18/18 [00:00<00:00, 29.48it/s]


Epoch: 2 Train loss: 1.1077


100%|██████████| 10/10 [00:00<00:00, 52.51it/s]


Epoch: 2 Test loss: 1.6946


100%|██████████| 18/18 [00:00<00:00, 35.39it/s]


Epoch: 3 Train loss: 1.0444


100%|██████████| 10/10 [00:00<00:00, 59.09it/s]


Epoch: 3 Test loss: 1.6409


100%|██████████| 18/18 [00:00<00:00, 36.11it/s]


Epoch: 4 Train loss: 1.0138


100%|██████████| 10/10 [00:00<00:00, 57.55it/s]


Epoch: 4 Test loss: 1.6174


100%|██████████| 18/18 [00:00<00:00, 29.61it/s]


Epoch: 5 Train loss: 1.0026


100%|██████████| 10/10 [00:00<00:00, 59.24it/s]


Epoch: 5 Test loss: 1.6080


100%|██████████| 18/18 [00:00<00:00, 36.27it/s]


Epoch: 6 Train loss: 0.9951


100%|██████████| 10/10 [00:00<00:00, 61.23it/s]


Epoch: 6 Test loss: 1.6046


100%|██████████| 18/18 [00:00<00:00, 42.23it/s]


Epoch: 7 Train loss: 0.9935


100%|██████████| 10/10 [00:00<00:00, 49.00it/s]


Epoch: 7 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 28.32it/s]


Epoch: 1 Train loss: 1.2181


100%|██████████| 10/10 [00:00<00:00, 55.54it/s]


Epoch: 1 Test loss: 0.9612


100%|██████████| 18/18 [00:00<00:00, 34.54it/s]


Epoch: 2 Train loss: 1.0834


100%|██████████| 10/10 [00:00<00:00, 57.99it/s]


Epoch: 2 Test loss: 0.8840


100%|██████████| 18/18 [00:00<00:00, 30.24it/s]


Epoch: 3 Train loss: 1.0296


100%|██████████| 10/10 [00:00<00:00, 52.50it/s]


Epoch: 3 Test loss: 0.8520


100%|██████████| 18/18 [00:00<00:00, 36.01it/s]


Epoch: 4 Train loss: 1.0076


100%|██████████| 10/10 [00:00<00:00, 42.91it/s]


Epoch: 4 Test loss: 0.8394


100%|██████████| 18/18 [00:00<00:00, 36.53it/s]


Epoch: 5 Train loss: 0.9929


100%|██████████| 10/10 [00:00<00:00, 56.90it/s]


Epoch: 5 Test loss: 0.8351


100%|██████████| 18/18 [00:00<00:00, 40.92it/s]


Epoch: 6 Train loss: 0.9919


100%|██████████| 10/10 [00:00<00:00, 37.76it/s]


Epoch: 6 Test loss: 0.8335


100%|██████████| 18/18 [00:00<00:00, 34.62it/s]


Epoch: 7 Train loss: 0.9897


100%|██████████| 10/10 [00:00<00:00, 51.92it/s]


Epoch: 7 Test loss: 0.8333


100%|██████████| 18/18 [00:00<00:00, 28.30it/s]


Epoch: 1 Train loss: 1.2337


100%|██████████| 10/10 [00:00<00:00, 54.74it/s]


Epoch: 1 Test loss: 1.0707


100%|██████████| 18/18 [00:00<00:00, 39.63it/s]


Epoch: 2 Train loss: 1.0982


100%|██████████| 10/10 [00:00<00:00, 45.62it/s]


Epoch: 2 Test loss: 0.9322


100%|██████████| 18/18 [00:00<00:00, 40.89it/s]


Epoch: 3 Train loss: 1.0364


100%|██████████| 10/10 [00:00<00:00, 39.38it/s]


Epoch: 3 Test loss: 0.8701


100%|██████████| 18/18 [00:00<00:00, 31.89it/s]


Epoch: 4 Train loss: 1.0086


100%|██████████| 10/10 [00:00<00:00, 45.93it/s]


Epoch: 4 Test loss: 0.8446


100%|██████████| 18/18 [00:00<00:00, 30.08it/s]


Epoch: 5 Train loss: 0.9910


100%|██████████| 10/10 [00:00<00:00, 45.56it/s]


Epoch: 5 Test loss: 0.8353


100%|██████████| 18/18 [00:00<00:00, 40.86it/s]


Epoch: 6 Train loss: 0.9888


100%|██████████| 10/10 [00:00<00:00, 46.64it/s]


Epoch: 6 Test loss: 0.8324


100%|██████████| 18/18 [00:00<00:00, 38.84it/s]


Epoch: 7 Train loss: 0.9862


100%|██████████| 10/10 [00:00<00:00, 53.30it/s]


Epoch: 7 Test loss: 0.8315


100%|██████████| 18/18 [00:00<00:00, 34.83it/s]


Epoch: 1 Train loss: 1.2357


100%|██████████| 10/10 [00:00<00:00, 58.92it/s]


Epoch: 1 Test loss: 0.9737


100%|██████████| 18/18 [00:00<00:00, 34.42it/s]


Epoch: 2 Train loss: 1.1026


100%|██████████| 10/10 [00:00<00:00, 37.87it/s]


Epoch: 2 Test loss: 0.8936


100%|██████████| 18/18 [00:00<00:00, 37.42it/s]


Epoch: 3 Train loss: 1.0404


100%|██████████| 10/10 [00:00<00:00, 55.78it/s]


Epoch: 3 Test loss: 0.8564


100%|██████████| 18/18 [00:00<00:00, 30.10it/s]


Epoch: 4 Train loss: 1.0112


100%|██████████| 10/10 [00:00<00:00, 56.90it/s]


Epoch: 4 Test loss: 0.8403


100%|██████████| 18/18 [00:00<00:00, 41.09it/s]


Epoch: 5 Train loss: 0.9925


100%|██████████| 10/10 [00:00<00:00, 59.26it/s]


Epoch: 5 Test loss: 0.8340


100%|██████████| 18/18 [00:00<00:00, 34.36it/s]


Epoch: 6 Train loss: 0.9897


100%|██████████| 10/10 [00:00<00:00, 58.52it/s]


Epoch: 6 Test loss: 0.8317


100%|██████████| 18/18 [00:00<00:00, 39.89it/s]


Epoch: 7 Train loss: 0.9867


100%|██████████| 10/10 [00:00<00:00, 50.98it/s]


Epoch: 7 Test loss: 0.8310


100%|██████████| 18/18 [00:00<00:00, 39.87it/s]


Epoch: 1 Train loss: 1.2394


100%|██████████| 10/10 [00:00<00:00, 57.00it/s]


Epoch: 1 Test loss: 1.0357


100%|██████████| 18/18 [00:00<00:00, 26.91it/s]


Epoch: 2 Train loss: 1.1078


100%|██████████| 10/10 [00:00<00:00, 53.44it/s]


Epoch: 2 Test loss: 0.9231


100%|██████████| 18/18 [00:00<00:00, 42.99it/s]


Epoch: 3 Train loss: 1.0442


100%|██████████| 10/10 [00:00<00:00, 59.08it/s]


Epoch: 3 Test loss: 0.8698


100%|██████████| 18/18 [00:00<00:00, 22.99it/s]


Epoch: 4 Train loss: 1.0153


100%|██████████| 10/10 [00:00<00:00, 45.83it/s]


Epoch: 4 Test loss: 0.8467


100%|██████████| 18/18 [00:00<00:00, 36.73it/s]


Epoch: 5 Train loss: 0.9963


100%|██████████| 10/10 [00:00<00:00, 59.47it/s]


Epoch: 5 Test loss: 0.8373


100%|██████████| 18/18 [00:00<00:00, 31.63it/s]


Epoch: 6 Train loss: 0.9936


100%|██████████| 10/10 [00:00<00:00, 59.14it/s]


Epoch: 6 Test loss: 0.8340


100%|██████████| 18/18 [00:00<00:00, 28.71it/s]


Epoch: 7 Train loss: 0.9906


100%|██████████| 10/10 [00:00<00:00, 49.43it/s]


Epoch: 7 Test loss: 0.8330


100%|██████████| 18/18 [00:00<00:00, 37.83it/s]


Epoch: 1 Train loss: 1.2150


100%|██████████| 10/10 [00:00<00:00, 36.15it/s]


Epoch: 1 Test loss: 1.0412


100%|██████████| 18/18 [00:00<00:00, 25.76it/s]


Epoch: 2 Train loss: 1.0806


100%|██████████| 10/10 [00:00<00:00, 36.94it/s]


Epoch: 2 Test loss: 0.9639


100%|██████████| 18/18 [00:00<00:00, 30.67it/s]


Epoch: 3 Train loss: 1.0304


100%|██████████| 10/10 [00:00<00:00, 32.51it/s]


Epoch: 3 Test loss: 0.9314


100%|██████████| 18/18 [00:00<00:00, 26.02it/s]


Epoch: 4 Train loss: 1.0049


100%|██████████| 10/10 [00:00<00:00, 47.01it/s]


Epoch: 4 Test loss: 0.9192


100%|██████████| 18/18 [00:00<00:00, 36.03it/s]


Epoch: 5 Train loss: 0.9952


100%|██████████| 10/10 [00:00<00:00, 28.89it/s]


Epoch: 5 Test loss: 0.9147


100%|██████████| 18/18 [00:00<00:00, 35.45it/s]


Epoch: 6 Train loss: 0.9929


100%|██████████| 10/10 [00:00<00:00, 52.01it/s]


Epoch: 6 Test loss: 0.9135


100%|██████████| 18/18 [00:00<00:00, 24.17it/s]


Epoch: 7 Train loss: 0.9917


100%|██████████| 10/10 [00:00<00:00, 27.24it/s]


Epoch: 7 Test loss: 0.9125


100%|██████████| 18/18 [00:00<00:00, 28.70it/s]


Epoch: 1 Train loss: 1.2314


100%|██████████| 10/10 [00:00<00:00, 52.39it/s]


Epoch: 1 Test loss: 1.1515


100%|██████████| 18/18 [00:00<00:00, 40.50it/s]


Epoch: 2 Train loss: 1.0949


100%|██████████| 10/10 [00:00<00:00, 55.94it/s]


Epoch: 2 Test loss: 1.0126


100%|██████████| 18/18 [00:00<00:00, 39.46it/s]


Epoch: 3 Train loss: 1.0364


100%|██████████| 10/10 [00:00<00:00, 55.53it/s]


Epoch: 3 Test loss: 0.9506


100%|██████████| 18/18 [00:00<00:00, 39.65it/s]


Epoch: 4 Train loss: 1.0054


100%|██████████| 10/10 [00:00<00:00, 37.80it/s]


Epoch: 4 Test loss: 0.9252


100%|██████████| 18/18 [00:00<00:00, 32.75it/s]


Epoch: 5 Train loss: 0.9928


100%|██████████| 10/10 [00:00<00:00, 52.12it/s]


Epoch: 5 Test loss: 0.9159


100%|██████████| 18/18 [00:00<00:00, 27.47it/s]


Epoch: 6 Train loss: 0.9891


100%|██████████| 10/10 [00:00<00:00, 30.61it/s]


Epoch: 6 Test loss: 0.9132


100%|██████████| 18/18 [00:00<00:00, 28.98it/s]


Epoch: 7 Train loss: 0.9876


100%|██████████| 10/10 [00:00<00:00, 40.57it/s]


Epoch: 7 Test loss: 0.9121


100%|██████████| 18/18 [00:00<00:00, 25.78it/s]


Epoch: 1 Train loss: 1.2348


100%|██████████| 10/10 [00:00<00:00, 41.57it/s]


Epoch: 1 Test loss: 1.0444


100%|██████████| 18/18 [00:00<00:00, 21.95it/s]


Epoch: 2 Train loss: 1.1019


100%|██████████| 10/10 [00:00<00:00, 20.01it/s]


Epoch: 2 Test loss: 0.9713


100%|██████████| 18/18 [00:00<00:00, 39.16it/s]


Epoch: 3 Train loss: 1.0433


100%|██████████| 10/10 [00:00<00:00, 49.18it/s]


Epoch: 3 Test loss: 0.9361


100%|██████████| 18/18 [00:00<00:00, 34.36it/s]


Epoch: 4 Train loss: 1.0101


100%|██████████| 10/10 [00:00<00:00, 42.56it/s]


Epoch: 4 Test loss: 0.9206


100%|██████████| 18/18 [00:00<00:00, 31.68it/s]


Epoch: 5 Train loss: 0.9958


100%|██████████| 10/10 [00:00<00:00, 41.82it/s]


Epoch: 5 Test loss: 0.9144


100%|██████████| 18/18 [00:00<00:00, 28.19it/s]


Epoch: 6 Train loss: 0.9914


100%|██████████| 10/10 [00:00<00:00, 30.67it/s]


Epoch: 6 Test loss: 0.9121


100%|██████████| 18/18 [00:00<00:00, 28.35it/s]


Epoch: 7 Train loss: 0.9894


100%|██████████| 10/10 [00:00<00:00, 27.81it/s]


Epoch: 7 Test loss: 0.9114


100%|██████████| 18/18 [00:00<00:00, 27.67it/s]


Epoch: 1 Train loss: 1.2366


100%|██████████| 10/10 [00:00<00:00, 32.17it/s]


Epoch: 1 Test loss: 1.1167


100%|██████████| 18/18 [00:00<00:00, 33.70it/s]


Epoch: 2 Train loss: 1.1047


100%|██████████| 10/10 [00:00<00:00, 52.78it/s]


Epoch: 2 Test loss: 1.0036


100%|██████████| 18/18 [00:00<00:00, 31.21it/s]


Epoch: 3 Train loss: 1.0448


100%|██████████| 10/10 [00:00<00:00, 53.22it/s]


Epoch: 3 Test loss: 0.9504


100%|██████████| 18/18 [00:00<00:00, 38.70it/s]


Epoch: 4 Train loss: 1.0124


100%|██████████| 10/10 [00:00<00:00, 38.61it/s]


Epoch: 4 Test loss: 0.9273


100%|██████████| 18/18 [00:00<00:00, 39.32it/s]


Epoch: 5 Train loss: 0.9987


100%|██████████| 10/10 [00:00<00:00, 40.78it/s]


Epoch: 5 Test loss: 0.9180


100%|██████████| 18/18 [00:00<00:00, 40.89it/s]


Epoch: 6 Train loss: 0.9943


100%|██████████| 10/10 [00:00<00:00, 51.56it/s]


Epoch: 6 Test loss: 0.9147


100%|██████████| 18/18 [00:00<00:00, 38.38it/s]


Epoch: 7 Train loss: 0.9926


100%|██████████| 10/10 [00:00<00:00, 35.16it/s]


Epoch: 7 Test loss: 0.9136


100%|██████████| 18/18 [00:00<00:00, 28.19it/s]


Epoch: 1 Train loss: 1.2171


100%|██████████| 10/10 [00:00<00:00, 30.45it/s]


Epoch: 1 Test loss: 1.4085


100%|██████████| 18/18 [00:00<00:00, 27.30it/s]


Epoch: 2 Train loss: 1.0851


100%|██████████| 10/10 [00:00<00:00, 39.61it/s]


Epoch: 2 Test loss: 1.3310


100%|██████████| 18/18 [00:00<00:00, 25.43it/s]


Epoch: 3 Train loss: 1.0273


100%|██████████| 10/10 [00:00<00:00, 52.66it/s]


Epoch: 3 Test loss: 1.2989


100%|██████████| 18/18 [00:00<00:00, 33.03it/s]


Epoch: 4 Train loss: 1.0107


100%|██████████| 10/10 [00:00<00:00, 54.86it/s]


Epoch: 4 Test loss: 1.2867


100%|██████████| 18/18 [00:00<00:00, 27.79it/s]


Epoch: 5 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 52.24it/s]


Epoch: 5 Test loss: 1.2823


100%|██████████| 18/18 [00:00<00:00, 36.97it/s]


Epoch: 6 Train loss: 0.9946


100%|██████████| 10/10 [00:00<00:00, 29.25it/s]


Epoch: 6 Test loss: 1.2814


100%|██████████| 18/18 [00:00<00:00, 30.70it/s]


Epoch: 7 Train loss: 0.9913


100%|██████████| 10/10 [00:00<00:00, 43.95it/s]


Epoch: 7 Test loss: 1.2818


100%|██████████| 18/18 [00:00<00:00, 24.79it/s]


Epoch: 1 Train loss: 1.2332


100%|██████████| 10/10 [00:00<00:00, 34.38it/s]


Epoch: 1 Test loss: 1.5187


100%|██████████| 18/18 [00:00<00:00, 29.44it/s]


Epoch: 2 Train loss: 1.1002


100%|██████████| 10/10 [00:00<00:00, 34.59it/s]


Epoch: 2 Test loss: 1.3804


100%|██████████| 18/18 [00:00<00:00, 30.79it/s]


Epoch: 3 Train loss: 1.0343


100%|██████████| 10/10 [00:00<00:00, 39.49it/s]


Epoch: 3 Test loss: 1.3186


100%|██████████| 18/18 [00:00<00:00, 27.75it/s]


Epoch: 4 Train loss: 1.0127


100%|██████████| 10/10 [00:00<00:00, 41.56it/s]


Epoch: 4 Test loss: 1.2933


100%|██████████| 18/18 [00:00<00:00, 38.16it/s]


Epoch: 5 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 52.48it/s]


Epoch: 5 Test loss: 1.2842


100%|██████████| 18/18 [00:00<00:00, 35.99it/s]


Epoch: 6 Train loss: 0.9924


100%|██████████| 10/10 [00:00<00:00, 34.59it/s]


Epoch: 6 Test loss: 1.2809


100%|██████████| 18/18 [00:00<00:00, 31.75it/s]


Epoch: 7 Train loss: 0.9884


100%|██████████| 10/10 [00:00<00:00, 44.35it/s]


Epoch: 7 Test loss: 1.2803


100%|██████████| 18/18 [00:00<00:00, 32.10it/s]


Epoch: 1 Train loss: 1.2360


100%|██████████| 10/10 [00:00<00:00, 53.33it/s]


Epoch: 1 Test loss: 1.4110


100%|██████████| 18/18 [00:00<00:00, 39.29it/s]


Epoch: 2 Train loss: 1.1066


100%|██████████| 10/10 [00:00<00:00, 35.82it/s]


Epoch: 2 Test loss: 1.3384


100%|██████████| 18/18 [00:00<00:00, 34.97it/s]


Epoch: 3 Train loss: 1.0403


100%|██████████| 10/10 [00:00<00:00, 52.66it/s]


Epoch: 3 Test loss: 1.3036


100%|██████████| 18/18 [00:00<00:00, 32.81it/s]


Epoch: 4 Train loss: 1.0168


100%|██████████| 10/10 [00:00<00:00, 53.21it/s]


Epoch: 4 Test loss: 1.2884


100%|██████████| 18/18 [00:00<00:00, 19.92it/s]


Epoch: 5 Train loss: 1.0008


100%|██████████| 10/10 [00:00<00:00, 47.82it/s]


Epoch: 5 Test loss: 1.2823


100%|██████████| 18/18 [00:00<00:00, 33.92it/s]


Epoch: 6 Train loss: 0.9943


100%|██████████| 10/10 [00:00<00:00, 40.30it/s]


Epoch: 6 Test loss: 1.2801


100%|██████████| 18/18 [00:00<00:00, 34.75it/s]


Epoch: 7 Train loss: 0.9897


100%|██████████| 10/10 [00:00<00:00, 38.85it/s]


Epoch: 7 Test loss: 1.2794


100%|██████████| 18/18 [00:00<00:00, 27.30it/s]


Epoch: 1 Train loss: 1.2356


100%|██████████| 10/10 [00:00<00:00, 29.96it/s]


Epoch: 1 Test loss: 1.4213


100%|██████████| 18/18 [00:00<00:00, 39.30it/s]


Epoch: 2 Train loss: 1.1048


100%|██████████| 10/10 [00:00<00:00, 45.78it/s]


Epoch: 2 Test loss: 1.3414


100%|██████████| 18/18 [00:00<00:00, 36.92it/s]


Epoch: 3 Train loss: 1.0386


100%|██████████| 10/10 [00:00<00:00, 36.76it/s]


Epoch: 3 Test loss: 1.3046


100%|██████████| 18/18 [00:00<00:00, 31.44it/s]


Epoch: 4 Train loss: 1.0155


100%|██████████| 10/10 [00:00<00:00, 53.30it/s]


Epoch: 4 Test loss: 1.2888


100%|██████████| 18/18 [00:00<00:00, 39.01it/s]


Epoch: 5 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 51.68it/s]


Epoch: 5 Test loss: 1.2827


100%|██████████| 18/18 [00:00<00:00, 33.87it/s]


Epoch: 6 Train loss: 0.9937


100%|██████████| 10/10 [00:00<00:00, 28.58it/s]


Epoch: 6 Test loss: 1.2806


100%|██████████| 18/18 [00:00<00:00, 31.99it/s]


Epoch: 7 Train loss: 0.9893


100%|██████████| 10/10 [00:00<00:00, 44.17it/s]


Epoch: 7 Test loss: 1.2800


#### Converter2

In [30]:
from fmri_reconstruction_with_dmvae.models.converter import Converter2

hidden_dim = 4096

lr = 1e-4
weight_decay = 1e-1

epochs = 7

conversion_output_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/results/deeprecon/conversion/"

pattern_corr_result = []
profile_corr_result = []

for subj_target, subj_source in permutations(subj_list, 2):
    s_t = f"{int(subj_target):02d}"
    s_s = f"{int(subj_source):02d}"

    model = Converter2(all_subj_num_voxels[f"subj{s_s}"], all_subj_num_voxels[f"subj{s_t}"], hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay) # L2
    loss_fn = torch.nn.MSELoss()

    g.manual_seed(42)
    for epoch in range(1, epochs + 1):
        train_loss = train(epoch)
        test_loss, target_brain, recon_brain, label = test(epoch)

    target_brain = torch.cat(target_brain, dim=0).cpu().numpy()
    recon_brain = torch.cat(recon_brain, dim=0).cpu().numpy()
    label = torch.cat(label, dim=0).cpu().numpy()

    pattern_corr = calculate_pattern_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(pattern_corr):
        pattern_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Image_idx': i+1}
        )

    profile_corr = calculate_profile_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(profile_corr):
        profile_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Voxel_idx': i}
        )

save_result(pattern_corr_result, conversion_output_path, "pattern_correlation_converter2.csv")
save_result(profile_corr_result, conversion_output_path, "profile_correlation_converter2.csv")

100%|██████████| 18/18 [00:00<00:00, 39.99it/s]


Epoch: 1 Train loss: 1.0336


100%|██████████| 10/10 [00:00<00:00, 54.31it/s]


Epoch: 1 Test loss: 1.0126


100%|██████████| 18/18 [00:00<00:00, 36.23it/s]


Epoch: 2 Train loss: 1.0105


100%|██████████| 10/10 [00:00<00:00, 55.37it/s]


Epoch: 2 Test loss: 0.9925


100%|██████████| 18/18 [00:00<00:00, 35.63it/s]


Epoch: 3 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 53.76it/s]


Epoch: 3 Test loss: 0.9863


100%|██████████| 18/18 [00:00<00:00, 44.41it/s]


Epoch: 4 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 56.55it/s]


Epoch: 4 Test loss: 0.9846


100%|██████████| 18/18 [00:00<00:00, 45.43it/s]


Epoch: 5 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 57.80it/s]


Epoch: 5 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 33.45it/s]


Epoch: 6 Train loss: 1.0050


100%|██████████| 10/10 [00:00<00:00, 43.27it/s]


Epoch: 6 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 39.31it/s]


Epoch: 7 Train loss: 0.9949


100%|██████████| 10/10 [00:00<00:00, 46.45it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 34.57it/s]


Epoch: 1 Train loss: 1.0342


100%|██████████| 10/10 [00:00<00:00, 55.62it/s]


Epoch: 1 Test loss: 0.9997


100%|██████████| 18/18 [00:00<00:00, 36.26it/s]


Epoch: 2 Train loss: 1.0113


100%|██████████| 10/10 [00:00<00:00, 55.08it/s]


Epoch: 2 Test loss: 0.9890


100%|██████████| 18/18 [00:00<00:00, 44.65it/s]


Epoch: 3 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 58.90it/s]


Epoch: 3 Test loss: 0.9855


100%|██████████| 18/18 [00:00<00:00, 44.71it/s]


Epoch: 4 Train loss: 1.0009


100%|██████████| 10/10 [00:00<00:00, 56.01it/s]


Epoch: 4 Test loss: 0.9844


100%|██████████| 18/18 [00:00<00:00, 44.68it/s]


Epoch: 5 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 30.82it/s]


Epoch: 5 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 28.81it/s]


Epoch: 6 Train loss: 1.0050


100%|██████████| 10/10 [00:00<00:00, 31.12it/s]


Epoch: 6 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 28.95it/s]


Epoch: 7 Train loss: 0.9949


100%|██████████| 10/10 [00:00<00:00, 53.30it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 33.45it/s]


Epoch: 1 Train loss: 1.0340


100%|██████████| 10/10 [00:00<00:00, 49.53it/s]


Epoch: 1 Test loss: 1.0009


100%|██████████| 18/18 [00:00<00:00, 38.27it/s]


Epoch: 2 Train loss: 1.0111


100%|██████████| 10/10 [00:00<00:00, 56.76it/s]


Epoch: 2 Test loss: 0.9892


100%|██████████| 18/18 [00:00<00:00, 45.29it/s]


Epoch: 3 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 59.38it/s]


Epoch: 3 Test loss: 0.9855


100%|██████████| 18/18 [00:00<00:00, 45.11it/s]


Epoch: 4 Train loss: 1.0009


100%|██████████| 10/10 [00:00<00:00, 59.12it/s]


Epoch: 4 Test loss: 0.9844


100%|██████████| 18/18 [00:00<00:00, 38.66it/s]


Epoch: 5 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 55.78it/s]


Epoch: 5 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 43.73it/s]


Epoch: 6 Train loss: 1.0050


100%|██████████| 10/10 [00:00<00:00, 56.46it/s]


Epoch: 6 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 37.40it/s]


Epoch: 7 Train loss: 0.9949


100%|██████████| 10/10 [00:00<00:00, 52.51it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 29.99it/s]


Epoch: 1 Train loss: 1.0344


100%|██████████| 10/10 [00:00<00:00, 56.28it/s]


Epoch: 1 Test loss: 1.0081


100%|██████████| 18/18 [00:00<00:00, 38.95it/s]


Epoch: 2 Train loss: 1.0115


100%|██████████| 10/10 [00:00<00:00, 57.25it/s]


Epoch: 2 Test loss: 0.9916


100%|██████████| 18/18 [00:00<00:00, 42.44it/s]


Epoch: 3 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 58.96it/s]


Epoch: 3 Test loss: 0.9862


100%|██████████| 18/18 [00:00<00:00, 35.82it/s]


Epoch: 4 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 30.86it/s]


Epoch: 4 Test loss: 0.9846


100%|██████████| 18/18 [00:00<00:00, 31.78it/s]


Epoch: 5 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 41.23it/s]


Epoch: 5 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 30.47it/s]


Epoch: 6 Train loss: 1.0050


100%|██████████| 10/10 [00:00<00:00, 39.85it/s]


Epoch: 6 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 28.51it/s]


Epoch: 7 Train loss: 0.9949


100%|██████████| 10/10 [00:00<00:00, 26.96it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 21.67it/s]


Epoch: 1 Train loss: 1.0347


100%|██████████| 10/10 [00:00<00:00, 30.02it/s]


Epoch: 1 Test loss: 1.6192


100%|██████████| 18/18 [00:00<00:00, 36.64it/s]


Epoch: 2 Train loss: 1.0077


100%|██████████| 10/10 [00:00<00:00, 32.57it/s]


Epoch: 2 Test loss: 1.6079


100%|██████████| 18/18 [00:00<00:00, 40.54it/s]


Epoch: 3 Train loss: 1.0032


100%|██████████| 10/10 [00:00<00:00, 42.50it/s]


Epoch: 3 Test loss: 1.6045


100%|██████████| 18/18 [00:00<00:00, 39.19it/s]


Epoch: 4 Train loss: 1.0003


100%|██████████| 10/10 [00:00<00:00, 52.64it/s]


Epoch: 4 Test loss: 1.6036


100%|██████████| 18/18 [00:00<00:00, 41.49it/s]


Epoch: 5 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 51.88it/s]


Epoch: 5 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 39.11it/s]


Epoch: 6 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 52.40it/s]


Epoch: 6 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 31.47it/s]


Epoch: 7 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 46.66it/s]


Epoch: 7 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 22.13it/s]


Epoch: 1 Train loss: 1.0366


100%|██████████| 10/10 [00:00<00:00, 40.99it/s]


Epoch: 1 Test loss: 1.6189


100%|██████████| 18/18 [00:00<00:00, 28.47it/s]


Epoch: 2 Train loss: 1.0094


100%|██████████| 10/10 [00:00<00:00, 41.24it/s]


Epoch: 2 Test loss: 1.6082


100%|██████████| 18/18 [00:00<00:00, 32.78it/s]


Epoch: 3 Train loss: 1.0039


100%|██████████| 10/10 [00:00<00:00, 41.42it/s]


Epoch: 3 Test loss: 1.6047


100%|██████████| 18/18 [00:00<00:00, 28.38it/s]


Epoch: 4 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 39.20it/s]


Epoch: 4 Test loss: 1.6037


100%|██████████| 18/18 [00:00<00:00, 32.27it/s]


Epoch: 5 Train loss: 1.0019


100%|██████████| 10/10 [00:00<00:00, 32.67it/s]


Epoch: 5 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 32.73it/s]


Epoch: 6 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 50.55it/s]


Epoch: 6 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 38.35it/s]


Epoch: 7 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 42.86it/s]


Epoch: 7 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 25.25it/s]


Epoch: 1 Train loss: 1.0364


100%|██████████| 10/10 [00:00<00:00, 46.45it/s]


Epoch: 1 Test loss: 1.6201


100%|██████████| 18/18 [00:00<00:00, 41.94it/s]


Epoch: 2 Train loss: 1.0092


100%|██████████| 10/10 [00:00<00:00, 54.74it/s]


Epoch: 2 Test loss: 1.6085


100%|██████████| 18/18 [00:00<00:00, 40.70it/s]


Epoch: 3 Train loss: 1.0038


100%|██████████| 10/10 [00:00<00:00, 37.38it/s]


Epoch: 3 Test loss: 1.6048


100%|██████████| 18/18 [00:00<00:00, 32.71it/s]


Epoch: 4 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 34.18it/s]


Epoch: 4 Test loss: 1.6037


100%|██████████| 18/18 [00:00<00:00, 25.19it/s]


Epoch: 5 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 41.45it/s]


Epoch: 5 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 34.02it/s]


Epoch: 6 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 31.12it/s]


Epoch: 6 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 27.84it/s]


Epoch: 7 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 28.18it/s]


Epoch: 7 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 23.86it/s]


Epoch: 1 Train loss: 1.0369


100%|██████████| 10/10 [00:00<00:00, 37.10it/s]


Epoch: 1 Test loss: 1.6276


100%|██████████| 18/18 [00:00<00:00, 31.63it/s]


Epoch: 2 Train loss: 1.0097


100%|██████████| 10/10 [00:00<00:00, 38.07it/s]


Epoch: 2 Test loss: 1.6109


100%|██████████| 18/18 [00:00<00:00, 18.68it/s]


Epoch: 3 Train loss: 1.0040


100%|██████████| 10/10 [00:00<00:00, 48.47it/s]


Epoch: 3 Test loss: 1.6055


100%|██████████| 18/18 [00:00<00:00, 32.25it/s]


Epoch: 4 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 39.31it/s]


Epoch: 4 Test loss: 1.6039


100%|██████████| 18/18 [00:00<00:00, 38.38it/s]


Epoch: 5 Train loss: 1.0019


100%|██████████| 10/10 [00:00<00:00, 50.15it/s]


Epoch: 5 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 33.33it/s]


Epoch: 6 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 42.58it/s]


Epoch: 6 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 30.89it/s]


Epoch: 7 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 29.55it/s]


Epoch: 7 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 35.70it/s]


Epoch: 1 Train loss: 1.0331


100%|██████████| 10/10 [00:00<00:00, 34.69it/s]


Epoch: 1 Test loss: 0.8488


100%|██████████| 18/18 [00:00<00:00, 31.65it/s]


Epoch: 2 Train loss: 1.0091


100%|██████████| 10/10 [00:00<00:00, 50.34it/s]


Epoch: 2 Test loss: 0.8374


100%|██████████| 18/18 [00:00<00:00, 33.53it/s]


Epoch: 3 Train loss: 1.0036


100%|██████████| 10/10 [00:00<00:00, 50.98it/s]


Epoch: 3 Test loss: 0.8340


100%|██████████| 18/18 [00:00<00:00, 40.29it/s]


Epoch: 4 Train loss: 1.0024


100%|██████████| 10/10 [00:00<00:00, 48.51it/s]


Epoch: 4 Test loss: 0.8330


100%|██████████| 18/18 [00:00<00:00, 40.28it/s]


Epoch: 5 Train loss: 0.9963


100%|██████████| 10/10 [00:00<00:00, 43.14it/s]


Epoch: 5 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 25.00it/s]


Epoch: 6 Train loss: 0.9988


100%|██████████| 10/10 [00:00<00:00, 45.03it/s]


Epoch: 6 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 40.24it/s]


Epoch: 7 Train loss: 0.9982


100%|██████████| 10/10 [00:00<00:00, 42.29it/s]


Epoch: 7 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 36.17it/s]


Epoch: 1 Train loss: 1.0343


100%|██████████| 10/10 [00:00<00:00, 24.88it/s]


Epoch: 1 Test loss: 0.8611


100%|██████████| 18/18 [00:00<00:00, 31.56it/s]


Epoch: 2 Train loss: 1.0100


100%|██████████| 10/10 [00:00<00:00, 51.74it/s]


Epoch: 2 Test loss: 0.8411


100%|██████████| 18/18 [00:00<00:00, 31.23it/s]


Epoch: 3 Train loss: 1.0039


100%|██████████| 10/10 [00:00<00:00, 42.29it/s]


Epoch: 3 Test loss: 0.8349


100%|██████████| 18/18 [00:00<00:00, 41.77it/s]


Epoch: 4 Train loss: 1.0024


100%|██████████| 10/10 [00:00<00:00, 53.44it/s]


Epoch: 4 Test loss: 0.8332


100%|██████████| 18/18 [00:00<00:00, 40.52it/s]


Epoch: 5 Train loss: 0.9962


100%|██████████| 10/10 [00:00<00:00, 39.47it/s]


Epoch: 5 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 33.72it/s]


Epoch: 6 Train loss: 0.9988


100%|██████████| 10/10 [00:00<00:00, 54.74it/s]


Epoch: 6 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 39.39it/s]


Epoch: 7 Train loss: 0.9981


100%|██████████| 10/10 [00:00<00:00, 45.73it/s]


Epoch: 7 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 22.51it/s]


Epoch: 1 Train loss: 1.0347


100%|██████████| 10/10 [00:00<00:00, 36.39it/s]


Epoch: 1 Test loss: 0.8496


100%|██████████| 18/18 [00:00<00:00, 39.89it/s]


Epoch: 2 Train loss: 1.0105


100%|██████████| 10/10 [00:00<00:00, 54.12it/s]


Epoch: 2 Test loss: 0.8379


100%|██████████| 18/18 [00:00<00:00, 34.86it/s]


Epoch: 3 Train loss: 1.0043


100%|██████████| 10/10 [00:00<00:00, 51.33it/s]


Epoch: 3 Test loss: 0.8341


100%|██████████| 18/18 [00:00<00:00, 42.00it/s]


Epoch: 4 Train loss: 1.0026


100%|██████████| 10/10 [00:00<00:00, 43.38it/s]


Epoch: 4 Test loss: 0.8330


100%|██████████| 18/18 [00:00<00:00, 37.17it/s]


Epoch: 5 Train loss: 0.9963


100%|██████████| 10/10 [00:00<00:00, 58.22it/s]


Epoch: 5 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 38.23it/s]


Epoch: 6 Train loss: 0.9989


100%|██████████| 10/10 [00:00<00:00, 52.64it/s]


Epoch: 6 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 36.43it/s]


Epoch: 7 Train loss: 0.9982


100%|██████████| 10/10 [00:00<00:00, 27.00it/s]


Epoch: 7 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 37.09it/s]


Epoch: 1 Train loss: 1.0349


100%|██████████| 10/10 [00:00<00:00, 34.40it/s]


Epoch: 1 Test loss: 0.8569


100%|██████████| 18/18 [00:00<00:00, 34.32it/s]


Epoch: 2 Train loss: 1.0109


100%|██████████| 10/10 [00:00<00:00, 50.17it/s]


Epoch: 2 Test loss: 0.8403


100%|██████████| 18/18 [00:00<00:00, 36.40it/s]


Epoch: 3 Train loss: 1.0044


100%|██████████| 10/10 [00:00<00:00, 53.24it/s]


Epoch: 3 Test loss: 0.8349


100%|██████████| 18/18 [00:00<00:00, 33.81it/s]


Epoch: 4 Train loss: 1.0027


100%|██████████| 10/10 [00:00<00:00, 49.79it/s]


Epoch: 4 Test loss: 0.8333


100%|██████████| 18/18 [00:00<00:00, 35.95it/s]


Epoch: 5 Train loss: 0.9964


100%|██████████| 10/10 [00:00<00:00, 52.95it/s]


Epoch: 5 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 24.40it/s]


Epoch: 6 Train loss: 0.9989


100%|██████████| 10/10 [00:00<00:00, 54.78it/s]


Epoch: 6 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 36.04it/s]


Epoch: 7 Train loss: 0.9982


100%|██████████| 10/10 [00:00<00:00, 32.47it/s]


Epoch: 7 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 22.51it/s]


Epoch: 1 Train loss: 1.0302


100%|██████████| 10/10 [00:00<00:00, 49.42it/s]


Epoch: 1 Test loss: 0.9290


100%|██████████| 18/18 [00:00<00:00, 39.11it/s]


Epoch: 2 Train loss: 1.0057


100%|██████████| 10/10 [00:00<00:00, 37.68it/s]


Epoch: 2 Test loss: 0.9175


100%|██████████| 18/18 [00:00<00:00, 35.22it/s]


Epoch: 3 Train loss: 1.0040


100%|██████████| 10/10 [00:00<00:00, 51.87it/s]


Epoch: 3 Test loss: 0.9140


100%|██████████| 18/18 [00:00<00:00, 33.04it/s]


Epoch: 4 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 51.35it/s]


Epoch: 4 Test loss: 0.9131


100%|██████████| 18/18 [00:00<00:00, 35.69it/s]


Epoch: 5 Train loss: 0.9982


100%|██████████| 10/10 [00:00<00:00, 49.42it/s]


Epoch: 5 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 26.95it/s]


Epoch: 6 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 40.05it/s]


Epoch: 6 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 42.01it/s]


Epoch: 7 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 38.30it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 21.83it/s]


Epoch: 1 Train loss: 1.0316


100%|██████████| 10/10 [00:00<00:00, 29.02it/s]


Epoch: 1 Test loss: 0.9413


100%|██████████| 18/18 [00:00<00:00, 29.45it/s]


Epoch: 2 Train loss: 1.0067


100%|██████████| 10/10 [00:00<00:00, 30.58it/s]


Epoch: 2 Test loss: 0.9212


100%|██████████| 18/18 [00:00<00:00, 28.02it/s]


Epoch: 3 Train loss: 1.0042


100%|██████████| 10/10 [00:00<00:00, 38.69it/s]


Epoch: 3 Test loss: 0.9150


100%|██████████| 18/18 [00:00<00:00, 30.17it/s]


Epoch: 4 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 28.72it/s]


Epoch: 4 Test loss: 0.9133


100%|██████████| 18/18 [00:00<00:00, 29.16it/s]


Epoch: 5 Train loss: 0.9982


100%|██████████| 10/10 [00:00<00:00, 35.24it/s]


Epoch: 5 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 42.90it/s]


Epoch: 6 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 54.69it/s]


Epoch: 6 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 36.62it/s]


Epoch: 7 Train loss: 0.9998


100%|██████████| 10/10 [00:00<00:00, 47.43it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 36.48it/s]


Epoch: 1 Train loss: 1.0323


100%|██████████| 10/10 [00:00<00:00, 36.49it/s]


Epoch: 1 Test loss: 0.9285


100%|██████████| 18/18 [00:00<00:00, 33.77it/s]


Epoch: 2 Train loss: 1.0075


100%|██████████| 10/10 [00:00<00:00, 44.26it/s]


Epoch: 2 Test loss: 0.9177


100%|██████████| 18/18 [00:00<00:00, 35.05it/s]


Epoch: 3 Train loss: 1.0047


100%|██████████| 10/10 [00:00<00:00, 40.52it/s]


Epoch: 3 Test loss: 0.9142


100%|██████████| 18/18 [00:00<00:00, 42.56it/s]


Epoch: 4 Train loss: 0.9997


100%|██████████| 10/10 [00:00<00:00, 48.87it/s]


Epoch: 4 Test loss: 0.9131


100%|██████████| 18/18 [00:00<00:00, 32.12it/s]


Epoch: 5 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 49.31it/s]


Epoch: 5 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 36.23it/s]


Epoch: 6 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 54.90it/s]


Epoch: 6 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 34.73it/s]


Epoch: 7 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 48.15it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 38.26it/s]


Epoch: 1 Train loss: 1.0325


100%|██████████| 10/10 [00:00<00:00, 31.10it/s]


Epoch: 1 Test loss: 0.9370


100%|██████████| 18/18 [00:00<00:00, 35.21it/s]


Epoch: 2 Train loss: 1.0077


100%|██████████| 10/10 [00:00<00:00, 53.23it/s]


Epoch: 2 Test loss: 0.9204


100%|██████████| 18/18 [00:00<00:00, 42.85it/s]


Epoch: 3 Train loss: 1.0048


100%|██████████| 10/10 [00:00<00:00, 33.59it/s]


Epoch: 3 Test loss: 0.9150


100%|██████████| 18/18 [00:00<00:00, 41.88it/s]


Epoch: 4 Train loss: 0.9997


100%|██████████| 10/10 [00:00<00:00, 38.50it/s]


Epoch: 4 Test loss: 0.9133


100%|██████████| 18/18 [00:00<00:00, 42.28it/s]


Epoch: 5 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 43.93it/s]


Epoch: 5 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 35.12it/s]


Epoch: 6 Train loss: 0.9995


100%|██████████| 10/10 [00:00<00:00, 45.71it/s]


Epoch: 6 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 41.75it/s]


Epoch: 7 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 26.97it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 38.12it/s]


Epoch: 1 Train loss: 1.0321


100%|██████████| 10/10 [00:00<00:00, 32.18it/s]


Epoch: 1 Test loss: 1.2958


100%|██████████| 18/18 [00:00<00:00, 34.76it/s]


Epoch: 2 Train loss: 1.0097


100%|██████████| 10/10 [00:00<00:00, 53.24it/s]


Epoch: 2 Test loss: 1.2842


100%|██████████| 18/18 [00:00<00:00, 43.30it/s]


Epoch: 3 Train loss: 0.9998


100%|██████████| 10/10 [00:00<00:00, 55.16it/s]


Epoch: 3 Test loss: 1.2808


100%|██████████| 18/18 [00:00<00:00, 43.40it/s]


Epoch: 4 Train loss: 1.0040


100%|██████████| 10/10 [00:00<00:00, 34.14it/s]


Epoch: 4 Test loss: 1.2798


100%|██████████| 18/18 [00:00<00:00, 33.84it/s]


Epoch: 5 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 56.07it/s]


Epoch: 5 Test loss: 1.2795


100%|██████████| 18/18 [00:01<00:00, 16.55it/s]


Epoch: 6 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 32.28it/s]


Epoch: 6 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 42.96it/s]


Epoch: 7 Train loss: 0.9971


100%|██████████| 10/10 [00:00<00:00, 25.04it/s]


Epoch: 7 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 38.01it/s]


Epoch: 1 Train loss: 1.0335


100%|██████████| 10/10 [00:00<00:00, 29.02it/s]


Epoch: 1 Test loss: 1.3080


100%|██████████| 18/18 [00:00<00:00, 38.87it/s]


Epoch: 2 Train loss: 1.0107


100%|██████████| 10/10 [00:00<00:00, 36.94it/s]


Epoch: 2 Test loss: 1.2880


100%|██████████| 18/18 [00:00<00:00, 32.66it/s]


Epoch: 3 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 28.81it/s]


Epoch: 3 Test loss: 1.2818


100%|██████████| 18/18 [00:00<00:00, 38.99it/s]


Epoch: 4 Train loss: 1.0040


100%|██████████| 10/10 [00:00<00:00, 47.75it/s]


Epoch: 4 Test loss: 1.2800


100%|██████████| 18/18 [00:00<00:00, 34.78it/s]


Epoch: 5 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 27.84it/s]


Epoch: 5 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 33.88it/s]


Epoch: 6 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 55.29it/s]


Epoch: 6 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 34.41it/s]


Epoch: 7 Train loss: 0.9971


100%|██████████| 10/10 [00:00<00:00, 43.87it/s]


Epoch: 7 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 27.73it/s]


Epoch: 1 Train loss: 1.0336


100%|██████████| 10/10 [00:00<00:00, 38.61it/s]


Epoch: 1 Test loss: 1.2950


100%|██████████| 18/18 [00:00<00:00, 31.83it/s]


Epoch: 2 Train loss: 1.0113


100%|██████████| 10/10 [00:00<00:00, 39.11it/s]


Epoch: 2 Test loss: 1.2844


100%|██████████| 18/18 [00:00<00:00, 27.57it/s]


Epoch: 3 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 38.90it/s]


Epoch: 3 Test loss: 1.2809


100%|██████████| 18/18 [00:00<00:00, 31.74it/s]


Epoch: 4 Train loss: 1.0042


100%|██████████| 10/10 [00:00<00:00, 29.67it/s]


Epoch: 4 Test loss: 1.2798


100%|██████████| 18/18 [00:00<00:00, 38.74it/s]


Epoch: 5 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 41.03it/s]


Epoch: 5 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 42.90it/s]


Epoch: 6 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 43.30it/s]


Epoch: 6 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 42.52it/s]


Epoch: 7 Train loss: 0.9971


100%|██████████| 10/10 [00:00<00:00, 41.77it/s]


Epoch: 7 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 25.49it/s]


Epoch: 1 Train loss: 1.0335


100%|██████████| 10/10 [00:00<00:00, 49.66it/s]


Epoch: 1 Test loss: 1.2962


100%|██████████| 18/18 [00:00<00:00, 42.51it/s]


Epoch: 2 Train loss: 1.0111


100%|██████████| 10/10 [00:00<00:00, 38.40it/s]


Epoch: 2 Test loss: 1.2846


100%|██████████| 18/18 [00:00<00:00, 41.07it/s]


Epoch: 3 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 53.94it/s]


Epoch: 3 Test loss: 1.2809


100%|██████████| 18/18 [00:00<00:00, 42.92it/s]


Epoch: 4 Train loss: 1.0042


100%|██████████| 10/10 [00:00<00:00, 55.36it/s]


Epoch: 4 Test loss: 1.2799


100%|██████████| 18/18 [00:00<00:00, 43.67it/s]


Epoch: 5 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 48.55it/s]


Epoch: 5 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 34.10it/s]


Epoch: 6 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 54.18it/s]


Epoch: 6 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 41.61it/s]


Epoch: 7 Train loss: 0.9971


100%|██████████| 10/10 [00:00<00:00, 49.71it/s]


Epoch: 7 Test loss: 1.2795


#### Converter3

In [31]:
from fmri_reconstruction_with_dmvae.models.converter import Converter3

hidden_dim = 4096

lr = 1e-4
weight_decay = 1e-1

epochs = 7

conversion_output_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/results/deeprecon/conversion/"

pattern_corr_result = []
profile_corr_result = []

for subj_target, subj_source in permutations(subj_list, 2):
    s_t = f"{int(subj_target):02d}"
    s_s = f"{int(subj_source):02d}"

    model = Converter3(all_subj_num_voxels[f"subj{s_s}"], all_subj_num_voxels[f"subj{s_t}"], hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = torch.nn.MSELoss()

    g.manual_seed(42)
    for epoch in range(1, epochs + 1):
        train_loss = train(epoch)
        test_loss, target_brain, recon_brain, label = test(epoch)

    target_brain = torch.cat(target_brain, dim=0).cpu().numpy()
    recon_brain = torch.cat(recon_brain, dim=0).cpu().numpy()
    label = torch.cat(label, dim=0).cpu().numpy()

    pattern_corr = calculate_pattern_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(pattern_corr):
        pattern_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Image_idx': i+1}
        )

    profile_corr = calculate_profile_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(profile_corr):
        profile_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Voxel_idx': i}
        )

save_result(pattern_corr_result, conversion_output_path, "pattern_correlation_converter3.csv")
save_result(profile_corr_result, conversion_output_path, "profile_correlation_converter3.csv")

100%|██████████| 18/18 [00:00<00:00, 28.26it/s]


Epoch: 1 Train loss: 1.0045


100%|██████████| 10/10 [00:00<00:00, 55.52it/s]


Epoch: 1 Test loss: 0.9875


100%|██████████| 18/18 [00:00<00:00, 37.77it/s]


Epoch: 2 Train loss: 1.0012


100%|██████████| 10/10 [00:00<00:00, 48.35it/s]


Epoch: 2 Test loss: 0.9848


100%|██████████| 18/18 [00:00<00:00, 34.78it/s]


Epoch: 3 Train loss: 0.9975


100%|██████████| 10/10 [00:00<00:00, 56.12it/s]


Epoch: 3 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 41.08it/s]


Epoch: 4 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 54.31it/s]


Epoch: 4 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 40.26it/s]


Epoch: 5 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 54.86it/s]


Epoch: 5 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 40.76it/s]


Epoch: 6 Train loss: 1.0051


100%|██████████| 10/10 [00:00<00:00, 53.82it/s]


Epoch: 6 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 40.16it/s]


Epoch: 7 Train loss: 0.9950


100%|██████████| 10/10 [00:00<00:00, 46.68it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 18.23it/s]


Epoch: 1 Train loss: 1.0045


100%|██████████| 10/10 [00:00<00:00, 39.01it/s]


Epoch: 1 Test loss: 0.9860


100%|██████████| 18/18 [00:00<00:00, 27.76it/s]


Epoch: 2 Train loss: 1.0013


100%|██████████| 10/10 [00:00<00:00, 29.99it/s]


Epoch: 2 Test loss: 0.9845


100%|██████████| 18/18 [00:00<00:00, 31.46it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 37.60it/s]


Epoch: 3 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 31.86it/s]


Epoch: 4 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 40.09it/s]


Epoch: 4 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 28.34it/s]


Epoch: 5 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 45.91it/s]


Epoch: 5 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 35.63it/s]


Epoch: 6 Train loss: 1.0051


100%|██████████| 10/10 [00:00<00:00, 53.39it/s]


Epoch: 6 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 40.31it/s]


Epoch: 7 Train loss: 0.9950


100%|██████████| 10/10 [00:00<00:00, 41.55it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 27.67it/s]


Epoch: 1 Train loss: 1.0045


100%|██████████| 10/10 [00:00<00:00, 51.71it/s]


Epoch: 1 Test loss: 0.9861


100%|██████████| 18/18 [00:00<00:00, 40.48it/s]


Epoch: 2 Train loss: 1.0012


100%|██████████| 10/10 [00:00<00:00, 52.19it/s]


Epoch: 2 Test loss: 0.9845


100%|██████████| 18/18 [00:00<00:00, 41.39it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 51.67it/s]


Epoch: 3 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 41.20it/s]


Epoch: 4 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 52.72it/s]


Epoch: 4 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 35.18it/s]


Epoch: 5 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 56.28it/s]


Epoch: 5 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 42.02it/s]


Epoch: 6 Train loss: 1.0051


100%|██████████| 10/10 [00:00<00:00, 55.16it/s]


Epoch: 6 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 41.00it/s]


Epoch: 7 Train loss: 0.9950


100%|██████████| 10/10 [00:00<00:00, 47.54it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 20.35it/s]


Epoch: 1 Train loss: 1.0045


100%|██████████| 10/10 [00:00<00:00, 48.02it/s]


Epoch: 1 Test loss: 0.9869


100%|██████████| 18/18 [00:00<00:00, 36.74it/s]


Epoch: 2 Train loss: 1.0013


100%|██████████| 10/10 [00:00<00:00, 51.52it/s]


Epoch: 2 Test loss: 0.9847


100%|██████████| 18/18 [00:00<00:00, 31.72it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 53.07it/s]


Epoch: 3 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 34.21it/s]


Epoch: 4 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 47.68it/s]


Epoch: 4 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 39.07it/s]


Epoch: 5 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 52.75it/s]


Epoch: 5 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 39.99it/s]


Epoch: 6 Train loss: 1.0051


100%|██████████| 10/10 [00:00<00:00, 52.19it/s]


Epoch: 6 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 40.48it/s]


Epoch: 7 Train loss: 0.9950


100%|██████████| 10/10 [00:00<00:00, 43.18it/s]


Epoch: 7 Test loss: 0.9840


100%|██████████| 18/18 [00:00<00:00, 24.52it/s]


Epoch: 1 Train loss: 1.0069


100%|██████████| 10/10 [00:00<00:00, 49.70it/s]


Epoch: 1 Test loss: 1.6054


100%|██████████| 18/18 [00:00<00:00, 38.35it/s]


Epoch: 2 Train loss: 0.9993


100%|██████████| 10/10 [00:00<00:00, 50.36it/s]


Epoch: 2 Test loss: 1.6038


100%|██████████| 18/18 [00:00<00:00, 39.27it/s]


Epoch: 3 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 47.38it/s]


Epoch: 3 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 35.22it/s]


Epoch: 4 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 51.72it/s]


Epoch: 4 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 39.49it/s]


Epoch: 5 Train loss: 1.0019


100%|██████████| 10/10 [00:00<00:00, 51.22it/s]


Epoch: 5 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 29.03it/s]


Epoch: 6 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 45.10it/s]


Epoch: 6 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 32.20it/s]


Epoch: 7 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 43.70it/s]


Epoch: 7 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 28.07it/s]


Epoch: 1 Train loss: 1.0071


100%|██████████| 10/10 [00:00<00:00, 51.15it/s]


Epoch: 1 Test loss: 1.6052


100%|██████████| 18/18 [00:00<00:00, 39.76it/s]


Epoch: 2 Train loss: 0.9995


100%|██████████| 10/10 [00:00<00:00, 44.68it/s]


Epoch: 2 Test loss: 1.6038


100%|██████████| 18/18 [00:00<00:00, 25.89it/s]


Epoch: 3 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 40.66it/s]


Epoch: 3 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 29.98it/s]


Epoch: 4 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 14.82it/s]


Epoch: 4 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 29.36it/s]


Epoch: 5 Train loss: 1.0019


100%|██████████| 10/10 [00:00<00:00, 41.38it/s]


Epoch: 5 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 32.25it/s]


Epoch: 6 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 47.04it/s]


Epoch: 6 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 37.22it/s]


Epoch: 7 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 43.26it/s]


Epoch: 7 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 28.54it/s]


Epoch: 1 Train loss: 1.0069


100%|██████████| 10/10 [00:00<00:00, 38.44it/s]


Epoch: 1 Test loss: 1.6053


100%|██████████| 18/18 [00:00<00:00, 32.06it/s]


Epoch: 2 Train loss: 0.9994


100%|██████████| 10/10 [00:00<00:00, 35.21it/s]


Epoch: 2 Test loss: 1.6038


100%|██████████| 18/18 [00:00<00:00, 37.51it/s]


Epoch: 3 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 51.51it/s]


Epoch: 3 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 32.44it/s]


Epoch: 4 Train loss: 0.9999


100%|██████████| 10/10 [00:00<00:00, 51.71it/s]


Epoch: 4 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 39.50it/s]


Epoch: 5 Train loss: 1.0019


100%|██████████| 10/10 [00:00<00:00, 50.30it/s]


Epoch: 5 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 39.18it/s]


Epoch: 6 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 51.84it/s]


Epoch: 6 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 38.98it/s]


Epoch: 7 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 45.27it/s]


Epoch: 7 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 29.76it/s]


Epoch: 1 Train loss: 1.0071


100%|██████████| 10/10 [00:00<00:00, 52.21it/s]


Epoch: 1 Test loss: 1.6063


100%|██████████| 18/18 [00:00<00:00, 38.74it/s]


Epoch: 2 Train loss: 0.9995


100%|██████████| 10/10 [00:00<00:00, 54.23it/s]


Epoch: 2 Test loss: 1.6040


100%|██████████| 18/18 [00:00<00:00, 40.26it/s]


Epoch: 3 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 53.80it/s]


Epoch: 3 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 33.30it/s]


Epoch: 4 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 55.02it/s]


Epoch: 4 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 39.08it/s]


Epoch: 5 Train loss: 1.0019


100%|██████████| 10/10 [00:00<00:00, 53.52it/s]


Epoch: 5 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 40.53it/s]


Epoch: 6 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 54.64it/s]


Epoch: 6 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 39.84it/s]


Epoch: 7 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 48.40it/s]


Epoch: 7 Test loss: 1.6033


100%|██████████| 18/18 [00:00<00:00, 23.82it/s]


Epoch: 1 Train loss: 1.0051


100%|██████████| 10/10 [00:00<00:00, 49.85it/s]


Epoch: 1 Test loss: 0.8348


100%|██████████| 18/18 [00:00<00:00, 40.21it/s]


Epoch: 2 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 54.34it/s]


Epoch: 2 Test loss: 0.8332


100%|██████████| 18/18 [00:00<00:00, 41.36it/s]


Epoch: 3 Train loss: 1.0014


100%|██████████| 10/10 [00:00<00:00, 54.52it/s]


Epoch: 3 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 34.89it/s]


Epoch: 4 Train loss: 1.0020


100%|██████████| 10/10 [00:00<00:00, 37.52it/s]


Epoch: 4 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 34.83it/s]


Epoch: 5 Train loss: 0.9964


100%|██████████| 10/10 [00:00<00:00, 54.23it/s]


Epoch: 5 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 37.16it/s]


Epoch: 6 Train loss: 0.9990


100%|██████████| 10/10 [00:00<00:00, 50.72it/s]


Epoch: 6 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 38.95it/s]


Epoch: 7 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 49.57it/s]


Epoch: 7 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 21.47it/s]


Epoch: 1 Train loss: 1.0051


100%|██████████| 10/10 [00:00<00:00, 49.28it/s]


Epoch: 1 Test loss: 0.8362


100%|██████████| 18/18 [00:00<00:00, 40.38it/s]


Epoch: 2 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 53.76it/s]


Epoch: 2 Test loss: 0.8334


100%|██████████| 18/18 [00:00<00:00, 33.05it/s]


Epoch: 3 Train loss: 1.0014


100%|██████████| 10/10 [00:00<00:00, 52.92it/s]


Epoch: 3 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 41.06it/s]


Epoch: 4 Train loss: 1.0020


100%|██████████| 10/10 [00:00<00:00, 54.33it/s]


Epoch: 4 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 40.11it/s]


Epoch: 5 Train loss: 0.9964


100%|██████████| 10/10 [00:00<00:00, 54.38it/s]


Epoch: 5 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 32.05it/s]


Epoch: 6 Train loss: 0.9990


100%|██████████| 10/10 [00:00<00:00, 54.54it/s]


Epoch: 6 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 40.63it/s]


Epoch: 7 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 50.35it/s]


Epoch: 7 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 18.23it/s]


Epoch: 1 Train loss: 1.0052


100%|██████████| 10/10 [00:00<00:00, 49.24it/s]


Epoch: 1 Test loss: 0.8348


100%|██████████| 18/18 [00:00<00:00, 33.51it/s]


Epoch: 2 Train loss: 1.0008


100%|██████████| 10/10 [00:00<00:00, 52.84it/s]


Epoch: 2 Test loss: 0.8332


100%|██████████| 18/18 [00:00<00:00, 39.60it/s]


Epoch: 3 Train loss: 1.0015


100%|██████████| 10/10 [00:00<00:00, 40.47it/s]


Epoch: 3 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 33.79it/s]


Epoch: 4 Train loss: 1.0020


100%|██████████| 10/10 [00:00<00:00, 35.18it/s]


Epoch: 4 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 39.03it/s]


Epoch: 5 Train loss: 0.9964


100%|██████████| 10/10 [00:00<00:00, 53.16it/s]


Epoch: 5 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 40.77it/s]


Epoch: 6 Train loss: 0.9990


100%|██████████| 10/10 [00:00<00:00, 53.34it/s]


Epoch: 6 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 34.23it/s]


Epoch: 7 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 46.87it/s]


Epoch: 7 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 28.46it/s]


Epoch: 1 Train loss: 1.0053


100%|██████████| 10/10 [00:00<00:00, 35.13it/s]


Epoch: 1 Test loss: 0.8357


100%|██████████| 18/18 [00:00<00:00, 34.78it/s]


Epoch: 2 Train loss: 1.0008


100%|██████████| 10/10 [00:00<00:00, 49.41it/s]


Epoch: 2 Test loss: 0.8334


100%|██████████| 18/18 [00:00<00:00, 37.43it/s]


Epoch: 3 Train loss: 1.0015


100%|██████████| 10/10 [00:00<00:00, 50.03it/s]


Epoch: 3 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 41.55it/s]


Epoch: 4 Train loss: 1.0020


100%|██████████| 10/10 [00:00<00:00, 54.02it/s]


Epoch: 4 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 35.67it/s]


Epoch: 5 Train loss: 0.9964


100%|██████████| 10/10 [00:00<00:00, 34.42it/s]


Epoch: 5 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 32.83it/s]


Epoch: 6 Train loss: 0.9990


100%|██████████| 10/10 [00:00<00:00, 43.24it/s]


Epoch: 6 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 26.13it/s]


Epoch: 7 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 37.89it/s]


Epoch: 7 Test loss: 0.8327


100%|██████████| 18/18 [00:00<00:00, 24.03it/s]


Epoch: 1 Train loss: 1.0024


100%|██████████| 10/10 [00:00<00:00, 24.68it/s]


Epoch: 1 Test loss: 0.9149


100%|██████████| 18/18 [00:00<00:00, 36.13it/s]


Epoch: 2 Train loss: 0.9974


100%|██████████| 10/10 [00:00<00:00, 46.75it/s]


Epoch: 2 Test loss: 0.9132


100%|██████████| 18/18 [00:00<00:00, 25.83it/s]


Epoch: 3 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 46.57it/s]


Epoch: 3 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 36.40it/s]


Epoch: 4 Train loss: 0.9990


100%|██████████| 10/10 [00:00<00:00, 40.41it/s]


Epoch: 4 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 35.77it/s]


Epoch: 5 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 55.20it/s]


Epoch: 5 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 31.36it/s]


Epoch: 6 Train loss: 0.9996


100%|██████████| 10/10 [00:00<00:00, 45.39it/s]


Epoch: 6 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 31.45it/s]


Epoch: 7 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 39.66it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 21.47it/s]


Epoch: 1 Train loss: 1.0025


100%|██████████| 10/10 [00:00<00:00, 49.84it/s]


Epoch: 1 Test loss: 0.9163


100%|██████████| 18/18 [00:00<00:00, 34.28it/s]


Epoch: 2 Train loss: 0.9974


100%|██████████| 10/10 [00:00<00:00, 54.19it/s]


Epoch: 2 Test loss: 0.9135


100%|██████████| 18/18 [00:00<00:00, 41.65it/s]


Epoch: 3 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 54.84it/s]


Epoch: 3 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 41.97it/s]


Epoch: 4 Train loss: 0.9990


100%|██████████| 10/10 [00:00<00:00, 54.81it/s]


Epoch: 4 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 37.93it/s]


Epoch: 5 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 50.33it/s]


Epoch: 5 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 39.59it/s]


Epoch: 6 Train loss: 0.9996


100%|██████████| 10/10 [00:00<00:00, 40.14it/s]


Epoch: 6 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 29.10it/s]


Epoch: 7 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 26.75it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 25.67it/s]


Epoch: 1 Train loss: 1.0022


100%|██████████| 10/10 [00:00<00:00, 50.08it/s]


Epoch: 1 Test loss: 0.9146


100%|██████████| 18/18 [00:00<00:00, 29.18it/s]


Epoch: 2 Train loss: 0.9974


100%|██████████| 10/10 [00:00<00:00, 33.77it/s]


Epoch: 2 Test loss: 0.9132


100%|██████████| 18/18 [00:00<00:00, 32.58it/s]


Epoch: 3 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 52.99it/s]


Epoch: 3 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 23.40it/s]


Epoch: 4 Train loss: 0.9990


100%|██████████| 10/10 [00:00<00:00, 31.02it/s]


Epoch: 4 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 27.15it/s]


Epoch: 5 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 33.19it/s]


Epoch: 5 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 37.28it/s]


Epoch: 6 Train loss: 0.9996


100%|██████████| 10/10 [00:00<00:00, 49.39it/s]


Epoch: 6 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 36.51it/s]


Epoch: 7 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 41.81it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 22.19it/s]


Epoch: 1 Train loss: 1.0025


100%|██████████| 10/10 [00:00<00:00, 42.82it/s]


Epoch: 1 Test loss: 0.9158


100%|██████████| 18/18 [00:00<00:00, 38.37it/s]


Epoch: 2 Train loss: 0.9975


100%|██████████| 10/10 [00:00<00:00, 54.84it/s]


Epoch: 2 Test loss: 0.9134


100%|██████████| 18/18 [00:00<00:00, 41.37it/s]


Epoch: 3 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 53.82it/s]


Epoch: 3 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 42.30it/s]


Epoch: 4 Train loss: 0.9991


100%|██████████| 10/10 [00:00<00:00, 53.74it/s]


Epoch: 4 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 35.12it/s]


Epoch: 5 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 52.74it/s]


Epoch: 5 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 37.11it/s]


Epoch: 6 Train loss: 0.9996


100%|██████████| 10/10 [00:00<00:00, 51.99it/s]


Epoch: 6 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 41.92it/s]


Epoch: 7 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 41.15it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 30.50it/s]


Epoch: 1 Train loss: 1.0041


100%|██████████| 10/10 [00:00<00:00, 42.22it/s]


Epoch: 1 Test loss: 1.2816


100%|██████████| 18/18 [00:00<00:00, 28.31it/s]


Epoch: 2 Train loss: 1.0012


100%|██████████| 10/10 [00:00<00:00, 51.05it/s]


Epoch: 2 Test loss: 1.2800


100%|██████████| 18/18 [00:00<00:00, 39.92it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 53.95it/s]


Epoch: 3 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 41.77it/s]


Epoch: 4 Train loss: 1.0036


100%|██████████| 10/10 [00:00<00:00, 50.28it/s]


Epoch: 4 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 23.60it/s]


Epoch: 5 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 37.93it/s]


Epoch: 5 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 28.96it/s]


Epoch: 6 Train loss: 0.9995


100%|██████████| 10/10 [00:00<00:00, 36.73it/s]


Epoch: 6 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 33.17it/s]


Epoch: 7 Train loss: 0.9972


100%|██████████| 10/10 [00:00<00:00, 35.59it/s]


Epoch: 7 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 24.37it/s]


Epoch: 1 Train loss: 1.0040


100%|██████████| 10/10 [00:00<00:00, 48.94it/s]


Epoch: 1 Test loss: 1.2830


100%|██████████| 18/18 [00:00<00:00, 39.19it/s]


Epoch: 2 Train loss: 1.0012


100%|██████████| 10/10 [00:00<00:00, 51.25it/s]


Epoch: 2 Test loss: 1.2802


100%|██████████| 18/18 [00:00<00:00, 38.08it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 52.11it/s]


Epoch: 3 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 40.51it/s]


Epoch: 4 Train loss: 1.0036


100%|██████████| 10/10 [00:00<00:00, 54.71it/s]


Epoch: 4 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 40.73it/s]


Epoch: 5 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 51.57it/s]


Epoch: 5 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 35.07it/s]


Epoch: 6 Train loss: 0.9995


100%|██████████| 10/10 [00:00<00:00, 47.82it/s]


Epoch: 6 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 24.89it/s]


Epoch: 7 Train loss: 0.9972


100%|██████████| 10/10 [00:00<00:00, 34.27it/s]


Epoch: 7 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 23.06it/s]


Epoch: 1 Train loss: 1.0040


100%|██████████| 10/10 [00:00<00:00, 37.64it/s]


Epoch: 1 Test loss: 1.2813


100%|██████████| 18/18 [00:00<00:00, 28.18it/s]


Epoch: 2 Train loss: 1.0013


100%|██████████| 10/10 [00:00<00:00, 36.81it/s]


Epoch: 2 Test loss: 1.2799


100%|██████████| 18/18 [00:00<00:00, 27.61it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 47.81it/s]


Epoch: 3 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 38.37it/s]


Epoch: 4 Train loss: 1.0036


100%|██████████| 10/10 [00:00<00:00, 50.38it/s]


Epoch: 4 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 26.81it/s]


Epoch: 5 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 45.79it/s]


Epoch: 5 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 37.15it/s]


Epoch: 6 Train loss: 0.9995


100%|██████████| 10/10 [00:00<00:00, 48.89it/s]


Epoch: 6 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 38.28it/s]


Epoch: 7 Train loss: 0.9972


100%|██████████| 10/10 [00:00<00:00, 22.93it/s]


Epoch: 7 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 28.33it/s]


Epoch: 1 Train loss: 1.0039


100%|██████████| 10/10 [00:00<00:00, 29.39it/s]


Epoch: 1 Test loss: 1.2815


100%|██████████| 18/18 [00:00<00:00, 24.65it/s]


Epoch: 2 Train loss: 1.0012


100%|██████████| 10/10 [00:00<00:00, 37.83it/s]


Epoch: 2 Test loss: 1.2799


100%|██████████| 18/18 [00:00<00:00, 32.45it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 39.10it/s]


Epoch: 3 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 23.16it/s]


Epoch: 4 Train loss: 1.0036


100%|██████████| 10/10 [00:00<00:00, 41.23it/s]


Epoch: 4 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 30.12it/s]


Epoch: 5 Train loss: 1.0007


100%|██████████| 10/10 [00:00<00:00, 31.45it/s]


Epoch: 5 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 36.27it/s]


Epoch: 6 Train loss: 0.9995


100%|██████████| 10/10 [00:00<00:00, 53.02it/s]


Epoch: 6 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 39.69it/s]


Epoch: 7 Train loss: 0.9972


100%|██████████| 10/10 [00:00<00:00, 44.56it/s]


Epoch: 7 Test loss: 1.2795


#### Converter4

In [32]:
from fmri_reconstruction_with_dmvae.models.converter import Converter4

zp_dim = 256
zs_dim = 768
hidden_dim = 1024

lr = 1e-4
weight_decay = 1e-1

epochs = 7

conversion_output_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/results/deeprecon/conversion/"

pattern_corr_result = []
profile_corr_result = []

for subj_target, subj_source in permutations(subj_list, 2):
    s_t = f"{int(subj_target):02d}"
    s_s = f"{int(subj_source):02d}"

    model = Converter4(all_subj_num_voxels[f"subj{s_s}"], all_subj_num_voxels[f"subj{s_t}"], zp_dim, zs_dim, hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = torch.nn.MSELoss()

    g.manual_seed(42)
    for epoch in range(1, epochs + 1):
        train_loss = train(epoch)
        test_loss, target_brain, recon_brain, label = test(epoch)

    target_brain = torch.cat(target_brain, dim=0).cpu().numpy()
    recon_brain = torch.cat(recon_brain, dim=0).cpu().numpy()
    label = torch.cat(label, dim=0).cpu().numpy()

    pattern_corr = calculate_pattern_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(pattern_corr):
        pattern_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Image_idx': i+1}
        )

    profile_corr = calculate_profile_correlation(target_brain, recon_brain, label)
    for i, corr in enumerate(profile_corr):
        profile_corr_result.append({
            'Subject_target': subj_target, 
            'Subject_source': subj_source,           
            'Correlation': corr, 
            'Voxel_idx': i}
        )

save_result(pattern_corr_result, conversion_output_path, "pattern_correlation_converter4.csv")
save_result(profile_corr_result, conversion_output_path, "profile_correlation_converter4.csv")

100%|██████████| 18/18 [00:00<00:00, 34.51it/s]


Epoch: 1 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 45.55it/s]


Epoch: 1 Test loss: 0.9850


100%|██████████| 18/18 [00:00<00:00, 28.53it/s]


Epoch: 2 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 39.37it/s]


Epoch: 2 Test loss: 0.9844


100%|██████████| 18/18 [00:00<00:00, 27.60it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 47.48it/s]


Epoch: 3 Test loss: 0.9843


100%|██████████| 18/18 [00:00<00:00, 40.39it/s]


Epoch: 4 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 47.64it/s]


Epoch: 4 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 41.54it/s]


Epoch: 5 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 47.71it/s]


Epoch: 5 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 41.77it/s]


Epoch: 6 Train loss: 1.0052


100%|██████████| 10/10 [00:00<00:00, 44.50it/s]


Epoch: 6 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 46.07it/s]


Epoch: 7 Train loss: 0.9951


100%|██████████| 10/10 [00:00<00:00, 47.33it/s]


Epoch: 7 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 38.85it/s]


Epoch: 1 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 53.88it/s]


Epoch: 1 Test loss: 0.9847


100%|██████████| 18/18 [00:00<00:00, 32.80it/s]


Epoch: 2 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 54.20it/s]


Epoch: 2 Test loss: 0.9844


100%|██████████| 18/18 [00:00<00:00, 46.66it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 57.08it/s]


Epoch: 3 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 45.54it/s]


Epoch: 4 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 55.12it/s]


Epoch: 4 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 41.60it/s]


Epoch: 5 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 50.38it/s]


Epoch: 5 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 46.73it/s]


Epoch: 6 Train loss: 1.0052


100%|██████████| 10/10 [00:00<00:00, 43.96it/s]


Epoch: 6 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 34.37it/s]


Epoch: 7 Train loss: 0.9951


100%|██████████| 10/10 [00:00<00:00, 44.97it/s]


Epoch: 7 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 31.02it/s]


Epoch: 1 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 53.06it/s]


Epoch: 1 Test loss: 0.9847


100%|██████████| 18/18 [00:00<00:00, 36.99it/s]


Epoch: 2 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 52.79it/s]


Epoch: 2 Test loss: 0.9844


100%|██████████| 18/18 [00:00<00:00, 47.26it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 53.81it/s]


Epoch: 3 Test loss: 0.9843


100%|██████████| 18/18 [00:00<00:00, 48.06it/s]


Epoch: 4 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 56.23it/s]


Epoch: 4 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 48.80it/s]


Epoch: 5 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 56.55it/s]


Epoch: 5 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 45.57it/s]


Epoch: 6 Train loss: 1.0052


100%|██████████| 10/10 [00:00<00:00, 55.63it/s]


Epoch: 6 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 48.24it/s]


Epoch: 7 Train loss: 0.9951


100%|██████████| 10/10 [00:00<00:00, 48.76it/s]


Epoch: 7 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 38.51it/s]


Epoch: 1 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 52.67it/s]


Epoch: 1 Test loss: 0.9849


100%|██████████| 18/18 [00:00<00:00, 36.18it/s]


Epoch: 2 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 47.67it/s]


Epoch: 2 Test loss: 0.9844


100%|██████████| 18/18 [00:00<00:00, 37.86it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 51.43it/s]


Epoch: 3 Test loss: 0.9843


100%|██████████| 18/18 [00:00<00:00, 43.55it/s]


Epoch: 4 Train loss: 1.0004


100%|██████████| 10/10 [00:00<00:00, 51.91it/s]


Epoch: 4 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 18.74it/s]


Epoch: 5 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 32.23it/s]


Epoch: 5 Test loss: 0.9842


100%|██████████| 18/18 [00:00<00:00, 32.88it/s]


Epoch: 6 Train loss: 1.0052


100%|██████████| 10/10 [00:00<00:00, 30.90it/s]


Epoch: 6 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 33.24it/s]


Epoch: 7 Train loss: 0.9951


100%|██████████| 10/10 [00:00<00:00, 35.88it/s]


Epoch: 7 Test loss: 0.9841


100%|██████████| 18/18 [00:00<00:00, 27.63it/s]


Epoch: 1 Train loss: 1.0030


100%|██████████| 10/10 [00:00<00:00, 29.54it/s]


Epoch: 1 Test loss: 1.6040


100%|██████████| 18/18 [00:00<00:00, 35.06it/s]


Epoch: 2 Train loss: 0.9987


100%|██████████| 10/10 [00:00<00:00, 47.02it/s]


Epoch: 2 Test loss: 1.6037


100%|██████████| 18/18 [00:00<00:00, 39.72it/s]


Epoch: 3 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 50.16it/s]


Epoch: 3 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 39.45it/s]


Epoch: 4 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 51.75it/s]


Epoch: 4 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 43.55it/s]


Epoch: 5 Train loss: 1.0020


100%|██████████| 10/10 [00:00<00:00, 51.30it/s]


Epoch: 5 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 43.91it/s]


Epoch: 6 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 52.63it/s]


Epoch: 6 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 46.06it/s]


Epoch: 7 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 47.91it/s]


Epoch: 7 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 37.01it/s]


Epoch: 1 Train loss: 1.0031


100%|██████████| 10/10 [00:00<00:00, 43.78it/s]


Epoch: 1 Test loss: 1.6040


100%|██████████| 18/18 [00:00<00:00, 28.82it/s]


Epoch: 2 Train loss: 0.9988


100%|██████████| 10/10 [00:00<00:00, 46.14it/s]


Epoch: 2 Test loss: 1.6037


100%|██████████| 18/18 [00:00<00:00, 38.68it/s]


Epoch: 3 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 53.75it/s]


Epoch: 3 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 46.19it/s]


Epoch: 4 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 55.53it/s]


Epoch: 4 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 43.85it/s]


Epoch: 5 Train loss: 1.0020


100%|██████████| 10/10 [00:00<00:00, 57.21it/s]


Epoch: 5 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 35.75it/s]


Epoch: 6 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 55.27it/s]


Epoch: 6 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 44.36it/s]


Epoch: 7 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 50.33it/s]


Epoch: 7 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 37.29it/s]


Epoch: 1 Train loss: 1.0030


100%|██████████| 10/10 [00:00<00:00, 50.53it/s]


Epoch: 1 Test loss: 1.6040


100%|██████████| 18/18 [00:00<00:00, 31.13it/s]


Epoch: 2 Train loss: 0.9987


100%|██████████| 10/10 [00:00<00:00, 49.87it/s]


Epoch: 2 Test loss: 1.6037


100%|██████████| 18/18 [00:00<00:00, 30.55it/s]


Epoch: 3 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 51.34it/s]


Epoch: 3 Test loss: 1.6036


100%|██████████| 18/18 [00:00<00:00, 43.13it/s]


Epoch: 4 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 51.73it/s]


Epoch: 4 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 36.56it/s]


Epoch: 5 Train loss: 1.0020


100%|██████████| 10/10 [00:00<00:00, 39.37it/s]


Epoch: 5 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 41.63it/s]


Epoch: 6 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 49.46it/s]


Epoch: 6 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 41.81it/s]


Epoch: 7 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 29.83it/s]


Epoch: 7 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 29.82it/s]


Epoch: 1 Train loss: 1.0030


100%|██████████| 10/10 [00:00<00:00, 32.97it/s]


Epoch: 1 Test loss: 1.6041


100%|██████████| 18/18 [00:00<00:00, 32.44it/s]


Epoch: 2 Train loss: 0.9987


100%|██████████| 10/10 [00:00<00:00, 44.04it/s]


Epoch: 2 Test loss: 1.6037


100%|██████████| 18/18 [00:00<00:00, 42.19it/s]


Epoch: 3 Train loss: 1.0010


100%|██████████| 10/10 [00:00<00:00, 49.46it/s]


Epoch: 3 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 42.43it/s]


Epoch: 4 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 48.61it/s]


Epoch: 4 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 42.11it/s]


Epoch: 5 Train loss: 1.0020


100%|██████████| 10/10 [00:00<00:00, 55.86it/s]


Epoch: 5 Test loss: 1.6035


100%|██████████| 18/18 [00:00<00:00, 46.05it/s]


Epoch: 6 Train loss: 1.0002


100%|██████████| 10/10 [00:00<00:00, 55.84it/s]


Epoch: 6 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 46.64it/s]


Epoch: 7 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 35.58it/s]


Epoch: 7 Test loss: 1.6034


100%|██████████| 18/18 [00:00<00:00, 31.06it/s]


Epoch: 1 Train loss: 1.0011


100%|██████████| 10/10 [00:00<00:00, 48.18it/s]


Epoch: 1 Test loss: 0.8334


100%|██████████| 18/18 [00:00<00:00, 33.54it/s]


Epoch: 2 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 50.69it/s]


Epoch: 2 Test loss: 0.8330


100%|██████████| 18/18 [00:00<00:00, 42.25it/s]


Epoch: 3 Train loss: 1.0015


100%|██████████| 10/10 [00:00<00:00, 48.65it/s]


Epoch: 3 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 40.03it/s]


Epoch: 4 Train loss: 1.0022


100%|██████████| 10/10 [00:00<00:00, 18.26it/s]


Epoch: 4 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 41.93it/s]


Epoch: 5 Train loss: 0.9965


100%|██████████| 10/10 [00:00<00:00, 56.10it/s]


Epoch: 5 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 45.54it/s]


Epoch: 6 Train loss: 0.9991


100%|██████████| 10/10 [00:00<00:00, 56.47it/s]


Epoch: 6 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 39.18it/s]


Epoch: 7 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 50.71it/s]


Epoch: 7 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 39.12it/s]


Epoch: 1 Train loss: 1.0012


100%|██████████| 10/10 [00:00<00:00, 52.64it/s]


Epoch: 1 Test loss: 0.8336


100%|██████████| 18/18 [00:00<00:00, 41.61it/s]


Epoch: 2 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 51.07it/s]


Epoch: 2 Test loss: 0.8331


100%|██████████| 18/18 [00:00<00:00, 40.14it/s]


Epoch: 3 Train loss: 1.0015


100%|██████████| 10/10 [00:00<00:00, 52.04it/s]


Epoch: 3 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 34.68it/s]


Epoch: 4 Train loss: 1.0022


100%|██████████| 10/10 [00:00<00:00, 47.04it/s]


Epoch: 4 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 33.47it/s]


Epoch: 5 Train loss: 0.9965


100%|██████████| 10/10 [00:00<00:00, 29.84it/s]


Epoch: 5 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 33.58it/s]


Epoch: 6 Train loss: 0.9991


100%|██████████| 10/10 [00:00<00:00, 24.06it/s]


Epoch: 6 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 34.18it/s]


Epoch: 7 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 27.20it/s]


Epoch: 7 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 30.57it/s]


Epoch: 1 Train loss: 1.0011


100%|██████████| 10/10 [00:00<00:00, 29.21it/s]


Epoch: 1 Test loss: 0.8333


100%|██████████| 18/18 [00:00<00:00, 26.40it/s]


Epoch: 2 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 39.38it/s]


Epoch: 2 Test loss: 0.8330


100%|██████████| 18/18 [00:00<00:00, 35.43it/s]


Epoch: 3 Train loss: 1.0015


100%|██████████| 10/10 [00:00<00:00, 34.44it/s]


Epoch: 3 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 43.40it/s]


Epoch: 4 Train loss: 1.0022


100%|██████████| 10/10 [00:00<00:00, 40.15it/s]


Epoch: 4 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 29.76it/s]


Epoch: 5 Train loss: 0.9965


100%|██████████| 10/10 [00:00<00:00, 42.67it/s]


Epoch: 5 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 30.21it/s]


Epoch: 6 Train loss: 0.9991


100%|██████████| 10/10 [00:00<00:00, 35.93it/s]


Epoch: 6 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 38.59it/s]


Epoch: 7 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 32.83it/s]


Epoch: 7 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 33.17it/s]


Epoch: 1 Train loss: 1.0012


100%|██████████| 10/10 [00:00<00:00, 39.83it/s]


Epoch: 1 Test loss: 0.8335


100%|██████████| 18/18 [00:00<00:00, 28.92it/s]


Epoch: 2 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 28.17it/s]


Epoch: 2 Test loss: 0.8331


100%|██████████| 18/18 [00:00<00:00, 30.73it/s]


Epoch: 3 Train loss: 1.0015


100%|██████████| 10/10 [00:00<00:00, 35.38it/s]


Epoch: 3 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 41.43it/s]


Epoch: 4 Train loss: 1.0022


100%|██████████| 10/10 [00:00<00:00, 54.16it/s]


Epoch: 4 Test loss: 0.8329


100%|██████████| 18/18 [00:00<00:00, 43.60it/s]


Epoch: 5 Train loss: 0.9965


100%|██████████| 10/10 [00:00<00:00, 54.17it/s]


Epoch: 5 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 44.66it/s]


Epoch: 6 Train loss: 0.9991


100%|██████████| 10/10 [00:00<00:00, 37.35it/s]


Epoch: 6 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 43.78it/s]


Epoch: 7 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 49.13it/s]


Epoch: 7 Test loss: 0.8328


100%|██████████| 18/18 [00:00<00:00, 36.74it/s]


Epoch: 1 Train loss: 0.9983


100%|██████████| 10/10 [00:00<00:00, 39.17it/s]


Epoch: 1 Test loss: 0.9134


100%|██████████| 18/18 [00:00<00:00, 27.91it/s]


Epoch: 2 Train loss: 0.9967


100%|██████████| 10/10 [00:00<00:00, 41.34it/s]


Epoch: 2 Test loss: 0.9131


100%|██████████| 18/18 [00:00<00:00, 45.54it/s]


Epoch: 3 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 51.24it/s]


Epoch: 3 Test loss: 0.9130


100%|██████████| 18/18 [00:00<00:00, 45.23it/s]


Epoch: 4 Train loss: 0.9992


100%|██████████| 10/10 [00:00<00:00, 56.58it/s]


Epoch: 4 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 46.87it/s]


Epoch: 5 Train loss: 0.9985


100%|██████████| 10/10 [00:00<00:00, 55.55it/s]


Epoch: 5 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 39.59it/s]


Epoch: 6 Train loss: 0.9997


100%|██████████| 10/10 [00:00<00:00, 48.64it/s]


Epoch: 6 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 44.92it/s]


Epoch: 7 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 48.85it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 34.30it/s]


Epoch: 1 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 50.63it/s]


Epoch: 1 Test loss: 0.9137


100%|██████████| 18/18 [00:00<00:00, 26.54it/s]


Epoch: 2 Train loss: 0.9967


100%|██████████| 10/10 [00:00<00:00, 43.64it/s]


Epoch: 2 Test loss: 0.9132


100%|██████████| 18/18 [00:00<00:00, 20.21it/s]


Epoch: 3 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 52.51it/s]


Epoch: 3 Test loss: 0.9130


100%|██████████| 18/18 [00:00<00:00, 41.50it/s]


Epoch: 4 Train loss: 0.9992


100%|██████████| 10/10 [00:00<00:00, 51.86it/s]


Epoch: 4 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 47.77it/s]


Epoch: 5 Train loss: 0.9985


100%|██████████| 10/10 [00:00<00:00, 40.16it/s]


Epoch: 5 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 37.58it/s]


Epoch: 6 Train loss: 0.9997


100%|██████████| 10/10 [00:00<00:00, 55.33it/s]


Epoch: 6 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 34.08it/s]


Epoch: 7 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 47.27it/s]


Epoch: 7 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 37.03it/s]


Epoch: 1 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 51.01it/s]


Epoch: 1 Test loss: 0.9134


100%|██████████| 18/18 [00:00<00:00, 35.64it/s]


Epoch: 2 Train loss: 0.9967


100%|██████████| 10/10 [00:00<00:00, 53.79it/s]


Epoch: 2 Test loss: 0.9131


100%|██████████| 18/18 [00:00<00:00, 45.72it/s]


Epoch: 3 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 54.80it/s]


Epoch: 3 Test loss: 0.9130


100%|██████████| 18/18 [00:00<00:00, 37.29it/s]


Epoch: 4 Train loss: 0.9992


100%|██████████| 10/10 [00:00<00:00, 31.99it/s]


Epoch: 4 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 36.36it/s]


Epoch: 5 Train loss: 0.9985


100%|██████████| 10/10 [00:00<00:00, 57.19it/s]


Epoch: 5 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 48.53it/s]


Epoch: 6 Train loss: 0.9997


100%|██████████| 10/10 [00:00<00:00, 59.18it/s]


Epoch: 6 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 49.71it/s]


Epoch: 7 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 52.77it/s]


Epoch: 7 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 32.56it/s]


Epoch: 1 Train loss: 0.9984


100%|██████████| 10/10 [00:00<00:00, 51.16it/s]


Epoch: 1 Test loss: 0.9136


100%|██████████| 18/18 [00:00<00:00, 25.90it/s]


Epoch: 2 Train loss: 0.9967


100%|██████████| 10/10 [00:00<00:00, 31.74it/s]


Epoch: 2 Test loss: 0.9131


100%|██████████| 18/18 [00:00<00:00, 31.89it/s]


Epoch: 3 Train loss: 1.0018


100%|██████████| 10/10 [00:00<00:00, 53.63it/s]


Epoch: 3 Test loss: 0.9130


100%|██████████| 18/18 [00:00<00:00, 43.50it/s]


Epoch: 4 Train loss: 0.9992


100%|██████████| 10/10 [00:00<00:00, 55.75it/s]


Epoch: 4 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 44.14it/s]


Epoch: 5 Train loss: 0.9985


100%|██████████| 10/10 [00:00<00:00, 51.77it/s]


Epoch: 5 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 43.61it/s]


Epoch: 6 Train loss: 0.9997


100%|██████████| 10/10 [00:00<00:00, 54.22it/s]


Epoch: 6 Test loss: 0.9129


100%|██████████| 18/18 [00:00<00:00, 36.39it/s]


Epoch: 7 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 25.23it/s]


Epoch: 7 Test loss: 0.9128


100%|██████████| 18/18 [00:00<00:00, 27.56it/s]


Epoch: 1 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 38.54it/s]


Epoch: 1 Test loss: 1.2801


100%|██████████| 18/18 [00:00<00:00, 26.69it/s]


Epoch: 2 Train loss: 1.0005


100%|██████████| 10/10 [00:00<00:00, 51.51it/s]


Epoch: 2 Test loss: 1.2798


100%|██████████| 18/18 [00:00<00:00, 48.74it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 46.98it/s]


Epoch: 3 Test loss: 1.2797


100%|██████████| 18/18 [00:00<00:00, 44.75it/s]


Epoch: 4 Train loss: 1.0037


100%|██████████| 10/10 [00:00<00:00, 54.88it/s]


Epoch: 4 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 46.46it/s]


Epoch: 5 Train loss: 1.0008


100%|██████████| 10/10 [00:00<00:00, 53.32it/s]


Epoch: 5 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 37.02it/s]


Epoch: 6 Train loss: 0.9996


100%|██████████| 10/10 [00:00<00:00, 52.91it/s]


Epoch: 6 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 39.74it/s]


Epoch: 7 Train loss: 0.9973


100%|██████████| 10/10 [00:00<00:00, 34.70it/s]


Epoch: 7 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 33.36it/s]


Epoch: 1 Train loss: 1.0001


100%|██████████| 10/10 [00:00<00:00, 49.87it/s]


Epoch: 1 Test loss: 1.2804


100%|██████████| 18/18 [00:00<00:00, 33.61it/s]


Epoch: 2 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 33.21it/s]


Epoch: 2 Test loss: 1.2799


100%|██████████| 18/18 [00:00<00:00, 40.73it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 52.82it/s]


Epoch: 3 Test loss: 1.2797


100%|██████████| 18/18 [00:00<00:00, 38.43it/s]


Epoch: 4 Train loss: 1.0037


100%|██████████| 10/10 [00:00<00:00, 53.00it/s]


Epoch: 4 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 45.71it/s]


Epoch: 5 Train loss: 1.0008


100%|██████████| 10/10 [00:00<00:00, 55.18it/s]


Epoch: 5 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 31.79it/s]


Epoch: 6 Train loss: 0.9996


100%|██████████| 10/10 [00:00<00:00, 40.00it/s]


Epoch: 6 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 26.84it/s]


Epoch: 7 Train loss: 0.9973


100%|██████████| 10/10 [00:00<00:00, 34.27it/s]


Epoch: 7 Test loss: 1.2795


100%|██████████| 18/18 [00:00<00:00, 26.87it/s]


Epoch: 1 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 38.69it/s]


Epoch: 1 Test loss: 1.2801


100%|██████████| 18/18 [00:00<00:00, 29.01it/s]


Epoch: 2 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 16.91it/s]


Epoch: 2 Test loss: 1.2798


100%|██████████| 18/18 [00:00<00:00, 33.47it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 40.53it/s]


Epoch: 3 Test loss: 1.2797


100%|██████████| 18/18 [00:00<00:00, 33.89it/s]


Epoch: 4 Train loss: 1.0037


100%|██████████| 10/10 [00:00<00:00, 42.49it/s]


Epoch: 4 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 30.23it/s]


Epoch: 5 Train loss: 1.0008


100%|██████████| 10/10 [00:00<00:00, 39.59it/s]


Epoch: 5 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 34.87it/s]


Epoch: 6 Train loss: 0.9996


100%|██████████| 10/10 [00:00<00:00, 30.67it/s]


Epoch: 6 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 30.39it/s]


Epoch: 7 Train loss: 0.9973


100%|██████████| 10/10 [00:00<00:00, 28.22it/s]


Epoch: 7 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 30.93it/s]


Epoch: 1 Train loss: 1.0000


100%|██████████| 10/10 [00:00<00:00, 38.25it/s]


Epoch: 1 Test loss: 1.2801


100%|██████████| 18/18 [00:00<00:00, 21.18it/s]


Epoch: 2 Train loss: 1.0006


100%|██████████| 10/10 [00:00<00:00, 26.99it/s]


Epoch: 2 Test loss: 1.2798


100%|██████████| 18/18 [00:00<00:00, 34.59it/s]


Epoch: 3 Train loss: 0.9976


100%|██████████| 10/10 [00:00<00:00, 39.20it/s]


Epoch: 3 Test loss: 1.2797


100%|██████████| 18/18 [00:00<00:00, 36.89it/s]


Epoch: 4 Train loss: 1.0037


100%|██████████| 10/10 [00:00<00:00, 37.55it/s]


Epoch: 4 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 25.58it/s]


Epoch: 5 Train loss: 1.0008


100%|██████████| 10/10 [00:00<00:00, 39.84it/s]


Epoch: 5 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 33.71it/s]


Epoch: 6 Train loss: 0.9996


100%|██████████| 10/10 [00:00<00:00, 23.41it/s]


Epoch: 6 Test loss: 1.2796


100%|██████████| 18/18 [00:00<00:00, 39.03it/s]


Epoch: 7 Train loss: 0.9973


100%|██████████| 10/10 [00:00<00:00, 48.19it/s]


Epoch: 7 Test loss: 1.2796


### Single

In [None]:
from fmri_reconstruction_with_dmvae.models.converter import Converter

lr = 1e-4

subj_target = 1
subj_source = 2

s_t = f"{int(subj_target):02d}"
s_s = f"{int(subj_source):02d}"

zp_dim = 256
zs_dim = 768
hidden_dim = 1024
model = Converter(all_subj_num_voxels[f"subj{s_s}"], all_subj_num_voxels[f"subj{s_t}"], zp_dim, zs_dim, hidden_dim).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.MSELoss()

In [None]:
def train(epoch):
    train_loss = 0
    total_samples = 0

    for data in tqdm(train_dl):
        optimizer.zero_grad()

        source_batch = data[f"subj{s_s}"].to(torch.float32).to(device)
        target_batch = data[f"subj{s_t}"].to(torch.float32).to(device)
        recon_batch = model(source_batch)
        
        loss = loss_fn(recon_batch, target_batch)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * recon_batch.size(0)
        total_samples += recon_batch.size(0)

    train_loss = train_loss / total_samples

    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [None]:
def test(epoch):
    model.eval()
    test_loss = 0
    total_samples = 0

    target_brain = []
    recon_brain = []
    label = []

    with torch.no_grad():
        for data in tqdm(test_dl):
            source_batch = data[f"subj{s_s}"].to(torch.float32).to(device)
            target_batch = data[f"subj{s_t}"].to(torch.float32).to(device)
            recon_batch = model(source_batch)

            loss = loss_fn(recon_batch, target_batch)

            test_loss += loss.item() * recon_batch.size(0)
            total_samples += recon_batch.size(0)

            if epoch == epochs:
                target_brain.append(target_batch.cpu())
                recon_brain.append(recon_batch.cpu())
                label.append(data["image_index"].cpu())
            
    test_loss = test_loss / total_samples

    print('Epoch: {} Test loss: {:.4f}'.format(epoch, test_loss))
    return test_loss, target_brain, recon_brain, label

In [70]:
epochs = 10
g.manual_seed(42)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss, target_brain, recon_brain, label = test(epoch)

target_brain = torch.cat(target_brain, dim=0).cpu().numpy()
recon_brain = torch.cat(recon_brain, dim=0).cpu().numpy()
label = torch.cat(label, dim=0).cpu().numpy()

  0%|          | 0/18 [00:00<?, ?it/s]

100%|██████████| 18/18 [00:00<00:00, 24.00it/s]


Epoch: 1 Train loss: 114.3821


100%|██████████| 10/10 [00:00<00:00, 29.89it/s]


Epoch: 1 Test loss: 223.3641


100%|██████████| 18/18 [00:00<00:00, 38.98it/s]


Epoch: 2 Train loss: 110.4861


100%|██████████| 10/10 [00:00<00:00, 30.62it/s]


Epoch: 2 Test loss: 229.5863


100%|██████████| 18/18 [00:00<00:00, 25.80it/s]


Epoch: 3 Train loss: 104.1034


100%|██████████| 10/10 [00:00<00:00, 41.64it/s]


Epoch: 3 Test loss: 221.7700


100%|██████████| 18/18 [00:00<00:00, 34.47it/s]


Epoch: 4 Train loss: 100.7712


100%|██████████| 10/10 [00:00<00:00, 35.68it/s]


Epoch: 4 Test loss: 222.9147


100%|██████████| 18/18 [00:00<00:00, 27.56it/s]


Epoch: 5 Train loss: 98.7620


100%|██████████| 10/10 [00:00<00:00, 47.02it/s]


Epoch: 5 Test loss: 222.6112


100%|██████████| 18/18 [00:00<00:00, 39.83it/s]


Epoch: 6 Train loss: 97.5344


100%|██████████| 10/10 [00:00<00:00, 34.71it/s]


Epoch: 6 Test loss: 222.7832


100%|██████████| 18/18 [00:00<00:00, 30.84it/s]


Epoch: 7 Train loss: 95.3180


100%|██████████| 10/10 [00:00<00:00, 48.16it/s]


Epoch: 7 Test loss: 223.3032


100%|██████████| 18/18 [00:00<00:00, 33.00it/s]


Epoch: 8 Train loss: 94.8216


100%|██████████| 10/10 [00:00<00:00, 48.53it/s]


Epoch: 8 Test loss: 223.6443


100%|██████████| 18/18 [00:00<00:00, 31.07it/s]


Epoch: 9 Train loss: 93.3659


100%|██████████| 10/10 [00:00<00:00, 48.21it/s]


Epoch: 9 Test loss: 223.7538


100%|██████████| 18/18 [00:00<00:00, 32.30it/s]


Epoch: 10 Train loss: 92.0130


100%|██████████| 10/10 [00:00<00:00, 43.51it/s]

Epoch: 10 Test loss: 224.2601





In [71]:
def calculate_pattern_correlation(target_brain, recon_brain, label, rep=24):
    sort_idx = np.argsort(label.flatten())
    target_brain = target_brain[sort_idx]
    recon_brain = recon_brain[sort_idx]
    label = label[sort_idx]
    unique_label = np.unique(label)

    pattern_corr = []
    for image_idx in unique_label:
        target_pattern = target_brain[(label == image_idx).flatten(), :]
        recon_pattern = recon_brain[(label == image_idx).flatten(), :]
        
        corrs = np.corrcoef(target_pattern, recon_pattern)[:rep, rep:]
        corr = np.mean(corrs[np.triu_indices(rep, k=0)]) # including the diagonal

        pattern_corr.append(corr)

    return pattern_corr

In [72]:
def calculate_profile_correlation(target_brain, recon_brain, label, rep=24):
    sort_idx = np.argsort(label.flatten())
    target_brain = target_brain[sort_idx]
    recon_brain = recon_brain[sort_idx]

    profile_corr = []
    for voxel_idx in range(target_brain.shape[1]):
        target_profile = target_brain[:, voxel_idx].reshape(rep, -1, order="F")
        recon_profile = recon_brain[:, voxel_idx].reshape(rep, -1, order="F")
        
        corrs = np.corrcoef(target_profile, recon_profile)[:rep, rep:]
        corr = np.mean(corrs[np.triu_indices(rep, k=0)]) # including the diagonal

        profile_corr.append(corr)

    return profile_corr

In [73]:
def save_result(result_data, output_dir, output_filename):
    """
    Save the results to a CSV file.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    df = pd.DataFrame(result_data)
    df.to_csv(os.path.join(output_dir, output_filename), index=None)

In [76]:
conversion_output_path = "/home/acg17270jl/projects/fmri-reconstruction-with-dmvae/results/deeprecon/conversion/"

pattern_corr_result = []
profile_corr_result = []

pattern_corr = calculate_pattern_correlation(target_brain, recon_brain, label)
for i, corr in enumerate(pattern_corr):
    pattern_corr_result.append({
        'Subject_target': subj_target, 
        'Subject_source': subj_source,           
        'Correlation': corr, 
        'Image_idx': i+1}
    )

profile_corr = calculate_profile_correlation(target_brain, recon_brain, label)
for i, corr in enumerate(profile_corr):
    profile_corr_result.append({
        'Subject_target': subj_target, 
        'Subject_source': subj_source,           
        'Correlation': corr, 
        'Voxel_idx': i}
    )

save_result(pattern_corr_result, conversion_output_path, "pattern_correlation_converter3.csv")
save_result(profile_corr_result, conversion_output_path, "profile_correlation_converter3.csv")