### Beginning

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import tensorflow as tf
import numpy as np
import h5py, os, glob, re, datetime

In [0]:
hdf_directory = "/content/drive/My Drive/DL/datasets/DIV2K HDF/"
patched_hdf_directory = "/content/drive/My Drive/DL/datasets/DIV2K HDF PATCHED/"

current_time = lambda: datetime.datetime.now().strftime("%m/%d %H:%M:%S")

def get_hdf_dir(training, hsv, patching):
    dataset_type = "train/" if training == True else "valid/"
    color_model = "RGB/" if hsv == False else "HSV/" 
    dir_ = patched_hdf_directory if patching == True else hdf_directory
    return "{}{}{}".format(dir_, color_model, dataset_type)
  
def batch_gen(training, hsv, epoch, batch):
    """Create patched minibatches / create nonpatched pairs "low-res"--"hi-res".
        
        
    Parameters
    ----------
        training: bool
            If true, "train" directory is used; if false, "valid" directory is used.
        hsv: bool 
            If true, the HSV color model is used; if false, the RGB color model is used.
        epoch: int
            If even, "bicubic" interpolation is used; if odd, "unknown" interpolation is used (for low-resolution images).
        batch: int
            The minibatch size.
        
        
    Yields
    ------
        I. If "batch" is equal to 1:
            numpy array 
                Low-resolution image.
            numpy array 
                High-resolution image.
            int 
                Image number (counter).
            
        II. If "batch" is greater than 1:
            numpy array 
                Low-resolution image patches.
            numpy array 
                High-resolution image patches.
            int 
                Minibatch number (counter).
         
         
    """
    #if batch == 1, then take nonpatched image
    if batch == 1:
        patching = False
    #else take patched image
    else: patching = True
    hdfs_path = get_hdf_dir(training, hsv, patching)
    hdfs = sorted(glob.glob(hdfs_path + "*.hdf5"))
    if training == True:
        np.random.shuffle(hdfs)
    batch_counter = 0
    for i in range(len(hdfs)):
        with h5py.File(hdfs[i], "r") as hdf:
            if batch == 1:
                if epoch % 2 == 0:
                    lr = hdf["lr_bicub"][()]
                if epoch % 2 == 1:
                    lr = hdf["lr_unkn"][()]
                hr = hdf["hr"][()]
                lr, hr = np.expand_dims(lr, axis=0), np.expand_dims(hr, axis=0)
                yield lr, hr, i
            else:
                patches_index = list(range(hdf["patched_shape"][()][0]*hdf["patched_shape"][()][1]))
                np.random.shuffle(patches_index)
                for j in range(len(patches_index) // batch):
                    lr, hr = [], []
                    batch_counter += 1
                    for k in range(batch):
                        if epoch % 2 == 0:
                            lr.append(hdf["lr_bicub"][()][patches_index[j * batch + k]])
                        if epoch % 2 == 1:
                            lr.append(hdf["lr_unkn"][()][patches_index[j * batch + k]])
                        hr.append(hdf["hr"][()][patches_index[j * batch + k]])
                    yield np.asarray(lr), np.asarray(hr), batch_counter            
      
def save_img(img, filename, hsv=False):
    if hsv == True:
        img = tf.image.hsv_to_rgb(img)
    img = tf.image.convert_image_dtype(img, tf.uint8)
    img_raw = tf.image.encode_png(img).eval()
    return tf.write_file(tf.constant(filename), img_raw) 

### Architecture

In [0]:
FEATURE_MAPS_NUMBER = 64
RES_BLOCKS_NUMBER = 5

def conv_layer(x, W_shape, b_shape):
    W = tf.get_variable("W", shape=W_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d())
    b = tf.get_variable("b", initializer=tf.constant(0.01, shape=b_shape))
    conv = tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME")
    conv_b = tf.nn.bias_add(conv, b)
    return conv_b

def res_block(x, scope):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE) as scope:
        shortcut = x
        with tf.variable_scope("conv1", reuse=tf.AUTO_REUSE) as scope:
            x = conv_layer(x, [5,5,FEATURE_MAPS_NUMBER,FEATURE_MAPS_NUMBER], [FEATURE_MAPS_NUMBER])
        x = tf.nn.relu(x)
        with tf.variable_scope("conv2", reuse=tf.AUTO_REUSE) as scope:
            x = conv_layer(x, [5,5,FEATURE_MAPS_NUMBER,FEATURE_MAPS_NUMBER], [FEATURE_MAPS_NUMBER])
        x = x * 0.1  
    return x + shortcut

In [0]:
def net(x):
    #First Conv
    with tf.variable_scope("first_conv", reuse=tf.AUTO_REUSE) as scope:
        conv1 = conv_layer(x, [5,5,3,FEATURE_MAPS_NUMBER], [FEATURE_MAPS_NUMBER])    
    shortcut1 = conv1
    #ResBlocks Stack
    res_stack = tf.contrib.layers.repeat(conv1, RES_BLOCKS_NUMBER, res_block, scope="res_stack")
    #Conv after ResBlocks
    with tf.variable_scope("after_res_block", reuse=tf.AUTO_REUSE) as scope:
        conv2 = conv_layer(res_stack, [5,5,FEATURE_MAPS_NUMBER,FEATURE_MAPS_NUMBER], [FEATURE_MAPS_NUMBER])
    #Shortcut
    res1 = 0.1 * conv2 + shortcut1
    #Conv before PixelShuffle
    with tf.variable_scope("before_shuffling", reuse=tf.AUTO_REUSE) as scope:
        conv3 = conv_layer(res1, [5,5,FEATURE_MAPS_NUMBER,FEATURE_MAPS_NUMBER*2], [FEATURE_MAPS_NUMBER*2])
    #PixelShuffle
    shuffled = tf.nn.depth_to_space(conv3, 2, data_format='NHWC') #maybe NCWH is better?
    #Conv to 3 channels
    with tf.variable_scope("last_conv", reuse=tf.AUTO_REUSE) as scope:
        conv4 = conv_layer(shuffled, [5,5,FEATURE_MAPS_NUMBER/2,3], [3])
    return conv4

