In [1]:
import sys
sys.path.append('../')
from models import VGAE, LatentMLP 
from utils import BrainGraphDataset, project_root, get_data_labels
import torch
import torch.nn as nn
import torch.optim as optim
import os
from tqdm import tqdm
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
root = project_root()

# instantiate the VGAE model
lr = 0.001
batch_size = 64

nf = 1
ef = 1
num_nodes = 100
hidden_dim = 128
latent_size = 8



criterion = nn.L1Loss(reduction='sum')


categories = ['patient_n','condition','bdi_before']

data_labels = get_data_labels()
data_labels = data_labels[categories]

annotations = 'annotations.csv'

data_labels.loc[data_labels["condition"] == "P", "condition"] = 1
data_labels.loc[data_labels["condition"] == "E", "condition"] = -1
data_labels['condition'] = data_labels['condition'].astype('float64')

dataroot = 'fc_matrices/psilo_schaefer_before/'

dataset = BrainGraphDataset(img_dir=os.path.join(root, dataroot),
                            annotations_file=os.path.join(root, annotations),
                            transform=None, extra_data=data_labels, setting='lz')

# Get the number of samples in the dataset
num_samples = len(dataset)

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold

# Assuming you have your dataset defined as 'dataset'
num_folds = 5  # Specify the number of folds
batch_size = 8  # Specify your desired batch size
random_seed = 42  # Specify the random seed

# Create indices for k-fold cross-validation with seeded random number generator
kf = KFold(n_splits=num_folds, shuffle=True, random_state=random_seed)

# Create empty lists to store train and validation loaders
train_loaders = []
val_loaders = []

for train_index, val_index in kf.split(dataset):
    # Split dataset into train and validation sets for the current fold
    train_set = torch.utils.data.Subset(dataset, train_index)
    val_set = torch.utils.data.Subset(dataset, val_index)

    # Define the dataloaders for the current fold
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

    # Append the loaders to the respective lists
    train_loaders.append(train_loader)
    val_loaders.append(val_loader)


num_epochs = 500

import json

# Dictionary to store training and validation curves
loss_curves = {}
best_set = [0] * num_folds
dropout_list = [0]

