In [1]:
!pip install ipywidgets

[0m

In [2]:
!pip install torch torchvision numpy scipy matplotlib tqdm

[0m

In [3]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm 
import matplotlib.pyplot as plt  
from scipy import linalg

In [4]:
DATASET = "STYLEGAN/train"

In [5]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

Using device: cuda


In [6]:
START_TRAIN_IMG_SIZE = 4
LR = 1e-3
BATCH_SIZES = [256,256,128,64,32,16]
CHANNELS_IMG = 3
Z_DIm = 512
W_DIM = 512
IN_CHANNELS = 512
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [40] * len(BATCH_SIZES)

In [7]:
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,  # Adjust as per your system
        pin_memory=True,
    )
    return loader, dataset

In [8]:
# Use this function to visualize some training images
def check_loader():
    loader, _ = get_loader(128)
    cloth, _ = next(iter(loader))
    _, ax = plt.subplots(3, 3, figsize=(8, 8))
    plt.suptitle('Some real samples')
    ind = 0
    for k in range(3):
        for kk in range(3):
            ax[k][kk].imshow((cloth[ind].permute(1, 2, 0) * 0.5 + 0.5).cpu().numpy())
            ind += 1
    plt.show()

In [9]:
factors = [1,1,1,1/2,1/4,1/8,1/16,1/32]

In [10]:
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features
    ):
        super(WSLinear,self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale  = (2/in_features) ** 0.5
        self.bias   = self.linear.bias
        self.linear.bias = None

        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self,x):
        return self.linear(x * self.scale) + self.bias

In [11]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8
    def forward(self,x ):
        return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True)+  self.epsilon)

In [12]:
class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )
    
    def forward(self,x):
        return self.mapping(x)

In [13]:
class injectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1,channels,1,1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device = x.device)
        return x + self.weight * noise

In [14]:
class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale   = WSLinear(w_dim, channels)
        self.style_bias    = WSLinear(w_dim, channels)

    def forward(self,x,w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias  = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

In [15]:
class GenBlock(nn.Module):
    def __init__(self, in_channel, out_channel, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channel, out_channel)
        self.conv2 = WSConv2d(out_channel, out_channel)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = injectNoise(out_channel)
        self.inject_noise2 = injectNoise(out_channel)
        self.adain1 = AdaIN(out_channel, w_dim)
        self.adain2 = AdaIN(out_channel, w_dim)
    def forward(self, x,w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

In [16]:
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super().__init__()
        self.starting_cte = nn.Parameter(torch.ones(1, in_channels, 4,4))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = injectNoise(in_channels)
        self.initial_noise2 = injectNoise(in_channels)
        self.initial_conv   = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky          = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb    = WSConv2d(
            in_channels, img_channels, kernel_size = 1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb])
        )

        for i in range(len(factors)-1):
            conv_in_c  = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i+1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size = 1, stride=1, padding=0))
        
    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1-alpha ) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_cte),w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)
        
        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode = 'bilinear', align_corners=False)
            out      = self.prog_blocks[step](upscaled,w)

        final_upscaled = self.rgb_layers[steps-1](upscaled)
        final_out      = self.rgb_layers[steps](out)

        return self.fade_in(alpha, final_upscaled, final_out)

In [17]:
class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In [18]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

In [19]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        # here we work back ways from factors because the discriminator
        # should be mirrored from the generator. So the first prog_block and
        # rgb layer we append will work for input size 1024x1024, then 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        # perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
        # did this to "mirror" the generator initial_rgb
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  # down sampling using avg pool

        # this is the block for 4x4 input size
        self.final_block = nn.Sequential(
            # +1 to in_channels because we concatenate from MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  # we use this instead of linear layer
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        # we take the std for each example (across all channels, and pixels) then we repeat it
        # for a single channel and concatenate it with the image. In this way the discriminator
        # will get information about the variation in the batch/image
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        # where we should start in the list of prog_blocks, maybe a bit confusing but
        # the last is for the 4x4. So example let's say steps=1, then we should start
        # at the second to last because input_size will be 8x8. If steps==0 we just
        # use the final block
        cur_step = len(self.prog_blocks) - steps

        # convert from rgb as initial step, this will depend on
        # the image size (each will have it's on rgb layer)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # i.e, image is 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        # because prog_blocks might change the channels, for down scale we use rgb_layer
        # from previous/smaller size which in our case correlates to +1 in the indexing
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        # the fade_in is done first between the downscaled and the input
        # this is opposite from the generator
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

In [20]:
def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIm).to(DEVICE)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

In [21]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [22]:
def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen
):
    loop = tqdm(loader, leave=True, mininterval=1.0)  

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]
        noise = torch.randn(cur_batch_size, Z_DIm).to(DEVICE)
        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001) * torch.mean(critic_real ** 2)
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        alpha += cur_batch_size / (
            PROGRESSIVE_EPOCHS[step] * len(loader.dataset)
        )
        alpha = min(alpha, 1)

        # Update postfix every 10 iterations
        if batch_idx % 10 == 0:
            loop.set_postfix(
                gp=gp.item(),
                loss_critic=loss_critic.item()
            )
    return alpha

