In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from scipy.misc import toimage
import skimage.color

device = torch.device('cuda')     # Default CUDA device

In [6]:
SIZE = 512
BETA1 = 0.5
BETA2 = 0.999
LAMBDA = 10
ALPHA = 1000
numEpochs = 5
latentDim = 100
trainNum = 50

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        #  Convolutional layers 
        
        # input 512x512x3  output 512x512x16
        self.conv1 = nn.Conv2d(3, 16, 5, stride = 1, padding = 1)
        self.conv1_bn = nn.BatchNorm2d(16)
        
        # input 512x512x16  output 256x256x32
        self.conv2 = nn.Conv2d(16, 32, 5, stride = 2, padding = 1)
        self.conv2_bn = nn.BatchNorm2d(32)
        
        # input 265x256x32  output 128x128x64
        self.conv3 = nn.Conv2d(32, 64, 5, stride = 2, padding = 1)
        self.conv3_bn = nn.BatchNorm2d(64)
        
        # input 128x128x64  output 64x64x128
        self.conv4 = nn.Conv2d(64, 128, 5, stride = 2, padding = 1)
        self.conv4_bn = nn.BatchNorm2d(128)
        
        # input 64x64x128  output 32x32x128
        # the output of this layer we need layers for global features
        self.conv5 = nn.Conv2d(128, 128, 5, stride = 2, padding = 1)
        self.conv5_bn = nn.BatchNorm2d(128)
        
        # convs for global features
        # input 32x32x128 output 16x16x128
        self.conv51 = nn.Conv2d(128,128,5, stride =2 , padding =1 )
        
        # input 16x16x128 output 8x8x128
        self.conv52 = nn.Conv2d(128,128,5, stride =2 , padding =1 )
        
        # input 8x8x128 output 1x1x128
        self.conv531 = nn.Conv2d(128,128,5, stride =2 , padding =1 )
        
        # input 1x1x128 output 1x1x128
        self.conv532 = nn.Conv2d(128,128,5, stride =2 , padding =1 )
        
        # input 32x32x128 output 32x32x128
        # the global features should be concatenated to the feature map aftere this layer
        # the output after concat would be 32x32x256
        self.conv6 = nn.Conv2d(128, 128, 5, stride = 1, padding = 1)
        
        # input 32x32x256 output 32x32x128
        self.conv7 = nn.Conv2d(256, 128, 5, stride = 1, padding = 1)
        
        # deconvolutional layers
        # input 32x32x128 output 64x64x128
        self.dconv1 = nn.ConvTranspose2d(128, 128, 5, stride = 2, padding = 1)
        self.dconv1_bn = nn.BatchNorm2d(128)
        
        # input 64x64x256 ouput 128x128x128
        self.dconv2 = nn.ConvTranspose2d(256, 128, 5, stride = 2, padding = 1)
        self.dconv2_bn = nn.BatchNorm2d(128)
        
        # input 128x128x192 output 256x256x64
        self.dconv3 = nn.ConvTranspose2d(192, 64, 5, stride = 2, padding = 1)
        self.dconv3_bn = nn.BatchNorm2d(64)
        
        # input 256x256x96 ouput 512x512x32
        self.dconv4 = nn.ConvTranspose2d(96, 32, 5, stride = 2, padding = 1)
        self.dconv4_bn = nn.BatchNorm2d(32)
        
        # final convolutional layers
        # input 512x512x48 output 512x512x16
        self.conv8 = nn.Conv2d(48, 16, 5, stride = 1, padding = 1)
        self.conv8_bn = nn.BatchNorm2d(16)
        
        # input 512x512x16 output 512x512x3
        self.conv9 = nn.Conv2d(16, 3, 5, stride = 1, padding = 1)    
        self.conv9_bn = nn.BatchNorm2d(3)
        # SELU
                
    def forward(self, x):
        # input 512x512x3 to output 512x512x16
        x = self.conv1_bn(F.selu(self.conv1(x)))
        
        # input 512x512x16 to output 256x256x32
        x1 = self.conv2_bn(F.selu(self.conv2(x)))
        
        # input 256x256x32 to output 128x128x64
        x2 = self.conv3_bn(F.selu(self.conv3(x1)))
        
        # input 128x128x64 to output 64x64x128
        x3 = self.conv4_bn(F.selu(self.conv4(x2)))
        
        # input 64x64x128 to output 32x32x128
        x4 = self.conv5_bn(F.selu(self.conv5(x3)))
        
        #convolutions for global features
        # input 32x32x128 to output 16x16x128
        x51 = self.conv51(x4)
        # input 16x16x128 to output 8x8x128
        x52 = self.conv52(x51)
        # input 8x8x128 to output 1x1x128
        x53 = self.conv532(F.selu(self.conv531(x52)))
        x53_temp = torch.cat([x53]*32)
        x53_temp = torch.cat([x53_temp]*32,dim=1)
        
        
        # input 32x32x256 to output 32x32x128
        x5 = self.conv6(x4)
        
        # input 32x32x128 to output 32x32x128
        x5 = self.conv7(torch.cat(x5,x53_temp))
        
        # input 32x32x128 to output 64x64x128
        xd = self.dconv1(self.dconv1_bn(F.selu(x5)))
        
        # input 64x64x256 to output 128x128x128
        xd = self.dconv2(self.dconv2_bn(F.selu(torch.cat((xd,x3),dim=1))))
        
        # input 128x128x192 to output 256x256x64
        xd = self.dconv3(self.dconv3_bn(F.selu(torch.cat((xd,x2),dim=1))))
        
        # input 256x256x64 to output 512x512x32
        xd = self.dconv4(self.dconv4_bn(F.selu(torch.cat((xd,x1),dim=1))))
        
        # input 512x512x48 to output 512x512x16
        xd = self.conv8(self.conv8_bn(F.selu(torch.cat((xd,x),dim=1))))
        
        # input 512x512x16 to output 512x512x3
        xd = self.conv9(self.conv9_bn(F.selu((xd))))
        return xd

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        #  Convolutional layers 
        
        # input 512x512x3  output 512x512x16
        self.conv1 = nn.Conv2d(3, 16, 5, stride = 1, padding = 1)
        self.conv1_in = nn.InstanceNorm2d(16)
        
        # input 512x512x16  output 256x256x32
        self.conv2 = nn.Conv2d(16, 32, 5, stride = 2, padding = 1)
        self.conv2_in = nn.InstanceNorm2d(32)
        
        # input 265x256x32  output 128x128x64
        self.conv3 = nn.Conv2d(32, 64, 5, stride = 2, padding = 1)
        self.conv3_in = nn.InstanceNorm2d(64)
        
        # input 128x128x64  output 64x64x128
        self.conv4 = nn.Conv2d(64, 128, 5, stride = 2, padding = 1)
        self.conv4_in = nn.InstanceNorm2d(128)
        
        # input 64x64x128  output 32x32x128
        # the output of this layer we need layers for global features
        self.conv5 = nn.Conv2d(128, 128, 5, stride = 2, padding = 1)
        self.conv5_in = nn.InstanceNorm2d(128)
        
        # input 32x32x128  output 16x16x128
        # the output of this layer we need layers for global features
        self.conv6 = nn.Conv2d(128, 128, 5, stride = 2, padding = 1)
        self.conv6_in = nn.InstanceNorm2d(128)
        
        # input 16x16x128  output 1x1x1
        # the output of this layer we need layers for global features
        self.conv6 = nn.Conv2d(128, 1, 5, stride = 32, padding = 1)
        self.conv6_in = nn.InstanceNorm2d(1)
        
    def forward(self, x):
        
        # input 512x512x3 to output 512x512x16
        x = self.conv1_in(F.leaky_relu(self.conv1(x)))
        
        # input 512x512x16 to output 256x256x32
        x = self.conv2_in(F.leaky_relu(self.conv2(x)))
        
        # input 256x256x32 to output 128x128x64
        x = self.conv3_in(F.leaky_relu(self.conv3(x)))
        
        # input 128x128x64 to output 64x64x128
        x = self.conv4_in(F.leaky_relu(self.conv4(x)))
        
        # input 64x64x128 to output 32x32x128
        x = self.conv5_in(F.leaky_relu(self.conv5(x)))
        
        # input 32x32x128 to output 16x16x128
        x = self.conv5_in(F.leaky_relu(self.conv5(x)))
        
        # input 16x16x128 to output 1x1x1
        x = self.conv5_in(F.leaky_relu(self.conv5(x)))
        
        return x

