In [5]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import math
import pandas as pd
import random
import wandb

import deepinv
from torchvision import datasets, transforms

from torch import nn, optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from datasets import load_dataset
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt

from models.UNet import UNet
from data.ImageDataset import ImageDataset

from pathlib import Path

In [4]:
# dynamically select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [None]:
# model  = UNet(
#     in_channels=1,
#     out_channels=1,
#     channels=[64, 128, 256, 512, 512, 384, 256],
#     scales=[-1, -1, -1, 1, 1, 1, 0],
#     attentions=[False, True, False, False, False, True, False],
#     time_steps=1000
# ).to(device)

In [6]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.0,), (1.0,)),
])

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(root="./data", train=True, download=True, transform=transform),
    batch_size=1,
    shuffle=True,
)

In [12]:
example = next(iter(dataloader))[0]

In [13]:
# load trained model
model_path = "/nfs/stak/users/negreanj/hpc-share/image-diffusion/checkpoints/diffusion-image-model/dim-2025_12_28_21_12/dim-2025_12_28_21_12_epoch99_end.pth"
chkpt = torch.load(model_path, weights_only=False, map_location=torch.device(device))

In [16]:
model = deepinv.models.DiffUNet(
    in_channels=1, out_channels=1, pretrained=None
).to(device)

model.load_state_dict(chkpt)
model.to(device)

