In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import pyplot as plt
import torchvision.utils as utils

device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

train_dataset = datasets.MNIST(
    root="./mnist", train=True, download=True, transform=transform
)

test_dataset = datasets.MNIST(
    root="./mnist", train=False, download=True, transform=transform
)

train_loader = DataLoader(
    dataset=train_dataset, batch_size=64, shuffle=True, drop_last=True
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=1000, shuffle=False, drop_last=True
)

In [4]:
class Discriminator(nn.Module):
    def __init__(self, width, height, filter_size, input_channel):
        super(Discriminator, self).__init__()

        # (i + 2p - k)/s + 1
        self.nets = nn.Sequential(
            nn.Conv2d(input_channel, filter_size, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(filter_size, filter_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(filter_size * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(filter_size * 2, filter_size * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(filter_size * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(filter_size * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.nets(x)
        x = x.view(-1, 1)

        return x

In [5]:
class Generator(nn.Module):
    def __init__(self, width, height, latent_size, filter_size, output_channel):
        super(Generator, self).__init__()

        self.latent_size = latent_size

        self.nets = nn.Sequential(
            # Kernel, Stride, Padding o=(i−1)∗s+k−2p
            nn.ConvTranspose2d(latent_size, filter_size * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(filter_size * 4),
            nn.ReLU(True),
            # 4x4
            nn.ConvTranspose2d(filter_size * 4, filter_size * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(filter_size * 2),
            nn.ReLU(True),
            # 7x7
            nn.ConvTranspose2d(filter_size * 2, filter_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(filter_size),
            nn.ReLU(True),
            # 14x14
            nn.ConvTranspose2d(filter_size, output_channel, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z):
        x = self.nets(z)

        return x

In [6]:
class GAN(nn.Module):
    def __init__(
        self,
        width,
        height,
        latent_size,
        device="cuda",
    ):
        super(GAN, self).__init__()
        self.generator = Generator(width, height, latent_size, 64, 1).to(device)
        self.discriminator = Discriminator(width, height, 32, 1).to(device)

    def generate(self, batch_size):
        z = torch.randn(batch_size, self.generator.latent_size, 1, 1).to(device)
        return self.generator(z)

    def discriminate(self, x):
        return self.discriminator(x)

    def forward(self, batch_size):
        fake_data = self.generate(batch_size)
        disc_fake_data = self.discriminate(fake_data)
        return fake_data, disc_fake_data

In [7]:
LATENT_SIZE = 100
LR = 3e-4

gan = GAN(28, 28, LATENT_SIZE).to(device)

loss = nn.BCELoss()

optim_gen = optim.Adam(gan.generator.parameters(), lr=LR, betas=(0.5, 0.999))
optim_dis = optim.Adam(gan.discriminator.parameters(), lr=LR, betas=(0.5, 0.999))

In [8]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(f"runs/gan_training_{run_id}")

EPOCHS = 100
BATCH_SIZE = 64

REAL_LABELS = torch.ones(BATCH_SIZE, 1).to(device)
FAKE_LABELS = torch.zeros(BATCH_SIZE, 1).to(device)


for epoch in range(1, EPOCHS + 1):
    total_loss_dis = 0
    total_accuracy = 0
    total_loss_gen = 0
    batches = 0

    for i, (real_images, _) in enumerate(train_loader):
        batches += 1

        # 장치 이동
        real_images = real_images.to(device)

        # 판별기 학습
        optim_dis.zero_grad()
        outputs = gan.discriminate(real_images)
        loss_real = loss(outputs, REAL_LABELS)
        accuracy_real = torch.where(outputs > 0.5, 1.0, 0.0).mean()

        fake_images = gan.generate(BATCH_SIZE)
        outputs = gan.discriminate(fake_images.detach())
        loss_fake = loss(outputs, FAKE_LABELS)
        accuracy_fake = torch.where(outputs < 0.5, 1.0, 0.0).mean()

        loss_dis = loss_real + loss_fake

        loss_dis.backward()
        optim_dis.step()
        accuracy = (accuracy_real + accuracy_fake) / 2

        # 생성기 학습
        optim_gen.zero_grad()
        fake_images = gan.generate(BATCH_SIZE)
        outputs = gan.discriminate(fake_images)
        loss_gen = loss(outputs, REAL_LABELS)
        loss_gen.backward()
        optim_gen.step()

        # 손실과 정확도 누적
        total_loss_dis += loss_dis.item()
        total_accuracy += accuracy.item()
        total_loss_gen += loss_gen.item()

    # 에폭 평균 계산
    avg_loss_dis = total_loss_dis / batches
    avg_accuracy = total_accuracy / batches
    avg_loss_gen = total_loss_gen / batches

    print(
        f"EPOCH #{epoch}/{EPOCHS}, LOSS(D): {avg_loss_dis:.4f}, ACCURACY(D): {avg_accuracy:.4f}, LOSS(G): {avg_loss_gen:.4f}"
    )

    # 손실, 정확도 기록
    writer.add_scalar("Loss/Discriminator", avg_loss_dis, epoch)
    writer.add_scalar("Loss/Generator", avg_loss_gen, epoch)
    writer.add_scalar("Accuracy/Discriminator", avg_accuracy, epoch)

    # 이미지 기록
    with torch.no_grad():
        fake_images = gan.generate(32)
        img_grid = torchvision.utils.make_grid(fake_images, normalize=True)
        writer.add_image("Generated Images", img_grid, epoch)

    if epoch % 10 == 0:
        torch.save(gan.state_dict(), f"./checkpoints/checkpoint_epoch_{epoch}.pth")

writer.close()

EPOCH #1/100, LOSS(D): 0.1391, ACCURACY(D): 0.9871, LOSS(G): 4.9357
EPOCH #2/100, LOSS(D): 0.1076, ACCURACY(D): 0.9885, LOSS(G): 5.1475
EPOCH #3/100, LOSS(D): 0.1574, ACCURACY(D): 0.9817, LOSS(G): 4.9762
EPOCH #4/100, LOSS(D): 0.1869, ACCURACY(D): 0.9746, LOSS(G): 4.6546
EPOCH #5/100, LOSS(D): 0.2162, ACCURACY(D): 0.9698, LOSS(G): 4.6126
EPOCH #6/100, LOSS(D): 0.2300, ACCURACY(D): 0.9682, LOSS(G): 4.5312
EPOCH #7/100, LOSS(D): 0.3167, ACCURACY(D): 0.9457, LOSS(G): 4.0596
EPOCH #8/100, LOSS(D): 0.2669, ACCURACY(D): 0.9544, LOSS(G): 4.3297
EPOCH #9/100, LOSS(D): 0.3344, ACCURACY(D): 0.9381, LOSS(G): 4.0162
EPOCH #10/100, LOSS(D): 0.3526, ACCURACY(D): 0.9345, LOSS(G): 3.8670
EPOCH #11/100, LOSS(D): 0.4554, ACCURACY(D): 0.9103, LOSS(G): 3.4568
EPOCH #12/100, LOSS(D): 0.4345, ACCURACY(D): 0.9125, LOSS(G): 3.4684
EPOCH #13/100, LOSS(D): 0.5522, ACCURACY(D): 0.8848, LOSS(G): 3.0438
EPOCH #14/100, LOSS(D): 0.5661, ACCURACY(D): 0.8793, LOSS(G): 2.9363
EPOCH #15/100, LOSS(D): 0.5984, ACCURACY(D)

In [8]:
gan.load_state_dict(torch.load("./checkpoints/checkpoint_epoch_100.pth"))

<All keys matched successfully>