In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision.transforms import transforms
plt.rcParams['image.cmap']='gray'
from torchsummary import summary

In [2]:
train_data=datasets.FashionMNIST(root='../data/',train=True,download=True,transform=transforms.ToTensor())
test_data=datasets.FashionMNIST(root='../data/',train= False,download=True,transform=transforms.ToTensor())

In [3]:
train_iter=DataLoader(train_data,batch_size=100,shuffle=True)
test_iter=DataLoader(test_data,batch_size=100)

In [4]:
for x,y in train_iter:
    print(x.size())
    break

torch.Size([100, 1, 28, 28])


In [5]:
class Reshape(nn.Module):
    def __init__(self,*args):
        super().__init__()
        self.shape=args
    def forward(self,x):
        return x.view(self.shape)

In [6]:
latent_dim=100

In [13]:
class GAN(nn.Module):
    def __init__(self,latent_dim):
        super().__init__()
        self.discriminator=nn.Sequential(
                            Reshape(-1,1,28,28),
                            #64*28*28
                            nn.Conv2d(1,64,kernel_size=2,padding=1,bias=False),
                            nn.BatchNorm2d(64),
                            nn.LeakyReLU(0.2),
                            #100*15*15
                            nn.Conv2d(64,100,kernel_size=2,stride=2,padding=1,bias=False),
                            nn.BatchNorm2d(100),
                            nn.LeakyReLU(0.2),
                            #120*8*8
                            nn.Conv2d(100,120,kernel_size=2,stride=2,padding=1,bias=False),
                            nn.BatchNorm2d(120),
                            nn.LeakyReLU(0.2),
                            nn.Flatten(),
                            nn.Linear(7680,1,bias=False),
                            nn.Sigmoid                          
                                        )
        self.generator=nn.Sequential(
                        nn.Linear(latent_dim,490,bias=False),
                        Reshape(-1,10,7,7),
                        #120*14*14
                        nn.ConvTranspose2d(10,120,kernel_size=4,stride=2,padding=1,bias=False),
                        nn.BatchNorm2d(120),
                        nn.ReLU(),
                        nn.ConvTranspose2d(120,64,kernel_size=3,stride=1,padding=1,bias=False),
                        nn.BatchNorm2d(64),
                        nn.ReLU(),
                        nn.ConvTranspose2d(64,72,kernel_size=4,stride=2,padding=1,bias=False),
                        nn.BatchNorm2d(72),
                        nn.ReLU(0.02),
                        nn.Conv2d(72,1,kernel_size=3,padding=1,bias=False),
                        nn.Tanh()
)
        self.latent_dim=latent_dim
        
    def discriminator_forward(self,x):
        return self.discriminator(x)
    
    def generator_forward(self,x):
        return self.generator(x)

In [14]:
def discriminator_loss(real_output,fake_output):
    real_loss=nn.functional.binary_cross_entropy_with_logits(torch.ones_like(real_output),real_output)
    fake_loss=nn.functional.binary_cross_entropy_with_logits(torch.zeros_like(fake_output), fake_output)
    loss=real_loss + fake_loss
    return loss

def generator_loss( fake_output):
    loss=nn.functional.binary_cross_entropy_with_logits(torch.ones_like(fake_output), fake_output)
    return loss

In [15]:
gan=GAN(latent_dim=latent_dim)
x=torch.rand((1,1,28,28))
summary(gan.discriminator,x)

Layer (type:depth-idx)                   Output Shape              Param #
├─Reshape: 1-1                           [-1, 1, 28, 28]           --
├─Conv2d: 1-2                            [-1, 64, 29, 29]          256
├─BatchNorm2d: 1-3                       [-1, 64, 29, 29]          128
├─LeakyReLU: 1-4                         [-1, 64, 29, 29]          --
├─Conv2d: 1-5                            [-1, 100, 15, 15]         25,600
├─BatchNorm2d: 1-6                       [-1, 100, 15, 15]         200
├─LeakyReLU: 1-7                         [-1, 100, 15, 15]         --
├─Conv2d: 1-8                            [-1, 120, 8, 8]           48,000
├─BatchNorm2d: 1-9                       [-1, 120, 8, 8]           240
├─LeakyReLU: 1-10                        [-1, 120, 8, 8]           --
├─Flatten: 1-11                          [-1, 7680]                --
├─Linear: 1-12                           [-1, 1]                   7,680
├─Tanh: 1-13                             [-1, 1]                   --


