In [3]:
import sys
sys.path.append('../')
from models import VGAE, LatentMLP 
from utils import BrainGraphDataset, project_root, get_data_labels
import torch
import torch.nn as nn
import torch.optim as optim
import os
from tqdm import tqdm
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
root = project_root()

# instantiate the VGAE model
lr = 0.001
batch_size = 64

nf = 1
ef = 1
num_nodes = 100
hidden_dim = 128
latent_size = 8

criterion = nn.L1Loss(reduction='sum')


categories = ['patient_n','condition','bdi_before']

data_labels = get_data_labels()
data_labels = data_labels[categories]

annotations = 'annotations.csv'

data_labels.loc[data_labels["condition"] == "P", "condition"] = 1
data_labels.loc[data_labels["condition"] == "E", "condition"] = -1
data_labels['condition'] = data_labels['condition'].astype('float64')

dataroot = 'fc_matrices/psilo_schaefer_before/'

dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=data_labels, setting='lz')

# Get the number of samples in the dataset
num_samples = len(dataset)

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold

# Assuming you have your dataset defined as 'dataset'
num_folds = 5  # Specify the number of folds
batch_size = 8  # Specify your desired batch size
random_seed = 42  # Specify the random seed

# Create indices for k-fold cross-validation with seeded random number generator
kf = KFold(n_splits=num_folds, shuffle=True, random_state=random_seed)

# Create empty lists to store train and validation loaders
train_loaders = []
val_loaders = []

for train_index, val_index in kf.split(dataset):
    # Split dataset into train and validation sets for the current fold
    train_set = torch.utils.data.Subset(dataset, train_index)
    val_set = torch.utils.data.Subset(dataset, val_index)

    # Define the dataloaders for the current fold
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

    # Append the loaders to the respective lists
    train_loaders.append(train_loader)
    val_loaders.append(val_loader)


num_epochs = 500

import json

# Dictionary to store training and validation curves
loss_curves = {}
best_set = [0] * num_folds
dropout_list = [0]

