In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Conv2D,DepthwiseConv2D,ReLU,add
from tensorflow.keras.optimizers import Adam,SGD
from tensorflow.keras.initializers import RandomUniform
from tensorflow.keras import layers
from tensorflow.keras.utils import get_custom_objects
import numpy as np
import wandb
import pathlib
from IPython.display import display
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.preprocessing.image import img_to_array,load_img
from wandb.keras import WandbCallback
from tensorflow.keras import backend as K
import random
import math
import PIL
import glob
print(tf.__version__)
print(tf.keras.__version__)
# tf.enable_eager_execution()

In [None]:
print(tf.test.gpu_device_name())

In [None]:
# wandb init
wandb.login()

In [None]:
sweep_config={
    "name":"SRHW_Sweep",
    "method":"random",
    "metric":{
        "name":"val_loss",
        "goal":"minimize"
    },
    "parameters":{
        "learning_rate":{
            "values":[0.0001,0.00015,0.0002]
        },
        "batch_size":{
            "values":[1,2,4,8]
        },
        "epochs":{
            "value":5000
        },
        "optimizer":{
            "value":"adam"
        }
    }
}
sweep_id=wandb.sweep(sweep_config,entity="krislara",project="SRHW_TensorflowV1.1")

In [None]:
# Data Loading - load and convert to yuv extract y, normalize for Y and X
def preprocess_input(inp,input_size,upscale):
    inp=tf.image.rgb_to_yuv(inp)
    last_dim=len(inp.shape)-1
    y,u,v=tf.split(inp,3,axis=last_dim)
    y=tf.image.resize(y,(input_size//upscale,input_size//upscale),method='bicubic')
    return y

def preprocess_target(inp):
    inp=tf.image.rgb_to_yuv(inp)
    last_dim=len(inp.shape)-1
    y,u,v=tf.split(inp,3,axis=last_dim)
    return y

AUTOTUNE = tf.data.experimental.AUTOTUNE

def data_from_file(file_path):
    img=tf.io.read_file(file_path)
    img=tf.io.decode_image(img,3,dtype=tf.dtypes.float32,expand_animations=False)
    return img
    
def get_data(data_dir,batch_size=2,img_size=128,upscale=2,shuffle=True):
    data_path=tf.data.Dataset.list_files(data_dir,shuffle=shuffle)
    data=data_path.map(lambda x:data_from_file(x))    
    data=data.batch(batch_size)
    data=data.map(lambda x:(preprocess_input(x,input_size=img_size,upscale=upscale),preprocess_target(x)))
    data=data.prefetch(buffer_size=AUTOTUNE)
    return data

In [None]:
#functional model
def get_model(upscale=2,quant=True):
    inputs=tf.keras.Input(shape=(None,None,1),name='LR')
    x=Conv2D(32,(3,3),padding="same",
             use_bias=False,
            # data_format='channels_first',
             kernel_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
            )(inputs)
    x=ReLU(trainable=False)(x)
    res=DepthwiseConv2D((1,5),padding='same',
                        #data_format='channels_first',
                        use_bias=False,
                        depthwise_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
                       )(x)
    res=Conv2D(16,(1,1),use_bias=False,padding='same',
               #data_format='channels_first',
               kernel_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
              )(res)
    res=ReLU(trainable=False)(res)
    res=DepthwiseConv2D((1,5),padding='same',
                        #data_format='channels_first',
                        use_bias=False,
                        depthwise_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
                       )(res)
    res=Conv2D(32,1,use_bias=False,padding='same',
               #data_format='channels_first',
               kernel_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
              )(res)
    res=ReLU(trainable=False)(res)
    x=add([x,res])

    x=DepthwiseConv2D(3,padding='same',
                      #data_format='channels_first',
                      use_bias=False,
                      depthwise_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
                     )(x)
    x=Conv2D(16,1,use_bias=False,padding='same',
             #data_format='channels_first',
             kernel_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
            )(x)
    x=ReLU(trainable=False)(x)
    x=DepthwiseConv2D(3,padding='same',
                      #data_format='channels_first',
                      use_bias=False,
                    depthwise_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
                     )(x)
    x=Conv2D(upscale**2,1,use_bias=False,padding='same',
                           #data_format='channels_first',
                           kernel_initializer=RandomUniform(minval=-0.5, maxval=0.5, seed=None)
            )(x)
    ps=tf.nn.depth_to_space(x,upscale,data_format='NHWC',name='HR')
    model=None
    if quant:
        model=tf.keras.Model(inputs=inputs,outputs=x)
    else:
        model=tf.keras.Model(inputs=inputs,outputs=ps)
    model.summary()
    return model
    

In [None]:
def PSNR(y_true, y_pred):
    max_pixel = 1.0
    return (10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true),)))) / 2.303

def SSIM(y_true, y_pred):
    return tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))

