# Deep Learning Project, Part 1

This notebook is our first part. Training our diffusion model using SHHQ dataset.

In [None]:
# For displaying progress bars during training loops
!pip -q install tqdm

# Standard Python utilities
import os, math, random
from glob import glob
from dataclasses import dataclass

# PyTorch core library
import torch
import torch.nn as nn
import torch.nn.functional as F

# Dataset and DataLoader utilities
from torch.utils.data import Dataset, DataLoader

# Image preprocessing and transformations
from torchvision import transforms
from torchvision.utils import make_grid, save_image

# Image loading and handling
from PIL import Image

# Progress bar
from tqdm import tqdm

# Object-oriented file system paths
from pathlib import Path


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
# We need to connect drive to get data
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# We are unzipping the file located in the drive as a zip file
#!unzip -P StylisH-HumanS-hq_1.0 "/content/drive/MyDrive/SHHQ-1.0.zip" -d "/content/drive/MyDrive/SHHQ_1.0"


In [None]:
img_transform = transforms.Compose([
    transforms.Resize((128, 64)), # All images are resized to 128,64 dimensions
    transforms.ToTensor(),  # Converted to tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], # Since the -1.1 range is required for diffusion, we perform that transformation
                         std=[0.5, 0.5, 0.5])
])

mask_transform = transforms.Compose([
    transforms.Resize((128, 64)),
    transforms.ToTensor(),   # There is no normalization because mask is information, not noise
])

In [None]:
class ImageMaskDataset(Dataset):
  """
    Custom PyTorch Dataset for loading paired human images and segmentation masks.
    Each sample consists of: An RGB image, a corresponding binary segmentation mask.
    """

  def __init__(self, img_dir, mask_dir, img_transform=None, mask_transform=None):
      """
        Args:
            img_dir (str): Directory containing RGB input images.
            mask_dir (str): Directory containing corresponding segmentation masks.
            img_transform (callable, optional): Transformations applied to images.
            mask_transform (callable, optional): Transformations applied to masks.
        """
      self.img_dir = img_dir
      self.mask_dir = mask_dir
      self.img_transform = img_transform
      self.mask_transform = mask_transform

      # Collect and sort image filenames to ensure deterministic pairing
      self.img_files = sorted([
          f for f in os.listdir(img_dir)
          if f.lower().endswith((".png", ".jpg", ".jpeg"))
      ])

      # Collect and sort mask filenames
      self.mask_files = sorted([
          f for f in os.listdir(mask_dir)
          if f.lower().endswith((".png", ".jpg", ".jpeg"))
      ])

      # Print dataset statistics
      print("Images:", len(self.img_files))
      print("Masks :", len(self.mask_files))


      assert len(self.img_files) == len(self.mask_files)

  def __len__(self):
      """
      Returns:
          int: Total number of samples in the dataset.
      """
      return len(self.img_files)

  def __getitem__(self, idx):
      """
      Loads and returns a single (image, mask) pair.

      Args:
          idx (int): Index of the sample.

      Returns:
          img (Tensor): Transformed RGB image tensor.
          mask (Tensor): Binary segmentation mask tensor.
      """

      # Load RGB image and convert to 3-channel format
      img = Image.open(os.path.join(self.img_dir, self.img_files[idx])).convert("RGB")

      # Load mask image and convert to single-channel (grayscale)
      mask = Image.open(os.path.join(self.mask_dir, self.mask_files[idx])).convert("L")

      # Apply image transformations
      if self.img_transform:
          img = self.img_transform(img)

      # Apply mask transformations
      if self.mask_transform:
          mask = self.mask_transform(mask)

      # Convert mask to strict binary format (0 or 1)
      # This ensures a clean silhouette for conditioning
      mask = (mask > 0.5).float()

      return img, mask


In [None]:
# Since it takes too long for the model to pull all the data from the drive and slows down the model, we upload the data to Colab
!cp /content/drive/MyDrive/SHHQ-1.0.zip /content/


In [None]:
# Unzip in colab
!unzip -P StylisH-HumanS-hq_1.0 /content/SHHQ-1.0.zip -d /content/data/


In [None]:
img_dir  = "/content/data/SHHQ-1.0/no_segment" # Original images
mask_dir = "/content/data/SHHQ-1.0/segments" # Masks

# Initialize the custom dataset that pairs each image with its segmentation mask
dataset = ImageMaskDataset(
    img_dir=img_dir,
    mask_dir=mask_dir,
    img_transform=img_transform,
    mask_transform=mask_transform
)

