In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np

import torchvision
import torchvision.transforms as transforms


import os
from torchvision.utils import save_image

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [30]:
# ===============================
# Configuration
# ===============================

batch_size = 128
latent_dim = 50
epochs = 30
lr = 2e-4

In [31]:
# ================================
# Dataset (MNIST)
# ================================

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [32]:
# ==========================
# Generator
# ==========================

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(latent_dim, 1024),
        nn.ReLU(),
        nn.Linear(1024, 1024),
        nn.BatchNorm1d(1024),
        nn.ReLU(),
        nn.Linear(1024, 784),
        nn.Tanh()
    )

  def forward(self,z):
    return self.net(z)

In [33]:
# ==========================
# Encoder
# ==========================

class Encoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(784, 1024),
        nn.ReLU(),
        nn.Linear(1024, 1024),
        nn.BatchNorm1d(1024),
        nn.ReLU(),
        nn.Linear(1024, latent_dim)
    )

  def forward(self, x):
    return self.net(x)


In [None]:
# ==============================
# Discriminator (Joint: x, z)
# ==============================

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(784 + latent_dim, 1024),
        nn.ReLU(),
        nn.Linear(1024, 1024),
        nn.BatchNorm1d(1024),
        nn.ReLU(),
        nn.Linear(1024, 1),
    )

  def forward(self, x, z):
    xz = torch.cat([x, z], dim=1)
    return self.net(xz)

In [35]:
# ======================
# Initialize models
# ======================

G = Generator().to(device)
E = Encoder().to(device)
D = Discriminator().to(device)

In [36]:
# ==========================
# Optimizers and Loss
# ==========================

bce_loss = nn.BCEWithLogitsLoss()
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_GE = optim.Adam(list(G.parameters()) + list(E.parameters()), lr=lr, betas=(0.5, 0.999))

In [37]:
k = 3
p = 1

os.makedirs('bigan', exist_ok=True)

In [None]:
# ============================
# Training Loop
# ============================

for epoch in range(1, epochs + 1):
  for i, (x_real, _) in enumerate(train_loader):
    x_real = x_real.to(device)
    batch_size = x_real.size(0)

    # Sample z from prior
    z_real = torch.randn(batch_size, latent_dim, device=device)

    # ===========================
    # 1. Train Discriminator
    # ===========================

    for i in range(p):
      G.eval()
      E.eval()
      D.train()

      x_fake = G(z_real).detach()  # x_fake == x_generated
      z_fake = E(x_real).detach()  # z_fake == z_generated

      D_real = D(x_real, z_fake)
      D_fake = D(x_fake, z_real)

      label_real = torch.ones_like(D_real)
      label_fake = torch.zeros_like(D_fake)

      loss_D = bce_loss(D_real, label_real) + bce_loss(D_fake, label_fake)

      optimizer_D.zero_grad()
      loss_D.backward()
      optimizer_D.step()

    # ================================
    # 2. Train Generator + Encoder
    # ================================

    for _ in range(k):
      G.eval()
      E.eval()
      D.train()

      X_fake = G(z_real)
      z_fake = E(x_real)

      D_real = D(x_real, z_fake)
      D_fake = D(x_fake, z_real)

      loss_GE = bce_loss(D_real, label_fake) + bce_loss(D_fake, label_real)

      optimizer_GE.zero_grad()
      loss_GE.backward()
      optimizer_GE.step()

    if i%100 == 0:
      print(f'Epoch [{epoch}/{epochs}], Batch [{i}/{len(train_loader)}], D_loss: {loss_D.item():.4f}, G_loss: {loss_GE.item():.4f}')

    # Save samples every epochs
    G.eval()

    with torch.no_grad():
      z = torch.randn(64, latent_dim, device=device)
      samples = G(z)
      samples = samples * 0.5 + 0.5  # denormalize
      save_image(samples, f'bigan/epoch_{epoch}.png', nrow=8)

    G.train()
