# Conditional GAN implementation

[Original video](https://youtu.be/Hp-jWm2SzR8)

[Article](https://arxiv.org/abs/1411.1784)

## Download and prepare dataset. Import libraries

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import multiprocessing

from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

In [None]:
IMG_SIZE = 64
CHANNELS_IMG = 1

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),  # resize proportionally to rectangular image
    transforms.RandomCrop(IMG_SIZE),  # crop rectangular image to square
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)],
                         [0.5 for _ in range(CHANNELS_IMG)]),
])

def save_checkpoint(state, filename):
    print('=> Saving checkpoint')
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print('=> Loading checkpoint')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    step = checkpoint['step']
    return step

In [None]:
# Get MNIST dataset

# For the error: HTTPError: HTTP Error 503: Service Unavailable
# Use this instead
data_dir = '/content/dataset/MNIST/raw/'
if os.path.exists(data_dir):
    !rm -rf $data_dir

!mkdir $data_dir
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-images-idx3-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-labels-idx1-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/train-images-idx3-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/train-labels-idx1-ubyte.gz

# For the error: HTTPError: HTTP Error 403: Forbidden
# StackOverflow: https://stackoverflow.com/a/66461122/7550928
from six.moves import urllib    
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

dataset = datasets.MNIST(root='dataset/', train=True, transform=transform, download=True)

--2021-03-16 13:10:47--  https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-images-idx3-ubyte.gz
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/master/mnist/data/t10k-images-idx3-ubyte.gz [following]
--2021-03-16 13:10:47--  https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/master/mnist/data/t10k-images-idx3-ubyte.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1648877 (1.6M) [application/octet-stream]
Saving to: ‘/content/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz’


2021-03-16 13:10:48 (58.0 MB/s) - ‘/content/dataset/MNIST/raw/t10k-images

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## Set and test the model

In [None]:
class Critic(nn.Module):
    def __init__(self, channels_img, features_d, num_classes, img_size):
        super(Critic, self).__init__()
        self.img_size = img_size

        self.critic = nn.Sequential(  # Input: N x C+1 x 64 x64
            self._block(channels_img+1, features_d,   kernel_size=4, stride=2, padding=1),  # N x 64 x 32 x 32
            self._block(features_d,     features_d*2, kernel_size=4, stride=2, padding=1),  # N x 128 x 16 x 16
            self._block(features_d*2,   features_d*4, kernel_size=4, stride=2, padding=1),  # N x 256 x 8 x 8
            self._block(features_d*4,   features_d*8, kernel_size=4, stride=2, padding=1),  # N x 512 x 4 x 4
            nn.Conv2d(features_d*8,     1,            kernel_size=4, stride=1, padding=0),  # N x 1 x 1 x 1
            nn.Flatten(),  # no nn.Sigmoid() anymore, this is why it is Critic, but not Discriminator
        )  # Output: N x 1
        self.embed = nn.Embedding(num_classes, self.img_size * self.img_size)

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            # bias=False for BatchNorm
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            # Do not normalize across the batches. Normalize only across the layer (instance).
            nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <--> InstanceNorm
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x, labels):
        # add additional channel to the image: N x 1 x H x W
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], dim=1)  # N x C+1 x H x W
        return self.critic(x)  # N x 1


class Generator(nn.Module):
    def __init__(self, noise_dim, channels_img, features_g, num_classes, img_size, embed_size):
        super(Generator, self).__init__()
        self.img_size = img_size

        self.gen = nn.Sequential(  # Input: N x noise_dim x 1 x 1
            self._block(noise_dim+embed_size, features_g*8, kernel_size=4, stride=1, padding=0),  # N x 512 x 4 x 4
            self._block(features_g*8,         features_g*4, kernel_size=4, stride=2, padding=1),  # N x 256 x 8 x 8
            self._block(features_g*4,         features_g*2, kernel_size=4, stride=2, padding=1),  # N x 128 x 16 x 16
            self._block(features_g*2,         features_g,   kernel_size=4, stride=2, padding=1),  # N x 64 x 32 x 32
            nn.ConvTranspose2d(features_g,    channels_img, kernel_size=4, stride=2, padding=1),
            # Output: N x C x 64 x 64
            nn.Tanh(),  # between (-1, 1)
        )
        # add embedding to the noise
        self.embed = nn.Embedding(num_classes, embed_size)

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),  # like in DCGAN paper
        )
    
    def forward(self, x, labels):
        # latent vector x: N x noise_dim x 1 x 1
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], dim=1)
        return self.gen(x)  # N x C x 64 x 64


