In [1]:
import os
import json
from random import random
from random import seed as r_seed
import math
from math import log2, floor
from functools import lru_cache, partial
from pathlib import Path
from shutil import rmtree
from contextlib import contextmanager

import torch
from torch.cuda.amp import autocast, GradScaler
from torch.optim import Adam
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import grad as torch_grad

from PIL import Image
import torchvision
from torchvision import transforms
from kornia.filters import filter2d

from tqdm import tqdm
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from adabelief_pytorch import AdaBelief

from retry.api import retry_call
from datetime import datetime

import numpy as np

In [2]:
# asserts

assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'

# constants

EXTS = ['jpg', 'jpeg', 'png', 'tiff']

# helpers

def exists(val):
    return val is not None

@contextmanager
def null_context():
    yield

def is_power_of_two(val):
    return log2(val).is_integer()

def default(val, d):
    return val if exists(val) else d

def set_requires_grad(model, bool):
    for p in model.parameters():
        p.requires_grad = bool

def cycle(iterable):
    while True:
        for i in iterable:
            yield i

def raise_if_nan(t):
    if torch.isnan(t):
        raise NanException

def evaluate_in_chunks(max_batch_size, model, *args):
    split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
    chunked_outputs = [model(*i) for i in split_args]
    if len(chunked_outputs) == 1:
        return chunked_outputs[0]
    return torch.cat(chunked_outputs, dim=0)

def slerp(val, low, high):
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm * high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
    return res

def safe_div(n, d):
    try:
        res = n / d
    except ZeroDivisionError:
        prefix = '' if int(n >= 0) else '-'
        res = float(f'{prefix}inf')
    return res

# loss functions

def gen_hinge_loss(fake, real):
    return fake.mean()

def hinge_loss(real, fake):
    return (F.relu(1 + real) + F.relu(1 - fake)).mean()

def dual_contrastive_loss(real_logits, fake_logits):
    device = real_logits.device
    real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))

    def loss_half(t1, t2):
        t1 = rearrange(t1, 'i -> i ()')
        t2 = repeat(t2, 'j -> i j', i = t1.shape[0])
        t = torch.cat((t1, t2), dim = -1)
        return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))

    return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits)

@lru_cache(maxsize=10)
def det_randn(*args):
    """
    deterministic random to track the same latent vars (and images) across training steps
    helps to visualize same image over training steps
    """
    return torch.randn(*args)

def interpolate_between(a, b, *, num_samples, dim):
    assert num_samples > 2
    samples = []
    step_size = 0
    for _ in range(num_samples):
        sample = torch.lerp(a, b, step_size)
        samples.append(sample)
        step_size += 1 / (num_samples - 1)
    return torch.stack(samples, dim=dim)

# helper classes

class NanException(Exception):
    pass

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
    def update_average(self, old, new):
        if not exists(old):
            return new
        return old * self.beta + (1 - self.beta) * new

class ChanNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = ChanNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x))

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

class SumBranches(nn.Module):
    def __init__(self, branches):
        super().__init__()
        self.branches = nn.ModuleList(branches)
    def forward(self, x):
        return sum(map(lambda fn: fn(x), self.branches))

class Blur(nn.Module):
    def __init__(self):
        super().__init__()
        f = torch.Tensor([1, 2, 1])
        self.register_buffer('f', f)
    def forward(self, x):
        f = self.f
        f = f[None, None, :] * f [None, :, None]
        return filter2d(x, f, normalized=True)

class Noise(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1))

    def forward(self, x, noise = None):
        b, _, h, w, device = *x.shape, x.device

        if not exists(noise):
            noise = torch.randn(b, 1, h, w, device = device)

        return x + self.weight * noise

def Conv2dSame(dim_in, dim_out, kernel_size, bias = True):
    pad_left = kernel_size // 2
    pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left

    return nn.Sequential(
        nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)),
        nn.Conv2d(dim_in, dim_out, kernel_size, bias = bias)
    )

# attention

class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

