In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from modelAlpha import _netlocalWD,_netWG
import utils

In [2]:
class opt():
    def __init__(self):
        self.pngdataroot = 'dataset/pngset'
        self.dataroot='dataset/animation'
        self.workers=2
        self.batchSize=128 #'input batch size')
        self.imageSize=256 #the height / width of the input image to network')
        self.nz=100
        self.ngf=128
        self.ndf=128 # center image size
        self.nc=4
        self.niter=4000
        self.lr=0.0002
        self.beta1=0.5
        self.cuda=True
        self.ngpu=1
        #self.netG=''
        #self.netD=''
        self.netG='model/netG_streetview_1.pth'
        self.netD='model/netlocalD_1.pth'
        self.outf='.'
        self.manualSeed = 0
        self.nBottleneck=4000  # 'of dim for bottleneck of encoder')
        self.overlapPred = 0 # 'overlapping edges')
        self.nef = 64#'of encoder filters in first conv layer')
        self.wtl2 = 0.998 #'0 means do not use else use with this weight')
        self.wtlD =0.001# means do not use else use with this weight')
        self.jittering=False


In [3]:
opt = opt()

In [4]:
if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark = True


Random Seed:  0


In [5]:
if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

transform = transforms.Compose([transforms.Resize(opt.imageSize+64),
                                transforms.RandomCrop(opt.imageSize, padding=0),
                                transforms.ToTensor()])
transform_png = transforms.Compose([transforms.Resize(opt.ndf),
                                    transforms.RandomRotation((-20,20)),
                                    transforms.RandomVerticalFlip(),
                                    transforms.RandomCrop(opt.ndf, padding=0),
                                    transforms.ToTensor()])
dataset = dset.ImageFolder(root=opt.dataroot, transform=transform)
pngdataset = dset.ImageFolder(root=opt.pngdataroot, transform=transform_png)
assert dataset
assert pngdataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))
pngdataloader = torch.utils.data.DataLoader(pngdataset, batch_size=opt.batchSize,
                                          shuffle=True, num_workers=int(opt.workers))

print("data loading has been done")

data loading has been done


In [6]:
ngpu = int(opt.ngpu)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3
nef = int(opt.nef)
nBottleneck = int(opt.nBottleneck)
wtl2 = float(opt.wtl2)
overlapL2Weight = 10

print("Setting options has been done")

Setting options has been done


In [7]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


resume_epoch=0

netG = _netWG(opt)
netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG,map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.netG)['epoch']
print(netG)


netD = _netlocalWD(opt)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD,map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.netD)['epoch']

criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()

input_real = torch.FloatTensor(opt.batchSize, 4, opt.imageSize, opt.imageSize)
input_cropped = torch.FloatTensor(opt.batchSize, 4, opt.imageSize, opt.imageSize)
input_png = torch.FloatTensor(opt.batchSize,1, opt.ndf, opt.ndf)
input_pngReverse = torch.FloatTensor(opt.batchSize,4, opt.ndf, opt.ndf)
label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0

image_margin = int((opt.imageSize - opt.ndf)/2)

print(opt.batchSize)
print(opt.imageSize)