In [23]:
# Define FID Calculation
def calculate_fid(gen, loader, device, steps, num_samples=5000):
    gen.eval()
    real_images = []
    fake_images = []

    with torch.no_grad():
        # Collect real images
        for real, _ in tqdm(loader, desc='Collecting Real Images', mininterval=1.0):
            real_images.append(real.to(device))
            if len(real_images) * real.size(0) >= num_samples:
                break
        real_images = torch.cat(real_images, dim=0)[:num_samples]

        # Collect fake images
        num_fake_batches = num_samples // BATCH_SIZES[steps] + 1
        for _ in tqdm(range(num_fake_batches), desc='Generating Fake Images', mininterval=1.0):
            noise = torch.randn(BATCH_SIZES[steps], Z_DIm).to(device)
            fake = gen(noise, alpha=1.0, steps=steps).detach()
            fake_images.append(fake)
            if len(fake_images) * fake.size(0) >= num_samples:
                break
        fake_images = torch.cat(fake_images, dim=0)[:num_samples]

    # Flatten images
    real = real_images.view(real_images.size(0), -1).cpu().numpy()
    fake = fake_images.view(fake_images.size(0), -1).cpu().numpy()

    # Compute statistics
    mu1 = np.mean(real, axis=0)
    sigma1 = np.cov(real, rowvar=False)
    mu2 = np.mean(fake, axis=0)
    sigma2 = np.cov(fake, rowvar=False)

    # Compute FID
    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    print(f'FID: {fid}')
    gen.train()
    return fid

In [24]:
# Define PPL Calculation with Corrections
def calculate_ppl(gen, device, steps, num_samples=2000, epsilon=1e-4):
    gen.eval()
    ppl = 0.0

    # Disable gradient computation for generator parameters
    for param in gen.parameters():
        param.requires_grad = False

    for _ in tqdm(range(num_samples), desc='Calculating PPL', mininterval=1.0):
        z = torch.randn(1, Z_DIm).to(device)
        z_perturbed = z + epsilon * torch.randn_like(z).to(device)
        z.requires_grad = True

        # Generate image and compute loss
        img = gen(z, alpha=1.0, steps=steps)
        loss = img.mean()
        loss.backward()

        grad = z.grad
        ppl += grad.norm().item()

    ppl /= num_samples
    print(f'PPL: {ppl}')

    # Re-enable gradient computation for generator parameters
    for param in gen.parameters():
        param.requires_grad = True

    gen.train()
    return ppl

In [25]:
# After Training: Compute FID and PPL
def evaluate_metrics(gen, critic, step):
    # Define the image size based on the current step
    image_size = 4 * 2**step
    loader, dataset = get_loader(image_size)
    
    print("Calculating FID...")
    fid = calculate_fid(gen, loader, DEVICE, step, num_samples=5000)
    
    print("Calculating PPL...")
    ppl = calculate_ppl(gen, DEVICE, step, num_samples=2000)
    
    return fid, ppl

In [None]:
# Training Loop with Metrics Evaluation
gen = Generator(
    Z_DIm, W_DIM, IN_CHANNELS, CHANNELS_IMG
).to(DEVICE)
critic = Discriminator(IN_CHANNELS, CHANNELS_IMG).to(DEVICE)
opt_gen = optim.Adam(
    [
        {'params': [param for name, param in gen.named_parameters() if 'map' not in name]},
        {'params': gen.map.parameters(), 'lr': 1e-5}
    ],
    lr=LR,
    betas=(0.0, 0.99)
)
opt_critic = optim.Adam(
    critic.parameters(), lr=LR, betas=(0.0, 0.99)
)

gen.train()
critic.train()
step = int(log2(START_TRAIN_IMG_SIZE / 4))
for idx, num_epochs in enumerate(PROGRESSIVE_EPOCHS[step:]):
    alpha = 0  # Start from 0 for fade-in
    current_step = step + idx
    loader, dataset = get_loader(4 * 2 ** current_step)
    print(f'Current image size: {4 * 2 ** current_step}')
    
    for epoch in range(num_epochs):
        print(f'Epoch [{epoch + 1}/{num_epochs}]')
        alpha = train_fn(
            critic, gen, loader, dataset, current_step, alpha, opt_critic, opt_gen
        )
    
    generate_examples(gen, current_step)
    
    # After completing training for the current step, evaluate metrics
    fid, ppl = evaluate_metrics(gen, critic, current_step)
    print(f'After Step {current_step}: FID = {fid}, PPL = {ppl}')
    
    # Save model checkpoints
    torch.save(gen.state_dict(), f'generator_step_{current_step}.pth')
    torch.save(critic.state_dict(), f'critic_step_{current_step}.pth')

