In [1]:
import sys
sys.path.append('../')
from models import VGAE, LatentMLP 
from utils import BrainGraphDataset, project_root, get_data_labels
import torch
import torch.nn as nn
import torch.optim as optim
import os
from tqdm import tqdm
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
root = project_root()

# instantiate the VGAE model
lr = 0.001
batch_size = 64

nf = 1
ef = 1
num_nodes = 100
hidden_dim = 128
latent_size = 8

criterion = nn.L1Loss(reduction='sum')
categories = ['patient_n','condition','bdi_before']

data_labels = get_data_labels()
data_labels = data_labels[categories]

annotations = 'annotations.csv'

data_labels.loc[data_labels["condition"] == "P", "condition"] = 1
data_labels.loc[data_labels["condition"] == "E", "condition"] = -1
data_labels['condition'] = data_labels['condition'].astype('float64')

dataroot = 'fc_matrices/psilo_aal_before/'

dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=data_labels, setting='lz')

# Define the train, validation, and test ratios
train_ratio = 0.8
val_ratio = 0.2
test_ratio = 0.2

# Get the number of samples in the dataset
num_samples = len(dataset)

# Calculate the number of samples for each set
train_size = int(train_ratio * num_samples)
val_size = num_samples - train_size

torch.manual_seed(0)
# Split the dataset into train, validation, and test sets


import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold

# Assuming you have your dataset defined as 'dataset'
num_folds = 5  # Specify the number of folds
batch_size = 8  # Specify your desired batch size
random_seed = 42  # Specify the random seed

# Create indices for k-fold cross-validation with seeded random number generator
kf = KFold(n_splits=num_folds, shuffle=True, random_state=random_seed)

# Create empty lists to store train and validation loaders
train_loaders = []
val_loaders = []

for train_index, val_index in kf.split(dataset):
    # Split dataset into train and validation sets for the current fold
    train_set = torch.utils.data.Subset(dataset, train_index)
    val_set = torch.utils.data.Subset(dataset, val_index)

    # Define the dataloaders for the current fold
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

    # Append the loaders to the respective lists
    train_loaders.append(train_loader)
    val_loaders.append(val_loader)


num_epochs = 500

import json

# Dictionary to store training and validation curves
results = []
dropout_list = [0]
best_set = [0] * num_folds

