In [None]:
# change working dir
%cd ..

In [None]:
# %pip install --upgrade torch diffusers

In [None]:
# !pip install -r requirements.txt

In [None]:
# dbutils.library.restartPython()

In [1]:
%cd .. 

d:\Projects\Github\UNet Diffusion


In [2]:
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
import logging

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel
from utils.ema import create_ema_model
from utils.checkpoint import save_training_state, load_training_state
from utils.celeba_with_caption import CelebAloader
# from utils.celeba_dataset_databricks import CelebAloader
from utils.metrics.gpu import init_nvml, gpu_info
from omegaconf import OmegaConf
import lpips
from utils.loss.lpips import safe_lpips
from utils.generate_samples import generate_sample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training.log")
    ]
)

In [4]:
def safe_lpips(pred_rgb, target_rgb, lpips_model, device):
    """
    Computes LPIPS loss safely, avoiding NaNs or Infs.

    Args:
        pred_rgb (torch.Tensor): Predicted images in [-1, 1], shape [B, 3, H, W]
        target_rgb (torch.Tensor): Ground-truth images in [-1, 1], shape [B, 3, H, W]
        lpips_model (lpips.LPIPS): An instance of the LPIPS model
        device (str or torch.device): Device for returning zero loss if invalid

    Returns:
        torch.Tensor: Scalar LPIPS loss or 0.0 if NaN/Inf
    """
    with torch.no_grad():
        val = lpips_model(pred_rgb, target_rgb).mean()

    if torch.isnan(val) or torch.isinf(val):
        print("⚠️ LPIPS loss returned NaN or Inf. Skipping this batch.")
        return torch.tensor(0.0, device=device)
    return val



In [5]:
torch.manual_seed(1)
handle = init_nvml()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Enable mixed precision training
scaler = torch.amp.GradScaler("cuda") if device == "cuda" else None
print("Mixed precision training enabled" if device == "cuda" else "Mixed precision training disabled")

# Load configuration
# config = OmegaConf.load("configs/train_config_256.yaml")
config = OmegaConf.load("configs/temp.yaml")
print(f"Configuration loaded: {OmegaConf.to_yaml(config)}")

Using device: cuda
Mixed precision training enabled
Configuration loaded: data:
  path: data/CelebA-HQ/images_512
  parquet_path: /dbfs/mnt/ds-space/Hitesh/Datasets/CelebA-HQ/parquet_files/CelebA-HQ.parquet
  image_size: 256
  normalize: true
  caption_path: data/CelebA-HQ/captions.csv
checkpoint:
  path: checkpoints/test/
  ckpt_name: UNet_ckpt_test.pth
  ema_ckpt_name: UNet_ema_ckpt_test.pth
training:
  batch_size: 4
  validation_split: 0.95
  epochs: 6
  warmup_epochs: 1
  lr: 0.0001
  grad_accum_steps: 2
  use_ema: true
  ema_beta: 0.995
  step_start_ema: 0
  num_workers: 4
losses:
  lpips:
    net: alex
sampling:
  dir: output/test_samples
  num_samples: 25
  steps: 50
model:
  type: unet
  sample_size: 32
  in_channels: 4
  out_channels: 4
  block_out_channels:
  - 256
  - 512
  - 1024
  - 1024
  down_block_types:
  - CrossAttnDownBlock2D
  - CrossAttnDownBlock2D
  - DownBlock2D
  - DownBlock2D
  up_block_types:
  - UpBlock2D
  - UpBlock2D
  - CrossAttnUpBlock2D
  - CrossAttnUpBl

In [6]:
# Load VAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device).eval()

# Load UNet
model = UNet2DConditionModel(
    sample_size=config.model.sample_size,
    in_channels=config.model.in_channels,
    out_channels=config.model.out_channels,
    down_block_types=config.model.down_block_types,
    up_block_types=config.model.up_block_types,
    block_out_channels=config.model.block_out_channels,
    layers_per_block=config.model.layers_per_block,
    cross_attention_dim=config.model.cross_attention_dim,
).to(device)

# Noise scheduler
scheduler = DDPMScheduler(
    num_train_timesteps=config.scheduler.timesteps,
    beta_start=config.scheduler.beta_start,
    beta_end=config.scheduler.beta_end,
    beta_schedule=config.scheduler.type,
)

# CLIP tokenizer & encoder
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval()

# EMA
unet_ema_model, ema = create_ema_model(
    model,
    beta=config.training.ema_beta,
    step_start_ema=config.training.step_start_ema
)

optimizer = torch.optim.AdamW(model.parameters(), lr=config.training.lr)
MSE_LOSS = torch.nn.MSELoss()
LPIPS_LOSS   = lpips.LPIPS(net=config.losses.lpips.net).to(device).eval()

