<font face='monospace'>

## <b>FULL DIFFUSION</b>

In [None]:
%pip install -qU fastai fastcore datasets torcheval diffusers

In [None]:
import math
import torch
import logging
import fastcore.all as fc
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from glob import glob
from pathlib import Path
from torch.nn import init
from torch import nn, optim
from scipy import integrate
from functools import partial, wraps
from datasets import load_dataset
from fastcore.foundation import L
from diffusers import AutoencoderKL
from torch.optim import lr_scheduler
from fastprogress import progress_bar
from torcheval.metrics import Mean, Metric
from torch.utils.data import DataLoader,default_collate

from diffusion_ai import *

# Set seeds for reproducibility
set_seed(42)
torch.manual_seed(1)

# Configure logging and torch settings
logging.disable(logging.WARNING)
torch.set_printoptions(precision=5, linewidth=140, sci_mode=False)

In [None]:
# Constants
IMG_KEY, LABEL_KEY = 'image', 'label'
NAME = "fashion_mnist"
NUM_STEPS = 1000
BATCH_SIZE = 512
SIGMA_DATA = 0.66  # standard deviation of our tarining dataset
dataset = load_dataset(NAME)

<font face='monospace'>

#### <b>Sampling</b>

Here we use `fid`, `kid` to compare the feature distribution of original data with the trained model's weight distribution

In [None]:
# evaluation model
cnn_model = torch.load('models/inference.pkl')
del(cnn_model[8])
del(cnn_model[7])

# pre-process
@inplace
def transformi(b): 
    b['image'] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b['image']]
tds = dataset.with_transform(transformi)
dls = DataLoaders.from_dd(tds, BATCH_SIZE, num_workers=4)

# first batch of our original data
dt = dls.train
xb,yb = next(iter(dt))

sample_size = (512,1,32,32)
ie = ImageEval(cnn_model, dls, cbs=[DeviceCB()])

<font face='monospace'>

### 1️⃣

Let's implement our own `unet` model architecture and start training our diffusion model. This UNET model will not have any type of embeddings, class conditioning or attention mechanisms. Its a smiple model.

In [None]:
@inplace
def transformi(batch):
    batch[IMG_KEY] = [F.pad(TF.to_tensor(image), (2, 2, 2, 2)) * 2 - 1 for image in batch[IMG_KEY]]

def compute_scalings(sigma):
    # Compute scaling factors using karras approach
    total_variance = sigma**2 + SIGMA_DATA**2
    c_skip = SIGMA_DATA**2 / total_variance
    c_out = sigma * SIGMA_DATA / total_variance.sqrt()
    c_in = 1 / total_variance.sqrt()
    return c_skip, c_out, c_in

def noisify(images):
    # Add noise to images using the Karras noise scheduler
    device = images.device
    sigma = (torch.randn([len(images)]) * 1.2 - 1.2).exp().to(images).reshape(-1, 1, 1, 1)
    noise = torch.randn_like(images, device=device)
    c_skip, c_out, c_in = compute_scalings(sigma)
    noised_input = images + noise * sigma
    target = (images - c_skip * noised_input) / c_out
    return (noised_input * c_in, sigma.squeeze()), target

def collate(batch):
    # Custom collate function for DataLoader
    return noisify(default_collate(batch)[IMG_KEY])

def create_dataloader(dataset):
    # Create a DataLoader for the given dataset
    return DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collate, num_workers=4)

In [None]:
# Pre-process the data, add noise and create DataLoaders
transformed_dataset = dataset.with_transform(transformi)
dataloaders = DataLoaders(create_dataloader(transformed_dataset['train']), create_dataloader(transformed_dataset['test']))

<font face='monospace'>

<b>---UNET ARCHITECTURE---</b>

In [None]:
def unet_conv(ni, nf, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True):
    layers = nn.Sequential()
    if norm: layers.append(norm(ni))
    if act : layers.append(act())
    layers.append(nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias))
    return layers

