In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import tqdm as tqdm
from torchmetrics.image.fid import FrechetInceptionDistance
import diffusers
import matplotlib.pyplot as plt


In [None]:
train_dataset = torchvision.datasets.MNIST('data', train=True, download=False, transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((32,32))
]))
test_dataset = torchvision.datasets.MNIST('data', train=False, download=False, transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((32,32))
]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


In [None]:
x, y = next(iter(test_loader))

grid = torchvision.utils.make_grid(x, nrow=16)

plt.figure(figsize=(20,20))
plt.imshow(grid.permute(1, 2, 0))
plt.axis("off")
plt.show()


In [None]:
model = diffusers.UNet2DModel(
    sample_size=(32, 32),
    in_channels=1,
    out_channels=1,
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
    ),  # "DownBlock2D", "AttnDownBlock2D",
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
    ),  # "UpBlock2D", "AttnUpBlock2D",
    block_out_channels=(32, 64, 128, 128),
    layers_per_block=2,
)
num_train_steps = 0
print(f"Model has {model.num_parameters():,}")


In [None]:
scheduler = diffusers.DDIMScheduler(
    num_train_timesteps=4000,
    rescale_betas_zero_snr=True
)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=4e-5)
criterion = torch.nn.MSELoss()


In [None]:
@torch.no_grad()
def generate(batch_size, eta, num_inference_steps):
    pipeline = diffusers.DDIMPipeline(model, scheduler)
    
    imgs = pipeline(
        batch_size, eta=eta, num_inference_steps=num_inference_steps, output_type="np"
    ).images
    
    imgs = torch.as_tensor(imgs).permute(0, 3, 1, 2)
    
    grid = torchvision.utils.make_grid(imgs)
    
    plt.figure(figsize=(20,20))
    plt.title(f"Generated Images after {num_train_steps:,} steps")
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis("off")
    plt.show()


In [None]:
pbar = tqdm.tqdm(range(10))

for epoch in pbar:
    for x, _ in tqdm.tqdm(train_loader, desc="Batches", unit="batches", leave=False):
        t = torch.randint(0, 4000, (x.size(0),)).long()
        noise = torch.randn_like(x)
        
        noisy_x = scheduler.add_noise(x, noise, t)
        
        pred = model.forward(noisy_x, t).sample
        
        loss=criterion(pred, noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        num_train_steps += 1
        
        pbar.set_postfix_str(f"Total Steps: {num_train_steps:,}, Loss: {loss:.4f}")
    