In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from typing import List, Tuple
import numpy as np

import os
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
import sys

sys.path.append('..')
from utils import BrainGraphDataset, project_root
from models import VAE
from torch.utils.data import ConcatDataset

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

# set the random seed for reproducibility
torch.manual_seed(0)

# define the hyperparameters
input_dim = 4950 # size of the graph adjacency matrix
hidden_dim = 128
latent_dim = 64
lr = 1e-3
batch_size = 128
num_epochs = 200
root = project_root()


annotations = 'annotations-before.csv'
dataroot = 'fc_matrices/psilo_ica_100_before/'

before_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

annotations = 'annotations-after.csv'
dataroot = 'fc_matrices/psilo_ica_100_after/'

after_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

dataset = ConcatDataset([before_dataset, after_dataset])

# split the dataset into training and validation sets
num_samples = len(dataset)
train_size = int(0.8 * num_samples)
val_size = num_samples - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# define the data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

best_val_loss = float('inf')  # set to infinity to start
best_model_state = None


# define a dictionary to store the loss curves for each configuration
loss_curves = {}

train_losses = []
val_losses = []
model = VAE(input_dim, [hidden_dim] * 2, latent_dim).to(device)  # move model to device
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in tqdm(range(num_epochs)):
    train_loss = 0.0
    val_loss = 0.0

    # training
    model.train()
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)  # move data to device
        optimizer.zero_grad()

        recon, mu, logvar, z = model(data.view(-1, input_dim))
        (mse_loss, gmm_loss, l2_reg) = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
        loss = mse_loss + gmm_loss
        loss.backward()
        optimizer.step()
        train_loss += mse_loss.item()

    # validation
    model.eval()
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(val_loader):
            data = data.to(device)  # move data to device
            recon, mu, logvar, z = model(data.view(-1, input_dim))
            mse_loss, gmm_loss, l2_reg = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
            val_loss += mse_loss.item()
    # append losses to lists
    train_losses.append(train_loss/len(train_dataset))
    val_losses.append(val_loss/len(val_dataset))

    # save the model if the validation loss is at its minimum
    if val_losses[-1] < best_val_loss:
        best_val_loss = val_losses[-1]
        best_model_state = model.state_dict()

    print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')

# save the best model for this configuration
torch.save(best_model_state, f'vgae_weights/vae_best.pt')

# add the loss curves to the dictionary
loss_curves = {"train_loss": train_losses, "val_loss": val_losses}

# save the loss curves to a file
with open(os.path.join(root, 'loss_curves', "loss_curves-vae.json"), "w") as f:
    json.dump(loss_curves, f)

cuda


  0%|          | 1/200 [00:03<10:47,  3.25s/it]

Epoch 1/200 - Train Loss: 378.7067 - Val Loss: 330.3306



  1%|          | 2/200 [00:05<08:59,  2.72s/it]

Epoch 2/200 - Train Loss: 348.2430 - Val Loss: 310.8991



  2%|▏         | 3/200 [00:07<08:22,  2.55s/it]

Epoch 3/200 - Train Loss: 330.3147 - Val Loss: 301.2161



  2%|▏         | 4/200 [00:10<08:03,  2.47s/it]

Epoch 4/200 - Train Loss: 318.5898 - Val Loss: 292.7309



  2%|▎         | 5/200 [00:12<07:53,  2.43s/it]

Epoch 5/200 - Train Loss: 311.6522 - Val Loss: 285.6180



  3%|▎         | 6/200 [00:14<07:45,  2.40s/it]

Epoch 6/200 - Train Loss: 305.1527 - Val Loss: 279.1438



  4%|▎         | 7/200 [00:17<07:40,  2.38s/it]

Epoch 7/200 - Train Loss: 297.9146 - Val Loss: 270.3188



  4%|▍         | 8/200 [00:19<07:35,  2.37s/it]

Epoch 8/200 - Train Loss: 287.0884 - Val Loss: 255.8003



  4%|▍         | 9/200 [00:22<07:31,  2.37s/it]

Epoch 9/200 - Train Loss: 274.7560 - Val Loss: 245.2633



  5%|▌         | 10/200 [00:24<07:28,  2.36s/it]

Epoch 10/200 - Train Loss: 264.1646 - Val Loss: 238.1963



  6%|▌         | 11/200 [00:26<07:25,  2.36s/it]

Epoch 11/200 - Train Loss: 254.4588 - Val Loss: 225.4854



  6%|▌         | 12/200 [00:29<07:22,  2.35s/it]

Epoch 12/200 - Train Loss: 243.7744 - Val Loss: 217.4939



  6%|▋         | 13/200 [00:31<07:19,  2.35s/it]

Epoch 13/200 - Train Loss: 232.7893 - Val Loss: 208.6123



  7%|▋         | 14/200 [00:33<07:17,  2.35s/it]

Epoch 14/200 - Train Loss: 224.8804 - Val Loss: 200.1286



  8%|▊         | 15/200 [00:36<07:14,  2.35s/it]

Epoch 15/200 - Train Loss: 217.2894 - Val Loss: 194.7773



  8%|▊         | 16/200 [00:38<07:12,  2.35s/it]

