# **Import Libraries**

In [1]:
import warnings

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST

from tqdm import tqdm

# **Ignoring Any Warning Messages**

In [2]:
warnings.filterwarnings('ignore')

# **Define Important Hyperparameters**

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 0.0005
num_epochs = 100
batch_size = 256
img_dim = 64
z_dim = 256
img_channels = 1
features = 64
fixed_noise = torch.rand(32, z_dim, 1, 1).to(device)

# **Create Discriminator**

In [4]:
class Discriminator(nn.Module):
  def __init__(self, img_channels, features):
    super().__init__()

    self.net = nn.Sequential(
        nn.Conv2d(img_channels, features * 1, 4, 2, 1),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 1, features * 2, 4, 2, 1),
        nn.BatchNorm2d(features * 2),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 2, features * 4, 4, 2, 1),
        nn.BatchNorm2d(features * 4),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 4, features * 8, 4, 2, 1),
        nn.BatchNorm2d(features * 8),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 8,  1, 4, 2, 0),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.net(x)

# **Create Generator**

In [5]:
class Generator(nn.Module):
  def __init__(self, noise_channels, img_channels, features):
    super().__init__()

    self.net = nn.Sequential(
        nn.ConvTranspose2d(noise_channels, features * 16, 4, 1, 0),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 16, features * 8, 4, 2, 1),
        nn.BatchNorm2d(features * 8),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 8, features * 4, 4, 2, 1),
        nn.BatchNorm2d(features * 4),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 4, features * 2, 4, 2, 1),
        nn.BatchNorm2d(features * 2),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 2, img_channels, 4, 2, 1),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.net(x)

# **Define Weights Intialization Function**

In [6]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d,  nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0, .02)

# **Generate Dataloader**

### **Define Transformation**

In [7]:
transforms = transforms.Compose(
    [
        transforms.Resize(img_dim),
        transforms.ToTensor(),
        transforms.Normalize((.5, ), (.5, ))
    ]
)

### **Create Dataloader**

In [8]:
dataloader = DataLoader(FashionMNIST('.', download = True, transform = transforms), batch_size = batch_size, shuffle = True)

# **Training Loop**

In [9]:
gen = Generator(z_dim, img_channels, features).to(device)
initialize_weights(gen)

disc = Discriminator(img_channels, features).to(device)
initialize_weights(disc)

opt_gen  = Adam(gen.parameters(),  lr=lr, betas=(0.5, 0.999))
opt_disc = Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()

for epoch in range(num_epochs):
    gen.train()
    disc.train()

    for real, _ in tqdm(dataloader):
        real = real.to(device)

        noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
        fake = gen(noise)

        disc_real = disc(real).view(-1)
        disc_fake = disc(fake.detach()).view(-1)

        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        output = disc(fake)
        lossG = criterion(output, torch.ones_like(output))

        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

    print(f"epoch {epoch+1}: lossD={lossD:.4f}, lossG={lossG:.4f}")

    with torch.no_grad():
        fake = gen(fixed_noise.to(device)).reshape(-1, 1, 64, 64)
        real, _ = next(iter(dataloader))
        real = real.to(device)

        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
        img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)

        torchvision.utils.save_image(img_grid_fake, "fake_grid.png")
        torchvision.utils.save_image(img_grid_real, "real_grid.png")

100%|██████████| 235/235 [02:22<00:00,  1.65it/s]


epoch 1: lossD=0.5676, lossG=1.1875


100%|██████████| 235/235 [02:21<00:00,  1.66it/s]


epoch 2: lossD=0.6074, lossG=1.1143


100%|██████████| 235/235 [02:21<00:00,  1.66it/s]


epoch 3: lossD=0.2699, lossG=2.2751


100%|██████████| 235/235 [02:20<00:00,  1.67it/s]


epoch 4: lossD=0.2305, lossG=2.8271


100%|██████████| 235/235 [02:20<00:00,  1.67it/s]


epoch 5: lossD=0.3717, lossG=2.5387


100%|██████████| 235/235 [02:20<00:00,  1.67it/s]


epoch 6: lossD=0.3331, lossG=1.8059


100%|██████████| 235/235 [02:20<00:00,  1.68it/s]


epoch 7: lossD=0.0589, lossG=3.2603


100%|██████████| 235/235 [02:20<00:00,  1.68it/s]


epoch 8: lossD=0.0368, lossG=4.1796


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 9: lossD=0.0416, lossG=3.6637


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 10: lossD=0.2665, lossG=3.3597


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 11: lossD=0.1346, lossG=5.5099


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 12: lossD=0.1332, lossG=3.9036


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 13: lossD=0.1787, lossG=2.6948


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 14: lossD=0.7398, lossG=0.6428


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 15: lossD=0.0406, lossG=3.6687


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 16: lossD=0.6265, lossG=1.6400


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 17: lossD=0.3189, lossG=5.8306


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 18: lossD=0.0214, lossG=4.8354


100%|██████████| 235/235 [02:20<00:00,  1.68it/s]


epoch 19: lossD=0.0291, lossG=4.5707


100%|██████████| 235/235 [02:20<00:00,  1.68it/s]


