# **Import Libraries**

In [1]:
import warnings

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST

from torch.utils.data import DataLoader

from tqdm.auto import tqdm

# **Ignoring Any Warning Messages**

In [2]:
warnings.filterwarnings('ignore')

# **Define Important Hyperparameters**

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32
lr = 3e-4
z_dim = 64
img_dim = 28 * 28
num_epochs = 80
fixed_noise = torch.rand(batch_size, z_dim).to(device)

# **Create Discriminator**

In [4]:
class Discriminator(nn.Module):
  def __init__(self, in_features):
    super().__init__()

    self.net = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.LeakyReLU(.01),

        nn.Linear(512, 256),
        nn.LeakyReLU(.01),

        nn.Linear(256, 128),
        nn.LeakyReLU(.01),

        nn.Linear(128, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.net(x)

# **Create Generator**

In [5]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()

    self.net = nn.Sequential(
        nn.Linear(z_dim, 128),
        nn.LeakyReLU(.01),

        nn.Linear(128, 256),
        nn.LeakyReLU(.01),

        nn.Linear(256, 512),
        nn.LeakyReLU(.01),

        nn.Linear(512, img_dim),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.net(x)

# **Generate Dataloader**

### **Define Transformation**

In [6]:
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((.5, ), (.5, ))
    ]
)

### **Create Dataloader**

In [7]:
dataloader = DataLoader(MNIST('.', download = True, transform = transforms), batch_size = batch_size, shuffle = True)

# **Training Loop**

In [8]:
disc = Discriminator(img_dim).to(device)
gen  = Generator(z_dim, img_dim).to(device)

opt_disc = optim.Adam(disc.parameters(), lr = lr)
opt_gen  = optim.Adam(gen.parameters(),  lr = lr)

criterion = nn.BCELoss()

for epoch in range(num_epochs):
  for real, _ in tqdm(dataloader):

    batch_size = real.shape[0]
    real = real.view(batch_size, -1).to(device)

    noise = torch.randn(batch_size, z_dim).to(device)
    fake  = gen(noise)

    disc_real  = disc(real).view(-1)
    disc_fake = disc(fake).view(-1)

    lossD_real  = criterion(disc_real,  torch.ones_like(disc_real))
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    lossD = (lossD_real + lossD_fake) / 2

    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()

    output = disc(fake).view(-1)

    lossG = criterion(output, torch.ones_like(output))

    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

  print(f"epoch{epoch + 1}:  lossD:{lossD : .4f}   lossG:{lossG : .4f}")

  with torch.no_grad():
    fake = gen(fixed_noise).reshape(-1, 1, 28, 28).to(device)
    real = next(iter(dataloader))[0].reshape(-1, 1, 28, 28).to(device)

    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
    img_grid_real = torchvision.utils.make_grid(real, normalize=True)

    torchvision.utils.save_image(img_grid_fake, "fake_grid.png")
    torchvision.utils.save_image(img_grid_real, "real_grid.png")

  0%|          | 0/1875 [00:00<?, ?it/s]

