<a href="https://colab.research.google.com/github/ghaiszaher/Foggy-CycleGAN/blob/master/Foggy_CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CycleFoggyGAN

## Set up the input pipeline

In [None]:
import sys
colab = 'google.colab' in sys.modules
if colab:
    # noinspection PyBroadException
    try:
        %tensorflow_version 2.x
    except Exception:
        pass
import tensorflow as tf

In [None]:
# noinspection PyUnresolvedReferences
print(tf.__version__)

In [None]:
import tensorflow_datasets as tfds

import os
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
if colab:
    os.environ['PROJECT_DIR'] = project_dir = '/content/Foggy-CycleGAN'
    if os.path.isdir(project_dir):
        input("Project already exists... files in existing folder\n"+
              "({}) will be deleted.\nEnter anything to continue\n"
              .format(project_dir))
        !rm -r $PROJECT_DIR
    import getpass
    os.environ['GITHUB_USER'] = input("Enter Username: ")
    os.environ['GITHUB_PASSWORD'] = getpass.getpass("Enter Password: ")
    !cd /content && git clone https://$GITHUB_USER:$GITHUB_PASSWORD@github.com/ghaiszaher/Foggy-CycleGAN.git
    os.chdir(project_dir)
    print("Done. Project cloned to " + project_dir)

## Input Pipeline

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 5 if colab else 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
from lib.dataset import DatasetInitializer
datasetInit = DatasetInitializer(256, 256)
(train_clear, train_fog), (test_clear, test_fog) = datasetInit.prepare_dataset(BUFFER_SIZE, BATCH_SIZE, AUTOTUNE)

In [None]:
#TODO: fix sample images and read them from folder
sample_clear = next(iter(train_clear))
sample_fog = next(iter(train_fog))

## Build Generator

In [None]:
from lib.models import ModelsBuilder
OUTPUT_CHANNELS = 3
models_builder = ModelsBuilder()

In [None]:
generator_clear2fog = models_builder.build_generator(use_transmission_map=True)
generator_fog2clear = models_builder.build_generator(use_transmission_map=False)

In [None]:
tf.keras.utils.plot_model(generator_clear2fog, show_shapes=True, dpi=64, to_file='generator_clear2fog.png');

In [None]:
tf.keras.utils.plot_model(generator_fog2clear, show_shapes=True, dpi=64, to_file='generator_fog2clear.png');

In [None]:
from lib.plot import plot_generators_predictions
plot_generators_predictions(generator_clear2fog, sample_clear, generator_fog2clear, sample_fog).show()

## Build Discriminator

In [None]:
discriminator_fog = models_builder.build_discriminator()
discriminator_clear = models_builder.build_discriminator()

In [None]:
tf.keras.utils.plot_model(discriminator_fog, show_shapes=True, dpi=64, to_file="discriminator_fog.png");

In [None]:
from lib.plot import plot_discriminators_predictions
plot_discriminators_predictions(discriminator_clear, sample_clear, discriminator_fog, sample_fog).show()

## Checkpoints

In [None]:
mount_path = None #to suppress warnings
drive_project_path = None
if colab:
    # noinspection PyUnresolvedReferences
    from google.colab import drive
    mount_path = '/content/drive'
    drive.mount(mount_path)
    drive_project_path = os.path.join(mount_path,"My Drive/Colab Notebooks/Cycle-Foggy-GAN/")
    weights_path = os.path.join(drive_project_path, 'weights/')
else:
    weights_path = "./weights/"

In [None]:
from lib.train import Trainer
trainer = Trainer(generator_clear2fog, generator_fog2clear,
                 discriminator_fog, discriminator_clear)

trainer.configure_checkpoint(weights_path = weights_path)


## Training 

In [None]:
use_tensorboard = True #@param{type:"boolean"}

In [None]:
if use_tensorboard:
    import tensorboard
    tb = tensorboard.program.TensorBoard()
    if colab:
        trainer.tensorboard_base_logdir = os.path.join(drive_project_path,"tensorboard_logs/")
    tb.configure(argv=[None, '--logdir', trainer.tensorboard_base_logdir])
    url = tb.launch()
    if colab:
        tensorboard.notebook.display(port=6006, height=1000)
    else:
        print(url)


In [None]:
if colab:
    trainer.image_log_path = os.path.join(drive_project_path,"image_logs/")
    trainer.config_path  = os.path.join(drive_project_path,"trainer_config.json")

In [None]:
trainer.load_config()

In [None]:
trainer.train(
    train_clear, train_fog,    
    clear_output_callback=lambda: clear_output(wait=True),
    use_tensorboard = use_tensorboard,
    sample_test =(sample_clear, sample_fog),
    load_config_first=False
)

## Testing 

In [None]:
for clear, fog in zip(test_clear.take(5), test_fog.take(5)):
    plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)