In [None]:
import torch
import torch.nn as nn


In [24]:
class Discriminator(nn.Module):
  def __init__(self, img_channels, feature_dims):
    super(Discriminator,self).__init__()
    self.disc = nn.Sequential(
      nn.Conv2d(img_channels,feature_dims, 4, 2, 1),
      nn.LeakyReLU(0.2),
      self.conv_block(feature_dims,feature_dims*2,4,2,1),
      self.conv_block(feature_dims*2,feature_dims*4,4,2,1),
      self.conv_block(feature_dims*4,feature_dims*8,4,2,1),
      nn.Conv2d(feature_dims*8,1,4,2,0),
      nn.Sigmoid(),
    ) 


  def conv_block(self, in_channels, out_channels, kernel,stride, padding):

    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel,
                  stride, padding, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )

  def forward(self,x):
    return self.disc(x)
    

In [32]:
class Generator(nn.Module):

  def __init__(self, z_dims, img_channels, feature_g):
    self.f = feature_g


    super(Generator,self).__init__()
    self.gen = nn.Sequential(
        self.transpose_block(z_dims,feature_g*16,4,1,0),
        self.transpose_block(feature_g*16,feature_g*8,4,2,1),
        self.transpose_block(feature_g*8, feature_g*4,4,2,1),
        self.transpose_block(feature_g*4, feature_g*2,4,2,1),
        nn.ConvTranspose2d(feature_g*2,img_channels,4,2,1),
        nn.Tanh()
    )


  def transpose_block(self,in_channels, out_channels, kernel , stride, padding):
   
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel,
                           stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )

  def forward(self,x):
    return self.gen(x)


def initialize_weights(model):

  for w in model.modules():
    if isinstance(w,(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(w.weight.data, 0.0, 0.02)

def test():
  n, in_channels, h, w = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((n, in_channels, h, w))
  disc = Discriminator(in_channels,8)
  initialize_weights(disc)

  assert disc(x).shape ==(n, 1, 1, 1)
  gen = Generator(z_dim, in_channels, 8)
  initialize_weights(gen)
  z = torch.randn((n, z_dim, 1, 1))
 

  assert gen(z).shape == (n, in_channels, h, w)
  print("test passed")

test()



test passed


In [33]:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [34]:
Lr_rate = 2e-4
BATCH_SIZE = 128
IMG_SIZE = 64
IMG_CHANNELS = 1
Z_DIMS = 100
N_EPOCHS = 5
DISC_FEATURES = 64
GEN_FEATURES = 64




In [38]:
transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(IMG_CHANNELS)],[0.5 for _ in range(IMG_CHANNELS)]
    )
]
)

In [39]:
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
                         download=True)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [42]:
gen = Generator(Z_DIMS, IMG_CHANNELS, GEN_FEATURES)
disc = Discriminator(IMG_CHANNELS, DISC_FEATURES)
initialize_weights(gen)
initialize_weights(disc)

In [45]:
gen_optim = optim.Adam(gen.parameters(), lr=Lr_rate, betas=(0.5,0.999))
disc_optim = optim.Adam(disc.parameters(), lr=Lr_rate, betas=(0.5,0.999))
criterion = nn.BCELoss()

In [44]:
fixed_noise = torch.randn(32, Z_DIMS, 1, 1)


In [None]:
for epoch in range(N_EPOCHS):
  for batch_idx, (real,_) in enumerate(dataloader):
    noise = torch.randn(BATCH_SIZE, Z_DIMS, 1, 1)
    fake = gen(noise)
    #training Disciminator

    d_real = disc(real).reshape(-1)
    loss_d_real = criterion(d_real, torch.ones_like(d_real))
    d_fake = disc(fake).reshape(-1)
    loss_d_fake = criterion(d_fake,torch.zeros_like(d_fake))
    lossDisc = (loss_d_real+loss_d_fake)/2
    disc.zero_grad()
    lossDisc.backward(retain_graph=True)
    disc_optim.step()

    #training Generator

    op = disc(fake).reshape(-1)
    g_loss = criterion(op, torch.ones_like(op))
    gen.zero_grad()
    g_loss.backward()
    gen_optim.step()

    if batch_idx % 100 ==0:
      print(f"Epoch: {epoch+1}/{N_EPOCHS}")


                            

