# Diffusion-based Image Colorization

In [None]:
CFG = {
"dataset": "CIFAR10",
"img_size": 32,
"batch_size": 128, # reduce to 64 or 32 if running on CPU
"epochs": 60, # small number for quick iteration
"lr": 2e-4,
"device": "cuda" if __import__('torch').cuda.is_available() else "cpu",
"timesteps": 200, # number of diffusion steps (DDPM)
"sample_steps": 25, # steps used for sampling (DDIM/accelerated)
"channels": 3,
# Experiment toggles (for ablation study)
"use_learnable_t_emb": True,
"use_attention": False,
"use_ddim": True,
"max_train_batches": None # set to an int to limit batches per epoch for faster runs
}


print("Device:", CFG['device'])

In [None]:
import math
import random
import time
from functools import partial


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)


def show_tensor_images(images, nrow=8, figsize=(8,8), title=None):
    images = images.clone().detach().cpu()
    grid = make_grid(images, nrow=nrow, normalize=True, value_range=(0,1))
    plt.figure(figsize=figsize)
    plt.axis('off')
    if title:
        plt.title(title)
    plt.imshow(grid.permute(1,2,0))
    plt.show()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(), # [0,1]
])

train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

# helper: make 1-channel grayscale from 3-channel image
def rgb_to_gray(x):
    # x: Tensor [C,H,W] with C=3
    r, g, b = x[0], x[1], x[2]
    gray = 0.2989*r + 0.5870*g + 0.1140*b
    return gray.unsqueeze(0)


# Wrap dataset to return (grayscale, color)
class ColorizationDataset(torch.utils.data.Dataset):
    def __init__(self, base_ds):
        self.base = base_ds
    def __len__(self):
        return len(self.base)
    def __getitem__(self, idx):
        img, label = self.base[idx]
        # img in [0,1]
        gray = rgb_to_gray(img)
        return gray, img


train_ds_col = ColorizationDataset(train_ds)
val_ds_col = ColorizationDataset(val_ds)


train_loader = DataLoader(train_ds_col, batch_size=CFG['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds_col, batch_size=CFG['batch_size'], shuffle=False, num_workers=2, pin_memory=True)


# Peek
gr, col = next(iter(train_loader))
show_tensor_images(torch.cat([gr.repeat(1,3,1,1)[:16], col[:16]], dim=0), nrow=8, title='Grayscale (replicated) | Ground Truth Color')

In [None]:
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, timesteps)

T = CFG['timesteps']
betas = linear_beta_schedule(T).to(CFG['device'])
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

# helpers that will be used in training/sampling
betas_t = betas
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

# utility: extract a tensor of shape [batch_size, 1,1,1] for a timestep t

def extract(a, t, x_shape):
    # a: tensor [T], t: tensor of shape [B] with ints
    out = a.gather(-1, t).to(t.device) # Fix: Removed .cpu() from t (change if running cpu)
    return out.view(-1, *((1,)*(len(x_shape)-1)))

In [None]:
class SmallConvBlock(nn.Module):
    def __init__(self, c_in, c_out, kernel=3, padding=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(c_in, c_out, kernel_size=kernel, padding=padding),
            nn.GroupNorm(8, c_out),
            nn.SiLU(),
        )
    def forward(self, x):
        return self.block(x)

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, t):
        device = t.device
        half = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half - 1)
        emb = torch.exp(torch.arange(half, device=device) * -emb)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb


class LearnableTimestepEmbedding(nn.Module):
    def __init__(self, dim, max_steps=1000):
        super().__init__()
        self.emb = nn.Embedding(max_steps, dim)
    def forward(self, t):
        # t tensor of ints shape [B]
        return self.emb(t)