# Optionally, save the final model
torch.save(gen.state_dict(), 'generator_final.pth')
torch.save(critic.state_dict(), 'critic_final.pth')

Current image size: 4
Epoch [1/40]


100%|██████████| 40/40 [00:10<00:00,  3.81it/s, gp=3.32, loss_critic=-191]  


Epoch [2/40]


100%|██████████| 40/40 [00:10<00:00,  3.92it/s, gp=3.85, loss_critic=-24.8]


Epoch [3/40]


100%|██████████| 40/40 [00:10<00:00,  3.88it/s, gp=1.68, loss_critic=-12.4]


Epoch [4/40]


100%|██████████| 40/40 [00:11<00:00,  3.63it/s, gp=0.99, loss_critic=-18.8]


Epoch [5/40]


100%|██████████| 40/40 [00:10<00:00,  3.89it/s, gp=0.784, loss_critic=-10.2]


Epoch [6/40]


100%|██████████| 40/40 [00:10<00:00,  3.89it/s, gp=0.738, loss_critic=-14]  


Epoch [7/40]


100%|██████████| 40/40 [00:10<00:00,  3.85it/s, gp=0.671, loss_critic=-15.7]


Epoch [8/40]


100%|██████████| 40/40 [00:10<00:00,  3.93it/s, gp=0.684, loss_critic=-17.4]


Epoch [9/40]


100%|██████████| 40/40 [00:10<00:00,  3.92it/s, gp=0.651, loss_critic=-17.8]


Epoch [10/40]


100%|██████████| 40/40 [00:10<00:00,  3.86it/s, gp=0.495, loss_critic=-14.7]


Epoch [11/40]


100%|██████████| 40/40 [00:10<00:00,  3.94it/s, gp=0.532, loss_critic=-16.4]


Epoch [12/40]


100%|██████████| 40/40 [00:10<00:00,  3.94it/s, gp=0.468, loss_critic=-15.3]


Epoch [13/40]


100%|██████████| 40/40 [00:10<00:00,  3.66it/s, gp=0.493, loss_critic=-18.4]


Epoch [14/40]


100%|██████████| 40/40 [00:10<00:00,  3.77it/s, gp=0.461, loss_critic=-14]  


Epoch [15/40]


100%|██████████| 40/40 [00:10<00:00,  3.88it/s, gp=0.48, loss_critic=-15.1] 


Epoch [16/40]


100%|██████████| 40/40 [00:10<00:00,  3.71it/s, gp=0.364, loss_critic=-15.6]


Epoch [17/40]


100%|██████████| 40/40 [00:10<00:00,  3.82it/s, gp=0.362, loss_critic=-13.5]


Epoch [18/40]


100%|██████████| 40/40 [00:10<00:00,  3.83it/s, gp=0.282, loss_critic=-11.8]


Epoch [19/40]


100%|██████████| 40/40 [00:10<00:00,  3.74it/s, gp=0.293, loss_critic=-12.3]


Epoch [20/40]


100%|██████████| 40/40 [00:10<00:00,  3.83it/s, gp=0.298, loss_critic=-13]  


Epoch [21/40]


100%|██████████| 40/40 [00:10<00:00,  3.85it/s, gp=0.243, loss_critic=-11.9]


Epoch [22/40]


100%|██████████| 40/40 [00:10<00:00,  3.93it/s, gp=0.319, loss_critic=-13.1]


Epoch [23/40]


100%|██████████| 40/40 [00:11<00:00,  3.64it/s, gp=0.285, loss_critic=-13.5]


Epoch [24/40]


100%|██████████| 40/40 [00:10<00:00,  3.91it/s, gp=0.253, loss_critic=-11.4]


Epoch [25/40]


100%|██████████| 40/40 [00:10<00:00,  3.93it/s, gp=0.256, loss_critic=-10.7]


Epoch [26/40]


100%|██████████| 40/40 [00:10<00:00,  3.93it/s, gp=0.219, loss_critic=-9.14]


Epoch [27/40]


100%|██████████| 40/40 [00:10<00:00,  3.91it/s, gp=0.117, loss_critic=-6.18]


Epoch [28/40]


100%|██████████| 40/40 [00:09<00:00,  4.01it/s, gp=0.115, loss_critic=-7.02]


Epoch [29/40]


100%|██████████| 40/40 [00:10<00:00,  3.84it/s, gp=0.12, loss_critic=-7.82] 


Epoch [30/40]


100%|██████████| 40/40 [00:10<00:00,  3.85it/s, gp=0.121, loss_critic=-6.59]


Epoch [31/40]


100%|██████████| 40/40 [00:10<00:00,  3.85it/s, gp=0.133, loss_critic=-7.66]


Epoch [32/40]


100%|██████████| 40/40 [00:10<00:00,  3.96it/s, gp=0.137, loss_critic=-9.8]