def init_weights(model):
    ''' Initialize weights of the model
        with mean of 0.0 and standard deviation of 0.02 '''
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


def test():
    N, in_channels, height, width = 8, 1, 64, 64
    num_classes = 10
    gen_embedding = 100
    noise_dim = 128
    features_d = features_g = 16
    
    x = torch.randn((N, in_channels, height, width))
    labels = torch.ones((N)).int()  # also can use *.to(torch.int64) or *.long()

    critic = Critic(in_channels, features_d, num_classes, height)
    init_weights(critic)
    assert critic(x, labels).shape == (N, 1)

    gen = Generator(noise_dim, in_channels, features_g, num_classes, height, gen_embedding)

    init_weights(gen)
    z = torch.randn((N, noise_dim, 1, 1))
    assert gen(z, labels).shape == (N, in_channels, height, width)
    
    print('Test is OK')

In [None]:
test()

Test is OK


## Run TensorBoard

In [None]:
# Run TensorBoard

# Delete previous logs dir
logs_dir = 'logs_dir'
if os.path.exists(logs_dir):
    !rm -rf $logs_dir

# To fix the error, because PyTorch and TensorFlow are installed both:
# AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Start TensorBoard before training to monitor it in progress
%tensorboard --logdir $logs_dir

# Reload TensorBoard
%reload_ext tensorboard

## Prepare the model

In [None]:
# Hyperparameters etc.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 1e-4  # could also use two lrs, one for gen and one for critic
BATCH_SIZE = 64  # was 64
NUM_CLASSES = 10  # MNIST dataset
GEN_EMBEDDING = 100
NOISE_DIM = 128  # was 100
NUM_EPOCHS = 10
FEATURES_CRITIC = FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10
logs_dir = 'logs_dir'

loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=multiprocessing.cpu_count(), pin_memory=True)

gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING).to(device)
critic = Critic(CHANNELS_IMG, FEATURES_CRITIC, NUM_CLASSES, IMG_SIZE).to(device)
init_weights(gen)
init_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

writer_real = SummaryWriter(os.path.join(logs_dir, 'real'))
writer_fake = SummaryWriter(os.path.join(logs_dir, 'fake'))
step = 0

# Load models
gen_checkpoint_name = 'generator_mnist.pth.tar'
critic_checkpoint_name = 'critic_mnist.pth.tar'

if os.path.exists(gen_checkpoint_name) and os.path.exists(critic_checkpoint_name):
    step = load_checkpoint(torch.load(gen_checkpoint_name), gen, opt_gen)
    step = load_checkpoint(torch.load(critic_checkpoint_name), critic, opt_critic)


