In [2]:
import sys
sys.path.append('../')
from models import VGAE, LatentMLP 
from utils import BrainGraphDataset, project_root, get_data_labels
import torch
import torch.nn as nn
import torch.optim as optim
import os
from tqdm import tqdm
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
root = project_root()

# instantiate the VGAE model
lr = 0.001
batch_size = 64

nf = 1
ef = 1
num_nodes = 100
hidden_dim = 128
latent_size = 8

criterion = nn.L1Loss(reduction='sum')

categories = ['patient_n','condition','bdi_before']

data_labels = get_data_labels()
data_labels = data_labels[categories]

annotations = 'annotations.csv'

data_labels.loc[data_labels["condition"] == "P", "condition"] = 1
data_labels.loc[data_labels["condition"] == "E", "condition"] = -1
data_labels['condition'] = data_labels['condition'].astype('float64')

dataroot = 'fc_matrices/psilo_aal_before/'

dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=data_labels, setting='lz')

# Get the number of samples in the dataset
num_samples = len(dataset)

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold

# Assuming you have your dataset defined as 'dataset'
num_folds = 5  # Specify the number of folds
batch_size = 8  # Specify your desired batch size
random_seed = 42  # Specify the random seed

# Create indices for k-fold cross-validation with seeded random number generator
kf = KFold(n_splits=num_folds, shuffle=True, random_state=random_seed)

# Create empty lists to store train and validation loaders
train_loaders = []
val_loaders = []

for train_index, val_index in kf.split(dataset):
    # Split dataset into train and validation sets for the current fold
    train_set = torch.utils.data.Subset(dataset, train_index)
    val_set = torch.utils.data.Subset(dataset, val_index)

    # Define the dataloaders for the current fold
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

    # Append the loaders to the respective lists
    train_loaders.append(train_loader)
    val_loaders.append(val_loader)


num_epochs = 500

import json

# Dictionary to store training and validation curves
loss_curves = {}
best_set = [0] * num_folds
dropout_list = [0]