# Create a DataLoader for efficient mini-batch training
loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,  # Shuffle data to improve generalization
    num_workers=2,  # Parallel data loading using CPU workers
    pin_memory=True  # Faster host-to-GPU memory transfer
)

Images: 40000
Masks : 40000


In [None]:
'''
This class defines the forward diffusion process in a DDPM.

In diffusion models, there are two main processes:
1) Forward diffusion  : Gradually add Gaussian noise to a clean image
2) Reverse diffusion  : Learn to remove the noise step by step to recover the image

This scheduler only handles the forward diffusion process.

The task of this class: "To mathematically define how to distort an image in a controlled manner."
'''
class DDPMScheduler:
    def __init__(
        self,
        T=1000, # Total number of diffusion steps
        beta_start=1e-4, # Initial noise level
        beta_end=0.015,   # Final noise level
        device="cuda"
    ):

        # Number of diffusion timesteps
        self.T = T

        # Linearly spaced noise schedule (beta_t)
        # Each beta controls how much noise is added at step t
        self.betas = torch.linspace(
            beta_start,
            beta_end,
            T,
            device=device
        )

        # Alpha_t = 1 - beta_t
        # Represents how much signal remains after adding noise at step t
        self.alphas = 1.0 - self.betas

        # Cumulative product of alphas
        # This represents how much of the original image remains after t steps
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)

        # Precompute square roots for efficiency
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar)

    def q_sample(self, x0, t, noise):
        """
        Forward diffusion equation:
        x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise

        This function takes a clean image x0 and returns a noisy version x_t
        at timestep t.
        """

        B = x0.size(0) # Batch size

        # Select alpha_bar_t for each sample in the batch
        a = self.sqrt_alpha_bar[t].view(B, 1, 1, 1)
        b = self.sqrt_one_minus_alpha_bar[t].view(B, 1, 1, 1)

        # Combine the original image and Gaussian noise
        # As t increases, the image becomes more noisy
        return a * x0 + b * noise

In [None]:
"""
This function encodes the diffusion timestep (t) into a high-dimensiona vector representation. It allows the model to know which diffusion step
it is currently processing.

In diffusion models, the same network is used for all timesteps, so explicit time conditioning is required.
"""
def sinusoidal_embedding(t, dim):
  """
  Args:
      t (Tensor): Diffusion timestep tensor of shape (B,)
      dim (int): Embedding dimension used to represent the timestep. Larger dimensions provide richer temporal information.
                  We use 256 as a practical and commonly adopted choice in diffusion literature.

  Returns:
      Tensor: Sinusoidal time embeddings of shape (B, dim)
  """

  # Half of the embedding is used for sine, half for cosine
  half = dim // 2

  # Generate exponentially decreasing frequencies
  # This allows the embedding to capture both
  # low-frequency (coarse) and high-frequency (fine) time information
  freqs = torch.exp(
      -math.log(10000) * torch.arange(half, device=t.device) / (half - 1)
  )

  # Scale frequencies by the timestep
  args = t.float().unsqueeze(1) * freqs.unsqueeze(0)

  # Concatenate sine and cosine embeddings
  # This makes the representation smooth and uniquely decodable
  emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)

  # These time embeddings are injected into every ResBlock so the network is explicitly aware that:
  # "this feature map was generated at diffusion step t"
  return emb


In [None]:
'''
This block is the fundamental building unit of the U-Net architecture.
It processes spatial features while explicitly conditioning on the diffusion timestep through time embeddings.

Residual connections help stabilize training and preserve information, which is especially important in deep diffusion models.
'''
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim):
        """
        Args:
            in_ch (int): Number of input feature channels.
            out_ch (int): Number of output feature channels.
            t_dim (int): Dimensionality of the time embedding (e.g., 256).
        """

        super().__init__()

        # Group Normalization stabilizes training for small batch sizes, which is common in diffusion models
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)

        self.norm2 = nn.GroupNorm(8, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

        # Linear projection that maps the time embedding to the same dimensionality as the feature channels
        self.time_mlp = nn.Linear(t_dim, out_ch)

        # Skip connection to preserve information and enable residual learning. If channel dimensions differ, a 1×1 convolution aligns them
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb):
        """
        Forward pass with time conditioning.

        Args:
            x (Tensor): Input feature map of shape (B, C, H, W).
            t_emb (Tensor): Time embedding of shape (B, t_dim).

        Returns:
            Tensor: Output feature map with residual connection applied.
        """

        # First normalization + nonlinearity + convolution
        h = self.conv1(F.silu(self.norm1(x)))

        # Inject time information into the feature map
        # This allows the network to adapt its behavior depending on the current diffusion step
        time = self.time_mlp(t_emb).view(x.size(0), -1, 1, 1)
        h = h + time

        # Second normalization + nonlinearity + convolution
        h = self.conv2(F.silu(self.norm2(h)))

        # Residual connection ensures stable gradient flow
        return h + self.skip(x)

