## **Generative Adversarial Network (GAN)**

In [16]:
!pip install -q wandb
import wandb
wandb.login()



True

In [2]:
lr = 0.0002
batch_size = 128
config = {
          "dataset": "MNIST",
          "gpu": "colab",
          "model": "GAN",
          "learning_rate": lr,
          "batch_size": batch_size,
}
wandb.init(project="week13_gan", config=config)
wandb.run.name = "mnist gan"

[34m[1mwandb[0m: Currently logged in as: [33mganshuyi[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [4]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [5]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [6]:
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5),(0.5))
                                ])
train_dataset = torchvision.datasets.MNIST(root="MNIST_data/",
                                           train=True,
                                           transform=transform,
                                           download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw



In [7]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [8]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.fc1 = nn.Linear(100, 256)
    self.fc2 = nn.Linear(256, 512)
    self.fc3 = nn.Linear(512, 1024)
    self.fc4 = nn.Linear(1024, 784)
    self.leakyrelu = nn.LeakyReLU(0.2)
    self.tanh = nn.Tanh()
  
  def forward(self, z):
    x = self.fc1(z)
    z = self.leakyrelu(x)

    x = self.fc2(x)
    x = self.leakyrelu(x)

    x = self.fc3(x)
    x = self.leakyrelu(x)

    x = self.fc4(x)
    x = self.tanh(x)

    return x

In [9]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.fc1 = nn.Linear(784, 1024)
    self.fc2 = nn.Linear(1024, 512)
    self.fc3 = nn.Linear(512, 256)
    self.fc4 = nn.Linear(256, 1)
    self.leakyrelu = nn.LeakyReLU(0.2)
    self.dropout = nn.Dropout(0.3)
    self.sigmoid = nn.Sigmoid()
  
  def forward(self, x):
    x = self.fc1(x)
    x = self.leakyrelu(x)
    x = self.dropout(x)

    x = self.fc2(x)
    x = self.leakyrelu(x)
    x = self.dropout(x)

    x = self.fc3(x)
    x = self.leakyrelu(x)
    x = self.dropout(x)

    x = self.fc4(x)
    x = self.sigmoid(x)

    return x

In [10]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [11]:
criterion = torch.nn.BCELoss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

In [12]:
epochs = 200
total_batch_num = len(train_dataloader)

for epoch in range(epochs):
  generator.train()
  discriminator.train()

  avg_g_cost = 0
  avg_d_cost = 0

  for step, batch in enumerate(train_dataloader):
    b_x, _ = batch
    b_x = b_x.view(-1, 784).to(device)
    num_img = len(b_x)
    real_label = torch.ones((num_img, 1)).to(device)
    fake_label = torch.zeros((num_img, 1)).to(device)

    real_logit = discriminator(b_x)
    d_real_loss = criterion(real_logit, real_label)

    z = torch.randn((num_img, 100), requires_grad=False).to(device)
    fake_data = generator(z)
    fake_logit = discriminator(fake_data)
    d_fake_loss = criterion(fake_logit, fake_label)

    d_loss = d_real_loss + d_fake_loss
    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()
    
    z = torch.randn((num_img, 100), requires_grad=False).to(device)
    fake_data = generator(z)
    fake_logit = discriminator(fake_data)
    g_loss = criterion(fake_logit, real_label)

    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()

    avg_d_cost += d_loss
    avg_g_cost += g_loss

  avg_d_cost /= total_batch_num
  avg_g_cost /= total_batch_num

  # observe fake images
  generator.eval()
  with torch.no_grad():
    z = torch.randn((64,100), requires_grad=False).to(device)
    fake_data = generator(z)

    fake_img = fake_data.detach().cpu().numpy().reshape(64,28,28)
    wandb.log({"discriminator loss": avg_d_cost, "generator loss": avg_g_cost, "fake image": [wandb.Image(i) for i in fake_img]})
    

**FID-score**

In [13]:
!mkdir fake_img
!mkdir real_img

In [14]:
from PIL import Image

num_img = 1000
test_noise =torch.randn(num_img, 100, device=device)
with torch.no_grad():
  test_fake = generator(test_noise).detach().cpu().numpy()

  for index, img in enumerate(test_fake):
    fake = np.reshape(img, (28,28))
    fake = (fake * 127.5 + 127.5).astype(np.uint8)
    fake = np.expand_dims(fake, axis=2)
    fake = np.repeat(fake, 3, axis=2)
    im = Image.fromarray(fake)
    im.save("./fake_img/fake_sample{}.jpeg".format(index))

In [15]:
for i in range(num_img):
  real = np.reshape(train_dataset[i][0].detach().cpu().numpy(), (28,28))
  real = (real * 127.5 + 127.5).astype(np.uint8)
  real = np.expand_dims(real, axis=2)
  real = np.repear(real, 3, axis=2)
  im = Image.fromarray(real)
  im.save("./real_img/real_sample{}.jpeg".format(i))

AttributeError: ignored

In [None]:
import os
import torch

from pytorch_fid.fid_score import *

os.environ['KMP_DUPLICATE_LIB_OB']='True'

real_img_path = 'real_img/'
fake_img_path = 'fake_img/'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

if __name__ == "__main__":
  fid = calculate_fid_given_paths(
      paths=[real_img_path, fake_img_path],
      batch_size=128,
      device=device,
      dims=2048
  )

  print("fid score : {}".format(fid))