In [None]:
pip install fastai

In [2]:
import numpy as np
import glob

In [None]:
from fastai.data.external import untar_data, URLs
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"
paths = glob.glob(coco_path + "/*.jpg") # Grabbing all the image file names
np.random.seed(123)
paths_subset = np.random.choice(paths, 2_000, replace=False) # choosing 1000 images randomly
rand_idxs = np.random.permutation(2_000)
train_idxs = rand_idxs[:1600] # choosing the first 8000 as training set
val_idxs = rand_idxs[1600:] # choosing last 2000 as validation set
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]
print(len(train_paths), len(val_paths))

B&W image **creation**

In [17]:
from PIL import Image 
#from tqdm.notebook import tqdm
path_target = "/root/.fastai/data/coco_sample/train_sample/"
path_train = "/root/.fastai/data/coco_sample/train/"
imgurls = !ls -1 {path_target}
for image in imgurls:
  try:
    img = PIL.Image.open(path_target + image)
    img = img.convert('L')
    img.save(path_train + image)
  except:
    pass




In [None]:

import PIL
path_train = "/root/.fastai/data/coco_sample/train/"
imgurls = !ls -1 {path_train}
for image in tqdm(imgurls):
  img = PIL.Image.open(path_train + image)
  img = img.convert(mode='RGB')
  img.save(path_train + image)

## Path Initialization

In [None]:

path_train = "/root/.fastai/data/coco_sample/train/"

path_target = "/root/.fastai/data/coco_sample/train_sample/"


imgurls = !ls -1 {path_train}
print(len(imgurls))

In [None]:

number_of_images = len(imgurls)
print(f"The number of total images are {number_of_images}")
train_percentage = 0.8
train_urls = imgurls[:int(train_percentage*number_of_images)]
test_urls = imgurls[int(train_percentage*number_of_images)+1:]
print(f"The number of total images are {len(train_urls)} and in test {len(test_urls)} ")

# Data Augmentation

-Random Jitter that enlarges the image to 572x572 and crops a random piece of 512x512.


-Flip a part of the random Jitter function that mirrors (or not) the image depending on a random variable

In [21]:
##Data augmentation
img_size = 512
import tensorflow as tf


@tf.function
def resize(input_img, tar_img, img_size):
    input_img = tf.image.resize(input_img, [img_size, img_size])
    tar_img = tf.image.resize(tar_img, [img_size, img_size])
    
    return input_img, tar_img


def normalize(input_img, tar_img):
    input_img = (input_img/255.) - 1
    tar_img = (tar_img/255.) - 1
    return input_img, tar_img

def random_jitter(input_img, tar_img):
    input_img, tar_img = resize(input_img, tar_img, 572)

  
    stacked_image = tf.stack([input_img, tar_img], axis=0)
  
    cropped_image = tf.image.random_crop(stacked_image, size=[2, img_size, img_size, 3])
    
    input_img, tar_img = cropped_image[0], cropped_image[1]
    if tf.random.uniform(()) > 0.5:
        input_img = tf.image.flip_left_right(input_img)
        tar_img = tf.image.flip_left_right(tar_img)
    return input_img, tar_img

# Loading with images

In [22]:
def load_image(filename, augment=True):
    input_img = tf.cast(tf.image.decode_jpeg(tf.io.read_file(path_train + filename)), tf.float32)[..., :3]
    tar_img = tf.cast(tf.image.decode_jpeg(tf.io.read_file(path_target + filename)), tf.float32)[..., :3]
    input_img, tar_img = resize(input_img, tar_img, img_size)
    if augment:
        input_img, tar_img = random_jitter(input_img, tar_img)
  
    input_img, tar_img = normalize(input_img, tar_img)
    return input_img, tar_img

def load_train_image(filename):
    return load_image(filename)

def load_test_image(filename):
    return load_image(filename, False)

In [None]:
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(((load_train_image(train_urls[0])[0]) + 1. ) / 2.);
plt.figure()
plt.imshow(((load_train_image(train_urls[0])[1]) + 1. ) / 2.);

# Dataset creation

In [24]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_urls)
train_dataset = train_dataset.map(load_train_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)

