# Imports

In [12]:
#Data manipulation 
import numpy as np
import pandas as pd

#Image processing / manipulation
from skimage.color import rgb2lab, lab2rgb
import cv2

#Deep learning libraries
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras import applications
from tensorflow.keras.layers import Conv2D, Activation,Reshape, UpSampling2D,Conv2DTranspose
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K
from tensorflow import keras
from tensorflow.keras import layers

#Utility libraries
import os
from tqdm import tqdm

# Helper functions

In [28]:

def normalize(input_image, real_image):
  input_image = input_image/256
  real_image = real_image/256

  return input_image, real_image


def deprocess(imgs):
    imgs = imgs * 255
    imgs = np.array(imgs)
    imgs[imgs>255] = 255
    imgs[imgs < 0] = 0
    return imgs.astype(np.uint8)  

def read_img(img):
    res = []
    for i,image in enumerate(img.numpy()):
      labimg = cv2.cvtColor(cv2.resize(image, (224, 224)), cv2.COLOR_BGR2Lab)
      labimg = labimg[:,:,0]
      labimg = labimg.reshape(labimg.shape+(1,))
      res.append(labimg)
    res = np.array(res)
    return res


def reconstruct(batchX, predictedY):
    result = np.concatenate((batchX, predictedY), axis=-1)
    res = []
    for i,image in enumerate(result):
      rgbimg = cv2.cvtColor(image, cv2.COLOR_Lab2RGB)
      res.append(rgbimg)
    res = np.array(res)
    return res


#Essentials for global autoencoder

class FusionLayer(Layer):
    def call(self, inputs, mask=None):
        imgs, embs = inputs
        # reshaped_shape = imgs.shape[:3].concatenate(embs.shape[1])
        reshaped_shape = (tf.shape(imgs)[0],imgs.shape[1],imgs.shape[1],embs.shape[1])
        embs = K.repeat(embs, imgs.shape[1] * imgs.shape[2])
        embs = K.reshape(embs, tf.stack(reshaped_shape))
        return K.concatenate([imgs, embs], axis=3)

    def compute_output_shape(self, input_shapes):
        # Must have 2 tensors as input
        assert input_shapes and len(input_shapes) == 2
        imgs_shape, embs_shape = input_shapes

        # The batch size of the two tensors must match
        assert imgs_shape[0] == embs_shape[0]

        # (batch_size, width, height, embedding_len + depth)
        return imgs_shape[:3] + (imgs_shape[3] + embs_shape[1],)

def getGlobal_encoder(model_input):
  from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2
  img_conc = layers.Concatenate()([model_input, model_input, model_input])  
  inputModel = InceptionResNetV2(weights='imagenet',
                      include_top=True,
                      input_tensor=img_conc)
  inputModel.trainable = False
  return inputModel.output

#For loading models
def model():
    from tensorflow.keras import Input
    input = Input(shape=(256, 256, 1))
    #The encoder
    encoder = Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)(input)
    encoder = Conv2D(128, (3, 3), activation='relu', padding='same')(encoder)
    encoder = Conv2D(128, (3, 3), activation='relu', padding='same', strides=2)(encoder)
    encoder = Conv2D(256, (3, 3), activation='relu', padding='same')(encoder)
    encoder = Conv2D(256, (3, 3), activation='relu', padding='same', strides=2)(encoder)
    encoder = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder)
    encoder = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder)
    #Global encoder
    globalEncoder = getGlobal_encoder(input)


    #fusion
    fusion = FusionLayer()([encoder,globalEncoder])
    fusion = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(fusion)

    #Decoder
    decoder = Conv2D(128, (3, 3), activation='relu', padding='same')(fusion)
    decoder = UpSampling2D((2, 2))(decoder)
    decoder = Conv2D(64, (3, 3), activation='relu', padding='same')(decoder)
    decoder = UpSampling2D((2, 2))(decoder)
    decoder = Conv2D(32, (3, 3), activation='relu', padding='same')(decoder)
    decoder = Conv2D(16, (3, 3), activation='relu', padding='same')(decoder)
    decoder = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder)
    decoder = UpSampling2D((2, 2))(decoder)
    return Model(input,decoder)


def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result    
  


