In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import sys
import os

py_file_location = "/content/drive/MyDrive/image/FUnIE-GAN-master/TF-Keras"
sys.path.append(os.path.abspath(py_file_location))

In [3]:
import os
import numpy as np
from os.path import join, exists
## local libs
from nets.funieGAN import FUNIE_GAN
from utils.data_utils import DataLoader
from utils.plot_utils import save_val_samples_funieGAN

In [4]:
data_dir = "/content/drive/MyDrive/image/dataset/EUVP Dataset/Paired/"
dataset_name = "underwater_imagenet" # options: {'underwater_imagenet', 'underwater_dark'}
data_loader = DataLoader(join(data_dir, dataset_name), dataset_name)
## create dir for log and (sampled) validation data
samples_dir = join("/content/drive/MyDrive/image/FUnIE-GAN-master/data/samples/funieGAN/", dataset_name)
checkpoint_dir = join("/content/drive/MyDrive/image/FUnIE-GAN-master/data/checkpoints/funieGAN/", dataset_name)
if not exists(samples_dir): os.makedirs(samples_dir)
if not exists(checkpoint_dir): os.makedirs(checkpoint_dir)


3700 training pairs



In [5]:
num_epoch = 200
batch_size = 4
val_interval = 2000
N_val_samples = 3
save_model_interval = data_loader.num_train//batch_size
num_step = num_epoch*save_model_interval

In [6]:
## load model arch
funie_gan = FUNIE_GAN()
## ground-truths for adversarial loss
valid = np.ones((batch_size,) + funie_gan.disc_patch)
fake = np.zeros((batch_size,) + funie_gan.disc_patch)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
KerasTensor(type_spec=TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name='input_6'), name='input_6', description="created by layer 'input_6'")


In [None]:
step = 0
all_D_losses = []; all_G_losses = []
while (step <= num_step):
    for _, (imgs_distorted, imgs_good) in enumerate(data_loader.load_batch(batch_size)):
        ##  train the discriminator
        imgs_fake = funie_gan.generator.predict(imgs_distorted)
        d_loss_real = funie_gan.discriminator.train_on_batch([imgs_good, imgs_distorted], valid)
        d_loss_fake = funie_gan.discriminator.train_on_batch([imgs_fake, imgs_distorted], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        ## train the generator
        g_loss = funie_gan.combined.train_on_batch([imgs_good, imgs_distorted], [valid, imgs_good])
        ## increment step, save losses, and print them 
        step += 1; all_D_losses.append(d_loss[0]);  all_G_losses.append(g_loss[0]);
        if step%50==0:
            print ("Step {0}/{1}: lossD: {2}, lossG: {3}".format(step, num_step, d_loss[0], g_loss[0])) 
        ## validate and save generated samples at regular intervals 
        if (step % val_interval==0):
            imgs_distorted, imgs_good = data_loader.load_val_data(batch_size=N_val_samples)
            imgs_fake = funie_gan.generator.predict(imgs_distorted)
            gen_imgs = np.concatenate([imgs_distorted, imgs_fake, imgs_good])
            gen_imgs = 0.5 * gen_imgs + 0.5 # Rescale to 0-1
            save_val_samples_funieGAN(samples_dir, gen_imgs, step, N_samples=N_val_samples)
        ## save model and weights
        if (step % save_model_interval==0):
            model_name = join(checkpoint_dir, ("model_%d" %step))
            with open(model_name+"_.json", "w") as json_file:
                json_file.write(funie_gan.generator.to_json())
            funie_gan.generator.save_weights(model_name+"_.h5")
            print("\nSaved trained model in {0}\n".format(checkpoint_dir))
        ## sanity
        if (step>=num_step): break

Step 50/185000: lossD: 0.39623644948005676, lossG: 0.5990022420883179
Step 100/185000: lossD: 0.31725554168224335, lossG: 0.5517651438713074
Step 150/185000: lossD: 0.4332434833049774, lossG: 0.43959829211235046
Step 200/185000: lossD: 0.25214336067438126, lossG: 0.5471800565719604
Step 250/185000: lossD: 0.24044166505336761, lossG: 0.4721641540527344
Step 300/185000: lossD: 0.164479598402977, lossG: 0.5547671914100647
Step 350/185000: lossD: 0.09350768476724625, lossG: 0.5931472778320312
Step 400/185000: lossD: 0.09400375932455063, lossG: 0.3725722134113312
Step 450/185000: lossD: 0.08519197441637516, lossG: 0.4613702893257141
Step 500/185000: lossD: 0.09560707584023476, lossG: 0.25143301486968994
Step 550/185000: lossD: 0.0757851405069232, lossG: 0.2448822259902954
Step 600/185000: lossD: 0.09753982536494732, lossG: 0.2579505145549774
Step 650/185000: lossD: 0.05247998237609863, lossG: 0.3017047345638275
Step 700/185000: lossD: 0.10431469790637493, lossG: 0.45131316781044006
Step 750