class LinearAttention(nn.Module):
    def __init__(self, dim, dim_head = 32, heads = 4, kernel_size = 3):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.dim_head = dim_head
        inner_dim = dim_head * heads

        self.kernel_size = kernel_size
        self.nonlin = nn.GELU()

        self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False)

        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)

        self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)

    def forward(self, fmap):
        h, x, y = self.heads, *fmap.shape[-2:]

        # linear attention

        lin_q, lin_k, lin_v = (self.to_lin_q(fmap), *self.to_lin_kv(fmap).chunk(2, dim = 1))
        lin_q, lin_k, lin_v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (lin_q, lin_k, lin_v))

        lin_q = lin_q.softmax(dim = -1)
        lin_k = lin_k.softmax(dim = -2)

        lin_q = lin_q * self.scale

        context = einsum('b n d, b n e -> b d e', lin_k, lin_v)
        lin_out = einsum('b n d, b d e -> b n e', lin_q, context)
        lin_out = rearrange(lin_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

        # conv-like full attention

        q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = h), (q, k, v))

        k = F.unfold(k, kernel_size = self.kernel_size, padding = self.kernel_size // 2)
        v = F.unfold(v, kernel_size = self.kernel_size, padding = self.kernel_size // 2)

        k, v = map(lambda t: rearrange(t, 'b (d j) n -> b n j d', d = self.dim_head), (k, v))

        q = rearrange(q, 'b c ... -> b (...) c') * self.scale

        sim = einsum('b i d, b i j d -> b i j', q, k)
        sim = sim - sim.amax(dim = -1, keepdim = True).detach()

        attn = sim.softmax(dim = -1)

        full_out = einsum('b i j, b i j d -> b i d', attn, v)
        full_out = rearrange(full_out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

        # add outputs of linear attention + conv like full attention

        lin_out = self.nonlin(lin_out)
        out = torch.cat((lin_out, full_out), dim = 1)
        return self.to_out(out)

# dataset

def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

class identity(object):
    def __call__(self, tensor):
        return tensor

class ImageDataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
        assert len(self.paths) > 0, f'No images were found in {folder} for training'

        pillow_mode = 'RGB'

        convert_image_fn = partial(convert_image_to, pillow_mode)

        self.transform = transforms.Compose([
            transforms.Lambda(convert_image_fn),
            transforms.Resize(image_size),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# augmentations

def random_hflip(tensor, prob):
    if prob > random():
        return tensor
    return torch.flip(tensor, dims=(3,))

class AugWrapper(nn.Module):
    def __init__(self, D, image_size):
        super().__init__()
        self.D = D

    def forward(self, images, prob = 0., types = [], detach = False, **kwargs):
        context = torch.no_grad if detach else null_context

        with context():
            if random() < prob:
                images = random_hflip(images, prob=0.5)

        return self.D(images, **kwargs)

# modifiable global variables

norm_class = nn.BatchNorm2d

class PixelShuffleUpsample(nn.Module):
    def __init__(self, dim, dim_out = None):
        super().__init__()
        dim_out = default(dim_out, dim)
        conv = nn.Conv2d(dim, dim_out * 4, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            nn.PixelShuffle(2)
        )

        self.init_conv_(conv)

    def init_conv_(self, conv):
        o, i, h, w = conv.weight.shape
        conv_weight = torch.empty(o // 4, i, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def forward(self, x):
        return self.net(x)

def SPConvDownsample(dim, dim_out = None):
    # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
    # named SP-conv in the paper, but basically a pixel unshuffle
    dim_out = default(dim_out, dim)
    return nn.Sequential(
        Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
        nn.Conv2d(dim * 4, dim_out, 1)
    )

# squeeze excitation classes

# global context network
# https://arxiv.org/abs/2012.13375
# similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm

class GlobalContext(nn.Module):
    def __init__(
        self,
        *,
        chan_in,
        chan_out
    ):
        super().__init__()
        self.to_k = nn.Conv2d(chan_in, 1, 1)
        chan_intermediate = max(3, chan_out // 2)

        self.net = nn.Sequential(
            nn.Conv2d(chan_in, chan_intermediate, 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(chan_intermediate, chan_out, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        context = self.to_k(x)
        context = context.flatten(2).softmax(dim = -1)
        out = einsum('b i n, b c n -> b c i', context, x.flatten(2))
        out = out.unsqueeze(-1)
        return self.net(out)

# frequency channel attention
# https://arxiv.org/abs/2012.11879

def get_1d_dct(i, freq, L):
    result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
    return result * (1 if freq == 0 else math.sqrt(2))

def get_dct_weights(width, channel, fidx_u, fidx_v):
    dct_weights = torch.zeros(1, channel, width, width)
    c_part = channel // len(fidx_u)

    for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
        for x in range(width):
            for y in range(width):
                coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
                dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value

    return dct_weights

class FCANet(nn.Module):
    def __init__(
        self,
        *,
        chan_in,
        chan_out,
        reduction = 4,
        width
    ):
        super().__init__()

        freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal
        dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w])
        self.register_buffer('dct_weights', dct_weights)

        chan_intermediate = max(3, chan_out // reduction)

        self.net = nn.Sequential(
            nn.Conv2d(chan_in, chan_intermediate, 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(chan_intermediate, chan_out, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1)
        return self.net(x)

# generative adversarial network

class Generator(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        latent_dim = 256,
        fmap_max = 256,
        fmap_inverse_coef = 10,
        attn_res_layers = [],
        freq_chan_attn = False
    ):
        super().__init__()
        resolution = log2(image_size)
        assert is_power_of_two(image_size), 'image size must be a power of 2'

        init_channel = 3

        fmap_max = default(fmap_max, latent_dim)

        self.initial_conv = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
            norm_class(latent_dim * 2),
            nn.GLU(dim = 1)
        )

        num_layers = int(resolution) - 2
        features = list(map(lambda n: (n,  2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2)))
        features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
        features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
        features = [latent_dim, *features]

        in_out_features = list(zip(features[:-1], features[1:]))

        self.res_layers = range(2, num_layers + 2)
        self.layers = nn.ModuleList([])
        self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))

        self.sle_map = ((2, 5), (3, 6), (4, 7))
        self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
        self.sle_map = dict(self.sle_map)

        self.num_layers_spatial_res = 1

        for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
            image_width = 2 ** res

            attn = None
            if image_width in attn_res_layers:
                attn = PreNorm(chan_in, LinearAttention(chan_in))

            sle = None
            if res in self.sle_map:
                residual_layer = self.sle_map[res]
                sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]

                if freq_chan_attn:
                    sle = FCANet(
                        chan_in = chan_out,
                        chan_out = sle_chan_out,
                        width = 2 ** (res + 1)
                    )
                else:
                    sle = GlobalContext(
                        chan_in = chan_out,
                        chan_out = sle_chan_out
                    )

            layer = nn.ModuleList([
                nn.Sequential(
                    PixelShuffleUpsample(chan_in),
                    Blur(),
                    Conv2dSame(chan_in, chan_out * 2, 4),
                    Noise(),
                    norm_class(chan_out * 2),
                    nn.GLU(dim = 1)
                ),
                sle,
                attn
            ])
            self.layers.append(layer)

        self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding = 1)

    def forward(self, x):
        x = rearrange(x, 'b c -> b c () ()')
        x = self.initial_conv(x)
        x = F.normalize(x, dim = 1)

        residuals = dict()

        for (res, (up, sle, attn)) in zip(self.res_layers, self.layers):
            if exists(attn):
                x = attn(x) + x

            x = up(x)

            if exists(sle):
                out_res = self.sle_map[res]
                residual = sle(x)
                residuals[out_res] = residual

            next_res = res + 1
            if next_res in residuals:
                x = x * residuals[next_res]

        return self.out_conv(x)

class SimpleDecoder(nn.Module):
    def __init__(
        self,
        *,
        chan_in,
        chan_out = 3,
        num_upsamples = 4,
    ):
        super().__init__()

        self.layers = nn.ModuleList([])
        final_chan = chan_out
        chans = chan_in

        for ind in range(num_upsamples):
            last_layer = ind == (num_upsamples - 1)
            chan_out = chans if not last_layer else final_chan * 2
            layer = nn.Sequential(
                PixelShuffleUpsample(chans),
                nn.Conv2d(chans, chan_out, 3, padding = 1),
                nn.GLU(dim = 1)
            )
            self.layers.append(layer)
            chans //= 2

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class Discriminator(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        fmap_max = 256,
        fmap_inverse_coef = 10,
        disc_output_size = 5,
        attn_res_layers = []
    ):
        super().__init__()
        resolution = log2(image_size)
        assert is_power_of_two(image_size), 'image size must be a power of 2'
        assert disc_output_size in {1, 5}, 'discriminator output dimensions can only be 5x5 or 1x1'

        resolution = int(resolution)

        init_channel = 3

        num_non_residual_layers = max(0, int(resolution) - 8)

        non_residual_resolutions = range(min(8, resolution), 2, -1)
        features = list(map(lambda n: (n,  2 ** (fmap_inverse_coef - n)), non_residual_resolutions))
        features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))

        if num_non_residual_layers == 0:
            res, _ = features[0]
            features[0] = (res, init_channel)

        chan_in_out = list(zip(features[:-1], features[1:]))

        self.non_residual_layers = nn.ModuleList([])
        for ind in range(num_non_residual_layers):
            last_layer = ind == (num_non_residual_layers - 1)
            chan_out = features[0][-1] if last_layer else init_channel

            self.non_residual_layers.append(nn.Sequential(
                Blur(),
                nn.Conv2d(init_channel, chan_out, 4, stride = 2, padding = 1),
                nn.LeakyReLU(0.1)
            ))

        self.residual_layers = nn.ModuleList([])

        for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out):
            image_width = 2 ** res

            attn = None
            if image_width in attn_res_layers:
                attn = PreNorm(chan_in, LinearAttention(chan_in))

            self.residual_layers.append(nn.ModuleList([
                SumBranches([
                    nn.Sequential(
                        Blur(),
                        SPConvDownsample(chan_in, chan_out),
                        nn.LeakyReLU(0.1),
                        nn.Conv2d(chan_out, chan_out, 3, padding = 1),
                        nn.LeakyReLU(0.1)
                    ),
                    nn.Sequential(
                        Blur(),
                        nn.AvgPool2d(2),
                        nn.Conv2d(chan_in, chan_out, 1),
                        nn.LeakyReLU(0.1),
                    )
                ]),
                attn
            ]))

        last_chan = features[-1][-1]
        if disc_output_size == 5:
            self.to_logits = nn.Sequential(
                nn.Conv2d(last_chan, last_chan, 1),
                nn.LeakyReLU(0.1),
                nn.Conv2d(last_chan, 1, 4)
            )
        elif disc_output_size == 1:
            self.to_logits = nn.Sequential(
                Blur(),
                nn.Conv2d(last_chan, last_chan, 3, stride = 2, padding = 1),
                nn.LeakyReLU(0.1),
                nn.Conv2d(last_chan, 1, 4)
            )

        self.to_shape_disc_out = nn.Sequential(
            nn.Conv2d(init_channel, 64, 3, padding = 1),
            Residual(PreNorm(64, LinearAttention(64))),
            SumBranches([
                nn.Sequential(
                    Blur(),
                    SPConvDownsample(64, 32),
                    nn.LeakyReLU(0.1),
                    nn.Conv2d(32, 32, 3, padding = 1),
                    nn.LeakyReLU(0.1)
                ),
                nn.Sequential(
                    Blur(),
                    nn.AvgPool2d(2),
                    nn.Conv2d(64, 32, 1),
                    nn.LeakyReLU(0.1),
                )
            ]),
            Residual(PreNorm(32, LinearAttention(32))),
            nn.AdaptiveAvgPool2d((4, 4)),
            nn.Conv2d(32, 1, 4)
        )

        self.decoder1 = SimpleDecoder(chan_in = last_chan, chan_out = init_channel)
        self.decoder2 = SimpleDecoder(chan_in = features[-2][-1], chan_out = init_channel) if resolution >= 9 else None

    def forward(self, x, calc_aux_loss = False):
        orig_img = x

        for layer in self.non_residual_layers:
            x = layer(x)

        layer_outputs = []

        for (net, attn) in self.residual_layers:
            if exists(attn):
                x = attn(x) + x

            x = net(x)
            layer_outputs.append(x)

        out = self.to_logits(x).flatten(1)

        img_32x32 = F.interpolate(orig_img, size = (32, 32))
        out_32x32 = self.to_shape_disc_out(img_32x32)

        if not calc_aux_loss:
            return out, out_32x32, None

        # self-supervised auto-encoding loss

        layer_8x8 = layer_outputs[-1]
        layer_16x16 = layer_outputs[-2]

        recon_img_8x8 = self.decoder1(layer_8x8)

        aux_loss = F.mse_loss(
            recon_img_8x8,
            F.interpolate(orig_img, size = recon_img_8x8.shape[2:])
        )

        if exists(self.decoder2):
            select_random_quadrant = lambda rand_quadrant, img: rearrange(img, 'b c (m h) (n w) -> (m n) b c h w', m = 2, n = 2)[rand_quadrant]
            crop_image_fn = partial(select_random_quadrant, floor(random() * 4))
            img_part, layer_16x16_part = map(crop_image_fn, (orig_img, layer_16x16))

            recon_img_16x16 = self.decoder2(layer_16x16_part)

            aux_loss_16x16 = F.mse_loss(
                recon_img_16x16,
                F.interpolate(img_part, size = recon_img_16x16.shape[2:])
            )

            aux_loss = aux_loss + aux_loss_16x16

        return out, out_32x32, aux_loss

class LightweightGAN(nn.Module):
    def __init__(
        self,
        *,
        latent_dim,
        image_size,
        optimizer = "adam",
        fmap_max = 256,
        fmap_inverse_coef = 10,
        disc_output_size = 5,
        attn_res_layers = [],
        freq_chan_attn = False,
        ttur_mult = 1.,
        lr = 2e-4,
        rank = 0
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.image_size = image_size

        G_kwargs = dict(
            image_size = image_size,
            latent_dim = latent_dim,
            fmap_max = fmap_max,
            fmap_inverse_coef = fmap_inverse_coef,
            attn_res_layers = attn_res_layers,
            freq_chan_attn = freq_chan_attn
        )

        self.G = Generator(**G_kwargs)

        self.D = Discriminator(
            image_size = image_size,
            fmap_max = fmap_max,
            fmap_inverse_coef = fmap_inverse_coef,
            attn_res_layers = attn_res_layers,
            disc_output_size = disc_output_size
        )

        self.ema_updater = EMA(0.995)
        self.GE = Generator(**G_kwargs)
        set_requires_grad(self.GE, False)


        if optimizer == "adam":
            self.G_opt = Adam(self.G.parameters(), lr = lr, betas=(0.5, 0.9))
            self.D_opt = Adam(self.D.parameters(), lr = lr * ttur_mult, betas=(0.5, 0.9))
        elif optimizer == "adabelief":
            self.G_opt = AdaBelief(self.G.parameters(), lr = lr, betas=(0.5, 0.9))
            self.D_opt = AdaBelief(self.D.parameters(), lr = lr * ttur_mult, betas=(0.5, 0.9))
        else:
            assert False, "No valid optimizer is given"

        self.apply(self._init_weights)
        self.reset_parameter_averaging()

        self.cuda(rank)
        self.D_aug = AugWrapper(self.D, image_size)

    def _init_weights(self, m):
        if type(m) in {nn.Conv2d, nn.Linear}:
            nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

    def EMA(self):
        def update_moving_average(ma_model, current_model):
            for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
                old_weight, up_weight = ma_params.data, current_params.data
                ma_params.data = self.ema_updater.update_average(old_weight, up_weight)

            for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
                new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer)
                ma_buffer.copy_(new_buffer_value)

        update_moving_average(self.GE, self.G)

    def reset_parameter_averaging(self):
        self.GE.load_state_dict(self.G.state_dict())

    def forward(self, x):
        raise NotImplemented

# trainer

class Trainer():
    def __init__(
        self,
        name = 'default',
        results_dir = 'results',
        models_dir = 'models',
        base_dir = './',
        optimizer = 'adam',
        num_workers = None,
        latent_dim = 256,
        image_size = 128,
        num_image_tiles = 8,
        fmap_max = 512,
        batch_size = 64,
        gp_weight = 10,
        gradient_accumulate_every = 1,
        attn_res_layers = [],
        freq_chan_attn = False,
        disc_output_size = 5,
        dual_contrast_loss = False,
        antialias = False,
        lr = 2e-4,
        ttur_mult = 1.,
        save_every = 1000,
        evaluate_every = 1000,
        aug_prob = None,
        aug_types = ['translation', 'cutout'],
        dataset_aug_prob = 0.,
        rank = 0,
        amp = False,
        hparams = None,
        load_strict = True,
        *args,
        **kwargs
    ):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir

        self.config_path = self.models_dir / name / '.config.json'

        assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        assert all(map(is_power_of_two, attn_res_layers)), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'

        assert not (dual_contrast_loss and disc_output_size > 1), 'discriminator output size cannot be greater than 1 if using dual contrastive loss'

        self.image_size = image_size
        self.num_image_tiles = num_image_tiles

        self.latent_dim = latent_dim
        self.fmap_max = fmap_max

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.optimizer = optimizer
        self.num_workers = num_workers
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.gp_weight = gp_weight

        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.attn_res_layers = attn_res_layers
        self.freq_chan_attn = freq_chan_attn

        self.disc_output_size = disc_output_size
        self.antialias = antialias

        self.dual_contrast_loss = dual_contrast_loss

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None

        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.is_main = rank == 0
        self.rank = rank


        self.load_strict = load_strict

        self.amp = amp
        self.G_scaler = GradScaler(enabled = self.amp)
        self.D_scaler = GradScaler(enabled = self.amp)

        self.hparams = hparams

        self.image_extension = 'jpg'

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)
        
    def init_GAN(self):
        args, kwargs = self.GAN_params

        # set some global variables before instantiating GAN

        global norm_class
        global Blur

        Blur = nn.Identity if not self.antialias else Blur

        # instantiate GAN

        self.GAN = LightweightGAN(
            optimizer=self.optimizer,
            lr = self.lr,
            latent_dim = self.latent_dim,
            attn_res_layers = self.attn_res_layers,
            freq_chan_attn = self.freq_chan_attn,
            image_size = self.image_size,
            ttur_mult = self.ttur_mult,
            fmap_max = self.fmap_max,
            disc_output_size = self.disc_output_size,
            rank = self.rank,
            *args,
            **kwargs
        )

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
        self.image_size = config['image_size']
        self.disc_output_size = config['disc_output_size']
        self.attn_res_layers = config.pop('attn_res_layers', [])
        self.freq_chan_attn = config.pop('freq_chan_attn', False)
        self.optimizer = config.pop('optimizer', 'adam')
        self.fmap_max = config.pop('fmap_max', 256)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {
            'image_size': self.image_size,
            'disc_output_size': self.disc_output_size,
            'optimizer': self.optimizer,
            'attn_res_layers': self.attn_res_layers,
            'freq_chan_attn': self.freq_chan_attn
        }

    def set_data_src(self, folder):
        self.dataset = ImageDataset(folder, self.image_size)
        dataloader = DataLoader(self.dataset, batch_size = self.batch_size, shuffle = True, drop_last = True, pin_memory = True)
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')

    def train(self):
        assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
        device = torch.device(f'cuda:{self.rank}')

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.zeros([], device=device)
        total_gen_loss = torch.zeros([], device=device)

        batch_size = self.batch_size

        latent_dim = self.GAN.latent_dim

        aug_prob   = default(self.aug_prob, 0)
        aug_types  = self.aug_types
        aug_kwargs = {'prob': aug_prob, 'types': aug_types}

        G = self.GAN.G
        D = self.GAN.D
        D_aug = self.GAN.D_aug

        apply_gradient_penalty = self.steps % 4 == 0

        # amp related contexts and functions

        amp_context = autocast if self.amp else null_context

        # discriminator loss fn

        if self.dual_contrast_loss:
            D_loss_fn = dual_contrastive_loss
        else:
            D_loss_fn = hinge_loss

        # train discriminator

        self.GAN.D_opt.zero_grad()

        latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
        image_batch = next(self.loader).cuda(self.rank)
        image_batch.requires_grad_()

        with amp_context():
            with torch.no_grad():
                generated_images = G(latents)

            fake_output, fake_output_32x32, _ = D_aug(generated_images, detach = True, **aug_kwargs)

            real_output, real_output_32x32, real_aux_loss = D_aug(image_batch,  calc_aux_loss = True, **aug_kwargs)

            real_output_loss = real_output
            fake_output_loss = fake_output

            divergence = D_loss_fn(real_output_loss, fake_output_loss)
            divergence_32x32 = D_loss_fn(real_output_32x32, fake_output_32x32)
            disc_loss = divergence + divergence_32x32

            aux_loss = real_aux_loss
            disc_loss = disc_loss + aux_loss

        if apply_gradient_penalty:
            outputs = [real_output, real_output_32x32]
            outputs = list(map(self.D_scaler.scale, outputs)) if self.amp else outputs

            scaled_gradients = torch_grad(outputs=outputs, inputs=image_batch,
                                    grad_outputs=list(map(lambda t: torch.ones(t.size(), device = image_batch.device), outputs)),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]

            inv_scale = safe_div(1., self.D_scaler.get_scale()) if self.amp else 1.

            if inv_scale != float('inf'):
                gradients = scaled_gradients * inv_scale

                with amp_context():
                    gradients = gradients.reshape(batch_size, -1)
                    gp =  self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

                    if not torch.isnan(gp):
                        disc_loss = disc_loss + gp
                        self.last_gp_loss = gp.clone().detach().item()

        with amp_context():
            disc_loss = disc_loss / self.gradient_accumulate_every

        disc_loss.register_hook(raise_if_nan)
        self.D_scaler.scale(disc_loss).backward()
        total_disc_loss += divergence

        self.last_recon_loss = aux_loss.item()
        self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every)
        self.D_scaler.step(self.GAN.D_opt)
        self.D_scaler.update()

        # generator loss fn

        if self.dual_contrast_loss:
            G_loss_fn = dual_contrastive_loss
            G_requires_calc_real = True
        else:
            G_loss_fn = gen_hinge_loss
            G_requires_calc_real = False

        # train generator

        self.GAN.G_opt.zero_grad()

        latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

        if G_requires_calc_real:
            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()

        with amp_context():
            generated_images = G(latents)

            fake_output, fake_output_32x32, _ = D_aug(generated_images, **aug_kwargs)
            real_output, real_output_32x32, _ = D_aug(image_batch, **aug_kwargs) if G_requires_calc_real else (None, None, None)

            loss = G_loss_fn(fake_output, real_output)
            loss_32x32 = G_loss_fn(fake_output_32x32, real_output_32x32)

            gen_loss = loss + loss_32x32

            gen_loss = gen_loss / self.gradient_accumulate_every

        gen_loss.register_hook(raise_if_nan)
        self.G_scaler.scale(gen_loss).backward()
        total_gen_loss += loss

        self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every)
        self.G_scaler.step(self.GAN.G_opt)
        self.G_scaler.update()

        # calculate moving averages

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}')
            self.load(self.checkpoint_num)
            raise NanException

        del total_disc_loss
        del total_gen_loss

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 20000):
                self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles = self.num_image_tiles)

        self.steps += 1

    @torch.no_grad()
    def evaluate(self, num = 0, num_image_tiles = 4):
        self.GAN.eval()

        ext = self.image_extension
        num_rows = num_image_tiles
    
        latent_dim = self.GAN.latent_dim

        latents = det_randn((num_rows ** 2, latent_dim)).cuda(self.rank)
        interpolate_latents = interpolate_between(latents[:num_rows], latents[-num_rows:],
                                                  num_samples=num_rows,
                                                  dim=0).flatten(end_dim=1)

        generate_interpolations = self.generate_(self.GAN.G, interpolate_latents)
        torchvision.utils.save_image(generate_interpolations, str(self.results_dir / self.name / f'{str(num)}-interp.{ext}'), nrow=num_rows)
        # regular

        generated_images = self.generate_(self.GAN.G, latents)

        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)

        # moving averages

        generated_images = self.generate_(self.GAN.GE, latents)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)

    @torch.no_grad()
    def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']):
        self.GAN.eval()

        latent_dim = self.GAN.latent_dim
        dir_name = self.name + str('-generated-') + str(checkpoint)
        dir_full = Path().absolute() / self.results_dir / dir_name
        ext = self.image_extension

        if not dir_full.exists():
            os.mkdir(dir_full)

        # regular
        if 'default' in types:
            for i in tqdm(range(num_image_tiles), desc='Saving generated default images'):
                latents = torch.randn((1, latent_dim)).cuda(self.rank)
                generated_image = self.generate_(self.GAN.G, latents)
                path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}')
                torchvision.utils.save_image(generated_image[0], path, nrow=1)

        # moving averages
        if 'ema' in types:
            for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'):
                latents = torch.randn((1, latent_dim)).cuda(self.rank)
                generated_image = self.generate_(self.GAN.GE, latents)
                path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}')
                torchvision.utils.save_image(generated_image[0], path, nrow=1)

        return dir_full

    @torch.no_grad()
    def show_progress(self, num_images=4, types=['default', 'ema']):
        checkpoints = self.get_checkpoints()
        assert exists(checkpoints), 'cannot find any checkpoints to create a training progress video for'

        dir_name = self.name + str('-progress')
        dir_full = Path().absolute() / self.results_dir / dir_name
        ext = self.image_extension
        latents = None

        zfill_length = math.ceil(math.log10(len(checkpoints)))

        if not dir_full.exists():
            os.mkdir(dir_full)

        for checkpoint in tqdm(checkpoints, desc='Generating progress images'):
            self.load(checkpoint, print_version=False)
            self.GAN.eval()

            if checkpoint == 0:
                latents = torch.randn((num_images, self.GAN.latent_dim)).cuda(self.rank)

            # regular
            if 'default' in types:
                generated_image = self.generate_(self.GAN.G, latents)
                path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}.{ext}')
                torchvision.utils.save_image(generated_image, path, nrow=num_images)

            # moving averages
            if 'ema' in types:
                generated_image = self.generate_(self.GAN.GE, latents)
                path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}-ema.{ext}')
                torchvision.utils.save_image(generated_image, path, nrow=num_images)

    @torch.no_grad()
    def generate_(self, G, style, num_image_tiles = 8):
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    def generate_interpolation(self, num = 0, num_image_tiles = 8, num_steps = 100, save_frames = False):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim

        # latents and noise

        latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
        latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)

        ratios = torch.linspace(0., 8., num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            generated_images = self.generate_(self.GAN.GE, interp_latents)
            images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())
                
            frames.append(pil_image)

        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)

        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}'))

    def print_log(self):
        data = [
            ('G', self.g_loss),
            ('D', self.d_loss),
            ('GP', self.last_gp_loss),
            ('SS', self.last_recon_loss),
        ]

        data = [d for d in data if exists(d[1])]
        log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
        print(log)

        return data

    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {
            'GAN': self.GAN.state_dict(),
            'G_scaler': self.G_scaler.state_dict(),
            'D_scaler': self.D_scaler.state_dict()
        }

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num=-1, print_version=True):
        self.load_config()

        name = num
        if num == -1:
            checkpoints = self.get_checkpoints()

            if not exists(checkpoints):
                return

            name = checkpoints[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        if print_version and 'version' in load_data and self.is_main:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data['GAN'], strict = self.load_strict)
        except Exception as e:
            print('unable to load save model.')
            raise e

        if 'G_scaler' in load_data:
            self.G_scaler.load_state_dict(load_data['G_scaler'])
        if 'D_scaler' in load_data:
            self.D_scaler.load_state_dict(load_data['D_scaler'])

    def get_checkpoints(self):
        file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
        saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))

        if len(saved_nums) == 0:
            return None

        return saved_nums