def gradient_penalty(critic, labels, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = epsilon * real + (1 - epsilon) * fake

    # calculate critic scores
    mixed_scores = critic(interpolated_images, labels)
    # calculate gradient
    gradient = torch.autograd.grad(inputs=interpolated_images,
                                   outputs=mixed_scores,
                                   grad_outputs=torch.ones_like(mixed_scores),
                                   create_graph=True,
                                   retain_graph=True,)[0]  # BATCH_SIZE x 3 x 64 x 64
    gradient = gradient.view(gradient.shape[0], -1)  # BATCH_SIZE x 12 288 or 64*64*3
    gradient_norm = gradient.norm(2, dim=1)  # L2 norm
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

## Train the model

In [None]:
def train(step=step):
    gen.train()
    critic.train()

    for epoch in range(NUM_EPOCHS):
        for batch_idx, (real, labels) in enumerate(loader):
            real = real.to(device)
            curr_batch_size = real.shape[0]
            labels = labels.to(device)

            # Train Critic: max (E(critic(real)) - E(critic(fake)))
            # or min (-1) * (E(critic(real)) - E(critic(fake)))
            for _ in range(CRITIC_ITERATIONS):
                noise = torch.randn(curr_batch_size, NOISE_DIM, 1, 1).to(device)
                fake = gen(noise, labels)
                critic_real = critic(real, labels)
                critic_fake = critic(fake, labels)
                gp = gradient_penalty(critic, labels, real, fake, device)

                loss_critic = torch.mean(critic_fake) - torch.mean(critic_real) + LAMBDA_GP*gp

                opt_critic.zero_grad()
                loss_critic.backward(retain_graph=True)
                opt_critic.step()

            # Train Generator: min (E(critic(real)) - E(critic(fake)))
            # which is the same as: min (-1) * E(critic(fake))
            # because generator can not influence on E(critic(real)).
            output = critic(fake, labels)
            loss_gen = -torch.mean(output)

            opt_gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

            # Print losses occasionally and print to tensorboard
            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] '
                    f'Batch {batch_idx}/{len(loader)} '
                    f'Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}')

                with torch.no_grad():
                    fake = gen(noise, labels)
                    # take out (up to) 32 examples
                    img_grid_real = torchvision.utils.make_grid(
                        real[:32], normalize=True
                    )
                    img_grid_fake = torchvision.utils.make_grid(
                        fake[:32], normalize=True
                    )

                    writer_real.add_image('Real', img_grid_real, global_step=step)
                    writer_fake.add_image('Fake', img_grid_fake, global_step=step)

                step += 1

        # Save models
        gen_checkpoint = {
            'state_dict': gen.state_dict(),
            'optimizer': opt_gen.state_dict(),
            'step': step,
        }
        critic_checkpoint = {
            'state_dict': critic.state_dict(),
            'optimizer': opt_critic.state_dict(),
            'step': step,
        }
        save_checkpoint(gen_checkpoint, gen_checkpoint_name)
        save_checkpoint(critic_checkpoint, critic_checkpoint_name)

In [None]:
train(step=step)

Epoch [1/10] Batch 0/938 Loss D: -2.1588, loss G: 42.3064
Epoch [1/10] Batch 100/938 Loss D: -2.7083, loss G: 42.9608
Epoch [1/10] Batch 200/938 Loss D: -3.9034, loss G: 35.5546
Epoch [1/10] Batch 300/938 Loss D: -4.8857, loss G: 26.6659
Epoch [1/10] Batch 400/938 Loss D: -2.2113, loss G: 38.4966
Epoch [1/10] Batch 500/938 Loss D: -3.8873, loss G: 30.7572
Epoch [1/10] Batch 600/938 Loss D: -1.8106, loss G: 32.3862
Epoch [1/10] Batch 700/938 Loss D: -2.3506, loss G: 30.3085
Epoch [1/10] Batch 800/938 Loss D: -3.4277, loss G: 23.3088
Epoch [1/10] Batch 900/938 Loss D: -3.3534, loss G: 30.7002
=> Saving checkpoint
=> Saving checkpoint
Epoch [2/10] Batch 0/938 Loss D: -3.7547, loss G: 24.1971
Epoch [2/10] Batch 100/938 Loss D: -2.2234, loss G: 36.9975
Epoch [2/10] Batch 200/938 Loss D: -1.3992, loss G: 27.1289
Epoch [2/10] Batch 300/938 Loss D: -2.2041, loss G: 18.3049
Epoch [2/10] Batch 400/938 Loss D: -2.2271, loss G: 18.0371
Epoch [2/10] Batch 500/938 Loss D: -2.6134, loss G: 26.1062
Ep

## Save models if necessary

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