for i, train_loader in enumerate(train_loaders):
    val_loader = val_loaders[i]
    for dropout in dropout_list:
        vgae = VGAE(1, 1, 116, 32, 8, device, dropout=0, l2_strength=0.001).to(device)
        # load the trained VGAE weights
        vgae.load_state_dict(torch.load(os.path.join(root, 'vgae_weights/vgae_nf_aal.pt'), map_location=device))
        # Convert the model to the device
        vgae.to(device)

        
        best_val_loss = float('inf')  # set to infinity to start
        best_model_state = None
        train_losses = []
        val_losses = []

        model = LatentMLP(64, 64, 1, dropout=dropout)
        optimizer = optim.Adam(model.parameters(), lr=lr)

        src, dest = vgae.edge_index

        for epoch in tqdm(range(num_epochs)):
            train_loss = 0.0
            val_loss = 0.0

            # training
            model.train()
            for batch_idx, ((graph, lz, baseline_bdi), label) in enumerate(train_loader):
                graph = graph.to(device)  # move data to device
                lz = lz.to(device)
                label = label.to(device)
                optimizer.zero_grad()

                rcn_lz, rcn_edges, z, _, _ = vgae(lz, graph)
                graph = graph[:, src, dest]

                output_bdi = model(z.view(z.shape[0], -1), baseline_bdi)

                l1_loss, l2_loss = model.loss(output_bdi, label.view(output_bdi.shape))
                loss = l1_loss + l2_loss
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            # validation
            model.eval()
            val_label = []
            val_output = []
            val_base = []
            with torch.no_grad():
                for batch_idx, ((graph, lz, baseline_bdi), label) in enumerate(val_loader):
                    graph = graph.to(device)  # move data to device
                    lz = lz.to(device)
                    label = label.to(device)

                    rcn_lz, rcn_edges, z, _, _ = vgae(lz, graph)
                    graph = graph[:, src, dest]

                    output_bdi = model(z.view(z.shape[0], -1), baseline_bdi)

                    val_label.extend(label)
                    val_output.extend(output_bdi)
                    val_base.extend(baseline_bdi)
                    
                    l1_loss, l2_loss = model.loss(output_bdi, label.view(output_bdi.shape))
                    loss = l1_loss + l2_loss
                    val_loss += loss.item()
            # append losses to lists
            train_losses.append(train_loss/len(train_set))
            val_losses.append(val_loss/len(val_set))

            # save the model if the validation loss is at its minimum
            if val_losses[-1] < best_val_loss:
                best_val_loss = val_losses[-1]

                best_model_state = (copy.deepcopy(vgae.state_dict()), copy.deepcopy(model.state_dict()))
                best_set[i] = (val_label, val_output, val_base)
            # print the losses
            with open('dropout_train.txt', 'a') as f:
                if (epoch + 1) % 20 == 0:
                    print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')
                f.write(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')

#         # save the best model for this configuration
#         torch.save(best_model_state[1], os.path.join(root, f'mlp_weights/vgae_model_weight_schaefer_final.pt'))
#         torch.save(best_model_state[0], os.path.join(root, f'mlp_weights/vgae_vgae_weight_schaefer_final.pt'))

#         # add the loss curves to the dictionary
#         loss_curves[f"dropout_{dropout}"] = {"train_loss": train_losses, "val_loss": val_losses}

    
#     # save the loss curves to a file
#     with open("loss_curves_mlp_vgae.json", "w") as f:
#         json.dump(loss_curves, f)

  4%|▍         | 20/500 [00:05<02:11,  3.66it/s]

Epoch 20/500 - Train Loss: 5.8568 - Val Loss: 12.4591



  8%|▊         | 40/500 [00:10<02:03,  3.73it/s]

Epoch 40/500 - Train Loss: 5.6169 - Val Loss: 12.1241



 12%|█▏        | 60/500 [00:16<01:56,  3.77it/s]

Epoch 60/500 - Train Loss: 5.1025 - Val Loss: 11.8188



 16%|█▌        | 80/500 [00:21<01:55,  3.64it/s]

Epoch 80/500 - Train Loss: 4.0288 - Val Loss: 10.4679



 20%|██        | 100/500 [00:27<01:53,  3.51it/s]

Epoch 100/500 - Train Loss: 3.9438 - Val Loss: 9.6353



 24%|██▍       | 120/500 [00:32<01:44,  3.63it/s]

Epoch 120/500 - Train Loss: 3.8010 - Val Loss: 11.1678



 28%|██▊       | 140/500 [00:38<01:36,  3.72it/s]

Epoch 140/500 - Train Loss: 4.0115 - Val Loss: 11.0155



 32%|███▏      | 160/500 [00:43<01:31,  3.71it/s]

Epoch 160/500 - Train Loss: 3.8012 - Val Loss: 11.1917



 36%|███▌      | 180/500 [00:49<01:25,  3.74it/s]

Epoch 180/500 - Train Loss: 3.6098 - Val Loss: 11.1047



 40%|████      | 200/500 [00:54<01:22,  3.65it/s]

Epoch 200/500 - Train Loss: 3.3182 - Val Loss: 10.6711



 44%|████▍     | 220/500 [01:00<01:18,  3.58it/s]

Epoch 220/500 - Train Loss: 3.3623 - Val Loss: 9.6967



 48%|████▊     | 240/500 [01:05<01:09,  3.73it/s]

Epoch 240/500 - Train Loss: 3.3592 - Val Loss: 11.7700



 52%|█████▏    | 260/500 [01:10<01:04,  3.75it/s]

Epoch 260/500 - Train Loss: 3.3318 - Val Loss: 10.9953



 56%|█████▌    | 280/500 [01:16<00:58,  3.77it/s]

Epoch 280/500 - Train Loss: 2.6887 - Val Loss: 12.0653



 60%|██████    | 300/500 [01:21<00:53,  3.76it/s]

Epoch 300/500 - Train Loss: 2.7687 - Val Loss: 12.3190



 64%|██████▍   | 320/500 [01:26<00:47,  3.76it/s]

Epoch 320/500 - Train Loss: 2.6241 - Val Loss: 11.6445



 68%|██████▊   | 340/500 [01:32<00:42,  3.74it/s]

Epoch 340/500 - Train Loss: 2.6956 - Val Loss: 11.9878



 72%|███████▏  | 360/500 [01:37<00:37,  3.75it/s]

Epoch 360/500 - Train Loss: 2.6300 - Val Loss: 11.9267



 76%|███████▌  | 380/500 [01:42<00:31,  3.77it/s]

Epoch 380/500 - Train Loss: 2.5728 - Val Loss: 12.3276



 80%|████████  | 400/500 [01:48<00:26,  3.78it/s]

Epoch 400/500 - Train Loss: 2.2865 - Val Loss: 12.3386



 84%|████████▍ | 420/500 [01:53<00:21,  3.77it/s]

Epoch 420/500 - Train Loss: 1.6605 - Val Loss: 12.3946



 88%|████████▊ | 440/500 [01:58<00:15,  3.77it/s]

Epoch 440/500 - Train Loss: 1.9380 - Val Loss: 12.4556



 92%|█████████▏| 460/500 [02:03<00:10,  3.77it/s]

Epoch 460/500 - Train Loss: 2.8429 - Val Loss: 11.8904



 96%|█████████▌| 480/500 [02:09<00:05,  3.77it/s]

Epoch 480/500 - Train Loss: 1.9439 - Val Loss: 13.0414



100%|██████████| 500/500 [02:22<00:00,  3.51it/s]


Epoch 500/500 - Train Loss: 1.7830 - Val Loss: 12.3804



  4%|▍         | 20/500 [00:05<02:07,  3.76it/s]

Epoch 20/500 - Train Loss: 7.1416 - Val Loss: 7.0467



  8%|▊         | 40/500 [00:10<02:02,  3.75it/s]

Epoch 40/500 - Train Loss: 6.7270 - Val Loss: 7.3524



 12%|█▏        | 60/500 [00:15<01:56,  3.79it/s]

Epoch 60/500 - Train Loss: 5.2612 - Val Loss: 10.6795



 16%|█▌        | 80/500 [00:21<01:52,  3.73it/s]

Epoch 80/500 - Train Loss: 5.4703 - Val Loss: 10.6322



 20%|██        | 100/500 [00:26<01:46,  3.77it/s]

Epoch 100/500 - Train Loss: 5.3141 - Val Loss: 10.0828



 24%|██▍       | 120/500 [00:31<01:40,  3.77it/s]

Epoch 120/500 - Train Loss: 4.8327 - Val Loss: 10.0831



 28%|██▊       | 140/500 [00:37<01:35,  3.75it/s]

Epoch 140/500 - Train Loss: 5.1394 - Val Loss: 9.2629



 32%|███▏      | 160/500 [00:42<01:30,  3.76it/s]

Epoch 160/500 - Train Loss: 5.2179 - Val Loss: 8.7171



 36%|███▌      | 180/500 [00:48<01:35,  3.35it/s]

Epoch 180/500 - Train Loss: 4.8038 - Val Loss: 8.1430



 40%|████      | 200/500 [00:54<01:28,  3.41it/s]

Epoch 200/500 - Train Loss: 4.5311 - Val Loss: 8.6489



 44%|████▍     | 220/500 [00:59<01:15,  3.72it/s]

Epoch 220/500 - Train Loss: 4.6587 - Val Loss: 8.2561



 48%|████▊     | 240/500 [01:05<01:09,  3.76it/s]

Epoch 240/500 - Train Loss: 4.5927 - Val Loss: 8.0216



 52%|█████▏    | 260/500 [01:10<01:03,  3.76it/s]

Epoch 260/500 - Train Loss: 4.4266 - Val Loss: 6.9194



 56%|█████▌    | 280/500 [01:15<00:58,  3.76it/s]

Epoch 280/500 - Train Loss: 4.7926 - Val Loss: 6.5141



 60%|██████    | 300/500 [01:21<00:53,  3.73it/s]

Epoch 300/500 - Train Loss: 4.3528 - Val Loss: 6.9893



 64%|██████▍   | 320/500 [01:26<00:49,  3.67it/s]

Epoch 320/500 - Train Loss: 4.4123 - Val Loss: 7.7589



 68%|██████▊   | 340/500 [01:32<00:44,  3.56it/s]

Epoch 340/500 - Train Loss: 4.3904 - Val Loss: 7.0729



 72%|███████▏  | 360/500 [01:37<00:38,  3.59it/s]

Epoch 360/500 - Train Loss: 4.4115 - Val Loss: 5.7577



 76%|███████▌  | 380/500 [01:43<00:35,  3.41it/s]

Epoch 380/500 - Train Loss: 4.4352 - Val Loss: 9.4032



 80%|████████  | 400/500 [01:49<00:27,  3.58it/s]

Epoch 400/500 - Train Loss: 3.6311 - Val Loss: 7.5316



 84%|████████▍ | 420/500 [01:55<00:22,  3.61it/s]

Epoch 420/500 - Train Loss: 4.0633 - Val Loss: 7.1908



 88%|████████▊ | 440/500 [02:00<00:17,  3.36it/s]

Epoch 440/500 - Train Loss: 4.3510 - Val Loss: 7.2239



 92%|█████████▏| 460/500 [02:06<00:11,  3.44it/s]

Epoch 460/500 - Train Loss: 3.5551 - Val Loss: 8.1156



 96%|█████████▌| 480/500 [02:12<00:06,  3.21it/s]

Epoch 480/500 - Train Loss: 3.8959 - Val Loss: 8.3316



100%|██████████| 500/500 [02:18<00:00,  3.61it/s]


Epoch 500/500 - Train Loss: 3.9755 - Val Loss: 7.3545



  4%|▍         | 20/500 [00:05<02:11,  3.66it/s]

Epoch 20/500 - Train Loss: 7.3503 - Val Loss: 6.3186



  8%|▊         | 40/500 [00:11<02:15,  3.39it/s]

Epoch 40/500 - Train Loss: 6.9957 - Val Loss: 6.0705



 12%|█▏        | 60/500 [00:16<02:01,  3.63it/s]

Epoch 60/500 - Train Loss: 6.0093 - Val Loss: 6.4123



 16%|█▌        | 80/500 [00:22<01:55,  3.63it/s]

Epoch 80/500 - Train Loss: 5.5310 - Val Loss: 6.1484



 20%|██        | 100/500 [00:27<01:49,  3.65it/s]

Epoch 100/500 - Train Loss: 4.7149 - Val Loss: 7.1265



 24%|██▍       | 120/500 [00:33<01:49,  3.47it/s]

Epoch 120/500 - Train Loss: 4.3384 - Val Loss: 6.1028



 28%|██▊       | 140/500 [00:39<01:37,  3.70it/s]

Epoch 140/500 - Train Loss: 4.2285 - Val Loss: 7.2690



 32%|███▏      | 160/500 [00:44<01:32,  3.69it/s]

Epoch 160/500 - Train Loss: 4.0474 - Val Loss: 6.3446



 36%|███▌      | 180/500 [00:50<01:27,  3.66it/s]

Epoch 180/500 - Train Loss: 3.7320 - Val Loss: 9.1505



 40%|████      | 200/500 [00:56<01:25,  3.50it/s]

Epoch 200/500 - Train Loss: 4.2856 - Val Loss: 7.3955



 44%|████▍     | 220/500 [01:01<01:23,  3.37it/s]

Epoch 220/500 - Train Loss: 3.6232 - Val Loss: 6.7976



 48%|████▊     | 240/500 [01:07<01:12,  3.58it/s]

Epoch 240/500 - Train Loss: 4.5199 - Val Loss: 9.1926



 52%|█████▏    | 260/500 [01:12<01:05,  3.67it/s]

Epoch 260/500 - Train Loss: 3.3847 - Val Loss: 6.2151



 56%|█████▌    | 280/500 [01:18<01:00,  3.63it/s]

Epoch 280/500 - Train Loss: 3.2561 - Val Loss: 6.1823



 60%|██████    | 300/500 [01:23<00:53,  3.75it/s]

Epoch 300/500 - Train Loss: 3.7205 - Val Loss: 7.3881



 64%|██████▍   | 320/500 [01:29<00:50,  3.54it/s]

Epoch 320/500 - Train Loss: 3.2346 - Val Loss: 7.5298



 68%|██████▊   | 340/500 [01:35<00:48,  3.31it/s]

Epoch 340/500 - Train Loss: 3.0425 - Val Loss: 7.2150



 72%|███████▏  | 360/500 [01:40<00:36,  3.82it/s]

Epoch 360/500 - Train Loss: 3.1148 - Val Loss: 9.0476



 76%|███████▌  | 380/500 [01:45<00:31,  3.80it/s]

Epoch 380/500 - Train Loss: 2.6064 - Val Loss: 7.7548



 80%|████████  | 400/500 [01:51<00:26,  3.81it/s]

Epoch 400/500 - Train Loss: 2.9629 - Val Loss: 9.0858



 84%|████████▍ | 420/500 [01:56<00:21,  3.81it/s]

Epoch 420/500 - Train Loss: 2.8264 - Val Loss: 8.5331



 88%|████████▊ | 440/500 [02:01<00:15,  3.80it/s]

Epoch 440/500 - Train Loss: 2.5464 - Val Loss: 7.9421



 92%|█████████▏| 460/500 [02:06<00:10,  3.81it/s]

Epoch 460/500 - Train Loss: 2.6830 - Val Loss: 8.1412



 96%|█████████▌| 480/500 [02:12<00:05,  3.80it/s]

Epoch 480/500 - Train Loss: 2.7592 - Val Loss: 8.8047



100%|██████████| 500/500 [02:17<00:00,  3.64it/s]


Epoch 500/500 - Train Loss: 2.5210 - Val Loss: 8.4141



  4%|▍         | 20/500 [00:05<02:05,  3.82it/s]

Epoch 20/500 - Train Loss: 7.3229 - Val Loss: 6.5493



  8%|▊         | 40/500 [00:10<02:00,  3.82it/s]

Epoch 40/500 - Train Loss: 6.9703 - Val Loss: 6.8369



 12%|█▏        | 60/500 [00:15<01:55,  3.82it/s]

Epoch 60/500 - Train Loss: 6.2366 - Val Loss: 7.2263



 16%|█▌        | 80/500 [00:20<01:50,  3.81it/s]

Epoch 80/500 - Train Loss: 5.4144 - Val Loss: 6.8485



 20%|██        | 100/500 [00:26<01:44,  3.82it/s]

Epoch 100/500 - Train Loss: 4.5323 - Val Loss: 8.5054



 24%|██▍       | 120/500 [00:31<01:39,  3.81it/s]

Epoch 120/500 - Train Loss: 4.9110 - Val Loss: 8.8683



 28%|██▊       | 140/500 [00:36<01:34,  3.82it/s]

Epoch 140/500 - Train Loss: 3.8344 - Val Loss: 7.8494



 32%|███▏      | 160/500 [00:41<01:29,  3.82it/s]

Epoch 160/500 - Train Loss: 4.6099 - Val Loss: 6.3796



 36%|███▌      | 180/500 [00:47<01:23,  3.82it/s]

Epoch 180/500 - Train Loss: 4.1886 - Val Loss: 6.7171



 40%|████      | 200/500 [00:52<01:18,  3.81it/s]

Epoch 200/500 - Train Loss: 4.0722 - Val Loss: 8.4733



 44%|████▍     | 220/500 [00:57<01:13,  3.82it/s]

Epoch 220/500 - Train Loss: 4.1595 - Val Loss: 7.3915



 48%|████▊     | 240/500 [01:02<01:08,  3.80it/s]

Epoch 240/500 - Train Loss: 3.6835 - Val Loss: 10.0892



 52%|█████▏    | 260/500 [01:08<01:02,  3.81it/s]

Epoch 260/500 - Train Loss: 3.6207 - Val Loss: 8.5231



 56%|█████▌    | 280/500 [01:13<00:57,  3.81it/s]

Epoch 280/500 - Train Loss: 3.4796 - Val Loss: 8.9677



 60%|██████    | 300/500 [01:18<00:52,  3.78it/s]

Epoch 300/500 - Train Loss: 3.2166 - Val Loss: 9.9994



 64%|██████▍   | 320/500 [01:23<00:47,  3.81it/s]

Epoch 320/500 - Train Loss: 3.2158 - Val Loss: 9.6997



 68%|██████▊   | 340/500 [01:29<00:42,  3.81it/s]

Epoch 340/500 - Train Loss: 3.0287 - Val Loss: 11.0064



 72%|███████▏  | 360/500 [01:34<00:36,  3.81it/s]

Epoch 360/500 - Train Loss: 2.8679 - Val Loss: 10.9721



 76%|███████▌  | 380/500 [01:39<00:31,  3.80it/s]

Epoch 380/500 - Train Loss: 3.0730 - Val Loss: 11.3387



 80%|████████  | 400/500 [01:44<00:26,  3.82it/s]

Epoch 400/500 - Train Loss: 3.4520 - Val Loss: 9.1631



 84%|████████▍ | 420/500 [01:50<00:21,  3.81it/s]

Epoch 420/500 - Train Loss: 2.2160 - Val Loss: 10.5867



 88%|████████▊ | 440/500 [01:55<00:15,  3.82it/s]

Epoch 440/500 - Train Loss: 2.3233 - Val Loss: 9.8233



 92%|█████████▏| 460/500 [02:00<00:10,  3.83it/s]

Epoch 460/500 - Train Loss: 2.6420 - Val Loss: 10.8327



 96%|█████████▌| 480/500 [02:05<00:05,  3.83it/s]

Epoch 480/500 - Train Loss: 2.1833 - Val Loss: 9.2025



100%|██████████| 500/500 [02:11<00:00,  3.81it/s]


Epoch 500/500 - Train Loss: 2.2880 - Val Loss: 11.0169



  4%|▍         | 20/500 [00:05<02:05,  3.82it/s]

Epoch 20/500 - Train Loss: 7.4095 - Val Loss: 6.0335



  8%|▊         | 40/500 [00:10<02:00,  3.83it/s]

Epoch 40/500 - Train Loss: 7.0223 - Val Loss: 5.5486



 12%|█▏        | 60/500 [00:15<01:55,  3.82it/s]

Epoch 60/500 - Train Loss: 6.1681 - Val Loss: 4.6637



 16%|█▌        | 80/500 [00:20<01:49,  3.83it/s]

Epoch 80/500 - Train Loss: 5.7313 - Val Loss: 4.4895



 20%|██        | 100/500 [00:26<01:44,  3.82it/s]

Epoch 100/500 - Train Loss: 5.4349 - Val Loss: 6.0385



 24%|██▍       | 120/500 [00:31<01:39,  3.83it/s]

Epoch 120/500 - Train Loss: 5.2077 - Val Loss: 5.2062



 28%|██▊       | 140/500 [00:36<01:33,  3.83it/s]

Epoch 140/500 - Train Loss: 4.9958 - Val Loss: 6.1073



 32%|███▏      | 160/500 [00:41<01:28,  3.83it/s]

Epoch 160/500 - Train Loss: 4.5646 - Val Loss: 6.5821



 36%|███▌      | 180/500 [00:47<01:23,  3.83it/s]

Epoch 180/500 - Train Loss: 4.7720 - Val Loss: 6.2003



 40%|████      | 200/500 [00:52<01:18,  3.83it/s]

Epoch 200/500 - Train Loss: 4.6902 - Val Loss: 7.3654



 44%|████▍     | 220/500 [00:57<01:13,  3.83it/s]

Epoch 220/500 - Train Loss: 4.1076 - Val Loss: 6.1591



 48%|████▊     | 240/500 [01:02<01:07,  3.83it/s]

Epoch 240/500 - Train Loss: 4.5124 - Val Loss: 7.8360



 52%|█████▏    | 260/500 [01:07<01:02,  3.83it/s]

Epoch 260/500 - Train Loss: 4.2274 - Val Loss: 6.1815



 56%|█████▌    | 280/500 [01:13<00:57,  3.84it/s]

Epoch 280/500 - Train Loss: 4.9036 - Val Loss: 7.4560



 60%|██████    | 300/500 [01:18<00:52,  3.83it/s]

Epoch 300/500 - Train Loss: 4.2009 - Val Loss: 7.9784



 64%|██████▍   | 320/500 [01:23<00:46,  3.83it/s]

Epoch 320/500 - Train Loss: 3.8685 - Val Loss: 8.3995



 68%|██████▊   | 340/500 [01:28<00:41,  3.83it/s]

Epoch 340/500 - Train Loss: 3.8524 - Val Loss: 7.1084



 72%|███████▏  | 360/500 [01:34<00:36,  3.83it/s]

Epoch 360/500 - Train Loss: 3.7255 - Val Loss: 6.8892



 76%|███████▌  | 380/500 [01:39<00:31,  3.83it/s]

Epoch 380/500 - Train Loss: 3.6995 - Val Loss: 6.8171



 80%|████████  | 400/500 [01:44<00:26,  3.82it/s]

Epoch 400/500 - Train Loss: 3.9766 - Val Loss: 6.9530



 84%|████████▍ | 420/500 [01:49<00:20,  3.83it/s]

Epoch 420/500 - Train Loss: 3.5545 - Val Loss: 6.2671



 88%|████████▊ | 440/500 [01:55<00:15,  3.83it/s]

Epoch 440/500 - Train Loss: 3.6329 - Val Loss: 7.3142



 92%|█████████▏| 460/500 [02:00<00:10,  3.83it/s]

Epoch 460/500 - Train Loss: 3.8044 - Val Loss: 7.0250



 96%|█████████▌| 480/500 [02:05<00:05,  3.83it/s]

Epoch 480/500 - Train Loss: 2.5665 - Val Loss: 6.9748



100%|██████████| 500/500 [02:10<00:00,  3.82it/s]

Epoch 500/500 - Train Loss: 2.9594 - Val Loss: 8.0865






In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import r2_score, mean_absolute_error

r2_scores = []
mae_scores = []
r_scores = []

l1 = []
l2 = []

for fold in best_set:
    l1.extend([t.item() for t in fold[0]])
    l2.extend([t.item() for t in fold[1]])

print(r2_score(l1, l2))
print(pearsonr(l1, l2))
print(mean_absolute_error(l1,l2))

print()

total_true = []
total_pred = []
total_drug = []
total_base = []

for i, (true, pred, base) in enumerate(best_set):
    true = [t.item() for t in true]
    pred = [p.item() for p in pred]
    drug = [d[0].item() for d in base]
    base = [d[1].item() for d in base]
    
    total_true.extend(true)
    total_pred.extend(pred)
    total_drug.extend(drug)
    total_base.extend(base)
    # Calculate R-squared (Pearson correlation coefficient)
    r2 = r2_score(true, pred)
    r2_scores.append(r2)

    r_scores.append(pearsonr(true, pred))
    
    # Calculate Mean Absolute Error (MAE)
    mae = mean_absolute_error(true, pred)
    mae_scores.append(mae)
    
        
# Specify the filename for the CSV file
filename = 'feature-vgae-full-val-results-aal.csv'

# Create a list of rows with headers
rows = [['true_post_bdi', 'predicted_post_bdi', 'drug (1 for psilo)', 'base_bdi']]
for true, pred, drug, base in zip(total_true, total_pred, total_drug, total_base):
    rows.append([true, pred, drug, base])

import csv
# Write the rows to the CSV file
with open(filename, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(rows)
        

0.4068533464099622
PearsonRResult(statistic=0.6459862123140767, pvalue=3.832264490075833e-06)
5.496987628794852