In [None]:
'''
This block reduces the spatial resolution of the feature maps.
It allows the network to capture global structure by increasing the receptive field.

In diffusion U-Nets, downsampling is crucial for modeling long range spatial dependencies efficiently.
'''
class Down(nn.Module):
    def __init__(self, ch):
      """
        Args:
            ch (int): Number of feature channels.
      """

      super().__init__()

      # Strided convolution halves the spatial resolution while preserving the number of channels
      self.conv = nn.Conv2d(ch, ch, 4, stride=2, padding=1)

    def forward(self, x):
        """
        Args:
            x (Tensor): Input feature map of shape (B, C, H, W).

        Returns:
            Tensor: Downsampled feature map (H/2, W/2).
        """
        return self.conv(x)


'''
This block increases the spatial resolution of the feature maps.
It reconstructs fine-grained spatial details using information from earlier layers via skip connections.

Upsampling is essential for recovering image details after aggressive downsampling in the encoder.
'''
class Up(nn.Module):
    def __init__(self, ch):
        """
        Args:
            ch (int): Number of feature channels.
        """
        super().__init__()

        # Transposed convolution doubles the spatial resolution and prepares features for concatenation with skip connections
        self.tconv = nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1)

    def forward(self, x):
        """
        Args:
            x (Tensor): Input feature map of shape (B, C, H, W).

        Returns:
            Tensor: Upsampled feature map (2H, 2W).
        """
        return self.tconv(x)


In [None]:
'''
This block enables the model to capture long-range spatial dependencies.
Unlike convolution, which is inherently local, self-attention allows every spatial location to directly interact with every other location.

In diffusion models, self-attention is especially useful for maintaining global coherence, such as consistent body structure and symmetry.
'''
class SelfAttention(nn.Module):
    def __init__(self, channels):
        """
        Args:
            channels (int): Number of feature channels.
        """
        super().__init__()

        # Normalize features before computing attention GroupNorm is used for stability with small batch sizes
        self.norm = nn.GroupNorm(32, channels)

        # Linear projections to obtain Query, Key, and Value tensors
        # 1×1 convolutions preserve spatial dimensions while mixing channels
        self.q = nn.Conv2d(channels, channels, 1)
        self.k = nn.Conv2d(channels, channels, 1)
        self.v = nn.Conv2d(channels, channels, 1)

        # Output projection after attention
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        """
        Args:
            x (Tensor): Input feature map of shape (B, C, H, W).

        Returns:
            Tensor: Feature map enhanced with global contextual information.
        """
        B, C, H, W = x.shape

        # Normalize input features
        h = self.norm(x)

        # Compute query, key, and value tensors
        # Flatten spatial dimensions (H × W)
        q = self.q(h).reshape(B, C, H * W)
        k = self.k(h).reshape(B, C, H * W)
        v = self.v(h).reshape(B, C, H * W)

        # Compute attention matrix
        # Each spatial location attends to all other locations
        attn = torch.softmax(
            torch.bmm(q.permute(0, 2, 1), k) / (C ** 0.5),
            dim=-1
        )

        out = torch.bmm(v, attn.permute(0, 2, 1))
        out = out.reshape(B, C, H, W)

        return x + self.proj(out)