In [3]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cast_list(el):
    return el if isinstance(el, list) else [el]

def timestamped_filename(prefix = 'generated-'):
    now = datetime.now()
    timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
    return f'{prefix}{timestamp}'

def set_seed(seed):
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    r_seed(seed)

def run_training(rank, model_args, data, load_from, new, num_train_steps, name, seed):
    is_main = rank == 0

    model_args.update(
        rank = rank,
    )

    model = Trainer(**model_args, hparams=model_args)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    model.set_data_src(data)

    progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>')
    while model.steps < num_train_steps:
        retry_call(model.train, tries=3, exceptions=NanException)
        progress_bar.n = model.steps
        progress_bar.refresh()
        if is_main and model.steps % 50 == 0:
            model.print_log()

    model.save(model.checkpoint_num)


def train_from_folder(
    data = './images',
    results_dir = './results',
    models_dir = './models',
    name = 'default',
    new = False,
    load_from = -1,
    image_size = 128,
    optimizer = 'adam',
    fmap_max = 256,
    batch_size = 64,
    gradient_accumulate_every = 4,
    num_train_steps = 150000,
    learning_rate = 2e-4,
    save_every = 1000,
    evaluate_every = 1000,
    generate = False,
    generate_types = ['default', 'ema'],
    generate_interpolation = False,
    aug_prob=None,
    aug_types=['cutout', 'translation'],
    dataset_aug_prob=0.,
    attn_res_layers = [],
    freq_chan_attn = False,
    disc_output_size = 1,
    dual_contrast_loss = False,
    antialias = False,
    interpolation_num_steps = 100,
    save_frames = False,
    num_image_tiles = None,
    num_workers = None,
    seed = 42,
    amp = False,
    show_progress = False,
    load_strict = True
):
    num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)

    model_args = dict(
        name = name,
        results_dir = results_dir,
        models_dir = models_dir,
        batch_size = batch_size,
        gradient_accumulate_every = gradient_accumulate_every,
        attn_res_layers = cast_list(attn_res_layers),
        freq_chan_attn = freq_chan_attn,
        disc_output_size = disc_output_size,
        dual_contrast_loss = dual_contrast_loss,
        antialias = antialias,
        image_size = image_size,
        num_image_tiles = num_image_tiles,
        optimizer = optimizer,
        num_workers = num_workers,
        fmap_max = fmap_max,
        lr = learning_rate,
        save_every = save_every,
        evaluate_every = evaluate_every,
        aug_prob = aug_prob,
        aug_types = cast_list(aug_types),
        dataset_aug_prob = dataset_aug_prob,
        amp = amp,
        load_strict = load_strict
    )

    if generate:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        checkpoint = model.checkpoint_num
        dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
        print(f'sample images generated at {dir_result}')
        return

    if generate_interpolation:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
        print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    if show_progress:
        model = Trainer(**model_args)
        model.show_progress(num_images=num_image_tiles, types=generate_types)
        return


    run_training(0, model_args, data, load_from, new, num_train_steps, name, seed)
    return