for i, train_loader in enumerate(train_loaders):
    val_loader = val_loaders[i]
    for dropout in dropout_list:
        
        vgae = VGAE(1, 1, 116, 64, 4, device, dropout=0, l2_strength=0.001, use_nf=False).to(device)
        # load the trained VGAE weights
        vgae.load_state_dict(torch.load(os.path.join(root, 'vgae_weights/vgae_no_nf_aal.pt'), map_location=device))

        # define the optimizer and the loss function


        # Convert the model to the device
        vgae.to(device)

        
        best_val_loss = float('inf')  # set to infinity to start
        best_model_state = None
        train_losses = []
        val_losses = []

        model = LatentMLP(64, 256, 1, dropout=dropout)
        optimizer = optim.Adam(model.parameters(), lr=lr)

        src, dest = vgae.edge_index

        for epoch in tqdm(range(num_epochs)):
            train_loss = 0.0
            val_loss = 0.0

            # training
            model.train()
            for batch_idx, ((graph, _, baseline_bdi), label) in enumerate(train_loader):
                graph = graph.to(device)  # move data to device
                label = label.to(device)
                optimizer.zero_grad()

                rcn_edges, z, _, _ = vgae(None, graph)
                graph = graph[:, src, dest]

                output_bdi = model(z.view(z.shape[0], -1), baseline_bdi)

                l1_loss, l2_loss = model.loss(output_bdi, label.view(output_bdi.shape))
                loss = l1_loss + l2_loss
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            # validation
            model.eval()
            val_label = []
            val_output = []
            val_base = []
            with torch.no_grad():
                for batch_idx, ((graph, _, baseline_bdi), label) in enumerate(val_loader):
                    graph = graph.to(device)  # move data to device
                    label = label.to(device)

                    rcn_edges, z, _, _ = vgae(None, graph)
                    graph = graph[:, src, dest]

                    output_bdi = model(z.view(z.shape[0], -1), baseline_bdi)
                    val_label.extend(label)
                    val_output.extend(output_bdi)
                    val_base.extend(baseline_bdi)
                    
                    l1_loss, l2_loss = model.loss(output_bdi, label.view(output_bdi.shape))
                    loss = l1_loss + l2_loss
                    val_loss += loss.item()
            # append losses to lists
            train_losses.append(train_loss/len(train_set))
            val_losses.append(val_loss/len(val_set))

            # save the model if the validation loss is at its minimum
            if val_losses[-1] < best_val_loss:
                best_val_loss = val_losses[-1]

                best_model_state = (copy.deepcopy(vgae.state_dict()), copy.deepcopy(model.state_dict()))
                best_set[i] = (val_label, val_output, val_base)
            # print the losses
            with open('dropout_train.txt', 'a') as f:
                if (epoch + 1) % 20 == 0:
                    print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')
                f.write(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')


  4%|▍         | 20/500 [00:05<02:07,  3.76it/s]

Epoch 20/500 - Train Loss: 5.8589 - Val Loss: 13.5989



  8%|▊         | 40/500 [00:10<02:02,  3.77it/s]

Epoch 40/500 - Train Loss: 5.6463 - Val Loss: 13.2267



 12%|█▏        | 60/500 [00:15<01:56,  3.76it/s]

Epoch 60/500 - Train Loss: 5.0909 - Val Loss: 13.1947



 16%|█▌        | 80/500 [00:21<01:51,  3.76it/s]

Epoch 80/500 - Train Loss: 5.1333 - Val Loss: 12.9343



 20%|██        | 100/500 [00:26<01:46,  3.76it/s]

Epoch 100/500 - Train Loss: 4.7024 - Val Loss: 13.8122



 24%|██▍       | 120/500 [00:31<01:41,  3.75it/s]

Epoch 120/500 - Train Loss: 4.5668 - Val Loss: 14.1776



 28%|██▊       | 140/500 [00:37<01:36,  3.74it/s]

Epoch 140/500 - Train Loss: 4.8656 - Val Loss: 14.5033



 32%|███▏      | 160/500 [00:42<01:30,  3.75it/s]

Epoch 160/500 - Train Loss: 4.3574 - Val Loss: 14.1017



 36%|███▌      | 180/500 [00:47<01:25,  3.73it/s]

Epoch 180/500 - Train Loss: 4.4693 - Val Loss: 14.7199



 40%|████      | 200/500 [00:53<01:20,  3.75it/s]

Epoch 200/500 - Train Loss: 4.4725 - Val Loss: 13.9408



 44%|████▍     | 220/500 [00:58<01:15,  3.73it/s]

Epoch 220/500 - Train Loss: 4.2291 - Val Loss: 13.7728



 48%|████▊     | 240/500 [01:03<01:08,  3.78it/s]

Epoch 240/500 - Train Loss: 4.1392 - Val Loss: 15.2306



 52%|█████▏    | 260/500 [01:09<01:03,  3.78it/s]

Epoch 260/500 - Train Loss: 4.3000 - Val Loss: 16.1075



 56%|█████▌    | 280/500 [01:14<00:58,  3.78it/s]

Epoch 280/500 - Train Loss: 3.9244 - Val Loss: 15.7764



 60%|██████    | 300/500 [01:19<00:53,  3.77it/s]

Epoch 300/500 - Train Loss: 4.2023 - Val Loss: 15.4120



 64%|██████▍   | 320/500 [01:25<00:47,  3.77it/s]

Epoch 320/500 - Train Loss: 4.1652 - Val Loss: 16.6914



 68%|██████▊   | 340/500 [01:30<00:42,  3.76it/s]

Epoch 340/500 - Train Loss: 4.0582 - Val Loss: 15.5009



 72%|███████▏  | 360/500 [01:35<00:37,  3.77it/s]

Epoch 360/500 - Train Loss: 4.0887 - Val Loss: 15.0967



 76%|███████▌  | 380/500 [01:41<00:31,  3.77it/s]

Epoch 380/500 - Train Loss: 3.8154 - Val Loss: 16.2527



 80%|████████  | 400/500 [01:46<00:26,  3.77it/s]

Epoch 400/500 - Train Loss: 3.7731 - Val Loss: 17.4305



 84%|████████▍ | 420/500 [01:51<00:21,  3.77it/s]

Epoch 420/500 - Train Loss: 3.5412 - Val Loss: 15.5062



 88%|████████▊ | 440/500 [01:57<00:15,  3.77it/s]

Epoch 440/500 - Train Loss: 3.5183 - Val Loss: 17.1098



 92%|█████████▏| 460/500 [02:02<00:10,  3.77it/s]

Epoch 460/500 - Train Loss: 3.4467 - Val Loss: 16.3404



 96%|█████████▌| 480/500 [02:07<00:05,  3.75it/s]

Epoch 480/500 - Train Loss: 3.6136 - Val Loss: 15.1915



100%|██████████| 500/500 [02:12<00:00,  3.76it/s]


Epoch 500/500 - Train Loss: 3.4817 - Val Loss: 17.1081



  4%|▍         | 20/500 [00:05<02:07,  3.76it/s]

Epoch 20/500 - Train Loss: 7.0115 - Val Loss: 7.3116



  8%|▊         | 40/500 [00:10<02:02,  3.77it/s]

Epoch 40/500 - Train Loss: 6.2646 - Val Loss: 8.8400



 12%|█▏        | 60/500 [00:16<01:58,  3.73it/s]

Epoch 60/500 - Train Loss: 6.2266 - Val Loss: 10.4294



 16%|█▌        | 80/500 [00:21<01:51,  3.78it/s]

Epoch 80/500 - Train Loss: 6.0398 - Val Loss: 9.5736



 20%|██        | 100/500 [00:26<01:46,  3.77it/s]

Epoch 100/500 - Train Loss: 6.0397 - Val Loss: 9.3212



 24%|██▍       | 120/500 [00:32<01:40,  3.77it/s]

Epoch 120/500 - Train Loss: 5.7004 - Val Loss: 7.8161



 28%|██▊       | 140/500 [00:37<01:38,  3.65it/s]

Epoch 140/500 - Train Loss: 5.4759 - Val Loss: 8.6891



 32%|███▏      | 160/500 [00:43<01:34,  3.60it/s]

Epoch 160/500 - Train Loss: 5.9516 - Val Loss: 7.6732



 36%|███▌      | 180/500 [00:48<01:26,  3.71it/s]

Epoch 180/500 - Train Loss: 5.8856 - Val Loss: 8.7271



 40%|████      | 200/500 [00:54<01:22,  3.65it/s]

Epoch 200/500 - Train Loss: 5.4696 - Val Loss: 8.4664



 44%|████▍     | 220/500 [00:59<01:17,  3.60it/s]

Epoch 220/500 - Train Loss: 5.2855 - Val Loss: 7.9650



 48%|████▊     | 240/500 [01:05<01:12,  3.59it/s]

Epoch 240/500 - Train Loss: 5.0495 - Val Loss: 8.9317



 52%|█████▏    | 260/500 [01:10<01:06,  3.59it/s]

Epoch 260/500 - Train Loss: 4.8104 - Val Loss: 10.1911



 56%|█████▌    | 280/500 [01:16<01:01,  3.56it/s]

Epoch 280/500 - Train Loss: 4.8856 - Val Loss: 10.0369



 60%|██████    | 300/500 [01:21<00:54,  3.68it/s]

Epoch 300/500 - Train Loss: 4.5848 - Val Loss: 11.6255



 64%|██████▍   | 320/500 [01:27<00:49,  3.65it/s]

Epoch 320/500 - Train Loss: 3.6227 - Val Loss: 13.1629



 68%|██████▊   | 340/500 [01:32<00:44,  3.63it/s]

Epoch 340/500 - Train Loss: 5.3982 - Val Loss: 12.1667



 72%|███████▏  | 360/500 [01:38<00:38,  3.65it/s]

Epoch 360/500 - Train Loss: 4.2854 - Val Loss: 11.1097



 76%|███████▌  | 380/500 [01:43<00:32,  3.69it/s]

Epoch 380/500 - Train Loss: 4.1502 - Val Loss: 11.9882



 80%|████████  | 400/500 [01:49<00:27,  3.68it/s]

Epoch 400/500 - Train Loss: 3.8847 - Val Loss: 9.8856



 84%|████████▍ | 420/500 [01:54<00:21,  3.70it/s]

Epoch 420/500 - Train Loss: 3.7079 - Val Loss: 12.2358



 88%|████████▊ | 440/500 [02:00<00:16,  3.64it/s]

Epoch 440/500 - Train Loss: 4.6173 - Val Loss: 11.2438



 92%|█████████▏| 460/500 [02:05<00:10,  3.67it/s]

Epoch 460/500 - Train Loss: 4.2326 - Val Loss: 12.0352



 96%|█████████▌| 480/500 [02:10<00:05,  3.60it/s]

Epoch 480/500 - Train Loss: 3.2652 - Val Loss: 10.5848



100%|██████████| 500/500 [02:16<00:00,  3.66it/s]


Epoch 500/500 - Train Loss: 3.9005 - Val Loss: 8.0789



  4%|▍         | 20/500 [00:05<02:11,  3.64it/s]

Epoch 20/500 - Train Loss: 7.1688 - Val Loss: 6.1229



  8%|▊         | 40/500 [00:11<02:08,  3.59it/s]

Epoch 40/500 - Train Loss: 7.4301 - Val Loss: 5.4209



 12%|█▏        | 60/500 [00:16<01:56,  3.78it/s]

Epoch 60/500 - Train Loss: 6.1433 - Val Loss: 5.5139



 16%|█▌        | 80/500 [00:21<01:50,  3.79it/s]

Epoch 80/500 - Train Loss: 6.6027 - Val Loss: 8.3903



 20%|██        | 100/500 [00:27<01:45,  3.79it/s]

Epoch 100/500 - Train Loss: 5.7001 - Val Loss: 6.2686



 24%|██▍       | 120/500 [00:32<01:40,  3.79it/s]

Epoch 120/500 - Train Loss: 5.4554 - Val Loss: 7.1588



 28%|██▊       | 140/500 [00:37<01:34,  3.79it/s]

Epoch 140/500 - Train Loss: 5.3422 - Val Loss: 8.9068



 32%|███▏      | 160/500 [00:42<01:29,  3.80it/s]

Epoch 160/500 - Train Loss: 5.0762 - Val Loss: 7.6567



 36%|███▌      | 180/500 [00:48<01:24,  3.79it/s]

Epoch 180/500 - Train Loss: 5.4939 - Val Loss: 7.4777



 40%|████      | 200/500 [00:53<01:19,  3.79it/s]

Epoch 200/500 - Train Loss: 5.0486 - Val Loss: 8.2365



 44%|████▍     | 220/500 [00:58<01:14,  3.75it/s]

Epoch 220/500 - Train Loss: 5.0668 - Val Loss: 9.0581



 48%|████▊     | 240/500 [01:04<01:09,  3.74it/s]

Epoch 240/500 - Train Loss: 4.7860 - Val Loss: 6.8870



 52%|█████▏    | 260/500 [01:09<01:03,  3.77it/s]

Epoch 260/500 - Train Loss: 5.0160 - Val Loss: 8.0051



 56%|█████▌    | 280/500 [01:14<00:58,  3.75it/s]

Epoch 280/500 - Train Loss: 5.9083 - Val Loss: 10.8193



 60%|██████    | 300/500 [01:20<00:53,  3.75it/s]

Epoch 300/500 - Train Loss: 4.5823 - Val Loss: 9.6794



 64%|██████▍   | 320/500 [01:25<00:48,  3.75it/s]

Epoch 320/500 - Train Loss: 4.4582 - Val Loss: 6.2911



 68%|██████▊   | 340/500 [01:30<00:42,  3.76it/s]

Epoch 340/500 - Train Loss: 4.5077 - Val Loss: 9.6057



 72%|███████▏  | 360/500 [01:36<00:37,  3.76it/s]

Epoch 360/500 - Train Loss: 4.1641 - Val Loss: 8.3885



 76%|███████▌  | 380/500 [01:41<00:32,  3.75it/s]

Epoch 380/500 - Train Loss: 3.6257 - Val Loss: 7.9329



 80%|████████  | 400/500 [01:46<00:26,  3.72it/s]

Epoch 400/500 - Train Loss: 4.3473 - Val Loss: 7.4643



 84%|████████▍ | 420/500 [01:52<00:21,  3.68it/s]

Epoch 420/500 - Train Loss: 4.2757 - Val Loss: 8.5884



 88%|████████▊ | 440/500 [01:57<00:15,  3.76it/s]

Epoch 440/500 - Train Loss: 4.5231 - Val Loss: 7.8250



 92%|█████████▏| 460/500 [02:02<00:10,  3.76it/s]

Epoch 460/500 - Train Loss: 2.7000 - Val Loss: 6.8498



 96%|█████████▌| 480/500 [02:08<00:05,  3.66it/s]

Epoch 480/500 - Train Loss: 4.2235 - Val Loss: 10.8174



100%|██████████| 500/500 [02:13<00:00,  3.74it/s]


Epoch 500/500 - Train Loss: 2.9289 - Val Loss: 11.7609



  4%|▍         | 20/500 [00:05<02:13,  3.61it/s]

Epoch 20/500 - Train Loss: 7.2169 - Val Loss: 5.5745



  8%|▊         | 40/500 [00:10<02:03,  3.72it/s]

Epoch 40/500 - Train Loss: 6.5594 - Val Loss: 8.6478



 12%|█▏        | 60/500 [00:16<01:58,  3.72it/s]

Epoch 60/500 - Train Loss: 6.2108 - Val Loss: 5.9307



 16%|█▌        | 80/500 [00:21<01:52,  3.72it/s]

Epoch 80/500 - Train Loss: 6.3682 - Val Loss: 6.9024



 20%|██        | 100/500 [00:27<01:47,  3.72it/s]

Epoch 100/500 - Train Loss: 6.0949 - Val Loss: 7.0091



 24%|██▍       | 120/500 [00:32<01:41,  3.73it/s]

Epoch 120/500 - Train Loss: 6.0181 - Val Loss: 7.1083



 28%|██▊       | 140/500 [00:37<01:36,  3.72it/s]

Epoch 140/500 - Train Loss: 5.8917 - Val Loss: 8.4766



 32%|███▏      | 160/500 [00:43<01:30,  3.74it/s]

Epoch 160/500 - Train Loss: 5.9752 - Val Loss: 8.0973



 36%|███▌      | 180/500 [00:48<01:25,  3.74it/s]

Epoch 180/500 - Train Loss: 5.9019 - Val Loss: 7.6664



 40%|████      | 200/500 [00:53<01:21,  3.70it/s]

Epoch 200/500 - Train Loss: 5.7751 - Val Loss: 7.2695



 44%|████▍     | 220/500 [00:59<01:16,  3.66it/s]

Epoch 220/500 - Train Loss: 6.2484 - Val Loss: 9.7361



 48%|████▊     | 240/500 [01:04<01:10,  3.68it/s]

Epoch 240/500 - Train Loss: 5.3404 - Val Loss: 7.5119



 52%|█████▏    | 260/500 [01:10<01:05,  3.64it/s]

Epoch 260/500 - Train Loss: 5.5823 - Val Loss: 8.1653



 56%|█████▌    | 280/500 [01:15<00:59,  3.67it/s]

Epoch 280/500 - Train Loss: 4.6932 - Val Loss: 8.1665



 60%|██████    | 300/500 [01:21<00:54,  3.69it/s]

Epoch 300/500 - Train Loss: 4.3333 - Val Loss: 8.4377



 64%|██████▍   | 320/500 [01:26<00:48,  3.68it/s]

Epoch 320/500 - Train Loss: 4.4864 - Val Loss: 8.7350



 68%|██████▊   | 340/500 [01:32<00:43,  3.69it/s]

Epoch 340/500 - Train Loss: 5.1645 - Val Loss: 9.6020



 72%|███████▏  | 360/500 [01:37<00:37,  3.70it/s]

Epoch 360/500 - Train Loss: 4.7001 - Val Loss: 9.1906



 76%|███████▌  | 380/500 [01:42<00:32,  3.70it/s]

Epoch 380/500 - Train Loss: 4.3083 - Val Loss: 8.1348



 80%|████████  | 400/500 [01:48<00:27,  3.70it/s]

Epoch 400/500 - Train Loss: 3.9547 - Val Loss: 11.6078



 84%|████████▍ | 420/500 [01:53<00:21,  3.71it/s]

Epoch 420/500 - Train Loss: 4.1625 - Val Loss: 9.6563



 88%|████████▊ | 440/500 [01:59<00:16,  3.74it/s]

Epoch 440/500 - Train Loss: 4.7492 - Val Loss: 6.3464



 92%|█████████▏| 460/500 [02:04<00:10,  3.75it/s]

Epoch 460/500 - Train Loss: 3.5216 - Val Loss: 9.2171



 96%|█████████▌| 480/500 [02:09<00:05,  3.74it/s]

Epoch 480/500 - Train Loss: 4.3328 - Val Loss: 8.9443



100%|██████████| 500/500 [02:15<00:00,  3.70it/s]


Epoch 500/500 - Train Loss: 4.3535 - Val Loss: 7.9925



  4%|▍         | 20/500 [00:05<02:08,  3.74it/s]

Epoch 20/500 - Train Loss: 7.4616 - Val Loss: 6.4120



  8%|▊         | 40/500 [00:10<02:03,  3.74it/s]

Epoch 40/500 - Train Loss: 6.8240 - Val Loss: 5.4791



 12%|█▏        | 60/500 [00:16<01:57,  3.75it/s]

Epoch 60/500 - Train Loss: 6.3830 - Val Loss: 5.4115



 16%|█▌        | 80/500 [00:21<01:51,  3.77it/s]

Epoch 80/500 - Train Loss: 5.9227 - Val Loss: 5.8748



 20%|██        | 100/500 [00:26<01:46,  3.75it/s]

Epoch 100/500 - Train Loss: 5.6633 - Val Loss: 5.8984



 24%|██▍       | 120/500 [00:32<01:41,  3.76it/s]

Epoch 120/500 - Train Loss: 6.0302 - Val Loss: 5.8701



 28%|██▊       | 140/500 [00:37<01:36,  3.75it/s]

Epoch 140/500 - Train Loss: 5.8672 - Val Loss: 7.3527



 32%|███▏      | 160/500 [00:42<01:30,  3.75it/s]

Epoch 160/500 - Train Loss: 5.8230 - Val Loss: 6.5607



 36%|███▌      | 180/500 [00:47<01:25,  3.76it/s]

Epoch 180/500 - Train Loss: 4.7380 - Val Loss: 8.2277



 40%|████      | 200/500 [00:53<01:19,  3.76it/s]

Epoch 200/500 - Train Loss: 5.2596 - Val Loss: 6.5938



 44%|████▍     | 220/500 [00:58<01:14,  3.76it/s]

Epoch 220/500 - Train Loss: 5.1478 - Val Loss: 8.1795



 48%|████▊     | 240/500 [01:03<01:09,  3.77it/s]

Epoch 240/500 - Train Loss: 5.4184 - Val Loss: 6.9910



 52%|█████▏    | 260/500 [01:09<01:03,  3.77it/s]

Epoch 260/500 - Train Loss: 5.4780 - Val Loss: 6.3161



 56%|█████▌    | 280/500 [01:14<00:58,  3.78it/s]

Epoch 280/500 - Train Loss: 4.5876 - Val Loss: 7.8684



 60%|██████    | 300/500 [01:19<00:53,  3.77it/s]

Epoch 300/500 - Train Loss: 5.6971 - Val Loss: 5.7358



 64%|██████▍   | 320/500 [01:25<00:47,  3.77it/s]

Epoch 320/500 - Train Loss: 5.0795 - Val Loss: 6.7670



 68%|██████▊   | 340/500 [01:30<00:42,  3.77it/s]

Epoch 340/500 - Train Loss: 4.9168 - Val Loss: 7.6568



 72%|███████▏  | 360/500 [01:35<00:37,  3.77it/s]

Epoch 360/500 - Train Loss: 5.3668 - Val Loss: 7.4544



 76%|███████▌  | 380/500 [01:41<00:31,  3.78it/s]

Epoch 380/500 - Train Loss: 4.5334 - Val Loss: 9.4526



 80%|████████  | 400/500 [01:46<00:26,  3.78it/s]

Epoch 400/500 - Train Loss: 4.1469 - Val Loss: 7.7111



 84%|████████▍ | 420/500 [01:51<00:21,  3.78it/s]

Epoch 420/500 - Train Loss: 4.9667 - Val Loss: 6.9118



 88%|████████▊ | 440/500 [01:56<00:15,  3.77it/s]

Epoch 440/500 - Train Loss: 3.9586 - Val Loss: 8.8874



 92%|█████████▏| 460/500 [02:02<00:10,  3.79it/s]

Epoch 460/500 - Train Loss: 4.2583 - Val Loss: 8.8122



 96%|█████████▌| 480/500 [02:07<00:05,  3.78it/s]

Epoch 480/500 - Train Loss: 3.4159 - Val Loss: 9.6784



100%|██████████| 500/500 [02:12<00:00,  3.76it/s]

Epoch 500/500 - Train Loss: 4.5342 - Val Loss: 8.0352






In [4]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import r2_score, mean_absolute_error

r2_scores = []
mae_scores = []
r_scores = []

l1 = []
l2 = []

for fold in best_set:
    l1.extend([t.item() for t in fold[0]])
    l2.extend([t.item() for t in fold[1]])

print(r2_score(l1, l2))
print(pearsonr(l1, l2))
print(mean_absolute_error(l1,l2))

print()

total_true = []
total_pred = []
total_drug = []
total_base = []

for i, (true, pred, base) in enumerate(best_set):
    true = [t.item() for t in true]
    pred = [p.item() for p in pred]
    drug = [d[0].item() for d in base]
    base = [d[1].item() for d in base]
    
    total_true.extend(true)
    total_pred.extend(pred)
    total_drug.extend(drug)
    total_base.extend(base)
    # Calculate R-squared (Pearson correlation coefficient)
    r2 = r2_score(true, pred)
    r2_scores.append(r2)

    r_scores.append(pearsonr(true, pred))
    
    # Calculate Mean Absolute Error (MAE)
    mae = mean_absolute_error(true, pred)
    mae_scores.append(mae)
    
        
# Specify the filename for the CSV file
filename = 'featureless-vgae-full-val-results-aal.csv'

# Create a list of rows with headers
rows = [['true_post_bdi', 'predicted_post_bdi', 'drug (1 for psilo)', 'base_bdi']]
for true, pred, drug, base in zip(total_true, total_pred, total_drug, total_base):
    rows.append([true, pred, drug, base])

import csv
# Write the rows to the CSV file
with open(filename, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(rows)
        

0.3226828678182607
PearsonRResult(statistic=0.5744138214246384, pvalue=6.951804004983459e-05)
6.109218472526187