print("Models, optimizers, losses initialized successfully.")

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: c:\Users\Incognito-R\miniconda3\envs\ml_env\Lib\site-packages\lpips\weights\v0.1\alex.pth
Models, optimizers, losses initialized successfully.


In [7]:
# === Load data ===
dataloader, _ = CelebAloader(data_config=config.data, train_config=config.training)
print(f"Dataset size: {len(dataloader.dataset)} images, batch size: {dataloader.batch_size}")

# batch = next(iter(dataloader))
# print(f"Batch image shape: {batch['image'].shape}, Batch captions: {len(batch['caption'])}, Batch images path: {len(batch['img_path'])}")
# Dataset size: 30000 images
# Batch image shape: torch.Size([12, 3, 256, 256]), Batch captions: 12, Batch images path: 12

# === Load checkpoint ===
os.makedirs(config.checkpoint.path, exist_ok=True)
ckpt_path = os.path.join(config.checkpoint.path, config.checkpoint.ckpt_name)
start_epoch, best_loss = load_training_state(ckpt_path, model, optimizer, device)
print(f"Resuming from epoch {start_epoch}, best_loss {best_loss:.4f}")

Dataset size: 1500 images, batch size: 4
✅ Resuming from checkpoint: checkpoints/test/UNet_ckpt_test.pth
Loaded base model and optimizer
Resuming at epoch 2, previous loss: 0.1587
Resuming from epoch 2, best_loss 0.0949


In [8]:
warmup_ep = config.training.warmup_epochs
# Baseline visual before training
generate_sample(0, vae, unet_ema_model, scheduler, clip_tokenizer, clip_encoder, config, device)

  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
Sampling epoch 0: 100%|██████████| 50/50 [00:02<00:00, 19.61it/s]


✅ Saved sample grid for epoch 0


In [None]:
warmup_ep = config.training.warmup_epochs
# Baseline visual before training