Epoch 16/200 - Train Loss: 211.3931 - Val Loss: 189.9685



  8%|▊         | 17/200 [00:40<07:09,  2.35s/it]

Epoch 17/200 - Train Loss: 205.4159 - Val Loss: 183.4078



  9%|▉         | 18/200 [00:43<07:06,  2.35s/it]

Epoch 18/200 - Train Loss: 199.5382 - Val Loss: 179.6830



 10%|▉         | 19/200 [00:45<07:04,  2.34s/it]

Epoch 19/200 - Train Loss: 195.2032 - Val Loss: 177.4412



 10%|█         | 20/200 [00:47<07:02,  2.34s/it]

Epoch 20/200 - Train Loss: 191.9385 - Val Loss: 174.5671



 10%|█         | 21/200 [00:50<07:00,  2.35s/it]

Epoch 21/200 - Train Loss: 188.5853 - Val Loss: 171.7923



 11%|█         | 22/200 [00:52<06:57,  2.35s/it]

Epoch 22/200 - Train Loss: 186.3854 - Val Loss: 169.3657



 12%|█▏        | 23/200 [00:54<06:55,  2.35s/it]

Epoch 23/200 - Train Loss: 184.4509 - Val Loss: 167.4625



 12%|█▏        | 24/200 [00:57<06:52,  2.35s/it]

Epoch 24/200 - Train Loss: 182.5862 - Val Loss: 166.9327



 12%|█▎        | 25/200 [00:59<06:50,  2.35s/it]

Epoch 25/200 - Train Loss: 181.6237 - Val Loss: 166.0001



 13%|█▎        | 26/200 [01:01<06:48,  2.35s/it]

Epoch 26/200 - Train Loss: 180.2914 - Val Loss: 165.1861



 14%|█▎        | 27/200 [01:04<06:45,  2.35s/it]

Epoch 27/200 - Train Loss: 179.2685 - Val Loss: 164.3412



 14%|█▍        | 28/200 [01:06<06:43,  2.35s/it]

Epoch 28/200 - Train Loss: 178.4888 - Val Loss: 163.7707



 14%|█▍        | 29/200 [01:08<06:41,  2.35s/it]

Epoch 29/200 - Train Loss: 177.7969 - Val Loss: 163.0696



 15%|█▌        | 30/200 [01:11<06:39,  2.35s/it]

Epoch 30/200 - Train Loss: 177.1544 - Val Loss: 162.7588



 16%|█▌        | 31/200 [01:13<06:36,  2.35s/it]

Epoch 31/200 - Train Loss: 176.5071 - Val Loss: 161.9033



 16%|█▌        | 32/200 [01:16<06:34,  2.35s/it]

Epoch 32/200 - Train Loss: 176.0720 - Val Loss: 161.4921



 16%|█▋        | 33/200 [01:18<06:32,  2.35s/it]

Epoch 33/200 - Train Loss: 175.4700 - Val Loss: 161.1956



 17%|█▋        | 34/200 [01:20<06:29,  2.35s/it]

Epoch 34/200 - Train Loss: 175.0282 - Val Loss: 160.7242



 18%|█▊        | 35/200 [01:23<06:27,  2.35s/it]

Epoch 35/200 - Train Loss: 174.4862 - Val Loss: 160.5048



 18%|█▊        | 36/200 [01:25<06:25,  2.35s/it]

Epoch 36/200 - Train Loss: 174.2241 - Val Loss: 159.9672



 18%|█▊        | 37/200 [01:27<06:22,  2.35s/it]

Epoch 37/200 - Train Loss: 173.6045 - Val Loss: 159.8565



 19%|█▉        | 38/200 [01:30<06:20,  2.35s/it]

Epoch 38/200 - Train Loss: 173.3790 - Val Loss: 159.2983



 20%|█▉        | 39/200 [01:32<06:18,  2.35s/it]

Epoch 39/200 - Train Loss: 172.8242 - Val Loss: 159.0558



 20%|██        | 40/200 [01:34<06:15,  2.35s/it]

Epoch 40/200 - Train Loss: 172.5081 - Val Loss: 158.4878



 20%|██        | 41/200 [01:37<06:13,  2.35s/it]

Epoch 41/200 - Train Loss: 172.0042 - Val Loss: 158.2834



 21%|██        | 42/200 [01:39<06:10,  2.35s/it]

Epoch 42/200 - Train Loss: 171.5522 - Val Loss: 157.8813



 22%|██▏       | 43/200 [01:41<06:08,  2.35s/it]

Epoch 43/200 - Train Loss: 171.0885 - Val Loss: 157.4772



 22%|██▏       | 44/200 [01:44<06:06,  2.35s/it]

Epoch 44/200 - Train Loss: 170.5117 - Val Loss: 157.0290



 22%|██▎       | 45/200 [01:46<06:03,  2.35s/it]

Epoch 45/200 - Train Loss: 169.9094 - Val Loss: 156.4506



 23%|██▎       | 46/200 [01:48<06:01,  2.35s/it]

Epoch 46/200 - Train Loss: 169.3016 - Val Loss: 156.0271



 24%|██▎       | 47/200 [01:51<05:58,  2.35s/it]