In [4]:
train_from_folder()

autosetting augmentation probability to 9%


default<./images>:   0%|          | 50/150000 [00:27<22:52:01,  1.82it/s]

G: 0.64 | D: 0.00 | GP: 1.86 | SS: 0.20


default<./images>:   0%|          | 100/150000 [00:52<21:56:23,  1.90it/s]

G: 0.50 | D: 0.08 | GP: 1.39 | SS: 0.06


default<./images>:   0%|          | 150/150000 [01:17<21:33:49,  1.93it/s]

G: 0.35 | D: 0.43 | GP: 0.60 | SS: 0.03


default<./images>:   0%|          | 200/150000 [01:42<21:24:41,  1.94it/s]

G: 0.25 | D: 0.55 | GP: 0.52 | SS: 0.02


default<./images>:   0%|          | 250/150000 [02:07<21:17:18,  1.95it/s]

G: 0.28 | D: 0.58 | GP: 0.79 | SS: 0.02


default<./images>:   0%|          | 300/150000 [02:33<21:14:43,  1.96it/s]

G: 0.30 | D: 0.44 | GP: 0.45 | SS: 0.02


default<./images>:   0%|          | 350/150000 [02:58<21:12:14,  1.96it/s]

G: 0.16 | D: 0.43 | GP: 0.30 | SS: 0.02