for epoch in range(start_epoch, config.training.epochs):
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    model.train()

    cumm_loss = 0.0
    cumm_mse = 0.0
    cumm_lpips = 0.0


    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{config.training.epochs}")
    for batch_idx, batch in pbar:
        if batch_idx % config.training.grad_accum_steps == 0:
            optimizer.zero_grad(set_to_none=True)

        images = batch['image'].to(device).float()
        captions = batch['caption']
        text_inputs = clip_tokenizer(
            captions,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            text_embeddings = clip_encoder(**text_inputs).last_hidden_state
            latents = vae.encode(images).latent_dist.sample() * 0.18215

        t = torch.randint(0, scheduler.config.num_train_timesteps, (latents.size(0),), device=device)
        noise = torch.randn_like(latents)
        x_t = scheduler.add_noise(latents, noise, t)

        with torch.amp.autocast(device_type='cuda', enabled=(device == 'cuda')):
            noise_pred = model(x_t, timestep=t, encoder_hidden_states=text_embeddings).sample
            mse_loss = MSE_LOSS(noise_pred, noise) / config.training.grad_accum_steps

            if epoch+1 <= warmup_ep:
                lpips_weight = 0.0
            else:
                # ramp-to-0.05 over [warmup_ep+1 .. 30], then hold
                frac = min((epoch+1 - warmup_ep) / float(30 - warmup_ep), 1.0)
                lpips_weight = 0.05 * frac

            if lpips_weight > 0:
                alpha_t = scheduler.alphas_cumprod[t].view(-1, 1, 1, 1)
                pred_x0 = (x_t - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
                pred_rgb = vae.decode(pred_x0 / 0.18215).sample.clamp(-1, 1)

                # Safe LPIPS computation
                lpips_loss = safe_lpips(pred_rgb, images, LPIPS_LOSS, device)
            else:
                lpips_loss = torch.tensor(0.0, device=device)

            total_loss = mse_loss + lpips_weight * lpips_loss

        scaler.scale(total_loss).backward()

        if (batch_idx + 1) % config.training.grad_accum_steps == 0:
            if device == 'cuda':
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                ema.step_ema(unet_ema_model, model)
                
        #--- for logging only ----
        cumm_mse += mse_loss.item()
        cumm_lpips += lpips_loss if isinstance(lpips_loss, float) else lpips_loss 
        cumm_loss += total_loss.item() # main
        avg_mse = cumm_mse / (batch_idx + 1)
        avg_lpips = cumm_lpips / (batch_idx + 1)
        avg_loss = cumm_loss / (batch_idx + 1) # Total loss. main log

        best_loss = min(best_loss, avg_loss)
        if (batch_idx+1) % 50 == 0:
            logging.info(f"Epoch {epoch+1} AVG MSE: {avg_mse:.4f}, AVG LPIPS: {avg_lpips:.4f}, AVG Total: {avg_loss:.4f}")

        pbar.set_postfix(avg_loss=avg_loss, mem=gpu_info(handle))
            
    # Epoch summary logging - for each epoch
    avg_loss = cumm_loss / (batch_idx + 1)
    logging.info(f"Epoch {epoch+1} AVG MSE: {avg_mse:.4f}, AVG LPIPS: {avg_lpips:.4f}, AVG Total: {avg_loss:.4f}")

    print("SAVING MODEL STATES...")
    # Save checkpoint & EMA weights
    save_training_state(
        checkpoint_path=ckpt_path, epoch=epoch,
        model=model, optimizer=optimizer,
        avg_loss=avg_loss, best_loss=best_loss
    )
    print("UNET MODEL SAVED!")

    ema_path = os.path.join(config.checkpoint.path, config.checkpoint.ema_ckpt_name)
    torch.save(unet_ema_model.state_dict(), ema_path)
    print("EMA_UNET MODEL SAVED!")

    print(f"Epoch {epoch+1} done. Avg loss: {avg_loss:.4f}")

    # Generate visual sample for this epoch
    generate_sample(
        epoch=epoch+1, vae=vae, ema_model=unet_ema_model,
        scheduler=scheduler, tokenizer=clip_tokenizer, text_encoder=clip_encoder,
        config=config, device=device
    )

print("Training completed.")

Epoch 3/6:   2%|▏         | 8/375 [01:11<54:46,  8.96s/it, avg_loss=0.0939, mem=🚨GPU usage:12233 > 11900 Mib)] 


KeyboardInterrupt: 

In [None]:
import os
import torch
from torchvision.utils import save_image, make_grid
from tqdm import tqdm

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
from utils.ema import create_ema_model
from utils.metrics.gpu import init_nvml, gpu_info

@torch.no_grad()
def main():
    # Load configuration
    config = OmegaConf.load("configs/train_config_256.yaml")
    model_cfg = config.model
    sample_cfg = config.sampling

    device = "cuda" if torch.cuda.is_available() else "cpu"
    handle = init_nvml()

    # Prepare output
    os.makedirs(sample_cfg.dir, exist_ok=True)
    output_path = os.path.join(sample_cfg.dir, 'test_sample_grid.png')

    # Load VAE
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device).eval()

    # Load conditional U-Net
    model = UNet2DConditionModel(
        sample_size=model_cfg.sample_size,
        in_channels=model_cfg.in_channels,
        out_channels=model_cfg.out_channels,
        down_block_types=tuple(model_cfg.down_block_types),
        up_block_types=tuple(model_cfg.up_block_types),
        block_out_channels=tuple(model_cfg.block_out_channels),
        layers_per_block=model_cfg.layers_per_block,
        cross_attention_dim=model_cfg.cross_attention_dim,
    ).to(device)

    # EMA wrapper
    ema_model, _ = create_ema_model(
        model,
        beta=config.training.ema_beta,
        step_start_ema=config.training.step_start_ema
    )
    ema_ckpt = os.path.join(config.checkpoint.path, config.checkpoint.ema_ckpt_name)
    ema_model.load_state_dict(torch.load(ema_ckpt, map_location=device))
    ema_model.eval()

    # Scheduler aligned with training
    scheduler = DDPMScheduler(
        num_train_timesteps=config.scheduler.timesteps,
        beta_start=config.scheduler.beta_start,
        beta_end=config.scheduler.beta_end,
        beta_schedule=config.scheduler.type
    )
    scheduler.set_timesteps(sample_cfg.steps)

    # Load CLIP
    tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
    text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14').to(device).eval()

    # Prepare guidance embeddings
    prompt = "a beautiful woman in a red dress"
    text_inputs = tokenizer(
        [prompt] * sample_cfg.num_samples,
        padding="max_length",
        truncation=True,
        max_length=77,
        return_tensors="pt"
    ).to(device)
    text_emb = text_encoder(**text_inputs).last_hidden_state
    # Unconditional (empty) embeddings
    uncond_inputs = tokenizer(
        [""] * sample_cfg.num_samples,
        padding="max_length",
        truncation=True,
        max_length=77,
        return_tensors="pt"
    ).to(device)
    uncond_emb = text_encoder(**uncond_inputs).last_hidden_state

    # Sampling
    num_samples = sample_cfg.num_samples
    shape = (num_samples, model_cfg.in_channels, model_cfg.sample_size, model_cfg.sample_size)
    latents = torch.randn(shape, device=device)
    guidance_scale = sample_cfg.guidance_scale if hasattr(sample_cfg, 'guidance_scale') else 7.5

    for t in tqdm(scheduler.timesteps, desc="Sampling"):  # timesteps descends
        t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)

        # Predict noise for both conditional and unconditional
        eps_uncond = ema_model(latents, timestep=t_batch, encoder_hidden_states=uncond_emb).sample
        eps_cond   = ema_model(latents, timestep=t_batch, encoder_hidden_states=text_emb).sample
        # Classifier-free guidance
        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

        # Step
        latents = scheduler.step(eps, t, latents).prev_sample

    # Decode latents to images
    images = vae.decode(latents / 0.18215).sample
    images = (images.clamp(-1, 1) + 1) / 2

    # Save grid
    grid = make_grid(images, nrow=int(num_samples**0.5))
    save_image(grid, output_path)
    print(f"✅ Samples saved to {output_path}")

