In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as ag

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from tqdm import tqdm
import os

In [44]:
lrg = 1e-3
lrd = 1e-3
optim_algo = 'adam'
nepochs = 20
zdim = 64
batch_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

denormalize = transforms.Compose([
    transforms.Normalize((0,), (1/0.5,)),
    transforms.Normalize((-0.5,), (1.,))
])

trainset = datasets.MNIST(root="./data/", download=True, train=True, transform=transform)
testset = datasets.MNIST(root="./data/", download=True, train=False, transform=transform)

In [3]:
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, drop_last=True)

In [4]:
# gen: a MLP that transforms a 64x1 dimensional Gaussian noise
# to MNIST (28x28=784x1) images
gen = nn.Sequential(
    nn.Linear(zdim, 128),
    nn.LeakyReLU(),
    nn.Linear(128, 512),
    nn.LeakyReLU(),
    nn.Linear(512, 512),
    nn.LeakyReLU(),
    nn.Linear(512, 784),
)

# dis: a MLP that takes a 781x1 vector and returns the prob: real 1 or fake 0
dis = nn.Sequential(
    nn.Linear(784, 512),
    nn.LeakyReLU(),
    nn.Linear(512, 512),
    nn.LeakyReLU(),
    nn.Linear(512, 128),
    nn.LeakyReLU(),
    nn.Linear(128, zdim),
    nn.LeakyReLU(),
    nn.Linear(zdim, 1),
    nn.Sigmoid()
)

In [45]:
def train_original_gan(gen, dis, lrg, lrd, optim_algo, nepochs, zdim, device='cuda'):
  gen.to(device)
  dis.to(device)

  if optim_algo == 'adam':
    opg = optim.Adam(gen.parameters(), lr=lrg, betas=(0.5, 0.99))
    opd = optim.Adam(dis.parameters(), lr=lrd, betas=(0.5, 0.99))
  else:
    opg = optim.SGD(gen.parameters(), lr=lrg)
    opd = optim.SGD(dis.parameters(), lr=lrd)

  loss = nn.BCELoss()
  ones = None
  zeros = None
  batch_size = None

  for ei in tqdm(range(nepochs)):
    for bi, (bx, by) in enumerate(train_loader):
      if ones is None:
        batch_size = bx.size(0)
        ones = torch.ones(batch_size, 1).to(device)
        zeros = torch.zeros(batch_size, 1).to(device)

      for i in range(1):
        opd.zero_grad()
        opg.zero_grad()
        bx = bx.reshape(batch_size, -1)
        bx = bx.to(device)
        px = dis(bx)
        # loss_d_real = loss(px, ones)
        # loss_d_real.backward()

        bz = torch.randn(batch_size, zdim).to(device)
        with torch.no_grad():
          bfx = gen(bz)
        pfx = dis(bfx)
        # loss_d_fake = loss(pfx, zeros)
        # loss_d_fake.backward()

        loss_d = loss(px, ones) + loss(pfx, zeros)
        loss_d.backward()
        opd.step()

      for i in range(1):
        opg.zero_grad()
        opd.zero_grad()
        bz = torch.randn(batch_size, zdim).to(device)
        bfx = gen(bz)
        pfx = dis(bfx)
        loss_g = loss(pfx, ones)
        loss_g.backward()
        opg.step()

    with torch.no_grad():
      bz = torch.rand(64, zdim).to(device)
      bfx = gen(bz)
      bfx = bfx.reshape(64, 1, 28, 28).cpu()
      bfx = denormalize(bfx)

      save_folder = "results"
      if not os.path.exists(save_folder):
        os.makedirs(save_folder)
      torchvision.utils.save_image(bfx, "%s/epoch_%02d.png" % (save_folder, ei))

In [47]:
train_original_gan(gen, dis, lrg, lrd, optim_algo, nepochs, zdim, device)

100%|██████████| 20/20 [05:07<00:00, 15.39s/it]


In [46]:
!rm -rf results