for i, train_loader in enumerate(train_loaders):
    val_loader = val_loaders[i]
    for dropout in dropout_list:
        
        vgae = VGAE(1, 1, 100, 64, 4, device, dropout=0, l2_strength=0.001, use_nf=False).to(device)
        # load the trained VGAE weights
        vgae.load_state_dict(torch.load(os.path.join(root, 'vgae_weights/vgae_no_nf_schaefer.pt'), map_location=device))

        # define the optimizer and the loss function


        # Convert the model to the device
        vgae.to(device)

        
        best_val_loss = float('inf')  # set to infinity to start
        best_model_state = None
        train_losses = []
        val_losses = []

        model = LatentMLP(64, 256, 1, dropout=dropout)
        optimizer = optim.Adam(model.parameters(), lr=lr)

        src, dest = vgae.edge_index

        for epoch in tqdm(range(num_epochs)):
            train_loss = 0.0
            val_loss = 0.0

            # training
            model.train()
            for batch_idx, ((graph, _, baseline_bdi), label) in enumerate(train_loader):
                graph = graph.to(device)  # move data to device
                label = label.to(device)
                optimizer.zero_grad()

                rcn_edges, z, _, _ = vgae(None, graph)
                graph = graph[:, src, dest]

                output_bdi = model(z.view(z.shape[0], -1), baseline_bdi)

                l1_loss, l2_loss = model.loss(output_bdi, label.view(output_bdi.shape))
                loss = l1_loss + l2_loss
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            # validation
            model.eval()
            val_label = []
            val_output = []
            val_base = []
            with torch.no_grad():
                for batch_idx, ((graph, _, baseline_bdi), label) in enumerate(val_loader):
                    graph = graph.to(device)  # move data to device
                    label = label.to(device)

                    rcn_edges, z, _, _ = vgae(None, graph)
                    graph = graph[:, src, dest]

                    output_bdi = model(z.view(z.shape[0], -1), baseline_bdi)
                    val_label.extend(label)
                    val_output.extend(output_bdi)
                    val_base.extend(baseline_bdi)
                    
                    l1_loss, l2_loss = model.loss(output_bdi, label.view(output_bdi.shape))
                    loss = l1_loss + l2_loss
                    val_loss += loss.item()
            # append losses to lists
            train_losses.append(train_loss/len(train_set))
            val_losses.append(val_loss/len(val_set))

            # save the model if the validation loss is at its minimum
            if val_losses[-1] < best_val_loss:
                best_val_loss = val_losses[-1]

                best_model_state = (copy.deepcopy(vgae.state_dict()), copy.deepcopy(model.state_dict()))
                best_set[i] = (val_label, val_output, val_base)
            # print the losses
            with open('dropout_train.txt', 'a') as f:
                if (epoch + 1) % 20 == 0:
                    print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')
                f.write(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')

#         # save the best model for this configuration
#         torch.save(best_model_state[1], os.path.join(root, f'mlp_weights/vgae_model_weight_{i}_no_nf.pt'))
#         torch.save(best_model_state[0], os.path.join(root, f'mlp_weights/vgae_vgae_weight_{i}_no_nf.pt'))

#         # add the loss curves to the dictionary
#         loss_curves[f"dropout_{dropout}"] = {"train_loss": train_losses, "val_loss": val_losses}

#     # save the loss curves to a file
#     with open("loss_curves_mlp_vgae.json", "w") as f:
#         json.dump(loss_curves, f)

  4%|▍         | 20/500 [00:04<01:45,  4.53it/s]

Epoch 20/500 - Train Loss: 5.9570 - Val Loss: 12.5374



  8%|▊         | 40/500 [00:08<01:41,  4.54it/s]

Epoch 40/500 - Train Loss: 5.4337 - Val Loss: 11.5747



 12%|█▏        | 60/500 [00:13<01:36,  4.56it/s]

Epoch 60/500 - Train Loss: 5.6681 - Val Loss: 10.7687



 16%|█▌        | 80/500 [00:17<01:31,  4.57it/s]

Epoch 80/500 - Train Loss: 5.4516 - Val Loss: 10.6780



 20%|██        | 100/500 [00:21<01:26,  4.64it/s]

Epoch 100/500 - Train Loss: 5.2727 - Val Loss: 11.0040



 24%|██▍       | 120/500 [00:26<01:21,  4.64it/s]

Epoch 120/500 - Train Loss: 5.2290 - Val Loss: 11.0004



 28%|██▊       | 140/500 [00:30<01:17,  4.63it/s]

Epoch 140/500 - Train Loss: 5.4822 - Val Loss: 11.4652



 32%|███▏      | 160/500 [00:34<01:14,  4.57it/s]

Epoch 160/500 - Train Loss: 5.2486 - Val Loss: 11.4032



 36%|███▌      | 180/500 [00:39<01:09,  4.58it/s]

Epoch 180/500 - Train Loss: 5.6204 - Val Loss: 12.3314



 40%|████      | 200/500 [00:43<01:05,  4.60it/s]

Epoch 200/500 - Train Loss: 5.7027 - Val Loss: 11.5381



 44%|████▍     | 220/500 [00:48<01:02,  4.49it/s]

Epoch 220/500 - Train Loss: 5.8190 - Val Loss: 11.3112



 48%|████▊     | 240/500 [00:52<00:58,  4.48it/s]

Epoch 240/500 - Train Loss: 5.3141 - Val Loss: 11.8890



 52%|█████▏    | 260/500 [00:57<00:53,  4.48it/s]

Epoch 260/500 - Train Loss: 5.0176 - Val Loss: 11.8572



 56%|█████▌    | 280/500 [01:01<00:48,  4.51it/s]

Epoch 280/500 - Train Loss: 5.4493 - Val Loss: 11.3419



 60%|██████    | 300/500 [01:05<00:44,  4.50it/s]

Epoch 300/500 - Train Loss: 5.1849 - Val Loss: 11.4481



 64%|██████▍   | 320/500 [01:10<00:40,  4.48it/s]

Epoch 320/500 - Train Loss: 5.2702 - Val Loss: 11.5000



 68%|██████▊   | 340/500 [01:14<00:35,  4.49it/s]

Epoch 340/500 - Train Loss: 4.9322 - Val Loss: 11.4171



 72%|███████▏  | 360/500 [01:19<00:31,  4.50it/s]

Epoch 360/500 - Train Loss: 5.2220 - Val Loss: 11.4953



 76%|███████▌  | 380/500 [01:23<00:26,  4.51it/s]

Epoch 380/500 - Train Loss: 4.6183 - Val Loss: 11.1647



 80%|████████  | 400/500 [01:28<00:22,  4.49it/s]

Epoch 400/500 - Train Loss: 5.1151 - Val Loss: 11.3828



 84%|████████▍ | 420/500 [01:32<00:17,  4.49it/s]

Epoch 420/500 - Train Loss: 4.9643 - Val Loss: 11.4208



 88%|████████▊ | 440/500 [01:37<00:13,  4.50it/s]

Epoch 440/500 - Train Loss: 4.9925 - Val Loss: 12.1590



 92%|█████████▏| 460/500 [01:41<00:08,  4.51it/s]

Epoch 460/500 - Train Loss: 4.7204 - Val Loss: 11.5570



 96%|█████████▌| 480/500 [01:45<00:04,  4.49it/s]

Epoch 480/500 - Train Loss: 5.1708 - Val Loss: 11.8302



100%|██████████| 500/500 [01:50<00:00,  4.53it/s]


Epoch 500/500 - Train Loss: 4.9459 - Val Loss: 11.7085



  4%|▍         | 20/500 [00:04<01:46,  4.49it/s]

Epoch 20/500 - Train Loss: 7.1313 - Val Loss: 7.4063



  8%|▊         | 40/500 [00:08<01:42,  4.50it/s]

Epoch 40/500 - Train Loss: 6.8374 - Val Loss: 7.4198



 12%|█▏        | 60/500 [00:13<01:37,  4.53it/s]

Epoch 60/500 - Train Loss: 6.3160 - Val Loss: 7.3092



 16%|█▌        | 80/500 [00:17<01:32,  4.53it/s]

Epoch 80/500 - Train Loss: 6.7338 - Val Loss: 7.7136



 20%|██        | 100/500 [00:22<01:28,  4.51it/s]

Epoch 100/500 - Train Loss: 6.2553 - Val Loss: 7.4122



 24%|██▍       | 120/500 [00:26<01:24,  4.52it/s]

Epoch 120/500 - Train Loss: 6.4093 - Val Loss: 7.8241



 28%|██▊       | 140/500 [00:31<01:19,  4.51it/s]

Epoch 140/500 - Train Loss: 5.6184 - Val Loss: 7.3314



 32%|███▏      | 160/500 [00:35<01:15,  4.53it/s]

Epoch 160/500 - Train Loss: 5.9069 - Val Loss: 7.9450



 36%|███▌      | 180/500 [00:39<01:10,  4.51it/s]

Epoch 180/500 - Train Loss: 6.3001 - Val Loss: 7.0514



 40%|████      | 200/500 [00:44<01:06,  4.49it/s]

Epoch 200/500 - Train Loss: 6.0498 - Val Loss: 7.3819



 44%|████▍     | 220/500 [00:48<01:02,  4.50it/s]

Epoch 220/500 - Train Loss: 5.9446 - Val Loss: 7.9123



 48%|████▊     | 240/500 [00:53<00:58,  4.47it/s]

Epoch 240/500 - Train Loss: 6.1590 - Val Loss: 9.1796



 52%|█████▏    | 260/500 [00:57<00:53,  4.49it/s]

Epoch 260/500 - Train Loss: 6.0157 - Val Loss: 7.9657



 56%|█████▌    | 280/500 [01:02<00:48,  4.50it/s]

Epoch 280/500 - Train Loss: 6.1839 - Val Loss: 9.0609



 60%|██████    | 300/500 [01:06<00:44,  4.50it/s]

Epoch 300/500 - Train Loss: 6.3756 - Val Loss: 7.7007



 64%|██████▍   | 320/500 [01:11<00:40,  4.47it/s]

Epoch 320/500 - Train Loss: 5.9242 - Val Loss: 8.7054



 68%|██████▊   | 340/500 [01:15<00:35,  4.51it/s]

Epoch 340/500 - Train Loss: 5.6783 - Val Loss: 8.1376



 72%|███████▏  | 360/500 [01:19<00:31,  4.49it/s]

Epoch 360/500 - Train Loss: 6.0476 - Val Loss: 7.9634



 76%|███████▌  | 380/500 [01:24<00:26,  4.48it/s]

Epoch 380/500 - Train Loss: 6.3894 - Val Loss: 8.8588



 80%|████████  | 400/500 [01:28<00:22,  4.51it/s]

Epoch 400/500 - Train Loss: 5.3840 - Val Loss: 9.1227



 84%|████████▍ | 420/500 [01:33<00:18,  4.35it/s]

Epoch 420/500 - Train Loss: 5.1787 - Val Loss: 7.8648



 88%|████████▊ | 440/500 [01:38<00:14,  4.07it/s]

Epoch 440/500 - Train Loss: 5.5351 - Val Loss: 7.9512



 92%|█████████▏| 460/500 [01:43<00:10,  3.98it/s]

Epoch 460/500 - Train Loss: 5.2715 - Val Loss: 9.1008



 96%|█████████▌| 480/500 [01:48<00:04,  4.08it/s]

Epoch 480/500 - Train Loss: 5.7945 - Val Loss: 7.7730



100%|██████████| 500/500 [01:52<00:00,  4.43it/s]


Epoch 500/500 - Train Loss: 5.5174 - Val Loss: 8.5649



  4%|▍         | 20/500 [00:05<01:59,  4.02it/s]

Epoch 20/500 - Train Loss: 7.4825 - Val Loss: 6.0558



  8%|▊         | 40/500 [00:09<01:54,  4.02it/s]

Epoch 40/500 - Train Loss: 6.9931 - Val Loss: 5.0643



 12%|█▏        | 60/500 [00:14<01:50,  3.98it/s]

Epoch 60/500 - Train Loss: 6.5367 - Val Loss: 4.7058



 16%|█▌        | 80/500 [00:19<01:44,  4.03it/s]

Epoch 80/500 - Train Loss: 6.5340 - Val Loss: 5.7325



 20%|██        | 100/500 [00:24<01:29,  4.46it/s]

Epoch 100/500 - Train Loss: 6.4776 - Val Loss: 8.1316



 24%|██▍       | 120/500 [00:29<01:23,  4.54it/s]

Epoch 120/500 - Train Loss: 6.1413 - Val Loss: 4.8177



 28%|██▊       | 140/500 [00:33<01:19,  4.54it/s]

Epoch 140/500 - Train Loss: 6.0171 - Val Loss: 6.4207



 32%|███▏      | 160/500 [00:37<01:14,  4.54it/s]

Epoch 160/500 - Train Loss: 6.5250 - Val Loss: 7.1291



 36%|███▌      | 180/500 [00:42<01:10,  4.54it/s]

Epoch 180/500 - Train Loss: 6.0360 - Val Loss: 5.6544



 40%|████      | 200/500 [00:46<01:06,  4.54it/s]

Epoch 200/500 - Train Loss: 6.0248 - Val Loss: 7.5446



 44%|████▍     | 220/500 [00:51<01:01,  4.56it/s]

Epoch 220/500 - Train Loss: 6.0598 - Val Loss: 7.5569



 48%|████▊     | 240/500 [00:55<00:57,  4.53it/s]

Epoch 240/500 - Train Loss: 5.8670 - Val Loss: 7.6064



 52%|█████▏    | 260/500 [00:59<00:53,  4.53it/s]

Epoch 260/500 - Train Loss: 6.0760 - Val Loss: 6.4991



 56%|█████▌    | 280/500 [01:04<00:48,  4.57it/s]

Epoch 280/500 - Train Loss: 5.7885 - Val Loss: 7.5949



 60%|██████    | 300/500 [01:08<00:44,  4.53it/s]

Epoch 300/500 - Train Loss: 5.9392 - Val Loss: 7.9036



 64%|██████▍   | 320/500 [01:13<00:39,  4.55it/s]

Epoch 320/500 - Train Loss: 6.2737 - Val Loss: 5.8851



 68%|██████▊   | 340/500 [01:17<00:35,  4.54it/s]

Epoch 340/500 - Train Loss: 6.1995 - Val Loss: 6.2325



 72%|███████▏  | 360/500 [01:21<00:30,  4.52it/s]

Epoch 360/500 - Train Loss: 6.1969 - Val Loss: 6.4805



 76%|███████▌  | 380/500 [01:26<00:26,  4.57it/s]

Epoch 380/500 - Train Loss: 6.2558 - Val Loss: 7.7446



 80%|████████  | 400/500 [01:30<00:22,  4.53it/s]

Epoch 400/500 - Train Loss: 5.7126 - Val Loss: 7.6664



 84%|████████▍ | 420/500 [01:35<00:17,  4.54it/s]

Epoch 420/500 - Train Loss: 6.0938 - Val Loss: 5.6523



 88%|████████▊ | 440/500 [01:39<00:13,  4.57it/s]

Epoch 440/500 - Train Loss: 5.5606 - Val Loss: 6.2038



 92%|█████████▏| 460/500 [01:43<00:08,  4.54it/s]

Epoch 460/500 - Train Loss: 5.9605 - Val Loss: 6.4663



 96%|█████████▌| 480/500 [01:48<00:04,  4.55it/s]

Epoch 480/500 - Train Loss: 5.9085 - Val Loss: 6.8357



100%|██████████| 500/500 [01:52<00:00,  4.43it/s]


Epoch 500/500 - Train Loss: 5.8417 - Val Loss: 6.6281



  4%|▍         | 20/500 [00:04<01:45,  4.54it/s]

Epoch 20/500 - Train Loss: 7.5317 - Val Loss: 5.7427



  8%|▊         | 40/500 [00:08<01:40,  4.56it/s]

Epoch 40/500 - Train Loss: 7.1484 - Val Loss: 5.8731



 12%|█▏        | 60/500 [00:13<01:36,  4.54it/s]

Epoch 60/500 - Train Loss: 7.1749 - Val Loss: 6.4256



 16%|█▌        | 80/500 [00:17<01:32,  4.53it/s]

Epoch 80/500 - Train Loss: 6.4819 - Val Loss: 6.5603



 20%|██        | 100/500 [00:22<01:28,  4.53it/s]

Epoch 100/500 - Train Loss: 6.1276 - Val Loss: 8.2616



 24%|██▍       | 120/500 [00:26<01:23,  4.53it/s]

Epoch 120/500 - Train Loss: 5.9045 - Val Loss: 9.1812



 28%|██▊       | 140/500 [00:30<01:19,  4.55it/s]

Epoch 140/500 - Train Loss: 6.4624 - Val Loss: 8.0681



 32%|███▏      | 160/500 [00:35<01:15,  4.52it/s]

Epoch 160/500 - Train Loss: 6.3506 - Val Loss: 11.3812



 36%|███▌      | 180/500 [00:39<01:10,  4.54it/s]

Epoch 180/500 - Train Loss: 5.7298 - Val Loss: 8.0571



 40%|████      | 200/500 [00:44<01:06,  4.54it/s]

Epoch 200/500 - Train Loss: 5.6387 - Val Loss: 8.8667



 44%|████▍     | 220/500 [00:48<01:01,  4.56it/s]

Epoch 220/500 - Train Loss: 5.7126 - Val Loss: 8.9034



 48%|████▊     | 240/500 [00:52<00:57,  4.55it/s]

Epoch 240/500 - Train Loss: 6.0347 - Val Loss: 9.3005



 52%|█████▏    | 260/500 [00:57<00:52,  4.56it/s]

Epoch 260/500 - Train Loss: 5.4680 - Val Loss: 7.4644



 56%|█████▌    | 280/500 [01:01<00:48,  4.54it/s]

Epoch 280/500 - Train Loss: 5.8504 - Val Loss: 8.2088



 60%|██████    | 300/500 [01:05<00:43,  4.58it/s]

Epoch 300/500 - Train Loss: 5.7526 - Val Loss: 12.6968



 64%|██████▍   | 320/500 [01:10<00:38,  4.70it/s]

Epoch 320/500 - Train Loss: 5.3830 - Val Loss: 8.2768



 68%|██████▊   | 340/500 [01:15<00:39,  4.02it/s]

Epoch 340/500 - Train Loss: 5.4531 - Val Loss: 9.6851



 72%|███████▏  | 360/500 [01:19<00:30,  4.55it/s]

Epoch 360/500 - Train Loss: 6.0820 - Val Loss: 9.6060



 76%|███████▌  | 380/500 [01:23<00:26,  4.59it/s]

Epoch 380/500 - Train Loss: 5.2244 - Val Loss: 12.2424



 80%|████████  | 400/500 [01:28<00:21,  4.64it/s]

Epoch 400/500 - Train Loss: 5.7268 - Val Loss: 8.2060



 84%|████████▍ | 420/500 [01:32<00:16,  4.72it/s]

Epoch 420/500 - Train Loss: 5.1644 - Val Loss: 8.1613



 88%|████████▊ | 440/500 [01:36<00:12,  4.63it/s]

Epoch 440/500 - Train Loss: 5.4980 - Val Loss: 9.5244



 92%|█████████▏| 460/500 [01:41<00:08,  4.66it/s]

Epoch 460/500 - Train Loss: 5.8926 - Val Loss: 11.4152



 96%|█████████▌| 480/500 [01:45<00:04,  4.61it/s]

Epoch 480/500 - Train Loss: 5.1657 - Val Loss: 7.8880



100%|██████████| 500/500 [01:49<00:00,  4.56it/s]


Epoch 500/500 - Train Loss: 4.0025 - Val Loss: 10.4363



  4%|▍         | 20/500 [00:04<01:42,  4.67it/s]

Epoch 20/500 - Train Loss: 7.4355 - Val Loss: 6.1045



  8%|▊         | 40/500 [00:08<01:39,  4.63it/s]

Epoch 40/500 - Train Loss: 6.9936 - Val Loss: 5.3580



 12%|█▏        | 60/500 [00:12<01:34,  4.65it/s]

Epoch 60/500 - Train Loss: 7.3149 - Val Loss: 5.7769



 16%|█▌        | 80/500 [00:17<01:30,  4.63it/s]

Epoch 80/500 - Train Loss: 6.8148 - Val Loss: 6.3956



 20%|██        | 100/500 [00:21<01:26,  4.64it/s]

Epoch 100/500 - Train Loss: 7.0445 - Val Loss: 7.9261



 24%|██▍       | 120/500 [00:25<01:22,  4.62it/s]

Epoch 120/500 - Train Loss: 6.0657 - Val Loss: 7.0850



 28%|██▊       | 140/500 [00:30<01:18,  4.58it/s]

Epoch 140/500 - Train Loss: 6.6905 - Val Loss: 7.1615



 32%|███▏      | 160/500 [00:34<01:12,  4.66it/s]

Epoch 160/500 - Train Loss: 6.3069 - Val Loss: 6.8960



 36%|███▌      | 180/500 [00:38<01:08,  4.65it/s]

Epoch 180/500 - Train Loss: 5.9237 - Val Loss: 5.9710



 40%|████      | 200/500 [00:43<01:03,  4.70it/s]

Epoch 200/500 - Train Loss: 5.6789 - Val Loss: 6.1874



 44%|████▍     | 220/500 [00:47<00:59,  4.70it/s]

Epoch 220/500 - Train Loss: 5.7386 - Val Loss: 6.8231



 48%|████▊     | 240/500 [00:51<00:55,  4.67it/s]

Epoch 240/500 - Train Loss: 5.8666 - Val Loss: 5.7016



 52%|█████▏    | 260/500 [00:55<00:50,  4.76it/s]

Epoch 260/500 - Train Loss: 5.5018 - Val Loss: 5.9300



 56%|█████▌    | 280/500 [01:00<00:46,  4.75it/s]

Epoch 280/500 - Train Loss: 5.3982 - Val Loss: 8.6107



 60%|██████    | 300/500 [01:04<00:42,  4.70it/s]

Epoch 300/500 - Train Loss: 5.5383 - Val Loss: 8.5765



 64%|██████▍   | 320/500 [01:08<00:38,  4.68it/s]

Epoch 320/500 - Train Loss: 5.2135 - Val Loss: 7.6817



 68%|██████▊   | 340/500 [01:12<00:33,  4.71it/s]

Epoch 340/500 - Train Loss: 5.5145 - Val Loss: 9.3881



 72%|███████▏  | 360/500 [01:17<00:29,  4.72it/s]

Epoch 360/500 - Train Loss: 5.0749 - Val Loss: 6.0135



 76%|███████▌  | 380/500 [01:21<00:25,  4.70it/s]

Epoch 380/500 - Train Loss: 4.9859 - Val Loss: 8.8618



 80%|████████  | 400/500 [01:25<00:21,  4.69it/s]

Epoch 400/500 - Train Loss: 5.7856 - Val Loss: 7.7645



 84%|████████▍ | 420/500 [01:29<00:17,  4.65it/s]

Epoch 420/500 - Train Loss: 5.0672 - Val Loss: 6.5851



 88%|████████▊ | 440/500 [01:34<00:12,  4.65it/s]

Epoch 440/500 - Train Loss: 4.7485 - Val Loss: 10.2516



 92%|█████████▏| 460/500 [01:38<00:08,  4.67it/s]

Epoch 460/500 - Train Loss: 4.9950 - Val Loss: 7.9866



 96%|█████████▌| 480/500 [01:42<00:04,  4.63it/s]

Epoch 480/500 - Train Loss: 5.0925 - Val Loss: 9.5367



100%|██████████| 500/500 [01:46<00:00,  4.68it/s]

Epoch 500/500 - Train Loss: 4.6097 - Val Loss: 8.1343






In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import r2_score, mean_absolute_error

r2_scores = []
mae_scores = []
r_scores = []

l1 = []
l2 = []

colors = ['red', 'blue', 'green', 'orange', 'purple']

for fold in best_set:
    l1.extend([t.item() for t in fold[0]])
    l2.extend([t.item() for t in fold[1]])


print(r2_score(l1, l2))
print(pearsonr(l1, l2))
print(mean_absolute_error(l1,l2))

print()

total_true = []
total_pred = []
total_drug = []
total_base = []

for i, (true, pred, base) in enumerate(best_set):
    true = [t.item() for t in true]
    pred = [p.item() for p in pred]
    drug = [d[0].item() for d in base]
    base = [d[1].item() for d in base]
    
    total_true.extend(true)
    total_pred.extend(pred)
    total_drug.extend(drug)
    total_base.extend(base)
    # Calculate R-squared (Pearson correlation coefficient)
    r2 = r2_score(true, pred)
    r2_scores.append(r2)

    r_scores.append(pearsonr(true, pred))
    
    # Calculate Mean Absolute Error (MAE)
    mae = mean_absolute_error(true, pred)
    mae_scores.append(mae)
    
    for t, p, d in zip(true, pred, drug):
        marker = 'x' if d == -1 else 'o'
        plt.scatter(t, p, color=colors[i], marker=marker)

        
min_val = min(min(true), min(pred))
max_val = max(max(true), max(pred))
plt.plot([min_val, max_val], [min_val, max_val], 'k--')
        
plt.xlabel('True BDI')
plt.ylabel('Predicted BDI')
plt.title('MLP trained with featureless VGAE')
plt.legend(['Psilocybin', 'Escitalopram'])
plt.savefig('mlp_featureless_vgae.png', bbox_inches='tight')
plt.show()

# Calculate average R-squared and MAE across all folds
avg_r2 = np.mean(r2_scores)
avg_mae = np.mean(mae_scores)

import csv
head = ['r2', 'pearson_r', 'pval']
with open('no_nf.csv', 'w', newline='') as f:
    writer = csv.writer(f)

    # Write the header
    writer.writerow(head)

    # Write the data rows
    for i in range(5):
        writer.writerow([r2_scores[i], r_scores[i][0], r_scores[i][1]])
        
# Specify the filename for the CSV file
filename = 'featureless-vgae-full-val-results.csv'

# Create a list of rows with headers
rows = [['true_post_bdi', 'predicted_post_bdi', 'drug (1 for psilo)', 'base_bdi']]
for true, pred, drug, base in zip(total_true, total_pred, total_drug, total_base):
    rows.append([true, pred, drug, base])

# Write the rows to the CSV file
with open(filename, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(rows)
        
print('Average R-squared:', avg_r2)
print('Average MAE:', avg_mae)

NameError: name 'best_set' is not defined