RuntimeError: Error(s) in loading state_dict for DiffUNet:
	Missing key(s) in state_dict: "time_embed.0.weight", "time_embed.0.bias", "time_embed.2.weight", "time_embed.2.bias", "input_blocks.0.0.weight", "input_blocks.0.0.bias", "input_blocks.1.0.in_layers.0.weight", "input_blocks.1.0.in_layers.0.bias", "input_blocks.1.0.in_layers.2.weight", "input_blocks.1.0.in_layers.2.bias", "input_blocks.1.0.emb_layers.1.weight", "input_blocks.1.0.emb_layers.1.bias", "input_blocks.1.0.out_layers.0.weight", "input_blocks.1.0.out_layers.0.bias", "input_blocks.1.0.out_layers.3.weight", "input_blocks.1.0.out_layers.3.bias", "input_blocks.2.0.in_layers.0.weight", "input_blocks.2.0.in_layers.0.bias", "input_blocks.2.0.in_layers.2.weight", "input_blocks.2.0.in_layers.2.bias", "input_blocks.2.0.emb_layers.1.weight", "input_blocks.2.0.emb_layers.1.bias", "input_blocks.2.0.out_layers.0.weight", "input_blocks.2.0.out_layers.0.bias", "input_blocks.2.0.out_layers.3.weight", "input_blocks.2.0.out_layers.3.bias", "input_blocks.3.0.in_layers.0.weight", "input_blocks.3.0.in_layers.0.bias", "input_blocks.3.0.in_layers.2.weight", "input_blocks.3.0.in_layers.2.bias", "input_blocks.3.0.emb_layers.1.weight", "input_blocks.3.0.emb_layers.1.bias", "input_blocks.3.0.out_layers.0.weight", "input_blocks.3.0.out_layers.0.bias", "input_blocks.3.0.out_layers.3.weight", "input_blocks.3.0.out_layers.3.bias", "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "input_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.5.0.in_layers.0.weight", "input_blocks.5.0.in_layers.0.bias", "input_blocks.5.0.in_layers.2.weight", "input_blocks.5.0.in_layers.2.bias", "input_blocks.5.0.emb_layers.1.weight", "input_blocks.5.0.emb_layers.1.bias", "input_blocks.5.0.out_layers.0.weight", "input_blocks.5.0.out_layers.0.bias", "input_blocks.5.0.out_layers.3.weight", "input_blocks.5.0.out_layers.3.bias", "input_blocks.5.0.skip_connection.weight", "input_blocks.5.0.skip_connection.bias", "input_blocks.6.0.in_layers.0.weight", "input_blocks.6.0.in_layers.0.bias", "input_blocks.6.0.in_layers.2.weight", "input_blocks.6.0.in_layers.2.bias", "input_blocks.6.0.emb_layers.1.weight", "input_blocks.6.0.emb_layers.1.bias", "input_blocks.6.0.out_layers.0.weight", "input_blocks.6.0.out_layers.0.bias", "input_blocks.6.0.out_layers.3.weight", "input_blocks.6.0.out_layers.3.bias", "input_blocks.7.0.in_layers.0.weight", "input_blocks.7.0.in_layers.0.bias", "input_blocks.7.0.in_layers.2.weight", "input_blocks.7.0.in_layers.2.bias", "input_blocks.7.0.emb_layers.1.weight", "input_blocks.7.0.emb_layers.1.bias", "input_blocks.7.0.out_layers.0.weight", "input_blocks.7.0.out_layers.0.bias", "input_blocks.7.0.out_layers.3.weight", "input_blocks.7.0.out_layers.3.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_blocks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input_blocks.9.0.in_layers.0.weight", "input_blocks.9.0.in_layers.0.bias", "input_blocks.9.0.in_layers.2.weight", "input_blocks.9.0.in_layers.2.bias", "input_blocks.9.0.emb_layers.1.weight", "input_blocks.9.0.emb_layers.1.bias", "input_blocks.9.0.out_layers.0.weight", "input_blocks.9.0.out_layers.0.bias", "input_blocks.9.0.out_layers.3.weight", "input_blocks.9.0.out_layers.3.bias", "input_blocks.9.0.skip_connection.weight", "input_blocks.9.0.skip_connection.bias", "input_blocks.9.1.norm.weight", "input_blocks.9.1.norm.bias", "input_blocks.9.1.qkv.weight", "input_blocks.9.1.qkv.bias", "input_blocks.9.1.proj_out.weight", "input_blocks.9.1.proj_out.bias", "input_blocks.10.0.in_layers.0.weight", "input_blocks.10.0.in_layers.0.bias", "input_blocks.10.0.in_layers.2.weight", "input_blocks.10.0.in_layers.2.bias", "input_blocks.10.0.emb_layers.1.weight", "input_blocks.10.0.emb_layers.1.bias", "input_blocks.10.0.out_layers.0.weight", "input_blocks.10.0.out_layers.0.bias", "input_blocks.10.0.out_layers.3.weight", "input_blocks.10.0.out_layers.3.bias", "input_blocks.11.0.in_layers.0.weight", "input_blocks.11.0.in_layers.0.bias", "input_blocks.11.0.in_layers.2.weight", "input_blocks.11.0.in_layers.2.bias", "input_blocks.11.0.emb_layers.1.weight", "input_blocks.11.0.emb_layers.1.bias", "input_blocks.11.0.out_layers.0.weight", "input_blocks.11.0.out_layers.0.bias", "input_blocks.11.0.out_layers.3.weight", "input_blocks.11.0.out_layers.3.bias", "middle_block.0.in_layers.0.weight", "middle_block.0.in_layers.0.bias", "middle_block.0.in_layers.2.weight", "middle_block.0.in_layers.2.bias", "middle_block.0.emb_layers.1.weight", "middle_block.0.emb_layers.1.bias", "middle_block.0.out_layers.0.weight", "middle_block.0.out_layers.0.bias", "middle_block.0.out_layers.3.weight", "middle_block.0.out_layers.3.bias", "middle_block.1.norm.weight", "middle_block.1.norm.bias", "middle_block.1.qkv.weight", "middle_block.1.qkv.bias", "middle_block.1.proj_out.weight", "middle_block.1.proj_out.bias", "middle_block.2.in_layers.0.weight", "middle_block.2.in_layers.0.bias", "middle_block.2.in_layers.2.weight", "middle_block.2.in_layers.2.bias", "middle_block.2.emb_layers.1.weight", "middle_block.2.emb_layers.1.bias", "middle_block.2.out_layers.0.weight", "middle_block.2.out_layers.0.bias", "middle_block.2.out_layers.3.weight", "middle_block.2.out_layers.3.bias", "output_blocks.0.0.in_layers.0.weight", "output_blocks.0.0.in_layers.0.bias", "output_blocks.0.0.in_layers.2.weight", "output_blocks.0.0.in_layers.2.bias", "output_blocks.0.0.emb_layers.1.weight", "output_blocks.0.0.emb_layers.1.bias", "output_blocks.0.0.out_layers.0.weight", "output_blocks.0.0.out_layers.0.bias", "output_blocks.0.0.out_layers.3.weight", "output_blocks.0.0.out_layers.3.bias", "output_blocks.0.0.skip_connection.weight", "output_blocks.0.0.skip_connection.bias", "output_blocks.1.0.in_layers.0.weight", "output_blocks.1.0.in_layers.0.bias", "output_blocks.1.0.in_layers.2.weight", "output_blocks.1.0.in_layers.2.bias", "output_blocks.1.0.emb_layers.1.weight", "output_blocks.1.0.emb_layers.1.bias", "output_blocks.1.0.out_layers.0.weight", "output_blocks.1.0.out_layers.0.bias", "output_blocks.1.0.out_layers.3.weight", "output_blocks.1.0.out_layers.3.bias", "output_blocks.1.0.skip_connection.weight", "output_blocks.1.0.skip_connection.bias", "output_blocks.1.1.in_layers.0.weight", "output_blocks.1.1.in_layers.0.bias", "output_blocks.1.1.in_layers.2.weight", "output_blocks.1.1.in_layers.2.bias", "output_blocks.1.1.emb_layers.1.weight", "output_blocks.1.1.emb_layers.1.bias", "output_blocks.1.1.out_layers.0.weight", "output_blocks.1.1.out_layers.0.bias", "output_blocks.1.1.out_layers.3.weight", "output_blocks.1.1.out_layers.3.bias", "output_blocks.2.0.in_layers.0.weight", "output_blocks.2.0.in_layers.0.bias", "output_blocks.2.0.in_layers.2.weight", "output_blocks.2.0.in_layers.2.bias", "output_blocks.2.0.emb_layers.1.weight", "output_blocks.2.0.emb_layers.1.bias", "output_blocks.2.0.out_layers.0.weight", "output_blocks.2.0.out_layers.0.bias", "output_blocks.2.0.out_layers.3.weight", "output_blocks.2.0.out_layers.3.bias", "output_blocks.2.0.skip_connection.weight", "output_blocks.2.0.skip_connection.bias", "output_blocks.2.1.norm.weight", "output_blocks.2.1.norm.bias", "output_blocks.2.1.qkv.weight", "output_blocks.2.1.qkv.bias", "output_blocks.2.1.proj_out.weight", "output_blocks.2.1.proj_out.bias", "output_blocks.3.0.in_layers.0.weight", "output_blocks.3.0.in_layers.0.bias", "output_blocks.3.0.in_layers.2.weight", "output_blocks.3.0.in_layers.2.bias", "output_blocks.3.0.emb_layers.1.weight", "output_blocks.3.0.emb_layers.1.bias", "output_blocks.3.0.out_layers.0.weight", "output_blocks.3.0.out_layers.0.bias", "output_blocks.3.0.out_layers.3.weight", "output_blocks.3.0.out_layers.3.bias", "output_blocks.3.0.skip_connection.weight", "output_blocks.3.0.skip_connection.bias", "output_blocks.3.1.norm.weight", "output_blocks.3.1.norm.bias", "output_blocks.3.1.qkv.weight", "output_blocks.3.1.qkv.bias", "output_blocks.3.1.proj_out.weight", "output_blocks.3.1.proj_out.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bias", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers.0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.4.0.in_layers.0.weight", "output_blocks.4.0.in_layers.0.bias", "output_blocks.4.0.in_layers.2.weight", "output_blocks.4.0.in_layers.2.bias", "output_blocks.4.0.emb_layers.1.weight", "output_blocks.4.0.emb_layers.1.bias", "output_blocks.4.0.out_layers.0.weight", "output_blocks.4.0.out_layers.0.bias", "output_blocks.4.0.out_layers.3.weight", "output_blocks.4.0.out_layers.3.bias", "output_blocks.4.0.skip_connection.weight", "output_blocks.4.0.skip_connection.bias", "output_blocks.5.0.in_layers.0.weight", "output_blocks.5.0.in_layers.0.bias", "output_blocks.5.0.in_layers.2.weight", "output_blocks.5.0.in_layers.2.bias", "output_blocks.5.0.emb_layers.1.weight", "output_blocks.5.0.emb_layers.1.bias", "output_blocks.5.0.out_layers.0.weight", "output_blocks.5.0.out_layers.0.bias", "output_blocks.5.0.out_layers.3.weight", "output_blocks.5.0.out_layers.3.bias", "output_blocks.5.0.skip_connection.weight", "output_blocks.5.0.skip_connection.bias", "output_blocks.5.1.in_layers.0.weight", "output_blocks.5.1.in_layers.0.bias", "output_blocks.5.1.in_layers.2.weight", "output_blocks.5.1.in_layers.2.bias", "output_blocks.5.1.emb_layers.1.weight", "output_blocks.5.1.emb_layers.1.bias", "output_blocks.5.1.out_layers.0.weight", "output_blocks.5.1.out_layers.0.bias", "output_blocks.5.1.out_layers.3.weight", "output_blocks.5.1.out_layers.3.bias", "output_blocks.6.0.in_layers.0.weight", "output_blocks.6.0.in_layers.0.bias", "output_blocks.6.0.in_layers.2.weight", "output_blocks.6.0.in_layers.2.bias", "output_blocks.6.0.emb_layers.1.weight", "output_blocks.6.0.emb_layers.1.bias", "output_blocks.6.0.out_layers.0.weight", "output_blocks.6.0.out_layers.0.bias", "output_blocks.6.0.out_layers.3.weight", "output_blocks.6.0.out_layers.3.bias", "output_blocks.6.0.skip_connection.weight", "output_blocks.6.0.skip_connection.bias", "output_blocks.7.0.in_layers.0.weight", "output_blocks.7.0.in_layers.0.bias", "output_blocks.7.0.in_layers.2.weight", "output_blocks.7.0.in_layers.2.bias", "output_blocks.7.0.emb_layers.1.weight", "output_blocks.7.0.emb_layers.1.bias", "output_blocks.7.0.out_layers.0.weight", "output_blocks.7.0.out_layers.0.bias", "output_blocks.7.0.out_layers.3.weight", "output_blocks.7.0.out_layers.3.bias", "output_blocks.7.0.skip_connection.weight", "output_blocks.7.0.skip_connection.bias", "output_blocks.7.1.in_layers.0.weight", "output_blocks.7.1.in_layers.0.bias", "output_blocks.7.1.in_layers.2.weight", "output_blocks.7.1.in_layers.2.bias", "output_blocks.7.1.emb_layers.1.weight", "output_blocks.7.1.emb_layers.1.bias", "output_blocks.7.1.out_layers.0.weight", "output_blocks.7.1.out_layers.0.bias", "output_blocks.7.1.out_layers.3.weight", "output_blocks.7.1.out_layers.3.bias", "output_blocks.8.0.in_layers.0.weight", "output_blocks.8.0.in_layers.0.bias", "output_blocks.8.0.in_layers.2.weight", "output_blocks.8.0.in_layers.2.bias", "output_blocks.8.0.emb_layers.1.weight", "output_blocks.8.0.emb_layers.1.bias", "output_blocks.8.0.out_layers.0.weight", "output_blocks.8.0.out_layers.0.bias", "output_blocks.8.0.out_layers.3.weight", "output_blocks.8.0.out_layers.3.bias", "output_blocks.8.0.skip_connection.weight", "output_blocks.8.0.skip_connection.bias", "output_blocks.9.0.in_layers.0.weight", "output_blocks.9.0.in_layers.0.bias", "output_blocks.9.0.in_layers.2.weight", "output_blocks.9.0.in_layers.2.bias", "output_blocks.9.0.emb_layers.1.weight", "output_blocks.9.0.emb_layers.1.bias", "output_blocks.9.0.out_layers.0.weight", "output_blocks.9.0.out_layers.0.bias", "output_blocks.9.0.out_layers.3.weight", "output_blocks.9.0.out_layers.3.bias", "output_blocks.9.0.skip_connection.weight", "output_blocks.9.0.skip_connection.bias", "output_blocks.9.1.in_layers.0.weight", "output_blocks.9.1.in_layers.0.bias", "output_blocks.9.1.in_layers.2.weight", "output_blocks.9.1.in_layers.2.bias", "output_blocks.9.1.emb_layers.1.weight", "output_blocks.9.1.emb_layers.1.bias", "output_blocks.9.1.out_layers.0.weight", "output_blocks.9.1.out_layers.0.bias", "output_blocks.9.1.out_layers.3.weight", "output_blocks.9.1.out_layers.3.bias", "output_blocks.10.0.in_layers.0.weight", "output_blocks.10.0.in_layers.0.bias", "output_blocks.10.0.in_layers.2.weight", "output_blocks.10.0.in_layers.2.bias", "output_blocks.10.0.emb_layers.1.weight", "output_blocks.10.0.emb_layers.1.bias", "output_blocks.10.0.out_layers.0.weight", "output_blocks.10.0.out_layers.0.bias", "output_blocks.10.0.out_layers.3.weight", "output_blocks.10.0.out_layers.3.bias", "output_blocks.10.0.skip_connection.weight", "output_blocks.10.0.skip_connection.bias", "output_blocks.11.0.in_layers.0.weight", "output_blocks.11.0.in_layers.0.bias", "output_blocks.11.0.in_layers.2.weight", "output_blocks.11.0.in_layers.2.bias", "output_blocks.11.0.emb_layers.1.weight", "output_blocks.11.0.emb_layers.1.bias", "output_blocks.11.0.out_layers.0.weight", "output_blocks.11.0.out_layers.0.bias", "output_blocks.11.0.out_layers.3.weight", "output_blocks.11.0.out_layers.3.bias", "output_blocks.11.0.skip_connection.weight", "output_blocks.11.0.skip_connection.bias", "out.0.weight", "out.0.bias", "out.2.weight", "out.2.bias". 
	Unexpected key(s) in state_dict: "epoch", "loss", "model_state_dict", "scheduler_state_dict", "optimizer_state_dict", "train_config", "model_config". 

