In [1]:
# Cell 1
%matplotlib inline
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets
import matplotlib.pyplot as plt
from tqdm import trange

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Cell 2
def color_mnist_digit(digit, color, root='./data', train=True):
    mnist = datasets.MNIST(root=root, train=train, download=True)
    idx = (mnist.targets == digit)
    images = mnist.data[idx].float() / 255.0
    images = images.unsqueeze(1)  # [N,1,28,28]
    zeros = torch.zeros_like(images)
    if color == 'red':
        rgb = torch.cat([images, zeros, zeros], 1)
    elif color == 'green':
        rgb = torch.cat([zeros, images, zeros], 1)
    else:
        raise ValueError("color must be 'red' or 'green'")
    return rgb

red_6_imgs = color_mnist_digit(6, 'red')
green_2_imgs = color_mnist_digit(2, 'green')

In [5]:
# Cell 3
beta_0 = 0.1
beta_1 = 20.0

def log_alpha(t):
    return -0.5 * t * beta_0 - 0.25 * t ** 2 * (beta_1 - beta_0)

def log_sigma(t):
    return torch.log(torch.clamp(t, min=1e-6))

def dlog_alphadt(t):
    t = t.detach().requires_grad_(True)
    out = log_alpha(t)
    grad_out = torch.autograd.grad(out.sum(), t, create_graph=True)[0]
    return grad_out

def beta(t):
    return (1 + 0.5 * t * beta_0 + 0.5 * t ** 2 * (beta_1 - beta_0))

In [6]:
# Cell 4
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels+1, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels*2, 3, padding=1), nn.ReLU(),
        )
        self.middle = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1), nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels, 3, padding=1), nn.ReLU(),
            nn.Conv2d(base_channels, out_channels, 3, padding=1),
        )
    def forward(self, x, t):
        t_img = t.view(-1,1,1,1).expand(-1,1,28,28)
        x = torch.cat([x, t_img], 1)
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

In [7]:
# Cell 5
def train_score_model(model, images, epochs=5, batch_size=64, device='cpu'):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
    dataset = TensorDataset(images)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        for (x,) in loader:
            x = x.to(device)
            bs = x.size(0)
            t = torch.rand(bs, 1, device=device)
            eps = torch.randn_like(x)
            x_t = torch.exp(log_alpha(t))[:, None, None, None] * x \
                + torch.exp(log_sigma(t))[:, None, None, None] * eps
            score = model(x_t, t)
            loss = ((score + eps)**2).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return model

In [8]:
# Cell 6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_red = SimpleUNet().to(device)
model_green = SimpleUNet().to(device)

model_red = train_score_model(model_red, red_6_imgs.to(device), epochs=5, device=device)
model_green = train_score_model(model_green, green_2_imgs.to(device), epochs=5, device=device)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/molef/micromamba/envs/cxr_superdiff/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3579, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_352110/1798967990.py", line 6, in <module>
    model_red = train_score_model(model_red, red_6_imgs.to(device), epochs=5, device=device)
  File "/tmp/ipykernel_352110/2268791827.py", line 15, in train_score_model
    score = model(x_t, t)
  File "/home/molef/micromamba/envs/cxr_superdiff/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/molef/micromamba/envs/cxr_superdiff/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_352110/1671143704.py", line 18, in forward
    x = torch.cat([x, t_img], 1)
RuntimeError: Tensors must have same number of dimensio