In [None]:
!pip install -q datasets

In [None]:
import math, random
import os
from tqdm import tqdm
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from torchvision import transforms, utils as vutils
from torchvision.models import inception_v3, Inception_V3_Weights

In [None]:
torch.backends.cudnn.benchmark = True

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

SEED = 42
set_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEVICE : {device}")

DEVICE : cuda


In [None]:
@dataclass
class DDPMConfig:
    T : int = 1000
    schedule : str = "cosine" # cosine / linear
    beta_start : float = 1e-4
    beta_end : float = 2e-2
    group_norm_batch_size : int = 32
    use_elbo_weights : bool = False

In [None]:
class DiffusionSchedule():
    def __init__(self, config : DDPMConfig):
        self.config = config
        self.T = config.T

        if config.schedule == "linear":
            self.betas = torch.linspace(config.beta_start, config.beta_end, config.T, dtype = torch.float32)
            self.alphas = 1 - self.betas
            self.a_bar = torch.cumprod(self.alphas, dim = 0)
        else:
            # cosine schedule
            s = 0.008
            t = torch.linspace(0, config.T, config.T + 1, dtype = torch.float32)
            f = lambda x : pow(torch.cos((((x / config.T) + s) /(1 + s)) * (math.pi / 2)), 2)
            self.a_bar = (f(t) / f(torch.tensor(0.0)))[1: ]
            self.alphas = self.a_bar / torch.cat([torch.tensor([1.0]), self.a_bar[: -1]], dim = 0)
            self.betas = 1 - self.alphas

        self.sqrt_a_bar = torch.sqrt(self.a_bar)
        self.sqrt_one_minus_a_bar = torch.sqrt(1 - self.a_bar)
        self.alpha_prev = torch.cat([torch.tensor([1.0]), self.alphas[: -1]], dim = 0)
        self.a_bar_prev = torch.cat([torch.tensor([1.0]), self.a_bar[: -1]], dim = 0)
        self.post_var = self.betas * ((1 - self.a_bar_prev) / (1 - self.a_bar)) # beta_tilde
        self.post_var = self.post_var.clamp_min(1e-20)

    def to(self, device : str = device):
        for k, v in self.__dict__.items():
            if torch.is_tensor(v):
                setattr(self, k, v.to(device))
        return self

In [None]:
config = DDPMConfig()
sched = DiffusionSchedule(config).to(device)

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, time_dim):
        super().__init__()
        self.T = sched.T
        self.time_dim = time_dim
        self.mlp = nn.Sequential(nn.Linear(time_dim, 2 * time_dim), nn.SiLU(), nn.Linear(2 * time_dim, time_dim))

    def forward(self, time_vals : torch.Tensor, base : float = 10000.0, device = device):
        # time_vals : [B, ]
        half = self.time_dim // 2
        freqs = torch.exp(-1 * (torch.arange(0, half, dtype = torch.float32, device = device) / half) * math.log(base)) # [half, ]
        args = time_vals[:, None] * freqs[None, :]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim = -1) # [B, time_dim]
        return self.mlp(emb)

In [None]:
count_params = lambda m : sum([p.numel() for p in m.parameters()])

