In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from typing import List, Tuple
import numpy as np

import os
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
import sys

sys.path.append('..')
from utils import BrainGraphDataset, project_root
from models import VAE
from torch.utils.data import ConcatDataset

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# set the random seed for reproducibility
torch.manual_seed(0)

# define the hyperparameters
input_dim = 4950 # size of the graph adjacency matrix
hidden_dim = 128
latent_dim = 64
lr = 1e-3
batch_size = 128
num_epochs = 200
root = project_root()


annotations = 'annotations-before.csv'
dataroot = 'fc_matrices/psilo_ica_100_before/'

dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular')

# split the dataset into training and validation sets
num_samples = len(dataset)
train_size = int(0.8 * num_samples)
val_size = num_samples - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# define the data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

best_val_loss = float('inf')  # set to infinity to start
best_model_state = None


# define a dictionary to store the loss curves for each configuration
loss_curves = {}

train_losses = []
val_losses = []
model = VAE(input_dim, [hidden_dim] * 2, latent_dim).to(device)  # move model to device
optimizer = optim.Adam(model.parameters(), lr=lr)

# for epoch in tqdm(range(num_epochs)):
#     train_loss = 0.0
#     val_loss = 0.0

#     # training
#     model.train()
#     for batch_idx, (data, _) in enumerate(train_loader):
#         data = data.to(device)  # move data to device
#         optimizer.zero_grad()

#         recon, mu, logvar, z = model(data.view(-1, input_dim))
#         (mse_loss, gmm_loss, l2_reg) = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
#         loss = mse_loss + gmm_loss
#         loss.backward()
#         optimizer.step()
#         train_loss += mse_loss.item()

#     # validation
#     model.eval()
#     with torch.no_grad():
#         for batch_idx, (data, _) in enumerate(val_loader):
#             data = data.to(device)  # move data to device
#             recon, mu, logvar, z = model(data.view(-1, input_dim))
#             mse_loss, gmm_loss, l2_reg = model.loss(recon, data.view(-1, input_dim), mu, logvar, n_components=3)
#             val_loss += mse_loss.item()
#     # append losses to lists
#     train_losses.append(train_loss/len(train_dataset))
#     val_losses.append(val_loss/len(val_dataset))

#     # save the model if the validation loss is at its minimum
#     if val_losses[-1] < best_val_loss:
#         best_val_loss = val_losses[-1]
#         best_model_state = model.state_dict()

#     print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')

# # save the best model for this configuration
# torch.save(best_model_state, f'vgae_weights/vae_best.pt')

# # add the loss curves to the dictionary
# loss_curves = {"train_loss": train_losses, "val_loss": val_losses}

# # save the loss curves to a file
# with open(os.path.join(root, 'loss_curves', "loss_curves-vae.json"), "w") as f:
#     json.dump(loss_curves, f)

cpu


In [None]:
import json
import matplotlib.pyplot as plt

# load in the loss curves from file
with open("loss_curves_gmm.json", "r") as f:
    loss_curves = json.load(f)

# plot the validation loss curves for each number of GMM components
plt.figure(figsize=(8, 6))
for n_comp, loss_dict in loss_curves.items():
    val_losses = loss_dict["val_loss"]
    epochs = range(1, len(val_losses) + 1)
    plt.plot(epochs, val_losses, label=f"{n_comp}")

# add labels and legend
plt.xlabel("Epoch")
plt.ylabel("Validation Loss")
plt.title("Validation Loss Curves for Different Numbers of GMM Components")
plt.legend()
plt.ylim((30, 60))

# show the plot
plt.show()


In [None]:
import json
import matplotlib.pyplot as plt

# load in the loss curves from file
with open("loss_curves_gmm.json", "r") as f:
    loss_curves = json.load(f)

# plot the validation loss curves for each number of GMM components
plt.figure(figsize=(8, 6))
for n_comp, loss_dict in loss_curves.items():
    val_losses = loss_dict["train_loss"]
    epochs = range(1, len(val_losses) + 1)
    plt.plot(epochs, val_losses, label=f"{n_comp}")