In [None]:
# dry run
x = torch.randn(32, 1, 32, 32).to(device)
y = model(x, 100)
print(y.shape) # should be [32, 1, 32, 32]

In [None]:
# diffusion scheduler
beta = torch.linspace(1e-4, 0.02, 1000, requires_grad=False)
alpha = 1 - beta
alpha_hat = torch.cumprod(alpha, dim=0).requires_grad_(False)

In [None]:
batch = example.unsqueeze(0).to(device)
t = 50
noise = torch.randn(batch.size(), device=batch.device, dtype=batch.dtype).to(device)
diffuse_batch = math.sqrt(alpha_hat[t]) * batch + math.sqrt(1 - alpha_hat[t]) * noise
plt.imshow(diffuse_batch.squeeze(0).permute(1, 2, 0).cpu().numpy(), vmin=0, vmax=1)

In [None]:
pred_noise = model(diffuse_batch, t)
gen_image = (1 / torch.sqrt(alpha[t])) * (
            diffuse_batch - (beta[t] / torch.sqrt(1 - alpha_hat[t])) * pred_noise
        )
plt.imshow(gen_image.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), vmin=0, vmax=1)
print(gen_image - diffuse_batch)

In [None]:
image = diffuse_batch - math.sqrt(1 - alpha_hat[t]) * model(diffuse_batch, t)
image /= math.sqrt(alpha_hat[t])
plt.imshow(image.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), vmin=0, vmax=1)