Epoch 47/200 - Train Loss: 168.5689 - Val Loss: 155.3836



 24%|██▍       | 48/200 [01:53<05:56,  2.35s/it]

Epoch 48/200 - Train Loss: 167.7355 - Val Loss: 154.7198



 24%|██▍       | 49/200 [01:55<05:54,  2.35s/it]

Epoch 49/200 - Train Loss: 166.8918 - Val Loss: 154.0389



 25%|██▌       | 50/200 [01:58<05:51,  2.35s/it]

Epoch 50/200 - Train Loss: 165.9451 - Val Loss: 153.2040



 26%|██▌       | 51/200 [02:00<05:49,  2.35s/it]

Epoch 51/200 - Train Loss: 164.9526 - Val Loss: 152.4620



 26%|██▌       | 52/200 [02:02<05:47,  2.35s/it]

Epoch 52/200 - Train Loss: 163.8668 - Val Loss: 151.6946



 26%|██▋       | 53/200 [02:05<05:45,  2.35s/it]

Epoch 53/200 - Train Loss: 162.7961 - Val Loss: 150.9216



 27%|██▋       | 54/200 [02:07<05:43,  2.35s/it]

Epoch 54/200 - Train Loss: 161.6985 - Val Loss: 150.0893



 28%|██▊       | 55/200 [02:10<05:40,  2.35s/it]

Epoch 55/200 - Train Loss: 160.5231 - Val Loss: 149.2893



 28%|██▊       | 56/200 [02:12<05:38,  2.35s/it]

Epoch 56/200 - Train Loss: 159.4032 - Val Loss: 148.3340



 28%|██▊       | 57/200 [02:14<05:36,  2.35s/it]

Epoch 57/200 - Train Loss: 158.2311 - Val Loss: 147.5503



 29%|██▉       | 58/200 [02:17<05:34,  2.35s/it]

Epoch 58/200 - Train Loss: 157.1123 - Val Loss: 146.7922



 30%|██▉       | 59/200 [02:19<05:31,  2.35s/it]

Epoch 59/200 - Train Loss: 156.0518 - Val Loss: 145.9505



 30%|███       | 60/200 [02:21<05:29,  2.35s/it]

Epoch 60/200 - Train Loss: 154.9775 - Val Loss: 145.2609



 30%|███       | 61/200 [02:24<05:26,  2.35s/it]

Epoch 61/200 - Train Loss: 154.0183 - Val Loss: 144.5611



 31%|███       | 62/200 [02:26<05:24,  2.35s/it]

Epoch 62/200 - Train Loss: 153.0334 - Val Loss: 143.9107



 32%|███▏      | 63/200 [02:28<05:22,  2.35s/it]

Epoch 63/200 - Train Loss: 152.1212 - Val Loss: 143.2952



 32%|███▏      | 64/200 [02:31<05:19,  2.35s/it]

Epoch 64/200 - Train Loss: 151.2589 - Val Loss: 142.7145



 32%|███▎      | 65/200 [02:33<05:17,  2.35s/it]

Epoch 65/200 - Train Loss: 150.3765 - Val Loss: 142.1269



 33%|███▎      | 66/200 [02:35<05:14,  2.35s/it]

Epoch 66/200 - Train Loss: 149.5372 - Val Loss: 141.5546



 34%|███▎      | 67/200 [02:38<05:12,  2.35s/it]

Epoch 67/200 - Train Loss: 148.6915 - Val Loss: 140.9922



 34%|███▍      | 68/200 [02:40<05:09,  2.35s/it]

Epoch 68/200 - Train Loss: 147.8916 - Val Loss: 140.6727



 34%|███▍      | 69/200 [02:42<05:07,  2.35s/it]

Epoch 69/200 - Train Loss: 147.1964 - Val Loss: 140.9095



 35%|███▌      | 70/200 [02:45<05:05,  2.35s/it]

Epoch 70/200 - Train Loss: 147.2756 - Val Loss: 142.6289



 36%|███▌      | 71/200 [02:47<05:03,  2.35s/it]

Epoch 71/200 - Train Loss: 148.3777 - Val Loss: 139.3871



 36%|███▌      | 72/200 [02:49<05:00,  2.35s/it]

Epoch 72/200 - Train Loss: 144.8607 - Val Loss: 140.1113



 36%|███▋      | 73/200 [02:52<04:58,  2.35s/it]

Epoch 73/200 - Train Loss: 145.2918 - Val Loss: 139.5713



 37%|███▋      | 74/200 [02:54<04:55,  2.35s/it]

Epoch 74/200 - Train Loss: 144.2114 - Val Loss: 138.6219



 38%|███▊      | 75/200 [02:57<04:53,  2.35s/it]

Epoch 75/200 - Train Loss: 142.7769 - Val Loss: 138.6236



 38%|███▊      | 76/200 [02:59<04:50,  2.35s/it]

Epoch 76/200 - Train Loss: 142.4269 - Val Loss: 137.4442



 38%|███▊      | 77/200 [03:01<04:48,  2.35s/it]

Epoch 77/200 - Train Loss: 140.7774 - Val Loss: 137.7500



 39%|███▉      | 78/200 [03:04<04:45,  2.34s/it]

