In [0]:
!pip3 install torch torchvision

In [0]:
import os,time,itertools,pickle,imageio
import matplotlib.pyplot as plt

In [0]:
import torch as t
from torch import cuda,nn,optim,utils
from torch.nn import functional as F
from torch.autograd import Variable
from torchvision import datasets,transforms

#DCGAN Model (Radford et.al 2015)

In [0]:
#ThePolice

class discriminator(nn.Module):
    # initializers
    def __init__(self, d):
        super(discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(d*8)
        self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0)

    # weight_init... for stable gradient flow across layers
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    
    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.conv1(input), 0.2)             #LeakyReLU in all layers except last
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = F.sigmoid(self.conv5(x))

        return x
      
#TheCulprit

class generator(nn.Module):
    # initializers
    def __init__(self, d):
        super(generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(100, d*8, 4, 1, 0)     #100-d latent space vector
        self.deconv1_bn = nn.BatchNorm2d(d*8)
        self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d*4)
        self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d*2)
        self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(d)
        self.deconv5 = nn.ConvTranspose2d(d, 1, 4, 2, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        # x = F.relu(self.deconv1(input))
        x = F.relu(self.deconv1_bn(self.deconv1(input)))  #ReLU all layers except last where tanh used
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = F.tanh(self.deconv5(x))

        return x
      
#WeightInitialiser
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

#Utilities

In [0]:
#Create Folder
if not os.path.isdir('MNIST_DCGAN_results'):
    os.mkdir('MNIST_DCGAN_results')
if not os.path.isdir('MNIST_DCGAN_results/Random_results'):
    os.mkdir('MNIST_DCGAN_results/Random_results')
if not os.path.isdir('MNIST_DCGAN_results/Fixed_results'):
    os.mkdir('MNIST_DCGAN_results/Fixed_results')


In [0]:
#Stackoverflow Stuff :P

fixed_z = Variable(torch.randn((5 * 5, 100)).view(-1, 100, 1, 1).cuda(), volatile=True) 
def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False):
    
    z = torch.randn((5*5, 100)).view(-1, 100, 1, 1)
    z = Variable(z.cuda(), volatile=True)

    g.eval() #EvalMode
    
    if isFix:
        test_images = g(fixed_z)
    else:
        test_images = g(z)
    
    g.train() 

    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(5*5):
        i = k // 5
        j = k % 5
        ax[i, j].cla()
        ax[i, j].imshow(test_images[k, 0].cpu().data.numpy(), cmap='gray')

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

In [0]:
#Training Plotter

def train_plot(hist, show = False, save = False, path = 'Train_plot.png'):
    
    x = range(len(hist['D_loss']))

    y1 = hist['D_loss']
    y2 = hist['G_loss']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Iterations')
    plt.ylabel('Loss')

    plt.legend(loc='best')
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

In [0]:
#As mentioned in the paper
param = {
    "batch_sz" : 128,
    "lr" : 0.0002,
    "beta" : 0.5,
    "img_size" : 64}

tr_epch = 15

#data transforms
transform = transforms.Compose([
        transforms.Resize(param["img_size"]),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

#Load MNIST
train_loader = utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=param["batch_sz"], shuffle=True)


In [0]:
#Model_init
g = generator(128)
d = discriminator(128)
g.weight_init(mean=0.0, std=0.02)
d.weight_init(mean=0.0, std=0.02)

#cost function
criterion = nn.BCELoss()

if cuda.is_available() :
  g.cuda() 
  d.cuda()
  criterion.cuda()

In [0]:
g_optim = optim.Adam(g.parameters(), lr=param["lr"], betas=(param["beta"], 0.999))
d_optim = optim.Adam(d.parameters(), lr=param["lr"], betas=(param["beta"], 0.999))

#Train Iterator

In [0]:
#DejaVu
hist = {}
hist['D_loss'] = []
hist['G_loss'] = []
hist['per_epoch_time'] = []
hist['tot_time'] = []
num_iter = 0

In [0]:
print("Training Starts Now")

start_time = time.time()

for epoch in range(tr_epch) :
  
  epoch_start_time = time.time()
  d_loss = []
  g_loss = []
  
  for i,data in enumerate(train_loader) : 
  
    y_real = t.ones(param["batch_sz"])   #Alternatively t.ones(data[0].size(0))
    y_fake = t.zeros(param["batch_sz"])
    x_real = data[0]
    z_d = t.randn((param["batch_sz"], 100)).view(-1, 100, 1, 1) #Uniform distribution
    z_g = t.randn((param["batch_sz"], 100)).view(-1, 100, 1, 1) #Uniform distribution
    
    if cuda.is_available() :
      #Variable is autograd compliant as of 0.4.1
      x_real,y_real,y_fake,z_d,z_g = Variable(x_real.cuda()),Variable(y_real.cuda()),Variable(y_fake.cuda()),Variable(z_d.cuda()),Variable(z_g.cuda())
    
    #Police Training
    
    #Discriminator Loss A (Classifies real samples as real....pretty intuitive)
    D_result = d(x_real).squeeze()
    D_loss_A = criterion(D_result,y_real)
    
    #Discriminator loss B (classifies fake samples as fake)
    D_result = d(g(z_d)).squeeze()
    D_loss_B = criterion(D_result,y_fake)
    
    #Total Discriminator Loss = Loss A + Loss B
    D_loss_total = D_loss_A + D_loss_B
    d_loss.append(D_loss_total.data[0])   #Future Use
    
    d.zero_grad()
    D_loss_total.backward()
    d_optim.step()
    
    #Culprit Training
    
    #Generator Loss (generates real looking fake samples)
    G_result = d(g(z_g)).squeeze()
    G_loss_total = criterion(G_result,y_real)
    g_loss.append(G_loss_total.data[0])
    
    g.zero_grad()
    G_loss_total.backward()
    g_optim.step()
    
    num_iter += 1
  
  epoch_end_time = time.time()
  per_epoch_time = epoch_end_time - epoch_start_time
  print('[%d/%d] - ptime:%.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), tr_epch, per_epoch_time, torch.mean(torch.FloatTensor(d_loss)),
                                                              torch.mean(torch.FloatTensor(g_loss))))
  hist['D_loss'].append(torch.mean(torch.FloatTensor(d_loss)))
  hist['G_loss'].append(torch.mean(torch.FloatTensor(g_loss)))
  hist['per_epoch_times'].append(per_epoch_time)
  
  #Showing generator output per epoch
  p = 'MNIST_DCGAN_results/Random_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
  fixed_p = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
  show_result((epoch+1), save=True, path=p, isFix=False)
  show_result((epoch+1), save=True, path=fixed_p, isFix=True)
  
#Iteration Info  
end_time = time.time()
total_time = end_time - start_time
hist['tot_time'].append(total_time)
print("Avg per epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor(hist['per_epoch_ptimes'])), tr_epch, total_time))
print("Training over!..")

#Saving model state
torch.save(g.state_dict(), "MNIST_DCGAN_results/g_param.pkl")
torch.save(d.state_dict(), "MNIST_DCGAN_results/d_param.pkl")
with open('MNIST_DCGAN_results/train_hist.pkl', 'wb') as f:
    pickle.dump(hist, f)



#Results

In [0]:
train_plot(hist, save=True, path='MNIST_DCGAN_results/MNIST_DCGAN_train_hist.png')

In [0]:
#Fun Stuff
images = []


for e in range(tr_epch):
    img_name = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(e + 1) + '.png'
    images.append(imageio.imread(img_name))
    
    
imageio.mimsave('MNIST_DCGAN_results/generation_animation.gif', images, fps=5)
