In [0]:
import torch
from torch import optim
from torch import nn as nn
from torch.autograd.variable import Variable
from torchvision import transforms,datasets


In [2]:
def mnist_data():
    compose = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((.5,),(.5,))
    ])
    out_dir='./dataset'
    return datasets.MNIST(root=out_dir,train=False,transform=compose,download=True)
data = mnist_data()

data_loader = torch.utils.data.DataLoader(data,batch_size=100,shuffle=True)

num_batches = len(data_loader)


  0%|          | 0/9912422 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|█████████▉| 9863168/9912422 [00:16<00:00, 554385.36it/s]

Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw



0it [00:00, ?it/s][A
32768it [00:00, 326902.47it/s][A
0it [00:00, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz



  1%|          | 16384/1648877 [00:00<00:12, 125936.89it/s][A
  5%|▍         | 81920/1648877 [00:00<00:09, 163139.03it/s][A
  6%|▋         | 106496/1648877 [00:00<00:08, 175685.01it/s][A
 10%|▉         | 163840/1648877 [00:00<00:06, 217692.19it/s][A
 12%|█▏        | 196608/1648877 [00:00<00:06, 238598.21it/s][A
 15%|█▍        | 245760/1648877 [00:00<00:05, 274929.09it/s][A
 19%|█▉        | 311296/1648877 [00:00<00:04, 319704.73it/s][A
 23%|██▎       | 376832/1648877 [00:00<00:03, 375875.32it/s][A
 27%|██▋       | 442368/1648877 [00:01<00:02, 430004.05it/s][A
 32%|███▏      | 532480/1648877 [00:01<00:02, 499279.60it/s][A
 38%|███▊      | 630784/1648877 [00:01<00:01, 579058.57it/s][A
 44%|████▎     | 720896/1648877 [00:01<00:01, 645310.08it/s][A
 50%|█████     | 827392/1648877 [00:01<00:01, 729470.07it/s][A
 57%|█████▋    | 933888/1648877 [00:01<00:00, 805568.78it/s][A
 64%|██████▎   | 1048576/1648877 [00:01<00:00, 848554.83it/s][A
 70%|██████▉   | 1146880/1648877 [00:01<

Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw
Processing...
Done!


In [0]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        n_features =784
        n_out = 1
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features,1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.hidden1 = nn.Sequential(
            nn.Linear(1024,512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.hidden2 = nn.Sequential(
                nn.Linear(512,256),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
        )
        
        self.out = nn.Sequential(
                nn.Linear(256,n_out),
                nn.Sigmoid()
        )
        
    def forward(self,x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
discriminator = Discriminator()

In [0]:
def images_to_vector(images):
    return images.view(images.size(0),784)
def vectors_to_images(vector):
    return vector.view(vector.size(0),1,28,28)

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features,256),
            nn.LeakyReLU(0.2)
        )
        
        self.hidden1 = nn.Sequential(
                nn.Linear(256,512),
                nn.LeakyReLU(0.2)
        )
        
        self.hidden2 = nn.Sequential(
            nn.Linear(512,1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024,n_out),
            nn.Tanh()
        )
    def forward(self,x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
generator = Generator()

9920512it [00:29, 554385.36it/s]                             

In [0]:
def noise(size):
    n = Variable(torch.randn(size,100))
    return n

In [0]:
d_optimizer = optim.Adam(discriminator.parameters(),lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(),lr=0.0002)


In [0]:
loss = nn.BCELoss()

In [0]:
def ones_target(size):
    data = Variable(torch.ones(size,1))
    return data
def zeros_target(size):
    data = Variable(torch.zeros(size,1))
    return data

In [0]:
def train_discriminator(optimizer,real_data,fake_data):
    N = real_data.size(0)
    optimizer.zero_grad()
    
    #train on real data
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real,ones_target(N))
    error_real.backward()
    
    #train on fake data
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake,zeros_target(N))
    error_fake.backward()
    
    #update weights
    optimizer.step()
    
    return error_real+error_fake,prediction_real,prediction_fake
    

In [0]:
def train_genrator(optimizer,fake_data):
    N = fake_data.size(0)
    
    optimizer.zero_grad()
    
    prediction = discriminator(fake_data)
    #cal loss
    error = loss(prediction,ones_target(N))
    #cal grads
    error.backward()
    #update wights
    optimizer.step()
    return error

In [0]:
num_test_samples =16
test_noise = noise(num_test_samples)

In [15]:
!pip install tensorboardX

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/a6/5c/e918d9f190baab8d55bad52840d8091dd5114cc99f03eaa6d72d404503cc/tensorboardX-1.9-py2.py3-none-any.whl (190kB)
[K     |████████████████████████████████| 194kB 3.5MB/s 
Installing collected packages: tensorboardX
Successfully installed tensorboardX-1.9


In [0]:
from utils import Logger

In [17]:
logger = Logger(model_name='VGAN',data_name='MNIST')

num_epochs =200

for epoch in range(num_epochs):
    for n_batch,(real_batch,_) in enumerate(data_loader):
        N = real_batch.size(0)
        
        #1.train Disc
        real_data = Variable(images_to_vector(real_batch))
        
        #generate fake data and detach( grads are not calculated for gen)
        fake_data = generator(noise(N)).detach()
        
        #train d
        d_error,d_pred_real,d_pred_fake = train_discriminator(d_optimizer,real_data,fake_data)
        
        #2. train gen
        #gen fake
        fake_data = generator(noise(N))
        
        #train G
        g_error = train_genrator(g_optimizer,fake_data)
        
        # Log batch error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        # Display Progress every few batches
        if (n_batch) % 100 == 0: 
            test_images = vectors_to_images(generator(test_noise))
            test_images = test_images.data
            logger.log_images(
                test_images, num_test_samples, 
                epoch, n_batch, num_batches
            );
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )

Output hidden; open in https://colab.research.google.com to view.