epoch 20: lossD=0.1547, lossG=4.2184


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 21: lossD=0.1899, lossG=0.8147


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 22: lossD=0.1622, lossG=3.4212


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 23: lossD=0.0861, lossG=3.2367


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 24: lossD=0.1989, lossG=3.6462


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 25: lossD=0.1276, lossG=2.7103


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 26: lossD=0.0473, lossG=4.4507


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 27: lossD=0.0101, lossG=5.6258


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 28: lossD=0.1054, lossG=2.2047


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 29: lossD=0.0468, lossG=4.4328


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 30: lossD=0.1201, lossG=2.8644


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 31: lossD=0.1038, lossG=2.6789


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 32: lossD=0.2749, lossG=2.5273


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 33: lossD=0.0176, lossG=4.7959


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 34: lossD=0.0635, lossG=3.5817


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 35: lossD=0.0252, lossG=4.8964


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 36: lossD=0.0031, lossG=6.8528


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 37: lossD=0.0027, lossG=6.8998


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 38: lossD=0.0097, lossG=5.1632


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 39: lossD=0.0037, lossG=8.1339


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 40: lossD=0.0028, lossG=13.3422


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 41: lossD=0.1741, lossG=4.7230


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 42: lossD=0.1406, lossG=4.8379


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 43: lossD=0.0347, lossG=5.5927


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 44: lossD=0.0184, lossG=4.4317


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 45: lossD=0.0217, lossG=4.7166


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 46: lossD=0.0023, lossG=6.8286


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 47: lossD=0.0226, lossG=5.9515


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 48: lossD=0.0106, lossG=5.9291


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 49: lossD=0.0714, lossG=5.5093


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 50: lossD=0.1088, lossG=4.3230


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 51: lossD=0.0063, lossG=6.1411


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 52: lossD=0.0056, lossG=6.3255


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 53: lossD=0.0004, lossG=7.8934


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 54: lossD=0.0004, lossG=7.8091


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 55: lossD=0.0012, lossG=9.1794


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 56: lossD=0.0010, lossG=7.2935


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 57: lossD=0.0034, lossG=5.2641


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 58: lossD=1.4250, lossG=11.3403


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 59: lossD=0.0513, lossG=5.2526


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 60: lossD=0.2866, lossG=8.9776


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 61: lossD=0.0752, lossG=4.2841


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 62: lossD=0.0508, lossG=6.1939


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 63: lossD=0.0108, lossG=7.4939


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 64: lossD=0.1617, lossG=3.9255


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 65: lossD=0.0129, lossG=5.5948


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 66: lossD=0.0081, lossG=5.6652


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 67: lossD=0.0073, lossG=5.6164


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 68: lossD=0.1320, lossG=3.9062


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 69: lossD=0.0070, lossG=6.0782


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 70: lossD=0.0461, lossG=4.9870


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 71: lossD=0.0401, lossG=3.1114


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 72: lossD=0.0118, lossG=5.6338


100%|██████████| 235/235 [02:19<00:00,  1.68it/s]


epoch 73: lossD=0.1119, lossG=4.2151


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 74: lossD=0.0365, lossG=9.0973


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 75: lossD=0.0354, lossG=4.4361


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]


epoch 76: lossD=0.0971, lossG=3.2645


100%|██████████| 235/235 [02:13<00:00,  1.77it/s]


epoch 77: lossD=0.0421, lossG=4.2946


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 78: lossD=0.2411, lossG=10.4901


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 79: lossD=0.2659, lossG=2.8955


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 80: lossD=0.0113, lossG=6.2108


100%|██████████| 235/235 [02:15<00:00,  1.74it/s]


epoch 81: lossD=0.0152, lossG=5.0647


100%|██████████| 235/235 [02:13<00:00,  1.76it/s]


epoch 82: lossD=0.0149, lossG=6.1264


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 83: lossD=0.0424, lossG=4.1707


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 84: lossD=0.0125, lossG=5.5311


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 85: lossD=0.0059, lossG=6.0297


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 86: lossD=0.1939, lossG=6.1906


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 87: lossD=0.0102, lossG=5.5451


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 88: lossD=0.0620, lossG=3.7463


100%|██████████| 235/235 [02:15<00:00,  1.74it/s]


epoch 89: lossD=0.0163, lossG=5.2693


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 90: lossD=0.0230, lossG=3.0186


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 91: lossD=0.0227, lossG=4.8403


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 92: lossD=0.0904, lossG=3.7114


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 93: lossD=0.0515, lossG=5.5749


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 94: lossD=0.1761, lossG=4.1615


100%|██████████| 235/235 [02:12<00:00,  1.77it/s]


epoch 95: lossD=0.0093, lossG=5.3661


100%|██████████| 235/235 [02:17<00:00,  1.71it/s]


epoch 96: lossD=0.0583, lossG=4.0130


100%|██████████| 235/235 [02:18<00:00,  1.69it/s]


epoch 97: lossD=0.0514, lossG=10.1462


100%|██████████| 235/235 [02:18<00:00,  1.70it/s]


epoch 98: lossD=0.0049, lossG=6.4346


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 99: lossD=0.0019, lossG=7.7282


100%|██████████| 235/235 [02:19<00:00,  1.69it/s]


epoch 100: lossD=0.0034, lossG=6.3365
