In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import torch.utils.data
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt

In [2]:
project_root = '/Users/JSen/Documents/pytorch-GAN-CGAN/'
os.chdir(project_root)

no_cuda = False
cuda_available = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
BATCH_SIZE = 64
EPOCH = 100
SEED = 8

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda_available else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data', train=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

In [3]:
# Model params
g_input_size = 100     # Random noise dimension coming into generator, per output vector
g_hidden_size = 256   # Generator complexity
g_output_size = 784    # size of generated output vector

d_input_size = 784   # Minibatch size - cardinality of distributions
d_hidden_size = 256   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'


d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)

print_interval = 200

d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(g_input_size, g_hidden_size),
            nn.LeakyReLU(0.2),
            
            nn.Linear(g_hidden_size, g_output_size),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = x.view(x.size(0), g_input_size)
        out = self.model(x)
        return out

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(d_input_size, d_hidden_size),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(d_hidden_size, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        out = self.model(x.view(x.size(0), d_input_size))
        out = out.view(out.size(0), -1)
        return out

In [6]:

D  = Discriminator().to(device)
G = Generator().to(device)
print(D)
print(G)

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.3)
    (3): Linear(in_features=256, out_features=1, bias=True)
    (4): Sigmoid()
  )
)
Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=784, bias=True)
    (3): Tanh()
  )
)


In [7]:
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr)

In [8]:
batch_idx, (data, label) =enumerate(train_loader).__next__()
print(data.shape, label.shape)
real_data = data.to(device)
d_real_decision = D(Variable(real_data))
print(d_real_decision.shape)

torch.Size([64, 1, 28, 28]) torch.Size([64])
torch.Size([64, 1])


In [None]:

for epoch in range(EPOCH):
    D_losses = []
    G_losses = []
    for batch_idx, (data, label) in enumerate(train_loader):
        real_data = data.to(device) 
        D.zero_grad()
        
        for d_index in range(d_steps):
            # 1. Train D on real+fake
            #  1A: Train D on real
            
            d_real_decision = D(Variable(real_data))
#             print(d_real_decision.shape, flush=True, end='')
            if d_real_decision.shape[0] != BATCH_SIZE:
                print(d_real_decision.shape)
            d_real_error = criterion(d_real_decision, Variable(torch.ones(d_real_decision.shape[0],1)))  # ones = true
            d_real_error.backward() # compute/store gradients, but don't change params

            #  1B: Train D on fake
            d_gen_input = Variable(torch.rand(BATCH_SIZE, g_input_size))
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            
            d_fake_decision = D(d_fake_data)
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(BATCH_SIZE, 1)))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
            d_error = d_real_error + d_fake_error
            D_losses.append(d_error.detach().numpy())
            print(f'd:{d_error}', end=' ')
        for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = Variable(torch.rand(BATCH_SIZE, g_input_size))
            g_fake_data = G(gen_input)
            dg_fake_decision = D(g_fake_data)
            g_error = criterion(dg_fake_decision, Variable(torch.ones(BATCH_SIZE, 1)))  # we want to fool, so pretend it's all genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters
            
            G_losses.append(g_error.detach().numpy())
            print(f'g:{g_error}')
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
        (epoch + 1), EPOCH, np.mean(D_losses), np.mean(G_losses)))

d:1.330461025238037 g:0.8089956045150757
d:1.2803194522857666 g:0.8091118335723877
d:1.2372429370880127 g:0.8155979514122009
d:1.1912767887115479 g:0.8114206194877625
d:1.1650160551071167 g:0.8133252263069153
d:1.1336820125579834 g:0.8121735453605652
d:1.1049927473068237 g:0.8034640550613403
d:1.0789728164672852 g:0.8010315895080566
d:1.0518661737442017 g:0.791846513748169
d:1.040086269378662 g:0.7824530601501465
d:1.0275896787643433 g:0.7652156949043274
d:1.015452265739441 g:0.7525621056556702
d:1.0039626359939575 g:0.7402397394180298
d:1.01898193359375 g:0.7174564003944397
d:1.022367000579834 g:0.699038565158844
d:1.0257172584533691 g:0.6807978749275208
d:1.0346004962921143 g:0.6565701365470886
d:1.0433909893035889 g:0.6364479064941406
d:1.078594446182251 g:0.6190317869186401
d:1.0859136581420898 g:0.5851039290428162
d:1.1170021295547485 g:0.5556296706199646
d:1.1421096324920654 g:0.5337837338447571
d:1.1793311834335327 g:0.5150040984153748
d:1.2243237495422363 g:0.4898017644882202
d