In [1]:
import sys
#%cd variational_inference_for_longitudinal_data/
sys.path.append('diffusion/stable_diffusion/')
sys.path.append('diffusion/stable_diffusion/model/')
sys.path.append('lib/src/')
import torch
from torch import nn
import os
from matplotlib import pyplot as plt

import numpy as np


from diffusion.stable_diffusion.latent_diffusion import MyLatentDiffusion, LitLDM
from diffusion.stable_diffusion.model.unet import UNetModel
from diffusion.stable_diffusion.sampler.ddim import DDIMSampler

from lib.src.pythae.models import VAE
from lib.src.pythae.models import LLDM_IAF, LVAE_IAF_Config
from lib.src.pythae.trainers import BaseTrainerConfig, BaseTrainer
from lib.scripts.utils import Encoder_Chairs,Decoder_Chairs
from lib.scripts.utils import My_MaskedDataset, make_batched_masks


def load_config_unet(config):
    return UNetModel(
        in_channels=config['in_channels'],
        out_channels=config['out_channels'],
        channels=config['channels'],
        n_res_blocks=config['n_res_blocks'],
        attention_levels=config['attention_levels'],
        channel_multipliers=config['channel_multipliers'],
        n_heads=config['n_heads'],
    )

%reload_ext autoreload
%autoreload 2

!nvidia-smi

Tue Jun  4 20:22:30 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.78                 Driver Version: 550.78         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A2000 12GB          Off |   00000000:01:00.0 Off |                  Off |
| 40%   63C    P8             13W /   70W |    2118MiB /  12282MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
PATH_DATA = 'lib/my_data/sprites/Sprites_train.pt'


train_data = torch.load(os.path.join(PATH_DATA))[:-1000]
eval_data = torch.load(os.path.join(PATH_DATA), map_location="cpu")[-1000:]
#test_data = torch.load(os.path.join('lib/my_data/sprites/Sprites_test.pt'), map_location="cpu")

print(train_data.shape)
train_data = train_data.permute(0, 1, 4, 2, 3)
eval_data = eval_data.permute(0, 1, 4, 2, 3)
#test_data = test_data.permute(0, 1, 4, 2, 3)
print(train_data.shape)

train_seq_mask = torch.ones(train_data.shape[:2], requires_grad=False).type(torch.bool)
eval_seq_mask = torch.ones(eval_data.shape[:2], requires_grad=False).type(torch.bool)
#test_seq_mask = torch.ones(test_data.shape[:2], requires_grad=False).type(torch.bool)
train_pix_mask = torch.ones_like(train_data, requires_grad=False).type(torch.bool)
eval_pix_mask = torch.ones_like(eval_data, requires_grad=False).type(torch.bool)
#test_pix_mask = torch.ones_like(test_data, requires_grad=False).type(torch.bool)

train_dataset = My_MaskedDataset(train_data, train_seq_mask, train_pix_mask)
eval_dataset = My_MaskedDataset(eval_data, eval_seq_mask, eval_pix_mask)
#test_dataset = My_MaskedDataset(test_data, test_seq_mask, test_pix_mask)


NUM_WORKERS = 12
train_loader = torch.utils.data.DataLoader(train_data, batch_size=200, shuffle=True)
val_loader = torch.utils.data.DataLoader(eval_data, batch_size=200, shuffle=True)

torch.Size([8000, 8, 64, 64, 3])
torch.Size([8000, 8, 3, 64, 64])


In [3]:
# encoder = Encoder_Chairs(config)
# decoder = Decoder_Chairs(config)
# vae = LVAE_IAF(config, encoder, decoder)
PATH_VAE_FOLDER = 'pre-trained_vae/VAE_training_2024-05-30_17-33-43-latdim12/final_model'
PATH_DIFFUSION_CKPT = 'ldm/lightning_logs/version_30/checkpoints/epoch=44-step=900.ckpt'




device = 'cpu'
vae = VAE.load_from_folder(PATH_VAE_FOLDER).to(device)
vae.eval()
_, _ = vae.retrieveG(train_data, verbose = True)


# in_channels = 3
# out_channels = 3
# channels = 32
# n_res_blocks = 2
# attention_levels = [2]
# channel_multipliers = (1, 2, 4)
# n_heads = 16


in_channels = 3
out_channels = 3
channels = 64
n_res_blocks = 4
attention_levels = [0]
channel_multipliers = [1]
n_heads = 4

unet_config = {
    'in_channels': in_channels,
    'out_channels': out_channels,
    'channels': channels,
    'n_res_blocks': n_res_blocks,
    'attention_levels': attention_levels,
    'channel_multipliers': channel_multipliers,
    'n_heads': n_heads,
}

unet = load_config_unet(unet_config)

