# CycleGAN train

In [1]:
import os
import matplotlib.pyplot as plt

from models.cycleGAN import CycleGAN
from utils.loaders import DataLoader

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [12]:

# run params
SECTION = 'paint'
RUN_ID = '0001'
DATA_NAME = 'apple2orange'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'load' # 'build' # 

# data

In [13]:
IMAGE_SIZE = 128

In [14]:

data_loader = DataLoader(dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE))


# architecture

In [15]:

if mode == 'build':

    gan = CycleGAN(
        input_dim = (IMAGE_SIZE,IMAGE_SIZE,3)
        ,learning_rate = 0.0002
        , buffer_max_length = 50
        , lambda_validation = 1
        , lambda_reconstr = 10
        , lambda_id = 2
        , generator_type = 'unet'
        , gen_n_filters = 32
        , disc_n_filters = 32
        )

    gan.save(RUN_FOLDER)

else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))
    


OSError: Unable to open file (unable to open file: name = 'run/paint/0001_apple2orange/weights/weights.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [6]:
gan.g_BA.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 64, 64, 32)   1568        input_4[0][0]                    
__________________________________________________________________________________________________
instance_normalization_14 (Inst (None, 64, 64, 32)   0           conv2d_19[0][0]                  
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 64, 64, 32)   0           instance_normalization_14[0][0]  
__________________________________________________________________________________________________
conv2d_20 

In [7]:
gan.g_AB.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 64, 64, 32)   1568        input_3[0][0]                    
__________________________________________________________________________________________________
instance_normalization_7 (Insta (None, 64, 64, 32)   0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 64, 64, 32)   0           instance_normalization_7[0][0]   
__________________________________________________________________________________________________
conv2d_12 

In [8]:
gan.d_A.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 64, 32)        1568      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 64)        32832     
_________________________________________________________________
instance_normalization_1 (In (None, 32, 32, 64)        0         
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 128)       131200    
__________

In [9]:
gan.d_B.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 64, 64, 32)        1568      
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 32, 32, 64)        32832     
_________________________________________________________________
instance_normalization_4 (In (None, 32, 32, 64)        0         
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 16, 16, 128)       131200    
__________

# train

In [10]:
BATCH_SIZE = 1
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 10

TEST_A_FILE = 'n07740461_14740.jpg'
TEST_B_FILE = 'n07749192_4241.jpg'

In [11]:
gan.train(data_loader
        , run_folder = RUN_FOLDER
        , epochs=EPOCHS
        , test_A_file = TEST_A_FILE
        , test_B_file = TEST_B_FILE
        , batch_size=BATCH_SIZE
        , sample_interval=PRINT_EVERY_N_BATCHES)
        

  if issubdtype(ts, int):
  elif issubdtype(type(size), float):
  'Discrepancy between trainable weights and collected trainable'


[Epoch 0/200] [Batch 0/995] [D loss: 1.585021, acc:  28%] [G loss: 21.178360, adv: 2.265190, recon: 1.566643, id: 1.623370] time: 0:00:16.474303 
[Epoch 0/200] [Batch 1/995] [D loss: 1.155459, acc:  39%] [G loss: 19.084061, adv: 2.154721, recon: 1.425696, id: 1.336189] time: 0:00:25.547761 
[Epoch 0/200] [Batch 2/995] [D loss: 0.857845, acc:  44%] [G loss: 15.271852, adv: 1.815620, recon: 1.095064, id: 1.252795] time: 0:00:26.692618 
[Epoch 0/200] [Batch 3/995] [D loss: 0.727391, acc:  41%] [G loss: 15.037607, adv: 1.497770, recon: 1.134178, id: 1.099030] time: 0:00:27.969016 
[Epoch 0/200] [Batch 4/995] [D loss: 0.618439, acc:  42%] [G loss: 15.101970, adv: 1.553823, recon: 1.108093, id: 1.233609] time: 0:00:29.177382 
[Epoch 0/200] [Batch 5/995] [D loss: 0.715419, acc:  45%] [G loss: 12.531412, adv: 1.287579, recon: 0.925545, id: 0.994191] time: 0:00:30.414578 
[Epoch 0/200] [Batch 6/995] [D loss: 0.611518, acc:  43%] [G loss: 15.110529, adv: 1.355310, recon: 1.142199, id: 1.166613] 

KeyboardInterrupt: 

# loss

In [None]:
fig = plt.figure(figsize=(20,10))

plt.plot([x[1] for x in gan.g_losses], color='green', linewidth=0.1) #DISCRIM LOSS
# plt.plot([x[2] for x in gan.g_losses], color='orange', linewidth=0.1)
plt.plot([x[3] for x in gan.g_losses], color='blue', linewidth=0.1) #CYCLE LOSS
# plt.plot([x[4] for x in gan.g_losses], color='orange', linewidth=0.25)
plt.plot([x[5] for x in gan.g_losses], color='red', linewidth=0.25) #ID LOSS
# plt.plot([x[6] for x in gan.g_losses], color='orange', linewidth=0.25)

plt.plot([x[0] for x in gan.g_losses], color='black', linewidth=0.25)

# plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.ylim(0, 5)

plt.show()