# Image to Image (PyTorch)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torchvision.utils as vutils

In [23]:
from easydict import EasyDict as edict
opt = edict()
# data params
opt.dataroot = '/hd3/yekui/facades/'
opt.workers = 2

# network params
opt.input_nc = 3
opt.output_nc = 3
opt.ngf = 64
opt.ndf = 64
opt.netG = 'experiment/netG_epoch_19.pth'
opt.netD = 'experiment/netD_epoch_19.pth'

# training params
opt.adam = False
opt.cuda = True
opt.niter = 100     # number of epochs
opt.Diters = 25    # train the discriminator Diters times
opt.experiment = 'experiment'
opt.clamp_lower = -0.01
opt.clamp_upper =  0.01
opt.batch_size = 1
opt.lr = 0.0002
opt.beta1 = 0.8

In [3]:
# prepare data
dataset = datasets.ImageFolder(root=opt.dataroot,
                           transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                         shuffle=True, num_workers=int(opt.workers))

In [4]:
for i, data in enumerate(dataloader):
    print(i, data[0].size(),data[0].max(), data[1].max())

(0, torch.Size([64, 3, 256, 512]), 1.0, 2)
(1, torch.Size([64, 3, 256, 512]), 1.0, 2)
(2, torch.Size([64, 3, 256, 512]), 1.0, 2)
(3, torch.Size([64, 3, 256, 512]), 1.0, 2)
(4, torch.Size([64, 3, 256, 512]), 1.0, 2)
(5, torch.Size([64, 3, 256, 512]), 1.0, 2)
(6, torch.Size([64, 3, 256, 512]), 1.0, 2)
(7, torch.Size([64, 3, 256, 512]), 1.0, 2)
(8, torch.Size([64, 3, 256, 512]), 1.0, 2)
(9, torch.Size([30, 3, 256, 512]), 1.0, 2)


