In [None]:
import math, random
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from dataclasses import dataclass
from tqdm import tqdm
import os
from torchvision.models import inception_v3, Inception_V3_Weights

In [None]:
torch.backends.cudnn.benchmark = True

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

SEED = 42
set_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEVICE : {device}")

DEVICE : cuda


In [None]:
class DiffusionSchedule():
    def __init__(self,
                 T : int = 1000,
                 beta_start : float = 1e-4,
                 beta_end : float = 2e-2,
                 use_elbo_wts : bool = False,
                 group_norm_batch_size : int = 8):
        self.T = T
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.use_elbo_wts = use_elbo_wts
        self.group_norm_batch_size = group_norm_batch_size if group_norm_batch_size is not None else 8

        self.betas = torch.linspace(beta_start, beta_end, T, dtype = torch.float32)
        self.alphas = 1 - self.betas
        self.a_bar = torch.cumprod(self.alphas, dim = 0)
        self.sqrt_a_bar = torch.sqrt(self.a_bar)
        self.sqrt_one_minus_a_bar = torch.sqrt(1 - self.a_bar)
        self.a_bar_prev = torch.cat([torch.tensor([1.0]), self.a_bar[: -1]], dim = 0)
        self.post_var = self.betas * ((1.0 - self.a_bar_prev) / (1.0 - self.a_bar)) # beta_tilde

    def to(self, device):
        for k, v in self.__dict__.items():
            if torch.is_tensor(v):
                setattr(self, k, v.to(device))
        return self

In [None]:
sched = DiffusionSchedule().to(device)

In [None]:
class SineTimeEmbedding(nn.Module):
    def __init__(self, dim : int):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(nn.Linear(dim, 2 * dim), nn.SiLU(), nn.Linear(2 * dim, dim))

    def forward(self, time_vals : torch.Tensor, base : float = 10000.0):
        # time_vals : [B, ] , Contains time step values
        half = self.dim // 2
        freqs = torch.exp(-1 * (torch.arange(0, half, dtype = torch.float32, device = time_vals.device) / (half)) * math.log(base))
        args = time_vals[:, None] * freqs[None, :]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim = -1)
        return self.mlp(emb)

In [None]:
class ClassEmbedding(nn.Module):
    def __init__(self, num_classes : int = 10, dim : int = 128):
        super().__init__()
        self.embed = nn.Embedding(num_classes + 1, dim)
        # +1 class for CFG.

    def forward(self, y):
        # y : [B, ], List of class values
        return self.embed(y)

In [None]:
count_params = lambda m : sum([p.numel() for p in m.parameters()])

In [None]:
class ConvResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, cond_dim):
        super().__init__()
        self.norm_1 = nn.GroupNorm(sched.group_norm_batch_size, in_ch)
        self.conv_1 = nn.Conv2d(in_ch, out_ch, kernel_size = 3, padding = 1)
        self.norm_2 = nn.GroupNorm(sched.group_norm_batch_size, out_ch)
        self.conv_2 = nn.Conv2d(out_ch, out_ch, kernel_size = 3, padding = 1)
        self.skip = nn.Conv2d(in_ch, out_ch, kernel_size = 1) if in_ch != out_ch else nn.Identity()
        self.cond = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, 2 * out_ch))

    def forward(self, x : torch.tensor, cond_vec : torch.tensor):
        h = self.conv_1(F.silu(self.norm_1(x)))
        h = self.norm_2(h)
        gamma, beta = self.cond(cond_vec).chunk(2, dim = -1)
        h = ((1 + gamma[:, :, None, None]) * h) + beta[:, :, None, None]
        h = self.conv_2(F.silu(h))
        return h + self.skip(x)