In [None]:
class UNetResBlock(nn.Module):
    # Residual block for UNet
    def __init__(self, in_channels, out_channels=None, kernel_size=3, activation=nn.SiLU, normalization=nn.BatchNorm2d):
        super().__init__()
        if out_channels is None:
            out_channels = in_channels
        self.convs = nn.Sequential(
            unet_conv(in_channels, out_channels, kernel_size, act=activation, norm=normalization),
            unet_conv(out_channels, out_channels, kernel_size, act=activation, norm=normalization)
        )
        self.identity_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else fc.noop

    def forward(self, x):
        return self.convs(x) + self.identity_conv(x)

In [None]:
class SaveModule:
    # Module that saves its output during the forward pass
    def forward(self, x, *args, **kwargs):
        self.saved = super().forward(x, *args, **kwargs)  # calls the next forward pass
        return self.saved

class SavedResBlock(SaveModule, UNetResBlock): pass
class SavedConv(SaveModule, nn.Conv2d): pass

In [None]:
def down_block(in_channels, out_channels, add_downsample=True, num_layers=1):
    # Create a downsampling block for the UNet
    layers = [SavedResBlock(in_channels if i == 0 else out_channels, out_channels) for i in range(num_layers)]
    if add_downsample:
        layers.append(SavedConv(out_channels, out_channels, 3, stride=2, padding=1))
    return nn.Sequential(*layers)

In [None]:
def upsample(nf):
    # Create an upsampling layer for the UNet
    return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1))

In [None]:
class UpBlock(nn.Module):
    # Upsampling block for UNet
    def __init__(self, in_channels, prev_out_channels, out_channels, add_upsample=True, num_layers=2):
        super().__init__()
        self.resnets = nn.ModuleList([
            UNetResBlock((prev_out_channels if i == 0 else out_channels) + (in_channels if (i == num_layers - 1) else out_channels), out_channels)
            for i in range(num_layers)
        ])
        self.up_sample = upsample(out_channels) if add_upsample else nn.Identity()

    def forward(self, x, skip_connections):
        for resnet in self.resnets:
            x = resnet(torch.cat([x, skip_connections.pop()], dim=1))
        return self.up_sample(x)

In [None]:
# But when we use nn.ModuleList, we need a for loop to pass all the blocks.   
class UNet(nn.Module):
    # UNet architecture for image denoising
    def __init__(self, in_channels=1, out_channels=1, feature_sizes=(224,448,672,896), num_layers=1):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, feature_sizes[0], kernel_size=3, padding=1)
        self.down_blocks = nn.Sequential()
        nf = feature_sizes[0]
        for i in range(len(feature_sizes)):
            ni = nf
            nf = feature_sizes[i]
            self.down_blocks.append(down_block(ni, nf, add_downsample=i!=len(feature_sizes)-1, num_layers=num_layers))
        self.middle_block = UNetResBlock(feature_sizes[-1])
        
        rev_feature = list(reversed(feature_sizes))
        nf = rev_feature[0]
        self.up_blocks = nn.ModuleList()
        for i in range(len(feature_sizes)):
            prev_nf = nf
            nf = rev_feature[i]
            ni = rev_feature[min(i+1, len(feature_sizes)-1)]
            self.up_blocks.append(UpBlock(ni, prev_nf, nf, add_upsample=i!=len(feature_sizes)-1, num_layers=num_layers+1))
            
        self.conv_out = unet_conv(feature_sizes[0], out_channels, activation=nn.SiLU, normalization=nn.BatchNorm2d)

    def forward(self, inp):
        x = self.conv_in(inp[0])
        skip_connections = [x]
        x = self.down_blocks(x)
        skip_connections += [layer.saved for block in self.down_blocks for layer in block]
        x = self.middle_block(x)
        for block in self.up_blocks:
            x = block(x, skip_connections)
        return self.conv_out(x)

