In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [52]:
import os
import numpy as np
import random
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
from tqdm import tqdm

Dataset

In [51]:
root_path = '/Applications/ML projects/Image Colorisation/archive/indoorCVPR_09/Images'

In [55]:
def read_images(root_path):
  colored = []
  gray = []
  for FOLDER in tqdm(os.listdir(root_path)):
    if FOLDER == '.DS_Store':
      continue
    folder_path = os.path.join(root_path, FOLDER)
    for IMAGE in os.listdir(folder_path):
      try:
        image_path = os.path.join(folder_path, IMAGE)
        img = cv2.imread(image_path)
        img = cv2.resize(img, (128, 128))
        img = img.astype(np.float32)
        img = img / 255.0
        gray_image = np.mean(img, axis=2, keepdims=True)
        gray_image= np.concatenate([gray_image] * 3, axis=2)
        colored.append(img)
        gray.append(gray_image)
      except:
        continue
  return np.array(gray), np.array(colored)


In [56]:
gray, colored = read_images(root_path)

100%|██████████| 68/68 [01:21<00:00,  1.19s/it]


In [57]:
filename = 'gc_128.npz'
np.savez_compressed(filename, gray, colored)
print('Saved Dataset: ', filename)

Saved Dataset:  gc_128.npz


Model Architecture - MIRNet

In [58]:
from keras.layers import Add, GlobalAveragePooling2D, Conv2D, Concatenate, MaxPooling2D, UpSampling2D, Input, BatchNormalization, LeakyReLU, Activation
from keras import Model
from keras.optimizers import Adam

In [59]:
def selective_kernel_feature_fusion(multi_scale_feature1, multi_scale_feature2, multi_scale_feature3):
  channels = list(multi_scale_feature1.shape)[-1]
  combined_feature = Add()(
      [multi_scale_feature1, multi_scale_feature2, multi_scale_feature3]
  )
  gap = GlobalAveragePooling2D()(combined_feature)
  channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
  compact_feature_representation = Conv2D(
      filters=channels // 8, kernel_size=(1, 1), activation='relu'
  )(channel_wise_statistics)
  feature_descriptor1 = Conv2D(channels, kernel_size=(1, 1), activation='softmax')(compact_feature_representation)
  feature_descriptor2 = Conv2D(channels, kernel_size=(1, 1), activation='softmax')(compact_feature_representation)
  feature_descriptor3 = Conv2D(channels, kernel_size=(1, 1), activation='softmax')(compact_feature_representation)
  feature1 = multi_scale_feature1 * feature_descriptor1
  feature2 = multi_scale_feature2 * feature_descriptor2
  feature3 = multi_scale_feature3 * feature_descriptor3
  aggregate_feature = Add()([feature1, feature2, feature3])
  return aggregate_feature

In [60]:
def channel_attention_block(input_tensor):
  channels = list(input_tensor.shape)[-1]
  gap = GlobalAveragePooling2D()(input_tensor)
  feature_descriptor = tf.reshape(gap, shape=(-1, 1, 1, channels))
  feature_activations = Conv2D(
      filters=channels // 8, kernel_size=(1, 1), activation='relu'
  )(feature_descriptor)
  feature_activations = Conv2D(
      filters=channels, kernel_size=(1, 1), activation='sigmoid'
  )(feature_activations)
  return input_tensor * feature_activations

In [61]:
def spatial_attention_block(input_tensor):
  average_pooling = tf.reduce_mean(input_tensor, axis=-1)
  average_pooling = tf.expand_dims(average_pooling, axis=-1)
  max_pooling = tf.reduce_max(input_tensor, axis=-1)
  max_pooling = tf.expand_dims(max_pooling, axis=-1)
  concatenated = Concatenate(axis=-1)([average_pooling, max_pooling])
  feature_map = Conv2D(1, kernel_size=(1, 1))(concatenated)
  feature_map = tf.nn.sigmoid(feature_map)
  return input_tensor * feature_map

In [62]:
def dual_attention_unit_block(input_tensor):
  channels = list(input_tensor.shape)[-1]
  feature_map = Conv2D(
      channels, kernel_size=(3, 3), padding='same', activation='relu'
  )(input_tensor)
  feature_map = Conv2D(
      channels, kernel_size=(3, 3), padding='same'
  )(feature_map)

  channel_attention = channel_attention_block(feature_map)
  spatial_attention = spatial_attention_block(feature_map)
  concatenation = Concatenate(axis=-1)([channel_attention, spatial_attention])
  concatenation = Conv2D(channels, kernel_size=(1, 1))(concatenation)

  return Add()([input_tensor, concatenation])

