# Objective

This notebook trains a bare-bones diffusion model ("Dino Diffusion") to unconditionally generate images.

To do this, it trains a simple neural network that denoises patches of images. This network is then reused across space and time, to generate new samples starting from pure noise.

The main sources of complexity (not in the base dino diffusion notebook) are:
1. Sketch conditioning (consuming a sketch that shows what to draw)
2. Cascaded upsampling (upsampling low-resolution generations into higher-resolution ones)

I recommend understanding the base notebook first.

Also TODO: this notebook is way messier and harder to read, sorry :)

# Config

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = "retina"

import random
from collections import namedtuple
from pathlib import Path

import matplotlib.pyplot as plt
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as tv
from PIL import Image
from tqdm import tqdm

th.backends.cudnn.benchmark = True

In [None]:
class Config:
    device = "cuda" if th.cuda.is_available() else "cpu"
    channels = 3
    patch_hw = 64
    upscale_factor = 2
    dataset = "pokemon" #"public_domain_plants_tiny"

def show(x):
    if not isinstance(x, th.Tensor) or x.ndim == 4:
        x = th.cat(tuple(x), -1)
    display(tv.to_pil_image(x))

# Dataset


In [None]:
def get_dataset(name):
    if Path(name).exists():
        print(f"dataset '{name}' already exists; skipping...")
        return
    !git clone https://huggingface.co/datasets/huggan/{name} && (cd {name} && git lfs pull)
    import pyarrow.parquet as pq
    from io import BytesIO

    i = 0
    for table in Path(f"{name}/data").glob("*.parquet"):
        for row in tqdm(pq.read_table(table)[0]):
            Image.open(BytesIO(row["bytes"].as_py())).save(f"{name}/{i:04d}.jpg")
            i += 1

get_dataset(Config.dataset)

In [None]:
import numpy as np
from torchvision.transforms import InterpolationMode
from functools import lru_cache
from PIL import ImageFilter

Patch = namedtuple("Patch", ("patch", "coords", "cond"))

def coords_for_patch(x, y, width, height):
    cx = th.linspace(x, x + width, Config.patch_hw).view(1, -1).expand(Config.patch_hw, Config.patch_hw)
    cy = th.linspace(y, y + height, Config.patch_hw).view(-1, 1).expand(Config.patch_hw, Config.patch_hw)
    cz = th.ones_like(cx)
    return th.stack((cx, cy, cz))

def threshold_cond(img):
    img = np.array(img.convert("RGB"))
    return Image.fromarray(((np.array(img) > random.randint(0, 192)).astype(np.uint8) * 255))

def box_draw(img_th):
    c, h, w = img_th.shape
    for i in range(random.randint(0, 3)):
        bh, bw = random.randrange(1, h), random.randrange(1, w)
        y, x = 0, 0
        if bh < h: y = random.randrange(0, h - bh)
        if bw < w: x = random.randrange(0, w - bw)
        img_th[:, y:y+bh,x:x+bw] = 255
    return img_th

def process_conditioning(img):
    arr = np.array(img.convert("RGB"))[..., :3]
    linework = (arr < random.randint(20, 100)).all(-1)
    img = Image.fromarray((~linework * 255).astype(np.uint8))
    img = img.filter(ImageFilter.MedianFilter(3))
    return img