In [4]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [18]:
# define generator
class generator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf):
        super(generator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf     
        
        self.conv1 = nn.Conv2d(input_nc, ngf, 4, stride=2, padding=1, bias=False)      
        self.conv2 = nn.Conv2d(ngf, ngf*2, 4, stride=2, padding=1, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(ngf*2)        
        self.conv3 = nn.Conv2d(ngf*2, ngf*4, 4, stride=2, padding=1, bias=False)
        self.batchnorm3 = nn.BatchNorm2d(ngf*4)        
        self.conv4 = nn.Conv2d(ngf*4, ngf*8, 4, stride=2, padding=1, bias=False)
        self.batchnorm4 = nn.BatchNorm2d(ngf*8)        
        self.conv5 = nn.Conv2d(ngf*8, ngf*8, 4, stride=2, padding=1, bias=False)
        self.batchnorm5 = nn.BatchNorm2d(ngf*8)        
        self.conv6 = nn.Conv2d(ngf*8, ngf*8, 4, stride=2, padding=1, bias=False)
        self.batchnorm6 = nn.BatchNorm2d(ngf*8)        
        self.conv7 = nn.Conv2d(ngf*8, ngf*8, 4, stride=2, padding=1, bias=False)
        self.batchnorm7 = nn.BatchNorm2d(ngf*8)        
        self.conv8 = nn.Conv2d(ngf*8, ngf*8, 4, stride=2, padding=1, bias=False)
        self.batchnorm8 = nn.BatchNorm2d(ngf*8)
        
        self.conv1d = nn.ConvTranspose2d(ngf*8,ngf*8,kernel_size=4, stride=2, padding=1, bias=False)
        self.batchnorm1d = nn.BatchNorm2d(ngf*8)        
        self.conv2d = nn.ConvTranspose2d(ngf*8*2,ngf*8,kernel_size=4, stride=2, padding=1, bias=False)
        self.batchnorm2d = nn.BatchNorm2d(ngf*8)        
        self.conv3d = nn.ConvTranspose2d(ngf*8*2,ngf*8,kernel_size=4, stride=2, padding=1, bias=False)
        self.batchnorm3d = nn.BatchNorm2d(ngf*8)        
        self.conv4d = nn.ConvTranspose2d(ngf*8*2,ngf*8,kernel_size=4, stride=2, padding=1, bias=False)
        self.batchnorm4d = nn.BatchNorm2d(ngf*8)        
        self.conv5d = nn.ConvTranspose2d(ngf*8*2,ngf*4,kernel_size=4, stride=2, padding=1, bias=False)
        self.batchnorm5d = nn.BatchNorm2d(ngf*4)        
        self.conv6d = nn.ConvTranspose2d(ngf*4*2,ngf*2,kernel_size=4, stride=2, padding=1, bias=False)
        self.batchnorm6d = nn.BatchNorm2d(ngf*2)        
        self.conv7d = nn.ConvTranspose2d(ngf*2*2,ngf,kernel_size=4, stride=2, padding=1, bias=False)
        self.batchnorm7d = nn.BatchNorm2d(ngf)        
        self.conv8d = nn.ConvTranspose2d(ngf*2,output_nc,kernel_size=4, stride=2, padding=1, bias=False)
        

    def forward(self, x):                                       # x : batch * 3  * 256 * 256                                                                
        x1 = self.conv1(x)                                      # x1: batch * 64 * 128 * 128
        x2 = self.batchnorm2(self.conv2(F.leaky_relu(x1, 0.2))) # x2: batch * 128 * 64 * 64
        x3 = self.batchnorm3(self.conv3(F.leaky_relu(x2, 0.2))) # x3: batch * 256 * 32 * 32
        x4 = self.batchnorm4(self.conv4(F.leaky_relu(x3, 0.2))) # x4: batch * 512 * 16 * 16
        x5 = self.batchnorm5(self.conv5(F.leaky_relu(x4, 0.2))) # x5: batch * 512 *  8 *  8
        x6 = self.batchnorm6(self.conv6(F.leaky_relu(x5, 0.2))) # x6: batch * 512 *  4 *  4
        x7 = self.batchnorm7(self.conv7(F.leaky_relu(x6, 0.2))) # x7: batch * 512 *  2 *  2
        x8 = self.batchnorm8(self.conv8(F.leaky_relu(x7, 0.2))) # x8: batch * 512 *  1 *  1
        
        d1_ = F.dropout(self.batchnorm1d(self.conv1d(F.relu(x8))), 0.5, training=True) # d1_: batch * 512 *  2 *  2
        d1 = torch.cat((d1_, x7), 1)        # d1:  batch * 1024 *  2 *  2
        d2_ = F.dropout(self.batchnorm2d(self.conv2d(F.relu(d1))), 0.5, training=True) # d2_: batch * 512 *  4 *  4
        d2 = torch.cat((d2_, x6), 1)        # d2:  batch * 1024 *  4 *  4
        d3_ = F.dropout(self.batchnorm3d(self.conv3d(F.relu(d2))), 0.5, training=True) # d3_: batch * 512 *  8 *  8
        d3 = torch.cat((d3_, x5), 1)        # d3:  batch * 1024 *  8 *  8
        d4_ = self.batchnorm4d(self.conv4d(F.relu(d3))) # d4_: batch * 512 *  16 *  16
        d4 = torch.cat((d4_, x4), 1)        # d4:  batch * 1024 *  16 *  16
        d5_ = self.batchnorm5d(self.conv5d(F.relu(d4))) # d5_: batch * 256 *  32 *  32
        d5 = torch.cat((d5_, x3), 1)        # d5:  batch * 512 *  32 *  32
        d6_ = self.batchnorm6d(self.conv6d(F.relu(d5))) # d6_: batch * 128 *  64 *  64
        d6 = torch.cat((d6_, x2), 1)        # d6:  batch * 256 *  64 *  64
        d7_ = self.batchnorm7d(self.conv7d(F.relu(d6))) # d6_: batch * 64 *  128 *  128
        d7 = torch.cat((d7_, x1), 1)        # d7:  batch * 128 *  128 *  128
        d8 = self.conv8d(F.relu(d7))        # d8:  batch * 3   *  256 *  256
        
        out = F.tanh(d8)
        return out
    
netG = generator(opt.input_nc, opt.output_nc, opt.ngf)
netG.apply(weights_init)
if opt.netG != '': # load checkpoint if needed
    netG.load_state_dict(torch.load(opt.netG))
print(netG)
'''
input = torch.randn(1,3,256,256)
input = Variable(input)   
input.cuda()
for each in netG.parameters():
    each.cuda()
res = netG.forward(input)
input.size() , res.size()'''

generator (
  (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batchnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batchnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
  (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
  (conv5): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batchnorm5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
  (conv6): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batchnorm6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
  (conv7): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)


'\ninput = torch.randn(1,3,256,256)\ninput = Variable(input)   \ninput.cuda()\nfor each in netG.parameters():\n    each.cuda()\nres = netG.forward(input)\ninput.size() , res.size()'

In [19]:
# define discriminator
class disciminator(nn.Module):
    def __init__(self, input_nc, output_nc, ndf):
        super(disciminator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ndf = ndf   
        
        main = nn.Sequential()
        main.add_module('initial.conv', nn.Conv2d(input_nc+output_nc, ndf, 4, stride=2, padding=1, bias=False)) # ndf * 128*128
        main.add_module('initial.relu', nn.LeakyReLU(0.2, inplace=True))
                
        main.add_module('extra1.conv',nn.Conv2d(ndf, ndf, 3, 1, 1, bias=False)) # ndf * 128*128
        main.add_module('extra1.batchnorm', nn.BatchNorm2d(ndf))
        main.add_module('extra1.relu', nn.LeakyReLU(0.2, inplace=True))
        
        main.add_module('pyramid1.conv', nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False)) # 2ndf * 64 * 64 
        main.add_module('pyramid1.batchnorm', nn.BatchNorm2d(ndf*2))
        main.add_module('pyramid1.relu', nn.LeakyReLU(0.2, inplace=True))
        
        main.add_module('pyramid2.conv', nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False)) # 4ndf * 32 * 32 
        main.add_module('pyramid2.batchnorm', nn.BatchNorm2d(ndf*4))
        main.add_module('pyramid2.relu', nn.LeakyReLU(0.2, inplace=True))
        
        main.add_module('pyramid3.conv', nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False)) # 8ndf * 16 * 16 
        main.add_module('pyramid3.batchnorm', nn.BatchNorm2d(ndf*8))
        main.add_module('pyramid3.relu', nn.LeakyReLU(0.2, inplace=True))
        
        main.add_module('pyramid4.conv', nn.Conv2d(ndf*8, ndf*8, 4, 2, 1, bias=False)) # 8ndf * 8 * 8 
        main.add_module('pyramid4.batchnorm', nn.BatchNorm2d(ndf*8))
        main.add_module('pyramid4.relu', nn.LeakyReLU(0.2, inplace=True))
        
        main.add_module('pyramid5.conv', nn.Conv2d(ndf*8, ndf*8, 4, 2, 1, bias=False)) # 8ndf * 4 * 4 
        main.add_module('pyramid5.batchnorm', nn.BatchNorm2d(ndf*8))
        main.add_module('pyramid5.relu', nn.LeakyReLU(0.2, inplace=True))
        
        main.add_module('final.conv', nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False)) # 1 * 1 * 1
        self.main = main
        
    def forward(self, x): 
        #output = self.main(x)
        #return output
        output = nn.parallel.data_parallel(self.main, x, [0])
        return output.view(-1, 1).mean()

netD = disciminator(opt.input_nc, opt.output_nc, opt.ndf)
netD.apply(weights_init)
if opt.netD != '': # load checkpoint if needed
    netD.load_state_dict(torch.load(opt.netD))
print(netD)
'''
input = torch.randn(5,6,256,256)
input = Variable(input)
netD = disciminator(opt.input_nc, opt.output_nc, opt.ndf, 3)
netD.cuda()
input.cuda()
res = netD.forward(input)
res.backward(one)
input.size(), res.size()'''

disciminator (
  (main): Sequential (
    (initial.conv): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (initial.relu): LeakyReLU (0.2, inplace)
    (extra1.conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (extra1.batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (extra1.relu): LeakyReLU (0.2, inplace)
    (pyramid1.conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid1.batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (pyramid1.relu): LeakyReLU (0.2, inplace)
    (pyramid2.conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid2.batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (pyramid2.relu): LeakyReLU (0.2, inplace)
    (pyramid3.conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid3.batchnorm): BatchNorm2d(512, ep

'\ninput = torch.randn(5,6,256,256)\ninput = Variable(input)\nnetD = disciminator(opt.input_nc, opt.output_nc, opt.ndf, 3)\nnetD.cuda()\ninput.cuda()\nres = netD.forward(input)\nres.backward(one)\ninput.size(), res.size()'

In [20]:
realA = torch.FloatTensor(opt.batch_size, 3, 256, 256)
realB = torch.FloatTensor(opt.batch_size, 3, 256, 256)

one = torch.FloatTensor([1])
mone = one * -1

if opt.cuda:
    netD.cuda()
    netG.cuda()
    realA, realB = realA.cuda(), realB.cuda()
    one, mone = one.cuda(), mone.cuda()

In [21]:
realA = Variable(realA)
realB = Variable(realB)

In [22]:
# setup optimizer
if opt.adam:
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
else:
    optimizerD = optim.RMSprop(netD.parameters(), lr = opt.lr)
    optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lr)

In [24]:
gen_iterations = 0
for epoch in range(opt.niter):
    data_iter = iter(dataloader)
    i = 0
    while i < len(dataloader):
        ############################
        # (1) Update D network
        ############################
        for p in netD.parameters(): # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update
            
        # train the discriminator Diters times    
        if gen_iterations < 25 or gen_iterations % 500 == 0:
            Diters = 100
        else:
            Diters = opt.Diters
        
        j = 0
        while j < Diters and i < len(dataloader):
            j += 1
            
            # clamp parameters to a cube
            for p in netD.parameters():
                p.data.clamp_(opt.clamp_lower, opt.clamp_upper)
            
            data, _ = data_iter.next()
            _realA, _realB = torch.chunk(data, 2, 3)
            realA.data.resize_(_realA.size()).copy_(_realA)
            realB.data.resize_(_realB.size()).copy_(_realB)
            i += 1
            
            # train with real
            netD.zero_grad()
            inputD = torch.cat((realA,realB), 1)
            errD_real = netD(inputD)
            errD_real.backward(one)
            
            # train with fake
            fakeA = netG(realB)
            inputD = torch.cat((fakeA, realB), 1)
            errD_fake = netD(inputD)
            errD_fake.backward(mone)
            errD = errD_real - errD_fake
            optimizerD.step()
            
        ############################
        # (2) Update G network
        ############################
        for p in netD.parameters():
            p.requires_grad = False # to avoid computation
        netG.zero_grad()
        fakeA = netG(realB)
        inputD = torch.cat((fakeA, realB), 1)
        errG = netD(inputD)
        errG.backward(one)
        optimizerG.step()
        gen_iterations += 1
        
        print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
            % (epoch, opt.niter, gen_iterations, len(dataloader),
            errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))
        
        if gen_iterations % 10 == 0:
            vutils.save_image(_realA, '{0}/real_samples.png'.format(opt.experiment))
            fakeA = netG(realB)
            vutils.save_image(fakeA.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))
    
    # do checkpointing
    torch.save(netG.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch))
    torch.save(netD.state_dict(), '{0}/netD_epoch_{1}.pth'.format(opt.experiment, epoch))

[0/100][1/606] Loss_D: -0.397517 Loss_G: 0.526464 Loss_D_real: -0.051497 Loss_D_fake 0.346020
[0/100][2/606] Loss_D: -1.321397 Loss_G: 0.844468 Loss_D_real: -0.718435 Loss_D_fake 0.602962
[0/100][3/606] Loss_D: -1.543132 Loss_G: 0.855380 Loss_D_real: -0.737275 Loss_D_fake 0.805857
[0/100][4/606] Loss_D: -1.405402 Loss_G: 0.823460 Loss_D_real: -0.668265 Loss_D_fake 0.737137
[0/100][5/606] Loss_D: -1.519274 Loss_G: 0.851752 Loss_D_real: -0.733670 Loss_D_fake 0.785604
[0/100][6/606] Loss_D: -1.569249 Loss_G: 0.857055 Loss_D_real: -0.756072 Loss_D_fake 0.813177
[0/100][7/606] Loss_D: -1.581297 Loss_G: 0.856457 Loss_D_real: -0.758984 Loss_D_fake 0.822313
[1/100][8/606] Loss_D: -1.499021 Loss_G: 0.839157 Loss_D_real: -0.720316 Loss_D_fake 0.778704
[1/100][9/606] Loss_D: -1.583631 Loss_G: 0.859476 Loss_D_real: -0.760348 Loss_D_fake 0.823283
[1/100][10/606] Loss_D: -1.581945 Loss_G: 0.855775 Loss_D_real: -0.761385 Loss_D_fake 0.820560
[1/100][11/606] Loss_D: -1.589665 Loss_G: 0.859210 Loss_D_r