In [1]:
# change working dir
%cd ..

w:\Projects\Diffusion Model\UNet-Diffusion


In [2]:
# %pip install --upgrade torch diffusers

In [3]:
# !pip install -r requirements.txt

In [None]:
# dbutils.library.restartPython()

In [2]:
# %cd .. 

In [None]:
# # download dataset
# %sh
# gdown --folder https://drive.google.com/drive/u/1/folders/1Ak8sYUszWhoLWvVzv-s9HNxMbdI0sAiV -O /dbfs/mnt/ds-space/Hitesh/Datasets/ImageDataset

In [5]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import logging

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel

from utils.ema import create_ema_model
from utils.checkpoint import save_training_state, load_training_state
from utils.celeba_parquet_dataset import DatasetLoader
from utils.metrics.gpu import init_nvml, gpu_info
from omegaconf import OmegaConf
from utils.generate_samples import generate_sample

In [None]:
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()

torch.manual_seed(1)
handle = init_nvml()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Enable mixed precision training
scaler = torch.amp.GradScaler("cuda") if device == "cuda" else None
print("Mixed precision training enabled" if device == "cuda" else "Mixed precision training disabled")

# Load configuration
config = OmegaConf.load("configs/config_databricks_256.yaml")
# config = OmegaConf.load("configs/train_config_256.yaml")
print(f"Configuration loaded: {OmegaConf.to_yaml(config)}")

Using device: cuda
Mixed precision training enabled
Configuration loaded: data:
  path: data/images/
  parquet_path: data/celebA_high.parquet
  image_size: 256
  normalize: true
  caption_path: data/captions.jsonl
checkpoint:
  path: Checkpoints/
  ckpt_name: UNet_ckpt_256.pth
  ema_ckpt_name: UNet_ema_ckpt_256.pth
training:
  batch_size: 4
  validation_split: 0
  epochs: 10
  lr: 0.0001
  grad_accum_steps: 2
  use_ema: true
  ema_beta: 0.995
  step_start_ema: 2000
  num_workers: 4
  prefetch_factor: 2
  use_lpips: false
  warmup_epochs: 5
losses:
  lpips:
    net: vgg
sampling:
  dir: output/samples/
  num_samples: 25
  steps: 50
  guidance_scale: 7.5
model:
  type: unet
  sample_size: 32
  in_channels: 4
  out_channels: 4
  block_out_channels:
  - 256
  - 512
  - 1024
  - 1024
  down_block_types:
  - CrossAttnDownBlock2D
  - CrossAttnDownBlock2D
  - CrossAttnDownBlock2D
  - DownBlock2D
  up_block_types:
  - UpBlock2D
  - CrossAttnUpBlock2D
  - CrossAttnUpBlock2D
  - CrossAttnUpBlock2

In [7]:
# Load VAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device).eval()
for p in vae.parameters():
    p.requires_grad = False

# Load UNet
model = UNet2DConditionModel(
    sample_size=config.model.sample_size,
    in_channels=config.model.in_channels,
    out_channels=config.model.out_channels,
    down_block_types=config.model.down_block_types,
    up_block_types=config.model.up_block_types,
    block_out_channels=config.model.block_out_channels,
    layers_per_block=config.model.layers_per_block,
    cross_attention_dim=config.model.cross_attention_dim,
).to(device)

# Noise scheduler
scheduler = DDPMScheduler(
    num_train_timesteps=config.scheduler.timesteps,
    beta_start=config.scheduler.beta_start,
    beta_end=config.scheduler.beta_end,
    beta_schedule=config.scheduler.type,
)

# CLIP tokenizer & encoder
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval()

# EMA
unet_ema_model, ema = create_ema_model(model, beta=config.training.ema_beta, step_start_ema=config.training.step_start_ema)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.training.lr)
MSE_LOSS = torch.nn.MSELoss()
use_lpips = config.training.use_lpips
if use_lpips:
    import lpips
    LPIPS_LOSS   = lpips.LPIPS(net=config.losses.lpips).to(device).eval() # net=vgg or alex

print("=== Models, optimizers, losses initialized successfully ===")

=== Models, optimizers, losses initialized successfully ===


In [None]:
logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler("logs/training.log")
        ]
    )

# Data
print(f"Loading Dataset")
dataloader, _ = DatasetLoader(data_config=config.data, train_config=config.training, device=device)
print(f"Total Images: {len(dataloader.dataset)}, batch size: {dataloader.batch_size}")

# batch = next(iter(dataloader))
# print(f"Batch images shape: {batch['image'].shape}, Batch captions: {len(batch['text'])}, Batch images path: {len(batch['image_id'])}")

# Checkpoint
os.makedirs(config.checkpoint.path, exist_ok=True)
ckpt_path = os.path.join(config.checkpoint.path, config.checkpoint.ckpt_name)
start_epoch, best_loss = load_training_state(ckpt_path, model, optimizer, device)