!ls -hal '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp $gen_checkpoint_name    '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp $critic_checkpoint_name '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
total 217M
-rw------- 1 root root   32K Mar  1 12:56 '2021.02.26 basics.ipynb'
-rw------- 1 root root   33K Mar  1 12:59 '2021.03.01-1 Pytorch Neural Network example.ipynb'
-rw------- 1 root root   33K Mar  5 08:04 '2021.03.01-2 Convolutional Neural Network example.ipynb'
-rw------- 1 root root   20K Mar 10 10:20 '2021.03.01-3 Recurrent Neural Network example.ipynb'
-rw------- 1 root root   34K Mar  4 22:21 '2021.03.01-4 Bidirectional LSTM example.ipynb'
-rw------- 1 root root   11K Mar  5 14:30 '2021.03.02-1 How to save and load models in Pytorch.ipynb'
-rw------- 1 root root   16K Mar  2 14:21 '2021.03.02-2 Transfer Learning and Fine Tuning.ipynb'
-rw------- 1 root root   49K Mar  5 08:09 '2021.03.02-3 Build custom dataset.ipynb'
-rw------- 1 root root  1.5M Mar 10 08:06 '2021.03.03-1 How to build custom Datasets for Text in Pytorch.ipynb'
-rw------- 1 ro

## Conditional GAN for cats vs. dogs dataset - DIDN'T CONVERGED for cats-dogs dataset

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# !ls -hal '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial/generator_cats-dogs.pth.tar' '.'
!cp $critic_checkpoint_name '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial/critic_cats-dogs.pth.tar' '.'

Mounted at /content/gdrive


In [None]:
# Download dataset from Kaggle

# Info on how to get your api key (kaggle.json) here:
# https://github.com/Kaggle/kaggle-api#api-credentials

# Install kaggle packages if necessary. Not necessary for CoLab
# !pip install -q kaggle
# !pip install -q kaggle-cli

# Colab's file access feature
from google.colab import files

# Upload `kaggle.json` file
uploaded = files.upload()

# Retrieve uploaded file and print results
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))


# Then copy kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!ls ~/.kaggle

# Download the dataset
!kaggle competitions download -c dogs-vs-cats
#!kaggle datasets download -d aladdinpersson/cats-dogs-example-with-csv  # private? not visible
#!kaggle datasets list -s aladdinpersson  # show all visible datasets

Saving kaggle.json to kaggle.json
User uploaded file "kaggle.json" with length 65 bytes
kaggle.json
Downloading train.zip to /content
 96% 519M/543M [00:03<00:00, 99.1MB/s]
100% 543M/543M [00:03<00:00, 156MB/s] 
Downloading test1.zip to /content
 94% 254M/271M [00:02<00:00, 132MB/s]
100% 271M/271M [00:02<00:00, 102MB/s]
Downloading sampleSubmission.csv to /content
  0% 0.00/86.8k [00:00<?, ?B/s]
100% 86.8k/86.8k [00:00<00:00, 90.2MB/s]


In [None]:
# Unzip
import zipfile

with zipfile.ZipFile('train.zip', 'r') as zip_ref:
    zip_ref.extractall('.')


# Check it
source_dir = './train'
train_files = os.listdir(source_dir)
print(f'images number: {len(train_files)}')

images number: 25000


In [None]:
# Create CSV file for labels
import pandas as pd

csv_file = 'cats_dogs.csv'

l = []
for f in train_files:
    s = f.split('.')
    if s[0] == 'cat':
        l.append([f, 0])
    elif s[0] == 'dog':
        l.append([f, 1])
    else:
        print('Error: wrong file name')

cats_dogs = pd.DataFrame(l, columns=['Filename', 'Label'])
cats_dogs.to_csv(csv_file, index=False)

print(cats_dogs.shape)
print(cats_dogs.groupby(by='Label').count())
cats_dogs.head(n=10)

(25000, 2)
       Filename
Label          
0         12500
1         12500


Unnamed: 0,Filename,Label
0,cat.10160.jpg,0
1,cat.10499.jpg,0
2,dog.65.jpg,1
3,dog.3247.jpg,1
4,dog.9660.jpg,1
5,dog.7944.jpg,1
6,cat.6223.jpg,0
7,dog.921.jpg,1
8,cat.9374.jpg,0
9,cat.3559.jpg,0


In [None]:
# Hyperparameters etc.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 1e-4  # could also use two lrs, one for gen and one for critic
BATCH_SIZE = 64  # was 64
GEN_EMBEDDING = 100
NOISE_DIM = 128  # was 100
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

# Change some hyperparameters
IMG_SIZE = 64
CHANNELS_IMG = 3  # RGB image
NUM_CLASSES = 2  # Cats vs. Dogs dataset
NUM_EPOCHS = 5
FEATURES_CRITIC = FEATURES_GEN = 64

