Code from:  
https://medium.com/ai-society/gans-from-scratch-1-a-deep-introduction-with-code-in-pytorch-and-tensorflow-cb03cdcdba0f

In [1]:
!pip install tensorboardx



In [2]:
!git clone https://github.com/diegoalejogm/gans.git

fatal: destination path 'gans' already exists and is not an empty directory.


In [0]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

In [0]:
def cifar10_data():
  compose = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((.5,.5,.5), (.5,.5,.5))
  ])
  out_dir = './dataset'
  return datasets.CIFAR10(root=out_dir, train=True, transform=compose, download=True)

In [0]:
data = cifar10_data()
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
num_batches = len(data_loader)

In [0]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.backends.cudnn.benchmark = True
device

In [0]:
class DiscriminatorNet(torch.nn.Module):
  def __init__(self):
    super(DiscriminatorNet, self).__init__()
    n_features = 3*32*32
    n_out = 1

    self.hidden0 = nn.Sequential(
        nn.Linear(n_features, 1024),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3)
    )

    self.hidden1 = nn.Sequential(
        nn.Linear(1024, 512),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3)
    )

    self.hidden2 = nn.Sequential(
        nn.Linear(512, 256),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.3)
    )

    self.out = nn.Sequential(
        nn.Linear(256, n_out),
        nn.Sigmoid()
    )

  def forward(self, x):
    x = self.hidden0(x)
    x = self.hidden1(x)
    x = self.hidden2(x)
    x = self.out(x)
    return x

discriminator = DiscriminatorNet()
discriminator.to(device)

In [0]:
def images_to_vectors(images):
  # images.size(0): Batch size
  return images.view(images.size(0), 3*32*32)

def vectors_to_images(vectors):
# vectors.size(0): Batch size
  return vectors.view(vectors.size(0),3,32,32)

In [0]:
class GeneratorNet(nn.Module):
  def __init__(self):
    super(GeneratorNet, self).__init__()
    n_features = 100
    n_out = 3*32*32

    self.hidden0 = nn.Sequential(
        nn.Linear(n_features, 256),
        nn.LeakyReLU(0.2)
    )

    self.hidden1 = nn.Sequential(
        nn.Linear(256, 512),
        nn.LeakyReLU(0.2)
    )

    self.hidden2 = nn.Sequential(
        nn.Linear(512, 1024),
        nn.LeakyReLU(0.2)
    )

    self.out = nn.Sequential(
        nn.Linear(1024, n_out),
        nn.Tanh()
    )

  def forward(self, x):
    x = self.hidden0(x)
    x = self.hidden1(x)
    x = self.hidden2(x)
    x = self.out(x)
    return x

generator = GeneratorNet()
generator.to(device)

In [0]:
def noise(size):
  n = Variable(torch.randn(size, 100,1,1))
  return n.to(device)

In [0]:
d_optimizer = optim.Adam(discriminator.parameters(),lr = 0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

In [0]:
loss = nn.BCELoss()

In [0]:
def ones_target(size):
  data = Variable(torch.ones(size, 1))
  return data.to(device)

def zeros_target(size):
  data = Variable(torch.zeros(size, 1))
  return data.to(device)


In [0]:
def train_discriminator(optimizer, real_data, fake_data):
  # real_data = real_data.to(device)
  # fake_data = fake_data.to(device)

  N = real_data.size(0)


  optimizer.zero_grad()

  # 1.1 Train on Real Data
  prediction_real = discriminator(real_data)
  error_real = loss(prediction_real, ones_target(N))
  error_real.backward()

  # 1.2 Train on Fake Data
  prediction_fake = discriminator(fake_data)
  error_fake = loss(prediction_fake, zeros_target(N))
  error_fake.backward()

  # 1.3 Update weights with gradients
  optimizer.step()

  return error_real + error_fake, prediction_real, prediction_fake


In [0]:
def train_generator(optimizer, fake_data):
  N = fake_data.size(0)
  # fake_data.to(device)

  optimizer.zero_grad()

  prediction = discriminator(fake_data)

  error = loss(prediction, ones_target(N))
  error.backward()

  optimizer.step()

  return error

In [0]:
num_test_samples = 16
test_noise = noise(num_test_samples)

In [0]:
from gans.utils import Logger

In [0]:
logger = Logger(model_name='VGAN', data_name='CIFAR10')

num_epochs =1000

for epoch in range(1, num_epochs+1):
  for n_batch, (real_batch, _) in enumerate(data_loader):
    N = real_batch.size(0)

    real_data = Variable(images_to_vectors(real_batch))

    fake_data = generator(noise(N)).detach()
    
    real_data = real_data.to(device)
    fake_data = fake_data.to(device)

    d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer, real_data, fake_data)

    fake_data = generator(noise(N))
    fake_data = fake_data.to(device)

    g_error = train_generator(g_optimizer, fake_data)

  if epoch % 10 == 0:
    test_images = vectors_to_images(generator(test_noise))
    test_images = test_images.data.cpu()

    logger.display_status(
    epoch, num_epochs, n_batch, num_batches,
    d_error, g_error, d_pred_real, d_pred_fake
    )

    logger.log_images(
        test_images, num_test_samples, 
        epoch, n_batch, num_batches
    )

In [0]:
noise(N).size()

In [0]:
test_images = vectors_to_images(generator(test_noise))
test_images = test_images.data.cpu()

logger.log_images(
    test_images, num_test_samples, 
    epoch, n_batch, num_batches
)