def Generator():
  inputs = tf.keras.layers.Input(shape=[256, 256, 3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),  # (batch_size, 64, 64, 128)
    downsample(256, 4),  # (batch_size, 32, 32, 256)
    downsample(512, 4),  # (batch_size, 16, 16, 512)
    downsample(512, 4),  # (batch_size, 8, 8, 512)
    downsample(512, 4),  # (batch_size, 4, 4, 512)
    downsample(512, 4),  # (batch_size, 2, 2, 512)
    downsample(512, 4),  # (batch_size, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(512, 4),  # (batch_size, 16, 16, 1024)
    upsample(256, 4),  # (batch_size, 32, 32, 512)
    upsample(128, 4),  # (batch_size, 64, 64, 256)
    upsample(64, 4),  # (batch_size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(3, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)



def colorization_model():

        input_img = Input((224,224,3))


        # VGG16 without top layers
        VGG_model = applications.vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
        model_ = Model(VGG_model.input,VGG_model.layers[-6].output)
        model = model_(input_img)


        # Global Features

        global_features = keras.layers.Conv2D(512, (3, 3), padding='same', strides=(2, 2), activation='relu')(model)
        global_features = keras.layers.BatchNormalization()(global_features)
        global_features = keras.layers.Conv2D(512, (3, 3), padding='same', strides=(1, 1), activation='relu')(global_features)
        global_features = keras.layers.BatchNormalization()(global_features)

        global_features = keras.layers.Conv2D(512, (3, 3), padding='same', strides=(2, 2), activation='relu')(global_features)
        global_features = keras.layers.BatchNormalization()(global_features)
        global_features = keras.layers.Conv2D(512, (3, 3), padding='same', strides=(1, 1), activation='relu')(global_features)
        global_features = keras.layers.BatchNormalization()(global_features)

        global_features2 = keras.layers.Flatten()(global_features)
        global_features2 = keras.layers.Dense(1024)(global_features2)
        global_features2 = keras.layers.Dense(512)(global_features2)
        global_features2 = keras.layers.Dense(256)(global_features2)
        global_features2 = keras.layers.RepeatVector(28*28)(global_features2)
        global_features2 = keras.layers.Reshape((28,28, 256))(global_features2)

        global_featuresClass = keras.layers.Flatten()(global_features)
        global_featuresClass = keras.layers.Dense(4096)(global_featuresClass)
        global_featuresClass = keras.layers.Dense(4096)(global_featuresClass)
        global_featuresClass = keras.layers.Dense(1000, activation='softmax')(global_featuresClass)

        # Midlevel Features

        midlevel_features = keras.layers.Conv2D(512, (3, 3),  padding='same', strides=(1, 1), activation='relu')(model)
        midlevel_features = keras.layers.BatchNormalization()(midlevel_features)
        midlevel_features = keras.layers.Conv2D(256, (3, 3),  padding='same', strides=(1, 1), activation='relu')(midlevel_features)
        midlevel_features = keras.layers.BatchNormalization()(midlevel_features)

        # fusion of (VGG16 + Midlevel) + (VGG16 + Global)
        modelFusion = keras.layers.concatenate([midlevel_features, global_features2])

        # Fusion + Colorization
        outputModel =  keras.layers.Conv2D(256, (1, 1), padding='same', strides=(1, 1), activation='relu')(modelFusion)
        outputModel =  keras.layers.Conv2D(128, (3, 3), padding='same', strides=(1, 1), activation='relu')(outputModel)

        outputModel =  keras.layers.UpSampling2D(size=(2,2))(outputModel)
        outputModel =  keras.layers.Conv2D(64, (3, 3), padding='same', strides=(1, 1), activation='relu')(outputModel)
        outputModel =  keras.layers.Conv2D(64, (3, 3), padding='same', strides=(1, 1), activation='relu')(outputModel)

        outputModel =  keras.layers.UpSampling2D(size=(2,2))(outputModel)
        outputModel =  keras.layers.Conv2D(32, (3, 3), padding='same', strides=(1, 1), activation='relu')(outputModel)
        outputModel =  keras.layers.Conv2D(2, (3, 3), padding='same', strides=(1, 1), activation='sigmoid')(outputModel)
        outputModel =  keras.layers.UpSampling2D(size=(2,2))(outputModel)
        final_model = Model(input_img,[outputModel, global_featuresClass])

        return final_model


#For colourising images
def colouriseMethod1or2(image,model):
    input = np.array([image])
    _,HEIGHT,WIDTH,_=input.shape
    L_original = rgb2lab(input)[:,:,:,0]/100
    L_original = L_original.reshape(L_original.shape+(1,))
    input = tf.image.resize(input, [256, 256],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    #Lab mapping for model 1 & 2
    L = rgb2lab(input)[:,:,:,0]/100
    L = L.reshape(L.shape+(1,))

    pred = model(L)
    pred = tf.image.resize(pred, [HEIGHT, WIDTH],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    result = lab2rgb(tf.concat([L_original*100,pred*128],axis=-1))[0]

    return result

def colouriseMethod3(image):
    input = np.array([image])
    _,HEIGHT,WIDTH,_=input.shape
    L_original = rgb2lab(input)[:,:,:,0]/100
    L_original = L_original.reshape(L_original.shape+(1,))
    input = tf.image.resize(input, [256, 256],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    pred = model3(input)
    pred = rgb2lab(pred)[:,:,:,1:]
    pred = tf.image.resize(pred, [HEIGHT, WIDTH],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    result = lab2rgb(tf.concat([L_original*100,pred],axis=-1))[0]
    return result

def colouriseMethod4(image):
    input = np.array([image])
    _,HEIGHT,WIDTH,_=input.shape
    L_original = rgb2lab(input)[:,:,:,0]/100
    L_original = L_original.reshape(L_original.shape+(1,))
    input = tf.image.resize(input, [256, 256],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    L2 = read_img(input)/100
    pred, _  = model4(np.tile(L2,[1,1,1,3]))
    pred = tf.image.resize(pred, [HEIGHT, WIDTH],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    result = reconstruct(deprocess(L_original), deprocess(pred))

    return result

# Load models

In [13]:
model1 = tf.keras.models.load_model('/trained_models/model.h5')#AutoEncoder
model2 = model()
model2.load_weights('/trained_models/model2.h5')
model3 = Generator()
model3.load_weights('/trained_models/generator.h5')
model4 = colorization_model()
model4.load_weights('/trained_models/chroma_generator.h5')

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5


# Video colourisation logic

The logic below uses CV2's VideoWriter to manipulate frames within the video. Here, the logic loops through each frame and colourises the frame using either of the chosen models. The colourised video will be saved in the "colourised.mp4". 

In [31]:
PATH = "footage.mp4" #provide a path to a video you want to colourise
MODEL = 3 #0 - Simple AutoEncoder | 1 - Global AutoEncoder | 2 - Pix2Pix | 3 - ChromaGAN


cap = cv2.VideoCapture(PATH)
frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

fourcc = cv2.VideoWriter_fourcc('X', 'V', 'I', 'D')
out = cv2.VideoWriter('colourised.mp4',fourcc , 20.0, (int(cap.get(3)),int(cap.get(4))))
with tqdm(total=frames) as pbar:
  while(cap.isOpened()):
      ret, frame = cap.read()
      if ret==True:
          image = np.array(frame)
          image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
          image = image.astype("float32")/256
          if MODEL == 0:
            colourisedFrame = colouriseMethod1or2(image,model1)
            colourisedFrame = cv2.cvtColor(colourisedFrame, cv2.COLOR_RGB2BGR)
            colourisedFrame = np.uint8(256 * colourisedFrame)
          elif MODEL == 1:
            colourisedFrame = colouriseMethod1or2(image,model2)
            colourisedFrame = cv2.cvtColor(colourisedFrame, cv2.COLOR_RGB2BGR)
            colourisedFrame = np.uint8(256 * colourisedFrame)
          elif MODEL == 2:
            colourisedFrame = colouriseMethod3(image)
            colourisedFrame = cv2.cvtColor(colourisedFrame, cv2.COLOR_RGB2BGR)
            colourisedFrame = np.uint8(256 * colourisedFrame)
          elif MODEL == 3:
            colourisedFrame = colouriseMethod4(image)
            colourisedFrame = cv2.cvtColor(colourisedFrame, cv2.COLOR_RGB2BGR)
          
          out.write(colourisedFrame)
          if cv2.waitKey(1) & 0xFF == ord('q'):
              break
          pbar.update(1)
      else:
          break

# Release everything if job is finished
cap.release()
out.release()
cv2.destroyAllWindows()

100%|██████████| 349/349 [06:04<00:00,  1.04s/it]
