In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from typing import List, Tuple
import numpy as np
import copy
import os
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
import sys
import csv
from torch.utils.data import ConcatDataset

sys.path.append('..')
from utils import BrainGraphDataset, project_root
from models import VAE

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

# define the hyperparameters

hidden_dim = 128
latent_dim = 64
lr = 1e-3
batch_size = 128
num_epochs = 200

root = project_root()
annotations = 'annotations.csv'

dataroot = 'fc_matrices/hcp_100_ica/'
hcp_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, dataroot, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

dataroot = 'fc_matrices/psilo_ica_100_before/'
psilo_ica_before_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

dataroot = 'fc_matrices/psilo_ica_100_after/'
psilo_ica_after_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')  

dataroot = 'fc_matrices/psilo_schaefer_before/'
psilo_schaefer_before_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

dataroot = 'fc_matrices/psilo_schaefer_after/'
psilo_schaefer_after_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')  

dataroot = 'fc_matrices/psilo_aal_before/'
psilo_aal_before_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

dataroot = 'fc_matrices/psilo_aal_after/'
psilo_aal_after_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')  

psilo_ica_combined_dataset = ConcatDataset([psilo_ica_before_dataset, psilo_ica_after_dataset])
psilo_schaefer_combined_dataset = ConcatDataset([psilo_schaefer_before_dataset, psilo_schaefer_after_dataset])
psilo_aal_combined_dataset = ConcatDataset([psilo_aal_before_dataset, psilo_aal_after_dataset])

configs = [
    (hcp_dataset, 'hcp'),
    (psilo_ica_before_dataset, 'psilo_ica_before'),
    (psilo_ica_combined_dataset, 'psilo_ica_combined'),
    (psilo_schaefer_before_dataset, 'psilo_schaefer_before'),
    (psilo_schaefer_combined_dataset, 'psilo_schaefer_combined'),
    (psilo_aal_before_dataset, 'psilo_aal_before'),
    (psilo_aal_combined_dataset, 'psilo_aal_combined'),
]
dropout_list = [0, 0.05 ,0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]

with open('vae_full_config.csv', 'a', newline='') as csvfile:
    writer = csv.writer(csvfile)

    # Write the header row
    writer.writerow(['Config', 'Dropout', 'Val Loss', 'Test Loss'])


for config in configs:
    for dropout in dropout_list:
        # set the random seed for reproducibility
        torch.manual_seed(0)

        dataset = config[0]
        input_dim = 6670 if ('aal' in config[1]) else 4950 # size of the graph adjacency matrix

        # split the dataset into training and validation sets
        num_samples = len(dataset)
        train_size = int(0.8 * num_samples)
        val_size = int(0.1 * num_samples)
        test_size =  num_samples - train_size - val_size
        train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

        # define the data loaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

        model = VAE(input_dim, [hidden_dim] * 2, latent_dim, dropout=dropout).to(device)  # move model to device
        model.load_state_dict(torch.load(os.path.join(root, f'vae_weights/vae_dropout_{config[1]}_{dropout}.pt'), 
                                         map_location=device))

        val_loss = 0.
        test_loss = 0.
        # validation
        model.eval()
        with torch.no_grad():
            for batch_idx, (data, _) in enumerate(val_loader):
                data = data.to(device)  # move data to device
                recon, mu, logvar, z = model(data.view(-1, input_dim))
                mse_loss, gmm_loss, l2_reg = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
                val_loss += mse_loss.item()
            for batch_idx, (data, _) in enumerate(test_loader):
                data = data.to(device)  # move data to device
                recon, mu, logvar, z = model(data.view(-1, input_dim))
                mse_loss, gmm_loss, l2_reg = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
                val_loss += mse_loss.item()
        print(f'{config[1]} {dropout} Val Loss = {val_loss / (len(val_dataset) + len(test_dataset))}')
        
        with open('vae_full_config_val.csv', 'a', newline='') as csvfile:
            writer = csv.writer(csvfile)
            # Write the data row
            writer.writerow([config[1], dropout, val_loss / (len(val_dataset) + len(test_dataset))])

cpu
hcp 0 Val Loss = 28.842549945584576
hcp 0.05 Val Loss = 30.77044829563122
hcp 0.1 Val Loss = 32.49013841923197
hcp 0.15 Val Loss = 33.12696162740983
hcp 0.2 Val Loss = 34.09783844449627
hcp 0.25 Val Loss = 35.27780409476057
hcp 0.3 Val Loss = 36.14284048507463
hcp 0.35 Val Loss = 36.86121394978234
hcp 0.4 Val Loss = 37.0195458255597
hcp 0.45 Val Loss = 37.531225707400495
hcp 0.5 Val Loss = 37.82115909709266
psilo_ica_before 0 Val Loss = 145.5535685221354
psilo_ica_before 0.05 Val Loss = 150.42519802517361
psilo_ica_before 0.1 Val Loss = 149.37181939019098
psilo_ica_before 0.15 Val Loss = 147.18073187934027
psilo_ica_before 0.2 Val Loss = 148.12685818142361
psilo_ica_before 0.25 Val Loss = 150.64151340060764
psilo_ica_before 0.3 Val Loss = 151.74075656467014
psilo_ica_before 0.35 Val Loss = 152.97360568576389
psilo_ica_before 0.4 Val Loss = 153.4437730577257
psilo_ica_before 0.45 Val Loss = 153.74240451388889
psilo_ica_before 0.5 Val Loss = 158.21116807725696
psilo_ica_combined 0 Va