In [None]:
class ConvResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, cond_dim, dropout : float = 0.0):
        super().__init__()
        self.norm_1 = nn.GroupNorm(config.group_norm_batch_size, in_ch)
        self.conv_1 = nn.Conv2d(in_ch, out_ch, kernel_size = 3, padding = 1)
        self.norm_2 = nn.GroupNorm(config.group_norm_batch_size, out_ch)
        self.conv_2 = nn.Conv2d(out_ch, out_ch, kernel_size = 3, padding = 1)
        self.skip = nn.Conv2d(in_ch, out_ch, kernel_size = 1, padding = 0) if in_ch != out_ch else nn.Identity()
        self.cond = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, 2 * out_ch))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x : torch.Tensor, cond_vec : torch.tensor):
        # x : [B, c, h, w], cond_vec : [B, cond_dim]
        h = self.conv_1(F.silu(self.norm_1(x)))
        h = self.norm_2(h)
        gamma, beta = self.cond(cond_vec).chunk(2, dim = 1)
        h = (h * (1 + gamma[:, :, None, None])) + beta[:, :, None, None]
        h = self.conv_2(F.silu(self.dropout(h)))
        return h + self.skip(x)

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, ch, num_heads = 4):
        super().__init__()
        assert ch % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = ch // num_heads
        self.norm = nn.GroupNorm(config.group_norm_batch_size, ch)
        self.qkv = nn.Conv2d(ch, 3 * ch, kernel_size = 1, padding = 0)
        self.proj = nn.Conv2d(ch, ch, kernel_size = 1, padding = 0)

    def forward(self, x : torch.Tensor):
        # x : [B, c, h, w]
        B, c, ht, wd = x.shape
        h = self.norm(x)
        q, k, v = self.qkv(h).chunk(3, dim = 1) # [B, c, h, w]

        def reshape(t : torch.Tensor):
            # t : [B, c, h, w]
            return t.view(B, self.num_heads, self.head_dim, ht * wd).permute([0, 1, 3, 2])

        q, k, v = map(reshape, (q, k, v)) # [B, num_heads, seq_len(h * w), head_dim]
        attn_scores = torch.softmax(torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim), dim = -1) # [B, num_heads, seq_len, seq_len]
        out = torch.matmul(attn_scores, v) # [B, num_heads, seq_len, head_dim]
        out = out.permute([0, 1, 3, 2]).contiguous().view([B, c, ht, wd])
        return self.proj(out) + x