class TinyUNet(nn.Module):
    def __init__(self, in_ch=4, base_ch=64, t_emb_dim=128, use_learnable_t_emb=False, use_attention=False):
        super().__init__()
        self.use_attention = use_attention
        self.t_emb_dim = t_emb_dim
        if use_learnable_t_emb:
            self.t_emb = LearnableTimestepEmbedding(t_emb_dim, max_steps=T)
        else:
            self.t_emb = SinusoidalPosEmb(t_emb_dim)
        # map timestep emb to channels
        self.time_mlp = nn.Sequential(nn.Linear(t_emb_dim, t_emb_dim*2), nn.SiLU(), nn.Linear(t_emb_dim*2, base_ch*2))


        # Downsampling path
        self.conv1 = SmallConvBlock(in_ch, base_ch) # Output: (B, base_ch, 32, 32)
        self.conv2 = SmallConvBlock(base_ch, base_ch*2) # Output: (B, base_ch*2, 16, 16) after pooling
        self.pool = nn.AvgPool2d(2)
        # Bottleneck
        self.conv3 = SmallConvBlock(base_ch*2, base_ch*2) # Output: (B, base_ch*2, 8, 8) after pooling
        # Upsampling path
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        # Changed conv4 to conv_up1 and adjusted its output channels
        self.conv_up1 = SmallConvBlock(base_ch*4, base_ch*2) # Output: (B, base_ch*2, 16, 16)
        # Added a new convolutional block for the final upsampling stage
        self.conv_up2 = SmallConvBlock(base_ch*3, base_ch) # Output: (B, base_ch, 32, 32)

        self.final = nn.Conv2d(base_ch, 3, kernel_size=1) # Predict noise for 3 color channels (Output: B, 3, 32, 32)


        if use_attention:
            # simple self-attention at bottleneck
            self.attn = nn.MultiheadAttention(embed_dim=base_ch*2, num_heads=4, batch_first=True)
        else:
            self.attn = None

    def forward(self, x, t, cond):
        # x: noisy color image [B,3,H,W]
        # cond: grayscale [B,1,H,W]
        # t: tensor of timesteps [B]
        B = x.shape[0]
        # concat conditioning
        h = torch.cat([x, cond], dim=1) # (B, 4, 32, 32)

        # Downsampling
        h1 = self.conv1(h) # (B, base_ch, 32, 32)
        h2 = self.conv2(self.pool(h1)) # (B, base_ch*2, 16, 16)

        # Bottleneck
        b = self.conv3(self.pool(h2)) # (B, base_ch*2, 8, 8)

        # Optional attention
        if self.attn is not None:
            # flatten spatial dims
            B, C, H, W = b.shape
            flat = b.view(B, C, H*W).permute(0,2,1) # [B, HW, C]
            attn_out, _ = self.attn(flat, flat, flat)
            attn_out = attn_out.permute(0,2,1).view(B, C, H, W)
            b = b + attn_out

        # Time embedding
        t_emb = self.t_emb(t)
        t_m = self.time_mlp(t_emb).view(B, -1, 1, 1)
        # Broadcast-add to bottleneck
        b = b + t_m

        # Upsampling path
        up1 = self.up(b) # (B, base_ch*2, 16, 16)
        cat1 = torch.cat([up1, h2], dim=1) # (B, base_ch*4, 16, 16)
        out1 = self.conv_up1(cat1) # (B, base_ch*2, 16, 16)

        up2 = self.up(out1) # (B, base_ch*2, 32, 32)
        cat2 = torch.cat([up2, h1], dim=1) # (B, base_ch*3, 32, 32)
        out2 = self.conv_up2(cat2) # (B, base_ch, 32, 32)

        # Final output
        out = self.final(out2) # (B, 3, 32, 32)
        return out

In [None]:
model = TinyUNet(in_ch=4, base_ch=64, t_emb_dim=128, use_learnable_t_emb=CFG['use_learnable_t_emb'], use_attention=CFG['use_attention']).to(CFG['device'])
print(model)


# count params
def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)
print('Trainable params:', count_params(model))

In [None]:
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_acp = extract(sqrt_alphas_cumprod, t, x_start.shape).to(x_start.device)
    sqrt_om_acp = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape).to(x_start.device)
    return sqrt_acp * x_start + sqrt_om_acp * noise


# loss: MSE between true noise and predicted noise
mse = nn.MSELoss()

In [None]:
optim = torch.optim.AdamW(model.parameters(), lr=CFG['lr'], weight_decay=1e-6)


