In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from typing import List, Tuple
import numpy as np
import os
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from nilearn import plotting

import sys
import random
sys.path.append('..')
from utils import BrainGraphDataset, project_root, make_edge_index
from models import VAE

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

# set the random seed for reproducibility
torch.manual_seed(0)

# define the hyperparameters
input_dim = 4950 # size of the graph adjacency matrix
hidden_dim = 128
latent_dim = 64
lr = 1e-3
batch_size = 128
num_epochs = 200

annotations = 'annotations.csv'

dataroot = 'fc_matrices/hcp_100_ica/'
root = project_root()

dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, dataroot, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

# split the dataset into training and validation sets
num_samples = len(dataset)
train_size = int(0.8 * num_samples)
val_size = int(0.1 * num_samples)
test_size =  num_samples - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# define the data loaders
val_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

for dropout in [0, 0.05 ,0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:
    
    model = VAE(input_dim, [hidden_dim] * 2, latent_dim).to(device)
    model.load_state_dict(torch.load(os.path.join(root, f'vae_weights/vae_dropout_{dropout}.pt'), map_location=device))
    # validation
    model.eval()
    val_loss = 0.
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(val_loader):
            data = data.to(device)  # move data to device
            recon, mu, logvar, z = model(data.view(-1, input_dim))
            mse_loss, gmm_loss, l2_reg = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
            val_loss += mse_loss.item()

        for batch_idx, (data, _) in enumerate(test_loader):
            data = data.to(device)  # move data to device
            recon, mu, logvar, z = model(data.view(-1, input_dim))
            mse_loss, gmm_loss, l2_reg = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
            val_loss += mse_loss.item()
    val_loss /= (len(test_dataset) + len(val_dataset))
    print(f'Dropout {dropout} - Test Loss: {test_loss:.4f}\n')

NameError: name 'torch' is not defined