for i, train_loader in enumerate(train_loaders):
    val_loader = val_loaders[i]
    for dropout in dropout_list:
        
        vgae = VGAE(1, 1, 100, 64, 4, device, dropout=0, l2_strength=0.001, use_nf=False).to(device)
        # load the trained VGAE weights
        vgae.load_state_dict(torch.load(os.path.join(root, 'vgae_weights/vgae_no_nf_ica.pt'), map_location=device))

        # define the optimizer and the loss function


        # Convert the model to the device
        vgae.to(device)

        
        best_val_loss = float('inf')  # set to infinity to start
        best_model_state = None
        train_losses = []
        val_losses = []

        model = LatentMLP(64, 256, 1, dropout=dropout)
        optimizer = optim.Adam(model.parameters(), lr=lr)

        src, dest = vgae.edge_index

        for epoch in tqdm(range(num_epochs)):
            train_loss = 0.0
            val_loss = 0.0

            # training
            model.train()
            for batch_idx, ((graph, _, baseline_bdi), label) in enumerate(train_loader):
                graph = graph.to(device)  # move data to device
                label = label.to(device)
                optimizer.zero_grad()

                rcn_edges, z, _, _ = vgae(None, graph)
                graph = graph[:, src, dest]

                output_bdi = model(z.view(z.shape[0], -1), baseline_bdi)

                l1_loss, l2_loss = model.loss(output_bdi, label.view(output_bdi.shape))
                loss = l1_loss + l2_loss
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            # validation
            model.eval()
            val_label = []
            val_output = []
            val_base = []
            with torch.no_grad():
                for batch_idx, ((graph, _, baseline_bdi), label) in enumerate(val_loader):
                    graph = graph.to(device)  # move data to device
                    label = label.to(device)

                    rcn_edges, z, _, _ = vgae(None, graph)
                    graph = graph[:, src, dest]

                    output_bdi = model(z.view(z.shape[0], -1), baseline_bdi)
                    val_label.extend(label)
                    val_output.extend(output_bdi)
                    val_base.extend(baseline_bdi)
                    
                    l1_loss, l2_loss = model.loss(output_bdi, label.view(output_bdi.shape))
                    loss = l1_loss + l2_loss
                    val_loss += loss.item()
            # append losses to lists
            train_losses.append(train_loss/len(train_set))
            val_losses.append(val_loss/len(val_set))

            # save the model if the validation loss is at its minimum
            if val_losses[-1] < best_val_loss:
                best_val_loss = val_losses[-1]

                best_model_state = (copy.deepcopy(vgae.state_dict()), copy.deepcopy(model.state_dict()))
                best_set[i] = (val_label, val_output, val_base)
            # print the losses
            with open('dropout_train.txt', 'a') as f:
                if (epoch + 1) % 20 == 0:
                    print(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')
                f.write(f'Epoch {epoch+1}/{num_epochs} - Train Loss: {train_losses[-1]:.4f} - Val Loss: {val_losses[-1]:.4f}\n')

#         # save the best model for this configuration
#         torch.save(best_model_state[1], os.path.join(root, f'mlp_weights/vgae_model_weight_{i}_no_nf.pt'))
#         torch.save(best_model_state[0], os.path.join(root, f'mlp_weights/vgae_vgae_weight_{i}_no_nf.pt'))

#         # add the loss curves to the dictionary
#         loss_curves[f"dropout_{dropout}"] = {"train_loss": train_losses, "val_loss": val_losses}

#     # save the loss curves to a file
#     with open("loss_curves_mlp_vgae.json", "w") as f:
#         json.dump(loss_curves, f)

  4%|▍         | 20/500 [00:04<01:47,  4.48it/s]

Epoch 20/500 - Train Loss: 5.8772 - Val Loss: 12.0721



  8%|▊         | 40/500 [00:08<01:41,  4.53it/s]

Epoch 40/500 - Train Loss: 5.6721 - Val Loss: 11.5092



 12%|█▏        | 60/500 [00:13<01:37,  4.52it/s]

Epoch 60/500 - Train Loss: 5.6863 - Val Loss: 10.9273



 16%|█▌        | 80/500 [00:17<01:33,  4.48it/s]

Epoch 80/500 - Train Loss: 5.4969 - Val Loss: 10.8755



 20%|██        | 100/500 [00:22<01:29,  4.47it/s]

Epoch 100/500 - Train Loss: 5.6077 - Val Loss: 11.1965



 24%|██▍       | 120/500 [00:26<01:24,  4.48it/s]

Epoch 120/500 - Train Loss: 5.6854 - Val Loss: 11.1380



 28%|██▊       | 140/500 [00:31<01:20,  4.50it/s]

Epoch 140/500 - Train Loss: 5.9038 - Val Loss: 10.9963



 32%|███▏      | 160/500 [00:35<01:15,  4.50it/s]

Epoch 160/500 - Train Loss: 5.7067 - Val Loss: 11.0446



 36%|███▌      | 180/500 [00:39<01:10,  4.54it/s]

Epoch 180/500 - Train Loss: 5.7177 - Val Loss: 11.4061



 40%|████      | 200/500 [00:44<01:06,  4.54it/s]

Epoch 200/500 - Train Loss: 5.6668 - Val Loss: 10.9280



 44%|████▍     | 220/500 [00:48<01:01,  4.58it/s]

Epoch 220/500 - Train Loss: 5.4407 - Val Loss: 10.8432



 48%|████▊     | 240/500 [00:53<00:57,  4.49it/s]

Epoch 240/500 - Train Loss: 5.3985 - Val Loss: 10.8329



 52%|█████▏    | 260/500 [00:57<00:53,  4.51it/s]

Epoch 260/500 - Train Loss: 5.4447 - Val Loss: 11.5045



 56%|█████▌    | 280/500 [01:01<00:48,  4.53it/s]

Epoch 280/500 - Train Loss: 5.5441 - Val Loss: 10.2013



 60%|██████    | 300/500 [01:06<00:44,  4.52it/s]

Epoch 300/500 - Train Loss: 5.3874 - Val Loss: 10.5960



 64%|██████▍   | 320/500 [01:10<00:39,  4.53it/s]

Epoch 320/500 - Train Loss: 5.5772 - Val Loss: 10.4330



 68%|██████▊   | 340/500 [01:15<00:35,  4.54it/s]

Epoch 340/500 - Train Loss: 5.6068 - Val Loss: 10.9534



 72%|███████▏  | 360/500 [01:19<00:30,  4.55it/s]

Epoch 360/500 - Train Loss: 5.6613 - Val Loss: 11.3116



 76%|███████▌  | 380/500 [01:24<00:26,  4.55it/s]

Epoch 380/500 - Train Loss: 5.5536 - Val Loss: 10.2106



 80%|████████  | 400/500 [01:28<00:22,  4.54it/s]

Epoch 400/500 - Train Loss: 5.7176 - Val Loss: 10.8875



 84%|████████▍ | 420/500 [01:32<00:17,  4.54it/s]

Epoch 420/500 - Train Loss: 5.4843 - Val Loss: 10.4675



 88%|████████▊ | 440/500 [01:37<00:13,  4.54it/s]

Epoch 440/500 - Train Loss: 5.3070 - Val Loss: 10.3570



 92%|█████████▏| 460/500 [01:41<00:08,  4.54it/s]

Epoch 460/500 - Train Loss: 5.4470 - Val Loss: 11.0934



 96%|█████████▌| 480/500 [01:46<00:04,  4.53it/s]

Epoch 480/500 - Train Loss: 5.3616 - Val Loss: 10.9126



100%|██████████| 500/500 [01:50<00:00,  4.52it/s]


Epoch 500/500 - Train Loss: 5.4303 - Val Loss: 11.2667



  4%|▍         | 20/500 [00:04<01:46,  4.51it/s]

Epoch 20/500 - Train Loss: 7.0181 - Val Loss: 7.5588



  8%|▊         | 40/500 [00:08<01:41,  4.53it/s]

Epoch 40/500 - Train Loss: 6.9035 - Val Loss: 8.0941



 12%|█▏        | 60/500 [00:13<01:36,  4.56it/s]

Epoch 60/500 - Train Loss: 6.2825 - Val Loss: 6.7916



 16%|█▌        | 80/500 [00:17<01:32,  4.55it/s]

Epoch 80/500 - Train Loss: 6.6274 - Val Loss: 7.7160



 20%|██        | 100/500 [00:22<01:27,  4.55it/s]

Epoch 100/500 - Train Loss: 6.5954 - Val Loss: 6.7447



 24%|██▍       | 120/500 [00:26<01:23,  4.55it/s]

Epoch 120/500 - Train Loss: 6.7352 - Val Loss: 8.5110



 28%|██▊       | 140/500 [00:30<01:17,  4.63it/s]

Epoch 140/500 - Train Loss: 6.4183 - Val Loss: 7.3239



 32%|███▏      | 160/500 [00:35<01:14,  4.56it/s]

Epoch 160/500 - Train Loss: 6.5459 - Val Loss: 7.0715



 36%|███▌      | 180/500 [00:39<01:09,  4.63it/s]

Epoch 180/500 - Train Loss: 6.5583 - Val Loss: 8.5789



 40%|████      | 200/500 [00:43<01:05,  4.61it/s]

Epoch 200/500 - Train Loss: 6.3413 - Val Loss: 6.5864



 44%|████▍     | 220/500 [00:48<01:00,  4.60it/s]

Epoch 220/500 - Train Loss: 6.1935 - Val Loss: 7.1667



 48%|████▊     | 240/500 [00:52<00:56,  4.64it/s]

Epoch 240/500 - Train Loss: 6.1504 - Val Loss: 6.9653



 52%|█████▏    | 260/500 [00:56<00:51,  4.65it/s]

Epoch 260/500 - Train Loss: 5.9240 - Val Loss: 7.0798



 56%|█████▌    | 280/500 [01:01<00:47,  4.62it/s]

Epoch 280/500 - Train Loss: 6.1665 - Val Loss: 7.2444



 60%|██████    | 300/500 [01:05<00:43,  4.64it/s]

Epoch 300/500 - Train Loss: 6.6607 - Val Loss: 6.6967



 64%|██████▍   | 320/500 [01:09<00:39,  4.61it/s]

Epoch 320/500 - Train Loss: 5.9208 - Val Loss: 6.8592



 68%|██████▊   | 340/500 [01:14<00:34,  4.63it/s]

Epoch 340/500 - Train Loss: 6.6908 - Val Loss: 7.6360



 72%|███████▏  | 360/500 [01:18<00:30,  4.64it/s]

Epoch 360/500 - Train Loss: 6.0839 - Val Loss: 7.3392



 76%|███████▌  | 380/500 [01:22<00:25,  4.63it/s]

Epoch 380/500 - Train Loss: 6.1858 - Val Loss: 8.4341



 80%|████████  | 400/500 [01:26<00:21,  4.66it/s]

Epoch 400/500 - Train Loss: 6.5551 - Val Loss: 8.4514



 84%|████████▍ | 420/500 [01:31<00:17,  4.70it/s]

Epoch 420/500 - Train Loss: 6.2325 - Val Loss: 7.9209



 88%|████████▊ | 440/500 [01:35<00:12,  4.66it/s]

Epoch 440/500 - Train Loss: 6.4320 - Val Loss: 6.2085



 92%|█████████▏| 460/500 [01:39<00:08,  4.61it/s]

Epoch 460/500 - Train Loss: 6.4976 - Val Loss: 7.6918



 96%|█████████▌| 480/500 [01:44<00:04,  4.65it/s]

Epoch 480/500 - Train Loss: 6.2836 - Val Loss: 9.0014



100%|██████████| 500/500 [01:48<00:00,  4.60it/s]


Epoch 500/500 - Train Loss: 6.3811 - Val Loss: 7.2515



  4%|▍         | 20/500 [00:04<01:42,  4.70it/s]

Epoch 20/500 - Train Loss: 7.3620 - Val Loss: 6.2717



  8%|▊         | 40/500 [00:08<01:40,  4.57it/s]

Epoch 40/500 - Train Loss: 7.0958 - Val Loss: 5.5676



 12%|█▏        | 60/500 [00:13<01:34,  4.66it/s]

Epoch 60/500 - Train Loss: 7.1093 - Val Loss: 6.2746



 16%|█▌        | 80/500 [00:17<01:29,  4.71it/s]

Epoch 80/500 - Train Loss: 6.6029 - Val Loss: 5.6950



 20%|██        | 100/500 [00:21<01:25,  4.69it/s]

Epoch 100/500 - Train Loss: 6.8001 - Val Loss: 5.8758



 24%|██▍       | 120/500 [00:25<01:20,  4.71it/s]

Epoch 120/500 - Train Loss: 6.7889 - Val Loss: 5.5121



 28%|██▊       | 140/500 [00:30<01:16,  4.72it/s]

Epoch 140/500 - Train Loss: 6.8020 - Val Loss: 5.5888



 32%|███▏      | 160/500 [00:34<01:11,  4.73it/s]

Epoch 160/500 - Train Loss: 6.8763 - Val Loss: 5.3885



 36%|███▌      | 180/500 [00:38<01:08,  4.67it/s]

Epoch 180/500 - Train Loss: 6.7123 - Val Loss: 5.7116



 40%|████      | 200/500 [00:42<01:03,  4.74it/s]

Epoch 200/500 - Train Loss: 6.3102 - Val Loss: 5.8515



 44%|████▍     | 220/500 [00:47<00:59,  4.73it/s]

Epoch 220/500 - Train Loss: 6.8320 - Val Loss: 5.7943



 48%|████▊     | 240/500 [00:51<00:55,  4.69it/s]

Epoch 240/500 - Train Loss: 6.8846 - Val Loss: 4.9843



 52%|█████▏    | 260/500 [00:55<00:51,  4.63it/s]

Epoch 260/500 - Train Loss: 6.6478 - Val Loss: 5.7630



 56%|█████▌    | 280/500 [01:00<00:47,  4.60it/s]

Epoch 280/500 - Train Loss: 6.5025 - Val Loss: 5.4398



 60%|██████    | 300/500 [01:04<00:43,  4.60it/s]

Epoch 300/500 - Train Loss: 6.8039 - Val Loss: 5.8324



 64%|██████▍   | 320/500 [01:08<00:38,  4.63it/s]

Epoch 320/500 - Train Loss: 6.6403 - Val Loss: 5.5882



 68%|██████▊   | 340/500 [01:12<00:34,  4.63it/s]

Epoch 340/500 - Train Loss: 6.6012 - Val Loss: 5.6852



 72%|███████▏  | 360/500 [01:17<00:30,  4.62it/s]

Epoch 360/500 - Train Loss: 6.7956 - Val Loss: 5.3962



 76%|███████▌  | 380/500 [01:21<00:25,  4.64it/s]

Epoch 380/500 - Train Loss: 6.8371 - Val Loss: 6.1333



 80%|████████  | 400/500 [01:25<00:21,  4.62it/s]

Epoch 400/500 - Train Loss: 7.1609 - Val Loss: 4.9900



 84%|████████▍ | 420/500 [01:30<00:17,  4.64it/s]

Epoch 420/500 - Train Loss: 6.5821 - Val Loss: 5.2604



 88%|████████▊ | 440/500 [01:34<00:12,  4.63it/s]

Epoch 440/500 - Train Loss: 7.1194 - Val Loss: 5.7465



 92%|█████████▏| 460/500 [01:38<00:08,  4.66it/s]

Epoch 460/500 - Train Loss: 6.9958 - Val Loss: 5.8803



 96%|█████████▌| 480/500 [01:43<00:04,  4.62it/s]

Epoch 480/500 - Train Loss: 6.6164 - Val Loss: 5.4025



100%|██████████| 500/500 [01:47<00:00,  4.65it/s]


Epoch 500/500 - Train Loss: 6.8209 - Val Loss: 4.8780



  4%|▍         | 20/500 [00:04<01:43,  4.64it/s]

Epoch 20/500 - Train Loss: 7.3788 - Val Loss: 6.2314



  8%|▊         | 40/500 [00:08<01:38,  4.65it/s]

Epoch 40/500 - Train Loss: 7.0091 - Val Loss: 7.7852



 12%|█▏        | 60/500 [00:12<01:34,  4.65it/s]

Epoch 60/500 - Train Loss: 6.8175 - Val Loss: 7.8219



 16%|█▌        | 80/500 [00:17<01:29,  4.68it/s]

Epoch 80/500 - Train Loss: 6.4386 - Val Loss: 8.1071



 20%|██        | 100/500 [00:21<01:25,  4.66it/s]

Epoch 100/500 - Train Loss: 6.8883 - Val Loss: 5.1547



 24%|██▍       | 120/500 [00:25<01:22,  4.59it/s]

Epoch 120/500 - Train Loss: 6.3157 - Val Loss: 6.8684



 28%|██▊       | 140/500 [00:30<01:17,  4.66it/s]

Epoch 140/500 - Train Loss: 6.4890 - Val Loss: 6.7091



 32%|███▏      | 160/500 [00:34<01:12,  4.67it/s]

Epoch 160/500 - Train Loss: 6.3170 - Val Loss: 8.2259



 36%|███▌      | 180/500 [00:38<01:08,  4.67it/s]

Epoch 180/500 - Train Loss: 6.1670 - Val Loss: 8.4867



 40%|████      | 200/500 [00:42<01:04,  4.63it/s]

Epoch 200/500 - Train Loss: 6.3225 - Val Loss: 8.6280



 44%|████▍     | 220/500 [00:47<01:00,  4.65it/s]

Epoch 220/500 - Train Loss: 5.9348 - Val Loss: 5.4364



 48%|████▊     | 240/500 [00:51<00:55,  4.67it/s]

Epoch 240/500 - Train Loss: 6.3671 - Val Loss: 7.2660



 52%|█████▏    | 260/500 [00:55<00:51,  4.67it/s]

Epoch 260/500 - Train Loss: 6.2471 - Val Loss: 9.2610



 56%|█████▌    | 280/500 [01:00<00:47,  4.67it/s]

Epoch 280/500 - Train Loss: 6.8300 - Val Loss: 9.2769



 60%|██████    | 300/500 [01:04<00:43,  4.56it/s]

Epoch 300/500 - Train Loss: 5.2814 - Val Loss: 7.3870



 64%|██████▍   | 320/500 [01:08<00:38,  4.63it/s]

Epoch 320/500 - Train Loss: 6.0598 - Val Loss: 7.8946



 68%|██████▊   | 340/500 [01:13<00:34,  4.65it/s]

Epoch 340/500 - Train Loss: 5.9818 - Val Loss: 7.0641



 72%|███████▏  | 360/500 [01:17<00:30,  4.64it/s]

Epoch 360/500 - Train Loss: 6.2171 - Val Loss: 9.2442



 76%|███████▌  | 380/500 [01:21<00:25,  4.65it/s]

Epoch 380/500 - Train Loss: 6.1144 - Val Loss: 6.6379



 80%|████████  | 400/500 [01:26<00:21,  4.60it/s]

Epoch 400/500 - Train Loss: 6.8169 - Val Loss: 7.5851



 84%|████████▍ | 420/500 [01:30<00:17,  4.59it/s]

Epoch 420/500 - Train Loss: 6.1533 - Val Loss: 7.7859



 88%|████████▊ | 440/500 [01:34<00:12,  4.62it/s]

Epoch 440/500 - Train Loss: 6.4835 - Val Loss: 8.1651



 92%|█████████▏| 460/500 [01:39<00:08,  4.63it/s]

Epoch 460/500 - Train Loss: 6.5291 - Val Loss: 8.3439



 96%|█████████▌| 480/500 [01:43<00:04,  4.65it/s]

Epoch 480/500 - Train Loss: 6.2214 - Val Loss: 9.8206



100%|██████████| 500/500 [01:47<00:00,  4.64it/s]


Epoch 500/500 - Train Loss: 5.5794 - Val Loss: 8.3491



  4%|▍         | 20/500 [00:04<01:43,  4.63it/s]

Epoch 20/500 - Train Loss: 7.5279 - Val Loss: 6.1291



  8%|▊         | 40/500 [00:08<01:40,  4.60it/s]

Epoch 40/500 - Train Loss: 7.0749 - Val Loss: 5.8014



 12%|█▏        | 60/500 [00:12<01:34,  4.63it/s]

Epoch 60/500 - Train Loss: 7.0472 - Val Loss: 5.7983



 16%|█▌        | 80/500 [00:17<01:30,  4.64it/s]

Epoch 80/500 - Train Loss: 7.0895 - Val Loss: 5.2782



 20%|██        | 100/500 [00:21<01:26,  4.64it/s]

Epoch 100/500 - Train Loss: 6.4839 - Val Loss: 4.9570



 24%|██▍       | 120/500 [00:25<01:21,  4.64it/s]

Epoch 120/500 - Train Loss: 7.5182 - Val Loss: 4.8773



 28%|██▊       | 140/500 [00:30<01:18,  4.61it/s]

Epoch 140/500 - Train Loss: 6.7625 - Val Loss: 6.4271



 32%|███▏      | 160/500 [00:34<01:13,  4.60it/s]

Epoch 160/500 - Train Loss: 7.0571 - Val Loss: 5.7445



 36%|███▌      | 180/500 [00:39<01:09,  4.59it/s]

Epoch 180/500 - Train Loss: 6.7558 - Val Loss: 5.2506



 40%|████      | 200/500 [00:43<01:04,  4.62it/s]

Epoch 200/500 - Train Loss: 6.4699 - Val Loss: 6.4222



 44%|████▍     | 220/500 [00:47<01:00,  4.60it/s]

Epoch 220/500 - Train Loss: 6.6380 - Val Loss: 6.2476



 48%|████▊     | 240/500 [00:52<00:56,  4.61it/s]

Epoch 240/500 - Train Loss: 6.6179 - Val Loss: 7.1835



 52%|█████▏    | 260/500 [00:56<00:52,  4.61it/s]

Epoch 260/500 - Train Loss: 6.7768 - Val Loss: 5.7401



 56%|█████▌    | 280/500 [01:00<00:47,  4.59it/s]

Epoch 280/500 - Train Loss: 6.4058 - Val Loss: 6.7713



 60%|██████    | 300/500 [01:05<00:43,  4.61it/s]

Epoch 300/500 - Train Loss: 6.9823 - Val Loss: 6.0045



 64%|██████▍   | 320/500 [01:09<00:39,  4.60it/s]

Epoch 320/500 - Train Loss: 6.5427 - Val Loss: 6.4765



 68%|██████▊   | 340/500 [01:13<00:34,  4.60it/s]

Epoch 340/500 - Train Loss: 6.5501 - Val Loss: 7.5683



 72%|███████▏  | 360/500 [01:18<00:30,  4.61it/s]

Epoch 360/500 - Train Loss: 6.0510 - Val Loss: 7.0642



 76%|███████▌  | 380/500 [01:22<00:26,  4.61it/s]

Epoch 380/500 - Train Loss: 6.6362 - Val Loss: 6.7765



 80%|████████  | 400/500 [01:26<00:21,  4.60it/s]

Epoch 400/500 - Train Loss: 6.9217 - Val Loss: 7.6055



 84%|████████▍ | 420/500 [01:31<00:17,  4.60it/s]

Epoch 420/500 - Train Loss: 6.1364 - Val Loss: 7.6880



 88%|████████▊ | 440/500 [01:35<00:13,  4.60it/s]

Epoch 440/500 - Train Loss: 6.1690 - Val Loss: 7.0134



 92%|█████████▏| 460/500 [01:39<00:08,  4.61it/s]

Epoch 460/500 - Train Loss: 5.8628 - Val Loss: 7.2096



 96%|█████████▌| 480/500 [01:44<00:04,  4.61it/s]

Epoch 480/500 - Train Loss: 5.9276 - Val Loss: 5.4657



100%|██████████| 500/500 [01:48<00:00,  4.61it/s]

Epoch 500/500 - Train Loss: 7.6101 - Val Loss: 6.3678






In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import r2_score, mean_absolute_error

r2_scores = []
mae_scores = []
r_scores = []

l1 = []
l2 = []

colors = ['red', 'blue', 'green', 'orange', 'purple']

for fold in best_set:
    l1.extend([t.item() for t in fold[0]])
    l2.extend([t.item() for t in fold[1]])


print(r2_score(l1, l2))
print(pearsonr(l1, l2))
print(mean_absolute_error(l1,l2))

print()

total_true = []
total_pred = []
total_drug = []
total_base = []

for i, (true, pred, base) in enumerate(best_set):
    true = [t.item() for t in true]
    pred = [p.item() for p in pred]
    drug = [d[0].item() for d in base]
    base = [d[1].item() for d in base]
    
    total_true.extend(true)
    total_pred.extend(pred)
    total_drug.extend(drug)
    total_base.extend(base)
    # Calculate R-squared (Pearson correlation coefficient)
    r2 = r2_score(true, pred)
    r2_scores.append(r2)

    r_scores.append(pearsonr(true, pred))
    
    # Calculate Mean Absolute Error (MAE)
    mae = mean_absolute_error(true, pred)
    mae_scores.append(mae)
    


import csv
        
# Specify the filename for the CSV file
filename = 'featureless-vgae-full-val-results-ica.csv'

# Create a list of rows with headers
rows = [['true_post_bdi', 'predicted_post_bdi', 'drug (1 for psilo)', 'base_bdi']]
for true, pred, drug, base in zip(total_true, total_pred, total_drug, total_base):
    rows.append([true, pred, drug, base])

# Write the rows to the CSV file
with open(filename, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(rows)

0.41926838596560756
PearsonRResult(statistic=0.6721994598237655, pvalue=1.0886068061251415e-06)
5.391959922654288