def train_one_epoch(epoch):
    model.train()
    pbar = tqdm(enumerate(train_loader), total=(len(train_loader) if CFG['max_train_batches'] is None else CFG['max_train_batches']))
    running_loss = 0.0
    for i, (gray, color) in pbar:
        if CFG['max_train_batches'] and i >= CFG['max_train_batches']:
            break
        gray = gray.to(CFG['device'])
        color = color.to(CFG['device'])
        batch_size = color.shape[0]
        t = torch.randint(0, T, (batch_size,), device=CFG['device']).long()
        noise = torch.randn_like(color)
        x_t = q_sample(color, t, noise=noise)
        # model predicts noise given x_t and cond
        pred_noise = model(x_t, t, gray)
        loss = mse(noise, pred_noise)
        optim.zero_grad()
        loss.backward()
        optim.step()
        running_loss += loss.item()
        if i % 50 == 0:
            pbar.set_description(f"Epoch {epoch} loss {running_loss/(i+1):.4f}")
    return running_loss / (i+1)

In [None]:
def p_sample(model, x_t, t, cond):
    # one reverse step (DDPM) predicting noise
    bet = betas_t[t].to(x_t.device)
    sqrt_alpha = torch.sqrt(alphas[t]).to(x_t.device)
    alpha_cum = alphas_cumprod[t].to(x_t.device)
    sqrt_one_minus_alpha_cum = torch.sqrt(1 - alpha_cum).to(x_t.device)
    # predicted noise
    pred_noise = model(x_t, t.repeat(x_t.shape[0]), cond)
    # estimate x0
    x0_pred = (x_t - sqrt_one_minus_alpha_cum.view(-1,1,1,1)*pred_noise) / torch.sqrt(alpha_cum).view(-1,1,1,1)
    # compute mean of p(x_{t-1} | x_t)
    coef1 = (bet * torch.sqrt(alphas_cumprod[:-1] if t>0 else torch.tensor([1.0])).to(x_t.device)) / (1. - alpha_cum)
    # simplified (vectorized) implementation below uses known formula for posterior mean
    posterior_mean_coef1 = bet / torch.sqrt(1. - alpha_cum)
    # for simplicity, use standard DDPM update with added noise
    # compute the posterior mean directly using common formula
    # NOTE: for stability we use vectorized math per batch using tensors extracted via extract()
    beta_t = extract(betas_t, t, x_t.shape).to(x_t.device)
    alpha_t = extract(alphas, t, x_t.shape).to(x_t.device)
    alpha_cum_t = extract(alphas_cumprod, t, x_t.shape).to(x_t.device)
    sqrt_recip_alpha_t = torch.sqrt(1.0/alpha_t)
    # predicted x0
    x0_pred = (x_t - torch.sqrt(1-alpha_cum_t)*pred_noise) / torch.sqrt(alpha_cum_t)
    # coef for mean
    mean = (torch.sqrt(alpha_cum_t)* (1 - alpha_t) / (1 - alpha_cum_t)) * x0_pred + (alpha_t * (1 - alpha_cum_t) / (1 - alpha_cum_t)) * x_t
    # add noise for non-zero timesteps
    if (t > 0):
        noise = torch.randn_like(x_t)
        sigma = torch.sqrt(beta_t)
        return mean + sigma * noise
    else:
        return mean

In [None]:
def ddim_sample(model, cond, shape, steps=50, eta=0.0):
    # cond: grayscale [B,1,H,W]
    device = cond.device
    B = shape[0]
    x = torch.randn(shape, device=device)
    # create time schedule
    seq = torch.linspace(T-1, 0, steps, dtype=torch.long)
    alphas_cum = alphas_cumprod.to(device)
    for i in range(steps):
        t = torch.full((B,), int(seq[i].item()), device=device, dtype=torch.long)
        with torch.no_grad():
            pred_noise = model(x, t, cond)
        alpha_t = extract(alphas_cum, t, x.shape).to(device)
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t)
        x0_pred = (x - sqrt_one_minus_alpha_t * pred_noise) / sqrt_alpha_t
        # compute next x
        if i == steps-1:
            x = x0_pred
        else:
            t_next = torch.full((B,), int(seq[i+1].item()), device=device, dtype=torch.long)
            alpha_next = extract(alphas_cum, t_next, x.shape).to(device)
            sigma = eta * torch.sqrt((1 - alpha_next) / (1 - alpha_t) * (1 - alpha_t/alpha_next))
            dir_part = torch.sqrt(1.0 - alpha_next - sigma**2) * pred_noise
            noise = sigma * torch.randn_like(x)
            x = torch.sqrt(alpha_next) * x0_pred + dir_part + noise
    return x

