In [None]:
import torch
from torch import optim, nn
import torch.nn.functional as F 
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

In [None]:
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Constants from the literature 
input_size = 100
learning_rate = 0.0002 
beta1 = 0.5
batch_size = 5

In [None]:
mnist_transforms = transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='/data/MNIST', train=True, download=True, transform=mnist_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 
test_dataset = datasets.MNIST(root='/data/MNIST', train=False, download=True, transform=mnist_transforms)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.fsconv1 = nn.ConvTranspose2d(input_size, 512, 4, 1, 0, bias=False)
    self.norm1 = nn.BatchNorm2d(512)

    self.fsconv2 = nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False)
    self.norm2 = nn.BatchNorm2d(256)

    self.fsconv3 = nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False)
    self.norm3 = nn.BatchNorm2d(128)
    
    self.fsconv4 = nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False)

  def forward(self, x):
    x = F.relu(self.norm1(self.fsconv1(x)))
    x = F.relu(self.norm2(self.fsconv2(x)))
    x = F.relu(self.norm3(self.fsconv3(x)))
    x = torch.tanh(self.fsconv4(x))
    return x

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.sconv1 = nn.ConvTranspose2d(1, 64, 4, 2, 1, bias=False)
    
    self.sconv2 = nn.ConvTranspose2d(64, 128, 4, 2, 1, bias=False)
    self.norm1 = nn.BatchNorm2d(128)

    self.sconv3 = nn.ConvTranspose2d(128, 256, 4, 2, 1, bias=False)
    self.norm2 = nn.BatchNorm2d(256)

    self.sconv4 = nn.ConvTranspose2d(256, 512, 4, 2, 1, bias=False)
    self.norm3 = nn.BatchNorm2d(512)

    self.sconv5= nn.ConvTranspose2d(512, 1, 4, 1, 0, bias=False)

  def forward(self, x):
    x = F.leaky_relu(self.sconv1(x),0.2)
    x = F.leaky_relu(self.norm1(self.sconv2(x)),0.2)
    x = F.leaky_relu(self.norm2(self.sconv3(x)),0.2)
    x = F.leaky_relu(self.norm3(self.sconv4(x)),0.2)
    x = torch.sigmoid(self.sconv5(x))
    x = x.reshape(-1, 1).squeeze(1)
    return x

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.normal(m.weight, 0.0, 0.02)

In [None]:
gen = Generator().to(device)
gen.apply(init_weights)
disc = Discriminator().to(device)
disc.apply(init_weights)

In [None]:
optimizer_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizer_disc = optim.Adam(disc.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
real_label = 1
fake_label = 0

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
  for i, (real_images, _) in enumerate(train_loader):
    torch.cuda.empty_cache()

    optimizer_disc.zero_grad()

    real_images = real_images.to(device)
    
    real_outputs = disc(real_images)
    labels = torch.full((real_outputs.shape), real_label, device=device)
    real_loss = F.binary_cross_entropy(real_outputs, labels.float())
    real_loss.backward()

    noise = torch.randn(batch_size, input_size, 1, 1, device=device)
    fake_images = gen(noise)
    labels.fill_(fake_label)
    fake_outputs = disc(fake_images)
    fake_loss = F.binary_cross_entropy(fake_outputs, labels.float())
    fake_loss.backward()

    optimizer_disc.step()

    optimizer_gen.zero_grad()
    labels.fill_(real_label)
    outputs = disc(fake_images)
    loss = F.binary_cross_entropy(outputs, labels.float())
    loss.backward()

    optimizer_gen.step()