In [0]:
%load_ext autoreload
%matplotlib inline

#install TensorFlow 2.0 
!pip install tensorboardX

# Load the TensorBoard notebook extension
%load_ext tensorboard



In [0]:
%autoreload 2

from IPython import display
# !pip uninstall utils
from utils import Logger

import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import numpy as np
cuda=torch.cuda.is_available()

In [0]:
DATA_FOLDER = './torch_data/VGAN/MNIST'

## Load Data

In [0]:
def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])]
        )
    out_dir = '{}/dataset'.format(DATA_FOLDER)
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

In [0]:
# Load data
data = mnist_data()
# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
# Num batches
num_batches = len(data_loader)
# print(num_batches)
n_classes=10

## Networks

In [0]:
class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        
        super(DiscriminatorNet, self).__init__()
        self.label_embedding = nn.Embedding(n_classes, n_classes)
        n_features = 784
        n_out = 1
        
        self.hidden0 = nn.Sequential( 
            nn.Linear(n_features+n_classes, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            torch.nn.Linear(256, n_out),
            torch.nn.Sigmoid()
        )

    def forward(self, x,labels):
        inp=(torch.cat((self.label_embedding(labels), x), -1))
        x = self.hidden0(inp)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

In [0]:
class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_classes+n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )

    def forward(self, x,labels):
        inp=(torch.cat((self.label_emb(labels), x), -1))
        x = self.hidden0(inp)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
# Noise
def noise(size):
    n = Variable(torch.randn(size, 100))
    if torch.cuda.is_available(): return n.cuda() 
    return n

In [0]:
discriminator = DiscriminatorNet()
generator = GeneratorNet()
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()

## Optimization

In [0]:
# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0001)

# Loss function
loss = nn.BCELoss()

# Number of steps to apply to the discriminator
d_steps = 1 
# Number of epochs
num_epochs = 200

## Training

In [0]:
def real_data_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1))
    if torch.cuda.is_available(): return data.cuda()
    return data

def fake_data_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1))
    if torch.cuda.is_available(): return data.cuda()
    return data

In [0]:
def train_discriminator(optimizer, real_data, fake_data,real_labels,fake_labels):
    # Reset gradients
    optimizer.zero_grad()
    
    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data,real_labels)
    # Calculate error and backpropagate
    error_real = loss(prediction_real, real_data_target(real_data.size(0)))
    error_real.backward()

    # 1.2 Train on Fake Data
    prediction_fake = discriminator(fake_data,fake_labels)
    # Calculate error and backpropagate
    error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))
    error_fake.backward()
    
    # 1.3 Update weights with gradients
    optimizer.step()
    
    # Return error
    return error_real + error_fake, prediction_real, prediction_fake

def train_generator(optimizer, fake_data,fake_labels):
    # 2. Train Generator
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data,fake_labels)
    # Calculate error and backpropagate
    error = loss(prediction, real_data_target(prediction.size(0)))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error

### Generate Samples for Testing

In [0]:
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

In [0]:
num_test_samples = 100
n_row=10
test_noise = noise(n_row**2)

# test_labels= Variable(LongTensor(np.random.randint(0, n_classes, num_test_samples)))
test_labels = Variable(LongTensor(np.array([num for _ in range(n_row) for num in range(n_row)])))
test_labels.size()
# print(labels)

torch.Size([100])

### Start training

In [0]:
logger = Logger(model_name='VGAN', data_name='MNIST')

for epoch in range(num_epochs):
    for n_batch, (real_batch,label_batch) in enumerate(data_loader):

        # 1. Train Discriminator
        real_data = Variable(images_to_vectors(real_batch))
        real_labels=Variable(label_batch.type(torch.cuda.LongTensor))

        if torch.cuda.is_available(): real_data = real_data.cuda()
        # Generate fake data
        z=noise(real_data.size(0))
        
        fake_labels = Variable(LongTensor(np.random.randint(0, n_classes, real_data.size(0))))
        fake_data = generator(z,fake_labels).detach()
        # Train D
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
                                                                real_data, fake_data,real_labels,fake_labels)

        # 2. Train Generator
        # Generate fake data
        z=noise(real_batch.size(0))
        
        fake_labels= Variable(LongTensor(np.random.randint(0, n_classes, real_data.size(0))))
        fake_data = generator(z,fake_labels)
        # Train G
        g_error = train_generator(g_optimizer, fake_data,fake_labels)
        # Log error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)

        # Display Progress
        if (n_batch) % 100 == 0:
            display.clear_output(True)
            # Display Images
            test_images = vectors_to_images(generator(test_noise,test_labels)).data.cpu()
            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )
        # Model Checkpoints
        logger.save_models(generator, discriminator, epoch)

Buffered data was truncated after reaching the output size limit.

In [0]:
!zip -r /content/imgcgan.zip /content/data/images/VGAN/MNIST
!zip -r /content/runcgan.zip /content/runs
from google.colab import files
files.download('/content/imgcgan.zip')
files.download('/content/runcgan.zip')

  adding: content/data/images/VGAN/MNIST/ (stored 0%)
  adding: content/data/images/VGAN/MNIST/hori_epoch_99_batch_100.png (deflated 17%)
  adding: content/data/images/VGAN/MNIST/_epoch_118_batch_500.png (deflated 8%)
  adding: content/data/images/VGAN/MNIST/_epoch_100_batch_400.png (deflated 8%)
  adding: content/data/images/VGAN/MNIST/hori_epoch_123_batch_500.png (deflated 16%)
  adding: content/data/images/VGAN/MNIST/_epoch_55_batch_500.png (deflated 8%)
  adding: content/data/images/VGAN/MNIST/_epoch_28_batch_0.png (deflated 7%)
  adding: content/data/images/VGAN/MNIST/hori_epoch_128_batch_500.png (deflated 16%)
  adding: content/data/images/VGAN/MNIST/hori_epoch_86_batch_400.png (deflated 16%)
  adding: content/data/images/VGAN/MNIST/_epoch_5_batch_400.png (deflated 9%)
  adding: content/data/images/VGAN/MNIST/hori_epoch_123_batch_400.png (deflated 15%)
  adding: content/data/images/VGAN/MNIST/hori_epoch_55_batch_200.png (deflated 17%)
  adding: content/data/images/VGAN/MNIST/hori