In [None]:
import os
import numpy as np
import import_ipynb
from tqdm import tqdm as tqdm
from original_reccnn import RecCNN
from comcnn import ComCNN

import matplotlib.image as mpimg
from PIL import Image
import cv2

from skimage.measure import compare_ssim as ssim

In [None]:
def compressThis(x_input, quality_factor):
    
    # x_input dimension (num, x, y, c)
    num = x_input.shape[0]
    ans = []
    for i in range(num):
        x_single = x_input[i]
        mpimg.imsave(os.getcwd()+'/org.JPEG', x_single)
        tmp = Image.open(os.getcwd()+'/org.JPEG')
        tmp.save(os.getcwd()+'/com.JPEG',"JPEG", quality = quality_factor)
        out_single = mpimg.imread(os.getcwd()+'/com.JPEG')
        ans.append(out_single)
        
    return np.array(ans)

In [None]:
def extract_lr(X, scale):
    lr = []
    for i in tqdm(range(X.shape[0])):
        img = X[i]
        img = cv2.resize(img, (X.shape[1]/scale , X.shape[2]/scale))
        lr.append(img)
    return np.array(lr)

In [None]:
def load_images(dset_location):
    loc = dset_location
    images_location = []
    X = []

    print("Extracting image locations..")
    for i in tqdm(os.listdir(loc)):
        images_location.append(loc+'/'+i)

    print("Extracting images..")
    images_location = images_location[:images_count]
    for im_loc in tqdm(images_location):
        img = cv2.imread(im_loc)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        X.append(img)

    return np.array(X)

In [None]:
def split_dataset(X, split_ratio):
    split = int(split_ratio * X.shape[0])
    x_train = X[:split,:,:,:]
    x_valid = X[split:,:,:,:]
    return x_train, x_valid

In [None]:
from keras.layers import Input
from keras.callbacks import ModelCheckpoint
import tensorflow as tf
from keras.models import Model

def run():
    
    # ----- data loading and pre-processing -------------
    X = load_images(dset_location)
    x_train, x_valid = split_dataset(X, split_ratio)

    print("train images shape : " + str(x_train.shape))
    print("valid images shape : " + str(x_valid.shape))

    print("Generating LR images of train images..")
    x_train_lr = extract_lr(x_train, scale)
    print("Generating LR images of valid images..")
    x_valid_lr = extract_lr(x_valid, scale)


    x_train = x_train.astype('float32')
    x_valid = x_valid.astype('float32')
    x_train_lr = x_train_lr.astype('float32')
    x_valid_lr = x_valid_lr.astype('float32')

    x_train = x_train / 255
    x_valid = x_valid / 255
    x_train_lr = x_train_lr / 255
    x_valid_lr = x_valid_lr / 255

    print("train LR images shape : " + str(x_train_lr.shape))
    print("valid LR images shape : " + str(x_valid_lr.shape))
    mpimg.imsave('sample_lr.png', x_train_lr[1])
    mpimg.imsave('sample_hr.png', x_train[1])
    
    #-------- rec-cnn and initial training ----------------
    inp = Input(shape=(lr_dim, lr_dim, channels))
    rec_cnn = RecCNN(c=channels)
    out = rec_cnn.model(filters=filters, scale=scale, h=lr_dim, w=lr_dim, inp=inp)
    model_reccnn = Model(inp, out)
    model_reccnn.compile(optimizer='adam', loss=reccnn_loss_fun, metrics=['accuracy'])
    checkpointr = ModelCheckpoint(filepath='reccnn.weights.best.hdf5', save_best_only=True, verbose=1)
    model_reccnn.summary()

    if perform_initial_training == True:

        print('Entered into reccnn-initial training phase..')
        model_reccnn.fit(x=x_train_lr, y=x_train, validation_data=(x_valid_lr, x_valid), epochs=initial_reccnn_epochs, shuffle=True, verbose=1, batch_size=8, callbacks=[checkpointr])

        sample_valid_lr = x_valid_lr[:5]
        sample_valid = model_reccnn.predict(sample_valid_lr)
        mpimg.imsave('initial_reccnn_input.png', sample_valid_lr[1])
        mpimg.imsave('initial_reccnn_output.png', sample_valid[1])
    
    
    #---------com-cnn and initial training -------------
    inp = Input(shape=(hr_dim, hr_dim, channels))
    com_cnn = ComCNN(c=channels)
    model_comcnn = Model(inp, model_reccnn(com_cnn.compact(inp)))
    model_comcnn.layers[4].trainable = False
    model_comcnn.compile(optimizer='adam', loss=comcnn_loss_fun, metrics=['accuracy'])
    checkpointc = ModelCheckpoint(filepath='comcnn.weights.best.hdf5', save_best_only=True, verbose=1)
    model_comcnn.summary()

    if perform_initial_training == True:
        print('Entered into comcnn-initial training phase..')
        model_comcnn.fit(x=x_train, y=x_train, validation_data=(x_valid, x_valid), epochs=initial_comcnn_epochs, shuffle=True, verbose=1, batch_size=8, callbacks=[checkpointc])

        sample_valid_hr = x_valid[:5]
        sample_valid = model_comcnn.predict(sample_valid_hr)
        mpimg.imsave('initial_hr.png', sample_valid_hr[1])
        mpimg.imsave('after_intial_comcnn_hr.png', sample_valid[1])
    
    #------- end-to-end training ----------
    print("Entered into final training phase..")
    for i in tqdm(range(iterations)):
        # calculating xm using comcnn
        upto_comcnn = Model(model_comcnn.input, model_comcnn.layers[3].output)
        xm = upto_comcnn.predict(x_train)
        xm_valid = upto_comcnn.predict(x_valid)
        mpimg.imsave('final-phase/'+str(i)+'-xm.png', xm[0])
        xm = compressThis(xm, QF)
        xm_valid = compressThis(xm_valid, QF)
        xm = xm.astype('float32')
        xm_valid = xm_valid.astype('float32')
        xm = xm / 255
        xm_valid = xm_valid / 255
        mpimg.imsave('final-phase/'+str(i)+'-xm_compress.png', xm[0])

        # final training phase 
        model_reccnn.fit(x=xm, y=x_train, validation_data=(xm_valid, x_valid),epochs=num_of_epochs, shuffle=True, verbose=1, batch_size=8, callbacks=[checkpointr])
        model_comcnn.fit(x=x_train, y=x_train, validation_data=(x_valid, x_valid),epochs=num_of_epochs, shuffle=True, verbose=1, batch_size=8, callbacks=[checkpointc])


In [None]:
dset_location = '/home/titanxpascal/Documents/sem6proj/img-compression/Subset16k/Subset16k'
images_count = 2500
split_ratio = 0.8
QF = 10

# ----- initial training phase ---------
perform_initial_training = True
initial_reccnn_epochs = 5
initial_comcnn_epochs = 5

# ----- final training phase ----------
iterations = 50
num_of_epochs = 2

# ----- rec-cnn parameters -----------
filters = 64
channels = 3
scale = 4
lr_dim = 64
reccnn_loss_fun = 'mean_squared_error'

# ----- com-cnn parameters -----------
hr_dim = 256
channels = 3
comcnn_loss_fun = 'mean_squared_error'