# Model and Training

## Load the model

In [1]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from diffusers import UNet2DModel, DDPMScheduler, DDIMScheduler
ON_COLAB = False  # Set to True if running on Google Colab
if ON_COLAB:
    !pip install dotenv
    
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()


# Configuration
target_size_for_training = (128, 128)
data_root = f"./raw_data"  # root directory containing 'train' and 'test' subfolders
if ON_COLAB:
    from google.colab import drive

    drive.mount(os.getenv("GOOGLE_DRIVE_CONTENT_PATH", "/content/drive"))
    data_root = os.getenv("GOOGLE_DRIVE_PATH_RESIZED", data_root)

train_dir = os.path.join(data_root, "train")
test_dir = os.path.join(data_root, "test")
model_save_dir = os.getenv("MODEL_SAVE_DIR", "./checkpoints")
os.makedirs(model_save_dir, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
IMAGE_SIZE = target_size_for_training[0]

# 2. Define the UNet diffusion model
from model_enrico import get_unet_model

model = get_unet_model(
    sample_size=IMAGE_SIZE,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 512),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
    dropout=0.1,
).to(DEVICE)

# Print model summary
def print_model_summary(model):
    print("Model Summary:")
    print(f"Model Type: {type(model).__name__}")
    print(f"Number of Parameters: {sum(p.numel() for p in model.parameters())}")
    print(f"Device: {next(model.parameters()).device}")

print_model_summary(model)


# 5. Load the model checkpoint if available
def load_checkpoint(ckpt_path, model, optimizer=None, device=torch.device("cpu")):
    loaded = False
    if not os.path.exists(ckpt_path):
        print(f"No checkpoint found at {ckpt_path}, starting fresh.")
        return loaded, model, optimizer, 0

    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint["model"])
    model.to(device)
    loaded = True
    start_epoch = checkpoint.get("epoch", 0)
    if optimizer is not None and "optimizer" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer"])
        print(f"Loaded optimizer state from checkpoint '{ckpt_path}'")
    else:
        print("Optimizer state not found in checkpoint, starting with a new optimizer.")
    print(f"Loaded checkpoint '{ckpt_path}' (epoch {start_epoch})")
    return loaded, model, optimizer, start_epoch

WEIGHT_DECAY = 1e-5  # weight decay for regularization
optimizer = torch.optim.Adam(
    model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

# Load the model checkpoint
ckpt = "checkpoints/ddim_unet_epoch25.pth"
isModelLoadedFromCheckpoint, model, optimizer, start_epoch = load_checkpoint(
    ckpt, model, optimizer, device=DEVICE
)
model.eval()
if isModelLoadedFromCheckpoint:
    print(f"Model {ckpt.split('/')[-1]} loaded and moved to {DEVICE}, starting from epoch {start_epoch}.")
else:
    print(f"Model {ckpt.split('/')[-1]} not found. Starting from scratch, loaded on {DEVICE}, starting from epoch {start_epoch}.")

Model Summary:
Model Type: UNet2DModel
Number of Parameters: 61830529
Device: cuda:0
No checkpoint found at checkpoints/ddim_unet_epoch25.pth, starting fresh.
Model ddim_unet_epoch25.pth not found. Starting from scratch, loaded on cuda, starting from epoch 0.


## Training and validation

In [None]:
from utils import NUM_TRAIN_TIMESTEPS, sample_images
from utils import AugmentedDataset
from matplotlib import pyplot as plt
import torch
from torch.amp import autocast, GradScaler

BATCH_SIZE = 16
NUM_EPOCHS = 20

# 4. Optimizer

from torch.optim.lr_scheduler import CosineAnnealingLR

# Cosine Annealing scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)


# Print the number of training and validation images before data augmentation
def count_images(root_dir, extensions=(".png", ".jpg", ".jpeg")):
    count = 0
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for fname in filenames:
            if fname.lower().endswith(extensions):
                count += 1
    return count


print(
    f"Number of training images before data augmentation: " f"{count_images(train_dir)}"
)
print(
    f"Number of validation images before data augmentation: "
    f"{count_images(test_dir)}"
)

import matplotlib.pyplot as p  # 1. Data transforms and datasets - this performs data augmentation

train_dataset = AugmentedDataset(root_dir=train_dir, image_size=IMAGE_SIZE)
test_dataset = AugmentedDataset(root_dir=test_dir, image_size=IMAGE_SIZE)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True
)

# Print the number of training and validation images
print(f"Number of training images: {len(train_loader.dataset)}")
print(f"Number of validation images: {len(test_loader.dataset)}")

# 3. Schedulers
noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS)
ddim_scheduler = DDIMScheduler(
    beta_start=noise_scheduler.config.beta_start,
    beta_end=noise_scheduler.config.beta_end,
    beta_schedule=noise_scheduler.config.beta_schedule,
    clip_sample=False,
)