class Dataset:
    def __init__(self, p):
        self.ims = []
        if not isinstance(p, (tuple, list)):
            p = [p]
        for pi in p:
            self.ims.extend(Path(pi).rglob("*.png"))
            self.ims.extend(Path(pi).rglob("*.jpg"))
    def __len__(self):
        return len(self.ims)
    def __getitem__(self, i, aug_rotation=True):
        img = Image.open(self.ims[i])
        cond = None
        if random.random() < 0.5:
            cond = process_conditioning(img)
        patch_scale = random.random()
        crop_hw = Config.patch_hw + int(patch_scale * (min(img.size) - Config.patch_hw))
        crop_x = random.randrange(0, img.width - crop_hw) if img.width > crop_hw else 0
        crop_y = random.randrange(0, img.height - crop_hw) if img.height > crop_hw else 0
        crop = img.crop((crop_x, crop_y, crop_x + crop_hw, crop_y + crop_hw))
        if cond is not None:
            cond = cond.crop((crop_x, crop_y, crop_x + crop_hw, crop_y + crop_hw))
        coords = coords_for_patch(crop_x / img.width, crop_y / img.height, crop_hw / img.width, crop_hw / img.height)
        coords = (coords * 255).round().byte()
        if aug_rotation and crop.width > 3 * Config.patch_hw and crop.height > 3 * Config.patch_hw and random.random() < 0.5:
            patch = crop.resize((Config.patch_hw * 3, Config.patch_hw * 3)).convert("RGB")
            if cond is not None:
                cond = cond.resize(patch.size).convert("RGB")
            coords = tv.resize(coords, patch.size)
            angle = random.random() * (20) - 10
            uncropped_size = (int(Config.patch_hw * 1.5), int(Config.patch_hw * 1.5))
            patch = tv.center_crop(tv.resize(tv.rotate(patch, angle, InterpolationMode.BILINEAR), uncropped_size, InterpolationMode.NEAREST), (Config.patch_hw, Config.patch_hw))
            coords = tv.center_crop(tv.resize(tv.rotate(coords, angle, InterpolationMode.BILINEAR), uncropped_size), (Config.patch_hw, Config.patch_hw))
            if cond is not None:
                cond = tv.center_crop(tv.resize(tv.rotate(cond, angle, InterpolationMode.BILINEAR), uncropped_size), (Config.patch_hw, Config.patch_hw))
        else:
            patch = crop.resize((Config.patch_hw, Config.patch_hw))
            if cond is not None:
                cond = cond.resize((Config.patch_hw, Config.patch_hw))
        if random.random() < 0.5:
            patch, coords = tv.hflip(patch), tv.hflip(coords)
            if cond is not None:
                cond = tv.hflip(cond)
        cond = 255 * th.ones(3, Config.patch_hw, Config.patch_hw, dtype=th.uint8) if cond is None else box_draw(tv.pil_to_tensor(threshold_cond(cond)))
        return Patch(tv.pil_to_tensor(patch), coords, cond)

d_train = Dataset(Config.dataset)

In [None]:
def demo_dataset(dataset, n=16):
    print(f"Dataset has {len(dataset)} images.")
    print(f"Here are some sample patches from the dataset:")
    samples = random.choices(dataset, k=n)
    show(s.patch for s in samples)
    show(s.coords for s in samples)
    show(s.cond for s in samples)

demo_dataset(d_train)

# Model

Next, we define the neural network.

It will take in a noisy patch, the % of noise, the patch coordinates, and an optional low-resolution patch, and try to predict the corresponding (denoised) high-resolution patch.

In [None]:
Prediction = namedtuple("Prediction", ("denoised"))

class Blocks(nn.Module):
    def __init__(self, n_in, n_f, n_out, n_l, bias=False):
        super().__init__()
        self.cs = nn.ModuleList()
        for i in range(n_l):
            self.cs.append(nn.Sequential(
                nn.Conv2d(n_in if i == 0 else n_f, n_f, 3, padding = 1), nn.ReLU(),
                nn.Conv2d(n_f, n_f, 3, padding=1), nn.ReLU(),
            ))
        self.s = nn.Conv2d(n_in + n_f * n_l, n_out, 1, bias=bias)
    def forward(self, x):
        out = [x]
        for b in self.cs:
            x = b(x)
            out.append(x)
        return self.s(th.cat(out, 1))
def EncBlock(n_in, n_out, n_b):
    return nn.Sequential(nn.Conv2d(n_in, n_out, 4, stride=2, padding=1, bias=False), Blocks(n_out, n_out, n_out, n_b))

def DecBlock(n_in, n_out, n_b):
    return nn.Sequential(Blocks(n_in, n_out, n_out * 4, n_b, bias=False), nn.PixelShuffle(2))

