In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Resize, Compose, ToPILImage, RandomCrop, Lambda
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import os
from tqdm import tqdm
import numpy as np

from diffusion_model import UNet

2025-10-03 07:14:39.747356: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759475679.760103   55456 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759475679.764020   55456 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-03 07:14:39.776278: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# --- Configuration ---
IMG_SIZE = 64 # Target HR image size 64x64
UPSCALE_FACTOR = 2 # Double resolution
BATCH_SIZE = 16
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TRAIN_DATA_DIR = 'data/train'
MODEL_SAVE_PATH = 'diffusion_upscaler.pth'
LOG_IMAGE_EPOCHS = 10 # Log sample images every N epochs

if not os.path.exists(TRAIN_DATA_DIR) or not os.listdir(TRAIN_DATA_DIR):
    print(f"Data directory '{TRAIN_DATA_DIR}' is empty or does not exist.")

# --- Hyperparameters ---
TIMESTEPS = 1000
BETA_START = 0.0001
BETA_END = 0.02
betas = torch.linspace(BETA_START, BETA_END, TIMESTEPS, device=DEVICE)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
# --- Helper Functions ---
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.to(DEVICE))
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(DEVICE)

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def p_losses(denoise_model, x_start, t, low_res_img, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t, low_res_img)
    return F.mse_loss(noise, predicted_noise)

# --- Sampling functions for logging ---
@torch.no_grad()
def p_sample(model, x, t, t_index, low_res_img):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t, low_res_img) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample_and_log_images(model, low_res_upscaled, hr_image, epoch, writer):
    model.eval()
    shape = hr_image.shape
    device = next(model.parameters()).device
    
    img = torch.randn(shape, device=device)
    
    for i in reversed(range(0, TIMESTEPS)):
        img = p_sample(model, img, torch.full((shape[0],), i, device=device, dtype=torch.long), i, low_res_upscaled)
    
    # Normalize all images to [0, 1] for grid view
    generated_img = (img.clamp(-1, 1) + 1) / 2
    hr_image_grid = (hr_image.clamp(-1, 1) + 1) / 2
    low_res_grid = (low_res_upscaled.clamp(-1, 1) + 1) / 2
    
    # Log the first image of the batch
    grid = make_grid([low_res_grid[0], generated_img[0], hr_image_grid[0]], nrow=3)
    writer.add_image(f'Epoch {epoch}: Low-Res / Generated / High-Res', grid, epoch)
    model.train()


# --- Dataset ---
class SuperResDataset(Dataset):
    def __init__(self, image_dir, img_size):
        """
        Use Train Data (HR Images)/DIV2K_train_HR dataset from https://data.vision.ee.ethz.ch/cvl/DIV2K/
        """
        self.image_filenames = [os.path.join(image_dir, x) for x in os.listdir(image_dir)]
        self.transform = Compose([
            RandomCrop(img_size),
            ToTensor(), # Scales to [0, 1]
            Lambda(lambda t: (t * 2) - 1) # Scale to [-1, 1]
        ])

    def __getitem__(self, index):
        image = Image.open(self.image_filenames[index]).convert('RGB')
        return self.transform(image)

    def __len__(self):
        return len(self.image_filenames)


In [4]:
writer = SummaryWriter("/home/jovyan/logs/fit/diffusion_super_res")

dataset = SuperResDataset(TRAIN_DATA_DIR, IMG_SIZE)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# Get a fixed batch for consistent image logging
fixed_batch = next(iter(dataloader)).to(DEVICE)
fixed_lr = F.interpolate(fixed_batch, scale_factor=1/UPSCALE_FACTOR, mode='bicubic', antialias=True)
fixed_lr_upscaled = F.interpolate(fixed_lr, size=(IMG_SIZE, IMG_SIZE), mode='bicubic', antialias=True)


model = UNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
global_step = 0
for epoch in range(NUM_EPOCHS):
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    for step, batch in enumerate(progress_bar):
        model.train()
        optimizer.zero_grad()

        hr_images = batch.to(DEVICE) # High-res images, shape (B, 3, 64, 64), range [-1, 1]
        
        low_res_img = F.interpolate(hr_images, scale_factor=1/UPSCALE_FACTOR, mode='bicubic', antialias=True)
        low_res_upscaled = F.interpolate(low_res_img, size=(IMG_SIZE, IMG_SIZE), mode='bicubic', antialias=True)

        t = torch.randint(0, TIMESTEPS, (hr_images.shape[0],), device=DEVICE).long()
        
        loss = p_losses(model, hr_images, t, low_res_upscaled)
        
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix(loss=f'{loss.item():.4f}')
        writer.add_scalar('Training Loss', loss.item(), global_step)
        global_step += 1
    
    # Log images every N epochs
    if (epoch + 1) % LOG_IMAGE_EPOCHS == 0:
        print(f"Logging images for epoch {epoch+1}...")
        sample_and_log_images(model, fixed_lr_upscaled, fixed_batch, epoch + 1, writer)


print("Finished Training")
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")
writer.close()

Starting training...


Epoch 1/20: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, loss=0.4589]
Epoch 2/20: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it, loss=0.2355]
Epoch 3/20: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, loss=0.1939]
Epoch 4/20: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, loss=0.0991]
Epoch 5/20: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, loss=0.1046]
Epoch 6/20: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, loss=0.0694]
Epoch 7/20: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, loss=0.1004]
Epoch 8/20: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, loss=0.0835]
Epoch 9/20: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it, loss=0.0622]
Epoch 10/20: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it, loss=0.0524]


Logging images for epoch 10...


Epoch 11/20: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it, loss=0.1007]
Epoch 12/20: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, loss=0.1008]
Epoch 13/20: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, loss=0.0921]
Epoch 14/20: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, loss=0.0315]
Epoch 15/20: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, loss=0.0479]
Epoch 16/20: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, loss=0.0285]
Epoch 17/20: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, loss=0.0278]
Epoch 18/20: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, loss=0.0311]
Epoch 19/20: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, loss=0.0783]
Epoch 20/20: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, loss=0.0253]


Logging images for epoch 20...
Finished Training
Model saved to diffusion_upscaler.pth
