In [2]:
import sys
sys.path.append('../..')
sys.path.append('../../lib/src/')
import torch
from torch import nn
import torch.nn.functional as F
import os
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import numpy as np
import logging




from diffusion.stable_diffusion.latent_diffusion import MyLatentDiffusion, LitLDM
from diffusion.stable_diffusion.model.unet import UNetModel
from diffusion.stable_diffusion.sampler.ddim import DDIMSampler

from lib.src.pythae.models import VAE
from lib.src.pythae.models.vae import VAEConfig
from lib.src.pythae.models import LLDM_IAF, LVAE_IAF_Config, LVAE_IAF
from lib.src.pythae.trainers import BaseTrainerConfig, BaseTrainer
from lib.scripts.utils import Encoder_ADNI, Decoder_ADNI, My_MaskedDataset
from lib.src.pythae.trainers.training_callbacks import WandbCallback

from geometric_perspective_on_vaes.sampling import hmc_sampling


def load_config_unet(config):
    return UNetModel(
        in_channels=config['in_channels'],
        out_channels=config['out_channels'],
        channels=config['channels'],
        n_res_blocks=config['n_res_blocks'],
        attention_levels=config['attention_levels'],
        channel_multipliers=config['channel_multipliers'],
        n_heads=config['n_heads'],
    )



%reload_ext autoreload
%autoreload 2

!nvidia-smi

Sun Sep 15 22:04:23 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A2000 12GB          Off |   00000000:01:00.0 Off |                  Off |
| 30%   37C    P8             10W /   70W |     114MiB /  12282MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
# Load the data
train_data = torch.load('ADNI_train.pt') #(N, T, D)
eval_data = torch.load('ADNI_eval.pt')
test_data = torch.load('ADNI_test.pt')
print(train_data.shape)

train_seq_mask = torch.load('ADNI_train_seq_mask.pt') #(N, T)
eval_seq_mask = torch.load('ADNI_eval_seq_mask.pt')
test_seq_mask = torch.load('ADNI_test_seq_mask.pt')


train_pix_mask = torch.ones_like(train_data, requires_grad=False).type(torch.bool)
eval_pix_mask = torch.ones_like(eval_data, requires_grad=False).type(torch.bool)
test_pix_mask = torch.ones_like(test_data, requires_grad=False).type(torch.bool)

train_dataset = My_MaskedDataset(train_data, train_seq_mask, train_pix_mask)
eval_dataset = My_MaskedDataset(eval_data, eval_seq_mask, eval_pix_mask)
test_dataset = My_MaskedDataset(test_data, test_seq_mask, test_pix_mask)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=200, shuffle=True)
val_loader = torch.utils.data.DataLoader(eval_data, batch_size=200, shuffle=True)

torch.Size([8000, 8, 120])


In [None]:
# encoder = Encoder_Chairs(config)
# decoder = Decoder_Chairs(config)
# vae = LVAE_IAF(config, encoder, decoder)


PATH_VAE_FOLDER = 'data-models/adni/trained_models/pre-trained_vae/final_model'
PATH_DIFFUSION_CKPT = 'data-models/adni/trained_models/pre-trained_ldm/checkpoints/epoch=49-step=1250.ckpt'


device = 'cuda'
vae = VAE.load_from_folder(PATH_VAE_FOLDER).to(device)
vae.eval()
_, _, _ = vae.retrieveG(train_data[train_seq_mask == 1], verbose = True, T_multiplier=5, device = device, addStdNorm=True)


# in_channels = 3
# out_channels = 3
# channels = 32
# n_res_blocks = 2
# attention_levels = [2]
# channel_multipliers = (1, 2, 4)
# n_heads = 16


in_channels = 1
out_channels = 1
channels = 32
n_res_blocks = 2
attention_levels = [0]
channel_multipliers = [1]
n_heads = 2

unet_config = {
    'in_channels': in_channels,
    'out_channels': out_channels,
    'channels': channels,
    'n_res_blocks': n_res_blocks,
    'attention_levels': attention_levels,
    'channel_multipliers': channel_multipliers,
    'n_heads': n_heads,
}

unet = load_config_unet(unet_config)

latent_scaling_factor = 1
n_steps = 1000

#Pas oublier de modif
linear_start =  0.00085
linear_end = 0.012

input_dim = (1, 120)
f = 40 #subsampling factor
latent_dim = 1* (120 // f) * (120 // f)
print('Latent dim:', latent_dim)


latent_diffusion = MyLatentDiffusion(unet, latent_scaling_factor, latent_dim, n_steps, linear_start, linear_end, channels = 1)
print('Number of parameters in the diffusion model: ', sum(p.numel() for p in latent_diffusion.parameters() if p.requires_grad))

model = LitLDM.load_from_checkpoint(PATH_DIFFUSION_CKPT, ldm = latent_diffusion, vae = vae, latent_dim = latent_dim, lr = 6e-4, channels = 1).to(device)
diffusion = model.ldm

In [5]:
model_config = LVAE_IAF_Config(
    input_dim=(1, 120),
    n_obs_per_ind=train_data.shape[1], #8 for Sprites, 7 as we remove last obs
    latent_dim=latent_dim,
    beta=0.2,
    n_hidden_in_made=2,
    n_made_blocks=4,
    warmup=5,
    context_dim=None,
    prior='standard',
    posterior='gaussian',
    linear_scheduling_steps=10,

)


device = 'cuda'
encoder = Encoder_ADNI(model_config.input_dim, model_config.latent_dim).to(device)
decoder = Decoder_ADNI(model_config.input_dim, model_config.latent_dim).to(device)
ddim_sampler = DDIMSampler(diffusion, time_steps=np.flip([997, 831, 665, 499, 333, 167, 83, 1]).copy(), ddim_eta = 0.25)
temperature = 0.25


#############
lldm = LLDM_IAF(model_config=model_config, encoder=encoder, decoder=decoder, 
                pretrained_vae=vae, pretrained_ldm=diffusion, ddim_sampler=ddim_sampler,
                verbose = True, temp = temperature)

lvae = LVAE_IAF(model_config, encoder, decoder).cuda()

Diffusion time steps  [997 831 665 499 333 167  83   1]
Running on  cuda:0
Freezing pre-trained VAE and pre-trained LDM...
Freezing done.
Number of trainable parameters: 2.0e+04
Number of total parameters: 4.0e+05


In [None]:
lvae = lvae.train()
lvae = lvae.to('cuda')

optimizer = torch.optim.Adam(lvae.parameters(), lr=1e-4)

training_config = BaseTrainerConfig(
        num_epochs=200,
        learning_rate=1e-4,
        batch_size=200,
        steps_saving=50,
        steps_predict=100,
        shuffle=False,
        output_dir='lldm'
    )


### Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=4, verbose=True)
trainer = BaseTrainer(
            model=lvae,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            training_config=training_config,
            optimizer=optimizer,
            scheduler=scheduler,
            #callbacks=callbacks
        )
trainer.train()

In [None]:
lrs = [4*1e-4, 2*1e-4, 1e-4, 5*1e-5]

for LR in lrs:

    optimizer = torch.optim.Adam(lldm.parameters(), lr=LR)

    ### Scheduler
    training_config = BaseTrainerConfig(
            num_epochs=50,
            learning_rate=LR,
            batch_size=200,
            steps_saving=50,
            steps_predict=100,
            shuffle=False,
            output_dir='lldm'
        )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=4, verbose=True)
    trainer = BaseTrainer(
                model=lldm,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                training_config=training_config,
                optimizer=optimizer,
                scheduler=scheduler,
                #callbacks=callbacks
            )
    trainer.train()