In [None]:
# Define model, optimizer, scheduler, and learner
LR = 3e-3
EPOCHS = 25
opt_func = partial(optim.Adam, eps=1e-5)
tmax = EPOCHS * len(dataloaders.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=LR, total_steps=tmax)
cbs = [DeviceCB(), MixedPrecision(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched)]
model = UNet(in_channels=1, out_channels=1, feature_sizes=(32,64,128,256), num_layers=2)
learn = Learner(model, dataloaders, nn.MSELoss(), lr=LR, cbs=cbs, opt_func=opt_func)

In [None]:
# Train the model
learn.fit(EPOCHS)

<font face='monospace'>
    
The model that we get after training is basically a diffusion model with our own unet architecture, We can sample from it using any of our samplers. We trained it using noisified images, so we can sample from it using our either `ddpm` or `ddim` samplers.

---

<font face='monospace'>

### 2️⃣

#### <b>Timestep Model</b>
This time let's create a model where we add sinusoidal noise at each time step. The below code is creating a **sinusoidal embedding** for timesteps. This technique is inspired by the positional encoding used in transformer models, where it helps the model to understand the order or position of elements (pixels in diffusion) in a sequence.

In diffusion models, this type of embedding can provide information about which timestep (or noise level) an image is at during the denoising process. It helps the model to adapt its behavior based on how far along the diffusion process is, which is crucial for generating coherent outputs at each step.

<font face='monospace'>
    
The `EmbUNetModel` is a U-Net architecture that incorporates timestep embeddings into its structure.

- The main reason for using such a model in diffusion models is to condition the generation process on both local and global information about the image at different scales and timesteps. This allows for more controlled and coherent image generation as noise is progressively removed from an image over time.
- In diffusion models, timestep embeddings are crucial because they allow the model to condition its predictions on the specific point in time during the denoising process. This is important because the amount and type of noise added to the images vary at each timestep.

In [None]:
IMG_KEY, LABEL_KEY = 'image','label'
BATCH_SIZE = 512

In [None]:
# using cosine noise scheduler
def abar(t): return (t*math.pi/2).cos()**2
def inv_abar(x): return x.sqrt().acos()*2/math.pi

def noisify(x0):
    device = x0.device
    n = len(x0)
    t = torch.rand(n,).to(x0).clamp(0,0.999)
    ε = torch.randn(x0.shape, device=device)
    abar_t = abar(t).reshape(-1, 1, 1, 1).to(device)
    xt = abar_t.sqrt()*x0 + (1-abar_t).sqrt()*ε
    return (xt, t.to(device)), ε

def collate_ddpm(b): return noisify(default_collate(b)[IMG_KEY])
def dl_ddpm(ds): return DataLoader(ds, batch_size=BATCH_SIZE, collate_fn=collate_ddpm, num_workers=4)

In [None]:
@inplace
def transformi(b): b[IMG_KEY] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[IMG_KEY]]

tds = dataset.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

dl = dls.train
(xt,t),eps = b = next(iter(dl))

In [None]:
# See the EmbUNetModel in diffusion_ai.diffusion.py file. This Unet has timestep embeddings and attention implemeted in it
# In UNET model, we have also implemented attention channels in the mid blocks.
lr = 1e-2
epochs = 1
opt_func = partial(optim.Adam, eps=1e-5)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = EmbUNetModel(in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func)

In [None]:
learn.fit(epochs)

<font face='monospace'>

This code might not work sometimes because of `OutOfMemoryError`, simply try changing batch size to `32`, or `64`. It might take longer to run but its a small fix! 

<font face='monospace'>

#### <b>Sampling from the above model using ddim step</b>

In [None]:
def ddim_step(x_t, noise, abar_t, abar_t1, bbar_t, bbar_t1, eta, sig, clamp=True):
    sig = ((bbar_t1/bbar_t).sqrt() * (1-abar_t/abar_t1).sqrt()) * eta
    x_0_hat = ((x_t-(1-abar_t).sqrt()*noise) / abar_t.sqrt())
    if clamp: x_0_hat = x_0_hat.clamp(-1,1)
    if bbar_t1<=sig**2+0.01: sig=0.  # set to zero if very small or NaN
    x_t = abar_t1.sqrt()*x_0_hat + (bbar_t1-sig**2).sqrt()*noise
    x_t += sig * torch.randn(x_t.shape).to(x_t)
    return x_0_hat,x_t

