In [1]:
from pathlib import Path
import sys

In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [None]:
!git clone -b metrics https://github.com/azfarkhoja305/GANs.git

In [3]:
if Path('./GANs').exists():
    sys.path.insert(0,'./GANs')

In [None]:
!python ./GANs/create_fid_stats.py -d cifar_10 -t False

In [7]:
%load_ext autoreload
%autoreload 2

### 1) Added a Seed fn

In [10]:
from utils.utils import set_seed, check_gpu, display_images

set_seed(seed=123)

In [8]:
device = check_gpu()
print(f'Using device: {device}')

Using device: cuda


### 2) Changes in datasets.py regarding transforms

In [28]:
from datasets import ImageDataset

# make a list of extra transforms for train set
tfms = [transforms.RandomHorizontalFlip()]
dataset =  ImageDataset('cifar_10', batch_sz = 256, tfms=tfms)

Files already downloaded and verified


In [31]:
print(f"Train set Transforms:\n {dataset.train_transforms}")
print(f"\nValid set Transforms:\n {dataset.valid_transforms}")

Train set Transforms:
 [RandomHorizontalFlip(p=0.5), ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]

Valid set Transforms:
 [ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]


In [32]:
# Demo gan train

from models.generator import Generator
from models.discriminator import Discriminator

Gen = Generator(z_sz=128).to(device)
Dis = Discriminator().to(device)

In [37]:
loss_fn = nn.BCEWithLogitsLoss()
real_label = 1.
fake_label = 0.
latent_dims = 128
fixed_noise = torch.randn(64, latent_dims, device=device)
lr,beta1 = 1e-4, 0
optG = optim.AdamW(Gen.parameters(), lr=lr, betas=(beta1, 0.999))
optD = optim.AdamW(Dis.parameters(), lr=lr, betas=(beta1, 0.999))

img_list = []
G_losses = []
D_losses = []
iters = 0

# Number of training epochs
num_epochs = 10

In [40]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataset.train_loader):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        Dis.zero_grad()
        # Format batch
        real = data[0].to(device)
        b_size = real.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = Dis(real).view(-1)
        # Calculate loss on all-real batch
        errD_real = loss_fn(output,label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = torch.sigmoid(output).mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, latent_dims, device=device)
        # Generate fake image batch with G
        fake = Gen(noise)
        label = torch.full_like(label, fake_label)
        # Classify all fake batch with D
        # pdb.set_trace()
        output = Dis(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = loss_fn(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = torch.sigmoid(output).mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        Gen.zero_grad()
        label = torch.full_like(label,real_label) # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = Dis(fake).view(-1)
        # Calculate G's loss based on this output
        errG = loss_fn(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = torch.sigmoid(output).mean().item()
        # Update G
        optG.step()

        # Output training stats
        if (i+1) %700 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                % (epoch, num_epochs, i, len(dataset.train_loader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 1000 == 0) or ((epoch == num_epochs-1) and (i == len(dataset.train_loader)-1)):
            with torch.no_grad():
                fake = Gen(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters +=1

### 3) Inception and FID scores

In [41]:
from metrics.torch_is_fid_score import is_fid_from_generator

In [42]:
stat_path = Path('fid_stats/cifar_10_valid_fid_stats.npz')

In [44]:
inception_score, fid = is_fid_from_generator(generator=Gen,
                                        latent_dims=latent_dims,
                                        num_imgs=10000,
                                        batch_sz=256,
                                        fid_stat_path = stat_path)

HBox(children=(FloatProgress(value=0.0, description='generating images', max=40.0, style=ProgressStyle(descrip…



HBox(children=(FloatProgress(value=0.0, description='inception_score_and_fid', layout=Layout(flex='2'), max=10…



In [45]:
# with deviation
inception_score

(3.6151833534240723, 0.0741254985332489)

In [46]:
fid

133.73007202148438