In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import diffusers
from tqdm import tqdm
from nn_zoo.datamodules import MNISTDataModule

In [2]:
dm = MNISTDataModule(
    data_dir="data",
    dataset_params={
        "download": True,
        "transform": torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize((32, 32)),
                torchvision.transforms.ToTensor(),
            ]
        ),
    },
    loader_params={
        "batch_size": 64,
        "num_workers": 2,
    },
)
dm.prepare_data()
dm.setup()
train_loader = dm.train_dataloader()
test_loader = dm.test_dataloader()

In [28]:
def down_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2),
    )
def up_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Upsample(scale_factor=2),
    )

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_embeddings = nn.Embedding(1000, 32 * 32)
        
        self.down = nn.ModuleList([
            down_block(1, 16),
            down_block(16, 32),
            down_block(32, 64),
            down_block(64, 128),
        ])
        self.up = nn.ModuleList([
            up_block(128, 64),
            up_block(64, 32),
            up_block(32, 16),
            up_block(16, 1),
        ])
        
    def forward(self, x, t):
        if isinstance(t, int or float):
            t = torch.tensor([t], device=x.device)
        
        t = self.time_embeddings(t)
        x = x + t.view(-1, 1, 32, 32)
        skips = [x]
        for layer in self.down:
            x = layer(x)
            skips.append(x)
        for layer, skip in zip(self.up, reversed(skips)):
            x = layer(x + skip * (1 / (len(skips))))
        return x
    
model = Model().to("mps")
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
print(model(torch.randn(1, 1, 32, 32, device="mps"), 1).shape)

torch.Size([1, 1, 32, 32])


In [29]:
ddim = diffusers.DDIMScheduler(rescale_betas_zero_snr=True, num_train_timesteps=10)

In [30]:
x, y = next(iter(train_loader))
x, y = x.to("mps"), y.to("mps")
x, y = x[:1], y[:1]
x, y = x.repeat(512, 1, 1, 1), y.repeat(512, 1, 1, 1)

In [31]:
pbar = tqdm(range(1000))
for i in pbar:
    t = torch.randint(0, 10, (x.shape[0],), device=x.device)
    noise = torch.randn_like(x, device=x.device)
    
    x_noisy = ddim.add_noise(x, noise, t)
    
    y_pred = model(x_noisy, t)
    
    loss = F.mse_loss(y_pred, noise) 
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    pbar.set_postfix_str(f"loss={loss.item():.3f}")

100%|██████████| 1000/1000 [01:23<00:00, 11.91it/s, loss=0.945]


In [None]:
plt.hist(noise.cpu().numpy().flatten(), bins=100, alpha=0.5)
plt.hist(y_pred.cpu().detach().numpy().flatten(), bins=100, alpha=0.5)
plt.show()

In [None]:
@torch.no_grad()
def generate(eta, steps, model):
    ddim.num_inference_steps = steps
    x = torch.randn((64, 1, 32, 32), device="mps")
    
    for i in reversed(range(0, 1000, 1000 // steps)):
        noise_pred = model(x, i)
        x = ddim.step(noise_pred, i, x, eta=eta).prev_sample
        # print(x.keys())
        
    return x

import matplotlib.pyplot as plt

grid = torchvision.utils.make_grid(generate(1, 10, model).cpu(), nrow=8)
plt.imshow(grid.permute(1, 2, 0))