In [63]:
def down_sampling_block(input_tensor):
  channels = list(input_tensor.shape)[-1]
  main_branch = Conv2D(channels, kernel_size=(1, 1), activation='relu')(input_tensor)
  main_branch = Conv2D(channels, kernel_size=(3, 3), padding='same', activation='relu')(main_branch)
  main_branch = MaxPooling2D()(main_branch)
  main_branch = Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
  skip_branch = MaxPooling2D()(input_tensor)
  skip_branch = Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
  return Add()([main_branch, skip_branch])

In [64]:
def up_sampling_block(input_tensor):
  channels = list(input_tensor.shape)[-1]
  main_branch = Conv2D(channels, kernel_size=(1, 1), activation='relu')(input_tensor)
  main_branch = Conv2D(channels, kernel_size=(3, 3), padding='same', activation='relu')(main_branch)
  main_branch = UpSampling2D()(main_branch)
  main_branch = Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
  skip_branch = UpSampling2D()(input_tensor)
  skip_branch = Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
  return Add()([main_branch, skip_branch])

In [65]:
def multi_scale_residual_block(input_tensor, channels):
  feature1 = input_tensor
  feature2 = down_sampling_block(feature1)
  feature3 = down_sampling_block(feature2)

  feature1_dau1 = dual_attention_unit_block(feature1)
  feature2_dau1 = dual_attention_unit_block(feature2)
  feature3_dau1 = dual_attention_unit_block(feature3)

  skff1 = selective_kernel_feature_fusion(
      feature1_dau1,
      up_sampling_block(feature2_dau1),
      up_sampling_block(up_sampling_block(feature3_dau1))
  )

  skff2 = selective_kernel_feature_fusion(
      down_sampling_block(feature1_dau1),
      feature2_dau1,
      up_sampling_block(feature3_dau1)
  )

  skff3 = selective_kernel_feature_fusion(
      down_sampling_block(down_sampling_block(feature1_dau1)),
      down_sampling_block(feature2_dau1),
      feature3_dau1
  )

  feature1_dau2 = dual_attention_unit_block(skff1)
  feature2_dau2 = up_sampling_block(dual_attention_unit_block(skff2))
  feature3_dau2 = up_sampling_block(up_sampling_block(dual_attention_unit_block(skff3)))

  skff_ = selective_kernel_feature_fusion(feature1_dau2, feature2_dau2, feature3_dau2)
  feature = Conv2D(channels, kernel_size=(3, 3), padding='same')(skff_)

  return Add()([input_tensor, feature])

In [66]:
def recursive_residual_block(input_tensor, msrb_count, channels):
  x = Conv2D(channels, kernel_size=(3, 3), padding='same')(input_tensor)
  for _ in range(msrb_count):
    x = multi_scale_residual_block(x, channels)
  x = Conv2D(channels, kernel_size=(3, 3), padding='same')(x)
  return Add()([input_tensor, x])

In [67]:
def MIRNet_Model(rrb_count, msrb_count, channels):
  input_tensor = Input(shape=(128, 128, 3))
  x = Conv2D(channels, kernel_size=(3, 3), padding='same')(input_tensor)
  for _ in range(rrb_count):
    x = recursive_residual_block(x, msrb_count, channels)
  x = Conv2D(3, kernel_size=(3, 3), padding='same')(x)
  output_tensor = Add()([input_tensor, x])
  return Model(input_tensor, output_tensor)

Build MIRNet Model

In [68]:
RRB_COUNT = 3
MSRB_COUNT = 2
CHANNELS = 64

In [69]:
MIRNet_model = MIRNet_Model(RRB_COUNT, MSRB_COUNT, CHANNELS)

Discriminator

In [70]:
def discriminator():
    mimg = Input(shape=(128, 128, 3))
    nmimg = Input(shape=(128, 128, 3))

    merged = Concatenate()([mimg, nmimg])

    d = Conv2D(64, (4, 4), strides=(2, 2), padding='same')(merged)
    d = LeakyReLU(alpha=0.2)(d)

    d = Conv2D(128, (4, 4), strides=(2, 2), padding='same')(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)

    d = Conv2D(256, (4, 4), strides=(2, 2), padding='same')(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)

    d = Conv2D(512, (4, 4), strides=(2, 2), padding='same')(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)

    d = Conv2D(512, (4, 4), strides=(2, 2), padding='same')(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)

    d = Conv2D(1, (4, 4), padding='same')(d)
    patch_out = Activation('sigmoid')(d)

    model = Model([mimg, nmimg], patch_out)

    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
    return model

In [71]:
discriminator_model = discriminator()

Model

In [72]:
def main_model(MIRNet_model, discriminator_model):
    for layer in discriminator_model.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False

    source = Input(shape=(128, 128, 3))
    gen_out = MIRNet_model(source)
    dis_out = discriminator_model([source, gen_out])

    model = Model(source, [dis_out, gen_out])
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1, 100])
    return model