class UNet(nn.Module):
    def __init__(self, n_in_out=Config.channels, n_f=[64, 64, 64, 64, 64], n_b=[1, 2, 2, 2, 2]):
        super().__init__()
        self.c_cat = Blocks(n_in_out * 4 + 1, n_f[0], n_f[0], n_b[0])
        self.c_enc = nn.ModuleList([EncBlock(n_f[i-1], n_f[i], n_b[i]) for i in range(1, len(n_f))])
        self.c_dec = nn.ModuleList([DecBlock(n_f[-i] * (1 if i == 1 else 2), n_f[-i-1], n_b[-i]) for i in range(1, len(n_f))])
        self.c_out = nn.Sequential(Blocks(n_f[0] * 2, n_f[0], n_f[0], n_b[0]), nn.Conv2d(n_f[0], n_in_out, 3, padding=1))
        nn.init.constant_(self.c_out[-1].bias, 0.5)
    def forward(self, x, noise_level, x_lowres, x_coords, x_cond):
        x = self.c_cat(th.cat([x, x_lowres, x_coords, x_cond, noise_level + 0 * x[:, :1]], 1))
        skips = []
        for c_enc in self.c_enc:
            skips.append(x)
            x = c_enc(x)
        for c_dec in self.c_dec:
            x = th.cat([c_dec(x), skips.pop()], 1)
        return Prediction(self.c_out(x))

model = UNet().to(Config.device)

# Our final model will use a smoothed average of recent weights
# (this is a ~free way to get a higher-quality final model)
def weight_average(w_prev, w_new, _):
    return 0.9 * w_prev + 0.1 * w_new
avg_model = th.optim.swa_utils.AveragedModel(model, avg_fn=weight_average)

In [None]:
import time
@th.no_grad()
def demo_model(model, n=16):
    model.eval()
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {n_parameters / 1e6:.1f} million trainable parameters ({4 * n_parameters / 1e6:.1f} MB).")
    x, x_lowres, x_coords, x_cond = (th.rand(n, Config.channels, Config.patch_hw, Config.patch_hw, device=Config.device) for _ in "rawr")
    noise_level = th.rand(n, 1, 1, 1, device=Config.device)
    tick = time.time()
    for _ in range(10):
        y = model(x, noise_level, x_lowres, x_coords, x_cond)
    tock = time.time()
    print(f"Here are some model outputs on random noise in {100 * (tock - tick):.1f}ms:")
    show(y.denoised.clamp(0, 1))
    model.train()

demo_model(model)

# Diffusion Training Logic

To train our diffusion model, we need some code that adds noise to patches.

In [None]:
NoisyPatch = namedtuple("NoisyPatch", ("patch", "noisy_patch", "noise", "noise_level", "lowres", "coords", "cond", "has_lowres"))

def alpha_blend(a, b, alpha):
    return alpha * a + (1 - alpha) * b

def mix_blur(x, alpha):
    return alpha_blend(F.avg_pool2d(x, 3, padding=1, count_include_pad=False, stride=1), x, alpha)

