In [1]:
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable as V


In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layer1 = nn.Linear(100,256)
        self.batchnorm1 = nn.BatchNorm1d(256, momentum=0.8)
        
        self.layer2 = nn.Linear(256,512)
        self.batchnorm2 = nn.BatchNorm1d(512, momentum=0.8)
        
        self.layer3 = nn.Linear(512,1024)
        self.batchnorm3 = nn.BatchNorm1d(1024, momentum=0.8)
        
        self.layer4 = nn.Linear(1024,28*28)
    
    def forward(self, x):
        x = F.leaky_relu(self.layer1(x))
        x = self.batchnorm1(x)
        x = self.batchnorm2(F.leaky_relu(self.layer2(x)))
        x = self.batchnorm3(F.leaky_relu(self.layer3(x)))
        x = F.tanh(self.layer4(x))
        x = x.reshape(-1,28,28)
        return x

     
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Linear(28*28,512)
        self.layer2 = nn.Linear(512,256)
        self.layer3 = nn.Linear(256,1)
    
    def forward(self, x):
        x = x.reshape(-1, 28*28)
        x = F.leaky_relu(self.layer1(x))
        x = F.leaky_relu(self.layer2(x))
        x = F.sigmoid(self.layer3(x))
        return x

In [3]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

BATCH_SIZE = 32

torch_X_train = torch.from_numpy(X_train).type(torch.FloatTensor)
torch_y_train = torch.from_numpy(y_train).type(torch.LongTensor) # data type is long

# create feature and targets tensor for test set.
torch_X_test = torch.from_numpy(X_test).type(torch.FloatTensor)
torch_y_test = torch.from_numpy(y_test).type(torch.LongTensor) # data type is long

# Pytorch train and test sets
train = torch.utils.data.TensorDataset(torch_X_train,torch_y_train)
test = torch.utils.data.TensorDataset(torch_X_test,torch_y_test)
# train = torch.utils.data.TensorDataset(torch_X_train,torch_X_train)
# test = torch.utils.data.TensorDataset(torch_X_test,torch_X_test)

# data loader
train_loader = torch.utils.data.DataLoader(train, batch_size = BATCH_SIZE, shuffle = False)
test_loader = torch.utils.data.DataLoader(test, batch_size = BATCH_SIZE, shuffle = False)

Using TensorFlow backend.


In [104]:
generator = Generator()
discriminator = Discriminator()
loss_dis = nn.BCELoss()
loss_gen = nn.BCELoss()
optimizer_dis = optim.Adam(discriminator.parameters(), lr=0.001)
optimizer_gen = optim.Adam(generator.parameters(), lr=0.001)


for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in (enumerate(train_loader)):
        # get the inputs; data is a list of [inputs, labels]
        inputs_real, _ = data
        inputs_real = inputs_real.reshape(-1,28*28)
        inputs_real = V(inputs_real)
        inputs_noise = V(torch.randn(BATCH_SIZE,100))
        label_real = V(torch.ones(BATCH_SIZE,1))
        label_fake = V(torch.zeros(BATCH_SIZE,1))
        
        ## GENERATOR TRAIN
        optimizer_gen.zero_grad()
        image_gen = generator(inputs_noise)
        output_loss_dis = loss_dis(discriminator(image_gen), label_real)
        output_loss_dis.backward()
        optimizer_gen.step()

        ## DISCRIMINATOR TRAIN                              
        # zero the parameter gradients
        optimizer_dis.zero_grad()

        # forward + backward + optimize
        real_loss = loss_dis(discriminator(inputs_real), label_real)
        fake_loss = loss_dis(discriminator(image_gen.detach()), label_fake)
        d_loss = real_loss/2 +fake_loss/2
        d_loss.backward()
        optimizer_dis.step()
        
        # print statistics
        if i%100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, 2, i, len(train_loader), d_loss.item(), output_loss_dis.item())
            )
print('Finished Training')

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

[Epoch 0/2] [Batch 0/1875] [D loss: 0.374473] [G loss: 0.643144]
[Epoch 0/2] [Batch 100/1875] [D loss: 0.074850] [G loss: 18.498960]
[Epoch 0/2] [Batch 200/1875] [D loss: 1.726944] [G loss: 28.591345]
[Epoch 0/2] [Batch 300/1875] [D loss: 0.431735] [G loss: 38.783363]
[Epoch 0/2] [Batch 400/1875] [D loss: 0.863469] [G loss: 38.346790]
[Epoch 0/2] [Batch 500/1875] [D loss: 6.044286] [G loss: 25.258535]
[Epoch 0/2] [Batch 600/1875] [D loss: 3.453878] [G loss: 20.723269]
[Epoch 0/2] [Batch 700/1875] [D loss: 3.242294] [G loss: 27.631029]
[Epoch 0/2] [Batch 800/1875] [D loss: 0.000000] [G loss: 27.631029]
[Epoch 0/2] [Batch 900/1875] [D loss: 0.000000] [G loss: 27.631029]
[Epoch 0/2] [Batch 1000/1875] [D loss: 0.000000] [G loss: 27.631029]
[Epoch 0/2] [Batch 1100/1875] [D loss: 6.554146] [G loss: 15.106862]
[Epoch 0/2] [Batch 1200/1875] [D loss: 4.317348] [G loss: 22.328619]
[Epoch 0/2] [Batch 1300/1875] [D loss: 0.000000] [G loss: 27.631029]
[Epoch 0/2] [Batch 1400/1875] [D loss: 0.863469

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

[Epoch 1/2] [Batch 0/1875] [D loss: 0.000000] [G loss: 27.631029]
[Epoch 1/2] [Batch 100/1875] [D loss: 0.000000] [G loss: 27.631029]
[Epoch 1/2] [Batch 200/1875] [D loss: 0.000000] [G loss: 27.631029]
[Epoch 1/2] [Batch 300/1875] [D loss: 0.000000] [G loss: 27.631029]


KeyboardInterrupt: 