Epoch [33/40]


100%|██████████| 40/40 [00:10<00:00,  3.95it/s, gp=0.117, loss_critic=-6.89]


Epoch [34/40]


100%|██████████| 40/40 [00:09<00:00,  4.01it/s, gp=0.151, loss_critic=-7.39]


Epoch [35/40]


100%|██████████| 40/40 [00:10<00:00,  3.97it/s, gp=0.148, loss_critic=-7.11]


Epoch [36/40]


100%|██████████| 40/40 [00:10<00:00,  3.96it/s, gp=0.11, loss_critic=-7.24] 


Epoch [37/40]


100%|██████████| 40/40 [00:10<00:00,  3.89it/s, gp=0.0861, loss_critic=-5.85]


Epoch [38/40]


100%|██████████| 40/40 [00:10<00:00,  3.90it/s, gp=0.101, loss_critic=-5.86]


Epoch [39/40]


100%|██████████| 40/40 [00:10<00:00,  3.86it/s, gp=0.0796, loss_critic=-6.25]


Epoch [40/40]


100%|██████████| 40/40 [00:10<00:00,  3.76it/s, gp=0.111, loss_critic=-5.76]


Calculating FID...


Collecting Real Images:  48%|████▊     | 19/40 [00:06<00:06,  3.01it/s]
Generating Fake Images:  95%|█████████▌| 19/20 [00:00<00:00, 72.70it/s]


FID: 1124.6609230304512
Calculating PPL...


Calculating PPL: 100%|██████████| 2000/2000 [00:05<00:00, 374.33it/s]


PPL: 9.806638026952744
After Step 0: FID = 1124.6609230304512, PPL = 9.806638026952744
Current image size: 8
Epoch [1/40]


100%|██████████| 40/40 [00:10<00:00,  3.71it/s, gp=0.103, loss_critic=-6.74] 


Epoch [2/40]


100%|██████████| 40/40 [00:10<00:00,  3.77it/s, gp=0.0366, loss_critic=-4.04]


Epoch [3/40]


100%|██████████| 40/40 [00:10<00:00,  3.84it/s, gp=0.0288, loss_critic=-3.19]


Epoch [4/40]


100%|██████████| 40/40 [00:10<00:00,  3.70it/s, gp=0.0452, loss_critic=-2.71]


Epoch [5/40]


100%|██████████| 40/40 [00:10<00:00,  3.81it/s, gp=0.02, loss_critic=-1.94]  


Epoch [6/40]


100%|██████████| 40/40 [00:10<00:00,  3.69it/s, gp=0.0392, loss_critic=-2.4] 


Epoch [7/40]


100%|██████████| 40/40 [00:10<00:00,  3.73it/s, gp=0.0327, loss_critic=-2.62]


Epoch [8/40]


100%|██████████| 40/40 [00:10<00:00,  3.74it/s, gp=0.175, loss_critic=-2.55] 


Epoch [9/40]


100%|██████████| 40/40 [00:10<00:00,  3.83it/s, gp=0.0256, loss_critic=-1.82]


Epoch [10/40]


100%|██████████| 40/40 [00:10<00:00,  3.77it/s, gp=0.0372, loss_critic=-1.64]


Epoch [11/40]


100%|██████████| 40/40 [00:10<00:00,  3.82it/s, gp=0.131, loss_critic=-.72]  


Epoch [12/40]


100%|██████████| 40/40 [00:10<00:00,  3.73it/s, gp=0.0415, loss_critic=-1.89]


Epoch [13/40]


100%|██████████| 40/40 [00:10<00:00,  3.77it/s, gp=0.0272, loss_critic=-1.3] 


Epoch [14/40]


100%|██████████| 40/40 [00:10<00:00,  3.82it/s, gp=0.0353, loss_critic=-2.29]


Epoch [15/40]


100%|██████████| 40/40 [00:10<00:00,  3.73it/s, gp=0.0365, loss_critic=-2.21]


Epoch [16/40]


100%|██████████| 40/40 [00:10<00:00,  3.76it/s, gp=0.0523, loss_critic=-3.5] 


Epoch [17/40]


100%|██████████| 40/40 [00:10<00:00,  3.84it/s, gp=0.0311, loss_critic=-1.81]


Epoch [18/40]


100%|██████████| 40/40 [00:10<00:00,  3.76it/s, gp=0.0364, loss_critic=-2.08]


Epoch [19/40]


100%|██████████| 40/40 [00:10<00:00,  3.79it/s, gp=0.026, loss_critic=-2.19] 


Epoch [20/40]


100%|██████████| 40/40 [00:10<00:00,  3.79it/s, gp=0.00952, loss_critic=-3]  


Epoch [21/40]


100%|██████████| 40/40 [00:10<00:00,  3.74it/s, gp=0.0229, loss_critic=-2.18] 


