In [1]:
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.4
config.gpu_options.allow_growth=True
set_session(tf.Session(config=config))

Using TensorFlow backend.


In [2]:
import os
import matplotlib.pyplot as plt
import numpy as np

from utils.loaders import DataLoader

In [3]:
from models.CycleGenerativeAdversarialNetwork import CycleGenerativeAdversarialNetwork

---

## Setup Run

In [4]:
# run params
SECTION = 'paint'
RUN_ID = '0002'
DATA_NAME = 'monet2photo'
RUN_FOLDER = 'run/{}/'.format(SECTION)
if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' # 'build' # 

---
## data

In [5]:
IMAGE_SIZE = 256

In [6]:
data_loader = DataLoader(dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE))

---
## model architecture

In [7]:
gan = CycleGenerativeAdversarialNetwork(
    input_dim = (IMAGE_SIZE,IMAGE_SIZE,3),
    learning_rate = 0.0002,
    lambda_discriminator=1.,
    lambda_reconstruction=10.,
    lambda_identity=5.,
    translator_model_type = 'resnet',
    translator_first_layer_filters=32,
    discriminator_first_layer_filters=64,
    discriminator_loss='mse', # odd choice, but this is what the book used.
    )


W0617 23:00:46.018660 140691014612736 deprecation.py:506] From /home/comadan/.venv/gdl/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [8]:
gan.translator_BA.summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 256, 256, 32) 4736        input_3[0][0]                    
__________________________________________________________________________________________________
instance_normalization_7 (Insta (None, 256, 256, 32) 0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 256, 256, 32) 0           instance_normalization_7[0][0]   
____________________________________________________________________________________________

In [9]:
gan.translator_AB.summary()

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 256, 256, 32) 4736        input_4[0][0]                    
__________________________________________________________________________________________________
instance_normalization_30 (Inst (None, 256, 256, 32) 0           conv2d_35[0][0]                  
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 256, 256, 32) 0           instance_normalization_30[0][0]  
____________________________________________________________________________________________

In [10]:
gan.discriminator_A.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 128, 64)      3136      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 128, 128, 64)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 64, 64, 128)       131200    
_________________________________________________________________
instance_normalization_1 (In (None, 64, 64, 128)       0         
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 64, 64, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 32, 32, 256)       5245

In [11]:
gan.discriminator_B.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 128, 128, 64)      3136      
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 128, 128, 64)      0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 64, 64, 128)       131200    
_________________________________________________________________
instance_normalization_4 (In (None, 64, 64, 128)       0         
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 64, 64, 128)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 32, 32, 256)       5245

In [12]:
gan.adversarial_model.summary()

Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
input_5 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
model_3 (Model)                 (None, 256, 256, 3)  2850563     input_6[0][0]                    
                                                                 model_4[1][0]                    
                                                                 input_5[0][0]                    
____________________________________________________________________________________________

  'Discrepancy between trainable weights and collected trainable'


In [13]:
if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/adversarial_model.h5'))

## Train

In [None]:
BATCH_SIZE = 1
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 100

TEST_A_FILE = '00010.jpg'
TEST_B_FILE = '2014-08-04 20:20:12.jpg'

In [None]:
gan.train(data_loader, 
          run_folder = RUN_FOLDER,
          epochs=EPOCHS,
          test_A_file = TEST_A_FILE,
          test_B_file = TEST_B_FILE,
          batch_size=BATCH_SIZE,
          sample_interval=PRINT_EVERY_N_BATCHES)

## loss

In [None]:
fig = plt.figure(figsize=(20,10))

plt.plot([x[1] for x in gan.translator_losses], color='green', linewidth=0.1) #DISCRIM LOSS
plt.plot([x[3] for x in gan.translator_losses], color='blue', linewidth=0.1) #CYCLE LOSS
plt.plot([x[5] for x in gan.translator_losses], color='red', linewidth=0.25) #ID LOSS

plt.plot([x[0] for x in gan.translator_losses], color='black', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.ylim(0, 5)

plt.show()