_netWG(
  (main): Sequential(
    (0): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(0.2, inplace)
    (2): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU(0.2, inplace)
    (5): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (7): LeakyReLU(0.2, inplace)
    (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (10): LeakyReLU(0.2, inplace)
    (11): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (13): LeakyReLU(0.2, inplace)
    (14): Conv2d(512, 4000, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (15): BatchNorm2d(4000, eps=

In [8]:
real_center = torch.FloatTensor(opt.batchSize, 4, opt.ndf, opt.ndf)
#real_center = torch.FloatTensor(64, 3, 64,64)

if opt.cuda:
    netD.cuda()
    netG.cuda()
    criterion.cuda()
    criterionMSE.cuda()
    input_real, input_cropped,label = input_real.cuda(),input_cropped.cuda(), label.cuda()
    real_center = real_center.cuda()
    input_png = input_png.cuda()


input_real = Variable(input_real)
input_cropped = Variable(input_cropped)
label = Variable(label)
input_png = Variable(input_png)
input_pngReverse = Variable(input_pngReverse)


real_center = Variable(real_center)

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

print("Setup Optimizer has been done")

Setup Optimizer has been done


In [None]:
pngmx = 93
for epoch in range(resume_epoch,opt.niter):      
    # jittering add
    randwf = random.uniform(-1.0,1.0)
    randhf = random.uniform(-1.0,1.0)
    if opt.jittering:
        jitterSizeW = int(opt.imageSize/5*randwf)
        jitterSizeH = int(opt.imageSize/5*randhf)
        print("jittering : W > ",jitterSizeW," H >",jitterSizeH)
    else :
        jitterSizeW = 0
        jitterSizeH = 0

    pngdata = list(enumerate(pngdataloader, 0))[0][1][0]
    for i, data in enumerate(dataloader, 0):
        png_image = pngdata
        #print(png_image.size())
        real_cpu, _ = data
        real_cpuV =  Variable(real_cpu)
        real_center_cpu = real_cpu[:,:,
                                   int(image_margin+jitterSizeW):int(image_margin+opt.ndf+jitterSizeW),
                                   int(image_margin+jitterSizeH):int(image_margin+opt.ndf+jitterSizeH)]
        batch_size = real_cpu.size(0)
        input_real.data.resize_(real_cpu.size()).copy_(real_cpu)
        #input_cropped.data.resize_(real_cpu.size()).copy_(real_cpu)
        #real_center.data.resize_(real_center_cpu.size()).copy_(real_center_cpu)
        
        input_cropped.data.resize_(torch.Size([batch_size, opt.nc, opt.imageSize, opt.imageSize]))
        real_center.data.resize_(torch.Size([batch_size, opt.nc, opt.ndf, opt.ndf]))
        #print("batch_size : ",batch_size , " input_cropped : ",input_cropped.data.size())
        #print("real_cetner", real_center.data[:,0].size(), " real_cetner_cpu", real_center_cpu[:,0].size())
        for j in range(0,batch_size):
            real_center.data[j,0] = real_center_cpu[j,0]
            real_center.data[j,1] = real_center_cpu[j,1]
            real_center.data[j,2] = real_center_cpu[j,2]
            real_center.data[j,3] = 1
            input_cropped.data[j,0] = real_cpuV.data[j,0]
            input_cropped.data[j,1] = real_cpuV.data[j,1]
            input_cropped.data[j,2] = real_cpuV.data[j,2]
            input_cropped.data[j,3] = 1
            input_png.data[j%pngmx,0] = torch.abs(png_image[j%pngmx,0] - 1)
            input_pngReverse.data[j%pngmx,0] = png_image[j%pngmx,0]
            input_pngReverse.data[j%pngmx,1] = png_image[j%pngmx,0]
            input_pngReverse.data[j%pngmx,2] = png_image[j%pngmx,0]
            input_pngReverse.data[j%pngmx,3] = png_image[j%pngmx,0]
            
            input_cropped.data[j,0,
                   int(image_margin+opt.overlapPred+jitterSizeW):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeW),
                   int(image_margin+opt.overlapPred+jitterSizeH):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeH)] = input_png.data[j%pngmx,0] * input_cropped.data[j,0,
                   int(image_margin+opt.overlapPred+jitterSizeW):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeW),
                   int(image_margin+opt.overlapPred+jitterSizeH):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeH)]
            input_cropped.data[j,1,
                   int(image_margin+opt.overlapPred+jitterSizeW):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeW),
                   int(image_margin+opt.overlapPred+jitterSizeH):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeH)] = input_png.data[j%pngmx,0] * input_cropped.data[j,1,
                   int(image_margin+opt.overlapPred+jitterSizeW):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeW),
                   int(image_margin+opt.overlapPred+jitterSizeH):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeH)]
            input_cropped.data[j,2,
                   int(image_margin+opt.overlapPred+jitterSizeW):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeW),
                   int(image_margin+opt.overlapPred+jitterSizeH):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeH)] = input_png.data[j%pngmx,0] * input_cropped.data[j,2,
                   int(image_margin+opt.overlapPred+jitterSizeW):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeW),
                   int(image_margin+opt.overlapPred+jitterSizeH):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeH)]
            input_cropped.data[j,3,
                   int(image_margin+opt.overlapPred+jitterSizeW):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeW),
                   int(image_margin+opt.overlapPred+jitterSizeH):int(image_margin+opt.ndf-opt.overlapPred+jitterSizeH)] = input_png.data[j%pngmx,0]
        
        # train with real
        netD.zero_grad()
        label.data.resize_(batch_size).fill_(real_label)
        
        #print("real_center size :",real_center.size(),", label size:",label.data.size())
        output = netD(real_center)
        #print("output size:",output.size())
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.data.mean()

        # train with fake
        fake = netG(input_cropped)
        label.data.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake
        optimizerD.step()


        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.data.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG_D = criterion(output, label)
        # errG_D.backward(retain_variables=True)

        # errG_l2 = criterionMSE(fake,real_center)
        wtl2Matrix = real_center.clone()
        wtl2Matrix.data.fill_(wtl2*overlapL2Weight)
        
        '''
        for j in range(0,batch_size):
            wtl2Matrix.data[j,:,
                            int(opt.overlapPred):int(opt.ndf - opt.overlapPred),
                            int(opt.overlapPred):int(opt.ndf - opt.overlapPred)] = wtl2 * input_pngReverse.data
        '''

        errG_l2 = (fake-real_center).pow(2)
        errG_l2 = errG_l2 * wtl2Matrix
        errG_l2 = errG_l2.mean()

        errG = (1-wtl2) * errG_D + wtl2 * errG_l2

        errG.backward()

        D_G_z2 = output.data.mean()
        optimizerG.step()
        '''
        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
              % (epoch, opt.niter, i, len(dataloader),
                 errD.data[0], errG_D.data[0],errG_l2.data[0], D_x,D_G_z1, ))
        '''
        if i % 50 == 0:
            vutils.save_image(real_cpu,'result/train/real/real_samples_epoch_%03d.png' % (epoch))            
            tmpsave = torch.FloatTensor(batch_size, 3, opt.imageSize, opt.imageSize)
            tmpsave[:,0] = input_cropped.data[:,0]
            tmpsave[:,1] = input_cropped.data[:,1]
            tmpsave[:,2] = input_cropped.data[:,2]
            vutils.save_image(tmpsave,'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
            '''
            tmpsave[:,0] = input_cropped.data[:,0]
            tmpsave[:,1] = input_cropped.data[:,1]
            tmpsave[:,2] = input_cropped.data[:,2]
            tmpfake = torch.FloatTensor(batch_size, 3, opt.ndf, opt.ndf)
            tmpReverse = torch.FloatTensor(3, opt.ndf, opt.ndf)
            tmpfake[:,0]= fake.data[:,0]
            tmpfake[:,1]= fake.data[:,1]
            tmpfake[:,2]= fake.data[:,2]
            tmpReverse[0] = input_pngReverse.data[0]
            #print(tmpfake.size(), " ",fake.size())
            
            #print(fake.data[0,2].size(), input_pngReverse.data[2].size())
            #print((fake.data[0,2]).size())
            tmpsave[:,0,
                 int(opt.imageSize/4+jitterSizeW):int(opt.imageSize/4+opt.imageSize/2+jitterSizeW),
                 int(opt.imageSize/4+jitterSizeH):int(opt.imageSize/4+opt.imageSize/2+jitterSizeH)] += tmpfake[:,0] * tmpReverse[0]
            tmpsave[:,1,
                 int(opt.imageSize/4+jitterSizeW):int(opt.imageSize/4+opt.imageSize/2+jitterSizeW),
                 int(opt.imageSize/4+jitterSizeH):int(opt.imageSize/4+opt.imageSize/2+jitterSizeH)] += tmpfake[:,1] * tmpReverse[0]
            tmpsave[:,2,
                 int(opt.imageSize/4+jitterSizeW):int(opt.imageSize/4+opt.imageSize/2+jitterSizeW),
                 int(opt.imageSize/4+jitterSizeH):int(opt.imageSize/4+opt.imageSize/2+jitterSizeH)] += tmpfake[:,2] * tmpReverse[0]  
            #vutils.save_image(tmpsave,'result/train/masked/masked_center_samples_epoch_%03d.png' % (epoch))
            '''
            tmpsave[:,0,
                 int(opt.imageSize/4+jitterSizeW):int(opt.imageSize/4+opt.imageSize/2+jitterSizeW),
                 int(opt.imageSize/4+jitterSizeH):int(opt.imageSize/4+opt.imageSize/2+jitterSizeH)] = fake.data[:,0]
            tmpsave[:,1,
                 int(opt.imageSize/4+jitterSizeW):int(opt.imageSize/4+opt.imageSize/2+jitterSizeW),
                 int(opt.imageSize/4+jitterSizeH):int(opt.imageSize/4+opt.imageSize/2+jitterSizeH)] = fake.data[:,1]
            tmpsave[:,2,
                 int(opt.imageSize/4+jitterSizeW):int(opt.imageSize/4+opt.imageSize/2+jitterSizeW),
                 int(opt.imageSize/4+jitterSizeH):int(opt.imageSize/4+opt.imageSize/2+jitterSizeH)] = fake.data[:,2]
            vutils.save_image(tmpsave,'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))

    # do checkpointing
    print('%d/%d' % (epoch, opt.niter))
    torch.save({'epoch':epoch+1,
                'state_dict':netG.state_dict()},
                'model/netG_streetview_{0}.pth'.format(epoch%2) )
    torch.save({'epoch':epoch+1,
                'state_dict':netD.state_dict()},
                'model/netlocalD_{0}.pth'.format(epoch%2))

  "Please ensure they have the same size.".format(target.size(), input.size()))
  "Please ensure they have the same size.".format(target.size(), input.size()))


3002/4000
3003/4000
3004/4000
3005/4000
3006/4000
3007/4000
3008/4000
3009/4000
3010/4000
3011/4000
3012/4000
3013/4000
3014/4000
3015/4000
3016/4000
3017/4000
3018/4000
3019/4000
3020/4000
3021/4000
3022/4000
3023/4000
3024/4000
3025/4000
3026/4000
3027/4000
3028/4000
3029/4000
3030/4000
3031/4000
3032/4000
3033/4000
3034/4000
3035/4000
3036/4000
3037/4000
3038/4000
3039/4000
3040/4000
3041/4000
3042/4000
3043/4000
3044/4000
3045/4000
3046/4000
3047/4000
3048/4000
3049/4000
3050/4000
3051/4000