In [None]:
'''
U-Net is well-suited for diffusion models because it can capture both global structure (e.g., full human silhouette)
and local details (e.g., arms, legs, edges) simultaneously.

The encoder extracts hierarchical features at multiple resolutions, while the decoder reconstructs fine details using skip connections.
'''
class UNet(nn.Module):
    def __init__(self, in_ch=4, base=64, t_dim=256):
        """
        Args:
            in_ch (int): Number of input channels.
                          (3 RGB channels + 1 segmentation mask channel)
            base (int): Base number of feature channels.
            t_dim (int): Dimensionality of the time embedding.
        """
        super().__init__()

        # Time embedding MLP
        # Transforms the sinusoidal time embedding
        # so it can be injected into ResBlocks
        self.time_mlp = nn.Sequential(
            nn.Linear(t_dim, t_dim),
            nn.SiLU(),
            nn.Linear(t_dim, t_dim)
        )

        # Initial convolution
        # Maps the input (RGB image + mask) to base feature channels
        self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)# 3 kanaldan(RGB), 64 kanala çevirdik conv ile

        # Encoder (Downsampling)
        # As depth increases: Spatial resolution decreases, feature channels increase
        # This allows the network to capture global structure efficiently

        self.rb1 = ResBlock(base, base, t_dim)
        self.down1 = Down(base) # downsampling

        self.rb2 = ResBlock(base, base*2, t_dim)
        self.down2 = Down(base*2)

        self.rb3 = ResBlock(base*2, base*4, t_dim)
        self.down3 = Down(base*4)

        # Bottleneck
        # Lowest resolution representation
        # Responsible for capturing global context
        self.mid1 = ResBlock(base*4, base*4, t_dim)

        # Self-attention is disabled here (Identity)
        # It was tested but removed due to high computational cost
        self.mid_attn = nn.Identity()

        self.mid2 = ResBlock(base*4, base*4, t_dim)

        # Decoder (Upsampling)
        # Gradually restores spatial resolution
        # Skip connections inject high-resolution details
        self.up3 = Up(base*4)
        self.urb3 = ResBlock(base*8, base*2, t_dim)

        self.up2 = Up(base*2)
        self.urb2 = ResBlock(base*4, base, t_dim)

        self.up1 = Up(base)
        self.urb1 = ResBlock(base*2, base, t_dim)

        # Final output layer
        # Predicts noise with the same spatial size as the input image
        self.out = nn.Conv2d(base, 3, 3, padding=1)

    def forward(self, x, t):
      """
        Forward pass of the U-Net.

        Args:
            x (Tensor): Input tensor of shape (B, 4, H, W)
                        (noisy RGB image + segmentation mask)
            t (Tensor): Diffusion timestep tensor of shape (B,)

        Returns:
            Tensor: Predicted noise tensor of shape (B, 3, H, W)
      """

      # Time embedding
      # Encode diffusion step information
      t_emb = sinusoidal_embedding(t, 256)
      t_emb = self.time_mlp(t_emb)

      # Input conv
      x0 = self.in_conv(x)

      # Downsampling path
      h1 = self.rb1(x0, t_emb)
      d1 = self.down1(h1)

      h2 = self.rb2(d1, t_emb)
      d2 = self.down2(h2)

      h3 = self.rb3(d2, t_emb)
      d3 = self.down3(h3)

      # Bottleneck
      m = self.mid1(d3, t_emb)
      m = self.mid_attn(m)
      m = self.mid2(m, t_emb)

      # Upsampling path
      u3 = self.up3(m)
      u3 = self.urb3(torch.cat([u3, h3], 1), t_emb)

      u2 = self.up2(u3)
      u2 = self.urb2(torch.cat([u2, h2], 1), t_emb)

      u1 = self.up1(u2)
      u1 = self.urb1(torch.cat([u1, h1], 1), t_emb)

      # Output (noise prediction)
      return self.out(u1)

In [None]:
model = UNet().to(device) # Initialize the U-Net model
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# Initialize the diffusion scheduler
scheduler = DDPMScheduler(
    T=1000,
    beta_start=1e-4,
    beta_end=0.015,
    device=device
)

In [None]:
# Initialize a separate U-Net for EMA tracking
ema_model = UNet(
    in_ch=4,
    base=64,
    t_dim=256
).to(device)

# At initialization, copy the main model weights into the EMA model
ema_model.load_state_dict(model.state_dict())
ema_model.eval()

In [None]:
@torch.no_grad()
def ema_update(ema_model, model, decay=0.999):
  '''
  This function updates the EMA model by smoothly averaging the current model parameters over time.

  Instead of copying weights directly, EMA applies a weighted update:
    Old EMA weights are kept with high importance (decay)
    New model weights are added with a small contribution (1 - decay)
  '''
  for ema_param, param in zip(ema_model.parameters(), model.parameters()):
      ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)