class PredictionCallback(tf.keras.callbacks.Callback):
    def __init__(self,val_dir,image_size,upscale):
        super(PredictionCallback,self).__init__()
        self.size=image_size
        self.upscale=upscale
        self.img_files=[f for f in glob.glob(str(val_dir)+"/*.png")]
        
    def on_epoch_end(self,epoch,logs={}):
        if (epoch+1)%100 == 0:
            img=load_img(self.img_files[math.floor(random.uniform(1,88))])
            img=img.resize(
            (img.size[0] // self.upscale, img.size[1] // self.upscale),
            PIL.Image.BICUBIC,
            )
            ycbcr = img.convert("YCbCr")
            y, cb, cr = ycbcr.split()
            y = tf.keras.preprocessing.image.img_to_array(y)
            y = y.astype("float32") / 255.0
            input = np.expand_dims(y, axis=0)
            y_pred=self.model.predict(input)
            display(array_to_img(y_pred[0]))

In [None]:
train_dir=pathlib.Path('/workspace/SRDataset/train/HR/')
val_dir=pathlib.Path('/workspace/SRDataset/val/HR/')
img_size=128
upscale=2

In [None]:
def train():
    #Consolidate all the above functions
    config_defaults=dict(
    learning_rate=0.0001,
    batch_size=4,
    epochs=5000,
    optimizer='adam',
    loss='MAE',
    Dataset='SRDataset',
    Model='SRHW',
    )
    wandb.init(config=config_defaults)
    config=wandb.config
    train_ds=get_data(str(train_dir/'*'),batch_size=config.batch_size,img_size=img_size,upscale=upscale,shuffle=True)
    valid_ds=get_data(str(val_dir/'*'),batch_size=88,img_size=img_size*2,upscale=upscale,shuffle=True)
    model=get_model(quant=False)
    if config.optimizer=='adam':
        optimizer=Adam(learning_rate=config.learning_rate,name='Adam',clipnorm=1.0)
    elif config.optimizer=='sgd':
        optimizer=SGD(learning_rate=config.learning_rate,momentum=0.9,name='SGD',clipnorm=1.0)
    model.compile(optimizer,loss=tf.keras.losses.MeanAbsoluteError(),metrics=[PSNR,SSIM])
    callbacks=[
    WandbCallback(
        monitor='val_loss',
        mode='min',
        save_weights_only=True,
        log_weights=True
        ),
        PredictionCallback(val_dir,img_size*2,upscale)
    ]
    model.fit(train_ds,epochs=config.epochs,callbacks=callbacks,validation_data=valid_ds)

In [None]:
wandb.agent(sweep_id,train)

In [None]:
# from https://keras.io/examples/vision/super_resolution_sub_pixel/
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset


model_path='/workspace/SRHW/tf_ckpt/model1.h5'
dependencies = {
    'PSNR': PSNR,
    'SSIM':SSIM
}

# trained_model=tf.keras.models.load_model(model_path,custom_objects=dependencies)

def plot_results(img, prefix, title,area):
    """Plot the result with zoom-in area."""
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = img_array.astype("float32") / 255.0

    # Create a new figure with a default 111 subplot.
    fig, ax = plt.subplots()
    im = ax.imshow(img_array[::-1], origin="lower")

    plt.title(title)
#     # zoom-factor: 2.0, location: upper-left
#     axins = zoomed_inset_axes(ax, 2, loc=2)
#     axins.imshow(img_array[::-1], origin="lower")

#     # Specify the limits.
#     x1, x2, y1, y2 = area
#     # Apply the x-limits.
#     axins.set_xlim(x1, x2)
#     # Apply the y-limits.
#     axins.set_ylim(y1, y2)

#     plt.yticks(visible=False)
#     plt.xticks(visible=False)

#     # Make the line.
#     mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
#     plt.savefig(str(prefix) + "-" + title + ".png")
    plt.show()

def get_lowres_image(img, upscale_factor):
    """Return low-resolution image to use as model input."""
    return img.resize(
        (img.size[0] // upscale_factor, img.size[1] // upscale_factor),
        PIL.Image.BICUBIC,
    )
    

def upscale_image(model, img):
    """Predict the result based on input image and restore the image as RGB."""
    ycbcr = img.convert("YCbCr")
    y, cb, cr = ycbcr.split()
    y = tf.keras.preprocessing.image.img_to_array(y)
    y = y.astype("float32") / 255.0

    input = np.expand_dims(y, axis=0)
#     input= np.transpose(input,[0,3,1,2])
    out = model.predict(input)
    
#     out_img_y =np.transpose(out,[0,2,3,1])
    out_img_y = out[0]
#     out_img_y =np.transpose(out_img_y,[1,2,0])
    out_img_y *= 255.0
    

    # Restore the image in RGB color space.
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = out_img_y.reshape((np.shape(out_img_y)[0], np.shape(out_img_y)[1]))
    out_img_y = PIL.Image.fromarray(np.uint8(out_img_y), mode="L")
    out_img_cb = cb.resize(out_img_y.size, PIL.Image.BICUBIC)
    out_img_cr = cr.resize(out_img_y.size, PIL.Image.BICUBIC)
    out_img = PIL.Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
        "RGB"
    )
    return out_img

img_files=[f for f in glob.glob(str(val_dir)+"/*.png")]
img=tf.keras.preprocessing.image.load_img(img_files[math.floor(random.uniform(1,88))])
lr=get_lowres_image(img,2)
w,h=lr.size[0]*2,lr.size[1]*2
hr=img.resize((w,h))
sr=upscale_image(model,lr)
plot_results(lr,0, "lowres",(32, 96, 32, 96))
plot_results(hr, 0, "highres",(96, 160, 96, 160))
plot_results(sr, 0, "superres",(96, 160, 96, 160))

In [None]:
#dugout
# metrics for compile
# class PSNR(tf.keras.metrics.Metric):
#     def __init__(self,name='PSNR',**kwargs):
#         super(PSNR,self).__init__(name=name,**kwargs)
#         self.psnr=self.add_weight(name='PSNR', initializer='zeros')
    
#     def update_state(self,y_true,y_pred,sample_weight=None):
#         self.psnr.assign_add=tf.image.psnr(y_true,y_pred,max_val=1.0)
    
#     def result(self):
#         return self.psnr

# class SSIM(tf.keras.metrics.Metric):
#     def __init__(self,name='SSIM',**kwargs):
#         super(SSIM,self).__init__(name=name,**kwargs)
#         self.ssim=self.add_weight(name='SSIM',initializer='zeros')
    
#     def update_state(self,y_true,y_pred,sample_weight=None):
#         self.ssim.assign_add=tf.image.ssim(y_true,y_pred,max_val=1.0)
    
#     def result(self):
#         return self.ssim


# def create_model(input_dims,upscale,quant):
#     dim1,dim2= input_dims[1]//upscale,input_dims[2]//upscale 
#     model = SRHW((1,dim1,dim2),upscale=upscale,quant=quant)
#     model.build((None,1,dim1,dim2))
#     model.build_graph().summary()
#     tf.keras.utils.plot_model(
#         model.build_graph(), to_file='model.png', show_shapes=True, show_layer_names=True,
#         rankdir='TB', expand_nested=False, dpi=96
#     )
#     return model


# Model subclassing
# class SRHW(tf.keras.Model):
#     def __init__(self,upscale=2,quant=False,dim=(1,64,64)):
#         super(SRHW,self).__init__()
#         self.Conv1=Conv2D(32,3,padding="same",
#                           use_bias=False,data_format='channels_first',
#                           kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.DWConv1=DepthwiseConv2D((1,5),padding='same',
#                                     data_format='channels_first',use_bias=False,
#                                     depthwise_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.PWConv1=Conv2D(16,1,use_bias=False,padding='same',
#                            data_format='channels_first',
#                            kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.DWConv2=DepthwiseConv2D((1,5),padding='same',
#                                     data_format='channels_first',use_bias=False,
#                                     depthwise_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.PWConv2=Conv2D(32,1,use_bias=False,padding='same',
#                            data_format='channels_first',
#                            kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.DWConv3=DepthwiseConv2D(3,padding='same',
#                                     data_format='channels_first',use_bias=False,
#                                     depthwise_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.PWConv3=Conv2D(16,1,use_bias=False,padding='same',
#                            data_format='channels_first',
#                            kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.DWConv4=DepthwiseConv2D(3,padding='same',
#                                     data_format='channels_first',use_bias=False,
#                                     depthwise_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.PWConv4=Conv2D(upscale**2,1,use_bias=False,padding='same',
#                            data_format='channels_first',
#                            kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
#         self.upscale=upscale
#         self.relu1=ReLU()
#         self.relu2=ReLU()
#         self.relu3=ReLU()
#         self.relu4=ReLU()
#         self.quant=quant
#         self.dim=dim
        
#     def call(self,x):
#         x=self.Conv1(x)
#         res=self.relu1(x)
#         res=self.DWConv1(res)
#         res=self.PWConv1(res)
#         res=self.relu2(res)
#         res=self.PWConv2(self.DWConv2(res))
#         x=x+res
#         x=self.relu3(x)
#         x=self.relu4(self.PWConv3(self.DWConv3(x)))
#         x=self.PWConv4(self.DWConv4(x))
#         if self.quant:
#             return x
#         else:
#             return tf.nn.depth_to_space(x,self.upscale,data_format='NCHW')
#     def build_graph(self):
#         x = tf.keras.layers.Input(shape=(self.dim))
#         return tf.keras.Model(inputs=[x], outputs=self.call(x))