In [None]:
best_val = 1e9
for epoch in range(1, CFG['epochs']+1):
    t0 = time.time()
    loss = train_one_epoch(epoch)
    t1 = time.time()
    print(f"Epoch {epoch} completed in {t1-t0:.1f}s, loss {loss:.4f}")
    # quick validation visualization: take a batch from val set and sample
    model.eval()
    with torch.no_grad():
        gray, color = next(iter(val_loader))
        gray = gray.to(CFG['device'])
        color = color.to(CFG['device'])
        B = min(8, color.shape[0])
        cond = gray[:B]
        if CFG['use_ddim']:
            samples = ddim_sample(model, cond, shape=(B,3,CFG['img_size'],CFG['img_size']), steps=CFG['sample_steps'], eta=0.0)
        else:
            # naive full ancestral sampling (slow)
            x = torch.randn((B,3,CFG['img_size'],CFG['img_size']), device=CFG['device'])
            for t_ in reversed(range(T)):
                t = torch.full((B,), t_, device=CFG['device'], dtype=torch.long)
                x = p_sample(model, x, t, cond)
            samples = x
        # clamp to [0,1]
        samples = samples.clamp(0,1)
        # show grayscale, ground truth, and sample
        cat = torch.cat([cond.repeat(1,3,1,1)[:B], color[:B], samples[:B]], dim=0)
        show_tensor_images(cat, nrow=B, title=f'Epoch {epoch}: Input Gray | GT Color | Sampled Color')

In [None]:
torch.save({'model_state_dict': model.state_dict(), 'cfg': CFG}, 'colorization_diffusion_small.pth')
print('Saved model to colorization_diffusion_small.pth')

In [None]:
checkpoint = torch.load('colorization_diffusion_small.pth', map_location=CFG['device'])

loaded_cfg = checkpoint['cfg']
loaded_model = TinyUNet(
    in_ch=4,
    base_ch=64,
    t_emb_dim=128,
    use_learnable_t_emb=loaded_cfg['use_learnable_t_emb'],
    use_attention=loaded_cfg['use_attention']
).to(loaded_cfg['device'])
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()
print('Loaded model from checkpoint.')

with torch.no_grad():
    gray_test, color_test = next(iter(val_loader))
    gray_test = gray_test.to(loaded_cfg['device'])
    color_test = color_test.to(loaded_cfg['device'])

    B_test = min(8, color_test.shape[0])
    cond_test = gray_test[:B_test]
    gt_color_test = color_test[:B_test]

    if loaded_cfg['use_ddim']:
        generated_samples = ddim_sample(
            loaded_model,
            cond_test,
            shape=(B_test, 3, loaded_cfg['img_size'], loaded_cfg['img_size']),
            steps=loaded_cfg['sample_steps'],
            eta=0.0
        )
    else:
        generated_samples = torch.randn((B_test,3,loaded_cfg['img_size'],loaded_cfg['img_size']), device=loaded_cfg['device'])
        for t_step in reversed(range(T)):
            t = torch.full((B_test,), t_step, device=loaded_cfg['device'], dtype=torch.long)
            generated_samples = p_sample(loaded_model, generated_samples, t, cond_test)

    generated_samples = generated_samples.clamp(0,1)

    display_images = torch.cat([
        cond_test.repeat(1,3,1,1)[:B_test],
        gt_color_test,
        generated_samples
    ], dim=0)
    show_tensor_images(display_images, nrow=B_test, title='Loaded Model: Input Gray | GT Color | Generated Color')

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np

gt_color_np = gt_color_test.permute(0,2,3,1).cpu().numpy()  # Convert to (B,H,W,C) and numpy
generated_samples_np = generated_samples.permute(0,2,3,1).cpu().numpy()

psnr_values = []
ssim_values = []

for i in range (B_test):
    current_psnr = psnr(gt_color_np[i], generated_samples_np[i], data_range=1.0)
    psnr_values.append(current_psnr)

    current_ssim = ssim(gt_color_np[i], generated_samples_np[i], data_range=1.0, channel_axis=2)
    ssim_values.append(current_ssim)

print(f'Average PSNR: {np.mean(psnr_values):.2f} dB')
print(f'Average SSIM: {np.mean(ssim_values):.4f}')