In [None]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import make_grid

from IPython.display import display, Image
from IPython.core.debugger import set_trace
import os


In [None]:
DATA_PATH = '../datasets/MNIST/'
EPOCHES = 10
BATCH_SIZE = 64
NOISE_DIM = 100

In [None]:
DOWNLOAD = False
if not os.path.exists(DATA_PATH):
    DOWNLOAD = True

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])    

train_data = datasets.MNIST(root=DATA_PATH, train=True, transform=trans, download=DOWNLOAD)
test_data = datasets.MNIST(root=DATA_PATH, train=False, transform=trans)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

# [1, 28, 28]

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(1*28*28, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        out = self.dis(x)
        return out

class Generator(nn.Module):
    def __init__(self, input_size=100):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1*28*28),
            nn.Tanh()
        )
    def forward(self, x):
        out = self.gen(x)
        return out

d_model = Discriminator()
g_model = Generator(NOISE_DIM)
if torch.cuda.is_available():
    d_model = d_model.cuda()
    g_model = g_model.cuda()

In [None]:
criterion = nn.BCELoss()
d_optim = optim.Adam(d_model.parameters(), lr=0.0003)
g_optim = optim.Adam(g_model.parameters(), lr=0.0003)

In [None]:
def check_generator():
    def to_img(x):
        out = 0.5 * (x + 1)
        out = out.clamp(0, 1)
        out = out.view(-1, 1, 28, 28)
        return out
    trans = transforms.ToPILImage()
    x_fake = Variable(torch.randn(64, NOISE_DIM))
    if torch.cuda.is_available():
        x_fake = x_fake.cuda()
    fake_img = g_model(x_fake)
    display(trans(make_grid(to_img(fake_img.cpu().data))))

In [None]:
# Training
for epoch in range(EPOCHES):
    print('*'*10)
    print('Epoch: {}'.format(epoch))
    d_tloss = 0.0
    g_tloss = 0.0
    d_taccu = 0.0
    for i, data in enumerate(train_loader):
        img, labels = data
        data_num = img.size(0)
        img = img.view(data_num, -1)
        
        x_real = Variable(img)
        x_fake = Variable(torch.randn(data_num, NOISE_DIM))
        y_real = Variable(torch.ones(data_num))
        y_fake = Variable(torch.zeros(data_num))
        if torch.cuda.is_available():
            x_real = x_real.cuda()
            x_fake = x_fake.cuda()
            y_real = y_real.cuda()
            y_fake = y_fake.cuda()
        
        d_out_real = d_model(x_real)
        d_loss_real = criterion(d_out_real, y_real)
        g_out_fake = g_model(x_fake)
        d_out_fake = d_model(g_out_fake.detach())
        d_loss_fake = criterion(d_out_fake, y_fake)
        d_loss = d_loss_real + d_loss_fake
        d_tloss += d_loss.data[0]

        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()

        g_out_fake = g_model(x_fake)
        d_out_fake = d_model(g_out_fake)
        g_loss = criterion(d_out_fake, y_real)
        g_tloss += g_loss.data[0]
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()
    print('Generator_Loss: {}'.format(g_tloss))
    print('Discrimitor_Loss: {}'.format(d_tloss))
    check_generator()