We estimate the # of parameters in edm diffusion models,
and estimate the inference cost of these model.

### Imports

In [1]:
import pickle

import numpy as np
import PIL.Image
import torch
from tqdm.auto import tqdm

from torch import nn

import dnnlib

### Load network

In [2]:
MODEL_ROOT = "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained"

with dnnlib.util.open_url(f"{MODEL_ROOT}/edm-cifar10-32x32-cond-vp.pkl") as f:
    net_cond = pickle.load(f)["ema"]
with dnnlib.util.open_url(f"{MODEL_ROOT}/edm-cifar10-32x32-uncond-vp.pkl") as f:
    net_uncond = pickle.load(f)["ema"]

In [3]:
def get_num_params(net: nn.Module) -> int:
    return sum(p.numel() for p in net.parameters())

print(f"cond:   {get_num_params(net_cond):,d} params")
print(f"uncond: {get_num_params(net_uncond):,d} params")

cond:   55,735,299 params
uncond: 55,733,891 params


### Measure inference cost

In [4]:
net = net_uncond.cuda()

seed = 0
gridw = 10
gridh = 10
num_steps = 200
rho = 7
S_churn = 0
S_min = 0
S_max = float("inf")
S_noise = 1

device = torch.device("cuda")

sigma_min = 0.002
sigma_max = 80
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)

torch.manual_seed(seed)
batch_size = gridw * gridh

latents = torch.randn(
    [batch_size, net.img_channels, net.img_resolution, net.img_resolution],
    device=device,
)
class_labels = None

# Time step discretization.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
t_steps = (
    sigma_max ** (1 / rho)
    + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho
t_steps = torch.cat(
    [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]
)  # t_N = 0

In [5]:
with torch.no_grad():
    x_next = latents.to(torch.float64) * t_steps[0]
    for i, (t_cur, t_next) in tqdm(
        list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit="step"
    ):  # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = (
            min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        )
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(
            x_cur
        )

        # Euler step.
        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = net(x_next, t_next, class_labels).to(torch.float64)
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

  0%|          | 0/200 [00:00<?, ?step/s]