In [7]:
import torch
import torch.nn as nn
import torch.utils.data
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [8]:
# Hyper-parameters & Variables setting
num_epoch = 200
batch_size = 100
learning_rate = 0.0002
img_size = 28 * 28
num_channel = 1
dir_name = "GAN_results"

noise_size = 100
hidden_size1 = 256
hidden_size2 = 512
hidden_size3 = 1024

In [10]:
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5)])
# MNIST dataset setting
MNIST_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True,
                                           transform=transform,
                                           download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=MNIST_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 29581947.35it/s]


Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 26805862.76it/s]


Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 7734693.82it/s]


Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10738742.26it/s]


Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw



In [12]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear1 = nn.Linear(img_size, hidden_size2)
        self.linear2 = nn.Linear(hidden_size2, hidden_size1)
        self.linear3 = nn.Linear(hidden_size1,1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        x = self.leaky_relu(self.linear1(x))
        x = self.leaky_relu(self.linear2(x))
        x = self.linear3(x)
        x = self.sigmoid(x)
        return x

In [13]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(noise_size, hidden_size1)
        self.linear2 = nn.Linear(hidden_size1, hidden_size2)
        self.linear3 = nn.Linear(hidden_size2, img_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.tanh(x)
        return x

In [15]:
discriminator = Discriminator()
generator = Generator()

# Device setting
discriminator = discriminator.to(device)
generator = generator.to(device)

In [16]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr = learning_rate)
g_optimizer = torch.optim.Adam(generator.parameters(), lr = learning_rate)

In [18]:
for epoch in range(num_epoch):
    for i, (images, label) in enumerate(data_loader):
        # 라벨을 만들어 줍니다. 1 for real, 0 for fake
        real_label = torch.full((batch_size, 1), 1, dtype=torch.float32).to(device)
        fake_label = torch.full((batch_size, 1), 0, dtype=torch.float32).to(device)

        # MNIST dataset의 데이터를 flatten 하게 reshape 해줍니다.
        real_images = images.reshape(batch_size, -1).to(device)
        
        # Initialize grad
        g_optimizer.zero_grad()
        d_optimizer.zero_grad()

        # fake image를 generator와 noize vector 'z' 를 통해 만들어주기
        z = torch.randn(batch_size, noise_size).to(device)
        fake_images = generator(z)

        # loss function에 fake image와 real label을 넘겨주기
        # 만약 generator가 discriminator를 속이면, g_loss가 줄어든다.
        g_loss = criterion(discriminator(fake_images), real_label)

        # backpropagation를 통해 generator 학습
        g_loss.backward()
        g_optimizer.step()
        
        # Initialize grad
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()

        # generator와 noise vector 'z'로 fake image 생성
        z = torch.randn(batch_size, noise_size).to(device)
        fake_images = generator(z)

        # fake image와 fake label, real image와 real label을 넘겨 loss 계산
        fake_loss = criterion(discriminator(fake_images), fake_label)
        real_loss = criterion(discriminator(real_images), real_label)
        d_loss = (fake_loss + real_loss) / 2

        # backpropagation을 통해 discriminator 학습
        # 이 부분에서는 generator는 학습시키지 않음
        d_loss.backward()
        d_optimizer.step()

        d_performance = discriminator(real_images).mean()
        g_performance = discriminator(fake_images).mean()

        if (i + 1) % 150 == 0:
            print("Epoch [ {}/{} ]  Step [ {}/{} ]  d_loss : {:.5f}  g_loss : {:.5f}"
                  .format(epoch, num_epoch, i+1, len(data_loader), d_loss.item(), g_loss.item()))

    # print discriminator & generator's performance
    print(" Epock {}'s discriminator performance : {:.2f}  generator performance : {:.2f}"
          .format(epoch, d_performance, g_performance))

    # Save fake images in each epoch
    samples = fake_images.reshape(batch_size, 1, 28, 28)
    save_image(samples, os.path.join(dir_name, 'GAN_fake_samples{}.png'.format(epoch + 1)))

Epoch [ 0/200 ]  Step [ 150/600 ]  d_loss : 0.03776  g_loss : 3.87854
Epoch [ 0/200 ]  Step [ 300/600 ]  d_loss : 0.19851  g_loss : 4.57201
Epoch [ 0/200 ]  Step [ 450/600 ]  d_loss : 0.02942  g_loss : 7.08624
Epoch [ 0/200 ]  Step [ 600/600 ]  d_loss : 0.02842  g_loss : 5.68782
 Epock 0's discriminator performance : 0.97  generator performance : 0.01
Epoch [ 1/200 ]  Step [ 150/600 ]  d_loss : 0.04292  g_loss : 5.41871
Epoch [ 1/200 ]  Step [ 300/600 ]  d_loss : 0.20588  g_loss : 4.53287
Epoch [ 1/200 ]  Step [ 450/600 ]  d_loss : 0.05058  g_loss : 6.53792
Epoch [ 1/200 ]  Step [ 600/600 ]  d_loss : 0.08754  g_loss : 4.89664
 Epock 1's discriminator performance : 0.93  generator performance : 0.02
Epoch [ 2/200 ]  Step [ 150/600 ]  d_loss : 0.38313  g_loss : 2.61285
Epoch [ 2/200 ]  Step [ 300/600 ]  d_loss : 0.18694  g_loss : 5.89945
Epoch [ 2/200 ]  Step [ 450/600 ]  d_loss : 0.14291  g_loss : 6.24403
Epoch [ 2/200 ]  Step [ 600/600 ]  d_loss : 0.21531  g_loss : 3.75441
 Epock 2's d

KeyboardInterrupt: 