In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
%matplotlib inline
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--epoch', type=int, default='64')
config = parser.parse_args(['--gpu', '0', '--epoch', '64'])

os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu)
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
print(device)

latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100

In [None]:
class Distinguisher(nn.Module):
    def __init__(self):
        super(Distinguisher, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(image_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, image_size),
            nn.Tanh()
        )
    def forward(self, x):
        return self.net(x)


In [None]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
mnist_dataset = datasets.MNIST(root="../../data", train=True, 
                               download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(dataset=mnist_dataset, 
                                         batch_size=batch_size,
                                         shuffle=True)
G_net = Generator().cuda()
D_net = Distinguisher().to(device)
G_optim = torch.optim.Adam(G_net.parameters(), lr=2e-4)
D_optim = torch.optim.Adam(D_net.parameters(), lr=2e-4)

print("preparation done")
for epoch in range(config.epoch):
    for idx, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        #train D_net
        if idx % 2 == 0:
            outputs = D_net(images)
            d_loss = nn.BCELoss()(outputs, real_labels)
            real_score = torch.mean(outputs).item()
        
            gaussian_noise = torch.randn(batch_size, latent_size).to(device)
            fake_images = G_net(gaussian_noise)
            print(fake_images.shape)
            outputs = D_net(fake_images.view(batch_size, -1))
            d_loss += nn.BCELoss()(outputs, fake_labels)
        
            G_optim.zero_grad()
            D_optim.zero_grad()
            d_loss.backward()
            D_optim.step()
        
        #train G_net
        gaussian_noise = torch.randn(batch_size, latent_size).to(device)
        fake_images = G_net(gaussian_noise)
        outputs = D_net(fake_images.view(batch_size, -1))
        fake_score = torch.mean(outputs).item()
        g_loss = nn.BCELoss()(outputs, real_labels)
        G_optim.zero_grad()
        D_optim.zero_grad()
        g_loss.backward()
        G_optim.step()
        
        if (idx + 1) % 100 == 0:
            print('epoch: {}, step: {}, d_loss: {:.4f}, g_loss:{:.4f}, real_score: {:.2f},fake_score: {:.2f}'
                 .format(epoch, idx+1, d_loss.item(), g_loss.item(), 
                        real_score, fake_score))
    fake_images = fake_images.reshape(fake_images.size(0), 28, 28)
    fake_image_id = torch.argmax(outputs, dim=0)
    print(fake_image_id)
    print(outputs[fake_image_id])
    plt.imshow(fake_images[fake_image_id].detach().cpu().numpy()[0])
    plt.show()
    