# add labels and legend
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.title("Training Loss Curves for Different Numbers of GMM Components")
plt.legend()
plt.ylim((20, 40))


# show the plot
plt.show()


In [None]:
# load in the models and get the validation loss for each
models = []
val_losses = []
for n_comp in range(2, 11):
    # instantiate the model
    model = VAE(input_dim, hidden_dim, latent_dim)

    model.load_state_dict(torch.load(f'vgae_weights/gmm{n_comp}_best.pt', map_location=torch.device('cpu')))
    
    # set the model to evaluation mode
    model.eval()

    # calculate the validation loss
    val_loss = 0.0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(val_loader):
            recon, mu, logvar, _ = model(data.view(-1, input_dim))
            loss = model.loss_function(recon, data.view(-1, input_dim), mu, logvar, n_components=n_comp)
            val_loss += loss.item()
    val_losses.append(val_loss/len(val_dataset))
    models.append(model)

# print the validation loss for each model
for i, val_loss in enumerate(val_losses):
    print(f'Model GMM {i+2}: Validation Loss = {val_loss:.4f}')

In [None]:
import matplotlib.pyplot as plt

# define the number of components used in each run
n_components_list = list(range(2, 11))

# plot the validation loss for each n_components value
plt.plot(n_components_list, val_losses)
plt.xlabel('Number of Components')
plt.ylabel('Validation Loss')
plt.title('Validation Loss vs. Number of GMM Components')
plt.savefig('gmm_component_testing.jpg')
plt.show()


In [None]:
model = VAE(input_dim, hidden_dim, latent_dim)
# load the weights
model.load_state_dict(torch.load(f'vgae_weights/gmm8_best.pt', map_location=device))

In [None]:
import matplotlib.pyplot as plt
from nilearn import plotting

# select a batch from the validation data loader
data, _ = next(iter(val_loader))

# pass the batch through the trained model to obtain the reconstructed output
recon, _, _, _ = model(data.view(-1, input_dim))

# reshape the output to a 100x100 matrix (assuming the input_dim is 100x100)
recon = recon.view(-1, 100, 100)

for i in range(3):
    # plot the original and reconstructed matrices for the first sample in the batch
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    plotting.plot_matrix(data[i], colorbar=True, vmax=0.8, vmin=-0.8, axes=ax1)
    ax1.set_title('Original')
    plotting.plot_matrix(recon[i].detach(), colorbar=True, vmax=0.8, vmin=-0.8, axes=ax2)
    ax2.set_title('Reconstructed')
    plt.show()

In [None]:
dataroot = 'fc_matrices/psilo_ica_100_before'
cwd = os.getcwd() + '/'

psilo_dataset = BrainGraphDataset(img_dir=cwd + dataroot,
                            annotations_file=cwd + annotations,
                            transform=None, extra_data=None, setting='no_label')

psilo_train_loader = DataLoader(psilo_dataset, batch_size=batch_size)

# select a batch from the validation data loader
data, _ = next(iter(psilo_train_loader))

# pass the batch through the trained model to obtain the reconstructed output
recon, _, _, _ = model(data.view(-1, input_dim))

# reshape the output to a 100x100 matrix (assuming the input_dim is 100x100)
recon = recon.view(-1, 100, 100)

for i in range(3):
    # plot the original and reconstructed matrices for the first sample in the batch
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    plotting.plot_matrix(data[i], colorbar=True, vmax=0.8, vmin=-0.8, axes=ax1)
    ax1.set_title('Original')
    plotting.plot_matrix(recon[i].detach(), colorbar=True, vmax=0.8, vmin=-0.8, axes=ax2)
    ax2.set_title('Reconstructed')
    plt.show()

In [None]:
model = VAE(input_dim, hidden_dim, latent_dim)

# set the model to evaluation model
model.eval()

