In [1]:
import os
import numpy as np
import math

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import torch.nn as nn
import torch.nn.functional as F
import torch

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

torch.manual_seed(10)

<torch._C.Generator at 0x7fc77e55d530>

In [2]:
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(28), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=512,
    shuffle=True,
)

In [3]:
class Generator(nn.Module): 
    def __init__(self):
        super().__init__()
        
        channels, img_size = 1, 28
        latent_dim = 100
        self.image_shape = (channels, img_size, img_size)
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),   # * means unpacking (list, tuple)
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, np.prod(self.image_shape)),   # 784
            nn.Tanh()
        )
    
    def forward(self, x):
        img = self.model(x)
        img = img.view(img.size(0), *self.image_shape)
        return img
    
    

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [5]:
cuda = True if torch.cuda.is_available() else False
n_epochs = 128
lr = 0.0002
b1 = 0.5
b2 = 0.999


G = Generator()
D = Discriminator()

loss_function = nn.BCELoss()


if cuda:
    G.cuda()
    D.cuda()
    device = torch.device('cuda')
    
optim_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(b1, b2))
optim_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(b1, b2))


In [None]:
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones(imgs.size(0),1, requires_grad=False, device=device)
        fake = torch.zeros(imgs.size(0),1, requires_grad=False, device=device)
        
        real_imgs = imgs.type(torch.cuda.FloatTensor)
        
        # Train Generator
        optim_G.zero_grad()
        z = torch.randn([imgs.shape[0], 100], device=device)   # 100: latent_dim
        gen_imgs = G(z)
        g_loss = loss_function(D(gen_imgs), valid)   # Discriminator를 진짜라고 속이는 방향으로 학습
        
        g_loss.backward()
        optim_G.step()
        
        # Train Discriminator
        optim_D.zero_grad()
        
        real_loss = loss_function(D(real_imgs), valid)
        fake_loss = loss_function(D(gen_imgs.detach()), fake)
        
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optim_D.step()
        
        if i % 100 == 0 :
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % 1000 == 0:
            save_image(gen_imgs.data[:25], "../../data/gan/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 0/128] [Batch 0/118] [D loss: 0.654661] [G loss: 0.693346]
[Epoch 0/128] [Batch 100/118] [D loss: 0.474566] [G loss: 0.819877]
[Epoch 1/128] [Batch 0/118] [D loss: 0.551927] [G loss: 0.705024]
[Epoch 1/128] [Batch 100/118] [D loss: 0.502806] [G loss: 0.672970]
[Epoch 2/128] [Batch 0/118] [D loss: 0.420660] [G loss: 0.897984]
[Epoch 2/128] [Batch 100/118] [D loss: 0.483977] [G loss: 0.663058]
[Epoch 3/128] [Batch 0/118] [D loss: 0.587980] [G loss: 0.496331]
[Epoch 3/128] [Batch 100/118] [D loss: 0.567325] [G loss: 0.634318]
[Epoch 4/128] [Batch 0/118] [D loss: 0.481992] [G loss: 0.710637]
[Epoch 4/128] [Batch 100/118] [D loss: 0.509074] [G loss: 1.401447]
[Epoch 5/128] [Batch 0/118] [D loss: 0.539858] [G loss: 0.732140]
[Epoch 5/128] [Batch 100/118] [D loss: 0.649973] [G loss: 0.417167]
[Epoch 6/128] [Batch 0/118] [D loss: 0.429105] [G loss: 1.194613]
[Epoch 6/128] [Batch 100/118] [D loss: 0.612132] [G loss: 0.394393]
[Epoch 7/128] [Batch 0/118] [D loss: 0.455231] [G loss: 0.7078