Epoch [22/40]


100%|██████████| 40/40 [00:10<00:00,  3.72it/s, gp=0.0159, loss_critic=-2.19]


Epoch [23/40]


100%|██████████| 40/40 [00:10<00:00,  3.66it/s, gp=0.0279, loss_critic=-2.44]


Epoch [24/40]


100%|██████████| 40/40 [00:10<00:00,  3.84it/s, gp=0.0342, loss_critic=-2.13]


Epoch [25/40]


100%|██████████| 40/40 [00:10<00:00,  3.81it/s, gp=0.0286, loss_critic=-2.02]


Epoch [26/40]


100%|██████████| 40/40 [00:10<00:00,  3.77it/s, gp=0.0395, loss_critic=-2.27]


Epoch [27/40]


100%|██████████| 40/40 [00:10<00:00,  3.76it/s, gp=0.0295, loss_critic=-2.35]


Epoch [28/40]


100%|██████████| 40/40 [00:10<00:00,  3.78it/s, gp=0.0516, loss_critic=-2.36]


Epoch [29/40]


100%|██████████| 40/40 [00:10<00:00,  3.81it/s, gp=0.0221, loss_critic=-2.61]


Epoch [30/40]


100%|██████████| 40/40 [00:10<00:00,  3.77it/s, gp=0.0417, loss_critic=-2.29]


Epoch [31/40]


100%|██████████| 40/40 [00:10<00:00,  3.75it/s, gp=0.0351, loss_critic=-3.02]


Epoch [32/40]


100%|██████████| 40/40 [00:10<00:00,  3.73it/s, gp=0.0315, loss_critic=-2.82]


Epoch [33/40]


100%|██████████| 40/40 [00:10<00:00,  3.76it/s, gp=0.0264, loss_critic=-2.05]


Epoch [34/40]


100%|██████████| 40/40 [00:10<00:00,  3.81it/s, gp=0.0292, loss_critic=-2.84]


Epoch [35/40]


100%|██████████| 40/40 [00:10<00:00,  3.75it/s, gp=0.0282, loss_critic=-2.6] 


Epoch [36/40]


100%|██████████| 40/40 [00:10<00:00,  3.76it/s, gp=0.0371, loss_critic=-2.82]


Epoch [37/40]


100%|██████████| 40/40 [00:10<00:00,  3.82it/s, gp=0.0453, loss_critic=-2.63]


Epoch [38/40]


100%|██████████| 40/40 [00:10<00:00,  3.82it/s, gp=0.0237, loss_critic=-2.85]


Epoch [39/40]


100%|██████████| 40/40 [00:10<00:00,  3.78it/s, gp=0.0266, loss_critic=-2.58]


Epoch [40/40]


100%|██████████| 40/40 [00:10<00:00,  3.88it/s, gp=0.0335, loss_critic=-3.44]


Calculating FID...


Collecting Real Images:  48%|████▊     | 19/40 [00:06<00:06,  3.07it/s]
Generating Fake Images:  95%|█████████▌| 19/20 [00:00<00:00, 26.94it/s]


FID: 12.89887002835085
Calculating PPL...


Calculating PPL: 100%|██████████| 2000/2000 [00:07<00:00, 251.47it/s]


PPL: 0.6368109695240856
After Step 1: FID = 12.89887002835085, PPL = 0.6368109695240856
Current image size: 16
Epoch [1/40]


100%|██████████| 79/79 [00:14<00:00,  5.33it/s, gp=0.0596, loss_critic=-4.42]


Epoch [2/40]


100%|██████████| 79/79 [00:14<00:00,  5.41it/s, gp=0.0662, loss_critic=-5.66]


Epoch [3/40]


100%|██████████| 79/79 [00:14<00:00,  5.43it/s, gp=0.0461, loss_critic=-4.76]


Epoch [4/40]


100%|██████████| 79/79 [00:14<00:00,  5.43it/s, gp=0.0765, loss_critic=-3.99]


Epoch [5/40]


100%|██████████| 79/79 [00:14<00:00,  5.42it/s, gp=0.0727, loss_critic=-3.78]


Epoch [6/40]


100%|██████████| 79/79 [00:14<00:00,  5.44it/s, gp=0.0196, loss_critic=-4.03]


Epoch [7/40]


100%|██████████| 79/79 [00:14<00:00,  5.40it/s, gp=0.0301, loss_critic=-2.92]


Epoch [8/40]


100%|██████████| 79/79 [00:14<00:00,  5.34it/s, gp=0.262, loss_critic=-7.52] 


Epoch [9/40]


100%|██████████| 79/79 [00:14<00:00,  5.39it/s, gp=0.0457, loss_critic=-3.51]


Epoch [10/40]


100%|██████████| 79/79 [00:14<00:00,  5.38it/s, gp=0.0377, loss_critic=-2.3] 


Epoch [11/40]