# === Training loop ===
warmup_ep = config.training.warmup_epochs
for epoch in range(start_epoch, config.training.epochs + 1):
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    model.train()

    cumm_loss = 0.0
    cumm_mse = 0.0
    cumm_lpips = 0.0

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}/{config.training.epochs}")

    for batch_idx, batch in pbar:
        if batch_idx % config.training.grad_accum_steps == 0:
            optimizer.zero_grad(set_to_none=True)

        images = batch['image'].to(device).float()
        captions = batch['text']
        text_inputs = clip_tokenizer( captions, padding="max_length", truncation=True, max_length=35, return_tensors="pt").to(device)

        with torch.no_grad():
            # ---- VAE Encoding ----
            latents = vae.encode(images).latent_dist.sample() * 0.18215 
            text_embeddings = clip_encoder(**text_inputs).last_hidden_state

        t = torch.randint(0, scheduler.config.num_train_timesteps, (latents.size(0),), device=device)
        noise = torch.randn_like(latents)
        x_t = scheduler.add_noise(latents, noise, t)

        # ---- Noise Prediction ----
        with torch.amp.autocast(device_type='cuda', enabled=(device == 'cuda')):
            noise_pred = model(x_t, timestep=t, encoder_hidden_states=text_embeddings).sample
            mse_loss = MSE_LOSS(noise_pred, noise) / config.training.grad_accum_steps

            if use_lpips:
                # lpips weight
                if epoch <= warmup_ep:
                    lpips_weight = 0.0
                else:
                    # ramp-to-0.05 over [warmup_ep+1 .. 30], then hold
                    frac = min((epoch - warmup_ep) / float(30 - warmup_ep), 1.0)
                    lpips_weight = 0.05 * frac
                # lpips loss
                if lpips_weight > 0:
                    alpha_t = scheduler.alphas_cumprod[t].view(-1, 1, 1, 1).clamp(min=1e-7)
                    pred_x0 = (x_t - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
                    normed_x0 = torch.nan_to_num(pred_x0 / 0.18215) # root cause
                    # normed_x0 = normed_x0.clamp(-6, 6)  # This solves the issue if nan in lpips but need optimum range for normed_x0 if any issue

                    with torch.no_grad():
                        with torch.amp.autocast(device_type='cuda', enabled=False):
                            pred_rgb = vae.decode(normed_x0).sample.clamp(-1, 1)

                    if torch.isnan(pred_rgb).any() or torch.isinf(pred_rgb).any():
                        print(f"[FATAL] pred_rgb exploded at epoch {epoch}, step {batch_idx+1}")
                        print(f"normed_x0 stats: min={normed_x0.min():.2f}, max={normed_x0.max():.2f}, std={normed_x0.std():.2f}")
                        raise ValueError("pred_rgb contains NaNs or Infs")

                    lpips_loss =  LPIPS_LOSS(pred_rgb, images).mean()
                else:
                    lpips_loss = torch.tensor(0.0, device=device)

                total_loss = mse_loss + lpips_weight * lpips_loss
            
            else:
                total_loss = mse_loss

        # ---- Backward Pass ----
        scaler.scale(total_loss).backward() # Overall loss

        # ---- Gradient Accumulation ----
        if (batch_idx + 1) % config.training.grad_accum_steps == 0:
            if device == 'cuda':
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                ema.step_ema(unet_ema_model, model)
                
        # ---- Progress Tracking ----
        if use_lpips:
            cumm_lpips += lpips_loss.item() if isinstance(lpips_loss, float) else lpips_loss.item()
            avg_lpips = cumm_lpips / (batch_idx + 1)
        
        cumm_mse += mse_loss.item()
        avg_mse = cumm_mse / (batch_idx + 1)
        cumm_loss += total_loss.item() # main
        avg_loss = cumm_loss / (batch_idx + 1) # Total average loss
        best_loss = min(best_loss, avg_loss)
        avg_lpips = avg_lpips if use_lpips else 0.0

        pbar.set_postfix(avg_loss = avg_loss, avg_lpips = avg_lpips, GPU=gpu_info(handle))
    
        if (batch_idx+1) % 1 == 0:
            # normed_x0 min/max: {normed_x0.min().item():.3f}/{normed_x0.max().item():.3f} # Add logging for debugging
            logging.info(f"Epoch: {epoch} Batch: {batch_idx+1} | AVG MSE: {avg_mse:.4f} | AVG lpips: {avg_lpips:.4f} | AVG Total: {avg_loss:.4f}")

        # Epoch summary logging
    avg_loss = cumm_loss / (batch_idx + 1)
    print(f"Epoch {epoch} done. Avg loss: {avg_loss:.4f}")

    print("SAVING MODEL STATES...")
    # Save checkpoint & EMA weights
    save_training_state(
        checkpoint_path=ckpt_path, epoch=epoch,
        model=model, optimizer=optimizer,
        avg_loss=avg_loss, best_loss=best_loss
    )
    print("UNET MODEL SAVED!")

    ema_path = os.path.join(config.checkpoint.path, config.checkpoint.ema_ckpt_name)
    torch.save(unet_ema_model.state_dict(), ema_path)
    print("EMA_UNET MODEL SAVED!")

    # Generate visual sample for this epoch
    generate_sample(
        epoch=epoch, vae=vae, ema_model=unet_ema_model,
        scheduler=scheduler, tokenizer=clip_tokenizer, text_encoder=clip_encoder,
        config=config, device=device
    )

print("Training completed.")

Loading Dataset
Total Images: 12288, batch size: 4
⚠️ Checkpoint not found at Checkpoints/UNet_ckpt_256.pth. Starting from scratch.


Epoch 1/10:   0%|          | 2/3072 [00:11<4:52:27,  5.72s/it, GPU=🚨GPU usage:12170 > 11900 Mib), avg_loss=0.543, avg_lpips=0]