default<./images>:   0%|          | 400/150000 [03:23<21:09:43,  1.96it/s]

G: 0.14 | D: 0.50 | GP: 0.42 | SS: 0.01


default<./images>:   0%|          | 450/150000 [03:49<21:11:17,  1.96it/s]

G: 0.44 | D: 0.44 | GP: 0.21 | SS: 0.01


default<./images>:   0%|          | 500/150000 [04:14<21:10:36,  1.96it/s]

G: 0.10 | D: 0.51 | GP: 0.10 | SS: 0.01


default<./images>:   0%|          | 550/150000 [04:40<21:11:12,  1.96it/s]

G: 0.03 | D: 0.45 | GP: 0.32 | SS: 0.01


default<./images>:   0%|          | 600/150000 [05:05<21:09:35,  1.96it/s]

G: 0.30 | D: 0.62 | GP: 0.27 | SS: 0.01


default<./images>:   0%|          | 650/150000 [05:31<21:09:57,  1.96it/s]

G: 0.06 | D: 0.60 | GP: 0.25 | SS: 0.01


default<./images>:   0%|          | 700/150000 [05:57<21:09:03,  1.96it/s]

G: 0.23 | D: 0.29 | GP: 0.34 | SS: 0.01


default<./images>:   0%|          | 750/150000 [06:22<21:09:30,  1.96it/s]