100%|██████████| 79/79 [00:14<00:00,  5.39it/s, gp=0.0377, loss_critic=-3.14]


Epoch [12/40]


100%|██████████| 79/79 [00:14<00:00,  5.40it/s, gp=0.0496, loss_critic=-3.87]


Epoch [13/40]


100%|██████████| 79/79 [00:14<00:00,  5.39it/s, gp=0.0329, loss_critic=-3.37]


Epoch [14/40]


100%|██████████| 79/79 [00:14<00:00,  5.45it/s, gp=0.0401, loss_critic=-3.32]


Epoch [15/40]


100%|██████████| 79/79 [00:14<00:00,  5.46it/s, gp=0.0407, loss_critic=-3.41]


Epoch [16/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.0727, loss_critic=-3.19]


Epoch [17/40]


100%|██████████| 79/79 [00:14<00:00,  5.45it/s, gp=0.0373, loss_critic=-2.74]


Epoch [18/40]


100%|██████████| 79/79 [00:14<00:00,  5.46it/s, gp=0.0516, loss_critic=-2.02]


Epoch [19/40]


100%|██████████| 79/79 [00:14<00:00,  5.46it/s, gp=0.0274, loss_critic=-2.68]


Epoch [20/40]


100%|██████████| 79/79 [00:14<00:00,  5.46it/s, gp=0.0263, loss_critic=-2.11]


Epoch [21/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.0215, loss_critic=-3.1] 


Epoch [22/40]


100%|██████████| 79/79 [00:14<00:00,  5.48it/s, gp=0.0264, loss_critic=-2.05]


Epoch [23/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.0198, loss_critic=-2.41]


Epoch [24/40]


100%|██████████| 79/79 [00:14<00:00,  5.49it/s, gp=0.0276, loss_critic=-2.27]


Epoch [25/40]


100%|██████████| 79/79 [00:14<00:00,  5.48it/s, gp=0.0328, loss_critic=-2.45]


Epoch [26/40]


100%|██████████| 79/79 [00:14<00:00,  5.48it/s, gp=0.042, loss_critic=-2.21]  


Epoch [27/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.0492, loss_critic=-2.75]


Epoch [28/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.046, loss_critic=-3.53] 


Epoch [29/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.015, loss_critic=-2]    


Epoch [30/40]


100%|██████████| 79/79 [00:14<00:00,  5.44it/s, gp=0.0558, loss_critic=-3.1] 


Epoch [31/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.0292, loss_critic=-2.14] 


Epoch [32/40]


100%|██████████| 79/79 [00:14<00:00,  5.48it/s, gp=0.0282, loss_critic=-2.02]


Epoch [33/40]


100%|██████████| 79/79 [00:14<00:00,  5.48it/s, gp=0.0188, loss_critic=-1.83]


Epoch [34/40]


100%|██████████| 79/79 [00:14<00:00,  5.45it/s, gp=0.058, loss_critic=-2.63] 


Epoch [35/40]


100%|██████████| 79/79 [00:14<00:00,  5.46it/s, gp=0.0231, loss_critic=-1.86]


Epoch [36/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.0197, loss_critic=-1.55] 


Epoch [37/40]


100%|██████████| 79/79 [00:14<00:00,  5.48it/s, gp=0.0184, loss_critic=-1.38]


Epoch [38/40]


100%|██████████| 79/79 [00:14<00:00,  5.46it/s, gp=0.035, loss_critic=-2.26] 


Epoch [39/40]


100%|██████████| 79/79 [00:14<00:00,  5.47it/s, gp=0.0225, loss_critic=-2.06]


Epoch [40/40]


100%|██████████| 79/79 [00:14<00:00,  5.46it/s, gp=0.03, loss_critic=-1.43]  


Calculating FID...


Collecting Real Images:  49%|████▉     | 39/79 [00:05<00:06,  6.64it/s]
Generating Fake Images:  98%|█████████▊| 39/40 [00:01<00:00, 30.97it/s]


FID: 26.39013978979188
Calculating PPL...


Calculating PPL: 100%|██████████| 2000/2000 [00:12<00:00, 163.21it/s]


PPL: 0.503183580763638
After Step 2: FID = 26.39013978979188, PPL = 0.503183580763638
Current image size: 32
Epoch [1/40]


100%|██████████| 157/157 [00:43<00:00,  3.60it/s, gp=0.125, loss_critic=-4.53] 


Epoch [2/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0571, loss_critic=-2.38]


Epoch [3/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0722, loss_critic=-3.59]


Epoch [4/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0344, loss_critic=-2.76]


Epoch [5/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0803, loss_critic=-3.93]


Epoch [6/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.00955, loss_critic=-4.31]


Epoch [7/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0675, loss_critic=-2.52]


Epoch [8/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0236, loss_critic=-2.44]


Epoch [9/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0485, loss_critic=-2.63]


Epoch [10/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0196, loss_critic=-1.67]