In [None]:
gen_image = torch.randn(batch.size(), device=batch.device, dtype=batch.dtype).to(device)

for i in range(100):
    t = 99 - i
    t *= 10

    gen_image = gen_image - math.sqrt(1 - alpha_hat[t]) * model(gen_image, t)
    gen_image /= math.sqrt(alpha_hat[t])
    gen_image += torch.randn(batch.size(), device=batch.device, dtype=batch.dtype).to(device) * math.sqrt(beta[t])

plt.imshow(image.squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), vmin=0, vmax=1)

In [None]:
plt.imshow(model(diffuse_batch, t).squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), vmin=-1, vmax=1)

In [None]:
train_config = {
    'max_examples': 1000,
    'max_len': 1000,
    'bs': 32,
    'lr': 0.0001,
    'weight_decay': 0.000001,
    'max_epochs': 10
}

In [None]:
# optimizer and criterion
optimizer = optim.AdamW(model.parameters(), lr=train_config['lr'], weight_decay=train_config['weight_decay'])
criterion = nn.MSELoss()

# construct linear warmup and cosine annealing cooldown
warmup_epochs = int(train_config['max_epochs'] / 10)
cooldown_epochs = train_config['max_epochs'] - warmup_epochs
epoch_len = len(dataloader)

linear = LinearLR(optimizer, start_factor=0.25, end_factor=1.0, total_iters=warmup_epochs*epoch_len)
cosine = CosineAnnealingLR(optimizer, T_max=cooldown_epochs*epoch_len, eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[linear, cosine], milestones=[warmup_epochs*epoch_len])

model.train()

# main training loop
pbar = tqdm(total=(train_config['max_epochs'])*epoch_len, desc="Training Iterations", unit="batch")
iteration = 0
for epoch in range(train_config['max_epochs']):
    # minibatch gradient descent
    for batch_idx, batch in enumerate(dataloader):
        # pick noising rate
        t = random.uniform(0.01, 0.99)

        # run batch through diffusion
        batch = batch.to(device)
        noise = torch.randn(batch.size(), device=x.device, dtype=x.dtype).to(device)
        diffuse_batch = math.sqrt(t) * batch + math.sqrt(1 - t) * noise

        # forward pass
        noise_pred = model(diffuse_batch)

        # compute L2 loss between predicted noise and true noise
        loss = criterion(noise_pred, noise)
        print(loss)

        # optimization
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        pbar.update(1)
        iteration += 1
        scheduler.step()


pbar.close()