<a href="https://colab.research.google.com/github/muyuuuu/colab/blob/main/GAN/WGAN-GP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=True)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return self.fc4(x)

In [4]:
z_dim = 100
lambda_gp = 10
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)



In [5]:
def G_train(x):
    #=======================Train the generator=======================#
    G_optimizer.zero_grad()

    z = torch.randn(bs, z_dim).to(device)
    y = torch.ones(bs, 1).to(device)

    G_output = G(z)
    D_output = D(G_output)
    g_loss = -torch.mean(D_output)

    # gradient backprop & optimize ONLY G's parameters
    g_loss.backward()
    G_optimizer.step()

    return g_loss.data.item()

In [6]:
def compute_gradient_penalty(D, real_samples, fake_samples):

    alpha = torch.rand(1).to(device)

    x = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True).to(device)

    interpolates = D(x)

    fake = torch.ones(bs, 1).to(device)

    gradients = torch.autograd.grad(
        outputs=interpolates,
        inputs=x,
        grad_outputs=fake,
        create_graph=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


def D_train(x):
    #=======================Train the discriminator=======================#
    D_optimizer.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim).to(device), torch.ones(bs, 1).to(device)

    z = torch.randn(bs, z_dim).to(device)
    x_fake, y_fake = G(z), torch.zeros(bs, 1).to(device)

    real_validity = D(x_real)
    fake_validity = D(x_fake)

    gradient_penalty = compute_gradient_penalty(D, x_real.data, x_fake.data)

    # gradient backprop & optimize ONLY D's parameters
    D_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
    # D_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
    D_loss.backward()
    D_optimizer.step()

    return  D_loss.data.item()

In [7]:
n_epoch = 500
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    if (epoch % 100 == 99):
        with torch.no_grad():
            test_z = torch.randn(bs, z_dim).to(device)
            generated = G(test_z)

            save_image(generated.view(generated.size(0), 1, 28, 28), 'sample_WGAN_{}'.format(epoch) + '.png')

[1/500]: loss_d: -36.102, loss_g: 35.722
[2/500]: loss_d: -30.222, loss_g: 30.825
[3/500]: loss_d: -33.425, loss_g: 34.292
[4/500]: loss_d: -34.882, loss_g: 35.230
[5/500]: loss_d: -38.325, loss_g: 37.114
[6/500]: loss_d: -41.931, loss_g: 39.303
[7/500]: loss_d: -44.077, loss_g: 40.485
[8/500]: loss_d: -46.132, loss_g: 41.632
[9/500]: loss_d: -47.372, loss_g: 42.142
[10/500]: loss_d: -19.978, loss_g: 21.792
[11/500]: loss_d: -15.929, loss_g: 19.840
[12/500]: loss_d: -9.969, loss_g: 12.519
[13/500]: loss_d: -9.843, loss_g: 13.572
[14/500]: loss_d: -6.676, loss_g: 11.509
[15/500]: loss_d: -6.404, loss_g: 12.013
[16/500]: loss_d: -6.186, loss_g: 11.827
[17/500]: loss_d: -6.249, loss_g: 11.369
[18/500]: loss_d: -6.014, loss_g: 9.937
[19/500]: loss_d: -5.122, loss_g: 9.522
[20/500]: loss_d: -4.683, loss_g: 8.598
[21/500]: loss_d: -4.101, loss_g: 8.019
[22/500]: loss_d: -3.915, loss_g: 7.681
[23/500]: loss_d: -3.692, loss_g: 6.416
[24/500]: loss_d: -3.384, loss_g: 5.840
[25/500]: loss_d: -3.

In [8]:
!zip image.zip *.png

  adding: sample_WGAN_199.png (deflated 3%)
  adding: sample_WGAN_299.png (deflated 3%)
  adding: sample_WGAN_399.png (deflated 3%)
  adding: sample_WGAN_499.png (deflated 3%)
  adding: sample_WGAN_99.png (deflated 1%)


In [9]:
!ls

image.zip   sample_data		 sample_WGAN_299.png  sample_WGAN_499.png
mnist_data  sample_WGAN_199.png  sample_WGAN_399.png  sample_WGAN_99.png