In [4]:
generator1 = Generator()
# generator2 = Generator()
discriminator = Discriminator()

In [5]:
### Loading the training and test sets
# Converting the images for PILImage to tensor, so they can be accepted as the input to the network
transform = transforms.compose ([transforms.Resize(size, interpolation=2),transforms.ToTensor()])

trainset = torchvision.datasets.STL10(root='./data', split='unlabeled', transform=transform, target_transform=None, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=5, shuffle=True)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data\stl10_binary.tar.gz


EOFError: Compressed file ended before the end-of-stream marker was reached

In [None]:
testset = torchvision.datasets.STL10(root='./data', split='test', transform=transform, target_transform=None, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

In [2]:
#  Computes gradient penalty loss for A-WGAN
# LAMBDA = 10
def computeGradientPenalty(D, realSample, fakeSample):
    alpha = Tensor(np.random.random((realSample.size(0), 1, 1, 1)))
    interpolate = alpha * realSample + ((1 - alpha) * fakeSample)
    interpolate = autograd.Variable(interpolate, requries_grad = True)
    dInterpolation = D(interpolate)
    fakeOutput = Variable(Tensor(realSample.shape[0], 1).fill_(1.0), requires_grad=False)
    
    gradients = autograd.grad(
        outputs = dInterpolation,
        inputs = interpolate,
        grad_outputs = fakeOutput,
        create_graph = True,
        retain_graph = True,
        only_inputs = True)[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradientPenalty = ((0,(gradients.norm(2, dim=1)-1)).max()).mean()
    return gradientPenalty

In [9]:
# generator loss
# alpha = 1000

criterion = nn.MSELoss()

def generatorAdversarialLoss( output_images):
    validity = discriminator(output_images)
    gen_adv_loss = torch.mean(validity)
    
    
    return gen_adv_loss

# def computeGeneratorLoss(inputs, outputs_g1,outputs_g2):
def computeGeneratorLoss(inputs, outputs_g1):
    # generator 1
    gen_adv_loss1 = generatorAdversarialLoss(outputs_g1)
    
    # generator 2
#     gen_adv_loss2 = generatorAdversarialLoss(outputs_g2)
    
    i_loss = criterion(inputs, outputs_g1)
    
    gen_loss = -gen_adv_loss1 + ALPHA*i_loss
    
    return gen_loss


In [8]:
optimizer_g1 = optim.Adam(generator1.parameters(), lr = 0.001, betas=(BETA1,BETA2))
# optimizer_g2 = optim.Adam(generator2.parameters(), lr = 0.001, betas=(BETA1,BETA2))
optimizer_d = optim.Adam(discriminator.parameters(), lr = 0.001, betas=(BETA1,BETA2))

In [None]:
#  Computes gradient penalty loss for A-WGAN
def computeGradientPenalty(D, realSample, fakeSample):
    alpha = Tensor(np.random.random((realSample.size(0), 1, 1, 1)))
    interpolate = alpha * realSample + ((1 - alpha) * fakeSample)
    interpolate = autograd.Variable(interpolate, requries_grad = True)
    dInterpolation = D(interpolate)
    fakeOutput = Variable(Tensor(realSample.shape[0], 1).fill_(1.0), requires_grad=False)
    
    gradients = autograd.grad(
        outputs = dInterpolation,
        inputs = interpolate,
        grad_outputs = fakeOutput,
        create_graph = True,
        retain_graph = True,
        only_inputs = True)[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradientPenalty = ((0,(gradients.norm(2, dim = 1)-1)).max()).mean()
    return gradientPenalty

In [None]:
def discrminatorLoss(d1Real, d1Fake, gradPenalty):
    return (torch.mean(d1Fake) - torch.mean(d1Real.mean)) + (LAMBDA*gradPenalty)

In [None]:

for epoch in range(numEpochs):
    for i, data in enumerate(trainloader, 0):
        input, dummy = data
        target = input
        
        realImgs = Variable(input.type(Tensor))
        noiseImgs = Variable(Tensor(np.random.normal(0,1, (input.shape[0], latentDim))))
        
        ### TRAIN DISCRIMINATOR
        optimizer_d.zero_grad()
        fakeImgs = generator1(noiseImgs)
        
        # Real Images
        realValid = discriminator(realImgs)
        # Fake Images
        fakeValid = discriminator(fakeImgs)
        
        gradientPenalty = computeGradientPenalty(discriminator, realImgs.data, fakeImgs.data)
        dLoss = discriminatorLoss(realValid, fakeValid, gradientPenalty)
        dLoss.backward()
        optimizer_d.step()
        optimizer_g1.zero_grad()
        
        ### TRAIN GENERATOR
        if i % trainNum == 50:
            print("Hello" + i)
            