In [0]:
tf.reset_default_graph()

x = tf.placeholder(tf.float32, [None, None, None, 3])
y = tf.placeholder(tf.float32, [None, None, None, 3])
y_net = net(x)

### Loss

In [0]:
l2_lambda = 0.01

vars_ = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]
regularization = tf.multiply(tf.reduce_sum([tf.nn.l2_loss(var) for var in vars_]), l2_lambda)

learning_rate = tf.placeholder(tf.float32)
mae = tf.reduce_mean(tf.abs(tf.subtract(y_net, y)))
train_step = tf.train.AdamOptimizer(learning_rate).minimize(tf.add(mae, regularization))

### Training

In [0]:
TRAINING = True
HSV = False
LAST_EPOCH = 12
LEARNING_RATE = 0.0005
BATCHSIZE = 16

print("Choose your initial epoch:")
INITIAL_EPOCH = int(input())
print("That's a great choice!")

if HSV == False: 
    save_path = "/content/drive/My Drive/DL/ckpt/SR_RGB_"
    log_path = "/content/drive/My Drive/DL/metrics/loss_rgb.txt"
else: 
    save_path = "/content/drive/My Drive/DL/ckpt/SR_HSV_"
    log_path = "/content/drive/My Drive/DL/metrics/loss_hsv.txt"
  
print("Start:", current_time())

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(max_to_keep=1000)
    if INITIAL_EPOCH == 0:
        with open(log_path, "w") as txt_file: pass
        print("OK, it's a new training (without disconnect, please BlessRNG).\n")
    else: 
        saver.restore(sess, "{}{}.ckpt".format(save_path, INITIAL_EPOCH))
        print("Сontinue training from a {} epoch (without disconnect, please BlessRNG).\n".format(INITIAL_EPOCH))
    for epoch in range(INITIAL_EPOCH, LAST_EPOCH):
        for img_x, img_y, num in batch_gen(TRAINING, HSV, epoch, BATCHSIZE):
            _, loss = sess.run([ train_step, mae ], feed_dict={ x: img_x, y: img_y, learning_rate: LEARNING_RATE })
            with open(log_path, "a+") as txt_file:
                txt_file.write("time: {}... epoch {}... batch {}... loss = {:.7f}\n".format(current_time(), epoch + 1, num, loss))
            if num % 2000 == 0:
                print("time: {}... epoch {}... batch {}... loss = {:.7f}".format(current_time(), epoch + 1, num, loss))
        saver.save(sess, "{}{}.ckpt".format(save_path, epoch + 1))      
    
print("\nEnd:", current_time())    

### Result

In [0]:
TRAINING = False
HSV = False
BATCHSIZE = 1
print("Choose your restoring epoch:")
EPOCH = int(input())
test_images = [8, 40, 52]

if HSV == False: 
    ckpt_path = "/content/drive/My Drive/DL/ckpt/SR_RGB_"
    log_path = "/content/drive/My Drive/DL/metrics/psnr_ssim_rgb.txt"
    img_save_path = "/content/drive/My Drive/DL/test_img/RGB/"
else: 
    ckpt_path = "/content/drive/My Drive/DL/ckpt/SR_HSV_"
    log_path = "/content/drive/My Drive/DL/metrics/psnr_ssim_hsv.txt"
    img_save_path = "/content/drive/My Drive/DL/test_img/HSV/"
    
print("Start:", current_time())

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, "{}{}.ckpt".format(ckpt_path, EPOCH))
    get_upsc_img = net(x)
    if EPOCH == 1:
        with open(log_path, "w") as txt_file: pass
    for lr_img, orig_img, num in batch_gen(TRAINING, HSV, 1, BATCHSIZE):
        if num in test_images:
            upsc_img = sess.run(get_upsc_img, feed_dict={ x : lr_img })
            upsc_img[upsc_img > 1] = 1
            upsc_img[upsc_img < 0] = 0
            upsc_img_tensor, orig_img_tensor = tf.convert_to_tensor(upsc_img), tf.convert_to_tensor(orig_img)
            ssim, psnr = tf.image.ssim(upsc_img_tensor, orig_img_tensor, max_val=1.0).eval(), tf.image.psnr(upsc_img_tensor, orig_img_tensor, max_val=1.0).eval()
            img_saving = save_img(upsc_img[0], "{}{}_epoch{}.png".format(img_save_path, num + 1, EPOCH), hsv=HSV)
            sess.run(img_saving)
            with open(log_path, "a+") as txt_file:
                txt_file.write("time: {}... epoch {}... image {}... psnr {:.7f}... ssim {:.7f}\n".format(current_time(), EPOCH, num + 1, psnr[0], ssim[0]))
            print("time: {}... epoch {}... image {}... psnr {:.7f}... ssim {:.7f}".format(current_time(), EPOCH, num + 1, psnr[0], ssim[0]))
    
print("End:", current_time())