epoch1:  lossD: 0.4002   lossG: 1.7359


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch2:  lossD: 0.0880   lossG: 4.9819


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch3:  lossD: 0.1354   lossG: 3.4863


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch4:  lossD: 0.2355   lossG: 4.7896


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch5:  lossD: 0.1889   lossG: 2.9721


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch6:  lossD: 0.1143   lossG: 3.3036


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch7:  lossD: 0.2507   lossG: 2.1241


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch8:  lossD: 0.3011   lossG: 2.9819


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch9:  lossD: 0.2400   lossG: 1.5062


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch10:  lossD: 0.3310   lossG: 1.7306


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch11:  lossD: 0.3893   lossG: 2.2779


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch12:  lossD: 0.3012   lossG: 2.6690


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch13:  lossD: 0.4817   lossG: 1.3637


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch14:  lossD: 0.3233   lossG: 1.5443


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch15:  lossD: 0.3344   lossG: 1.6006


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch16:  lossD: 0.3493   lossG: 1.6248


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch17:  lossD: 0.3648   lossG: 1.5701


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch18:  lossD: 0.3210   lossG: 2.0041


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch19:  lossD: 0.4279   lossG: 1.5673


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch20:  lossD: 0.4844   lossG: 1.6313


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch21:  lossD: 0.5581   lossG: 1.3653


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch22:  lossD: 0.4218   lossG: 1.4906


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch23:  lossD: 0.3604   lossG: 1.5289


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch24:  lossD: 0.6130   lossG: 0.9807


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch25:  lossD: 0.4410   lossG: 1.5409


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch26:  lossD: 0.4296   lossG: 1.3694


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch27:  lossD: 0.4539   lossG: 1.6923


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch28:  lossD: 0.4531   lossG: 1.5547


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch29:  lossD: 0.5237   lossG: 1.3482


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch30:  lossD: 0.4817   lossG: 1.4752


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch31:  lossD: 0.3787   lossG: 1.4406


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch32:  lossD: 0.4336   lossG: 1.7492


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch33:  lossD: 0.5035   lossG: 1.4718


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch34:  lossD: 0.5649   lossG: 1.4187


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch35:  lossD: 0.4325   lossG: 1.4811


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch36:  lossD: 0.4636   lossG: 1.3772


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch37:  lossD: 0.4152   lossG: 1.3216


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch38:  lossD: 0.3744   lossG: 1.5346


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch39:  lossD: 0.4127   lossG: 1.5888


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch40:  lossD: 0.3769   lossG: 1.5025


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch41:  lossD: 0.4284   lossG: 1.6088


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch42:  lossD: 0.4977   lossG: 1.5794


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch43:  lossD: 0.4926   lossG: 1.1737


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch44:  lossD: 0.5183   lossG: 1.3605


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch45:  lossD: 0.4680   lossG: 1.3374


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch46:  lossD: 0.4482   lossG: 1.4099


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch47:  lossD: 0.5158   lossG: 1.3424


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch48:  lossD: 0.3730   lossG: 1.5950


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch49:  lossD: 0.3753   lossG: 1.5430


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch50:  lossD: 0.3974   lossG: 1.3161


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch51:  lossD: 0.4214   lossG: 1.5359


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch52:  lossD: 0.4596   lossG: 1.6602


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch53:  lossD: 0.3619   lossG: 1.9528


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch54:  lossD: 0.4667   lossG: 0.9834


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch55:  lossD: 0.3610   lossG: 1.4581


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch56:  lossD: 0.4772   lossG: 1.3829


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch57:  lossD: 0.4740   lossG: 1.5397


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch58:  lossD: 0.5510   lossG: 1.4935


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch59:  lossD: 0.3714   lossG: 1.2924


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch60:  lossD: 0.5173   lossG: 1.5713


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch61:  lossD: 0.3530   lossG: 1.5857


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch62:  lossD: 0.3763   lossG: 1.6711


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch63:  lossD: 0.5436   lossG: 1.3032


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch64:  lossD: 0.4200   lossG: 1.4028


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch65:  lossD: 0.4504   lossG: 1.4931


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch66:  lossD: 0.4253   lossG: 1.3755


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch67:  lossD: 0.4397   lossG: 1.4410


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch68:  lossD: 0.4606   lossG: 1.3091


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch69:  lossD: 0.3995   lossG: 1.4349


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch70:  lossD: 0.3425   lossG: 1.4964


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch71:  lossD: 0.5517   lossG: 1.5909


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch72:  lossD: 0.4382   lossG: 1.6279


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch73:  lossD: 0.4950   lossG: 1.7286


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch74:  lossD: 0.4175   lossG: 1.3949


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch75:  lossD: 0.5234   lossG: 1.5204


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch76:  lossD: 0.4313   lossG: 1.7984


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch77:  lossD: 0.3712   lossG: 1.6356


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch78:  lossD: 0.3636   lossG: 1.2525


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch79:  lossD: 0.4977   lossG: 1.5100


  0%|          | 0/1875 [00:00<?, ?it/s]

epoch80:  lossD: 0.4464   lossG: 1.7153