In [73]:
model = main_model(MIRNet_model, discriminator_model)

Data

In [74]:
def load_real_samples(filename):
  data = np.load(filename)
  source, target = data['arr_0'], data['arr_1']
  return source, target

In [75]:
def generate_real_samples(dataset, n_samples, patch_shape):
  trainA, trainB = dataset
  ix = np.random.randint(0, trainA.shape[0], n_samples)
  X1, X2 = trainA[ix], trainB[ix]
  y = np.ones((n_samples, patch_shape, patch_shape, 1))
  return [X1, X2], y

In [76]:
def generate_fake_samples(MIRNet_model, samples, patch_shape):
  X = MIRNet_model.predict(samples)
  y = np.zeros((len(X), patch_shape, patch_shape, 1))
  return X, y

Training Process

In [77]:
def summarize_performance(step, MIRNet_model, dataset, n_samples=3):
    [X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
    X_fakeB, _ = generate_fake_samples(MIRNet_model, X_realA, 1)

    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + i)
        plt.axis('off')
        plt.imshow(X_realA[i])

    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + n_samples + i)
        plt.axis('off')
        plt.imshow(X_fakeB[i])

    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + n_samples * 2 + i)
        plt.axis('off')
        plt.imshow(X_realB[i])

    filename1 = 'plot_%06d.png' % (step+1)
    plt.savefig(filename1)
    plt.close()

    filename2 = 'model_%06d.h5' % (step+1)
    MIRNet_model.save(filename2)
    print('>Saved: %s and %s' % (filename1, filename2))

In [81]:
def train(MIRNet_model, discriminator_model, model, dataset, n_epochs=10):
  n_patch = discriminator_model.output_shape[1]
  trainA, trainB = dataset
  n_steps = int(len(trainA))
  n_batch = 1

  for i in range(n_epochs):
    for j in range(n_steps):
      [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
      X_fakeB, y_fake = generate_fake_samples(MIRNet_model, X_realA, n_patch)
      d_loss1 = discriminator_model.train_on_batch([X_realA, X_realB], y_real)
      d_loss2 = discriminator_model.train_on_batch([X_realA, X_fakeB], y_fake)
      g_loss, _, _ = model.train_on_batch(X_realA, [y_real, X_realB])
      print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (j+1, d_loss1, d_loss2, g_loss))
    summarize_performance(i, MIRNet_model, dataset)

In [82]:
dataset = load_real_samples('gc_128.npz')
print('Loaded: ', dataset[0].shape, dataset[1].shape)
image_shape = dataset[0].shape[1:]

Loaded:  (15590, 128, 128, 3) (15590, 128, 128, 3)


In [83]:
train(MIRNet_model, discriminator_model, model, dataset)

>1, d1[0.428] d2[1.742] g[39.723]
>2, d1[0.298] d2[1.078] g[35.657]
>3, d1[0.325] d2[0.911] g[40.863]
>4, d1[0.417] d2[0.976] g[29.192]
>5, d1[0.408] d2[0.939] g[21.968]
>6, d1[0.493] d2[0.796] g[16.022]
>7, d1[0.538] d2[0.686] g[10.539]
>8, d1[0.597] d2[0.926] g[17.306]
>9, d1[0.604] d2[0.524] g[25.755]
>10, d1[0.438] d2[0.616] g[12.870]
>11, d1[0.577] d2[0.512] g[16.209]
>12, d1[0.482] d2[0.641] g[12.501]
>13, d1[0.466] d2[0.597] g[9.900]
>14, d1[0.511] d2[0.737] g[8.762]
>15, d1[0.490] d2[0.521] g[13.198]
>16, d1[0.537] d2[0.376] g[39.160]
>17, d1[0.500] d2[0.511] g[11.683]
>18, d1[0.536] d2[0.631] g[21.087]
>19, d1[0.535] d2[0.589] g[9.617]
>20, d1[0.524] d2[0.581] g[16.441]
>21, d1[0.551] d2[0.570] g[13.543]
>22, d1[0.474] d2[0.411] g[11.475]
>23, d1[0.449] d2[0.430] g[8.611]
>24, d1[0.557] d2[0.363] g[12.674]
>25, d1[0.435] d2[0.461] g[11.555]
>26, d1[0.517] d2[0.595] g[13.067]
>27, d1[0.455] d2[0.517] g[11.741]
>28, d1[0.414] d2[0.521] g[13.288]
>29, d1[0.499] d2[0.646] g[12.889

KeyboardInterrupt: 

In [84]:
summarize_performance(1, MIRNet_model, dataset)



Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>Saved: plot_000002.png and model_000002.h5
