In [1]:
!pip install imageio



In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!cd 'drive/My Drive/Projet DataSience/dataset_clean_degraded' && ls

/bin/bash: line 0: cd: drive/My Drive/Projet DataSience/dataset_clean_degraded: No such file or directory


In [4]:
import numpy as np 
import pandas as pd 
import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from skimage import transform
from __future__ import print_function, division

from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import sys
import os
#from imageio import imread
import imageio

Using TensorFlow backend.


In [0]:
def load_data(batch_size=1, is_val=True):
  """
   Return 3 couples of images to visualize progress of the networks after epochs
  """

  path_files = 'drive/My Drive/Projet DataScience/Data/Train/dataset_clean_degraded' if not is_val else 'drive/My Drive/Projet DataScience/Data/Val'
  img_res=(128,128)

  clean_images = []
  degraded_images = []

  files = os.listdir(path_files + '/clean/')
  batch_images = np.random.choice(files, size=batch_size)
    
  for image in batch_images:
    clean = imread(path_files + '/clean/' + image)
    degraded = imread(path_files + '/degraded/' + image)

    # decrease resolution
    clean = transform.resize(clean, img_res)
    degraded = transform.resize(degraded, img_res)

    # Data augmentation, trick to avoid overfitting
    #if not is_val and np.random.random() < 0.5:
    #  clean_images = np.fliplr(clean_images)
    #  degraded_images = np.fliplr(degraded_images)

    clean_images.append(clean)
    degraded_images.append(degraded)

  #normalizing images
  clean_images = np.array(clean_images)/127.5 - 1.
  degraded_images = np.array(degraded_images)/127.5 -1.

  return clean_images, degraded_images

In [0]:
def load_batch(batch_size=1, is_val=False):
  """
  Same function as load_data except for the fact that is used during training to load image in batches
  """
  
  path_files = 'drive/My Drive/Projet DataScience/Data/Train/dataset_clean_degraded' if not is_val else 'drive/My Drive/Projet DataScience/Data/Val'
  img_res=(128,128)

  n_batches = batch_size
  files = os.listdir(path_files + '/clean/')

  for i in range(n_batches):
     batch = files[i*batch_size:(i+1)*batch_size]
     
     clean_images = []
     degraded_images = []
     ugly = UglyImage(img_size=(128,128))
     clean_images, degraded_images = ugly.loadimg(32)
     #clean_image[0,,,:]/255

     for image in batch:
       clean = imread(path_files + '/clean/' + image)
       degraded = imread(path_files + '/degraded/' + image)

       # decrease resolution
       clean = transform.resize(clean, img_res)
       degraded = transform.resize(degraded, img_res)

       # Data augmentation, trick to avoid overfitting
       #if not is_val and np.random.random() < 0.5:
       #  clean_images = np.fliplr(clean_images)
       #  degraded_images = np.fliplr(degraded_images)
       
       clean_images.append(clean)
       degraded_images.append(degraded)

     #normalizing images
     clean_images = np.array(clean_images)/127.5 - 1.
     degraded_images = np.array(degraded_images)/127.5 -1.

     yield clean_images, degraded_images

In [0]:
def imread(path):
  #return scipy.misc.imread(path, mode='RGB').astype(np.float)
  return imageio.imread(path).astype(np.float)

In [0]:
def build_generator():
  """
  U-Net Generator (to generate image)
  """

  def conv2d(input_layer, filters, f_size=4, bn=True):
    """
    Layers used during downsampling
    """
    d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(input_layer)
    d = LeakyReLU(alpha=0.2)(d)
    
    if bn:
      d = BatchNormalization(momentum=0.8)(d)
    
    return d

  def deconv2d(input_layer, skip_input, filters, f_size=4, dropout_rate=0):
    """
    Layers used during downsampling
    """
    u = UpSampling2D(size=2)(input_layer)
    u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)

    if dropout_rate:
      u = Dropout(dropout_rate)(u)
    
    u = BatchNormalization(momentum=0.8)(u)
    u = Concatenate()([u, skip_input])

    return u

  #Image Input
  d0 = Input(shape=img_shape)

  #Downsampling
  d1 = conv2d(d0, gf, bn=False)
  d2 = conv2d(d1, gf*2)
  d3 = conv2d(d2, gf*4)
  d4 = conv2d(d3, gf*8)
  d5 = conv2d(d4, gf*8)
  d6 = conv2d(d5, gf*8)
  d7 = conv2d(d6, gf*8)

  # Upsampling
  u1 = deconv2d(d7, d6, gf*8)
  u2 = deconv2d(u1, d5, gf*8)
  u3 = deconv2d(u2, d4, gf*8)
  u4 = deconv2d(u3, d3, gf*4)
  u5 = deconv2d(u4, d2, gf*2)
  u6 = deconv2d(u5, d1, gf)

  u7 = UpSampling2D(size=2)(u6)
  output_img = Conv2D(channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)

  return Model(d0, output_img)

