# Pix2Pix
https://arxiv.org/pdf/1611.07004v1.pdf
* Download base dataset from http://cmp.felk.cvut.cz/%7Etylecr1/facade/, unzip and put them to ../datasets/facade/

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from progressbar import ProgressBar
import time, os
from sklearn.utils import shuffle

In [2]:
from lasagne.layers import InputLayer, DenseLayer, batch_norm, Conv2DLayer, concat, Deconv2DLayer, dropout
from lasagne.init import HeUniform
from lasagne.nonlinearities import rectify, sigmoid, leaky_rectify, elu, tanh
from lasagne.updates import adam

 https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29

Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN 5005)


In [3]:
from Tars.models import GAN
from Tars.distributions import Bernoulli, Deterministic
from Tars.load_data import facade



In [4]:
load,plot = facade('../datasets/')
train_x,train_y,test_x,test_y = load(test=True)

seed = 1234
np.random.seed(seed)

n_epoch = 100
n_batch = 1

optimizer = adam
optimizer_params={"learning_rate":2e-4, "beta1":0.5}

In [5]:
def CD_layer(input_layer, num_filters, filter_size, enc_dec="enc", concat_layer=None, nonlinearity=rectify, W=HeUniform(gain="relu"), dropout_layer=False):
    if enc_dec=="enc":
        output_layer = batch_norm(Conv2DLayer(input_layer, num_filters, filter_size, nonlinearity=nonlinearity, stride=2, pad=1, W=W))
    elif enc_dec=="dec":
        if concat_layer is not None:
            input_layer=concat([input_layer,concat_layer])
        output_layer = batch_norm(Deconv2DLayer(input_layer, num_filters, filter_size, nonlinearity=nonlinearity, stride=2, crop=1, W=W))

    if dropout_layer is True:
        output_layer = dropout(output_layer)
                                  
    return output_layer

In [6]:
x = InputLayer((None, 3, 256, 256)) # Target image
y = InputLayer((None, 12, 256, 256)) # Source image
z = InputLayer((None, 10, 2, 2)) # random noise (diffrent from the original paper)

# generator
enc_0 = Conv2DLayer(y, 64, 3, nonlinearity=leaky_rectify, pad=1, W=HeUniform(gain="relu"))
enc_1 = CD_layer(enc_0, 128, 4, nonlinearity=leaky_rectify, enc_dec="enc")
enc_2 = CD_layer(enc_1, 256, 4, nonlinearity=leaky_rectify, enc_dec="enc")
enc_3 = CD_layer(enc_2, 512, 4, nonlinearity=leaky_rectify, enc_dec="enc")
enc_4 = CD_layer(enc_3, 512, 4, nonlinearity=leaky_rectify, enc_dec="enc")
enc_5 = CD_layer(enc_4, 512, 4, nonlinearity=leaky_rectify, enc_dec="enc")
enc_6 = CD_layer(enc_5, 512, 4, nonlinearity=leaky_rectify, enc_dec="enc")
enc_7 = CD_layer(enc_6, 512, 4, nonlinearity=leaky_rectify, enc_dec="enc")

_dec_0 = concat([enc_7,z])
dec_0 = CD_layer(_dec_0, 512, 4, nonlinearity=rectify, enc_dec="dec")
dec_1 = CD_layer(dec_0, 512, 4, concat_layer=enc_6, nonlinearity=rectify, enc_dec="dec")
dec_2 = CD_layer(dec_1, 512, 4, concat_layer=enc_5, nonlinearity=rectify, enc_dec="dec")
dec_3 = CD_layer(dec_2, 512, 4, concat_layer=enc_4, nonlinearity=rectify, enc_dec="dec")
dec_4 = CD_layer(dec_3, 256, 4, concat_layer=enc_3, nonlinearity=rectify, enc_dec="dec")
dec_5 = CD_layer(dec_4, 128, 4, concat_layer=enc_2, nonlinearity=rectify, enc_dec="dec")
dec_6 = CD_layer(dec_5, 64, 4, concat_layer=enc_1, nonlinearity=rectify, enc_dec="dec")
dec_7 = Deconv2DLayer(dec_6, 3, 3, nonlinearity=tanh, crop=1)
p = Deterministic(dec_7, given=[z,y]) #p(x|z,y)
                                
# discriminator
dis_0_0 = CD_layer(x, 32, 4, nonlinearity=leaky_rectify, enc_dec="enc")
dis_0_1 = CD_layer(y, 32, 4, nonlinearity=leaky_rectify, enc_dec="enc")
dis_0 = concat([dis_0_0, dis_0_1])
dis_1 = CD_layer(x, 128, 4, nonlinearity=leaky_rectify, enc_dec="enc")
dis_2 = CD_layer(x, 256, 4, nonlinearity=leaky_rectify, enc_dec="enc")
dis_3 = CD_layer(x, 512, 4, nonlinearity=leaky_rectify, enc_dec="enc")
dis_4 = Conv2DLayer(x, 1, 3, nonlinearity=sigmoid, pad=1)

d = Bernoulli(dis_4,given=[x,y]) #d(t|x,y)

In [7]:
model = GAN(p, d, n_batch,
            optimizer, optimizer,
            optimizer_params, optimizer_params,
            l1_lambda=100)

In [None]:
sample_z  = np.random.standard_normal((1, 10, 2, 2)).astype(np.float32)

def plot_image(t,i,sample_id=0):
    sample_x = p.np_sample_mean_given_x(sample_z, test_y[sample_id][np.newaxis])
    X = plot(sample_x)
    plt.imshow(X[0])
    plt.savefig('../plot/%d/%04d_%02d_generate.jpg'%(t,i,sample_id))
    
    X = plot(test_x[sample_id][np.newaxis])
    plt.imshow(X[0])
    plt.savefig('../plot/%d/%04d_%02d_img.jpg'%(t,i,sample_id))

    X = plot(test_y[sample_id][np.newaxis])
    plt.imshow(X[0])
    plt.savefig('../plot/%d/%04d_%02d_label.jpg'%(t,i,sample_id))    
    
    plt.close()

In [None]:
t = int(time.time())
os.mkdir('../plot/%d' % t)

model.set_seed(seed)
pbar = ProgressBar(maxval=n_epoch).start()
for i in range(1, n_epoch+1):
    train_x, train_y = shuffle(train_x, train_y)
    loss_train = model.train([train_x,train_y])
    if (i%10 == 0) or (i == 1):
        loss_test = model.gan_test([test_x, test_y])
        lw = "epoch = %d, loss (train) = %lf %lf loss (test) = %lf %lf\n" %(i,loss_train[0],loss_train[1],loss_test[0],loss_test[1])
        f = open("../plot/%d/temp.txt" % t, "a")
        f.write(lw)
        f.close()
        print lw
        plot_image(t,i,0)
        
    pbar.update(i)

  0% (  0 of 100) |                                                                              | Elapsed Time: 0:00:00 ETA:  --:--:--

epoch = 1, loss (train) = 2248154.250000 93641.296875 loss (test) = 48610.371094 90600.414062



  9% (  9 of 100) |#######                                                                         | Elapsed Time: 0:07:05 ETA: 1:11:45

epoch = 10, loss (train) = 1753739.625000 91018.023438 loss (test) = 45887.765625 88833.812500



 19% ( 19 of 100) |###############                                                                 | Elapsed Time: 0:14:58 ETA: 1:03:29

epoch = 20, loss (train) = 1307852.000000 90961.648438 loss (test) = 45229.792969 90073.757812



 20% ( 20 of 100) |################                                                                | Elapsed Time: 0:15:47 ETA: 1:03:03