In [None]:
import torch
import torch.utils.data
import torch.nn as nn
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torchvision.utils as vutils



def weight_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    m.weight.data.normal_(0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
    m.weight.data.normal_(1.0, 0.02)
    m.bias.data.fill_(0)
  elif classname.find('Linear') != -1:
    m.weight.data.normal_(0.0, 0.02)
    m.bias.data.fill_(0)

#========================================================
### class 선언

### ebut-ing...

class Generator(nn.Module):
  def __init__(self):
    super().__init__()

    self.network = nn.Sequential(
        nn.Conv2d(2, 256, 5, stride=1, padding=1),
        nn.ReLU(),

        nn.Conv2d(256, 128, 5, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),

        nn.Conv2d(128, 64, 5, stride=2, padding=0),
        nn.BatchNorm2d(64),
        nn.ReLU(),

        nn.Conv2d(64, 1, 4, stride=2, padding=0),
        nn.Tanh()
    )

  def forward(self, x):
    G = self.network(x)
    return G

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()

    self.network = nn.Sequential(
        nn.Conv2d(1, 64, 5, stride=2, padding=1),
        nn.LeakyReLU(0.2),

        nn.Conv2d(64, 128, 4, stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        nn.Conv2d(128, 256, 4, stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256, 1, 3, stride=1, padding=0),
        nn.Sigmoid()
    )

  def forward(self, x):
    D = self.network(x)
    return D

loss_function = nn.BCELoss()
batch_size = 100
nz = 2

dataset = datasets.MNIST('../data', train=True, download=True,
                         transform=transforms.Compose([
                      transforms.ToTensor()
                      ,transforms.Normalize((0.5,), (0.5,))
                    ]))


epochs = 20
learning_rate = 0.001

#---------------------------
num_train = len(dataset)
valid_size = 500

indices = list(range(num_train))
split = num_train - valid_size
np.random.shuffle(indices)
train_idx, valid_idx = indices[:split], indices[split:]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True,
                    transform=transforms.Compose([
                      transforms.ToTensor()
                      ,transforms.Normalize((0.5,), (0.5,))
                    ])),
    batch_size=batch_size, shuffle=True)
real_label = 1
fake_label = 0


net_discriminator = Discriminator().cuda()
net_discriminator.apply(weight_init)

net_generator = Generator().cuda()
net_generator.apply(weight_init)

optimizer_D = optim.Adam(net_discriminator.parameters(), lr=0.0002)
optimizer_G = optim.Adam(net_generator.parameters(), lr=0.001)

label = torch.FloatTensor(batch_size)


train_loss_list = []
val_loss_list = []


for epoch in range(epochs):
  for i, (X, _) in enumerate(train_loader):

    net_discriminator.zero_grad()

    X = X.cuda()
    D = net_discriminator(X)

    label.data.fill_(real_label)
    lossR = loss_function(X, label)
    optimizer_D.zero_grad()
    lossR.backward()

    noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1).cuda()
    fake_data = net_generator(noise)
    
    label.data.fill_(fake_label)
    lossF = loss_function(net_discriminator(fake_data.detach()), label)
    lossF.backward()

    optimizer_D.step()

    optimizer_G.zero_grad()
    label.data.fill_(real_label)
    lossG = loss_function(net_discriminator(fake_data), label)
    lossG.backward()
    optimizer_G.step()



In [None]:
def show_generated_data(real_data, fake_data):
  plt.figure(figsize=(15,5))
  plt.subplot(1,2,1)
  plt.axis("off")
  plt.title("Real images")
  plt.imshow(np.transpose(vutils.make_grid(real_data[:64], padding=5, nomalize=True).cpu(), (1,2,0)))

  plt.subplot(1,2,2)
  plt.axis("off")
  plt.title("Fake images")
  plt.imshow(np.transpose(vutils.make_grid(fake_data.detach()[:64], padding=5, nomalize=True).cpu(), (1,2,0)))
  plt.show()