Layer (type:depth-idx)                   Output Shape              Param #
├─Reshape: 1-1                           [-1, 1, 28, 28]           --
├─Conv2d: 1-2                            [-1, 64, 29, 29]          256
├─BatchNorm2d: 1-3                       [-1, 64, 29, 29]          128
├─LeakyReLU: 1-4                         [-1, 64, 29, 29]          --
├─Conv2d: 1-5                            [-1, 100, 15, 15]         25,600
├─BatchNorm2d: 1-6                       [-1, 100, 15, 15]         200
├─LeakyReLU: 1-7                         [-1, 100, 15, 15]         --
├─Conv2d: 1-8                            [-1, 120, 8, 8]           48,000
├─BatchNorm2d: 1-9                       [-1, 120, 8, 8]           240
├─LeakyReLU: 1-10                        [-1, 120, 8, 8]           --
├─Flatten: 1-11                          [-1, 7680]                --
├─Linear: 1-12                           [-1, 1]                   7,680
├─Tanh: 1-13                             [-1, 1]                   --


In [16]:
x=torch.rand((1,latent_dim))
summary(gan.generator,x)

Layer (type:depth-idx)                   Output Shape              Param #
├─Linear: 1-1                            [-1, 490]                 49,000
├─Reshape: 1-2                           [-1, 10, 7, 7]            --
├─ConvTranspose2d: 1-3                   [-1, 120, 14, 14]         19,200
├─BatchNorm2d: 1-4                       [-1, 120, 14, 14]         240
├─ReLU: 1-5                              [-1, 120, 14, 14]         --
├─ConvTranspose2d: 1-6                   [-1, 64, 14, 14]          69,120
├─BatchNorm2d: 1-7                       [-1, 64, 14, 14]          128
├─ReLU: 1-8                              [-1, 64, 14, 14]          --
├─ConvTranspose2d: 1-9                   [-1, 72, 28, 28]          73,728
├─BatchNorm2d: 1-10                      [-1, 72, 28, 28]          144
├─ReLU: 1-11                             [-1, 72, 28, 28]          --
├─Conv2d: 1-12                           [-1, 1, 28, 28]           648
├─Tanh: 1-13                             [-1, 1, 28, 28]         

Layer (type:depth-idx)                   Output Shape              Param #
├─Linear: 1-1                            [-1, 490]                 49,000
├─Reshape: 1-2                           [-1, 10, 7, 7]            --
├─ConvTranspose2d: 1-3                   [-1, 120, 14, 14]         19,200
├─BatchNorm2d: 1-4                       [-1, 120, 14, 14]         240
├─ReLU: 1-5                              [-1, 120, 14, 14]         --
├─ConvTranspose2d: 1-6                   [-1, 64, 14, 14]          69,120
├─BatchNorm2d: 1-7                       [-1, 64, 14, 14]          128
├─ReLU: 1-8                              [-1, 64, 14, 14]          --
├─ConvTranspose2d: 1-9                   [-1, 72, 28, 28]          73,728
├─BatchNorm2d: 1-10                      [-1, 72, 28, 28]          144
├─ReLU: 1-11                             [-1, 72, 28, 28]          --
├─Conv2d: 1-12                           [-1, 1, 28, 28]           648
├─Tanh: 1-13                             [-1, 1, 28, 28]         

In [17]:

g_optimizer=torch.optim.Adam(gan.generator.parameters())
d_optimizer=torch.optim.Adam(gan.discriminator.parameters())

In [None]:
epochs=1
for epoch in range(epochs):
    gan=gan.train()
    for x,_ in train_iter:
        x=(x-0.5)*2
        batch_size=x.size()[0]
        
        z_random_latent_vectors=torch.randn((batch_size,latent_dim))
        generated_imgs=gan.generator_forward(z_random_latent_vectors)
        fake_output=gan.discriminator_forward(generated_imgs)
        real_output=gan.discriminator_forward(x)
        
        d_optimizer.zero_grad()
        d_loss=discriminator_loss(real_output=real_output,fake_output= fake_output)
        d_loss.backward()
        d_optimizer.step()
        
        g_optimizer.zero_grad()
        #z_random_latent_vectors=torch.randn((batch_size,latent_dim))
        generated_imgs=gan.generator_forward(z_random_latent_vectors)
        fake_output=gan.discriminator_forward(generated_imgs)
        g_loss=generator_loss(fake_output)
        g_loss.backward()
        g_optimizer.step()
        
        print("Epoch %d: d_loss %.3f,  g_loss %.3f" % (epoch+1, d_loss.detach().numpy().mean(),
                                                       g_loss.detach().numpy().mean()))
        

  **REFERENCE**

- [Generative Adversarial Networks](http://d2l.ai/chapter_generative-adversarial-networks/gan.html)

- [Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661.pdf)

- [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/pdf/1701.00160.pdf)