@torch.no_grad()
def sample(f, model, sz, steps, eta=1., clamp=True):
    model.eval()
    ts = torch.linspace(1-1/steps,0,steps)
    x_t = torch.randn(sz).cuda()
    preds = []
    for i,t in enumerate(progress_bar(ts)):
        t = t[None].cuda()
        abar_t = abar(t)
        noise = model((x_t, t))
        abar_t1 = abar(t-1/steps) if t>=1/steps else torch.tensor(1)
        x_0_hat,x_t = f(x_t, noise, abar_t, abar_t1, 1-abar_t, 1-abar_t1, eta, 1-((i+1)/100), clamp=clamp)
        preds.append(x_0_hat.float().cpu())
    return preds

sample_size = (512, 1, 32, 32)

In [None]:
# see diffusion_ai.diffusion.py file for this code
preds = sample(ddim_step, model, sample_size, steps=100, eta=1.)
s = (preds[-1]*2)
s.min(),s.max(),s.shape

In [None]:
show_images(s[:25].clamp(-1,1), imsize=1.5)

In [None]:
ie.fid(s),ie.kid(s),s.shape

In [None]:
preds = sample(ddim_step, model, sample_size, steps=100, eta=1.)
ie.fid(preds[-1]*2)

---

<font face='monospace'>

### 3️⃣

<b>Conditional model</b><br>

Here we try to add class embedding to the `unet` model.

The `CondUNetModel` is similar to the `EmbUNetModel` in structure but includes an additional conditioning mechanism on class labels.

While both models use timestep embeddings, the `CondUNetModel` extends this concept by also conditioning on class labels, making it suitable for a wider range of tasks that require fine-grained control over the output based on categorical information.

In [None]:
def collate_ddpm(b):
    # Here we try to group both images and labels
    b = default_collate(b)
    (xt,t),eps = noisify(b[IMG_KEY])
    return (xt,t,b[LABEL_KEY]),eps

In [None]:
@inplace
def transformi(b): 
    b[IMG_KEY] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[IMG_KEY]]

tds = dataset.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

dl = dls.train
(xt,t,c),eps = b = next(iter(dl))

In [None]:
class CondUNetModel(nn.Module):
    def __init__( self, n_classes, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
        self.n_temb = nf = nfs[0]
        n_emb = nf*4
        self.cond_emb = nn.Embedding(n_classes, n_emb)
        self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d),
                                     lin(n_emb, n_emb))
        self.downs = nn.ModuleList()
        for i in range(len(nfs)):
            ni = nf
            nf = nfs[i]
            self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=len(nfs)-1, num_layers=num_layers))
        self.mid_block = EmbResBlock(n_emb, nfs[-1])

        rev_nfs = list(reversed(nfs))
        nf = rev_nfs[0]
        self.ups = nn.ModuleList()
        for i in range(len(nfs)):
            prev_nf = nf
            nf = rev_nfs[i]
            ni = rev_nfs[min(i+1, len(nfs)-1)]
            self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=num_layers+1))
        self.conv_out = pre_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)

    def forward(self, inp):
        x,t,c = inp
        temb = timestep_embedding(t, self.n_temb)
        cemb = self.cond_emb(c)
        emb = self.emb_mlp(temb) + cemb
        x = self.conv_in(x)
        saved = [x]
        for block in self.downs: x = block(x, emb)
        saved += [p for o in self.downs for p in o.saved]
        x = self.mid_block(x, emb)
        for block in self.ups: x = block(x, emb, saved)
        return self.conv_out(x)

In [None]:
lr = 1e-2
epochs = 25
opt_func = partial(optim.Adam, eps=1e-5)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = CondUNetModel(10, in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func)

In [None]:
learn.fit(epochs)

In [None]:
sz = (256,1,32,32)

