In [1]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Conv2D,DepthwiseConv2D,ReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomUniform
from tensorflow.keras import layers
from tensorflow.keras.utils import get_custom_objects
import numpy as np
import wandb
from wandb.keras import WandbCallback
print(tf.__version__)
print(tf.keras.__version__)

1.15.2
2.2.4-tf


In [2]:
print(tf.test.gpu_device_name())

/device:GPU:0


In [3]:
# wandb init
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mkrislara[0m (use `wandb login --relogin` to force relogin)


True

In [5]:
config=dict(
    learning_rate=0.0001,
    batch_size=2,
    epochs=5000,
    loss='MAE',
    Dataset='SRDataset',
    Model='SRHW',
    )


In [99]:
# Data Loading - load and convert to yuv extract y, normalize for Y and X
def scaling(image):
    image =image/255.0
    return image

def preprocess_input(inp,input_size,upscale):
    inp=tf.image.rgb_to_yuv(inp)
    last_dim=len(inp.shape)-1
    y,u,v=tf.split(inp,3,axis=last_dim)
    y=tf.transpose(y,[2,0,1])
    y=tf.expand_dims(y,axis=0)
    y=tf.image.resize_bicubic(y,input_size//upscale)
    return y

def preprocess_target(inp):
    inp=tf.image.rgb_to_yuv(inp)
    last_dim=len(inp.shape)-1
    y,u,v=tf.split(inp,3,axis=last_dim)
    print('ytar shape:',y.shape)
    return tf.expand_dims(y,axis=0)

AUTOTUNE = tf.data.experimental.AUTOTUNE

def data_from_file(file_path):
    print(file_path)
    img=tf.io.read_file(file_path)
    img=tf.io.decode_image(img,3,dtype=tf.dtypes.float32,expand_animations=False)
    print(img.shape)
    return img
    
def get_data(data_dir,batch_size=2,img_size=128,upscale=2,shuffle=True):
    data_path=tf.data.Dataset.list_files(data_dir,shuffle=False)
    print(data_path)
    data=data_path.map(lambda x:data_from_file(x))
    data.batch(batch_size)
    data=data.map(lambda x:scaling(x))
    data=data.map(lambda x:(preprocess_input(x,input_size=img_size,upscale=upscale),preprocess_target(x)))
    data=data.prefetch(buffer_size=AUTOTUNE)
    return data

In [63]:
# Model definition
class SRHW(tf.keras.Model):
    def __init__(self,dim,upscale=2,quant=False):
        super(SRHW,self).__init__()
        self.Conv1=Conv2D(32,3,padding="same",
                          use_bias=False,data_format='channels_first',
                          kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.DWConv1=DepthwiseConv2D((1,5),padding='same',
                                    data_format='channels_first',use_bias=False,
                                    depthwise_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.PWConv1=Conv2D(16,1,use_bias=False,padding='same',
                           data_format='channels_first',
                           kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.DWConv2=DepthwiseConv2D((1,5),padding='same',
                                    data_format='channels_first',use_bias=False,
                                    depthwise_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.PWConv2=Conv2D(32,1,use_bias=False,padding='same',
                           data_format='channels_first',
                           kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.DWConv3=DepthwiseConv2D(3,padding='same',
                                    data_format='channels_first',use_bias=False,
                                    depthwise_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.PWConv3=Conv2D(16,1,use_bias=False,padding='same',
                           data_format='channels_first',
                           kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.DWConv4=DepthwiseConv2D(3,padding='same',
                                    data_format='channels_first',use_bias=False,
                                    depthwise_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.PWConv4=Conv2D(upscale**2,1,use_bias=False,padding='same',
                           data_format='channels_first',
                           kernel_initializer=RandomUniform(minval=-0.05, maxval=0.05, seed=None))
        self.upscale=upscale
        self.relu1=ReLU()
        self.relu2=ReLU()
        self.relu3=ReLU()
        self.relu4=ReLU()
        self.quant=quant
        self.dim=dim
        
    def call(self,x):
        x=self.Conv1(x)
        res=self.relu1(x)
        res=self.DWConv1(res)
        res=self.PWConv1(res)
        res=self.relu2(res)
        res=self.PWConv2(self.DWConv2(res))
        x=x+res
        x=self.relu3(x)
        x=self.relu4(self.PWConv3(self.DWConv3(x)))
        x=self.PWConv4(self.DWConv4(x))
        if self.quant:
            return x
        else:
            return tf.nn.depth_to_space(x,self.upscale,data_format='NCHW')
    def build_graph(self):
        x = tf.keras.layers.Input(shape=(self.dim))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

In [64]:
def create_model(input_dims,upscale,quant):
    dim = input_dims
    model = SRHW(dim,upscale=upscale,quant=quant)
    model.build((None, *dim))
    model.build_graph().summary()
    tf.keras.utils.plot_model(
        model.build_graph(), to_file='model.png', show_shapes=True, show_layer_names=True,
        rankdir='TB', expand_nested=False, dpi=96
    )
    return model

In [65]:
# create_model((1,64,64),True)

In [66]:
# callbacks wandb -save best model
callbacks=[
    WandbCallback(
        monitor='val_loss',
        mode='min',
        save_model=True,
        log_weights=True
        )
]

In [67]:
# metrics for compile
class PSNR(tf.keras.metrics.Metric):
    def __init__(self,name='PSNR',**kwargs):
        super(PSNR,self).__init__(name=name,**kwargs)
        self.psnr=self.add_weight(name='psnr', initializer='zeros')
    
    def update_state(self,y_true,y_pred,sample_weight=None):
        self.psnr.assign_add=tf.image.psnr(y_true,y_pred,max_val=1.0)
    
    def result(self):
        return self.psnr

class SSIM(tf.keras.metrics.Metric):
    def __init__(self,name='SSIM',**kwargs):
        super(SSIM,self).__init__(name=name,**kwargs)
        self.ssim=self.add_weight(name='ssim',initializer='zeros')
    
    def update_state(self,y_true,y_pred,sample_weight=None):
        y_true=tf.transpose(y_true,[0,2,3,1])
        y_pred=tf.transpose(y_pred,[0,2,3,1])
        self.ssim.assign_add=tf.image.ssim(y_true,y_pred,max_val=1.0)
    
    def result(self):
        return self.ssim
        

In [68]:
# Compile and fit (generator) with callbacks and metrics


In [69]:
# images_nhwc = tf.placeholder(tf.float32, [None, 200, 300, 3])  # input batch
# out = tf.transpose(images_nhwc, [0, 3, 1, 2])
# print(out.get_shape())

In [100]:
train_dir='/workspace/SRDataset/train/HR/'
val_dir='/workspace/SRDataset/val/HR/'
img_size=128
upscale=2

In [102]:
def train(config):
    #Consolidate all the above functions
#     wandb.init(config=config)
#     config=wandb.config
    train_ds=get_data(str(train_dir+'*.png'),batch_size=2,img_size=img_size,upscale=upscale,shuffle=True)
    val_ds=get_data(val_dir,batch_size=config.batch_size,img_height=img_height*2,img_width=img_width*2,upscale=upscale,shuffle=True)
#     model=get_model((1,img_height,img_width),upscale=upscale,quant=False)
#     optimizer=Adam(learning_rate=config.learning_rate,name='Adam')
#     model.compile(optimizer,loss=tf.keras.losses.MeanAbsoluteError(),metrics=[PSNR(),SSIM()])
#     model.fit(train_ds,epochs=config.epochs,callbacks=callbacks,validation_data=valid_ds,verbose=2)


In [103]:
train(config)

<DatasetV1Adapter shapes: (), types: tf.string>
Tensor("args_0:0", shape=(), dtype=string)
(?, ?, ?)


ValueError: in converted code:

    <ipython-input-99-719cd71e5d28>:37 None  *
        data=data.map(lambda x:(preprocess_input(x,input_size=img_size,upscale=upscale),preprocess_target(x)))
    <ipython-input-99-719cd71e5d28>:12 preprocess_input  *
        y=tf.image.resize_bicubic(y,input_size//upscale)
    /opt/vitis_ai/conda/envs/vitis-ai-tensorflow/lib/python3.6/site-packages/tensorflow_core/python/ops/image_ops_impl.py:3547 resize_bicubic
        name=name)
    /opt/vitis_ai/conda/envs/vitis-ai-tensorflow/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_image_ops.py:3288 resize_bicubic
        half_pixel_centers=half_pixel_centers, name=name)
    /opt/vitis_ai/conda/envs/vitis-ai-tensorflow/lib/python3.6/site-packages/tensorflow_core/python/framework/op_def_library.py:794 _apply_op_helper
        op_def=op_def)
    /opt/vitis_ai/conda/envs/vitis-ai-tensorflow/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:548 create_op
        compute_device)
    /opt/vitis_ai/conda/envs/vitis-ai-tensorflow/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:3426 _create_op_internal
        op_def=op_def)
    /opt/vitis_ai/conda/envs/vitis-ai-tensorflow/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1770 __init__
        control_input_ops)
    /opt/vitis_ai/conda/envs/vitis-ai-tensorflow/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1610 _create_c_op
        raise ValueError(str(e))

    ValueError: Shape must be rank 1 but is rank 0 for 'ResizeBicubic' (op: 'ResizeBicubic') with input shapes: [1,1,?,?], [].