# calculate the validation loss
val_losses = []
with torch.no_grad():
    for n_comp in range(2, 11):
        val_loss = 0.0
        model.load_state_dict(torch.load(f'vgae_weights/gmm{n_comp}_best.pt', map_location=device))
        for batch_idx, (data, _) in enumerate(psilo_train_loader):
            recon, mu, logvar, _ = model(data.view(-1, input_dim))
            loss = model.loss_function(recon, data.view(-1, input_dim), mu, logvar, n_components=n_comp)
            val_loss += loss.item()
        val_loss /= len(psilo_dataset)
        val_losses.append(val_loss)
        print(f'gmm_{n_comp}: {val_loss} loss')

In [None]:
n_components_list = list(range(2, 11))

# plot the validation loss for each n_components value
plt.plot(n_components_list, val_losses)
plt.xlabel('Number of Components')
plt.ylabel('Validation Loss')
plt.title('Validation Loss vs. Number of GMM Components')
plt.savefig('gmm_component_testing.jpg')
plt.show()

In [None]:
# define the hyperparameters
input_dim = 100 * 100  # size of the graph adjacency matrix
lr = 1e-3
batch_size = 128
num_epochs = 300

annotations = 'annotations.csv'

dataroot = 'fc_matrices/hcp_100_ica/'
cwd = os.getcwd() + '/'

dataset = BrainGraphDataset(img_dir=cwd + dataroot,
                            annotations_file=cwd + dataroot + annotations,
                            transform=None, extra_data=None, setting='no_label')

# define the data loaders
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# instantiate the model


from tqdm import tqdm

import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loss_curves = {}
best_val_losses = {}  # create a dictionary to store the best validation loss for each configuration

best_n = 3

# for hidden_dim in [256, 512]:
#     for latent_dim in [64, 128]:
#         train_losses = []
#         val_losses = []
#         model = VAE(input_dim, hidden_dim, latent_dim).to(device)  # move model to device
#         optimizer = optim.Adam(model.parameters(), lr=lr)
#         best_val_loss = float('inf')  # initialize the best validation loss to infinity
        
#         with open('gmm_train_overfit.txt', 'a') as f:
#             f.write(f'Hidden dim: {hidden_dim}, latent_dim: {latent_dim}\n')
        
#         for epoch in range(num_epochs):
#             train_loss = 0.0
#             val_loss = 0.0

#             # training
#             model.train()
#             # define the optimizer and the loss function

#             for batch_idx, (data, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
#                 data = data.to(device)  # move input data to device
#                 optimizer.zero_grad()

#                 recon, mu, logvar, z = model(data.view(-1, input_dim))
#                 loss = model.loss_function(recon, data.view(-1, input_dim), mu, logvar, n_components=best_n)
#                 loss.backward()
#                 optimizer.step()
#                 train_loss += loss.item()

#             # validation
#             model.eval()
#             with torch.no_grad():
#                 for batch_idx, (data, _) in tqdm(enumerate(psilo_train_loader), total=len(psilo_train_loader)):
#                     data = data.to(device)  # move input data to device
#                     recon, mu, logvar, z = model(data.view(-1, input_dim))
#                     loss = loss_function_gmm(recon, data.view(-1, input_dim), mu, logvar, n_components=best_n)
#                     val_loss += loss.item()

#             # append losses to lists
#             train_losses.append(train_loss/len(train_dataset))
#             val_losses.append(val_loss/len(psilo_dataset))

#             with open('gmm_train_overfit.txt', 'a') as f:
#                 f.write(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')
                
#             # update the best validation loss and save the model weights if it's the best so far for this configuration
#             if val_losses[-1] < best_val_loss:
#                 best_val_loss = val_losses[-1]
#                 best_val_losses[(hidden_dim, latent_dim)] = best_val_loss
#                 torch.save(model.state_dict(), f'vgae_weights/gmm_{best_n}_hidden{hidden_dim}_latent{latent_dim}.pt')

#         # plot the losses
#         plt.plot(val_losses, label=f'Validation Loss (hidden_dim={hidden_dim}, latent_dim={latent_dim})')
        