In [None]:
lbls = dataset['train'].features[yl].names
lbls

In [None]:
set_seed(42)
cid = 0
# see in diffusion_ai.diffusion.py file.
preds = cond_sample(cid, ddim_step, model, sz, steps=100, eta=1.)
s = (preds[-1]*2)
show_images(s[:25].clamp(-1,1), imsize=1.5, suptitle=lbls[cid])

In [None]:
set_seed(42)
cid = 0
preds = cond_sample(cid, ddim_step, model, sz, steps=100, eta=0.)
s = (preds[-1]*2)
show_images(s[:25].clamp(-1,1), imsize=1.5, suptitle=lbls[cid])

---

<font face='monospace'>

### 4️⃣

<b>Variational Auto Encoder</b>

_NOTE_: The goal of VAE is to encode a normal image from higher dimension into a latent image of lowe dimensions. Then we reconstruct the input images but decoding the latents hence we set both inputs and targets to be the same flattened images.

Here we just see how latents can be used as inputs to the unet model instead of noisy images. **WHY?** Because we can have input images of any resolution, and training UNET models for higher resolution is very costly and time consuming. So, we reduce the dimension of the image and use them to save time, cost, space, etc. Just make life simpler!


In [None]:
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 64
dsd = load_dataset(name)

@inplace
def transformi(b):
    img = [TF.to_tensor(o).flatten() for o in b[xl]]
    b[yl] = b[xl] = img
    # This means that both the inputs (xl) and targets (yl)
    # for the model are now the same flattened images.

tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=4)

dl = dls.valid
xb,yb = b = next(iter(dl))

<font face='monospace'>

- **ni**: Input dimension (number of features)
- **nh**: Hidden layer dimension (number of neurons in each hidden layer)
- **nl**: Latent space dimension (dimensionality of the compressed representation)

These are the dimension variables used in a simple autoencoder.

In [None]:
ni,nh,nl = 784,400,200

In [None]:
def lin(ni, nf, act=nn.SiLU, norm=nn.BatchNorm1d, bias=True):
    layers = nn.Sequential(nn.Linear(ni, nf, bias=bias))
    if act : layers.append(act())
    if norm: layers.append(norm(nf))
    return layers

In [None]:
def init_weights(m, leaky=0.):
    if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)):
        init.kaiming_normal_(m.weight, a=leaky)

In [None]:
iw = partial(init_weights, leaky=0.2)