# Delete previous logs dir
logs_dir = 'logs_animals'
if os.path.exists(logs_dir):
    !rm -rf $logs_dir

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),  # resize proportionally to rectangular image
    transforms.RandomCrop(IMG_SIZE),  # crop rectangular image to square
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)],
                         [0.5 for _ in range(CHANNELS_IMG)]),
])


# Create dataset class
class CatsAndDogsDataset(Dataset):
    def __init__(self, csv_file, root, transform):
        self.annotations = pd.read_csv(csv_file)
        self.root = root
        self.transform = transform

    def __len__(self):
        return len(self.annotations)  # 25 000 images

    def __getitem__(self, index):
        img_path = os.path.join(self.root, self.annotations.iloc[index, 0])
        image = Image.open(img_path)
        label = torch.tensor(int(self.annotations.iloc[index, 1]))
        image = self.transform(image)
        return (image, label)


# Load data
dataset = CatsAndDogsDataset(csv_file=csv_file, root=source_dir, transform=transform)

loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=multiprocessing.cpu_count(), pin_memory=True)

gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING).to(device)
critic = Critic(CHANNELS_IMG, FEATURES_CRITIC, NUM_CLASSES, IMG_SIZE).to(device)
init_weights(gen)
init_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

writer_real = SummaryWriter(os.path.join(logs_dir, 'real'))
writer_fake = SummaryWriter(os.path.join(logs_dir, 'fake'))
step = 0

# Load models
gen_checkpoint_name = 'generator_cats-dogs.pth.tar'
critic_checkpoint_name = 'critic_cats-dogs.pth.tar'

if os.path.exists(gen_checkpoint_name) and os.path.exists(critic_checkpoint_name):
    step = load_checkpoint(torch.load(gen_checkpoint_name), gen, opt_gen)
    step = load_checkpoint(torch.load(critic_checkpoint_name), critic, opt_critic)


def gradient_penalty(critic, labels, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = epsilon * real + (1 - epsilon) * fake

    # calculate critic scores
    mixed_scores = critic(interpolated_images, labels)
    # calculate gradient
    gradient = torch.autograd.grad(inputs=interpolated_images,
                                   outputs=mixed_scores,
                                   grad_outputs=torch.ones_like(mixed_scores),
                                   create_graph=True,
                                   retain_graph=True,)[0]  # BATCH_SIZE x 3 x 64 x 64
    gradient = gradient.view(gradient.shape[0], -1)  # BATCH_SIZE x 12 288 or 64*64*3
    gradient_norm = gradient.norm(2, dim=1)  # L2 norm
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def train(step=step):
    gen.train()
    critic.train()

    for epoch in range(NUM_EPOCHS):
        for batch_idx, (real, labels) in enumerate(loader):
            real = real.to(device)
            curr_batch_size = real.shape[0]
            labels = labels.to(device)

            # Train Critic: max (E(critic(real)) - E(critic(fake)))
            # or min (-1) * (E(critic(real)) - E(critic(fake)))
            for _ in range(CRITIC_ITERATIONS):
                noise = torch.randn(curr_batch_size, NOISE_DIM, 1, 1).to(device)
                fake = gen(noise, labels)
                critic_real = critic(real, labels)
                critic_fake = critic(fake, labels)
                gp = gradient_penalty(critic, labels, real, fake, device)

                loss_critic = torch.mean(critic_fake) - torch.mean(critic_real) + LAMBDA_GP*gp

                opt_critic.zero_grad()
                loss_critic.backward(retain_graph=True)
                opt_critic.step()

            # Train Generator: min (E(critic(real)) - E(critic(fake)))
            # which is the same as: min (-1) * E(critic(fake))
            # because generator can not influence on E(critic(real)).
            output = critic(fake, labels)
            loss_gen = -torch.mean(output)

            opt_gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

            # Print losses occasionally and print to tensorboard
            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] '
                    f'Batch {batch_idx}/{len(loader)} '
                    f'Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}')

                with torch.no_grad():
                    fake = gen(noise, labels)
                    # take out (up to) 32 examples
                    img_grid_real = torchvision.utils.make_grid(
                        real[:32], normalize=True
                    )
                    img_grid_fake = torchvision.utils.make_grid(
                        fake[:32], normalize=True
                    )

                    writer_real.add_image('Real', img_grid_real, global_step=step)
                    writer_fake.add_image('Fake', img_grid_fake, global_step=step)

                step += 1

        # Save models
        gen_checkpoint = {
            'state_dict': gen.state_dict(),
            'optimizer': opt_gen.state_dict(),
            'step': step,
        }
        critic_checkpoint = {
            'state_dict': critic.state_dict(),
            'optimizer': opt_critic.state_dict(),
            'step': step,
        }
        save_checkpoint(gen_checkpoint, gen_checkpoint_name)
        save_checkpoint(critic_checkpoint, critic_checkpoint_name)