In [None]:
class UpSample(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Conv2d(ch, ch, kernel_size = 3, padding = 1)

    def forward(self, x : torch.tensor):
        return self.conv(F.interpolate(x, scale_factor = 2, mode = "nearest"))

class DownSample(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Conv2d(ch, ch, kernel_size = 3, stride = 2, padding = 1)

    def forward(self, x : torch.tensor):
        return self.conv(x)

In [None]:
class UNET(nn.Module):
    def __init__(self, num_classes : int = 10, in_channels : int = 1, base : int = 32, cond_dim : int = 128):
        super().__init__()
        self.time_emb = SineTimeEmbedding(cond_dim)
        self.class_emb = ClassEmbedding(num_classes, cond_dim)

        self.in_conv = nn.Conv2d(in_channels, base, kernel_size = 3, padding = 1)

        self.d1 = ConvResBlock(base, base, cond_dim)
        self.d2 = ConvResBlock(base, 2 * base, cond_dim)
        self.down_1 = DownSample(2 * base)
        self.d3 = ConvResBlock(2 * base, 2 * base, cond_dim)
        self.d4 = ConvResBlock(2 * base, 4 * base, cond_dim)
        self.down_2 = DownSample(4 * base)

        self.mid = ConvResBlock(4 * base, 4 * base, cond_dim)

        self.u1 = ConvResBlock(4 * base, 4 * base, cond_dim)
        self.up_1 = UpSample(4 * base)
        self.u2 = ConvResBlock(4 * base + 4 * base, 2 * base, cond_dim)
        self.u3 = ConvResBlock(2 * base, 2 * base, cond_dim)
        self.up_2 = UpSample(2 * base)
        self.u4 = ConvResBlock(2 * base + 2 * base, base, cond_dim)

        self.out_norm = nn.GroupNorm(sched.group_norm_batch_size, base)
        self.out_conv = nn.Conv2d(base, in_channels, kernel_size = 3, padding = 1)


    def forward(self, x : torch.tensor, t : torch.tensor, y : torch.tensor):
        # x : [B, c, h, w]
        # t : [B, ], y : [B, ]
        cond_vec = self.time_emb(t) + self.class_emb(y)
        h0 = self.in_conv(x) # [B, base, h, w]
        h1 = self.d1(h0, cond_vec) # [B, base, h, w]
        h2 = self.d2(h1, cond_vec) # [B, 2 * base, h, w]
        h3 = self.down_1(h2) # [B, 2 * base, h / 2, w / 2]
        h3 = self.d3(h3, cond_vec) # [B, 2 * base, h / 2, w / 2]
        h4 = self.d4(h3, cond_vec) # [B, 4 * base, h / 2, w / 2]
        h5 = self.down_2(h4) # [B, 4 * base, h / 4, w / 4]

        hm = self.mid(h5, cond_vec) # [B, 4 * base, h / 4, w / 4]

        hu = self.u1(hm, cond_vec) # [B, 4 * base, h / 4, w / 4]
        hu = self.up_1(hu) # [B, 4 * base, h / 2, w / 2]
        hu = self.u2(torch.cat([hu, h4], dim = 1), cond_vec) # [B, 2 * base, h / 2, w / 2]
        hu = self.u3(hu, cond_vec) # [B, 2 * base, h / 2, w / 2]
        hu = self.up_2(hu) # [B, 2 * base, h, w]
        hu = self.u4(torch.cat([hu, h2], dim = 1), cond_vec) # [B, base, h, w]

        out_x = self.out_conv(F.silu(self.out_norm(hu)))
        return out_x

In [None]:
def q_sample(x_0 : torch.tensor, t : torch.tensor, noise = None):
    # x_0 : [B, c, h, w], t : [B, ]
    if noise is None:
        noise = torch.randn_like(x_0)

    # q(x_t | x_0) ~ N(sqrt_a_bar[t] * x_0, (1 - a_bar[t]) * I)
    return (x_0 * sched.sqrt_a_bar[t].view(-1, 1, 1, 1)) + (noise * sched.sqrt_one_minus_a_bar[t].view(-1, 1, 1, 1)), noise

In [None]:
def loss_eps(model, x_0, t, y, weighted = False):
    x_t, eps = q_sample(x_0, t)
    eps_pred = model(x_t, t, y)
    mse = pow(eps_pred - eps, 2).mean(dim = (1, 2, 3))
    if weighted:
        wt = sched.betas[t] / (sched.alphas[t] * (1 - sched.a_bar_prev[t]) * 2.0)
        mse *= wt
    return mse.mean()

In [None]:
def mnist_loader(batch_size = 128, root = "./data/"):
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x : (x * 2.0) - 1.0)
    ])

    train = datasets.MNIST(root, train = True, transform = tfm, download = True)
    test = datasets.MNIST(root, train = False, transform = tfm, download = True)

    return (DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 2, pin_memory = True),
            DataLoader(test, batch_size = batch_size, shuffle = False, num_workers = 2, pin_memory = True))