Epoch 78/200 - Train Loss: 140.6085 - Val Loss: 136.9438



 40%|███▉      | 79/200 [03:06<04:43,  2.34s/it]

Epoch 79/200 - Train Loss: 139.2324 - Val Loss: 136.8599



 40%|████      | 80/200 [03:08<04:41,  2.34s/it]

Epoch 80/200 - Train Loss: 138.7132 - Val Loss: 136.3385



 40%|████      | 81/200 [03:11<04:39,  2.35s/it]

Epoch 81/200 - Train Loss: 137.6716 - Val Loss: 136.1496



 41%|████      | 82/200 [03:13<04:36,  2.34s/it]

Epoch 82/200 - Train Loss: 136.9748 - Val Loss: 135.9375



 42%|████▏     | 83/200 [03:15<04:34,  2.35s/it]

Epoch 83/200 - Train Loss: 136.1899 - Val Loss: 135.6158



 42%|████▏     | 84/200 [03:18<04:32,  2.35s/it]

Epoch 84/200 - Train Loss: 135.3116 - Val Loss: 135.6261



 42%|████▎     | 85/200 [03:20<04:29,  2.35s/it]

Epoch 85/200 - Train Loss: 134.7654 - Val Loss: 135.1789



 43%|████▎     | 86/200 [03:22<04:27,  2.35s/it]

Epoch 86/200 - Train Loss: 133.7726 - Val Loss: 135.3178



 44%|████▎     | 87/200 [03:25<04:25,  2.35s/it]

Epoch 87/200 - Train Loss: 133.3599 - Val Loss: 134.8001



 44%|████▍     | 88/200 [03:27<04:22,  2.35s/it]

Epoch 88/200 - Train Loss: 132.2467 - Val Loss: 135.1539



 44%|████▍     | 89/200 [03:29<04:20,  2.35s/it]

Epoch 89/200 - Train Loss: 131.9222 - Val Loss: 134.6382



 45%|████▌     | 90/200 [03:32<04:17,  2.34s/it]

Epoch 90/200 - Train Loss: 130.8397 - Val Loss: 134.9245



 46%|████▌     | 91/200 [03:34<04:15,  2.34s/it]

Epoch 91/200 - Train Loss: 130.5048 - Val Loss: 134.3608



 46%|████▌     | 92/200 [03:36<04:13,  2.34s/it]

Epoch 92/200 - Train Loss: 129.5072 - Val Loss: 134.5719



 46%|████▋     | 93/200 [03:39<04:10,  2.34s/it]

Epoch 93/200 - Train Loss: 129.0709 - Val Loss: 134.2677



 47%|████▋     | 94/200 [03:41<04:08,  2.35s/it]

Epoch 94/200 - Train Loss: 128.1418 - Val Loss: 134.3256



 48%|████▊     | 95/200 [03:43<04:06,  2.35s/it]

Epoch 95/200 - Train Loss: 127.6339 - Val Loss: 134.0571



 48%|████▊     | 96/200 [03:46<04:04,  2.35s/it]

Epoch 96/200 - Train Loss: 126.7982 - Val Loss: 134.0593



 48%|████▊     | 97/200 [03:48<04:02,  2.35s/it]

Epoch 97/200 - Train Loss: 126.2150 - Val Loss: 133.8424



 49%|████▉     | 98/200 [03:50<03:59,  2.35s/it]

Epoch 98/200 - Train Loss: 125.4562 - Val Loss: 133.9347



 50%|████▉     | 99/200 [03:53<03:57,  2.35s/it]

Epoch 99/200 - Train Loss: 124.8717 - Val Loss: 133.7547



 50%|█████     | 100/200 [03:55<03:55,  2.35s/it]

Epoch 100/200 - Train Loss: 124.1707 - Val Loss: 133.8988



 50%|█████     | 101/200 [03:58<03:52,  2.35s/it]

Epoch 101/200 - Train Loss: 123.5622 - Val Loss: 133.5489



 51%|█████     | 102/200 [04:00<03:50,  2.35s/it]

Epoch 102/200 - Train Loss: 122.7882 - Val Loss: 133.5772



 52%|█████▏    | 103/200 [04:02<03:48,  2.35s/it]

Epoch 103/200 - Train Loss: 122.0926 - Val Loss: 133.3012



 52%|█████▏    | 104/200 [04:05<03:45,  2.35s/it]

Epoch 104/200 - Train Loss: 121.3061 - Val Loss: 133.3160



 52%|█████▎    | 105/200 [04:07<03:43,  2.35s/it]

Epoch 105/200 - Train Loss: 120.6481 - Val Loss: 133.2493



 53%|█████▎    | 106/200 [04:09<03:41,  2.35s/it]

Epoch 106/200 - Train Loss: 119.9861 - Val Loss: 133.0660



 54%|█████▎    | 107/200 [04:12<03:38,  2.35s/it]

Epoch 107/200 - Train Loss: 119.3533 - Val Loss: 133.1626



 54%|█████▍    | 108/200 [04:14<03:36,  2.35s/it]

Epoch 108/200 - Train Loss: 118.6805 - Val Loss: 133.0029



 55%|█████▍    | 109/200 [04:16<03:33,  2.35s/it]

