In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

from utils.loaders import DataLoader

In [None]:
from models.CycleGenerativeAdversarialNetwork import CycleGenerativeAdversarialNetwork

In [None]:
np.zeros((12,) + (4, 4, 1))

In [None]:
(12,) + (4, 4, 1)

---

## Setup Run

In [None]:
# run params
SECTION = 'paint'
RUN_ID = '0001'
DATA_NAME = 'apple2orange'
RUN_FOLDER = 'run/{}/'.format(SECTION)
if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' # 'build' # 

---
## data

In [None]:
IMAGE_SIZE = 128

In [None]:
data_loader = DataLoader(dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE))

---
## model architecture

In [None]:
gan = CycleGenerativeAdversarialNetwork(
    input_dim = (IMAGE_SIZE,IMAGE_SIZE,3),
    learning_rate = 0.0002,
    lambda_discriminator=1.,
    lambda_reconstruction=10.,
    lambda_identity=2.,
    translator_model_type = 'unet',
    translator_first_layer_filters=32,
    discriminator_first_layer_filters=32,
    discriminator_loss='mse', # odd choice, but this is what the book used.
    )


In [None]:
gan.translator_BA.summary()

In [None]:
gan.translator_AB.summary()

In [None]:
gan.discriminator_A.summary()

In [None]:
gan.discriminator_B.summary()

In [None]:
if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/adversarial_model.h5'))

## Train

In [None]:
BATCH_SIZE = 1
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 100

TEST_A_FILE = 'n07740461_14740.jpg'
TEST_B_FILE = 'n07749192_4241.jpg'

In [None]:
gan.train(data_loader, 
          run_folder = RUN_FOLDER,
          epochs=EPOCHS,
          test_A_file = TEST_A_FILE,
          test_B_file = TEST_B_FILE,
          batch_size=BATCH_SIZE,
          sample_interval=PRINT_EVERY_N_BATCHES)

## loss

In [None]:
fig = plt.figure(figsize=(20,10))

plt.plot([x[1] for x in gan.translator_losses], color='green', linewidth=0.1) #DISCRIM LOSS
plt.plot([x[3] for x in gan.translator_losses], color='blue', linewidth=0.1) #CYCLE LOSS
plt.plot([x[5] for x in gan.translator_losses], color='red', linewidth=0.25) #ID LOSS

plt.plot([x[0] for x in gan.translator_losses], color='black', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.ylim(0, 5)

plt.show()