def augmented_lowres(x, ds=2):
    x_avg = F.avg_pool2d(x, ds, stride=ds)
    blur_alpha = F.interpolate(th.rand(x_avg.shape[0], 1, x_avg.shape[-2]//4, x_avg.shape[-1]//4, device=x_avg.device), scale_factor=4, mode="bilinear")
    noise_alpha = F.interpolate(th.rand(x_avg.shape[0], 1, x_avg.shape[-2]//4, x_avg.shape[-1]//4, device=x_avg.device), scale_factor=4, mode="bilinear")
    x_down = mix_blur(x_avg, blur_alpha)
    x_down = x_down + 0.02 * noise_alpha * th.rand_like(x_down[:, :1, :1, :1]) * th.randn_like(x_down)
    return F.interpolate(x_down, scale_factor=ds)

@th.no_grad()
def add_noise_to_patches(patches, lowres_dropout=0.75):
    patch, coords, cond = (x.to(Config.device) / 255.0 for x in patches)
    noise_level = th.rand_like(patch[:, :1, :1, :1])
    dropout_mask = th.rand_like(noise_level) < lowres_dropout
    lowres = dropout_mask * -1 + ~dropout_mask * augmented_lowres(patch)
    noise = th.rand_like(patch)
    noisy_patch = alpha_blend(noise, patch, noise_level)
    return NoisyPatch(patch, noisy_patch, noise, noise_level, lowres, coords, cond, ~dropout_mask)

In [None]:
def demo_data_generation(dataset, n_demo=16):
    patches = next(iter(th.utils.data.DataLoader(dataset, batch_size=n_demo, shuffle=True)))
    print("Here's what the targets look like during training")
    print("Here's what the inputs look like during training")
    patches = add_noise_to_patches(patches)
    show(patches.patch)
    show(patches.noisy_patch)
    show(patches.noise_level.expand(n_demo, 3, 16, Config.patch_hw))
    show(patches.lowres.clamp(0, 1))
    show(patches.coords)
    show(patches.cond)

demo_data_generation(d_train)

# Diffusion Sampling Logic

To sample from our diffusion model, we need some code that iteratively removes noise from a given patch using the model.

In [None]:
@th.no_grad()
def generate_images_at_resolution(model, lowres_patches, coords, n_steps, generator=None):
    pure_noise = th.rand(*lowres_patches.shape, device=Config.device, generator=generator)
    x = pure_noise
    noise_levels = th.linspace(1, 0, n_steps + 1, device=Config.device)
    cond = th.ones_like(lowres_patches)
    for nl_in, nl_out in zip(noise_levels, noise_levels[1:]):
        denoised_patches = model(x, nl_in.expand(x[:, :1, :1, :1].shape), lowres_patches, coords, cond).denoised
        x = alpha_blend(x, denoised_patches, nl_out / nl_in)
    return x.clamp(0, 1)

@th.no_grad()
def generate_images(model, n_images=4, n_steps_per_resolution=[50], generator=None):
    model.eval()
    lowres_patches = -th.ones(n_images, Config.channels, Config.patch_hw, Config.patch_hw, device=Config.device)
    coords = coords_for_patch(0, 0, 1, 1).unsqueeze(0).to(Config.device).expand(lowres_patches.shape)
    for n_steps in n_steps_per_resolution:
        patches = generate_images_at_resolution(model, lowres_patches, coords, n_steps, generator=generator)
        lowres_patches = F.interpolate(patches, scale_factor=Config.upscale_factor)
        coords = F.interpolate(coords, scale_factor=Config.upscale_factor, mode="bilinear", align_corners=False)
    model.train()
    return patches

In [None]:
def demo_image_generation(model, n_demo=4, seed=9):
    print("Here are some samples from the model")
    def rng(): return th.Generator(device=Config.device).manual_seed(seed)
    show(generate_images(model, n_images=16, n_steps_per_resolution=[50], generator=rng()).clamp(0, 1))
    show(generate_images(model, n_images=8, n_steps_per_resolution=[50, 5], generator=rng()).clamp(0, 1))
    show(generate_images(model, n_images=4, n_steps_per_resolution=[50, 20, 10], generator=rng()).clamp(0, 1))
    show(generate_images(model, n_images=2, n_steps_per_resolution=[50, 10, 5, 1], generator=rng()).clamp(0, 1))
demo_image_generation(avg_model)

# Training

Finally, we write a training loop that loads patches from the dataset, adds noise to them, and trains our neural network to remove the noise.

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import random
from collections import defaultdict
from IPython.display import clear_output
import time
import datetime
def _mean(x):
    return sum(x) / len(x)

class Repeat(th.utils.data.Dataset):
    def __init__(self, dataset, n=1<<20):
        self.dataset = dataset
        self.n = n
    def __len__(self):
        return self.n
    def __getitem__(self, i):
        return self.dataset[i % len(self.dataset)]

class TinyProfiler:
    def __init__(self):
        self.t = {}
        self.r = defaultdict(list)
    def tick(self, k):
        self.t[k] = time.time();
    def tock(self, k):
        if k not in self.t:
            raise ValueError(k)
        elapsed = time.time() - self.t[k]
        self.r[k].append(elapsed)
    def __repr__(self):
        return "Timing:\n"+ "\n".join(f"{'  ' * k.count('_')}{k.split('_')[-1].ljust(16)}: \033[34m{1000 * _mean(self.r[k][-10:]):0.1f}\033[0m\033[37mms\033[0m" for k in sorted(self.r))

def init_log_path():
    i = 0
    log_path = datetime.datetime.now().strftime("%Y_%m_%d") + "_dino_diffusion"
    while Path(log_path + f"_{i}").exists() and len(list(Path(log_path + f"_{i}").glob("*.jpg"))) > 1:
        i += 1
    log_path += f"_{i}"
    print(log_path)
    !mkdir -p "$log_path"
    return log_path

class DinoTrainer:
    def __init__(self, model, dataset, ema_model=None, batch_size=16, n_demo=16, lr=3e-4):
        self.stats = defaultdict(list)
        self.stat_steps = defaultdict(list)
        self.model = model
        self.ema_model = ema_model
        self.opt = th.optim.AdamW(self.model.parameters(), lr, amsgrad=True)
        self.dl = th.utils.data.DataLoader(Repeat(dataset), batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
        self.batch_size = batch_size
        self.n_demo = n_demo
        self.log_path = init_log_path()
        self.prof = TinyProfiler()
        self.step = 0
        self.generator = th.Generator(device=Config.device)
        # adversarial loss to try and make upsampling model generate images faster
        self.adv_model = nn.Sequential(
            nn.Conv2d(3, 8, 4, stride=2), nn.ReLU(),
            nn.Conv2d(8, 16, 4, stride=2, groups=2), nn.ReLU(),
            nn.Conv2d(16, 32, 3, groups=4), nn.ReLU(),
            nn.Conv2d(32, 64, 3, groups=4), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(64, 1, 1)
        ).to(Config.device)
        self.adv_opt = th.optim.AdamW(self.adv_model.parameters(), lr)

    def loss(self, pred, gt, noise_level, has_lowres):
        a = self.adv_model(pred)
        with th.no_grad():
            b = self.adv_model(gt)
        # adversarial loss is only applied when upsampling
        return F.mse_loss(pred, gt) + 0.25 * th.pow(has_lowres * (a - b), 2).mean()

    def train(self, n_steps=1000_000, render_interval_s=30, demo_interval_s=60, save_interval_s=300, runtime_s=15*60, weight_avg_interval_steps=100):
        start = time.time()
        last_disp_time = 0
        last_demo_time = -1000
        last_save_time = -1000
        last_render_time = -1000

        dl_gen = iter(self.dl)
        loss_acc = []
        for rel_step in range(1, 1 + n_steps):
            self.step += 1
            seconds_elapsed = (time.time() - start)
            render_step = seconds_elapsed > last_render_time + render_interval_s

            self.prof.tick("step")

            self.prof.tick("dl")
            self.prof.tick("dl_next")
            try:
                xb = next(dl_gen)
            except StopIteration:
                dl_gen = iter(self.dl)
                xb = next(dl_gen)
            self.prof.tock("dl_next")
            self.prof.tick("dl_copy")
            xb = add_noise_to_patches(xb)
            self.prof.tock("dl_copy")
            self.prof.tock("dl")

            self.prof.tick("model")

            self.prof.tick("model_grad")
            self.model.train()
            xpb = self.model(xb.noisy_patch, xb.noise_level, xb.lowres, xb.coords, xb.cond).denoised
            if render_step:
                grad_catcher = th.zeros_like(xpb).requires_grad_(True)
                xpb = xpb + grad_catcher
            loss = self.loss(xpb, xb.patch, xb.noise_level, xb.has_lowres)
            loss_acc.append(loss.item())
            self.prof.tock("model_grad")
            self.prof.tick("model_bwd")
            self.opt.zero_grad(); loss.backward(); self.opt.step()
            self.prof.tock("model_bwd")
            self.prof.tock("model")

            # update adversarial trainer
            adv_loss = (
                F.mse_loss(self.adv_model(xpb.detach()), xb.noise_level) +
                F.mse_loss(self.adv_model(xb.patch), 0 * xb.noise_level)
            )
            self.adv_opt.zero_grad(); adv_loss.backward(); self.adv_opt.step()

            self.prof.tock("step")

            if self.ema_model is not None and rel_step % weight_avg_interval_steps == 0:
                self.ema_model.update_parameters(self.model)

            # steps per minute
            minutes_elapsed = seconds_elapsed / 60
            est_time = n_steps / rel_step * minutes_elapsed
            est_time_str = f"\033[35m{est_time:.3f}\033[0m minutes" if est_time < 60 else f"\033[35m{est_time / 60:.1f}\033[0m hours"
            if render_step:
                # losses as usual
                self.stats["model_loss"].append(sum(loss_acc) / len(loss_acc))
                loss_acc = []
                self.stat_steps["model_loss"].append(self.step)
                last_render_time = seconds_elapsed
                self.generator.manual_seed(1337)
                demo_img = tv.to_pil_image(th.cat(tuple(generate_images(self.model, n_images=16, n_steps_per_resolution=[32], generator=self.generator)), -1).clamp(0, 1))
                if self.ema_model is not None:
                    self.generator.manual_seed(1337)
                    demo_img_ema = tv.to_pil_image(th.cat(tuple(generate_images(self.ema_model, n_images=16, n_steps_per_resolution=[32], generator=self.generator)), -1).clamp(0, 1))
                self.generator.manual_seed(1337)
                demo_img_hr = tv.to_pil_image(th.cat(tuple(generate_images(self.ema_model, n_images=4, n_steps_per_resolution=[32, 16, 8], generator=self.generator)), -1).clamp(0, 1))
                clear_output(wait=True)
                print("demo on fixed noise")
                display(demo_img)
                if self.ema_model is not None:
                    display(demo_img_ema)
                display(demo_img_hr)
                with th.no_grad():
                    print("model training")
                    mask = th.randperm(len(xb.patch))[:self.n_demo].sort(0).values
                    print("input (input_nl, input_lr, input_coords, input)")
                    display(tv.to_pil_image(th.cat(tuple(xb.noise_level.detach().expand(xb.patch.shape)[mask, :3, :16]), -1)))
                    display(tv.to_pil_image(th.cat(tuple(xb.lowres.detach()[mask, :3]), -1).clamp(0, 1)))
                    display(tv.to_pil_image(th.cat(tuple(xb.coords.detach()[mask, :3]), -1).clamp(0, 1)))
                    display(tv.to_pil_image(th.cat(tuple(xb.cond.detach()[mask, :3]), -1).clamp(0, 1)))
                    display(tv.to_pil_image(th.cat(tuple(xb.noisy_patch.detach()[mask, :3]), -1).clamp(0, 1)))
                    print("pred")
                    display(tv.to_pil_image(th.cat(tuple(xpb.detach()[mask]), -1).clamp(0, 1)))
                    print("target")
                    display(tv.to_pil_image(th.cat(tuple(xb.patch.detach()[mask, :3]), -1)))
                    grad_im = grad_catcher.grad.detach()
                    print(f"grad (min {grad_im.min().item():.6f} mean {grad_im.mean().item():.6f} max {grad_im.max().item():.6f})")
                    # show flipped gradient
                    grad_im = 0.5 - grad_im / (th.abs(grad_im).max() + 1e-6)
                    display(tv.to_pil_image(th.cat(tuple(grad_im[mask]), -1).clamp(0, 1)))
                if seconds_elapsed > last_demo_time + demo_interval_s:
                    demo_img.save(f"{self.log_path}/{self.step:06d}.jpg", quality=95)
                    if self.ema_model is not None:
                        demo_img_ema.save(f"{self.log_path}/ema_{self.step:06d}.jpg", quality=95)
                    demo_img_hr.save(f"{self.log_path}/hr_{self.step:06d}.jpg", quality=95)
                    th.save((self.stats, self.stat_steps), f"{self.log_path}/stats.pth")
                    last_demo_time = seconds_elapsed
                if seconds_elapsed > last_save_time + save_interval_s:
                    print("saving checkpoints...")
                    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
                    th.save(self.model.state_dict(), f"{self.log_path}/model_checkpoint_last.pth")
                    th.save(self.ema_model.state_dict(), f"{self.log_path}/ema_model_checkpoint_last.pth")
                    export_model(self.ema_model, f"{self.log_path}/ema_model_checkpoint_last.onnx")
                    print("saved")
                    last_save_time = seconds_elapsed
                plt.title("Stats")
                for k, vs in self.stats.items():
                    plt.plot(self.stat_steps[k], vs, label=f"{k} ({vs[-1]:.5f})", linewidth=2, alpha=0.75)
                plt.ylim(0, 2 * max(self.stats["model_loss"][-3:]))
                plt.gcf().set_size_inches(12, 4)
                plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
                plt.gcf().savefig(f"{self.log_path}/stats.jpg", bbox_inches='tight')
                plt.show()
                print(self.prof)
                print()
            if seconds_elapsed > last_disp_time + 1:
                last_disp_time = seconds_elapsed
                def make_pbar(prop, w=20):
                    p = -int(-w * prop)
                    return f"\033[42m" + " " * p + "\033[40m" + " " * (w - p) + "\033[0m"
                render_interval_s_pos = seconds_elapsed - last_render_time
                print(
                    "\r"
                    f"{time.time() - start:.0f}s / {runtime_s:.0f}s. Step {self.step: 5d}. Demo in {render_interval_s - render_interval_s_pos:.1f}s "
                    f"{make_pbar(render_interval_s_pos / render_interval_s)}; 🦔 \033[34m{rel_step / minutes_elapsed:.3f}\033[0m Steps / Min. BS {self.batch_size}. "
                    f"\033[35m{self.batch_size * rel_step / minutes_elapsed:.3f}\033[0m Images / Min. "
                    f"Loss {loss.item():0.4f} "
                    f"{th.cuda.memory_reserved() / 1e9:.2f}GB"
                , end="")

            if time.time() - start > runtime_s:
                print("\n ✅ Training completed")
                self.model.eval()
                th.save(self.model.state_dict(), f"{self.log_path}/model_checkpoint_last.pth")
                break

trainer = DinoTrainer(model, d_train, ema_model=avg_model)

In [None]:
trainer.train(runtime_s=960) #1*60*60)

In [None]:
show(generate_images(avg_model, n_images=16, n_steps_per_resolution=[50]).clamp(0, 1))
show(generate_images(avg_model, n_images=8, n_steps_per_resolution=[32, 8]).clamp(0, 1))
show(generate_images(avg_model, n_images=4, n_steps_per_resolution=[32, 10, 1]).clamp(0, 1))

In [None]:
!pip install onnx

In [None]:
def export_model(model, name="model_test.onnx"):
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {n_parameters / 1e6:.1f} million trainable parameters.")
    x, x_lowres, x_coords, x_cond = (th.rand(1, Config.channels, Config.patch_hw, Config.patch_hw, device=Config.device) for _ in "rawr")
    noise_level = th.rand(1, 1, 1, 1, device=Config.device)
    th.onnx.export(model, (x, noise_level, x_lowres, x_coords, x_cond), name,
                   input_names=["x", "noise_level", "x_lowres", "x_coords", "x_cond"],
                   output_names=["denoised"]
                   , opset_version=9)
    print("Exported model to", name)

export_model(model)