# 라이브러리

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


# 모델 정의

In [3]:
latent_size = 64
hidden_size = 256
image_size = 28 * 28
batch_size = 64

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),

            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),

            nn.Linear(hidden_size, image_size),
            nn.Tanh()
        )
    def forward(self, x):
        return self.fc(x)

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(image_size, hidden_size),
            nn.ReLU(),

            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),

            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.fc(x)

# 데이터 로딩

In [8]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

mnist = MNIST(root='data/',
              train=True,
              transform=transform,
              download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 120477529.82it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 43324640.14it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 32222088.96it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 22952444.30it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [9]:
data_loader = DataLoader(dataset=mnist,
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=2,
                         drop_last=True)

# 초기 설정

In [12]:
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

G = Generator().to(device)
D = Discriminator().to(device)

criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr = 0.0002, betas=(0.5, 0.999))  # Adam은 beta값 두개를 입력으로 받는다
g_optimizer = optim.Adam(G.parameters(), lr = 0.0002, betas=(0.5, 0.999))

In [13]:
# 결과 출력 함수
def show_result_image(result_image, num_images = 25, size = (1, 28, 28)):
    print(result_image.shape)
    image_flat = result_image.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_flat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

# 훈련

In [15]:
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):  # label은 필요 없어서 image만 불러옴
        # Flatten the images
        images = images.reshape(batch_size, -1).to(device)

        # Create labels for real and fake images (real --> 1, fake --> 0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Discriminator 훈련
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)

        z = torch.randn(batch_size, latent_size).to(device) # latent  vector 생성
        fake_images = G(z)
        outputs = D(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)

        # backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad
        d_loss.backward()
        d_optimizer.step()


        # Generator를 훈련
        z = torch.randn(batch_size, latent_size).to(device) # latent  vector 생성
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    show_result_image(fake_images)
    show_result_image(images)

Output hidden; open in https://colab.research.google.com to view.