# 1. Import

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change change!

<torch._C.Generator at 0x7fba06a9b8b0>

# 2. Generator

In [2]:
def generator_block(input_dim, output_dim):
  return nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.BatchNormld(output_dim),
      nn.ReLU(inplace=True),
  )

In [3]:
class Generator(nn.Module):
  def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
    super(Generator, self).__init__()
    # Build the neural network
    self.gen = nn.Sequential(
        get_generator_block(z_dim, hidden_dim),
        get_generator_block(hidden_dim, hidden_dim * 2),
        get_generator_block(hidden_dim, hidden_dim * 2),
        get_generator_block(hidden_dim * 2, hidden_dim * 4),
        get_generator_block(hidden_dim * 4, hidden_dim * 8),

        nn.Linear (hidden_dim * 8, im_dim),
        nn.Sigmoid()
    )

# 3. Discriminator

Structure of Discriminator: Structure of Discriminator block

In [4]:
def discriminator_block(input_dim, output_dim):
  return nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.LeakyReLu(0.2, inplace=True)
  )

In [5]:
class Discriminator(nn.Module):
  def __init__(self, im_dim=784, hidden_dim=128):
    super(Discriminator, self).__init__()
    self.disc = nn.Squential(
        discriminator_block(im_dim, hidden_dim * 4),
        discriminator_block(hidden_dim * 4, hidden_dim * 2),
        discriminator_block(hidden_dim * 2, hidden_dim),
        nn.Linear(hidden_dim, 1)
    )

# 3. Loss Function

5-2. Disc Loss

In [6]:
def get_disc_loss (gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise ( num_images, z_dim, device=device) # z
    fake = gen(fake_noise) # G(z)
    disc_fake_pred = disc(fake.detatch()) # D(G(z))
    # compare fake_pred & zeros
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc (real) # D(x)
    # compare real_pred & ones
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss)/2

    return disc_loss

5-3. Gen Loss

In [7]:
def get_gen_loss (gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise (num_images, z_dim, device=device) # z
    fake = gen(fake_noise) # G(z)
    disc_fake_pred = disc(fake) # D(G(z))
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) # compare

    return gen_loss

# 4. Others

4-1. Noise Function

In [8]:
def get_noise (n_samples, z_dim, device='cpu'):
  return torch.randn (n_sample, z_dim, device=device)

4-2. Parameter Setup

In [9]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
# device = 'cuda' # 학습하기 원하는 리소스
device = 'cpu'

4-3. Data Loading

In [10]:
Dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size = batch_size,
    shuffle = True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# 5. Optimization

5-1. Optimizer

In [11]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

NameError: ignored

# 6. Image Display

In [None]:
def show_tensor_images(image_tensor, num_images=25, size(1, 28, 28)):
image_unflat = image_tensor.detach().cpu().view(-1), *size)
image_grid = make_grid(image_unflat[:_images], nrow=5)
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.show()

# 7. Training

In [None]:
cur_step = 0
mean_generator_loss =
mean_discriminator_loss = 0
test_generator = Trure
gen_loss = False
error = False

image_tensor: the image