In [None]:
# Run TensorBoard

# To fix the error, because PyTorch and TensorFlow are installed both:
# AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Start TensorBoard before training to monitor it in progress
%tensorboard --logdir $logs_dir

# Reload TensorBoard
%reload_ext tensorboard

In [None]:
train(step=step)

Epoch [1/20] Batch 0/391 Loss D: -12.4652, loss G: 189.9851
Epoch [1/20] Batch 100/391 Loss D: -13.3562, loss G: 196.8530
Epoch [1/20] Batch 200/391 Loss D: -11.8560, loss G: 190.7151
Epoch [1/20] Batch 300/391 Loss D: -11.9301, loss G: 207.8896
=> Saving checkpoint
=> Saving checkpoint
Epoch [2/20] Batch 0/391 Loss D: -11.2401, loss G: 185.9312
Epoch [2/20] Batch 100/391 Loss D: -12.6043, loss G: 182.5322
Epoch [2/20] Batch 200/391 Loss D: -11.4958, loss G: 184.7975
Epoch [2/20] Batch 300/391 Loss D: -6.3614, loss G: 191.2304
=> Saving checkpoint
=> Saving checkpoint
Epoch [3/20] Batch 0/391 Loss D: -7.5468, loss G: 194.0101
Epoch [3/20] Batch 100/391 Loss D: -14.1988, loss G: 193.7846
Epoch [3/20] Batch 200/391 Loss D: -10.7508, loss G: 181.9749
Epoch [3/20] Batch 300/391 Loss D: -7.2492, loss G: 195.2432
=> Saving checkpoint
=> Saving checkpoint
Epoch [4/20] Batch 0/391 Loss D: -10.1171, loss G: 186.9228
Epoch [4/20] Batch 100/391 Loss D: -11.9274, loss G: 197.6188
Epoch [4/20] Batc

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# gen_checkpoint_name = 'generator_cats-dogs.pth.tar'
# critic_checkpoint_name = 'critic_cats-dogs.pth.tar'

!ls -hal '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp $gen_checkpoint_name    '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp $critic_checkpoint_name '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
total 243M
-rw------- 1 root root   32K Mar  1 12:56 '2021.02.26 basics.ipynb'
-rw------- 1 root root   33K Mar  1 12:59 '2021.03.01-1 Pytorch Neural Network example.ipynb'
-rw------- 1 root root   33K Mar  5 08:04 '2021.03.01-2 Convolutional Neural Network example.ipynb'
-rw------- 1 root root   20K Mar 10 10:20 '2021.03.01-3 Recurrent Neural Network example.ipynb'
-rw------- 1 root root   34K Mar  4 22:21 '2021.03.01-4 Bidirectional LSTM example.ipynb'
-rw------- 1 root root   11K Mar  5 14:30 '2021.03.02-1 How to save and load models in Pytorch.ipynb'
-rw------- 1 root root   16K Mar  2 14:21 '2021.03.02-2 Transfer Learning and Fine Tuning.ipynb'
-rw------- 1 root root   49K Mar  5 08:09 '2021.03.02-3 Build custom dataset.ipynb'
-rw------- 1 root root  1.5M Mar 10 08:06 '2021.03.03-1 How to build custom Datasets for Text in Pytorch.ipynb'
-rw------- 1 ro