latent_scaling_factor = 1
n_steps = 1000
linear_start =  0.00085
linear_end = 0.012

input_dim = (3, 64, 64)
f = 32 #subsampling factor
latent_dim = 3* (64 // f) * (64 // f)
print('Latent dim:', latent_dim)


latent_diffusion = MyLatentDiffusion(unet, latent_scaling_factor, latent_dim, n_steps, linear_start, linear_end)
print('Number of parameters in the diffusion model: ', sum(p.numel() for p in latent_diffusion.parameters() if p.requires_grad))

model = LitLDM.load_from_checkpoint(PATH_DIFFUSION_CKPT, ldm = latent_diffusion, vae = vae, latent_dim = latent_dim, lr = 6e-4).to(device)
diffusion = model.ldm




Running Kmedoids
Finding temperature
Best temperature found:  1.8156673908233643
Building metric
Latent dim: 12
Number of parameters in the diffusion model:  2223043


In [8]:
config = LVAE_IAF_Config(
    input_dim=(3, 64, 64),
    n_obs_per_ind=train_data.shape[1], #8 for Sprites
    latent_dim=latent_dim,
    beta=1,
    n_hidden_in_made=6,
    n_made_blocks=4,
    warmup=2,
    context_dim=None,
    prior='standard',
    posterior='iaf',
    linear_scheduling_steps=10
)

device = 'cpu'
encoder = Encoder_Chairs(config).to(device)
decoder = Decoder_Chairs(config).to(device)
ddim_sampler = DDIMSampler(diffusion, n_steps = train_data.shape[1]-1, ddim_eta = 1)
zT_samples = torch.load('zT_samples.pt') #shape (1000, 12)


lldm = LLDM_IAF(model_config=config, encoder=encoder, decoder=decoder, 
                pretrained_vae=vae, pretrained_ldm=diffusion, ddim_sampler=ddim_sampler,
                precomputed_zT_samples=zT_samples, verbose = True)



Running on  cuda:0
Freezing pre-trained VAE and pre-trained LDM...
Freezing done.
Number of trainable parameters: 3.6e+06
Number of total parameters: 6.9e+06


In [9]:
training_config = BaseTrainerConfig(
        num_epochs=100,
        learning_rate=1e-6,
        batch_size=512,
        steps_saving=None,
        steps_predict=100,
        shuffle=True
    )

optimizer = torch.optim.Adam(lldm.parameters(), lr=training_config.learning_rate, eps=1e-5)

### Scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[50, 100, 125, 150],
    gamma=0.5,
    verbose=True
)

trainer = BaseTrainer(
            model=lldm,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            training_config=training_config,
            optimizer=optimizer,
            scheduler=scheduler,
        )

In [10]:
trainer.train()

Model passed sanity check !

Created dummy_output_dir/LLDM_IAF_training_2024-06-04_20-28-12. 
Training config, checkpoints and final model will be saved here.

Successfully launched training !



Training of epoch 1/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 1/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 385.7482
Eval loss: 386.5037
--------------------------------------------------------------------------


Training of epoch 2/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 2/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 680.4539
Eval loss: 1865.1428
--------------------------------------------------------------------------


Training of epoch 3/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 3/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 776.7308
Eval loss: 958.6192
--------------------------------------------------------------------------


Training of epoch 4/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 4/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1053.5114
Eval loss: 925.493
--------------------------------------------------------------------------


Training of epoch 5/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 5/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 655.3809
Eval loss: 405.1302
--------------------------------------------------------------------------


Training of epoch 6/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 6/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1002.5866
Eval loss: 1554.6194
--------------------------------------------------------------------------


Training of epoch 7/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 7/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1651.2015
Eval loss: 382.2829
--------------------------------------------------------------------------


Training of epoch 8/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 8/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1228.4372
Eval loss: 1534.7603
--------------------------------------------------------------------------


Training of epoch 9/100:   0%|          | 0/16 [00:00<?, ?batch/s]

Eval of epoch 9/100:   0%|          | 0/2 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1132.664
Eval loss: 659.9601
--------------------------------------------------------------------------


Training of epoch 10/100:   0%|          | 0/16 [00:00<?, ?batch/s]

KeyboardInterrupt: 

## TO DO

#ergonomiser le Riemaniann sampling
- inspecter LVIAF
- train !

Possible pain points : le diffusion backward très bof (at least dans les intemédiaires). Après c'est du sampling... wait and see

30/05
Le problème du vanishing G a été résolu en passant à une dim latente de 12.
Il faut améliorer le diffusion model !
Eventuellemnt passer à du 16, puis 1, 4, 4 pour avoir des étages ce qui  n'est pas le cas now. Checker si on a pas de vansihing G.


04/06 Prêt à train (je crois ?) Il faudrait probablement tester les priors (à voir cmt)