Epoch 109/200 - Train Loss: 118.0076 - Val Loss: 132.9485



 55%|█████▌    | 110/200 [04:19<03:31,  2.35s/it]

Epoch 110/200 - Train Loss: 117.2683 - Val Loss: 132.7106



 56%|█████▌    | 111/200 [04:21<03:29,  2.35s/it]

Epoch 111/200 - Train Loss: 116.5914 - Val Loss: 132.6061



 56%|█████▌    | 112/200 [04:23<03:26,  2.35s/it]

Epoch 112/200 - Train Loss: 115.8563 - Val Loss: 132.7150



 56%|█████▋    | 113/200 [04:26<03:24,  2.35s/it]

Epoch 113/200 - Train Loss: 115.2611 - Val Loss: 132.4651



 57%|█████▋    | 114/200 [04:28<03:22,  2.35s/it]

Epoch 114/200 - Train Loss: 114.5447 - Val Loss: 132.5524



 57%|█████▊    | 115/200 [04:30<03:19,  2.35s/it]

Epoch 115/200 - Train Loss: 113.9742 - Val Loss: 132.3824



 58%|█████▊    | 116/200 [04:33<03:17,  2.35s/it]

Epoch 116/200 - Train Loss: 113.3023 - Val Loss: 132.3715



 58%|█████▊    | 117/200 [04:35<03:15,  2.35s/it]

Epoch 117/200 - Train Loss: 112.6552 - Val Loss: 132.1787



 59%|█████▉    | 118/200 [04:38<03:12,  2.35s/it]

Epoch 118/200 - Train Loss: 112.0002 - Val Loss: 132.1052



 60%|█████▉    | 119/200 [04:40<03:10,  2.35s/it]

Epoch 119/200 - Train Loss: 111.3586 - Val Loss: 132.1104



 60%|██████    | 120/200 [04:42<03:07,  2.35s/it]

Epoch 120/200 - Train Loss: 110.6942 - Val Loss: 132.0441



 60%|██████    | 121/200 [04:45<03:05,  2.35s/it]

Epoch 121/200 - Train Loss: 110.0537 - Val Loss: 131.9229



 61%|██████    | 122/200 [04:47<03:03,  2.35s/it]

Epoch 122/200 - Train Loss: 109.4395 - Val Loss: 131.9663



 62%|██████▏   | 123/200 [04:49<03:00,  2.35s/it]

Epoch 123/200 - Train Loss: 108.8021 - Val Loss: 131.9356



 62%|██████▏   | 124/200 [04:52<02:58,  2.35s/it]

Epoch 124/200 - Train Loss: 108.1789 - Val Loss: 131.8698



 62%|██████▎   | 125/200 [04:54<02:56,  2.35s/it]

Epoch 125/200 - Train Loss: 107.5565 - Val Loss: 131.7571



 63%|██████▎   | 126/200 [04:56<02:53,  2.35s/it]

Epoch 126/200 - Train Loss: 106.9358 - Val Loss: 131.8237



 64%|██████▎   | 127/200 [04:59<02:51,  2.35s/it]

Epoch 127/200 - Train Loss: 106.3260 - Val Loss: 131.7763



 64%|██████▍   | 128/200 [05:01<02:48,  2.35s/it]

Epoch 128/200 - Train Loss: 105.7510 - Val Loss: 131.7555



 64%|██████▍   | 129/200 [05:03<02:46,  2.35s/it]

Epoch 129/200 - Train Loss: 105.1785 - Val Loss: 131.9419



 65%|██████▌   | 130/200 [05:06<02:44,  2.35s/it]

Epoch 130/200 - Train Loss: 104.7632 - Val Loss: 132.2675



 66%|██████▌   | 131/200 [05:08<02:42,  2.35s/it]

Epoch 131/200 - Train Loss: 104.4730 - Val Loss: 132.6796



 66%|██████▌   | 132/200 [05:10<02:39,  2.35s/it]

Epoch 132/200 - Train Loss: 104.6870 - Val Loss: 133.6407



 66%|██████▋   | 133/200 [05:13<02:37,  2.35s/it]

Epoch 133/200 - Train Loss: 104.5320 - Val Loss: 132.3213



 67%|██████▋   | 134/200 [05:15<02:35,  2.35s/it]

Epoch 134/200 - Train Loss: 103.2536 - Val Loss: 132.0357



 68%|██████▊   | 135/200 [05:17<02:32,  2.35s/it]

Epoch 135/200 - Train Loss: 102.0469 - Val Loss: 132.8027



 68%|██████▊   | 136/200 [05:20<02:30,  2.35s/it]

Epoch 136/200 - Train Loss: 102.1577 - Val Loss: 132.1733



 68%|██████▊   | 137/200 [05:22<02:28,  2.35s/it]

Epoch 137/200 - Train Loss: 101.3422 - Val Loss: 132.0868



 69%|██████▉   | 138/200 [05:24<02:25,  2.35s/it]

Epoch 138/200 - Train Loss: 100.3263 - Val Loss: 132.6235



 70%|██████▉   | 139/200 [05:27<02:23,  2.35s/it]

Epoch 139/200 - Train Loss: 100.3799 - Val Loss: 131.7464



 70%|███████   | 140/200 [05:29<02:21,  2.35s/it]

