 # Conditional Diffusion MNIST Image Generation

In [None]:
import torch
import torchvision
import torchvision.transforms.functional as F

from torch import nn
from tqdm import tqdm
from torchmultimodal.diffusion_labs.models.adm_unet.adm import adm_unet
from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance
from torchmultimodal.diffusion_labs.modules.losses.diffusion_hybrid_loss import DiffusionHybridLoss
from torchmultimodal.diffusion_labs.samplers.ddpm import DDPModule
from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import linear_beta_schedule, DiscreteGaussianSchedule
from torchmultimodal.diffusion_labs.transforms.diffusion_transform import RandomDiffusionSteps

device = "cuda"

# Define Model

To define a diffusion model you need to define four primary components:

1. Network and Adapters
2. Diffusion Schedule
3. Predictor
4. Sampler

The network typically used with image diffusion models is a [U-Net](https://paperswithcode.com/method/u-net). A U-Net is a convolutional network that maps the input space directly to a equal sized output space. This makes it ideal for image segmentation and transformation tasks. Since we are denosing an image, a U-Net works very well here. [ADMUnet](https://arxiv.org/abs/2105.05233) is a specific implementation shown to work well for image generation. 

The default values for adm_unet are a bit over-kill for the tiny MNIST dataset so we'll choose some custom smaller values here.

In [None]:
unet = adm_unet(
    time_embed_dim=128,          # Model takes diffusion timestep as a conditional input
    cond_embed_dim=128,          # Projected size of conditional embedding
    embed_dim=768,               # Size of conditional embedding for conditional image generation
    embed_name="digit",          # Name of conditional input
    predict_variance_value=True, # If the model should learn per step variance values for sampling
    image_channels=1,            # MNIST images are single channel
    depth=128,                   # U-Net layer depth
    num_resize=3,                # Number of upsample/downsampler blocks for U-Net
    num_res_per_layer=3,         # Residual Blocks per channel.
)

Apart from the core network, we can add adapters to the network design to allow it to handle different tasks common for diffusion training. Here we'll use [classifer free guidance](https://arxiv.org/abs/2207.12598), this is a technique used for conditional generative models that improves image-prompt alignment.

In [None]:
decoder = CFGuidance(unet,          # Model being adapted
                    {"digit": 768}, # Define conditional inputs name and size
                    guidance=2.0)   # How strong to to increase image-prompt alignment at the expense of image diversity

Step 2 is to define the schedule. The schedule is the [diffusion process](https://arxiv.org/abs/2006.11239) which describes the amount of noise added to the image at each diffusion step. Using a Gaussian schedule, we sample Gaussian noise at every step defined as $\mathcal{N}(0, \beta)$. Here we define the schedule with a linearly increasing schedule of variance ($\beta$) values.

In [None]:
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000)) # Helper function for vairance values

Step 3 is the Predictor which determines the output of the denoising network. Here we train the model to output the noise to be removed at each step. The predictor contains the methods to convert the model output into the cleaned image.

In [None]:
predictor = NoisePredictor(schedule,                        # Scale of noise at each step
                           lambda x: torch.clamp(x, -1, 1)) # Min and max image values

Step 4 is to define the Sampler. The **Sampler** applies the denoising **Network** for each step of the diffusion **Schedule** using the **Predictor** to fully denoise an image. Here we use the [Diffusion Probabilistic Implicit Models](https://arxiv.org/abs/2006.11239) sampler which is the original diffusion sampler.

In [None]:
eval_steps = torch.linspace(0, 999, 250, dtype=torch.int)     # Diffusion steps to sample at inference
decoder = DDPModule(decoder, schedule, predictor, eval_steps) # Sampler

Finally, to condition this model on MNIST digits, lets define a simple conditional encoder to convert digits to conditional embeddings:

In [None]:
encoder = nn.Embedding(10,  # Number of digits
                       768) # Embed size

# Training

For data, we need to define the transforms, a dataset, and a dataloader. For training a diffusion model, you sample each data point from the diffusion process. The RandomDiffusionSteps transform takes in your data point and samples a random diffusion step and applies noise to the data accordingly.?

In [None]:
from torchvision.transforms import Compose, Resize, ToTensor, Lambda

diffusion_transform = RandomDiffusionSteps(schedule, batched=False)    # Diffusion transform given schedule
transform = Compose([Resize(32),                                       # Resize MNIST image for network
                     ToTensor(),
                     Lambda(lambda x: 2*x - 1),                        # Scale image to [-1, 1]
                     Lambda(lambda x: diffusion_transform({"x": x}))]) # Apply diffusion transform

Load Dataset

In [None]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

train_dataset = MNIST("mnist", train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=192, shuffle=True, num_workers=2, pin_memory=True)

For DDPM we'll train using [diffusion hybrid loss](https://arxiv.org/abs/2102.09672) between the model output and added noise. This loss measures the distance between the model output and the target as well as the KL Divergence between the predicted noise variance and actual.

In [None]:
h_loss = DiffusionHybridLoss(schedule)

We can then choose our favorite optimizer and optionally use a scaler for mixed precision training.

In [None]:
encoder.to(device)
decoder.to(device)

optimizer = torch.optim.AdamW(
    [{"params": encoder.parameters()}, {"params": decoder.parameters()}], lr=0.0001
)
scaler = torch.cuda.amp.GradScaler()

# Train

Here is a simple standard Pytorch training loop, just with mixed precision added in for faster training. The diffusion model has a fixed signature

$model(x_t, t, cond_{dict})$

When training, the model only computes a single denoising step per input. The model is also given a dictionary of conditional inputs that the Adapters and underlying network have access to for conditional generation.

In [None]:
epochs = 5

encoder.train()
decoder.train()
for e in range(epochs):
	for sample in (pbar := tqdm(train_dataloader)):
		x, d = sample
		x0, xt, noise, t, d = x["x"].to(device), x["xt"].to(device), x["noise"].to(device), x["t"].to(device), d.to(device)
		optimizer.zero_grad()

		with torch.autocast(device):
			d = encoder(d)
			out = decoder(xt, t, {"digit": d})
			loss = h_loss(out.prediction, noise, out.mean, out.log_variance, x0, xt, t)

		scaler.scale(loss).backward()
		scaler.step(optimizer)
		scaler.update()

		pbar.set_description(f'{e+1}| Loss: {loss.item()}')

# Eval

Likewise, eval is done as a standard Pytorch eval model call. While in train mode the model computes a single denoising step and outputs the raw model output, in eval mode the model computes steps $t, ..., 0$ and returns the denoised data.

In eval, if no timestep is provided, it's assumed to be the largest timestep $T$ and the input is $x_T$. Since $x_T$ is equivalent to random noise, you sample the input from torch.randn. 

In [None]:
encoder.eval()
decoder.eval()

digit = torch.as_tensor([i for i in range(1,10)]).to(device) # Generate digits 0 to 9
noise = torch.randn(size=(9,1,32,32)).to(device)             # Sample 9 inputs

with torch.no_grad():
    d = encoder(digit)
    imgs = decoder(noise, conditional_inputs={"digit": d})

img_grid = torchvision.utils.make_grid(imgs, 3)
img = F.to_pil_image((img_grid + 1) / 2)
img.resize((288, 288))