In [68]:
!pip install torch torchvision matplotlib imageio

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [69]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import imageio
import numpy as np
import matplotlib

from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from tqdm import tqdm

matplotlib.style.use('ggplot')

In [70]:
# learning parameters
batch_size = 512
epochs = 200
sample_size = 64 # fixed sample size
nz = 128 # latent vector size
k = 1 # number of steps to apply to the discriminator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [71]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),
])
to_pil_image = transforms.ToPILImage()

In [72]:
train_data = datasets.MNIST(
    root='../input/data',
    train=True,
    download=True,
    transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)


In [73]:
class Generator(nn.Module):
    def __init__(self, nz):
        super(Generator, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            nn.Linear(self.nz, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )
    def forward(self, x):
        return self.main(x).view(-1, 1, 28, 28)

In [74]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_input = 784
        self.main = nn.Sequential(
            nn.Linear(self.n_input, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.main(x)

In [75]:
generator = Generator(nz).to(device)
discriminator = Discriminator().to(device)
print('##### GENERATOR #####')
print(generator)
print('######################')
print('\n##### DISCRIMINATOR #####')
print(discriminator)
print('######################')

##### GENERATOR #####
Generator(
  (main): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Linear(in_features=1024, out_features=784, bias=True)
    (7): Tanh()
  )
)
######################

##### DISCRIMINATOR #####
Discriminator(
  (main): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=256, out_features=1

In [76]:
# optimizers
optim_g = optim.Adam(generator.parameters(), lr=0.0002)
optim_d = optim.Adam(discriminator.parameters(), lr=0.0002)

# loss function
criterion = nn.BCELoss()

# losses
losses_g = [] # to store generator loss after each epoch
losses_d = [] # to store discriminator loss after each epoch
images = [] # to store images generatd by the generator

In [77]:
# to create real labels (1s)
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)
# to create fake labels (0s)
def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)

# function to create the noise vector
def create_noise(sample_size, nz):
    return torch.randn(sample_size, nz).to(device)

# to save the images generated by the generator
def save_generator_image(image, path):
    save_image(image, path)

In [78]:
# function to train the discriminator network
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    real_label = label_real(b_size)
    fake_label = label_fake(b_size)
    optimizer.zero_grad()
    output_real = discriminator(data_real)
    loss_real = criterion(output_real, real_label)
    output_fake = discriminator(data_fake)
    loss_fake = criterion(output_fake, fake_label)
    loss_real.backward()
    loss_fake.backward()
    optimizer.step()
    return loss_real + loss_fake

In [79]:
# function to train the generator network
def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)
    real_label = label_real(b_size)
    optimizer.zero_grad()
    output = discriminator(data_fake)
    loss = criterion(output, real_label)
    loss.backward()
    optimizer.step()
    return loss

In [80]:
# create the noise vector
noise = create_noise(sample_size, nz)
generator.train()
discriminator.train()

Discriminator(
  (main): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=256, out_features=1, bias=True)
    (10): Sigmoid()
  )
)

In [81]:
!mkdir -p mnist_gans/outputs
!ls -la

total 4169
drwxr-xr-x 5 root root       8 Sep 25 21:55 .
drwxr-xr-x 8 root root      13 Sep 25 21:44 ..
drwxr-xr-x 2 root root       2 Sep 25 21:28 .ipynb_checkpoints
-rw-r--r-- 1 root root   33026 Sep 23 16:02 MNIST-Digit-Classification_using_run_img_cls.ipynb
-rw-r--r-- 1 root root  169272 Sep 23 18:10 MNIST-Full-Data-Analysis.ipynb
-rw-r--r-- 1 root root   21706 Sep 25 21:55 MNIST-Generative_Adversarial_Networks.ipynb
-rw-r--r-- 1 root root   62198 Sep 23 16:02 Roman-Numeral-Classification_using_run_img_cls.ipynb
-rw-r--r-- 1 root root 3980882 Sep 13 23:52 TSNE_On_MNIST.ipynb
drwxr-xr-x 3 root root       1 Sep 25 21:39 mnist_data
drwxr-xr-x 4 root root       2 Sep 25 21:55 mnist_gans


In [None]:
for epoch in range(epochs):
    loss_g = 0.0
    loss_d = 0.0
    for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
        image, _ = data
        image = image.to(device)
        b_size = len(image)
        # run the discriminator for k number of steps
        for step in range(k):
            data_fake = generator(create_noise(b_size, nz)).detach()
            data_real = image
            # train the discriminator network
            loss_d += train_discriminator(optim_d, data_real, data_fake)
        data_fake = generator(create_noise(b_size, nz))
        # train the generator network
        loss_g += train_generator(optim_g, data_fake)
    # create the final fake image for the epoch
    generated_img = generator(noise).cpu().detach()
    # make the images as grid
    generated_img = make_grid(generated_img)
    # save the generated torch tensor models to disk
    save_generator_image(generated_img, f"mnist_gans/outputs/gen_img{epoch}.png")
    images.append(generated_img)
    epoch_loss_g = loss_g / bi # total generator loss for the epoch
    epoch_loss_d = loss_d / bi # total discriminator loss for the epoch
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)
    
    print(f"Epoch {epoch} of {epochs}")
    print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")

