In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from typing import List, Tuple
import numpy as np
import os
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from nilearn import plotting
import csv
import sys
import random
sys.path.append('..')
from utils import BrainGraphDataset, project_root, make_edge_index
from models import VAE

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

# set the random seed for reproducibility
torch.manual_seed(0)

# define the hyperparameters
input_dim = 4950 # size of the graph adjacency matrix
hidden_dim = 128
latent_dim = 64
lr = 1e-3
batch_size = 128
num_epochs = 200

annotations = 'annotations.csv'

dataroot = 'fc_matrices/hcp_100_ica/'
root = project_root()

dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, dataroot, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

# split the dataset into training and validation sets
num_samples = len(dataset)
train_size = int(0.8 * num_samples)
val_size = int(0.1 * num_samples)
test_size =  num_samples - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# define the data loaders
val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

for config in ['combined', 'before']:
    for dropout in [0, 0.05 ,0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:

        model = VAE(input_dim, [hidden_dim] * 2, latent_dim).to(device)
        model.load_state_dict(torch.load(os.path.join(root, f'vae_weights/vae_fine_tune_{config}_dropout_{dropout}.pt'), map_location=device))
        # validation
        model.eval()
        val_loss = 0.
        with torch.no_grad():
            for batch_idx, (data, _) in enumerate(val_loader):
                data = data.to(device)  # move data to device
                recon, mu, logvar, z = model(data.view(-1, input_dim))
                mse_loss, gmm_loss, l2_reg = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
                val_loss += mse_loss.item()
            for batch_idx, (data, _) in enumerate(test_loader):
                data = data.to(device)  # move data to device
                recon, mu, logvar, z = model(data.view(-1, input_dim))
                mse_loss, gmm_loss, l2_reg = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
                val_loss += mse_loss.item()
        val_loss /= (len(val_dataset)  + len(val_dataset))
        print(f'Dropout {dropout} {config} - Val Loss: {val_loss:.4f}\n')
        
        
        with open('vae_full_config.csv', 'a', newline='') as csvfile:
            writer = csv.writer(csvfile)
            # Write the data row
            writer.writerow([f'fine_tune_{config}', dropout, val_loss])

cpu
Dropout 0 combined - Val Loss: 103.4081

Dropout 0.05 combined - Val Loss: 127.9083

Dropout 0.1 combined - Val Loss: 126.1251

Dropout 0.15 combined - Val Loss: 119.1079

Dropout 0.2 combined - Val Loss: 111.8325

Dropout 0.25 combined - Val Loss: 112.8385

Dropout 0.3 combined - Val Loss: 111.2883

Dropout 0.35 combined - Val Loss: 107.7048

Dropout 0.4 combined - Val Loss: 115.0061

Dropout 0.45 combined - Val Loss: 111.3288

Dropout 0.5 combined - Val Loss: 114.6170

Dropout 0 before - Val Loss: 100.5748

Dropout 0.05 before - Val Loss: 125.7059

Dropout 0.1 before - Val Loss: 124.6112

Dropout 0.15 before - Val Loss: 120.5714

Dropout 0.2 before - Val Loss: 119.0313

Dropout 0.25 before - Val Loss: 122.3387

Dropout 0.3 before - Val Loss: 110.0571

Dropout 0.35 before - Val Loss: 112.5300

Dropout 0.4 before - Val Loss: 113.4967

Dropout 0.45 before - Val Loss: 112.1724

Dropout 0.5 before - Val Loss: 118.1505