In [None]:
@torch.no_grad()
def sample_cond(model, sched, mask, n=None):
    '''
    This function generates images using the learned reverse diffusion process.
    Starting from pure Gaussian noise, the model iteratively removes noise step-by-step while being guided by a segmentation mask.
    '''
    # Use evaluation mode for stable inference
    model.eval()

    # If sample count is not specified, use batch size from mask
    if n is None:
        n = mask.size(0)

    # Initialize with pure Gaussian noise
    B, _, H, W = mask.shape
    x = torch.randn(B, 3, H, W).to(device)

    # Reverse diffusion: iterate from T → 0
    for t in reversed(range(sched.T)):
        t_batch = torch.full((B,), t, device=device)

        # Concatenate noisy image with segmentation mask
        # This conditions generation on human silhouette
        model_in = torch.cat([x, mask], dim=1)

        # Predict noise component at timestep t
        eps = model(model_in, t_batch)

        # Retrieve scheduler coefficients
        alpha = sched.alphas[t]
        alpha_bar = sched.alpha_bar[t]
        beta = sched.betas[t]

        # Compute mean of the reverse diffusion distribution
        mean = (1 / torch.sqrt(alpha)) * (
            x - (1 - alpha) / torch.sqrt(1 - alpha_bar) * eps
        )

        # Add noise except for the final step
        if t > 0:
            x = mean + torch.sqrt(beta) * torch.randn_like(x)
        else:
            x = mean

    # Final denoised image
    return x

In [None]:
@torch.no_grad()
def ddim_sample_cond(
    model,
    scheduler,
    mask,
    device="cuda",
    ddim_steps=250,
    eta=0.0
):
    # Use evaluation mode for inference
    model.eval()

    # Initialize with pure Gaussian noise
    B, _, H, W = mask.shape
    x = torch.randn(B, 3, H, W, device=device)

    # Select a reduced set of timesteps (e.g., 1000 → 250)
    # This speeds up sampling significantly
    timesteps = torch.linspace(
        scheduler.T - 1,
        0,
        ddim_steps,
        dtype=torch.long,
        device=device
    )

    # Reverse diffusion with fewer steps
    for i in range(len(timesteps) - 1):
        t = timesteps[i]
        t_next = timesteps[i + 1]

        t_batch = torch.full((B,), t, device=device)

        # Concatenate noisy image with segmentation mask
        model_in = torch.cat([x, mask], dim=1)

        # Predict noise at timestep t
        eps = model(model_in, t_batch)

        # Retrieve cumulative noise coefficients
        alpha_bar_t = scheduler.alpha_bar[t]
        alpha_bar_next = scheduler.alpha_bar[t_next]

        # Predict the clean image x0
        x0_pred = (
            (x - torch.sqrt(1 - alpha_bar_t) * eps)
            / torch.sqrt(alpha_bar_t)
        ).clamp(-1, 1)

        # Control stochasticity (eta = 0 → deterministic)
        sigma = (
            eta
            * torch.sqrt(
                (1 - alpha_bar_next)
                / (1 - alpha_bar_t)
                * (1 - alpha_bar_t / alpha_bar_next)
            )
        )

        noise = sigma * torch.randn_like(x) if eta > 0 else 0

        # DDIM update step
        x = (
            torch.sqrt(alpha_bar_next) * x0_pred
            + torch.sqrt(1 - alpha_bar_next - sigma**2) * eps
            + noise
        )

    # Final generated image
    return x

In [None]:
# Directories for saving checkpoints and generated samples
checkpoint_dir = Path("/content/drive/MyDrive/checkpoints")
sample_dir = Path("/content/drive/MyDrive/samples1")
checkpoint_dir.mkdir(exist_ok=True)
sample_dir.mkdir(exist_ok=True)

# Training configuration
NUM_EPOCHS = 50      # Total number of training epochs
SAVE_EVERY = 10       # Save model checkpoint every N epochs
SAMPLE_EVERY = 5      # Generate samples every N epochs

# Track average loss per epoch
epoch_losses = []