Epoch 140/200 - Train Loss: 99.2164 - Val Loss: 132.0685



 70%|███████   | 141/200 [05:32<02:18,  2.35s/it]

Epoch 141/200 - Train Loss: 98.9320 - Val Loss: 132.2402



 71%|███████   | 142/200 [05:34<02:16,  2.35s/it]

Epoch 142/200 - Train Loss: 98.3577 - Val Loss: 131.8363



 72%|███████▏  | 143/200 [05:36<02:14,  2.35s/it]

Epoch 143/200 - Train Loss: 97.5643 - Val Loss: 132.0211



 72%|███████▏  | 144/200 [05:39<02:11,  2.35s/it]

Epoch 144/200 - Train Loss: 97.3628 - Val Loss: 131.9340



 72%|███████▎  | 145/200 [05:41<02:09,  2.35s/it]

Epoch 145/200 - Train Loss: 96.4120 - Val Loss: 132.2835



 73%|███████▎  | 146/200 [05:43<02:06,  2.35s/it]

Epoch 146/200 - Train Loss: 96.2334 - Val Loss: 131.8551



 74%|███████▎  | 147/200 [05:46<02:04,  2.35s/it]

Epoch 147/200 - Train Loss: 95.4905 - Val Loss: 132.0408



 74%|███████▍  | 148/200 [05:48<02:02,  2.35s/it]

Epoch 148/200 - Train Loss: 94.9997 - Val Loss: 132.2108



 74%|███████▍  | 149/200 [05:50<01:59,  2.35s/it]

Epoch 149/200 - Train Loss: 94.5274 - Val Loss: 131.8613



 75%|███████▌  | 150/200 [05:53<01:57,  2.35s/it]

Epoch 150/200 - Train Loss: 93.9059 - Val Loss: 132.0007



 75%|███████▌  | 150/200 [05:53<01:57,  2.36s/it]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/rds/general/user/ljn19/home/anaconda3/envs/fyp/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/tmp/pbs.7686689.pbs/ipykernel_2461273/1536454020.py", line 62, in <module>
    for batch_idx, (data, _) in enumerate(train_loader):
  File "/rds/general/user/ljn19/home/anaconda3/envs/fyp/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
    data = self._next_data()
  File "/rds/general/user/ljn19/home/anaconda3/envs/fyp/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 678, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/rds/general/user/ljn19/home/anaconda3/envs/fyp/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/rds/general/user/ljn19/home

In [None]:
import json
import matplotlib.pyplot as plt

# load in the loss curves from file
with open("loss_curves_gmm.json", "r") as f:
    loss_curves = json.load(f)

# plot the validation loss curves for each number of GMM components
plt.figure(figsize=(8, 6))
for n_comp, loss_dict in loss_curves.items():
    val_losses = loss_dict["val_loss"]
    epochs = range(1, len(val_losses) + 1)
    plt.plot(epochs, val_losses, label=f"{n_comp}")

# add labels and legend
plt.xlabel("Epoch")
plt.ylabel("Validation Loss")
plt.title("Validation Loss Curves for Different Numbers of GMM Components")
plt.legend()
plt.ylim((30, 60))

# show the plot
plt.show()


In [None]:
import json
import matplotlib.pyplot as plt

# load in the loss curves from file
with open("loss_curves_gmm.json", "r") as f:
    loss_curves = json.load(f)

# plot the validation loss curves for each number of GMM components
plt.figure(figsize=(8, 6))
for n_comp, loss_dict in loss_curves.items():
    val_losses = loss_dict["train_loss"]
    epochs = range(1, len(val_losses) + 1)
    plt.plot(epochs, val_losses, label=f"{n_comp}")

# add labels and legend
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.title("Training Loss Curves for Different Numbers of GMM Components")
plt.legend()
plt.ylim((20, 40))


# show the plot
plt.show()


In [None]:
# load in the models and get the validation loss for each
models = []
val_losses = []
for n_comp in range(2, 11):
    # instantiate the model
    model = VAE(input_dim, hidden_dim, latent_dim)

    model.load_state_dict(torch.load(f'vgae_weights/gmm{n_comp}_best.pt', map_location=torch.device('cpu')))
    
    # set the model to evaluation mode
    model.eval()

    # calculate the validation loss
    val_loss = 0.0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(val_loader):
            recon, mu, logvar, _ = model(data.view(-1, input_dim))
            loss = model.loss_function(recon, data.view(-1, input_dim), mu, logvar, n_components=n_comp)
            val_loss += loss.item()
    val_losses.append(val_loss/len(val_dataset))
    models.append(model)

# print the validation loss for each model
for i, val_loss in enumerate(val_losses):
    print(f'Model GMM {i+2}: Validation Loss = {val_loss:.4f}')

In [None]:
import matplotlib.pyplot as plt

# define the number of components used in each run
n_components_list = list(range(2, 11))

# plot the validation loss for each n_components value
plt.plot(n_components_list, val_losses)
plt.xlabel('Number of Components')
plt.ylabel('Validation Loss')
plt.title('Validation Loss vs. Number of GMM Components')
plt.savefig('gmm_component_testing.jpg')
plt.show()