G: -0.41 | D: 0.45 | GP: 0.50 | SS: 0.01


default<./images>:   1%|          | 800/150000 [06:47<21:06:30,  1.96it/s]

G: 0.20 | D: 0.39 | GP: 0.30 | SS: 0.01


default<./images>:   1%|          | 850/150000 [07:12<21:03:57,  1.97it/s]

G: 0.06 | D: 0.66 | GP: 0.26 | SS: 0.01


default<./images>:   1%|          | 900/150000 [07:36<21:01:46,  1.97it/s]

G: 0.17 | D: 0.43 | GP: 0.37 | SS: 0.01


default<./images>:   1%|          | 950/150000 [08:02<21:02:14,  1.97it/s]

G: -0.03 | D: 0.55 | GP: 0.10 | SS: 0.01


default<./images>:   1%|          | 1000/150000 [08:27<21:01:20,  1.97it/s]

G: 0.10 | D: 0.51 | GP: 0.34 | SS: 0.01


default<./images>:   1%|          | 1050/150000 [08:53<21:01:36,  1.97it/s]

G: 0.16 | D: 0.45 | GP: 0.38 | SS: 0.01


default<./images>:   1%|          | 1100/150000 [09:18<21:00:27,  1.97it/s]

G: -0.04 | D: 0.54 | GP: 0.25 | SS: 0.01