Epoch [11/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0382, loss_critic=-2.2] 


Epoch [12/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.106, loss_critic=-3.02] 


Epoch [13/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0145, loss_critic=-1.58]


Epoch [14/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0116, loss_critic=-1.38]


Epoch [15/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0479, loss_critic=-2.18]


Epoch [16/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0346, loss_critic=-1.8] 


Epoch [17/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.018, loss_critic=-1.2]  


Epoch [18/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0124, loss_critic=-1.45]


Epoch [19/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0488, loss_critic=-2.16]


Epoch [20/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0154, loss_critic=-1.55] 


Epoch [21/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0199, loss_critic=-1.31]


Epoch [22/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0153, loss_critic=-1.23]


Epoch [23/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0174, loss_critic=-1.31]


Epoch [24/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0184, loss_critic=-.907]


Epoch [25/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0134, loss_critic=-1.03]


Epoch [26/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0177, loss_critic=-1.19]


Epoch [27/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0704, loss_critic=-1.49]


Epoch [28/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.018, loss_critic=-1.11] 


Epoch [29/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0133, loss_critic=-2.06] 


Epoch [30/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0254, loss_critic=-1.52]


Epoch [31/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.032, loss_critic=-.951]  


Epoch [32/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0291, loss_critic=-1.61] 


Epoch [33/40]


100%|██████████| 157/157 [00:43<00:00,  3.62it/s, gp=0.0183, loss_critic=-1.08]


Epoch [34/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0296, loss_critic=-1.36] 


Epoch [35/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0242, loss_critic=-1.2] 


Epoch [36/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0391, loss_critic=-1.49] 


Epoch [37/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0222, loss_critic=-.708]


Epoch [38/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0214, loss_critic=-.534]


Epoch [39/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.0128, loss_critic=-.857] 


Epoch [40/40]


100%|██████████| 157/157 [00:43<00:00,  3.63it/s, gp=0.00987, loss_critic=-1.24]


Calculating FID...


Collecting Real Images:  50%|████▉     | 78/157 [00:05<00:05, 14.15it/s]
Generating Fake Images:  99%|█████████▊| 78/79 [00:02<00:00, 30.89it/s]


FID: 43.38151100777037
Calculating PPL...


Calculating PPL: 100%|██████████| 2000/2000 [00:13<00:00, 145.05it/s]


PPL: 0.47898850885033606
After Step 3: FID = 43.38151100777037, PPL = 0.47898850885033606
Current image size: 64
Epoch [1/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.137, loss_critic=-2.55]  


Epoch [2/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0175, loss_critic=-1.65]


Epoch [3/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0611, loss_critic=-2.98] 


Epoch [4/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.129, loss_critic=-2.55]  


Epoch [5/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0181, loss_critic=-2.4]  


Epoch [6/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0386, loss_critic=-3.36] 


Epoch [7/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.046, loss_critic=-3.63] 


Epoch [8/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0295, loss_critic=-2.59]


Epoch [9/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0355, loss_critic=-2.16]


Epoch [10/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0203, loss_critic=-3.81]


Epoch [11/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0219, loss_critic=-3.13]


Epoch [12/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0248, loss_critic=-2.15]


Epoch [13/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0226, loss_critic=-2.83]


Epoch [14/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0224, loss_critic=-2.87]


Epoch [15/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0212, loss_critic=-2.59]


Epoch [16/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0159, loss_critic=-2.92]


Epoch [17/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0138, loss_critic=-3.21] 


Epoch [18/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0461, loss_critic=-1.95] 


Epoch [19/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0339, loss_critic=-3.57] 


Epoch [20/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0205, loss_critic=-2.85] 


Epoch [21/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0131, loss_critic=-2.26]


Epoch [22/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.019, loss_critic=-3.55] 


Epoch [23/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0212, loss_critic=-2.55] 


Epoch [24/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0289, loss_critic=-1.89] 


Epoch [25/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0184, loss_critic=-1.77]


Epoch [26/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0227, loss_critic=-2.2]  


Epoch [27/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0235, loss_critic=-2.13]


Epoch [28/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0559, loss_critic=-2.88]


Epoch [29/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0279, loss_critic=-2.05] 


Epoch [30/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0169, loss_critic=-2.96] 


Epoch [31/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0129, loss_critic=-2.22] 


Epoch [32/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0097, loss_critic=-1.35]


Epoch [33/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.00627, loss_critic=-2.55]


Epoch [34/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0147, loss_critic=-1.76] 


Epoch [35/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0122, loss_critic=-2.42] 


Epoch [36/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0227, loss_critic=-1.53]


Epoch [37/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0387, loss_critic=-3.7] 


Epoch [38/40]


100%|██████████| 313/313 [01:37<00:00,  3.23it/s, gp=0.0234, loss_critic=-3.23] 


Epoch [39/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0253, loss_critic=-1.46] 


