In [None]:
import torch
import os
import pytorch_lightning as pl

from vae import DisentangleVAE
from dataset import CelebADataModule

In [None]:
RESULTS_DIR = '../results'
DATA_DIR = '../data'

IMG_SIZE = 64
NB_CHANNELS = 3
Z_DIM = 64
BETA = 1
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 100

In [None]:
celeba_dm = CelebADataModule(
    data_dir = DATA_DIR,
    img_size = IMG_SIZE,
    batch_size = BATCH_SIZE
)
celeba_dm.prepare_data()
celeba_dm.setup()   

In [None]:
VAE = DisentangleVAE(
    img_size = IMG_SIZE,
    nb_channels = NB_CHANNELS,
    z_dim = Z_DIM,
    beta = BETA,
    learning_rate = LEARNING_RATE
)

In [None]:
checkpooint_callback = pl.callbacks.ModelCheckpoint(
    dirpath = os.path.join(RESULTS_DIR, 'checkpoints'),
    filename = 'disentangle_vae-{epoch:02d}-{val_loss:.2f}',
    monitor = 'val_loss',
    mode = 'min',
    save_top_k = 1
)

In [None]:
trainer = pl.Trainer(
    max_epochs = EPOCHS,
    accelerator = 'auto',
    devices = 'auto',
    callbacks = [checkpooint_callback],
)

In [None]:
trainer.fit(VAE, datamodule = celeba_dm)
trainer.save_checkpoint(os.path.join(RESULTS_DIR, 'disentangle_vae.ckpt'))

In [None]:
print("\nLoading best model for inference...")
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
    loaded_model = VAE.load_from_checkpoint(best_model_path)
else:
    loaded_model = vae_model # Fallback

loaded_model.eval()
# Sposta il modello sul device corretto per l'inferenza
device = trainer.strategy.root_device # Ottieni il device su cui il trainer ha addestrato
loaded_model.to(device)

# Genera nuove immagini
print("Generating new images with the trained model...")
with torch.no_grad():
    sample = torch.randn(16, loaded_model.z_dim).to(device)
    generated_images = loaded_model._decode(sample).cpu()
    from torchvision.utils import make_grid, save_image # Importa qui se non usi utils.py
    grid = make_grid(generated_images, nrow=4, padding=2, normalize=True)
    save_image(grid, os.path.join(RESULTS_DIR, 'final_generated_samples.png'))
    print(f"Final generated samples saved in {os.path.join(RESULTS_DIR, 'final_generated_samples.png')}")

# Ricostruisci immagini di esempio
print("\nReconstructing example images from the dataset...")
with torch.no_grad():
    # Ottieni un batch dal dataloader di validazione
    # Non è necessario chiamare dm.setup() di nuovo se trainer.fit() è già stato chiamato
    # e dm.val_dataloader() è disponibile
    val_dataloader_for_inference = celeba_dm.val_dataloader() 
    data_iter = iter(val_dataloader_for_inference)
    sample_batch, _ = next(data_iter)
    sample_batch = sample_batch[:16].to(device) # Prendi le prime 16 immagini
    
    recon_sample_batch, _, _ = loaded_model(sample_batch)
    
    comparison = torch.cat([sample_batch.cpu(), recon_sample_batch.cpu()])
    grid_comparison = make_grid(comparison, nrow=16, padding=2, normalize=True)
    save_image(grid_comparison, os.path.join(RESULTS_DIR, 'final_reconstructions_comparison.png'))
    print(f"Reconstruction comparison saved in {os.pi.join(RESULTS_DIR, 'final_reconstructions_comparison.png')}")