In [None]:
class Autoenc(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(lin(ni, nh), lin(nh, nh), lin(nh, nl))
        self.dec = nn.Sequential(lin(nl, nh), lin(nh, nh), lin(nh, ni, act=None))
        iw(self)

    def forward(self, x):
        x = self.enc(x)
        return self.dec(x)

In [None]:
lr = 3e-2
epochs = 20
tmax = epochs * len(dls.train)
opt_func = partial(optim.Adam, eps=1e-5)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = Autoenc()
learn = Learner(model, dls, nn.BCEWithLogitsLoss(), lr=lr, cbs=cbs, opt_func=opt_func)

In [None]:
learn.fit(epochs)

<font face='monospace'>
After training we get an encdoer model which can be used to compress our dataset and use the compressed version for training.
let's try sampling from our autoencoder model.

In [None]:
with torch.no_grad(): t = to_cpu(model(xb).float())

In [None]:
# original
show_images(xb[:9].reshape(-1,1,28,28), imsize=1.5, title='Original');

In [None]:
# generated
show_images(t[:9].reshape(-1,1,28,28).sigmoid(), imsize=1.5, title='Autoenc');

<font face='monospace' color='#800080'>
<div class="alert alert-info">
  <i class="fas fa-lightbulb"></i> latents are not the images themselves, they are tensors, but are valuable for many computer vision tasks due to their ability to encapsulate important aspects of the data in a more efficient form.
</div>

<font face='monospace'>

**Variational Autoencoder (VAE)**: An extension of the autoencoder architecture that introduces probabilistic elements into the latent space. Here are the key differences:

1. **Encoder**:
   - The VAE's encoder maps the input data to a lower-dimensional latent space.
   - Instead of directly producing a fixed latent representation, it generates two vectors:
     - **Mean (`mu`)**: Represents the center of a Gaussian distribution in the latent space.
     - **Log Variance (`lv`)**: Determines the spread or uncertainty of the distribution.
   - By combining `mu` and `lv`, the encoder creates a probabilistic representation of the input data.

2. **Latent Space**:
   - The latent space in a VAE is continuous and probabilistic.
   - It allows for sampling from the learned distribution, enabling generative capabilities.
   - Each point in the latent space corresponds to a potential data point.

3. **Decoder**:
   - The decoder takes a sampled latent vector as input.
   - It reconstructs the original input data from this probabilistic representation.
   - The decoder learns to generate realistic data points by sampling from the latent space.

**Use Cases**:
  - Image generation, Data denoising, Learning meaningful representations for downstream tasks.

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(lin(ni, nh), lin(nh, nh))
        self.mu,self.lv = lin(nh, nl, act=None),lin(nh, nl, act=None)
        self.dec = nn.Sequential(lin(nl, nh), lin(nh, nh), lin(nh, ni, act=None))
        iw(self)

    def forward(self, x):
        x = self.enc(x)
        mu,lv = self.mu(x),self.lv(x)
        z = mu + (0.5*lv).exp()*torch.randn_like(lv)
        return self.dec(z),mu,lv

In [None]:
# Kullback-Leibler divergence loss formula is used to calculate the divergence b/w 
# model's learned distribution space and the dataset's original distribution

def kld_loss(inp, x):
    x_hat,mu,lv = inp
    return -0.5 * (1 + lv - mu.pow(2) - lv.exp()).mean()

def bce_loss(inp, x): 
    return F.binary_cross_entropy_with_logits(inp[0], x)

def vae_loss(inp, x):
    return kld_loss(inp, x) + bce_loss(inp,x)

In [None]:
class FuncMetric(Mean):
    def __init__(self, fn, device=None):
        super().__init__(device=device)
        self.fn = fn

    def update(self, inp, targets):
        self.weighted_sum += self.fn(inp, targets)
        self.weights += 1

In [None]:
metrics = MetricsCB(kld=FuncMetric(kld_loss), bce=FuncMetric(bce_loss))
opt_func = partial(optim.Adam, eps=1e-5)

In [None]:
lr = 3e-2
epochs = 20
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), metrics, BatchSchedCB(sched), MixedPrecision()]
model = VAE()
learn = Learner(model, dls, vae_loss, lr=lr, cbs=cbs, opt_func=opt_func)

In [None]:
learn.fit(epochs)

<font face='monospace'>**Sampling** from our **VAE**.

In [None]:
with torch.no_grad(): t,mu,lv = to_cpu(model(xb))
t = t.float()

In [None]:
show_images(xb[:9].reshape(-1,1,28,28), imsize=1.5, title='Original');

In [None]:
show_images(t[:9].reshape(-1,1,28,28).sigmoid(), imsize=1.5, title='VAE');

In [None]:
# using normal distribution noise to sample new data

noise = torch.randn(16, nl)
with torch.no_grad(): 
    ims = model.dec(noise).sigmoid()
show_images(ims.reshape(-1, 1, 28, 28), imsize=1.5)

<font face='monospace'> Now all we have to do is just try using these latents as the input data and train our model. This makes our model more flexible. We will just see how to put these latents into the model. The rest of the procedure remains same.

<font face='monospace'>

### 5️⃣

Using latents as training data for a diffusion model. This is normally done for colour images and not black and white. The code looks something like this.

