In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

In [2]:
batch_size = 64
z_dim = 100

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


mnist = datasets.MNIST(root='/content/sample_data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

100%|██████████| 9.91M/9.91M [00:11<00:00, 895kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]
100%|██████████| 1.65M/1.65M [00:06<00:00, 242kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.04MB/s]


In [3]:
class Generator(nn.Module):
  def __init__(self, z_dim):
    super(Generator, self).__init__()
    self.model = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, 1028),
        nn.ReLU(),
        nn.Linear(1028, 28 * 28),
        nn.Tanh()
    )

  def forward(self, z):
    return self.model(z).view(-1, 1, 28, 28)

In [4]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.model = nn.Sequential(
        nn.Linear(28 * 28, 1024),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),
        nn.Linear(1024, 512),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),
        nn.Linear(512, 256),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3),
        nn.Linear(256, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.model(x.view(-1, 28 * 28))

In [5]:
G = Generator(z_dim)
D = Discriminator()

In [6]:
criterion = nn.BCELoss()
optim_g = optim.Adam(G.parameters(), lr=0.0002)
optim_d = optim.Adam(D.parameters(), lr=0.0002)

In [None]:
epochs = 50
for epoch in range(epochs):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images
        batch_size = real_images.size(0)

        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        z = torch.randn(batch_size, z_dim)
        fake_images = G(z)

        D_real = D(real_images)
        D_fake = D(fake_images.detach())

        D_loss = criterion(D_real, real_labels) + criterion(D_fake, fake_labels)

        optim_d.zero_grad()
        D_loss.backward()
        optim_d.step()

        z = torch.randn(batch_size, z_dim)
        fake_images = G(z)
        D_fake_for_G = D(fake_images)

        g_loss = criterion(D_fake_for_G, real_labels)

        optim_g.zero_grad()
        g_loss.backward()
        optim_g.step()

    print(f'Epoch [{epoch + 1}/{epochs}] - D loss: {D_loss.item():.3f} - G loss: {g_loss.item():.3f}')

Epoch [1/50] - D loss: 0.552 - G loss: 1.982
Epoch [2/50] - D loss: 0.945 - G loss: 0.931
Epoch [3/50] - D loss: 0.194 - G loss: 3.215
Epoch [4/50] - D loss: 0.262 - G loss: 1.942
Epoch [5/50] - D loss: 0.388 - G loss: 3.259
Epoch [6/50] - D loss: 0.166 - G loss: 4.921
Epoch [7/50] - D loss: 0.299 - G loss: 2.763
Epoch [8/50] - D loss: 0.362 - G loss: 1.995
Epoch [9/50] - D loss: 0.161 - G loss: 3.156
Epoch [10/50] - D loss: 0.453 - G loss: 3.585
Epoch [11/50] - D loss: 0.539 - G loss: 2.770
Epoch [12/50] - D loss: 0.498 - G loss: 2.767
Epoch [13/50] - D loss: 0.656 - G loss: 1.886


In [None]:
G.eval()
with torch.no_grad():
  z = torch.randn(16, z_dim)
  samples = G(z)
  samples = (samples + 1) / 2

  fig, axes = plt.subplots(4, 4, figsize=(6, 6))
  for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i, 0], cmap='gray')
    ax.axis('off')

  plt.show()