In [None]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler


def weight_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    m.weight.data.normal_(0.0,0.02)
  elif classname.find('BatchNorm') != -1:
    m.weight.data.normal_(1.0,0.02)
    m.bias.data.fill_(0)
  elif classname.find('Linear') != -1 :
    m.weight.data.normal_(0.0,0.02)
    m.bias.data.fill_(0)

import torchvision.utils as vutils

def show_generated_data(real_data, fake_data):
  plt.figure(figsize = (15,5))
  plt.subplot(1,2,1)
  plt.axis("off")
  plt.title("Real Images")
  plt.imshow(np.transpose(vutils.make_grid(real_data[:64], padding=5, normalize = True).cpu(), (1,2,0)))

  plt.subplot(1,2,2)
  plt.axis("off")
  plt.title("Fake images")
  plt.imshow(np.transpose(vutils.make_grid(fake_data.detach()[:64], padding = 5, normalize = True).cpu(), (1,2,0)))
  plt.show()


class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(2, 256, 5, stride = 1, padding = 1),
        nn.BatchNorm2d(256),
        nn.ReLU(),

        nn.ConvTranspose2d(256, 128, 5, stride = 1,padding = 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),

        nn.ConvTranspose2d(128, 64, 5, stride = 2, padding = 0),
        nn.BatchNorm2d(64),
        nn.ReLU(),

        nn.ConvTranspose2d(64, 1, 4, stride = 2, padding = 0),
        nn.Tanh()
    )

  def forward(self, z):
    x_ = self.decoder(z)
    return x_


class Discriminator(nn.Module):
  def __init__(self, class_num ):
    super().__init__()
    self.class_num = class_num
    self.conv_net = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=2, padding=1),
      nn.LeakyReLU(0.2),

      nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2),

      nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2),

      nn.Conv2d(in_channels=256, out_channels=1, kernel_size=3, stride=1, padding=0),
      nn.Sigmoid()
    )

  def forward(self,x):
    y = self.conv_net(x)
    return y

epochs = 12
lr_G = 0.001    
lr_D = 0.0002  

nz = 2  
batch_size = 100
loss_function = nn.BCELoss()

real_label=1
fake_label=0

data_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data',train=True,download=True,
                   transform= transforms.Compose([
                                                  transforms.ToTensor(),
                                                  transforms.Normalize((0.5,),(0.5,))
                   ])),
    batch_size = batch_size, shuffle=True)

net_G = Generator().cuda()
net_G.apply(weight_init)
net_D = Discriminator(class_num = 10).cuda()
net_D.apply(weight_init)

optimizer_G = optim.Adam(net_G.parameters(), betas = (0.5,0.999), lr = lr_G)
optimizer_D = optim.Adam(net_D.parameters(), betas = (0.5,0.999), lr = lr_D)

D_train_loss_list = []
G_train_loss_list = []

for epoch in range(epochs):
  for i, (X,t) in enumerate(data_loader, 0):
    X = X.cuda()
    t = torch.FloatTensor(batch_size).data.fill_(real_label).cuda()
    Y = net_D(X)

    loss_real = loss_function(Y,t)
    optimizer_D.zero_grad()
    loss_real.backward()
    optimizer_D.step()

    noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0,1).cuda()
    X_fake = net_G(noise).cuda()
    t_fake = torch.FloatTensor(batch_size).data.fill_(fake_label).cuda()
    Y_fake = net_D(X_fake.detach())

    loss_fake = loss_function(Y_fake,t_fake)
    optimizer_D.zero_grad()
    loss_fake.backward()
    optimizer_D.step()

    loss_D = loss_real + loss_fake
    D_train_loss_list.append(loss_D)

    # print("from real + fake data [%d/%d][%d/%d] loss : %f" %(i,len(data_loader),epoch,epochs,loss_D))


    t_fake_G = torch.FloatTensor(batch_size).data.fill_(real_label).cuda()
    Y_G = net_D(X_fake)

    loss_G = loss_function(Y_G, t_fake_G)
    optimizer_G.zero_grad()
    loss_G.backward()
    optimizer_G.step()

    G_train_loss_list.append(loss_G)
    # print("generator [%d/%d][%d/%d] loss : %f" %(i,len(data_loader),epoch,epochs,loss_G))

    if i%100 == 1:
      show_generated_data(X, X_fake)

  