In [None]:
model = VAE(input_dim, hidden_dim, latent_dim)
# load the weights
model.load_state_dict(torch.load(f'vgae_weights/gmm8_best.pt', map_location=device))

In [None]:
import matplotlib.pyplot as plt
from nilearn import plotting

# select a batch from the validation data loader
data, _ = next(iter(val_loader))

# pass the batch through the trained model to obtain the reconstructed output
recon, _, _, _ = model(data.view(-1, input_dim))

# reshape the output to a 100x100 matrix (assuming the input_dim is 100x100)
recon = recon.view(-1, 100, 100)

for i in range(3):
    # plot the original and reconstructed matrices for the first sample in the batch
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    plotting.plot_matrix(data[i], colorbar=True, vmax=0.8, vmin=-0.8, axes=ax1)
    ax1.set_title('Original')
    plotting.plot_matrix(recon[i].detach(), colorbar=True, vmax=0.8, vmin=-0.8, axes=ax2)
    ax2.set_title('Reconstructed')
    plt.show()

In [None]:
dataroot = 'fc_matrices/psilo_ica_100_before'
cwd = os.getcwd() + '/'

psilo_dataset = BrainGraphDataset(img_dir=cwd + dataroot,
                            annotations_file=cwd + annotations,
                            transform=None, extra_data=None, setting='no_label')

psilo_train_loader = DataLoader(psilo_dataset, batch_size=batch_size)

# select a batch from the validation data loader
data, _ = next(iter(psilo_train_loader))

# pass the batch through the trained model to obtain the reconstructed output
recon, _, _, _ = model(data.view(-1, input_dim))

# reshape the output to a 100x100 matrix (assuming the input_dim is 100x100)
recon = recon.view(-1, 100, 100)

for i in range(3):
    # plot the original and reconstructed matrices for the first sample in the batch
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    plotting.plot_matrix(data[i], colorbar=True, vmax=0.8, vmin=-0.8, axes=ax1)
    ax1.set_title('Original')
    plotting.plot_matrix(recon[i].detach(), colorbar=True, vmax=0.8, vmin=-0.8, axes=ax2)
    ax2.set_title('Reconstructed')
    plt.show()

In [None]:
model = VAE(input_dim, hidden_dim, latent_dim)

# set the model to evaluation model
model.eval()

# calculate the validation loss
val_losses = []
with torch.no_grad():
    for n_comp in range(2, 11):
        val_loss = 0.0
        model.load_state_dict(torch.load(f'vgae_weights/gmm{n_comp}_best.pt', map_location=device))
        for batch_idx, (data, _) in enumerate(psilo_train_loader):
            recon, mu, logvar, _ = model(data.view(-1, input_dim))
            loss = model.loss_function(recon, data.view(-1, input_dim), mu, logvar, n_components=n_comp)
            val_loss += loss.item()
        val_loss /= len(psilo_dataset)
        val_losses.append(val_loss)
        print(f'gmm_{n_comp}: {val_loss} loss')

In [None]:
n_components_list = list(range(2, 11))

# plot the validation loss for each n_components value
plt.plot(n_components_list, val_losses)
plt.xlabel('Number of Components')
plt.ylabel('Validation Loss')
plt.title('Validation Loss vs. Number of GMM Components')
plt.savefig('gmm_component_testing.jpg')
plt.show()

In [None]:
# define the hyperparameters
input_dim = 100 * 100  # size of the graph adjacency matrix
lr = 1e-3
batch_size = 128
num_epochs = 300

annotations = 'annotations.csv'

dataroot = 'fc_matrices/hcp_100_ica/'
cwd = os.getcwd() + '/'

dataset = BrainGraphDataset(img_dir=cwd + dataroot,
                            annotations_file=cwd + dataroot + annotations,
                            transform=None, extra_data=None, setting='no_label')

# define the data loaders
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# instantiate the model


from tqdm import tqdm

import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loss_curves = {}
best_val_losses = {}  # create a dictionary to store the best validation loss for each configuration

best_n = 3

# for hidden_dim in [256, 512]:
#     for latent_dim in [64, 128]:
#         train_losses = []
#         val_losses = []
#         model = VAE(input_dim, hidden_dim, latent_dim).to(device)  # move model to device
#         optimizer = optim.Adam(model.parameters(), lr=lr)
#         best_val_loss = float('inf')  # initialize the best validation loss to infinity
        
#         with open('gmm_train_overfit.txt', 'a') as f:
#             f.write(f'Hidden dim: {hidden_dim}, latent_dim: {latent_dim}\n')
        
#         for epoch in range(num_epochs):
#             train_loss = 0.0
#             val_loss = 0.0

#             # training
#             model.train()
#             # define the optimizer and the loss function

#             for batch_idx, (data, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
#                 data = data.to(device)  # move input data to device
#                 optimizer.zero_grad()

#                 recon, mu, logvar, z = model(data.view(-1, input_dim))
#                 loss = model.loss_function(recon, data.view(-1, input_dim), mu, logvar, n_components=best_n)
#                 loss.backward()
#                 optimizer.step()
#                 train_loss += loss.item()

