In [1]:
!pip install pytorch_lightning
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

import pytorch_lightning as pl

Collecting pytorch_lightning
  Downloading pytorch_lightning-1.5.3-py3-none-any.whl (523 kB)
[K     |████████████████████████████████| 523 kB 5.7 MB/s 
[?25hCollecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2021.11.1-py3-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 42.9 MB/s 
[?25hCollecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 35.2 MB/s 
[?25hCollecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.6.0-py3-none-any.whl (329 kB)
[K     |████████████████████████████████| 329 kB 49.3 MB/s 
[?25hCollecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting PyYAML>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 43.7 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x8

In [2]:
parameter = {
    "z_size": 64,
    "img_size":784,
    "hidden_size": 32,
    "output_size":1,
    "lr": 0.0002,
    "n_epoch": 100,
    "batch_size": 100,
    "num_workers": 0
}
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [3]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.model = nn.Sequential(
        nn.Linear(parameter['img_size'],parameter['hidden_size']*4),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(parameter['hidden_size']*4,parameter['hidden_size']*2),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(parameter['hidden_size']*2,parameter['hidden_size']),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(parameter['hidden_size'],1),
        nn.Sigmoid()
    )

  def forward(self,x):
    return self.model(x)

class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.model = nn. Sequential(
        nn.Linear(parameter['z_size'],parameter['hidden_size']),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(parameter['hidden_size'], parameter['hidden_size']*2),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(parameter['hidden_size']*2, parameter['hidden_size']*4),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),

        nn.Linear(parameter['hidden_size']*4,parameter['img_size']),
        nn.Tanh()
    )

  def forward(self,x):
    return self.model(x)

In [4]:
criterion = nn.BCELoss()
fixed_z = torch.randn(parameter['batch_size'], parameter['z_size']).to(device)

In [5]:
class GAN(pl.LightningModule):
  def __init__(self):
    super().__init__()

    self.G = Generator().to(device)
    self.D = Discriminator().to(device)
  
  def forward(self,z):
    return self.G(z)
  
  def training_step(self, batch, batch_idx, optimizer_idx):
    real_imgs,_ = batch
    real_imgs = real_imgs.view(parameter['batch_size'], -1).to(device)
    
    z = torch.randn(parameter['batch_size'], parameter['z_size']).to(device)

    if( batch_idx%600==0 and optimizer_idx==0):
      fake_images = self.G(z).reshape(self.G(z).size(0), 1, 28, 28)
      imgs = torchvision.utils.make_grid(fake_images)
      npimgs = imgs.numpy()
      plt.figure(figsize=(8,8))
      plt.imshow(np.transpose(npimgs, (1,2,0)), cmap='Greys_r')
      plt.xticks([])
      plt.yticks([])
      plt.show()

    if optimizer_idx == 0:
      fake_imgs = self.G(z)
      G_fake = self.D(fake_imgs)
      G_loss = criterion(G_fake,torch.ones(parameter['batch_size'],1).to(device))
      return G_loss
    elif optimizer_idx == 1:
      fake_imgs =  self(z)
      D_real = self.D(real_imgs)
      D_real_loss = criterion(D_real,torch.ones(parameter['batch_size'],1).to(device))
      D_fake = self.D(fake_imgs)
      D_fake_loss = criterion(D_fake,torch.zeros(parameter['batch_size'],1).to(device))
      D_loss = D_real_loss+D_fake_loss
      return D_loss

  def configure_optimizers(self):
    opt_g = optim.Adam(self.G.parameters(),lr=0.0002,betas=(0.5,0.999))
    opt_d = optim.Adam(self.D.parameters(),lr=0.0002,betas=(0.5,0.999))
    return [opt_g, opt_d]

  def train_dataloader(self):
    train_dataset = torchvision.datasets.MNIST(root = 'data', train = True,
                                    download = True, transform = transforms.Compose([transforms.ToTensor(),
                                                                                     transforms.Normalize(mean=[0.5],
                                                                                                          std=[0.5])]))
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size = 100, num_workers=2, shuffle=True)
    return train_loader
  
  def val_dataloader(self):
    val_dataset = torchvision.datasets.MNIST(root = 'data', train = False,
                                    download = False, transform = transforms.Compose([transforms.ToTensor(),
                                                                                     transforms.Normalize(mean=[0.5],
                                                                                                          std=[0.5])]))
    
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,batch_size = 100, num_workers=2, shuffle=False)
    return val_loader




In [6]:
trainer = pl.Trainer(max_epochs=100)
model = GAN()
trainer.fit(model)

Output hidden; open in https://colab.research.google.com to view.