#                 # add the loss curves to the dictionary
#         loss_curves[f"hidden{hidden_dim}_latent_dim{latent_dim}"] = {"train_loss": train_losses, "val_loss": val_losses}

# # save the loss curves to a file
# with open("loss_curves_overfit_new.json", "w") as f:
#     json.dump(loss_curves, f)

# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()


In [None]:
import json
import matplotlib.pyplot as plt

# load in the loss curves from file
with open("loss_curves_overfit.json", "r") as f:
    loss_curves = json.load(f)

# plot the validation loss curves for each number of GMM components
plt.figure(figsize=(10, 8))
for n_comp, loss_dict in loss_curves.items():
    val_losses = loss_dict["val_loss"]
    epochs = range(1, len(val_losses) + 1)
    plt.plot(epochs, val_losses, label=f"{n_comp}")

# add labels and legend
plt.xlabel("Epoch")
plt.ylabel("Val Loss")
plt.title("Validation Loss Curves for Different Net Architectures")
plt.legend()
plt.ylim((350, 500))

# show the plot
plt.show()


In [None]:
# define the hyperparameters
input_dim = 100 * 100  # size of the graph adjacency matrix
hidden_dims = [256, 128, 64]
latent_dims = [64, 32, 16]
lr = 1e-3
batch_size = 128
num_epochs = 300

annotations = 'annotations.csv'

dataroot = 'fc_matrices/hcp_100_ica/'
cwd = os.getcwd() + '/'


# define the optimizer and the loss function
optimizer = optim.Adam(model.parameters(), lr=lr)

from tqdm import tqdm

import matplotlib.pyplot as plt

for hidden_dim in hidden_dims:
    for latent_dim in latent_dims:
        train_losses = []
        val_losses = []
        model = VAE(input_dim, hidden_dim, latent_dim)
        
        # load in the model weights
        model.load_state_dict(torch.load(f'vgae_weights/gmm_5_hidden{hidden_dim}_latent{latent_dim}.pt', map_location=device))
        
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for batch_idx, (data, _) in tqdm(enumerate(psilo_train_loader), total=len(psilo_train_loader)):
                recon, mu, logvar, _ = model(data.view(-1, input_dim))
                loss = model.loss_function(recon, data.view(-1, input_dim), mu, logvar, n_components=5)
                val_loss += loss.item()
            val_losses.append(val_loss/len(psilo_dataset))

        # print the validation loss for this configuration
        print(f'Hidden Dim: {hidden_dim}, Latent Dim: {latent_dim}, Validation Loss: {val_losses[-1]:.4f}')


In [None]:
hidden_dim = 256
latent_dim = 64
input_dim = 100 * 100

model = VAE(input_dim, hidden_dim, latent_dim)

model.load_state_dict(torch.load('vgae_weights/gmm3_best.pt', map_location=device))

psilo_zs = []
hcp_zs = []

with torch.no_grad():
    for batch_idx, (data, _) in enumerate(psilo_train_loader):
        _, _, _, z = model(data.view(-1, input_dim))
        psilo_zs.append(z)
    
    for batch_idx, (data, _) in enumerate(train_loader):
        _, _, _, z = model(data.view(-1, input_dim))
        hcp_zs.append(z)

psilo_zs = torch.cat(psilo_zs, dim=0)
hcp_zs = torch.cat(hcp_zs, dim=0)
        
# Concatenate the encoded representations and create labels
x = torch.cat((psilo_zs, hcp_zs), dim=0)
labels = torch.cat((torch.zeros(psilo_zs.shape[0]), torch.ones(hcp_zs.shape[0])), dim=0)

from sklearn.manifold import TSNE

for per in [30, 40, 50]:
    # Use t-SNE to reduce the dimensionality of the encoded representations
    tsne = TSNE(n_components=2, perplexity=per, n_iter=1000)
    x_tsne = tsne.fit_transform(x)

    # Plot the t-SNE embeddings
    plt.scatter(x_tsne[:, 0], x_tsne[:, 1], c=labels, cmap='coolwarm')
    plt.colorbar()
    plt.show()