#             # validation
#             model.eval()
#             with torch.no_grad():
#                 for batch_idx, (data, _) in tqdm(enumerate(psilo_train_loader), total=len(psilo_train_loader)):
#                     data = data.to(device)  # move input data to device
#                     recon, mu, logvar, z = model(data.view(-1, input_dim))
#                     loss = loss_function_gmm(recon, data.view(-1, input_dim), mu, logvar, n_components=best_n)
#                     val_loss += loss.item()

#             # append losses to lists
#             train_losses.append(train_loss/len(train_dataset))
#             val_losses.append(val_loss/len(psilo_dataset))

#             with open('gmm_train_overfit.txt', 'a') as f:
#                 f.write(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')
                
#             # update the best validation loss and save the model weights if it's the best so far for this configuration
#             if val_losses[-1] < best_val_loss:
#                 best_val_loss = val_losses[-1]
#                 best_val_losses[(hidden_dim, latent_dim)] = best_val_loss
#                 torch.save(model.state_dict(), f'vgae_weights/gmm_{best_n}_hidden{hidden_dim}_latent{latent_dim}.pt')

#         # plot the losses
#         plt.plot(val_losses, label=f'Validation Loss (hidden_dim={hidden_dim}, latent_dim={latent_dim})')
        
#                 # add the loss curves to the dictionary
#         loss_curves[f"hidden{hidden_dim}_latent_dim{latent_dim}"] = {"train_loss": train_losses, "val_loss": val_losses}

# # save the loss curves to a file
# with open("loss_curves_overfit_new.json", "w") as f:
#     json.dump(loss_curves, f)

# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()


In [None]:
import json
import matplotlib.pyplot as plt

# load in the loss curves from file
with open("loss_curves_overfit.json", "r") as f:
    loss_curves = json.load(f)

# plot the validation loss curves for each number of GMM components
plt.figure(figsize=(10, 8))
for n_comp, loss_dict in loss_curves.items():
    val_losses = loss_dict["val_loss"]
    epochs = range(1, len(val_losses) + 1)
    plt.plot(epochs, val_losses, label=f"{n_comp}")

# add labels and legend
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.title("Validation Loss Curves for Different Net Architectures")
plt.legend()
plt.ylim((350, 500))

# show the plot
plt.show()


In [None]:
# define the hyperparameters
input_dim = 100 * 100  # size of the graph adjacency matrix
hidden_dims = [256, 128, 64]
latent_dims = [64, 32, 16]
lr = 1e-3
batch_size = 128
num_epochs = 300

annotations = 'annotations.csv'

dataroot = 'fc_matrices/hcp_100_ica/'
cwd = os.getcwd() + '/'


# define the optimizer and the loss function
optimizer = optim.Adam(model.parameters(), lr=lr)

from tqdm import tqdm

import matplotlib.pyplot as plt

for hidden_dim in hidden_dims:
    for latent_dim in latent_dims:
        train_losses = []
        val_losses = []
        model = VAE(input_dim, hidden_dim, latent_dim)
        
        # load in the model weights
        model.load_state_dict(torch.load(f'vgae_weights/gmm_5_hidden{hidden_dim}_latent{latent_dim}.pt', map_location=device))
        
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for batch_idx, (data, _) in tqdm(enumerate(psilo_train_loader), total=len(psilo_train_loader)):
                recon, mu, logvar, _ = model(data.view(-1, input_dim))
                loss = model.loss_function(recon, data.view(-1, input_dim), mu, logvar, n_components=5)
                val_loss += loss.item()
            val_losses.append(val_loss/len(psilo_dataset))

        # print the validation loss for this configuration
        print(f'Hidden Dim: {hidden_dim}, Latent Dim: {latent_dim}, Validation Loss: {val_losses[-1]:.4f}')


In [None]:
hidden_dim = 256
latent_dim = 64
input_dim = 100 * 100

model = VAE(input_dim, hidden_dim, latent_dim)

model.load_state_dict(torch.load('vgae_weights/gmm3_best.pt', map_location=device))

psilo_zs = []
hcp_zs = []

with torch.no_grad():
    for batch_idx, (data, _) in enumerate(psilo_train_loader):
        _, _, _, z = model(data.view(-1, input_dim))
        psilo_zs.append(z)
    
    for batch_idx, (data, _) in enumerate(train_loader):
        _, _, _, z = model(data.view(-1, input_dim))
        hcp_zs.append(z)

psilo_zs = torch.cat(psilo_zs, dim=0)
hcp_zs = torch.cat(hcp_zs, dim=0)
        
# Concatenate the encoded representations and create labels
x = torch.cat((psilo_zs, hcp_zs), dim=0)
labels = torch.cat((torch.zeros(psilo_zs.shape[0]), torch.ones(hcp_zs.shape[0])), dim=0)

from sklearn.manifold import TSNE

for per in [30, 40, 50]:
    # Use t-SNE to reduce the dimensionality of the encoded representations
    tsne = TSNE(n_components=2, perplexity=per, n_iter=1000)
    x_tsne = tsne.fit_transform(x)

    # Plot the t-SNE embeddings
    plt.scatter(x_tsne[:, 0], x_tsne[:, 1], c=labels, cmap='coolwarm')
    plt.colorbar()
    plt.show()