for epoch in range(NUM_EPOCHS):

    print(f"\n Epoch {epoch+1}/{NUM_EPOCHS}")

    # Enable training mode
    model.train()
    epoch_loss = 0.0

    # Progress bar for batches
    pbar = tqdm(
        loader,
        total=len(loader),
        desc=f"Epoch {epoch+1}",
        leave=True
    )

    for batch_idx, (x0, mask) in enumerate(pbar):
        x0 = x0.to(device)
        mask = mask.to(device)

        # Randomly sample diffusion timestep
        t = torch.randint(0, scheduler.T, (x0.size(0),), device=device)

        # Forward diffusion: add noise to clean image
        noise = torch.randn_like(x0)
        xt = scheduler.q_sample(x0, t, noise)

        # Predict noise using mask-conditioned U-Net
        model_in = torch.cat([xt, mask], dim=1)
        pred_noise = model(model_in, t)

        # Face-weighted loss to emphasize perceptually important regions
        B, _, H, W = mask.shape
        weight_map = torch.ones_like(mask)
        face_region = int(0.3 * H)
        weight_map[:, :, :face_region, :] = 2.5
        weight_map = 1.0 + mask * weight_map

        # Combined MSE + L1 loss
        diff = pred_noise - noise
        loss = (
            0.6 * (diff**2 * weight_map).mean() +
            0.4 * (diff.abs() * weight_map).mean()
        )

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Update EMA model
        ema_update(ema_model, model, decay=0.9997)

        # Track loss
        epoch_loss += loss.item()

        # Update progress bar
        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "avg": f"{epoch_loss/(batch_idx+1):.4f}"
        })

    # Compute average epoch loss
    avg_epoch_loss = epoch_loss / len(loader)
    epoch_losses.append(avg_epoch_loss)

    print(f" Epoch {epoch+1} - Avg Loss: {avg_epoch_loss:.4f}")

    # Save model checkpoint periodically
    if (epoch + 1) % SAVE_EVERY == 0:
        checkpoint_path = checkpoint_dir / f"model_epoch_{epoch+1}.pt"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'ema_model_state_dict': ema_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_epoch_loss,
            'epoch_losses': epoch_losses
        }, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

    # Generate and save samples periodically
    if (epoch + 1) % SAMPLE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            _, sample_mask = next(iter(loader))
            sample_mask = sample_mask[:8].to(device)

            samples = sample_cond(
                ema_model,      # Use EMA model for better quality
                scheduler,
                mask=sample_mask
            )

            samples = (samples + 1) / 2  # Denormalize
            save_image(
                samples,
                sample_dir / f"samples_epoch_{epoch+1}.png",
                nrow=4
            )
            print(f"Samples saved: samples_epoch_{epoch+1}.png")

        model.train()

# Save final trained model
final_path = checkpoint_dir / "model_final.pt"
torch.save({
    'epoch': NUM_EPOCHS,
    'model_state_dict': model.state_dict(),
    'ema_model_state_dict': ema_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch_losses': epoch_losses
}, final_path)


In [None]:
'''
This block generates 400 human images using the trained EMA diffusion model.
We take masks from the DataLoader and generate one image per mask.
Each generated image is saved as a separate PNG file to Google Drive.
'''

# Output directory on Drive
out_dir = Path("/content/drive/MyDrive/SHHQ_1.0/SHHQ-1.0/model_samples")
out_dir.mkdir(parents=True, exist_ok=True)

@torch.no_grad()
def generate_and_save_samples(
    ema_model,
    scheduler,
    loader,
    out_dir,
    total_samples=400
):
    """
    Generates `total_samples` images and saves them one-by-one.

    Args:
        ema_model: EMA version of the diffusion model (better quality).
        scheduler: DDPM scheduler used for sampling.
        loader: DataLoader that returns (image, mask).
        out_dir: Directory to save generated PNG files.
        total_samples: Number of images to generate.
    """
    ema_model.eval()

    saved = 0
    loader_iter = iter(loader)

    while saved < total_samples:
        try:
            _, mask = next(loader_iter)  # we only need the mask
        except StopIteration:
            # restart loader if it ends before reaching total_samples
            loader_iter = iter(loader)
            _, mask = next(loader_iter)

        mask = mask.to(device)

        # How many samples to generate from this batch
        remaining = total_samples - saved
        b = min(mask.size(0), remaining)

        # Use only the needed number of masks
        mask_batch = mask[:b]

        # Generate images with DDPM sampling (mask-conditioned)
        samples = sample_cond(ema_model, scheduler, mask=mask_batch)

        # Convert from [-1, 1] to [0, 1] for saving
        samples = (samples.clamp(-1, 1) + 1) / 2

        # Save each image separately
        for i in range(b):
            save_path = out_dir / f"sample_{saved:04d}.png"
            save_image(samples[i], save_path)
            saved += 1

        print(f"Saved {saved}/{total_samples} images...", end="\r")

    print(f"\n Done! Saved {total_samples} images to: {out_dir}")



In [None]:
# Run generation
generate_and_save_samples(
    ema_model=ema_model,
    scheduler=scheduler,
    loader=loader,
    out_dir=out_dir,
    total_samples=400
)