Epoch [40/40]


100%|██████████| 313/313 [01:37<00:00,  3.22it/s, gp=0.0183, loss_critic=-1.92] 


Calculating FID...


Collecting Real Images:  50%|████▉     | 156/313 [00:05<00:05, 26.78it/s]
Generating Fake Images:  99%|█████████▉| 156/157 [00:03<00:00, 44.15it/s]


FID: 190.86005055362304
Calculating PPL...


Calculating PPL: 100%|██████████| 2000/2000 [00:18<00:00, 110.62it/s]


PPL: 0.3673036408424377
After Step 4: FID = 190.86005055362304, PPL = 0.3673036408424377
Current image size: 128
Epoch [1/40]


100%|██████████| 625/625 [02:52<00:00,  3.63it/s, gp=0.0519, loss_critic=-4.48]


Epoch [2/40]


 93%|█████████▎| 584/625 [02:41<00:11,  3.63it/s, gp=0.0902, loss_critic=-6.57]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.105, loss_critic=-3.27] 


Epoch [10/40]


100%|██████████| 625/625 [02:53<00:00,  3.61it/s, gp=0.0615, loss_critic=-7.29]


Epoch [11/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.0472, loss_critic=-5.43]


Epoch [12/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.0415, loss_critic=-7.3] 


Epoch [13/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.0949, loss_critic=-7.43]


Epoch [14/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.0612, loss_critic=-6.53]


Epoch [15/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.0987, loss_critic=-6.74]


Epoch [16/40]


100%|██████████| 625/625 [02:52<00:00,  3.61it/s, gp=0.0781, loss_critic=-2.85]


Epoch [17/40]


100%|██████████| 625/625 [02:52<00:00,  3.61it/s, gp=0.0671, loss_critic=-5.76]


Epoch [18/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.0646, loss_critic=-5.91]


Epoch [19/40]


100%|██████████| 625/625 [02:52<00:00,  3.63it/s, gp=0.053, loss_critic=-2.12] 


Epoch [20/40]


100%|██████████| 625/625 [02:52<00:00,  3.63it/s, gp=0.0831, loss_critic=-4.24]


Epoch [21/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.044, loss_critic=-8.34] 


Epoch [22/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.033, loss_critic=-6.12] 


Epoch [23/40]


100%|██████████| 625/625 [02:53<00:00,  3.61it/s, gp=0.0682, loss_critic=-8.18]


Epoch [24/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.0475, loss_critic=-4.02]


Epoch [25/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.146, loss_critic=-4.71] 


Epoch [26/40]


100%|██████████| 625/625 [02:53<00:00,  3.61it/s, gp=0.0542, loss_critic=-11.3]


Epoch [27/40]


100%|██████████| 625/625 [02:52<00:00,  3.61it/s, gp=0.0887, loss_critic=-7.79]


Epoch [28/40]


100%|██████████| 625/625 [02:53<00:00,  3.60it/s, gp=0.0776, loss_critic=-2.64]


Epoch [29/40]


100%|██████████| 625/625 [02:53<00:00,  3.61it/s, gp=0.0765, loss_critic=-4.75]


Epoch [30/40]


100%|██████████| 625/625 [02:53<00:00,  3.60it/s, gp=0.0732, loss_critic=-3.66]


Epoch [31/40]


100%|██████████| 625/625 [02:53<00:00,  3.61it/s, gp=0.0632, loss_critic=-5.33]


Epoch [32/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.0395, loss_critic=-7.7] 


Epoch [33/40]


100%|██████████| 625/625 [02:52<00:00,  3.61it/s, gp=0.0708, loss_critic=-6.41]


Epoch [34/40]


100%|██████████| 625/625 [02:53<00:00,  3.61it/s, gp=0.0455, loss_critic=-5.14]


Epoch [35/40]


100%|██████████| 625/625 [02:54<00:00,  3.58it/s, gp=0.0357, loss_critic=-3.87]


Epoch [36/40]


100%|██████████| 625/625 [02:53<00:00,  3.60it/s, gp=0.0788, loss_critic=-5.42]


Epoch [37/40]


100%|██████████| 625/625 [02:54<00:00,  3.59it/s, gp=0.0724, loss_critic=-6.84]


Epoch [38/40]


100%|██████████| 625/625 [02:53<00:00,  3.61it/s, gp=0.0652, loss_critic=-3.84] 


Epoch [39/40]


100%|██████████| 625/625 [02:52<00:00,  3.62it/s, gp=0.103, loss_critic=-7.98] 


Epoch [40/40]


100%|██████████| 625/625 [02:53<00:00,  3.61it/s, gp=0.0483, loss_critic=-5.43]


Calculating FID...


Collecting Real Images:  50%|████▉     | 312/625 [00:07<00:07, 43.80it/s]
Generating Fake Images: 100%|█████████▉| 312/313 [00:04<00:00, 68.05it/s]