In [0]:
def build_discriminator():
  
  def d_layer(input_layer, filters, f_size=4, bn=True):
    """
    Discriminator layer
    """
    d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(input_layer)
    d = LeakyReLU(alpha=0.2)(d)

    if bn:
      d = BatchNormalization(momentum=0.8)(d)
    return d

  clean_image = Input(shape=img_shape)
  degraded_image = Input(shape=img_shape)

  # Concatenate image and conditioning image by chanels to produce input
  combined_imgs = Concatenate(axis=-1)([clean_image, degraded_image])

  d1 = d_layer(combined_imgs, df, bn=False)
  d2 = d_layer(d1, df*2)
  d3 = d_layer(d2, df*4)
  d4 = d_layer(d3, df*8)

  validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

  return Model([clean_image, degraded_image], validity)

In [18]:
# Input shape
img_rows = 128
img_cols = 128
channels = 3

img_shape = (img_rows, img_cols, channels)

# Calculate output shape of D (PatchGAN)
patchrows = int(img_rows / 2**4)
patchcols = int(img_cols / 2**4)
disc_patch = (patchrows, patchcols, 1)

# Number of filters in the first layer of G and D
gf = 64
df = 64

optimizer = Adam(0.002, 0.5)

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

# Load discriminator weights
filepath_discriminator_weights = 'weights_discriminator_loss_0.102622256.hdf5'
if os.path.isfile(filepath_discriminator_weights):
  discriminator.load_weights(filepath_discriminator_weights, True)
  print('discriminator weights loaded')

# Build the generator
generator = build_generator()

# Load generator weights
filepath_generator_weights = 'weights_generator_loss_5.8947883.hdf5'
if os.path.isfile(filepath_generator_weights):
  generator.load_weights(filepath_generator_weights, True)
  print('generator weights loaded')

# Input images and their conditioning images
clean_image = Input(shape=img_shape)
degraded_image = Input(shape=img_shape)

# By conditioning on degraded_image generate a fake version of clean_image
fake_clean_image = generator(degraded_image)

# For the combined model we will only train the generator
discriminator.trainable = False

# Discriminators determines validity of translated images / condition pairs
valid = discriminator([fake_clean_image, degraded_image])

combined = Model(inputs=[clean_image, degraded_image], outputs=[valid, fake_clean_image])
combined.compile(loss=['mse', 'mae'],
                              loss_weights=[1, 100],
                              optimizer=optimizer)

discriminator weights loaded
generator weights loaded


In [0]:
def show_images(epoch, batch_i):
        
  r, c = 3, 3

  clean_images, degraded_images = load_data(batch_size=3, is_val=True)
  fake_clean_images = generator.predict(degraded_images)

  gen_imgs = np.concatenate([degraded_images, fake_clean_images, clean_images])

  # Rescale images 0 - 1
  gen_imgs = 0.5 * gen_imgs + 0.5

  titles = ['Input', 'Output', 'Ground Truth']
  fig, axs = plt.subplots(r, c)
  fig.set_size_inches(12, 12)
  cnt = 0
  for i in range(r):
    for j in range(c):
      axs[i,j].imshow(gen_imgs[cnt])
      axs[i, j].set_title(titles[i])
      axs[i,j].axis('off')
      cnt += 1
  plt.show()
  plt.close()

In [0]:
def train(epochs, batch_size=1, show_interval=10):
    start_time = datetime.datetime.now()

    # Adversarial loss ground truths
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)

    for epoch in range(epochs):
        #print("start epoch : " + str(epoch))
        for batch_i, (clean_images, degraded_images) in enumerate(load_batch(batch_size)):
            #print("start step : " + str(batch_i))

            # Train disciminator

            # Condition on degraded_images and translated version
            fake_clean_images = generator.predict(degraded_images)

            # Train the disciminators (orginal images = real / generated = Fake)
            d_loss_real = discriminator.train_on_batch([clean_images, degraded_images], valid)
            d_loss_fake = discriminator.train_on_batch([fake_clean_images, degraded_images], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)



            # Train generator
            g_loss = combined.train_on_batch([clean_images, degraded_images], [valid, clean_images])
            elapsed_time = datetime.datetime.now() - start_time
        
        
        if epoch % show_interval == 0:
          # Plot the progress
          print("[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs, d_loss[0], 100*d_loss[1], g_loss[0], elapsed_time))

          # If at show interval => show generated image samples
          show_images(epoch, batch_i)

    # Save models
    generator_weights_filepath = 'weights_generator_loss_' + str(g_loss[0]) + '.hdf5'
    discriminator_weights_filepath = 'weights_discriminator_loss_' + str(d_loss[0]) + '.hdf5'
    
    generator.save_weights(generator_weights_filepath, True)
    discriminator.save_weights(discriminator_weights_filepath, True)      

In [20]:
train(epochs=120, batch_size=32, show_interval=3)

Output hidden; open in https://colab.research.google.com to view.