In [None]:
from IPython import display
from utils import Logger
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from matplotlib import pyplot
import MNISTtools
import torch.utils.data as data_utils
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
import os

##### Loading MNIST dataset

In [None]:
xtrain, ltrain = MNISTtools.load(dataset='training', path = '/datasets/MNIST')
xtrain = np.transpose(xtrain)

def normalize_MNIST_images(x):
    x = x.astype(np.float32)
    x = (x - 127.5)/127.5
    return x
xtrain = normalize_MNIST_images(xtrain).reshape(60000,1,28,28)

def label2onehot(lbl):
    d = np.zeros((lbl.max() + 1, lbl.size))
    d[lbl[np.arange(0, lbl.size)], np.arange(0, lbl.size)] = 1
    return d
dtrain = label2onehot(ltrain)

dtrain = np.transpose(dtrain)
print(xtrain.shape)

img = torch.from_numpy(xtrain).float().to(device)
lbl = torch.from_numpy(dtrain).float().to(device)

train = data_utils.TensorDataset(img,lbl)
data_loader = data_utils.DataLoader(train, batch_size = 100, shuffle = True)

print(img.size())
print(lbl.size())
print(type(img))
print(type(lbl))
# Num batches
num_batches = len(data_loader)
print(num_batches)

In [None]:
class DiscriminativeNet(torch.nn.Module):
    
    def __init__(self):
        super(DiscriminativeNet, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1, out_channels=128, kernel_size=4, 
                stride=2, padding=1, bias=False
            ),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=128, out_channels=256, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=256, out_channels=512, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(
                in_channels=512, out_channels=1024, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.out = nn.Sequential(
            nn.Linear(1024*1*1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(-1, 1024*1*1)
        x = self.out(x)
        return x

In [None]:
class GenerativeNet(torch.nn.Module):
    
    def __init__(self):
        super(GenerativeNet, self).__init__()
        
        self.linear = torch.nn.Linear(100, 1024*1*1)
        
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=1024, out_channels=512, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=512, out_channels=256, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=256, out_channels=128, kernel_size=3,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=128, out_channels=1, kernel_size=4,
                stride=2, padding=1, bias=False
            )
        )
        self.conv5 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=1, out_channels=1, kernel_size=4,
                stride=2, padding=1, bias=False
            )
        )
        self.out = torch.nn.Tanh()

    def forward(self, x):
        x = self.linear(x)
        x = x.view(x.shape[0], 1024, 1, 1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return self.out(x)
    

In [None]:
discriminator = DiscriminativeNet().to(device)
generator = GenerativeNet().to(device)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
DiscriminatorLoss = nn.BCELoss()
num_epochs = 63
generated_images = Variable(torch.randn(16, 100)).to(device)  #sample noise

##### Training Loop

In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import save_image
for epoch in range(0,num_epochs):
    for i, (images,labels) in enumerate(data_loader):
        real = torch.ones(images.size(0), 1).to(device)
        fake = torch.zeros(images.size(0), 1).to(device)    

        output_real = discriminator(images)
        loss_real = DiscriminatorLoss(output_real, real)
        
        z = torch.randn(images.size(0), 100).to(device)
        fake_images = generator(z)
        
        output_fake = discriminator(fake_images.detach())
        loss_fake = DiscriminatorLoss(output_fake,fake)
        
        e = loss_real + loss_fake
        
        d_optimizer.zero_grad()
        e.backward()
        
        d_optimizer.step()
        
        z = torch.randn(images.size(0), 100).to(device)
        fake_data = generator(z)

        prediction = discriminator(fake_data)
        g_error = DiscriminatorLoss(prediction, real)
        
        g_optimizer.zero_grad()
        g_error.backward()
        g_optimizer.step()

        if (i) % 100 == 0:
            display.clear_output(True)
            test_images = (generator(generated_images)).data.cpu()
            save_image(test_images, 'models/images/%d_%d.png' % (epoch, i),normalize=True)
            print ('Epoch: {}, Generator_Loss: {}, Discriminator Loss: {}'.format(epoch+1, g_error.item(), e.item()))
            torch.save(generator.state_dict(), 'models/netG.pth')
            torch.save(discriminator.state_dict(), 'models/netD.pth')