In [1]:
import torch
import torchvision
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import torchvision.utils as vutils
import matplotlib.pyplot as plt

In [2]:
# Establish Parameters
learning_rate = 0.00005
sizeOfBatch = 64
numEpochs = 135
criticIter = 5
noiseDim = 100
outputFolder="./wganOutput"

In [3]:
# Grab CIFAR10 Dataset
dataset = datasets.CIFAR10(root="./data", download=False, transform=transforms.Compose([
    transforms.Resize(64),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 128, shuffle=True, num_workers=2)

In [4]:
# initialize weights for discriminator and generator
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
# Create Discriminator and Generator Classes
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main=nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0),
        )
    def forward(self, passedInput):
        return self.main(passedInput).mean(0).view(1)

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.main=nn.Sequential(
            nn.ConvTranspose2d(noiseDim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, passedInput):
        return self.main(passedInput)

In [6]:
# Create Generator and Discriminator and apply initial weights
discriminator = Discriminator()
generator = Generator()
discriminator.apply(weights_init)
generator.apply(weights_init)
discriminator.cuda()
generator.cuda()

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [7]:
# Set up optimizers
disOptim = optim.RMSprop(discriminator.parameters(),lr=learning_rate)
genOptim = optim.RMSprop(generator.parameters(),lr=learning_rate)
criterion = nn.BCELoss()

# Establish parameters for training
inputTensor = torch.FloatTensor(sizeOfBatch, 3, 64, 64)
noise = torch.FloatTensor(sizeOfBatch, noiseDim, 1, 1)
normNoise = torch.FloatTensor(sizeOfBatch, noiseDim, 1, 1).normal_(0, 1)
normNoise = Variable(normNoise)
onesTensor = torch.FloatTensor([1])
nOnesTensor=onesTensor * (-1)
realValue = 1
fakeValue = 0

# Make variables available for cuda
criterion.cuda()
inputTensor = inputTensor.cuda()
noise = noise.cuda()
normNoise = normNoise.cuda()
onesTensor = onesTensor.cuda()
nOnesTensor = nOnesTensor.cuda()

# Variables for tracking training progress
dLossList = []
gLossList = []
countList = []
count = 0 

In [None]:
# Training algorithm for discriminator and generator
for epoch in range(numEpochs):
    epochData = iter(dataloader)
    dataCounter = 0
    
    # Iterate for all batches of data
    while dataCounter < len(dataloader):
        for param in discriminator.parameters():
            param.requires_grad=True
        
        # Iterate until critic requirement is satisfied
        criticCounter = 0
        while dataCounter < len(dataloader) and criticCounter < criticIter:
            data = epochData.next()
            criticCounter += 1
            for param in discriminator.parameters():
                param.data.clamp_(-1e-2, 1e-2)
            dataCounter += 1
            
            # Train Discriminator with real data
            disOptim.zero_grad()
            realData, _ = data
            sizeOfBatch = realData.size(0)
            realData = realData.cuda()
            inputTensor.resize_as_(realData).copy_(realData)
            inputVar = Variable(inputTensor)
            disRealError = discriminator(inputVar)
            disRealError.backward(onesTensor)

            # Train Discriminator on fake data
            noise.resize_(sizeOfBatch, noiseDim, 1,1).normal_(0,1)
            noiseVar = Variable(noise)
            fakeData = generator(noiseVar)
            disFakeError = discriminator(fakeData.detach())
            disFakeError.backward(nOnesTensor)
            disOptim.step()
            finalDisError = -disFakeError + disRealError
            
        # Train Generator
        for param in discriminator.parameters():
            param.requires_grad = False
        genOptim.zero_grad()
        genError = discriminator(fakeData)
        genError.backward(onesTensor)
        genOptim.step()
        
        
        print('[%d/%d][%d/%d] DLoss: %.4f GLoss: %.4f' % 
             (epoch, numEpochs, dataCounter, len(dataloader), finalDisError.data[0], genError.data[0]))

        count += 1
        countList.append(count)
        gLossList.append(genError.data.cpu().numpy()[0])
        dLossList.append(finalDisError.data.cpu().numpy()[0])
    
    # Store fake images
    fakeData = generator(normNoise)
    fakeData.data = fakeData.data.mul(0.5).add(0.5)
    vutils.save_image(fakeData.data, '%s/fake_samples_epoch_%03d.png' % 
                      (outputFolder, epoch), normalize=True)

[0/135][5/391] DLoss: -0.2537 GLoss: 0.2394
[0/135][10/391] DLoss: -0.3725 GLoss: 0.2644
[0/135][15/391] DLoss: -0.4881 GLoss: 0.2966
[0/135][20/391] DLoss: -0.5949 GLoss: 0.3350
[0/135][25/391] DLoss: -0.6885 GLoss: 0.3722
[0/135][30/391] DLoss: -0.7643 GLoss: 0.4040
[0/135][35/391] DLoss: -0.8285 GLoss: 0.4326
[0/135][40/391] DLoss: -0.8941 GLoss: 0.4582
[0/135][45/391] DLoss: -0.9447 GLoss: 0.4809
[0/135][50/391] DLoss: -0.9916 GLoss: 0.5013
[0/135][55/391] DLoss: -1.0244 GLoss: 0.5173
[0/135][60/391] DLoss: -1.0665 GLoss: 0.5356
[0/135][65/391] DLoss: -1.1035 GLoss: 0.5500
[0/135][70/391] DLoss: -1.1303 GLoss: 0.5598
[0/135][75/391] DLoss: -1.1552 GLoss: 0.5731
[0/135][80/391] DLoss: -1.1816 GLoss: 0.5835
[0/135][85/391] DLoss: -1.2077 GLoss: 0.5929
[0/135][90/391] DLoss: -1.2177 GLoss: 0.5950
[0/135][95/391] DLoss: -1.2350 GLoss: 0.6058
[0/135][100/391] DLoss: -1.2491 GLoss: 0.6118
[0/135][105/391] DLoss: -1.2703 GLoss: 0.6194
[0/135][110/391] DLoss: -1.2400 GLoss: 0.6150
[0/135][

In [None]:
# Plot the loss of the generator and the descriminator
# plot predictions for arcsinh(x) and compate to ground truth
plt.plot(countList, gLossList, 'r.', label='Generator')
plt.plot(countList, dLossList, 'g.', label='Discriminator')
plt.title("WGAN Loss of Discriminator and Generator")
plt.xlabel("Batch Number")
plt.ylabel("Loss (Binary Cross Entropy)")
plt.legend(loc = "best")
plt.show()