ddim_scheduler.set_timesteps(NUM_TRAIN_TIMESTEPS)


# Mixed Precision Training
model = torch.compile(model)

# Create a GradScaler for mixed precision training
scaler = GradScaler(device=DEVICE)


# 5. Training + Validation Loop
train_losses = []
val_losses = []
print(
    f"Training on {DEVICE} with batch size {BATCH_SIZE} for {NUM_EPOCHS} epochs. Starting from epoch {start_epoch}."
)

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss = 0.0

    for step, (images, _) in enumerate(train_loader, 1):
        images = images.to(DEVICE)
        batch_size = images.size(0)

        # sample random noise and timesteps
        noise = torch.randn_like(images)
        timesteps = torch.randint(0, NUM_TRAIN_TIMESTEPS, (batch_size,), device=DEVICE)

        # ---- forward + backward con AMP ----
        optimizer.zero_grad()
        with autocast(device_type=DEVICE.type, enabled=True):
            noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
            noise_pred = model(noisy_images, timesteps).sample
            loss = F.mse_loss(noise_pred, noise)

        # scalpatura del gradiente e step optimizer
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()  # aggiornamento LR

        train_loss += loss.item()
        if step % 50 == 0 or step == len(train_loader):
            print(
                f"[Epoch {epoch}/{NUM_EPOCHS} | Step {step}/{len(train_loader)}] "
                f"Train Loss: {loss.item():.6f}"
            )

    # Sampling di esempio
    sample_images(
        output_path=f"{model_save_dir}/epoch_{epoch + start_epoch}.png",
        num_steps=NUM_TRAIN_TIMESTEPS,
        DEVICE=DEVICE,
        IMAGE_SIZE=IMAGE_SIZE,
        model=model,
        ddim_scheduler=ddim_scheduler,
    )

    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"==> Epoch {epoch} Done. Avg Train Loss: {avg_train_loss:.6f}")

    # Validation (anche qui AMP ma senza backward)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(DEVICE)
            batch_size = images.size(0)
            noise = torch.randn_like(images)
            timesteps = torch.randint(
                0, NUM_TRAIN_TIMESTEPS, (batch_size,), device=DEVICE
            )

            with autocast(device_type=DEVICE.type, enabled=True):
                noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
                noise_pred = model(noisy_images, timesteps).sample
                val_loss += F.mse_loss(noise_pred, noise).item()

    avg_val_loss = val_loss / len(test_loader)
    val_losses.append(avg_val_loss)
    print(f"==> Epoch {epoch} Done. Avg Validation Loss: {avg_val_loss:.6f}")

    # Save checkpoint
    ckpt_path = os.path.join(
        model_save_dir, f"ddim_unet_epoch{epoch + start_epoch}.pth"
    )
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch + start_epoch,
        },
        ckpt_path,
    )
    print(f"Saved checkpoint: {ckpt_path}\n")

# Plotting the training and validation loss
epochs_plt = list(range(1, NUM_EPOCHS + 1))
plt.figure(figsize=(8, 5))
plt.plot(epochs_plt, train_losses, label="Train Loss")
plt.plot(epochs_plt, val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Andamento del Train vs Validation Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

Number of training images before data augmentation: 3306
Number of validation images before data augmentation: 327
Number of training images: 23142
Number of validation images: 2289
Training on cuda with batch size 16 for 20 epochs. Starting from epoch 0.


W0514 10:49:07.342000 5356 Lib\site-packages\torch\_inductor\utils.py:1361] [0/0] Not enough SMs to use max_autotune_gemm mode


TritonMissing: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at: https://github.com/triton-lang/triton

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


In [6]:

from diffusers import DDPMScheduler, DDIMScheduler
from tqdm import tqdm
from utils import sample_images
# Imposta il numero di step di inferenza per DDIM (uguale a NUM_TRAIN_TIMESTEPS di default)
# 6. Sampling example with DDIM
noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS)

ddim_scheduler = DDIMScheduler(
    beta_start=noise_scheduler.config.beta_start,
    beta_end=noise_scheduler.config.beta_end,
    beta_schedule=noise_scheduler.config.beta_schedule,
    clip_sample=False
)

ddim_scheduler.set_timesteps(NUM_TRAIN_TIMESTEPS)
sample_images(
    output_path="result/ddim_sample.png",
    num_steps=NUM_TRAIN_TIMESTEPS,
    DEVICE=DEVICE,
    IMAGE_SIZE=IMAGE_SIZE,
    model=model,
    ddim_scheduler=ddim_scheduler
)

Sampling DDIM:   1%|          | 10/1000 [00:08<14:23,  1.15it/s]


KeyboardInterrupt: 