118it [00:15,  7.60it/s]                         


Epoch 0 of 200
Generator loss: 1.32530153, Discriminator loss: 0.93339187


118it [00:15,  7.43it/s]                         


Epoch 1 of 200
Generator loss: 2.10442781, Discriminator loss: 1.24310565


118it [00:15,  7.50it/s]                         


Epoch 2 of 200
Generator loss: 3.32916188, Discriminator loss: 0.84761894


118it [00:15,  7.53it/s]                         


Epoch 3 of 200
Generator loss: 2.71348262, Discriminator loss: 1.01204181


118it [00:15,  7.67it/s]                         


Epoch 4 of 200
Generator loss: 2.28597641, Discriminator loss: 1.26122856


118it [00:16,  7.37it/s]                         


Epoch 5 of 200
Generator loss: 1.37975395, Discriminator loss: 1.04801834


118it [00:16,  7.35it/s]                         


Epoch 6 of 200
Generator loss: 1.22216094, Discriminator loss: 1.09545517


118it [00:15,  7.51it/s]                         


Epoch 7 of 200
Generator loss: 1.34024346, Discriminator loss: 1.28265405


118it [00:16,  7.31it/s]                         


Epoch 8 of 200
Generator loss: 1.25383067, Discriminator loss: 1.13575017


118it [00:16,  7.30it/s]                         


Epoch 9 of 200
Generator loss: 1.17950213, Discriminator loss: 1.07224154


118it [00:16,  7.22it/s]                         


Epoch 10 of 200
Generator loss: 1.52006304, Discriminator loss: 1.24278867


118it [00:16,  7.26it/s]                         


Epoch 11 of 200
Generator loss: 1.21575439, Discriminator loss: 1.21784937


118it [00:16,  7.37it/s]                         


Epoch 12 of 200
Generator loss: 1.58386827, Discriminator loss: 1.17667413


118it [00:16,  7.24it/s]                         


Epoch 13 of 200
Generator loss: 1.92653215, Discriminator loss: 1.29030228


118it [00:16,  7.14it/s]                         


Epoch 14 of 200
Generator loss: 2.03188300, Discriminator loss: 1.09004402


118it [00:16,  7.21it/s]                         


Epoch 15 of 200
Generator loss: 1.07607806, Discriminator loss: 1.14402461


118it [00:16,  7.29it/s]                         


Epoch 16 of 200
Generator loss: 1.11665690, Discriminator loss: 1.16225696


118it [00:16,  7.33it/s]                         


Epoch 17 of 200
Generator loss: 2.14706659, Discriminator loss: 0.82187754


118it [00:16,  7.10it/s]                         


Epoch 18 of 200
Generator loss: 1.44348812, Discriminator loss: 1.13707221


118it [00:16,  7.29it/s]                         


Epoch 19 of 200
Generator loss: 1.55810404, Discriminator loss: 1.26635599


118it [00:16,  7.28it/s]                         


Epoch 20 of 200
Generator loss: 1.37957549, Discriminator loss: 1.02946413


118it [00:16,  7.17it/s]                         


Epoch 21 of 200
Generator loss: 1.36409819, Discriminator loss: 1.10331619


118it [00:16,  7.30it/s]                         


Epoch 22 of 200
Generator loss: 1.32062542, Discriminator loss: 1.00502944


118it [00:16,  7.19it/s]                         


Epoch 23 of 200
Generator loss: 2.09034538, Discriminator loss: 0.84948105


118it [00:16,  7.22it/s]                         


Epoch 24 of 200
Generator loss: 1.55064034, Discriminator loss: 1.09854949


118it [00:16,  7.12it/s]                         


Epoch 25 of 200
Generator loss: 2.19465947, Discriminator loss: 0.68284982


118it [00:16,  7.07it/s]                         


Epoch 26 of 200
Generator loss: 2.37272167, Discriminator loss: 0.82395840


118it [00:16,  7.17it/s]                         


Epoch 27 of 200
Generator loss: 1.91850281, Discriminator loss: 0.70356065


118it [00:16,  7.11it/s]                         


Epoch 28 of 200
Generator loss: 2.80933928, Discriminator loss: 0.68774289


118it [00:16,  7.15it/s]                         


Epoch 29 of 200
Generator loss: 2.31559014, Discriminator loss: 0.57359570


 77%|███████▋  | 90/117 [00:12<00:03,  6.91it/s]

In [None]:
print('DONE TRAINING')
torch.save(generator.state_dict(), 'mnist_gans/outputs/generator.pth')

# save the generated images as GIF file
imgs = [np.array(to_pil_image(img)) for img in images]
imageio.mimsave('mnist_gans/outputs/generator_images.gif', imgs)

# plot and save the generator and discriminator loss
plt.figure()
plt.plot(losses_g, label='Generator loss')
plt.plot(losses_d, label='Discriminator Loss')
plt.legend()
plt.savefig('mnist_gans/outputs/loss.png')