KeyboardInterrupt: 

# Sample-Inference

In [None]:
import os
import torch
from torchvision.utils import save_image, make_grid
from tqdm import tqdm

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
from utils.ema import create_ema_model
from utils.metrics.gpu import init_nvml, gpu_info

@torch.no_grad()
def main():
    # Load configuration
    config = OmegaConf.load("configs/train_config_256.yaml")
    model_cfg = config.model
    sample_cfg = config.sampling

    device = "cuda" if torch.cuda.is_available() else "cpu"
    handle = init_nvml()

    # Prepare output
    os.makedirs(sample_cfg.dir, exist_ok=True)
    output_path = os.path.join(sample_cfg.dir, 'test_sample_grid.png')

    # Load VAE
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device).eval()

    # Load conditional U-Net
    model = UNet2DConditionModel(
        sample_size=model_cfg.sample_size,
        in_channels=model_cfg.in_channels,
        out_channels=model_cfg.out_channels,
        down_block_types=tuple(model_cfg.down_block_types),
        up_block_types=tuple(model_cfg.up_block_types),
        block_out_channels=tuple(model_cfg.block_out_channels),
        layers_per_block=model_cfg.layers_per_block,
        cross_attention_dim=model_cfg.cross_attention_dim,
    ).to(device)

    # EMA wrapper
    ema_model, _ = create_ema_model(
        model,
        beta=config.training.ema_beta,
        step_start_ema=config.training.step_start_ema
    )
    ema_ckpt = os.path.join(config.checkpoint.path, config.checkpoint.ema_ckpt_name)
    ema_model.load_state_dict(torch.load(ema_ckpt, map_location=device))
    ema_model.eval()

    # Scheduler aligned with training
    scheduler = DDPMScheduler(
        num_train_timesteps=config.scheduler.timesteps,
        beta_start=config.scheduler.beta_start,
        beta_end=config.scheduler.beta_end,
        beta_schedule=config.scheduler.type
    )
    scheduler.set_timesteps(sample_cfg.steps)

    # Load CLIP
    tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
    text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14').to(device).eval()

    # Prepare guidance embeddings
    prompt = "a beautiful woman in a red dress"
    text_inputs = tokenizer(
        [prompt] * sample_cfg.num_samples,
        padding="max_length",
        truncation=True,
        max_length=77,
        return_tensors="pt"
    ).to(device)
    text_emb = text_encoder(**text_inputs).last_hidden_state
    # Unconditional (empty) embeddings
    uncond_inputs = tokenizer(
        [""] * sample_cfg.num_samples,
        padding="max_length",
        truncation=True,
        max_length=77,
        return_tensors="pt"
    ).to(device)
    uncond_emb = text_encoder(**uncond_inputs).last_hidden_state

    # Sampling
    num_samples = sample_cfg.num_samples
    shape = (num_samples, model_cfg.in_channels, model_cfg.sample_size, model_cfg.sample_size)
    latents = torch.randn(shape, device=device)
    guidance_scale = sample_cfg.guidance_scale if hasattr(sample_cfg, 'guidance_scale') else 7.5

    for t in tqdm(scheduler.timesteps, desc="Sampling"):  # timesteps descends
        t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)

        # Predict noise for both conditional and unconditional
        eps_uncond = ema_model(latents, timestep=t_batch, encoder_hidden_states=uncond_emb).sample
        eps_cond   = ema_model(latents, timestep=t_batch, encoder_hidden_states=text_emb).sample
        # Classifier-free guidance
        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

        # Step
        latents = scheduler.step(eps, t, latents).prev_sample

    # Decode latents to images
    images = vae.decode(latents / 0.18215).sample
    images = (images.clamp(-1, 1) + 1) / 2

    # Save grid
    grid = make_grid(images, nrow=int(num_samples**0.5))
    save_image(grid, output_path)
    print(f"✅ Samples saved to {output_path}")

if __name__ == "__main__":
    torch.cuda.empty_cache()
    main()