```python

# Directory setup, data download, and data extraction
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path = path_data / 'bedroom'
url = 'https://s3.amazonaws.com/fast-ai-imageclas/bedroom.tgz'
if not path.exists():
    path_zip = fc.urlsave(url, path_data)
    shutil.unpack_archive('data/bedroom.tgz', 'data')

# Create a dataset
def to_img(f): 
    return read_image(f, mode=ImageReadMode.RGB) / 255

class ImagesDS:
    def __init__(self, spec):
        self.path = Path(path)
        self.files = glob(str(spec), recursive=True)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        return to_img(self.files[i])[:, :256, :256]
        
ds = ImagesDS(path / '**/*.jpg')
dl = DataLoader(ds, batch_size=64, num_workers=defaults.cpus)
xb = next(iter(dl))
show_images(xb[:16], imsize=2)

# Using Hugging Face model (Variational Autoencoder)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").cuda().requires_grad_(False)
xe = vae.encode(xb.cuda())
xs = xe.latent_dist.mean[:16]
show_images(((xs[:16, :3]) / 4).sigmoid(), imsize=2)
xd = to_cpu(vae.decode(xs))
show_images(xd['sample'].clamp(0, 1), imsize=2)

# Create a memory-mapped array for latents
mmpath = Path('data/bedroom/data.npmm')
if not mmpath.exists():
    a = np.memmap(mmpath, np.float32, mode='w+', shape=(303125, 4, 32, 32))
    i = 0
    for b in progress_bar(dl):
        n = len(b)
        a[i:i + n] = to_cpu(vae.encode(b.cuda()).latent_dist.mean).numpy()
        i += n
    a.flush()
    del a

# collate function
def collate(b):
    return noisify(default_collate(b)*0.2)

# Split latents into training and validation sets
lats = np.memmap(mmpath, dtype=np.float32, mode='r', shape=(303125, 4, 32, 32))
tds = lats[:len(lats) // 10 * 9]
vds = lats[len(lats) // 10 * 9:]
bs = 128
dls = DataLoaders(*get_dls(tds, vds, bs=bs, num_workers=defaults.cpus, collate_fn=collate))
(xt, t), eps = b = next(iter(dls.train))

# Initialize DDPM model
def init_ddpm(model):
    for o in model.downs:
        for p in o.resnets:
            p.conv2[-1].weight.data.zero_()

    for o in model.ups:
        for p in o.resnets:
            p.conv2[-1].weight.data.zero_()

lr = 3e-3
epochs = 25
opt_func = partial(optim.AdamW, eps=1e-5)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]

# Create DDPM model using the embedded unet architecture
model = EmbUNetModel(in_channels=4, out_channels=4, nfs=(128, 256, 512, 768), num_layers=2, attn_start=1, attn_chans=16)
init_ddpm(model)
learn = Learner(model, dls, MSELossFlat(), lr=lr, cbs=cbs, opt_func=opt_func)
learn.fit(epochs)

# Generate samples
sample_size = (16, 4, 32, 32)
preds = sample(ddim_step, model, sample_size, steps=100, eta=1.0, clamp=False)
S = preds[-1]

# Reconstruct images from latents
with torch.no_grad():
    pd = to_cpu(vae.decode(S.cuda()))
show_images(pd['sample'][:9].clamp(0, 1), im
```

<br>
Note that this is again normally done for color images and not needed for grayscale images.

<font face='monospace'>
<div class="alert alert-warning">
  <i class="fas fa-exclamation-circle"></i>&nbsp;<strong>Warning</strong><br>
  The minimum GPU memory required depends on your specific model and dataset.
  For most of the above models, GPUs with at least 4GB VRAM are recommended.
  NVIDIA T4 - A single <code>T4 GPU</code>, <code>16GB VRAM</code>, <code>8 CPUs</code> having a speed of <code>65 TFLOPs</code> would still crash if used for above models without proper strategy.  

  Try to train the above models if the below requirements are satisfied. Or else just understand how the code works.
  - GPUs capable of running the above models
      - <code>NVIDIA P100</code> ~16GB VRAM
      - <code>NVIDIA V100</code>: ~16GB to 32GB VRAM
      - <code>NVIDIA A100</code>: ~40GB to 80GB VRAM
  - Using Cloud GPUs
</div>

In [None]:
import gc
gc.collect()

In [None]:
%reset -f