In [None]:
train_loader, test_loader = mnist_loader()

100%|██████████| 9.91M/9.91M [00:00<00:00, 57.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.64MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.5MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.86MB/s]


In [None]:
class EMA:
    def __init__(self, model, decay : float = 0.999):
        self.model = model
        self.decay = decay
        self.shadow = [p.detach().clone() for p in model.parameters()] # shadow wts to be used instead during inference.

    @torch.no_grad()
    def update(self):
        for i, p in enumerate(self.model.parameters()):
            self.shadow[i] = (self.shadow[i] * self.decay) + p.detach() * (1 - self.decay)

    def store(self):
        self.backup = [p.detach().clone() for p in self.model.parameters()]

    def copy_to(self):
        for p, s in zip(self.model.parameters(), self.shadow):
            p.copy_(s)

    def restore(self):
        for p, b in zip(self.model.parameters(), self.backup):
            p.copy_(b)

In [None]:
@torch.no_grad()
def sample_grid(model, y, num_imgs = 16, img_shape = (1, 28, 28), steps = None, cfg_scale = 3.0, sigma_choice = "beta_tilde", num_classes = 10, device = device):
    model.eval()
    x = torch.randn([num_imgs, *img_shape], device = device)
    y = y.to(device)
    steps = steps if steps is not None else sched.T
    per_step = max(1, sched.T // steps)

    t_seq = torch.arange(sched.T - 1, -1, -per_step, dtype = torch.int64, device = device)

    for ti in t_seq:
        t = torch.full((num_imgs, ), ti, dtype = torch.int64, device = device)
        null_id = num_classes

        y_uncond = torch.full_like(y, null_id)
        x_in = torch.cat([x, x], dim = 0)
        y_in = torch.cat([y_uncond, y], dim = 0)
        t_in = torch.cat([t, t], dim = 0)

        eps_uncond, eps_cond = model(x_in, t_in, y_in).chunk(2, dim = 0)
        eps_pred = eps_uncond + cfg_scale * (eps_cond - eps_uncond)


        alpha_t = sched.alphas[t].view(num_imgs, 1, 1, 1)
        beta_t = sched.betas[t].view(num_imgs, 1, 1, 1)
        a_bar_t = sched.a_bar[t].view(num_imgs, 1, 1, 1)

        mean = (1 / torch.sqrt(alpha_t)) * (x - ((beta_t / torch.sqrt(1 - a_bar_t)) * eps_pred))

        if ti > 0:
            if sigma_choice == "beta_tilde":
                var = sched.post_var[t].view(num_imgs, 1, 1, 1).clamp_min(1e-20)
            else:
                var = beta_t

            x = mean + (torch.sqrt(var) * torch.randn_like(x))
        else:
            x = mean

    return x.clamp(-1, 1)

In [None]:
def preprocess_for_inception(imgs : torch.tensor):
    # imgs : [B, c, h, w]
    if imgs.shape[1] == 1:
        imgs = imgs.repeat(1, 3, 1, 1) # As inception network takes in imgs with 3 channels.

    imgs = (imgs + 1.0) / 2.0 # [-1, 1] -> [0, 1]
    imgs = F.interpolate(imgs, size = (299, 299), align_corners = False, mode = "bilinear")
    mean = torch.tensor([0.485, 0.456, 0.406], device = imgs.device).view(1, 3, 1, 1)
    std  = torch.tensor([0.229, 0.224, 0.225], device = imgs.device).view(1, 3, 1, 1)

    imgs = (imgs - mean) / std
    return imgs

In [None]:
def load_inception():
    model = inception_v3(weights = Inception_V3_Weights.IMAGENET1K_V1) # During inference, returns only last output, not aux outputs
    model.eval().to(device)
    return model

In [None]:
def inception_outputs(inception : nn.Module, imgs : torch.Tensor, device = device):
    imgs = preprocess_for_inception(imgs.to(device))
    collected = {}

    def hook_fn(m, i, o):
        collected["feats"] = torch.flatten(o, 1)

    h = inception.avgpool.register_forward_hook(hook_fn)
    logits = inception(imgs)
    h.remove()

    return collected["feats"], logits

In [None]:
@torch.no_grad()
def inception_score_from_logits(logits : torch.Tensor, splits : int = 10):
    # logits : [B, 1000]
    probs = F.softmax(logits, dim = -1)
    split_size = logits.shape[0] // splits
    scores = []

    for i in range(splits):
        part = probs[i * split_size : (i + 1) * split_size]
        p_y = part.mean(dim = 0, keepdim = True) # [1, 1000]
        kl = (part * (torch.log(part + 1e-12) - torch.log(p_y + 1e-12))).sum(dim = 1) # [B, ]
        scores.append(torch.exp(kl.mean()))

    return torch.stack(scores).mean().item()

In [None]:
def _cov(x : torch.Tensor):
    # x : [B, D]
    mu = x.mean(dim = 0, keepdim = True)
    x_ = x - mu
    return torch.matmul(x_.transpose(0, 1), x_) / (x.shape[0] - 1)

In [None]:
def mean_and_cov(feats):
    mean = feats.mean(dim = 0)
    cov = _cov(feats)
    return mean, cov

In [None]:
def sqrt_trace_cov(sigma_1 : torch.Tensor, sigma_2 : torch.tensor):
    # sigma_1 : [D, D], sigma_2 : [D, D]
    # sqrt(sigma_1) * sigma_2 * sqrt(sigma_1)

    S1, U1 = torch.linalg.eigh(sigma_1) # S : [D, ], U : [D, D]
    S1 =  S1.clamp_min(1e-12).unsqueeze(0) # [1, D]
    sqrt_sigma_1 = torch.matmul(U1 * torch.sqrt(S1), U1.T)

    A = sqrt_sigma_1 @ sigma_2 @ sqrt_sigma_1

    Sa, Ua = torch.linalg.eigh((A + A.T) / 2)
    Sa = Sa.clamp_min(1e-12).unsqueeze(0)
    return torch.sqrt(Sa).sum()

In [None]:
def fid_from_stats(mu_1, sigma_1, mu_2, sigma_2):
    diff = pow(mu_1 - mu_2, 2).sum()
    tr = torch.trace(sigma_1 + sigma_2) - 2 * sqrt_trace_cov(sigma_1, sigma_2)
    return (diff + tr).item()

In [None]:
@torch.no_grad()
def get_real_stats(dataloader, inception, device = device, max_images = 5000):
    feats_list, logits_list = [], []
    cnt = 0

    for x, _ in dataloader:
        x = x.to(device)
        feats, logits = inception_outputs(inception, x, device)
        feats_list.append(feats)
        logits_list.append(logits)

        cnt += len(x)

        if cnt >= max_images:
            break

    feats = torch.cat(feats_list, dim = 0)[: max_images]
    logits = torch.cat(logits_list, dim = 0)[: max_images]
    mu, sigma = mean_and_cov(feats)
    return mu, sigma, logits

In [None]:
@torch.no_grad()
def get_fake_stats(unet, inception, device = device, num_imgs = 5000, cfg_scale = 3.0, steps = None, num_classes = 10):
    B = 512 # generation batch size
    feats_list, logits_list = [], []
    cnt = 0
    null_id = num_classes
    per_class = math.ceil(num_imgs / num_classes)

    pbar = tqdm(total = num_imgs, desc = "Fake_Stats")

    while cnt < num_imgs:
        labels = []
        for c in range(num_classes):
            labels.extend([c] * min(per_class, B // num_classes))
        labels = labels[: B]

        y = torch.tensor(labels, dtype = torch.int64, device = device)
        imgs = sample_grid(unet, y, y.shape[0], steps = steps, cfg_scale = cfg_scale, num_classes = num_classes, device = device)

        feats, logits = inception_outputs(inception, imgs, device)
        feats_list.append(feats)
        logits_list.append(logits)
        cnt += y.shape[0]

        pbar.update(y.shape[0])

    pbar.close()


    feats = torch.cat(feats_list, dim = 0)[: num_imgs]
    logits = torch.cat(logits_list, dim = 0)[: num_imgs]

    mu, sigma = mean_and_cov(feats)
    return mu, sigma, logits

In [None]:
@torch.no_grad()
def compute_IS_FID(unet, train_loader, device = device, num_fakes = 5000, steps = None, cfg_scale = 3.0, num_classes = 10, cache_real_stats = True, real_stats_path = None):
    inception = load_inception()

    if real_stats_path is None:
        real_stats_path = "samples/real_stats.pt" # Keep it in "samples" folder

    if cache_real_stats and os.path.isfile(real_stats_path):
        ckpt = torch.load(real_stats_path, map_location = device)
        mu_r, sigma_r, logits_r = ckpt["mu"], ckpt["sigma"], ckpt["logits"]
    else:
        mu_r, sigma_r, logits_r = get_real_stats(train_loader, inception, device)
        os.makedirs("samples", exist_ok = True)
        torch.save({"mu" : mu_r, "sigma" : sigma_r, "logits" : logits_r}, real_stats_path)

    mu_f, sigma_f, logits_f = get_fake_stats(unet, inception, device, num_fakes, cfg_scale, steps, num_classes)
    IS = inception_score_from_logits(logits_f)
    FID = fid_from_stats(mu_r, sigma_r, mu_f, sigma_f)
    return IS, FID

In [None]:
unet = UNET().to(device)
print(f"{count_params(unet):,} params")
ema = EMA(unet)

1,916,865 params


In [None]:
def train(model, epochs = 20, lr = 2e-4, checkpoint = "mnist_ddpm.pt", weighted = False, p_uncond = 0.1, num_classes = 10):
    optim = torch.optim.Adam(model.parameters(), lr = lr,  weight_decay = 1e-4) # L2 Reg.
    scaler = torch.cuda.amp.GradScaler(enabled = (device == "cuda"))

    model.train()
    itr = 0

    for epoch in range(1, epochs + 1):
        print(f"Epoch : {epoch}")
        for x_0, y in tqdm(train_loader, total = len(train_loader)):
            x_0 = x_0.to(device)
            y = y.to(device)

            null_id = num_classes
            drop_mask = (torch.rand_like(y.float()) < p_uncond)
            y_train = y.clone()
            y_train[drop_mask] = null_id


            t = torch.randint(0, sched.T, size = (x_0.shape[0], ), dtype = torch.int64, device = device)

            with torch.cuda.amp.autocast(enabled = (device == "cuda")):
                loss = loss_eps(model, x_0, t, y_train, weighted)


            optim.zero_grad(set_to_none = True)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            ema.update()


            if itr % 200 == 0:
                print(f"itr : {itr:06d}, Loss : {loss.item():.4f}")

            itr += 1


        with torch.no_grad():
            y_sample = torch.arange(0, 10).repeat_interleave((16 // 10) + 1)[: 16]

            ema.store()
            ema.copy_to()
            imgs = sample_grid(model, y_sample, num_imgs = len(y_sample))
            ema.restore()

            grid = vutils.make_grid((imgs + 1) / 2, nrow = 4)
            os.makedirs("samples", exist_ok = True)
            vutils.save_image(grid, f"samples/{epoch:03d}.png")

            if epoch % 10 == 0 or epoch == 1:
                # Takes 5 min, as generated fake images for fake stats.
                IS, FID = compute_IS_FID(unet, train_loader)
                print(f"\nEpoch : {epoch:03d}, FID : {FID:.2f}, IS : {IS:.2f}\n")


    torch.save(model.state_dict(), checkpoint)

In [None]:
train(unet)

  scaler = torch.cuda.amp.GradScaler(enabled = (device == "cuda"))


Epoch : 1


  with torch.cuda.amp.autocast(enabled = (device == "cuda")):
  1%|          | 4/469 [00:31<45:56,  5.93s/it]  

itr : 000000, Loss : 0.9882


 44%|████▎     | 205/469 [00:39<00:10, 24.62it/s]

itr : 000200, Loss : 0.0623


 86%|████████▌ | 403/469 [00:47<00:02, 24.81it/s]

itr : 000400, Loss : 0.0328


100%|██████████| 469/469 [01:17<00:00,  6.03it/s]


Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


100%|██████████| 104M/104M [00:00<00:00, 194MB/s] 


torch.Size([2048]) torch.Size([2048, 2048]) torch.Size([5000, 1000])


Fake_Stats: 5100it [05:22, 15.80it/s]



Epoch : 1, FID : 79.78657531738281, IS : 2.708458423614502

Epoch : 2


 29%|██▉       | 136/469 [00:05<00:15, 22.20it/s]

itr : 000600, Loss : 0.0297


 71%|███████   | 334/469 [00:14<00:05, 24.41it/s]

itr : 000800, Loss : 0.0295


100%|██████████| 469/469 [00:19<00:00, 23.77it/s]


Epoch : 3


 14%|█▍        | 66/469 [00:03<00:18, 22.25it/s]

itr : 001000, Loss : 0.0317


 57%|█████▋    | 267/469 [00:12<00:09, 22.30it/s]

itr : 001200, Loss : 0.0278


 99%|█████████▉| 465/469 [00:21<00:00, 22.95it/s]

itr : 001400, Loss : 0.0263


100%|██████████| 469/469 [00:21<00:00, 21.96it/s]


Epoch : 4


 42%|████▏     | 196/469 [00:08<00:11, 24.00it/s]

itr : 001600, Loss : 0.0240


 85%|████████▍ | 397/469 [00:16<00:02, 24.99it/s]

itr : 001800, Loss : 0.0225


100%|██████████| 469/469 [00:19<00:00, 24.03it/s]


Epoch : 5


 27%|██▋       | 127/469 [00:05<00:14, 22.86it/s]

itr : 002000, Loss : 0.0221


 70%|██████▉   | 328/469 [00:14<00:06, 21.82it/s]

itr : 002200, Loss : 0.0271


100%|██████████| 469/469 [00:20<00:00, 23.03it/s]


Epoch : 6


 12%|█▏        | 58/469 [00:02<00:18, 21.80it/s]

itr : 002400, Loss : 0.0283


 55%|█████▌    | 259/469 [00:11<00:08, 24.01it/s]

itr : 002600, Loss : 0.0227


 98%|█████████▊| 460/469 [00:19<00:00, 24.82it/s]

itr : 002800, Loss : 0.0276


100%|██████████| 469/469 [00:20<00:00, 23.34it/s]


Epoch : 7


 41%|████      | 190/469 [00:08<00:13, 20.46it/s]

itr : 003000, Loss : 0.0242


 83%|████████▎ | 390/469 [00:16<00:03, 25.18it/s]

itr : 003200, Loss : 0.0352


100%|██████████| 469/469 [00:20<00:00, 23.22it/s]


Epoch : 8


 26%|██▌       | 121/469 [00:05<00:14, 24.01it/s]

itr : 003400, Loss : 0.0245


 69%|██████▊   | 322/469 [00:13<00:05, 24.90it/s]

itr : 003600, Loss : 0.0263


100%|██████████| 469/469 [00:19<00:00, 23.96it/s]


Epoch : 9


 11%|█         | 52/469 [00:02<00:16, 24.92it/s]

itr : 003800, Loss : 0.0269


 54%|█████▍    | 253/469 [00:10<00:09, 22.94it/s]

itr : 004000, Loss : 0.0228


 96%|█████████▌| 451/469 [00:19<00:00, 21.90it/s]

itr : 004200, Loss : 0.0223


100%|██████████| 469/469 [00:20<00:00, 23.17it/s]


Epoch : 10


 39%|███▉      | 184/469 [00:07<00:12, 22.36it/s]

itr : 004400, Loss : 0.0243


 81%|████████▏ | 382/469 [00:16<00:03, 24.70it/s]

itr : 004600, Loss : 0.0245


100%|██████████| 469/469 [00:20<00:00, 23.14it/s]
Fake_Stats: 5100it [05:18, 16.03it/s]



Epoch : 10, FID : 15.860332489013672, IS : 2.275853395462036

Epoch : 11


 25%|██▍       | 115/469 [00:04<00:14, 24.39it/s]

itr : 004800, Loss : 0.0230


 67%|██████▋   | 313/469 [00:13<00:07, 21.95it/s]

itr : 005000, Loss : 0.0241


100%|██████████| 469/469 [00:21<00:00, 22.27it/s]


Epoch : 12


 10%|▉         | 46/469 [00:01<00:17, 24.15it/s]

itr : 005200, Loss : 0.0219


 52%|█████▏    | 244/469 [00:10<00:09, 24.22it/s]

itr : 005400, Loss : 0.0208


 95%|█████████▍| 445/469 [00:19<00:01, 21.75it/s]

itr : 005600, Loss : 0.0231


100%|██████████| 469/469 [00:20<00:00, 23.02it/s]


Epoch : 13


 38%|███▊      | 178/469 [00:07<00:11, 24.90it/s]

itr : 005800, Loss : 0.0228


 80%|████████  | 376/469 [00:15<00:03, 25.21it/s]

itr : 006000, Loss : 0.0190


100%|██████████| 469/469 [00:19<00:00, 24.39it/s]


Epoch : 14


 23%|██▎       | 106/469 [00:05<00:17, 21.27it/s]

itr : 006200, Loss : 0.0240


 65%|██████▌   | 307/469 [00:13<00:06, 23.74it/s]

itr : 006400, Loss : 0.0260


100%|██████████| 469/469 [00:20<00:00, 23.07it/s]


Epoch : 15


  9%|▊         | 40/469 [00:01<00:17, 24.67it/s]

itr : 006600, Loss : 0.0231


 51%|█████     | 238/469 [00:09<00:10, 22.69it/s]

itr : 006800, Loss : 0.0233


 94%|█████████▎| 439/469 [00:18<00:01, 25.08it/s]

itr : 007000, Loss : 0.0317


100%|██████████| 469/469 [00:19<00:00, 23.89it/s]


Epoch : 16


 36%|███▌      | 169/469 [00:07<00:12, 24.33it/s]

itr : 007200, Loss : 0.0193


 79%|███████▉  | 370/469 [00:15<00:03, 24.89it/s]

itr : 007400, Loss : 0.0239


100%|██████████| 469/469 [00:19<00:00, 23.97it/s]


Epoch : 17


 21%|██▏       | 100/469 [00:04<00:16, 22.83it/s]

itr : 007600, Loss : 0.0243


 64%|██████▍   | 301/469 [00:13<00:07, 23.61it/s]

itr : 007800, Loss : 0.0197


100%|██████████| 469/469 [00:20<00:00, 23.15it/s]


Epoch : 18


  7%|▋         | 31/469 [00:01<00:19, 22.65it/s]

itr : 008000, Loss : 0.0255


 49%|████▉     | 232/469 [00:10<00:11, 20.97it/s]

itr : 008200, Loss : 0.0258


 92%|█████████▏| 430/469 [00:18<00:01, 23.67it/s]

itr : 008400, Loss : 0.0214


100%|██████████| 469/469 [00:20<00:00, 23.04it/s]


Epoch : 19


 35%|███▍      | 163/469 [00:07<00:13, 23.49it/s]

itr : 008600, Loss : 0.0254


 78%|███████▊  | 364/469 [00:15<00:04, 24.94it/s]

itr : 008800, Loss : 0.0222


100%|██████████| 469/469 [00:19<00:00, 23.80it/s]


Epoch : 20


 20%|██        | 94/469 [00:03<00:15, 24.55it/s]

itr : 009000, Loss : 0.0216


 63%|██████▎   | 295/469 [00:12<00:07, 24.68it/s]

itr : 009200, Loss : 0.0265


100%|██████████| 469/469 [00:19<00:00, 23.75it/s]
Fake_Stats: 5100it [05:18, 16.02it/s]



Epoch : 20, FID : 21.35757064819336, IS : 2.4000580310821533



In [None]:
def generate_dig(unet : nn.Module, dig = None, num_imgs = 16, steps = 1000, ckpt = None, sigma_choice = "beta_tilde", cfg_scale = 3.0, num_classes = 10, return_imgs = False):
    if ckpt is not None:
        unet.load_state_dict(torch.load(ckpt), map_location = device)

    if dig is None:
        dig = num_classes # null_id

    y_label = torch.full((num_imgs, ), int(dig), device = device, dtype = torch.int64)
    imgs = sample_grid(unet, y_label, num_imgs, sigma_choice = sigma_choice, steps = steps, cfg_scale = cfg_scale)
    os.makedirs("samples", exist_ok = True)
    save_path = f"samples/sample_digit_{dig}_steps_{steps}_cfg_scale_{cfg_scale}.png"
    grid = vutils.make_grid((imgs + 1) / 2, nrow = 4)
    vutils.save_image(grid, save_path)
    print(f"Saved at : {save_path}")

    if return_imgs:
        return imgs

In [None]:
generate_dig(unet, dig = 6)

Saved at : samples/sample_digit_6_steps_1000_cfg_scale_3.0.png


In [None]:
imgs = generate_dig(unet, return_imgs = True)

Saved at : samples/sample_digit_10_steps_1000_cfg_scale_3.0.png