In [None]:
class Upsample(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Conv2d(ch, ch, kernel_size = 3, padding = 1)

    def forward(self, x : torch.Tensor):
        # x : [B, c, h, w]
        return self.conv(F.interpolate(x, scale_factor = 2, mode = "nearest"))

class DownSample(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Conv2d(ch, ch, kernel_size = 3, stride = 2, padding = 1)

    def forward(self, x : torch.Tensor):
        # x : [B, c, h, w]
        return self.conv(x)

In [None]:
class UNET(nn.Module):
    def __init__(self, in_ch = 3, base = 64, ch_mult = (1, 2, 2, 4), num_res_blocks = 2, attn_res = (16, 8), cond_dim = 256, dropout = 0.0, num_heads = 4):
        super().__init__()
        self.time_emb  = TimeEmbedding(time_dim = cond_dim)
        self.in_conv = nn.Conv2d(in_ch, base, kernel_size = 3, padding = 1)

        # DOWN
        downs = []
        ch = base
        hidden_sizes = [ch]
        res = 64

        for i, m in enumerate(ch_mult):
            out_c = base * m
            for _ in range(num_res_blocks):
                downs.append(ConvResBlock(ch, out_c, cond_dim, dropout))
                ch = out_c
                hidden_sizes.append(ch)

                if res in attn_res:
                    downs.append(SelfAttention(ch, num_heads))

            if i != len(ch_mult) - 1:
                downs.append(DownSample(ch))
                res = res // 2 # Downed to Half size
                hidden_sizes.append(ch)

        self.down = nn.ModuleList(downs)


        # MID
        mids = [ConvResBlock(ch, ch, cond_dim, dropout),
                SelfAttention(ch, num_heads),
                ConvResBlock(ch, ch, cond_dim, dropout)]
        self.mid = nn.ModuleList(mids)

        # UP
        ups = []
        for i, m in list(enumerate(ch_mult))[::-1]:
            out_c = base * m
            for _ in range(num_res_blocks + 1): # +1 to account for extra DownSample layer and in_conv layer
                skip_ch = hidden_sizes.pop()
                ups.append(ConvResBlock(ch + skip_ch, out_c, cond_dim, dropout))
                ch = out_c

                if res in attn_res:
                    ups.append(SelfAttention(ch, num_heads))

            if i != 0:
                ups.append(Upsample(ch))
                res *= 2

        self.up = nn.ModuleList(ups)
        self.norm_out = nn.GroupNorm(config.group_norm_batch_size, ch)
        self.conv_out = nn.Conv2d(ch, in_ch, kernel_size = 3, padding = 1)


    def forward(self, x : torch.Tensor, t : torch.Tensor):
        # x : [B, c, h, w], t : [B, ]
        cond_vec = self.time_emb(t) # [B, cond_dim]
        h = self.in_conv(x)

        feats = [h]

        # DOWN
        for m in self.down:
            if isinstance(m, ConvResBlock):
                h = m(h, cond_vec)
                feats.append(h)
            elif isinstance(m, SelfAttention):
                h = m(h)
            else:
                # Downsample
                h = m(h)
                feats.append(h)


        # MID
        for m in self.mid:
            if isinstance(m, ConvResBlock):
                h = m(h, cond_vec)
            else:
                # SelfAttention
                h = m(h)

        # UP
        for m in self.up:
            if isinstance(m, ConvResBlock):
                d_ = feats.pop()
                h_ = torch.cat([h, d_], dim = 1)
                h = m(h_, cond_vec)
            else:
                # SelfAttention, UpSample
                h = m(h)


        return self.conv_out(F.silu(self.norm_out(h)))

In [None]:
def q_sample(x_0 : torch.Tensor, t : torch.Tensor, noise = None):
    # x_0 : [B, c, h, w], t : [B, ]
    if noise is None:
        noise = torch.randn_like(x_0)

    # q(x_t | x_0) ~ N(sqrt(a_bar[t]), 1 - a_bar[t])
    mean = sched.sqrt_a_bar[t].view(-1, 1, 1, 1)
    std = sched.sqrt_one_minus_a_bar[t].view(-1, 1, 1, 1)
    return (mean * x_0) + (std * noise), noise

In [None]:
def loss_eps(model, x_0, t, weighted = False):
    x_t, eps = q_sample(x_0, t)
    eps_pred = model(x_t, t)
    mse = pow(eps_pred - eps, 2).mean(dim = (1, 2, 3))
    if weighted:
        wt = (sched.betas[t]) / (2.0 * sched.alphas[t] * (1.0 - sched.a_bar_prev[t]))
        mse *= wt
    return mse.mean()

In [None]:
class CelebAHFDataset(Dataset):
    def __init__(self, split, transform):
        super().__init__()
        self.ds = load_dataset("eurecom-ds/celeba", split = split)
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        out = self.ds[idx]["image"].convert("RGB")
        if self.transform:
            out = self.transform(out)
        return out

In [None]:
def celeba_loader(root = "./data", batch_size = 128):
    tfm = transforms.Compose([transforms.CenterCrop(140),
                              transforms.Resize(64),
                              transforms.ToTensor(),
                              transforms.Lambda(lambda x : x * 2.0 - 1.0)])
    train = CelebAHFDataset(split = "train", transform = tfm)
    valid = CelebAHFDataset(split = "validation", transform = tfm)
    test = CelebAHFDataset(split = "test", transform = tfm)

    return (DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 4, pin_memory = True),
            DataLoader(valid, batch_size = batch_size, shuffle = True, num_workers = 4, pin_memory = True),
            DataLoader(test, batch_size = batch_size, shuffle = True, num_workers = 4, pin_memory = True))

In [None]:
train_loader, val_loader, test_loader = celeba_loader()

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

In [None]:
class EMA:
    def __init__(self, model, decay : float = 0.999):
        self.model = model
        self.decay = decay
        self.shadow = [p.detach().clone() for p in model.parameters()]

    @torch.no_grad()
    def update(self):
        for i, p in enumerate(self.model.parameters()):
            self.shadow[i] = p.detach() * (1 - self.decay) + (self.decay * self.shadow[i])

    def store(self):
        self.backup = [p.detach().clone() for p in self.model.parameters()]

    def copy_to(self):
        for p, s in zip(self.model.parameters(), self.shadow):
            p.copy_(s)

    def restore(self):
        for p, b in zip(self.model.parameters(), self.backup):
            p.copy_(b)

In [None]:
@torch.no_grad()
def sample_grid(model, num_imgs = 16, img_shape = (3, 64, 64), steps = None, sigma_choice = "beta_tilde", device = device):
    model.eval()
    x = torch.randn(num_imgs, *img_shape, device = device)
    steps = steps if steps is not None else sched.T
    per_step = max(1, sched.T // steps)
    t_seq = torch.arange(sched.T - 1, -1, -per_step, device = device, dtype = torch.int64)

    for ti in t_seq:
        t = torch.full((num_imgs, ), ti, device = device, dtype = torch.int64)
        eps_pred = model(x, t)

        alpha_t = sched.alphas[t].view(-1, 1, 1, 1)
        beta_t = sched.betas[t].view(-1, 1, 1, 1)
        a_bar_t = sched.a_bar[t].view(-1, 1, 1, 1)

        mean = (1 / torch.sqrt(alpha_t)) * (x - ((beta_t / (torch.sqrt(1 - a_bar_t))) * eps_pred))

        if ti > 0:
            if sigma_choice == "beta_tilde":
                var = sched.post_var[t].view(-1, 1, 1, 1)
            else:
                var = beta_t

            x = mean + (torch.sqrt(var) * torch.randn_like(x))
        else:
            x = mean

    return x.clamp(-1, 1)

In [None]:
unet = UNET().to(device)
print(f"{count_params(unet):,} params")
ema = EMA(unet)

19,588,547 params


In [None]:
def train(model, epochs = 100, lr = 2e-4, checkpoint = "celeba_ddpm.pt", weighted = False):
    optim = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = 1e-4)
    scaler = torch.cuda.amp.GradScaler(enabled = (device == "cuda"))

    model.train()
    itr = 0

    for epoch in range(1, epochs + 1):
        print(f"Epoch : {epoch}")
        for x_0 in tqdm(train_loader, total = len(train_loader)):
            x_0 = x_0.to(device)
            t = torch.randint(0, sched.T, size = (x_0.shape[0],), dtype = torch.int64, device = device)
            with torch.cuda.amp.autocast(enabled = (device == "cuda")):
                loss = loss_eps(model, x_0, t, weighted)

            optim.zero_grad(set_to_none = True)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            ema.update()

            if itr % 500 == 0:
                print(f"itr : {itr:06d}, Loss : {loss.item():.4f}")
            itr += 1


        with torch.no_grad():
            ema.store()
            ema.copy_to()
            imgs = sample_grid(model, device = device)
            ema.restore()

            grid = vutils.make_grid((imgs + 1) / 2, nrow = 4)
            os.makedirs("samples", exist_ok = True)
            vutils.save_image(grid, f"samples/{epoch:03d}.png")

    torch.save(model.state_dict(), checkpoint)

In [None]:
train(unet)

  scaler = torch.cuda.amp.GradScaler(enabled = (device == "cuda"))


Epoch : 1


  with torch.cuda.amp.autocast(enabled = (device == "cuda")):
  0%|          | 2/1272 [00:01<09:44,  2.17it/s]

itr : 000000, Loss : 1.1602


 39%|███▉      | 502/1272 [01:26<02:11,  5.87it/s]

itr : 000500, Loss : 0.0472


 79%|███████▉  | 1002/1272 [02:51<00:46,  5.86it/s]

itr : 001000, Loss : 0.0365


100%|██████████| 1272/1272 [04:19<00:00,  4.91it/s]


Epoch : 2


 18%|█▊        | 230/1272 [00:39<02:57,  5.86it/s]

itr : 001500, Loss : 0.0423


 57%|█████▋    | 730/1272 [02:05<01:32,  5.87it/s]

itr : 002000, Loss : 0.0343


 97%|█████████▋| 1230/1272 [04:31<00:07,  5.86it/s]

itr : 002500, Loss : 0.0461


100%|██████████| 1272/1272 [04:38<00:00,  4.57it/s]


Epoch : 3


 36%|███▌      | 458/1272 [01:18<02:18,  5.87it/s]

itr : 003000, Loss : 0.0336


 75%|███████▌  | 958/1272 [02:44<00:53,  5.86it/s]

itr : 003500, Loss : 0.0295


100%|██████████| 1272/1272 [03:38<00:00,  5.83it/s]


Epoch : 4


 15%|█▍        | 186/1272 [00:32<03:07,  5.79it/s]

itr : 004000, Loss : 0.0377


 54%|█████▍    | 686/1272 [01:57<01:39,  5.87it/s]

itr : 004500, Loss : 0.0255


 93%|█████████▎| 1186/1272 [03:23<00:14,  5.86it/s]

itr : 005000, Loss : 0.0292


100%|██████████| 1272/1272 [03:38<00:00,  5.83it/s]


Epoch : 5


 33%|███▎      | 414/1272 [01:11<02:26,  5.86it/s]

itr : 005500, Loss : 0.0308


 72%|███████▏  | 914/1272 [02:36<01:01,  5.86it/s]

itr : 006000, Loss : 0.0354


100%|██████████| 1272/1272 [03:38<00:00,  5.83it/s]


Epoch : 6


 11%|█         | 142/1272 [00:24<03:13,  5.83it/s]

itr : 006500, Loss : 0.0510


 50%|█████     | 642/1272 [01:50<01:47,  5.87it/s]

itr : 007000, Loss : 0.0320


 90%|████████▉ | 1142/1272 [03:15<00:22,  5.85it/s]

itr : 007500, Loss : 0.0347


100%|██████████| 1272/1272 [03:38<00:00,  5.83it/s]


Epoch : 7


 29%|██▉       | 370/1272 [01:03<02:34,  5.84it/s]

itr : 008000, Loss : 0.0276


 68%|██████▊   | 870/1272 [02:29<01:08,  5.85it/s]

itr : 008500, Loss : 0.0397


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 8


  8%|▊         | 98/1272 [00:17<03:21,  5.83it/s]

itr : 009000, Loss : 0.0294


 47%|████▋     | 598/1272 [01:42<01:54,  5.87it/s]

itr : 009500, Loss : 0.0393


 86%|████████▋ | 1098/1272 [03:08<00:29,  5.83it/s]

itr : 010000, Loss : 0.0367


100%|██████████| 1272/1272 [03:38<00:00,  5.83it/s]


Epoch : 9


 26%|██▌       | 326/1272 [00:56<02:41,  5.86it/s]

itr : 010500, Loss : 0.0255


 65%|██████▍   | 826/1272 [02:21<01:16,  5.83it/s]

itr : 011000, Loss : 0.0409


100%|██████████| 1272/1272 [03:38<00:00,  5.83it/s]


Epoch : 10


  4%|▍         | 54/1272 [00:09<03:28,  5.84it/s]

itr : 011500, Loss : 0.0277


 44%|████▎     | 554/1272 [01:35<02:02,  5.86it/s]

itr : 012000, Loss : 0.0319


 83%|████████▎ | 1054/1272 [03:00<00:37,  5.78it/s]

itr : 012500, Loss : 0.0272


100%|██████████| 1272/1272 [03:38<00:00,  5.83it/s]


Epoch : 11


 22%|██▏       | 282/1272 [00:48<02:49,  5.85it/s]

itr : 013000, Loss : 0.0404


 61%|██████▏   | 782/1272 [02:14<01:24,  5.82it/s]

itr : 013500, Loss : 0.0393


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 12


  1%|          | 10/1272 [00:02<03:44,  5.61it/s]

itr : 014000, Loss : 0.0339


 40%|████      | 510/1272 [01:27<02:10,  5.83it/s]

itr : 014500, Loss : 0.0408


 79%|███████▉  | 1010/1272 [02:53<00:44,  5.84it/s]

itr : 015000, Loss : 0.0331


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 13


 19%|█▊        | 238/1272 [00:41<02:57,  5.83it/s]

itr : 015500, Loss : 0.0280


 58%|█████▊    | 738/1272 [02:07<01:31,  5.85it/s]

itr : 016000, Loss : 0.0332


 97%|█████████▋| 1238/1272 [03:32<00:05,  5.84it/s]

itr : 016500, Loss : 0.0285


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 14


 37%|███▋      | 466/1272 [01:20<02:18,  5.84it/s]

itr : 017000, Loss : 0.0251


 76%|███████▌  | 966/1272 [02:46<00:52,  5.86it/s]

itr : 017500, Loss : 0.0418


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 15


 15%|█▌        | 194/1272 [00:33<03:04,  5.84it/s]

itr : 018000, Loss : 0.0312


 55%|█████▍    | 694/1272 [01:59<01:40,  5.78it/s]

itr : 018500, Loss : 0.0350


 94%|█████████▍| 1194/1272 [03:25<00:13,  5.81it/s]

itr : 019000, Loss : 0.0348


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 16


 33%|███▎      | 422/1272 [01:12<02:25,  5.84it/s]

itr : 019500, Loss : 0.0415


 72%|███████▏  | 922/1272 [02:38<00:59,  5.84it/s]

itr : 020000, Loss : 0.0337


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 17


 12%|█▏        | 150/1272 [00:26<03:12,  5.84it/s]

itr : 020500, Loss : 0.0444


 51%|█████     | 650/1272 [01:52<01:47,  5.80it/s]

itr : 021000, Loss : 0.0351


 90%|█████████ | 1150/1272 [03:17<00:20,  5.84it/s]

itr : 021500, Loss : 0.0289


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 18


 30%|██▉       | 378/1272 [01:05<02:32,  5.85it/s]

itr : 022000, Loss : 0.0297


 69%|██████▉   | 878/1272 [02:31<01:07,  5.83it/s]

itr : 022500, Loss : 0.0379


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 19


  8%|▊         | 106/1272 [00:18<03:19,  5.84it/s]

itr : 023000, Loss : 0.0307


 48%|████▊     | 606/1272 [01:44<01:53,  5.85it/s]

itr : 023500, Loss : 0.0308


 87%|████████▋ | 1106/1272 [03:10<00:28,  5.83it/s]

itr : 024000, Loss : 0.0442


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 20


 26%|██▋       | 334/1272 [00:57<02:40,  5.84it/s]

itr : 024500, Loss : 0.0283


 66%|██████▌   | 834/1272 [02:23<01:14,  5.85it/s]

itr : 025000, Loss : 0.0344


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 21


  5%|▍         | 62/1272 [00:11<03:27,  5.82it/s]

itr : 025500, Loss : 0.0375


 44%|████▍     | 562/1272 [01:36<02:01,  5.83it/s]

itr : 026000, Loss : 0.0364


 83%|████████▎ | 1062/1272 [03:02<00:35,  5.85it/s]

itr : 026500, Loss : 0.0383


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 22


 23%|██▎       | 290/1272 [00:50<02:47,  5.85it/s]

itr : 027000, Loss : 0.0425


 62%|██████▏   | 790/1272 [02:16<01:22,  5.84it/s]

itr : 027500, Loss : 0.0371


100%|██████████| 1272/1272 [03:39<00:00,  5.80it/s]


Epoch : 23


  1%|▏         | 18/1272 [00:03<03:35,  5.82it/s]

itr : 028000, Loss : 0.0286


 41%|████      | 518/1272 [01:29<02:08,  5.85it/s]

itr : 028500, Loss : 0.0348


 80%|████████  | 1018/1272 [02:55<00:43,  5.81it/s]

itr : 029000, Loss : 0.0299


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 24


 19%|█▉        | 246/1272 [00:42<02:55,  5.84it/s]

itr : 029500, Loss : 0.0329


 59%|█████▊    | 746/1272 [02:08<01:29,  5.85it/s]

itr : 030000, Loss : 0.0292


 98%|█████████▊| 1246/1272 [03:34<00:04,  5.84it/s]

itr : 030500, Loss : 0.0325


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 25


 37%|███▋      | 474/1272 [01:21<02:16,  5.85it/s]

itr : 031000, Loss : 0.0285


 77%|███████▋  | 974/1272 [02:47<00:51,  5.84it/s]

itr : 031500, Loss : 0.0256


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 26


 16%|█▌        | 202/1272 [00:35<03:03,  5.85it/s]

itr : 032000, Loss : 0.0421


 55%|█████▌    | 702/1272 [02:00<01:37,  5.83it/s]

itr : 032500, Loss : 0.0290


 94%|█████████▍| 1202/1272 [03:26<00:12,  5.83it/s]

itr : 033000, Loss : 0.0385


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 27


 34%|███▍      | 430/1272 [01:14<02:25,  5.79it/s]

itr : 033500, Loss : 0.0307


 73%|███████▎  | 930/1272 [02:40<00:58,  5.84it/s]

itr : 034000, Loss : 0.0432


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 28


 12%|█▏        | 158/1272 [00:27<03:11,  5.80it/s]

itr : 034500, Loss : 0.0238


 52%|█████▏    | 658/1272 [01:53<01:45,  5.84it/s]

itr : 035000, Loss : 0.0252


 91%|█████████ | 1158/1272 [03:19<00:19,  5.87it/s]

itr : 035500, Loss : 0.0360


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 29


 30%|███       | 386/1272 [01:06<02:31,  5.83it/s]

itr : 036000, Loss : 0.0379


 70%|██████▉   | 886/1272 [02:32<01:06,  5.85it/s]

itr : 036500, Loss : 0.0472


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 30


  9%|▉         | 114/1272 [00:20<03:18,  5.83it/s]

itr : 037000, Loss : 0.0291


 48%|████▊     | 614/1272 [01:45<01:52,  5.84it/s]

itr : 037500, Loss : 0.0268


 88%|████████▊ | 1114/1272 [03:11<00:27,  5.83it/s]

itr : 038000, Loss : 0.0263


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 31


 27%|██▋       | 342/1272 [00:59<02:39,  5.84it/s]

itr : 038500, Loss : 0.0328


 66%|██████▌   | 842/1272 [02:24<01:13,  5.84it/s]

itr : 039000, Loss : 0.0376


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 32


  6%|▌         | 70/1272 [00:12<03:27,  5.80it/s]

itr : 039500, Loss : 0.0333


 45%|████▍     | 570/1272 [01:38<02:00,  5.84it/s]

itr : 040000, Loss : 0.0313


 84%|████████▍ | 1070/1272 [03:03<00:34,  5.84it/s]

itr : 040500, Loss : 0.0331


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 33


 23%|██▎       | 298/1272 [00:51<02:46,  5.84it/s]

itr : 041000, Loss : 0.0310


 63%|██████▎   | 798/1272 [02:17<01:21,  5.81it/s]

itr : 041500, Loss : 0.0316


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 34


  2%|▏         | 26/1272 [00:05<03:33,  5.85it/s]

itr : 042000, Loss : 0.0347


 41%|████▏     | 526/1272 [01:30<02:08,  5.81it/s]

itr : 042500, Loss : 0.0385


 81%|████████  | 1026/1272 [02:56<00:42,  5.83it/s]

itr : 043000, Loss : 0.0294


100%|██████████| 1272/1272 [03:38<00:00,  5.81it/s]


Epoch : 35


 20%|█▉        | 254/1272 [00:44<02:54,  5.83it/s]

itr : 043500, Loss : 0.0250


 59%|█████▉    | 754/1272 [02:09<01:29,  5.81it/s]

itr : 044000, Loss : 0.0274


 99%|█████████▊| 1254/1272 [03:35<00:03,  5.84it/s]

itr : 044500, Loss : 0.0327


100%|██████████| 1272/1272 [03:38<00:00,  5.82it/s]


Epoch : 36


 38%|███▊      | 482/1272 [01:23<02:15,  5.82it/s]

itr : 045000, Loss : 0.0325


 77%|███████▋  | 982/1272 [02:48<00:49,  5.84it/s]

itr : 045500, Loss : 0.0346


 97%|█████████▋| 1237/1272 [03:32<00:05,  5.85it/s]

In [None]:
from google.colab import files

In [None]:
files.download("samples/100.png")