default<./images>:   1%|          | 1150/150000 [09:42<20:56:02,  1.98it/s]

G: 0.12 | D: 0.51 | GP: 0.34 | SS: 0.01


default<./images>:   1%|          | 1200/150000 [10:05<20:51:33,  1.98it/s]

G: 0.13 | D: 0.42 | GP: 0.40 | SS: 0.01


default<./images>:   1%|          | 1250/150000 [10:29<20:48:06,  1.99it/s]

G: 0.14 | D: 0.49 | GP: 0.20 | SS: 0.01


default<./images>:   1%|          | 1300/150000 [10:53<20:44:57,  1.99it/s]

G: -0.04 | D: 0.52 | GP: 0.71 | SS: 0.01


default<./images>:   1%|          | 1350/150000 [11:17<20:43:05,  1.99it/s]

G: -0.04 | D: 0.54 | GP: 0.73 | SS: 0.01


default<./images>:   1%|          | 1400/150000 [11:40<20:39:18,  2.00it/s]

G: 0.16 | D: 0.57 | GP: 0.50 | SS: 0.01


default<./images>:   1%|          | 1450/150000 [12:04<20:36:30,  2.00it/s]

G: 0.11 | D: 0.49 | GP: 0.58 | SS: 0.01


default<./images>:   1%|          | 1500/150000 [12:27<20:33:19,  2.01it/s]

G: 0.00 | D: 0.57 | GP: 0.83 | SS: 0.01


default<./images>:   1%|          | 1550/150000 [12:51<20:30:43,  2.01it/s]

