In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import math
import pandas as pd
import random
import wandb

# import deepinv
from torchvision import datasets, transforms

from torch import nn, optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from datasets import load_dataset
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt

from models.UNet import UNet
from models.VisionTransformer import VisionTransformer

from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# dynamically select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# dataloader = torch.utils.data.DataLoader(
#     datasets.MNIST(root="./data", train=True, download=True, transform=transform),
#     batch_size=1,
#     shuffle=True,
# )

# dataloader = torch.utils.data.DataLoader(
#     datasets.StanfordCars(root="./data", split='train', download=False, transform=transform),
#     batch_size=1,
#     shuffle=True,
# )

dataloader = torch.utils.data.DataLoader(
    datasets.CelebA(root="./data", split='train', download=False, transform=transform),
    batch_size=1,
    shuffle=True,
)

In [None]:
example = next(iter(dataloader))[0]

In [None]:
# load trained model
model_path = "/Users/josh/Documents/GitHub/image-diffusion/checkpoints/diffusion-image-model/dim-2026_01_05_15_01_epoch129_end.pth"
chkpt = torch.load(model_path, weights_only=False, map_location=torch.device(device))

# get model configuration
model_config = chkpt['model_config']
train_config = chkpt['train_config']

model = UNet(
    in_channels=model_config['in_channels'],
    out_channels=model_config['out_channels'],
    channels=model_config['channels'],
    scales=model_config['scales'],
    attentions=model_config['attentions'],
    time_steps=model_config['time_steps'],
).to(device)
model.load_state_dict(chkpt['model_state_dict'])

In [7]:
model = VisionTransformer(
    patch_size=4,
    in_channels=3,
    out_channels=3,
    embed_dim=256,
    num_layers=4,
    num_heads=4
).to(device)

In [None]:
print(model_config)
print(train_config)

In [None]:
# model = deepinv.models.DiffUNet(
#     in_channels=1, out_channels=1, pretrained=None
# ).to(device)

# model.load_state_dict(chkpt['model_state_dict'])
# model.to(device)

In [8]:
# dry run
x = torch.randn(32, 3, 32, 32).to(device)
y = model(x)
print(y.shape) # should be [32, 3, 32, 32]

torch.Size([32, 3, 32, 32])


In [None]:
# diffusion scheduler
beta = torch.linspace(1e-4, 0.02, model_config['time_steps'], requires_grad=False).to(device)
alpha = 1.0 - beta
alpha_hat = torch.cumprod(alpha, dim=0).requires_grad_(False).to(device)

In [None]:
plt.imshow((example * 0.5 + 0.5).squeeze(0).permute(1, 2, 0).cpu().numpy(), vmin=-1, vmax=1)

In [None]:
batch = example.to(device)
t = 10
noise = torch.randn(batch.size(), device=batch.device, dtype=batch.dtype).to(device)
diffuse_batch = math.sqrt(alpha_hat[t]) * batch + math.sqrt(1 - alpha_hat[t]) * noise

plt.imshow((diffuse_batch * 0.5 + 0.5).squeeze(0).permute(1, 2, 0).cpu().numpy(), vmin=-1, vmax=1)

In [None]:
x = diffuse_batch
pred_noise = model(diffuse_batch, t)

if t > 0:
    noise = torch.randn_like(x)
else:
    noise = 0

x = (1 / torch.sqrt(alpha[t])) * (
    x - (beta[t] / torch.sqrt(1 - alpha_hat[t])) * pred_noise
) + torch.sqrt(beta[t]) * noise

plt.imshow((x * 0.5 + 0.5).squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), vmin=0, vmax=1)

In [None]:
image = diffuse_batch - math.sqrt(1 - alpha_hat[t]) * model(diffuse_batch, t)
image /= math.sqrt(alpha_hat[t])
plt.imshow((image * 0.5 + 0.5).squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), vmin=0, vmax=1)

In [None]:
torch.mps.empty_cache()
x = torch.randn(1, model_config['in_channels'], train_config['image_size'], train_config['image_size']).to(device)

for t in reversed(range(0, model_config['time_steps'])):
    # predict noise
    pred_noise = model(x, t)
    x = (1 / torch.sqrt(alpha[t])) * (x - (beta[t] / torch.sqrt(1 - alpha_hat[t])) * pred_noise)

    # add noise up to final generation
    if t > 0:
        x = x + torch.sqrt(beta[t]) * torch.randn_like(x).to(device)

plt.imshow((x * 0.5 + 0.5).squeeze(0).permute(1, 2, 0).cpu().detach().numpy(), vmin=-1, vmax=1)