test_dataset = tf.data.Dataset.from_tensor_slices(test_urls)
test_dataset = test_dataset.map(load_test_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(1)

# Model Creation

In [25]:
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential, Model

def downsample(filters, batch_norm=True):
  
  result = Sequential()
  initializer = tf.random_normal_initializer(0, 0.02)
  
 
  result.add(Conv2D(filters=filters,strides=2, kernel_size=4, padding='same',
                    kernel_initializer=initializer, use_bias=not batch_norm))
  #Batch
  if batch_norm:
    result.add(BatchNormalization())
  #Activation LeakyRelu
  result.add(LeakyReLU())
  return result

def upsample(filters, dropout=True):
  result = Sequential()
  initializer = tf.random_normal_initializer(0, 0.02)
  
  #Conv
  result.add(Conv2DTranspose(filters=filters,strides=2, kernel_size=4, padding='same',
                             kernel_initializer=initializer, use_bias=False))
  #Batch
  if dropout:
    result.add(Dropout(0.5))
  #Activation LeakyRelu
  result.add(ReLU())
  return result


def Generator():
  
  initializer = tf.random_normal_initializer(0, 0.02)
  
  inputs = Input(shape=[None, None, 3]) # (b, 256, 256, 64)
  
  down_stack = [
      downsample(64, batch_norm=False), # (b, 128, 128, 64)
      downsample(128), # (b, 64, 64, 128)
      downsample(256), # (b, 32, 32, 256)
      downsample(512), # (b, 16, 16, 512)
      downsample(512), # (b, 8, 8, 512)
      downsample(512), # (b, 4, 4, 512)
      downsample(512), # (b, 2, 2, 512)
      downsample(512)  # (b, 1, 1, 512)
  ]
  
  up_stack = [
      upsample(512), # (b, 2, 2, 1024)
      upsample(512), # (b, 4, 4, 1024)
      upsample(512), # (b, 8, 8, 1024)
      upsample(512, dropout=False), # (b, 16, 16, 1024)
      upsample(256, dropout=False), # (b, 32, 32, 512)
      upsample(128, dropout=False), # (b, 64, 64, 256)
      upsample(64, dropout=False), # (b, 128, 128, 128)

  ]

  last = Conv2DTranspose(filters=3, kernel_size=4, strides=2, padding="same", kernel_initializer=initializer, 
                         activation='tanh')

  x = inputs
  s = []
  concat = Concatenate()
  for enc in down_stack:
    x = enc(x)
    s.append(x)
  s = reversed(s[:-1])
    
  for dec, sk in zip(up_stack, s):
    x = dec(x)
    x = concat([x, sk])    
 
  output = last(x)
  
  return Model(inputs=inputs, outputs=output)


def Discriminator():
  real_input = Input(shape=[None, None, 3], name="real_image")
  fake_input = Input(shape=[None, None, 3], name="fake_image")
  
  con = concatenate([real_input, fake_input])
  
  initializer = tf.random_normal_initializer(0, 0.02)
  
  dec1 = downsample(64, batch_norm=False)(con)
  dec2 = downsample(128)(dec1)
  dec3 = downsample(128)(dec2)
  dec4 = downsample(128)(dec3)
  
  output = Conv2D(filters=1, kernel_size=4, strides=1, kernel_initializer=initializer, padding='same')(dec4)
  return Model(inputs=[real_input, fake_input], outputs=output)

# Model initialization

In [26]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

generator = Generator()
discriminator = Discriminator()

# Losses

In [27]:

loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discrimator_loss(disc_real_output, disc_generated_output):
  
  
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
  
  fake_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
  
  total_disc_loss = real_loss + fake_loss
  
  return total_disc_loss
  
  
LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  

  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
  
  total_gen_loss =  gan_loss +(LAMBDA * l1_loss)
  
  return total_gen_loss

# Generation of images for the model

In [28]:
def generate_images(model, test_input, tar, save_filename=False, display_imgs=True):
  prediction = model(test_input, training=True)
  
  if save_filename:
    tf.keras.preprocessing.image.save_img('/root/.fastai/data/coco_sample/Output/' + save_filename + '.jpg', prediction[0,...])
    
  plt.figure(figsize=(10,10))
  
  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']
  
  if display_imgs:
    for i in range(3):
      plt.subplot(1, 3, i+1)
      plt.title(title[i])
      plt.imshow(display_list[i] * 1 + 1)
      plt.axis('off')
  plt.show()

# Training

In [29]:
def train_step(input_image, target):
  
  with tf.GradientTape() as gen_tape, tf.GradientTape() as discr_tape:
    
    output_image = generator(input_image, training=True) 
    
    output_gen_discr = discriminator([output_image, input_image], training=True) 
    
    output_target_discr = discriminator([target, input_image], training=True) 

    discr_loss = discrimator_loss(output_target_discr, output_gen_discr)  
    gen_loss = generator_loss(output_gen_discr,output_image, target)
    
    generator_grads = gen_tape.gradient(gen_loss, generator.trainable_variables) 
    discriminator_grads = discr_tape.gradient(discr_loss, discriminator.trainable_variables)
  
    generator_optimizer.apply_gradients(zip(generator_grads, generator.trainable_variables)) 
    discriminator_optimizer.apply_gradients(zip(discriminator_grads, discriminator.trainable_variables))

In [41]:
def train(dataset, epochs):
  for epoch in range(epochs):
    imgi = 0
    for input_image, target in dataset:
      if(imgi==113 or imgi==114 or imgi==115 or imgi==116):
        continue
      else:
        imgi += 1
        print ('epoch ' + str(epoch) + ' - train: ' + str(imgi) + '/' + str(len(train_urls)))
        train_step (input_image, target)
        clear_output(wait=True)
      
      
        imgi = 0
        for inp, tar in test_dataset.take(1):
          generate_images(generator, inp, tar, str(imgi) + '_' + str(epoch), display_imgs=True)
          imgi +=1
      
    ##Saving
    

In [39]:
from IPython.display import clear_output

In [None]:
train(train_dataset, 25)

In [None]:
imgi = 2
for inp, tar in test_dataset.take(20):
  generate_images(generator, inp, tar, str(imgi) + '_' + str(300), display_imgs=True)