In [1]:
import sys
sys.path.append('../..')
sys.path.append('../../lib/src/')
import torch
from torch import nn
import os
import numpy as np
from matplotlib import pyplot as plt
import math

from lib.src.pythae.models import VAE
from lib.scripts.utils import Encoder_ADNI, Decoder_ADNI, My_Dataset
from lib.src.pythae.models.vae import VAEConfig
from lib.src.pythae.trainers import BaseTrainerConfig
from lib.src.pythae.pipelines.training import TrainingPipeline
from lib.src.pythae.samplers.normal_sampling import NormalSampler
from lib.src.pythae.samplers.manifold_sampler import RHVAESampler
from lib.src.pythae.trainers.training_callbacks import WandbCallback
from lib.src.pythae.models.nn import BaseEncoder, BaseDecoder
from lib.src.pythae.models.base.base_utils import ModelOutput

from geometric_perspective_on_vaes.sampling import build_metrics, hmc_sampling

%reload_ext autoreload
%autoreload 2
!nvidia-smi

Sun Sep 15 13:43:04 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A2000 12GB          Off |   00000000:01:00.0 Off |                  Off |
| 30%   36C    P8             10W /   70W |     114MiB /  12282MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [25]:
# Load the data
train_data = torch.load('data-models/adni/data/ADNI_train.pt') #(N, T, D)
eval_data = torch.load('data-models/adni/data/ADNI_eval.pt')
test_data = torch.load('data-models/adni/data/ADNI_test.pt')
print(train_data.shape)

train_seq_mask = torch.load('data-models/adni/data/ADNI_train_seq_mask.pt') #(N, T)
eval_seq_mask = torch.load('data-models/adni/data/ADNI_eval_seq_mask.pt')
test_seq_mask = torch.load('data-models/adni/data/ADNI_test_seq_mask.pt')


#Take only non-NaN values
train_data = train_data[train_seq_mask == 1]
eval_data = eval_data[eval_seq_mask == 1]
test_data = test_data[test_seq_mask == 1]
print(train_data.shape)

train_pix_mask = torch.ones_like(train_data, requires_grad=False).type(torch.bool)
eval_pix_mask = torch.ones_like(eval_data, requires_grad=False).type(torch.bool)
test_pix_mask = torch.ones_like(test_data, requires_grad=False).type(torch.bool)

train_dataset = My_Dataset(train_data)
eval_dataset = My_Dataset(eval_data)
test_dataset = My_Dataset(test_data)

torch.Size([8000, 8, 120])
torch.Size([56800, 120])


In [27]:
latent_dim = 9
input_dim = (1, 120)
encoder = Encoder_ADNI(input_dim, latent_dim)
decoder = Decoder_ADNI(input_dim, latent_dim)

model_config = VAEConfig(input_dim=input_dim, latent_dim= latent_dim, uses_default_encoder= False, uses_default_decoder= False, reconstruction_loss= 'mse')
vae = VAE(model_config=model_config, encoder=encoder, decoder=decoder)

training_config = BaseTrainerConfig(output_dir='pre-trained_vae',
num_epochs=50,
learning_rate=5*1e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
train_dataloader_num_workers=2,
eval_dataloader_num_workers=2,
steps_saving=25,
optimizer_cls="AdamW",
optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)},
scheduler_cls="ReduceLROnPlateau",
scheduler_params={"patience": 3, "factor": 0.8})


In [None]:
callbacks = [] # the TrainingPipeline expects a list of callbacks
wandb_cb = WandbCallback() # Build the callback 
# SetUp the callback 
wandb_cb.setup(
    training_config=training_config, # pass the training config
    model_config = model_config,
    project_name="pre_training_VAE_latdim9_fulldataset", # specify your wandb project # specify your wandb entity
)
callbacks.append(wandb_cb) 

In [None]:
vae = vae.to('cuda')
pipeline = TrainingPipeline(
    training_config=training_config,
    model=vae
)
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset,
    #callbacks=callbacks
)