In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
latent_dim = 100

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            *self.block(latent_dim, 128, normalize=False),
            *self.block(128, 256),
            *self.block(256, 512),
            *self.block(512, 1024),
            nn.Linear(1024, 1 * 28 * 28),
            nn.Tanh()
        )

    def block(self, input_dim, output_dim, normalize=True):
        layers = [nn.Linear(input_dim, output_dim)]
        if normalize:
            layers.append(nn.BatchNorm1d(output_dim, 0.8))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(1 * 28 * 28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        flatten_img = img.view(img.size(0), -1)
        output = self.model(flatten_img)
        return output

In [4]:
transforms_train = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root="./dataset", train=True, download=True, transform=transforms_train)
data_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw



  cpuset_checked))


In [5]:
generator = Generator()
discriminator = Discriminator()

generator.cuda()
discriminator.cuda()

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [10]:
adversarial_loss = nn.BCELoss()
adversarial_loss.cuda()

BCELoss()

In [15]:
lr = 0.0002

gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
dis_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
import time

n_epochs = 200
sample_interval = 2000
start_time = time.time()

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(data_loader):
        real = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(1.0)  # real : 1
        fake = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(0.0)  # fake : 0

        real_imgs = imgs.cuda()

        gen_optimizer.zero_grad()

        # random noise sampling
        z = torch.normal(mean=0, std=1, size=(imgs.shape[0], latent_dim)).cuda()

        # image generation
        generated_imgs = generator(z)

        gen_loss = adversarial_loss(discriminator(generated_imgs), real)
        gen_loss.backward()
        gen_optimizer.step()

        dis_optimizer.zero_grad()

        real_loss = adversarial_loss(discriminator(real_imgs), real)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        dis_optimizer.step()

        done = epoch * len(data_loader) + i
        if done % sample_interval == 0:
            save_image(generated_imgs.data[:25], f"{done}.png", nrow=5, normalize=True)

    print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {gen_loss.item():.6f}] [Elapsed time: {time.time() - start_time:.2f}s]")


  cpuset_checked))


[Epoch 0/200] [D loss: 0.390655] [G loss: 1.726265] [Elapsed time: 16.14s]
[Epoch 1/200] [D loss: 0.389596] [G loss: 1.203094] [Elapsed time: 32.20s]
[Epoch 2/200] [D loss: 0.358418] [G loss: 0.826851] [Elapsed time: 48.03s]
[Epoch 3/200] [D loss: 0.282546] [G loss: 1.854411] [Elapsed time: 64.22s]
[Epoch 4/200] [D loss: 0.424276] [G loss: 3.553478] [Elapsed time: 80.38s]
[Epoch 5/200] [D loss: 0.366568] [G loss: 0.916532] [Elapsed time: 96.60s]
[Epoch 6/200] [D loss: 0.351186] [G loss: 0.985455] [Elapsed time: 112.72s]
[Epoch 7/200] [D loss: 0.513755] [G loss: 2.893015] [Elapsed time: 128.54s]
[Epoch 8/200] [D loss: 1.155723] [G loss: 0.158064] [Elapsed time: 144.55s]
[Epoch 9/200] [D loss: 0.400728] [G loss: 4.347172] [Elapsed time: 160.53s]
[Epoch 10/200] [D loss: 0.195249] [G loss: 1.574443] [Elapsed time: 176.66s]
[Epoch 11/200] [D loss: 0.167089] [G loss: 3.049140] [Elapsed time: 192.69s]
[Epoch 12/200] [D loss: 0.385837] [G loss: 0.760226] [Elapsed time: 208.76s]
[Epoch 13/200] 