if __name__ == "__main__":
    torch.cuda.empty_cache()
    main()


In [None]:
#=======================================================================================

# Inference

In [None]:
import os
import torch
from torchvision.utils import save_image, make_grid
from tqdm import tqdm

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
from utils.ema import create_ema_model
from utils.metrics.gpu import init_nvml, gpu_info

@torch.no_grad()
def main():
    # Load configuration
    config = OmegaConf.load("configs/train_config_256.yaml")
    model_cfg = config.model
    sample_cfg = config.sampling

    device = "cuda" if torch.cuda.is_available() else "cpu"
    handle = init_nvml()

    # Prepare output
    os.makedirs(sample_cfg.dir, exist_ok=True)
    output_path = os.path.join(sample_cfg.dir, 'test_sample_grid.png')

    # Load VAE
    vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device).eval()

    # Load conditional U-Net
    model = UNet2DConditionModel(
        sample_size=model_cfg.sample_size,
        in_channels=model_cfg.in_channels,
        out_channels=model_cfg.out_channels,
        down_block_types=tuple(model_cfg.down_block_types),
        up_block_types=tuple(model_cfg.up_block_types),
        block_out_channels=tuple(model_cfg.block_out_channels),
        layers_per_block=model_cfg.layers_per_block,
        cross_attention_dim=model_cfg.cross_attention_dim,
    ).to(device)

    # EMA wrapper
    ema_model, _ = create_ema_model(
        model,
        beta=config.training.ema_beta,
        step_start_ema=config.training.step_start_ema
    )
    ema_ckpt = os.path.join(config.checkpoint.path, config.checkpoint.ema_ckpt_name)
    ema_model.load_state_dict(torch.load(ema_ckpt, map_location=device))
    ema_model.eval()

    # Scheduler aligned with training
    scheduler = DDPMScheduler(
        num_train_timesteps=config.scheduler.timesteps,
        beta_start=config.scheduler.beta_start,
        beta_end=config.scheduler.beta_end,
        beta_schedule=config.scheduler.type
    )
    scheduler.set_timesteps(sample_cfg.steps)

    # Load CLIP
    tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
    text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14').to(device).eval()

    # Prepare guidance embeddings
    prompt = "a beautiful woman in a red dress"
    text_inputs = tokenizer(
        [prompt] * sample_cfg.num_samples,
        padding="max_length",
        truncation=True,
        max_length=77,
        return_tensors="pt"
    ).to(device)
    text_emb = text_encoder(**text_inputs).last_hidden_state
    # Unconditional (empty) embeddings
    uncond_inputs = tokenizer(
        [""] * sample_cfg.num_samples,
        padding="max_length",
        truncation=True,
        max_length=77,
        return_tensors="pt"
    ).to(device)
    uncond_emb = text_encoder(**uncond_inputs).last_hidden_state

    # Sampling
    num_samples = sample_cfg.num_samples
    shape = (num_samples, model_cfg.in_channels, model_cfg.sample_size, model_cfg.sample_size)
    latents = torch.randn(shape, device=device)
    guidance_scale = sample_cfg.guidance_scale if hasattr(sample_cfg, 'guidance_scale') else 7.5

    for t in tqdm(scheduler.timesteps, desc="Sampling"):  # timesteps descends
        t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)

        # Predict noise for both conditional and unconditional
        eps_uncond = ema_model(latents, timestep=t_batch, encoder_hidden_states=uncond_emb).sample
        eps_cond   = ema_model(latents, timestep=t_batch, encoder_hidden_states=text_emb).sample
        # Classifier-free guidance
        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

        # Step
        latents = scheduler.step(eps, t, latents).prev_sample

    # Decode latents to images
    images = vae.decode(latents / 0.18215).sample
    images = (images.clamp(-1, 1) + 1) / 2

    # Save grid
    grid = make_grid(images, nrow=int(num_samples**0.5))
    save_image(grid, output_path)
    print(f"✅ Samples saved to {output_path}")

if __name__ == "__main__":
    torch.cuda.empty_cache()
    main()