G: 0.12 | D: 0.53 | GP: 0.90 | SS: 0.01


default<./images>:   1%|          | 1600/150000 [13:14<20:28:04,  2.01it/s]

G: 0.06 | D: 0.48 | GP: 0.70 | SS: 0.01


default<./images>:   1%|          | 1650/150000 [13:38<20:26:03,  2.02it/s]

G: 0.14 | D: 0.60 | GP: 0.37 | SS: 0.01


default<./images>:   1%|          | 1700/150000 [14:01<20:23:12,  2.02it/s]

G: 0.29 | D: 0.54 | GP: 0.38 | SS: 0.01


default<./images>:   1%|          | 1750/150000 [14:25<20:21:19,  2.02it/s]

G: 0.02 | D: 0.52 | GP: 0.64 | SS: 0.01


default<./images>:   1%|          | 1800/150000 [14:48<20:19:22,  2.03it/s]

G: -0.09 | D: 0.39 | GP: 0.70 | SS: 0.01


default<./images>:   1%|          | 1850/150000 [15:12<20:17:46,  2.03it/s]

G: 0.18 | D: 0.56 | GP: 0.94 | SS: 0.01


default<./images>:   1%|▏         | 1900/150000 [15:35<20:15:32,  2.03it/s]

G: 0.05 | D: 0.50 | GP: 0.33 | SS: 0.01


default<./images>:   1%|▏         | 1950/150000 [15:59<20:13:49,  2.03it/s]

G: 0.07 | D: 0.51 | GP: 0.65 | SS: 0.01


default<./images>:   1%|▏         | 2000/150000 [16:22<20:11:47,  2.04it/s]

G: 0.15 | D: 0.46 | GP: 0.51 | SS: 0.01


default<./images>:   1%|▏         | 2050/150000 [16:46<20:10:22,  2.04it/s]

G: -0.04 | D: 0.54 | GP: 0.27 | SS: 0.01


default<./images>:   1%|▏         | 2100/150000 [17:09<20:08:15,  2.04it/s]

G: 0.12 | D: 0.50 | GP: 0.57 | SS: 0.01


default<./images>:   1%|▏         | 2150/150000 [17:33<20:07:18,  2.04it/s]

G: -0.08 | D: 0.62 | GP: 0.55 | SS: 0.01


default<./images>:   1%|▏         | 2200/150000 [17:57<20:05:59,  2.04it/s]

G: 0.16 | D: 0.52 | GP: 0.31 | SS: 0.01


default<./images>:   2%|▏         | 2250/150000 [18:21<20:05:13,  2.04it/s]

G: 0.18 | D: 0.49 | GP: 0.34 | SS: 0.01


default<./images>:   2%|▏         | 2300/150000 [18:44<20:04:00,  2.04it/s]

G: 0.09 | D: 0.57 | GP: 0.25 | SS: 0.01


default<./images>:   2%|▏         | 2350/150000 [19:09<20:03:13,  2.05it/s]

G: 0.19 | D: 0.53 | GP: 0.34 | SS: 0.01


default<./images>:   2%|▏         | 2400/150000 [19:32<20:02:04,  2.05it/s]

G: 0.17 | D: 0.52 | GP: 0.30 | SS: 0.01


default<./images>:   2%|▏         | 2450/150000 [19:57<20:01:44,  2.05it/s]

G: 0.20 | D: 0.50 | GP: 0.46 | SS: 0.01


default<./images>:   2%|▏         | 2500/150000 [20:21<20:01:02,  2.05it/s]

G: -0.05 | D: 0.48 | GP: 0.29 | SS: 0.01


default<./images>:   2%|▏         | 2550/150000 [20:45<20:00:47,  2.05it/s]

G: 0.30 | D: 0.53 | GP: 0.18 | SS: 0.01


default<./images>:   2%|▏         | 2600/150000 [21:10<20:00:01,  2.05it/s]

G: 0.09 | D: 0.46 | GP: 0.46 | SS: 0.01


default<./images>:   2%|▏         | 2650/150000 [21:34<19:59:42,  2.05it/s]

G: 0.11 | D: 0.46 | GP: 1.46 | SS: 0.01


default<./images>:   2%|▏         | 2700/150000 [21:58<19:59:01,  2.05it/s]

G: 0.14 | D: 0.53 | GP: 0.17 | SS: 0.01


default<./images>:   2%|▏         | 2750/150000 [22:23<19:58:48,  2.05it/s]

G: 0.11 | D: 0.36 | GP: 0.17 | SS: 0.01


default<./images>:   2%|▏         | 2800/150000 [22:47<19:58:04,  2.05it/s]

G: 0.11 | D: 0.52 | GP: 0.16 | SS: 0.01


default<./images>:   2%|▏         | 2850/150000 [23:11<19:57:46,  2.05it/s]

G: 0.04 | D: 0.47 | GP: 0.17 | SS: 0.01


default<./images>:   2%|▏         | 2900/150000 [23:36<19:57:12,  2.05it/s]

G: 0.04 | D: 0.49 | GP: 0.12 | SS: 0.01


default<./images>:   2%|▏         | 2950/150000 [24:00<19:56:50,  2.05it/s]

G: 0.20 | D: 0.52 | GP: 0.28 | SS: 0.01


default<./images>:   2%|▏         | 3000/150000 [24:24<19:56:13,  2.05it/s]

G: 0.05 | D: 0.42 | GP: 0.27 | SS: 0.01


default<./images>:   2%|▏         | 3050/150000 [24:49<19:55:49,  2.05it/s]

G: 0.03 | D: 0.48 | GP: 0.13 | SS: 0.01


default<./images>:   2%|▏         | 3100/150000 [25:13<19:55:22,  2.05it/s]

G: 0.14 | D: 0.44 | GP: 0.53 | SS: 0.01


default<./images>:   2%|▏         | 3111/150000 [25:19<19:55:22,  2.05it/s]

KeyboardInterrupt: 