In [2]:
import sys
sys.path.append('../')
from scipy.stats import pearsonr
from models import LatentMLP, VAE
from utils import BrainGraphDataset, project_root
import torch
import torch.optim as optim
import os
import torch.nn as nn
import copy
import numpy as np
from sklearn.metrics import mean_absolute_error, r2_score


root = project_root()
annotations = 'annotations.csv'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataroot = 'fc_matrices/psilo_ica_100_before/'
psilo_ica_before_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular_and_baseline')

dataroot = 'fc_matrices/psilo_schaefer_before/'
psilo_schaefer_before_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular_and_baseline')

dataroot = 'fc_matrices/psilo_aal_before/'
psilo_aal_before_dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=None, setting='upper_triangular_and_baseline')

configs = [
    (psilo_ica_before_dataset, 'vae_fine_tune_before_dropout_0.pt', 'fine_tune_before'),
    (psilo_ica_before_dataset, 'vae_fine_tune_combined_dropout_0.pt', 'fine_tune_combined'),
    (psilo_ica_before_dataset, 'vae_dropout_psilo_ica_before_0.pt', 'ica_before'),
    (psilo_ica_before_dataset, 'vae_dropout_psilo_ica_combined_0.pt', 'ica_combined'),
    (psilo_schaefer_before_dataset, 'vae_dropout_psilo_schaefer_before_0.pt', 'schaefer_before'),
    (psilo_schaefer_before_dataset, 'vae_dropout_psilo_schaefer_combined_0.pt', 'schaefer_combined'),
    (psilo_aal_before_dataset, 'vae_dropout_psilo_aal_before_0.pt', 'aal_before'),
    (psilo_aal_before_dataset, 'vae_dropout_psilo_aal_combined_0.pt', 'aal_combined'),
]

results = []
values = {}

for config in configs:
    # instantiate the VGAE model
    hidden_dim = 256
    latent_dim = 64
    output_dim = 1
    input_dim = 6670 if 'aal' in config[1] else 4950
    lr = 0.001
    batch_size = 8

    dataset = config[0]

    # Define the train, validation, and test ratios
    train_ratio = 0.6
    val_ratio = 0.2
    test_ratio = 0.2

    # Get the number of samples in the dataset
    num_samples = len(dataset)

    # Calculate the number of samples for each set
    train_size = int(train_ratio * num_samples)
    val_size = int(val_ratio * num_samples)
    test_size = num_samples - train_size - val_size

    torch.manual_seed(0)
    # Split the dataset into train, validation, and test sets
    train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

    test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_size, shuffle=False)

    dropout_list = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
    for dropout in dropout_list:
        
        vae = VAE(input_dim, [128] * 2, latent_dim)

        # load the trained VGAE weights
        with torch.no_grad():
            vae.load_state_dict(torch.load(os.path.join(root, f'mlp_weights/vae_unfrozen_dropout_{dropout}_{config[2]}.pt'),
                                           map_location=device))

        # Convert the model to the device
        vae.to(device)
        # instantiate the LatentMLP model
        mlp = LatentMLP(latent_dim, hidden_dim, output_dim, dropout=dropout)
        mlp.load_state_dict(torch.load(os.path.join(root, f'mlp_weights/mlp_weight_dropout_{dropout}_{config[2]}.pt'), 
                                       map_location=device))
        # Convert the MLP to the device
        mlp.to(device)

        test_loss = 0.0

        mlp.eval()
        vae.eval()
        with torch.no_grad():
            for data in test_loader:
                (graphs, base_bdis), labels = data

                graphs = graphs.to(device)
                base_bdis = base_bdis.to(device)

                labels = labels.to(device).float()

                # get the latent embeddings from the VGAE
                _, _, _, zs = vae(graphs.view(-1, input_dim))

                # pass the latent embeddings through the MLP
                outputs = mlp(zs, base_bdis)
                
                labels = labels.view(-1).cpu().numpy()
                outputs = outputs.view(-1).cpu().numpy()
                
                mae = mean_absolute_error(labels, outputs)
                corr_coeffs = pearsonr(labels, outputs)
                r2 = r2_score(labels, outputs)
                
                results.append((config[2], dropout, mae, corr_coeffs[0], corr_coeffs[1], r2))
                if r2 > 0:
                    print(results[-1])
                values[f'{config[2]}_{dropout}'] = {'ground': labels.tolist(), 'predicted': outputs.tolist()}
               
import csv
import json

with open(os.path.join(root, 'mlp_weights', 'test_results.json'), 'w') as f:
    json.dump(values, f)

csv_filename = os.path.join(root, 'mlp_weights', 'mlp_full_config_results.csv')
# Write the results to the CSV file
with open(csv_filename, mode='w', newline='') as file:
    writer = csv.writer(file)
    
    # Write the header row
    writer.writerow(["Config", "Dropout", "MAE", "Pearson stat", "Pearson p-value", "R2 Score"])
    
    # Write the data rows
    for result in results:
        writer.writerow(result)

('fine_tune_before', 0.45, 6.2285314, 0.5320751703089233, 0.14033960456124286, 0.06733461775119665)
('fine_tune_before', 0.5, 6.3827367, 0.5830962714496242, 0.09935662996610997, 0.05312920055723558)
('fine_tune_combined', 0.45, 6.24588, 0.48251874643327575, 0.18832422899704301, 0.09181052314527094)
('fine_tune_combined', 0.5, 6.2818613, 0.49236157777153045, 0.1781474676598453, 0.08599706078719171)
('ica_before', 0.45, 6.092506, 0.2848418323804986, 0.4575402602606016, 0.03977892605166378)
('ica_before', 0.5, 5.646469, 0.6835828352689789, 0.04233690914899893, 0.1680122220903334)
('ica_combined', 0.5, 5.8798594, 0.175862014080338, 0.6508422867853887, 0.017317424335925646)
('schaefer_before', 0.45, 6.2384586, 0.6815831229165346, 0.04318611989783615, 0.07273276718176991)
('schaefer_before', 0.5, 5.165992, 0.48808625478486123, 0.18252849924380182, 0.1338191377015293)
('aal_before', 0.45, 5.2959704, 0.28148357584602346, 0.46309161465124954, 0.07108623714590145)
('aal_combined', 0.45, 6.008455