In [1]:
!pip -q install keras-flops

[0m

In [2]:
import os
import numpy as np
import gc

#import keras
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint,CSVLogger
from tensorflow.keras import layers as L
from tensorflow.keras.models import Sequential , Model
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda
from tensorflow.keras.layers import *
import tensorflow_addons as tfa
import keras_flops
from keras_flops import get_flops

import warnings
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Models

In [3]:
def keras2tflite (model, name, fp16=False):
    print ('converting...')
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    if fp16:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.target_spec.supported_types = [tf.float16, tf.int8]

    # Be very careful here:
    # "experimental_new_converter" is enabled by default in TensorFlow 2.2+. However, using the new MLIR TFLite
    # converter might result in corrupted / incorrect TFLite models for some particular architectures. Therefore, the
    # best option is to perform the conversion using both the new and old converter and check the results in each case:
    #converter.target_ops= [TFLITE_BUILTINS,SELECT_TF_OPS]
    #converter.target_spec.supported_ops = [
    #    tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
    #    tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
    #]
    
    converter.experimental_new_converter = False
    tflite_model = converter.convert()
    open(f"{name}", "wb").write(tflite_model)
    print ('saved!')

In [5]:
IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS = 400, 400, 3
HR_HEIGHT,  HR_WIDTH,  HR_CHANNELS  = 1080, 1920, 3
LR_HEIGHT,  LR_WIDTH,  LR_CHANNELS  = 256, 256, 3

conv_activation = tf.keras.layers.LeakyReLU()
conv_activation = tf.keras.activations.elu

def convolution_block(x, filters, size, strides=(1,1), padding='same', activation=True, bn=False, dilation=1):
    x = tf.keras.layers.Conv2D(filters, size, strides=strides, padding=padding, dilation_rate=dilation)(x) #name='{}_conv'.format(name)
    if bn:
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
    if activation:
        x = conv_activation(x)
    return x

def residual_subblock(blockInput,num_filters):
    x = convolution_block(blockInput, num_filters, (3,3) ,activation=True)
    x = convolution_block(x, num_filters, (3,3), activation=True)
    x = tf.keras.layers.Add()([x, blockInput])
    #x = conv_activation(x)
    return x

def inverted_linear_residual_block(x, expand=64, squeeze=16):
    m = Conv2D(expand, (1,1), activation='elu', strides=(1,1), padding='same')(x)
    m = DepthwiseConv2D((3,3), activation='elu', strides=(1,1), padding='same')(m)
    m = Conv2D(squeeze, (1,1))(m)
    return Add()([m, x])

def inverted_proj_block(x, proj=16):
    m = DepthwiseConv2D((3,3),  activation='elu', strides=(1,1), padding='same')(x)
    m = Conv2D(proj, (1,1))(m)
    return m

def CALayer(blockInput,num_filters):
    '''
    Dilated Attention Block (DAB)
    '''
    y = blockInput
    filtersCount = blockInput.shape[-1]
    x0 = convolution_block(y,num_filters,(3,3),activation=True,dilation=1)
    x1 = convolution_block(y,num_filters,(3,3),activation=True,dilation=2)
    x2 = convolution_block(y,num_filters,(3,3),activation=True,dilation=4)
    out = Concatenate(axis=3)([x0,x1,x2])

    sg = tf.keras.layers.Conv2D(filtersCount, (3,3), strides=1, padding="same")(out)
    sg = Activation("sigmoid")(sg)
    return  tf.keras.layers.Multiply()([blockInput,sg])


def residual_dense_attention(blockInput, num_filters=16):
    '''
    RDB block with attention DAB
    '''
    count = 3
    li = [blockInput]
    pas= convolution_block(blockInput, num_filters,size=(3,3),strides=(1,1))
    for i in range(2 , count+1):
        li.append(pas)
        out = tf.keras.layers.Concatenate(axis = 3)(li) # conctenated out put
        pas = convolution_block(out,num_filters,size=(3,3),strides=(1,1))
        pas = residual_subblock(pas,num_filters)
        #pas = inverted_linear_residual_block(pas, expand=num_filters*2, squeeze=num_filters)
    
    li.append(pas)
    out = Concatenate(axis=3)(li)
    out = tf.keras.layers.Conv2D(num_filters, (3,3), strides=(1,1), padding="same", activation='relu')(out)
    return out

def DAB (x, dim=64):
    
    inputs = x
    
    for i in range(2):
        x = tf.keras.layers.Conv2D(dim, (3,3), strides=(1,1), padding="same", activation='relu')(x)
    
    shortcut = x
    
    gap = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(x)
    gmp = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(x)
    
    ## spatial attention
    gap_gmp = Concatenate(axis=3)([gap, gmp])
    gap_gmp = tf.keras.layers.Conv2D(dim, (3,3), strides=(1,1), 
                               padding="same", 
                               activation='sigmoid')(gap_gmp)
    
    spatial_attention = multiply([shortcut, gap_gmp])
    
    ## channel attention
    x1 = tf.keras.layers.Conv2D(dim, (1,1), strides=(1,1), 
                               padding="same", 
                               activation='relu')(gap)
    x1 = tf.keras.layers.Conv2D(dim, (1,1), strides=(1,1), 
                               padding="same", 
                               activation='sigmoid')(x1)
    
    channel_attention = multiply([shortcut, x1])
    
    
    attention = Concatenate(axis=3)([spatial_attention, channel_attention])
    x2 = tf.keras.layers.Conv2D(dim, (1,1), strides=(1,1), 
                               padding="same", 
                               activation='relu')(attention)
    
    #input_project = tf.keras.layers.Conv2D(dim, (1,1), strides=(1,1), 
    #                           padding="same", 
    #                           activation='relu')(inputs)
    
    out = Add()([inputs, x2])
    return out
    

def RRG(x, kernel_size, reduction, n_feats=64, num_dab=8):
    '''Recursive Residual Group
    source: https://github.com/swz30/CycleISP'''
    shortcut = x
    for _ in range(num_dab):
        x = DAB (x,dim=n_feats)
        
    x = tf.keras.layers.Conv2D(n_feats, (3,3), strides=(1,1), padding="same", activation='relu')(x)
    out = out = Add()([shortcut, x])
    return out

def basic_encoder(blockInput,num_filters,activation=True):
    x = convolution_block(blockInput, num_filters, (3,3) ,activation=True)
    x = convolution_block(x, num_filters, (3,3), activation=True)
    x = tf.keras.layers.Add()([x, convolution_block(blockInput, num_filters, (3,3), activation=True)])
    if activation:
        x = tf.keras.layers.LeakyReLU()(x)
    return x

### Main Blocks

In [6]:
gelu = tf.keras.activations.gelu
selu = tf.keras.activations.selu
elu = tf.keras.activations.elu

def attention_block (x, dim=16):
    
    inputs = x
    
    #x = tf.keras.layers.Conv2D(dim, (3,3), strides=(1,1), 
    #                           padding="same", 
    #                           activation='relu')(x)
    
    shortcut = x
    
    gap = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(x)
    gmp = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(x)
    
    ## spatial attention
    gap_gmp = Concatenate(axis=3)([gap, gmp])
    gap_gmp = tf.keras.layers.Conv2D(dim, (3,3), strides=(1,1), 
                               padding="same", 
                               activation='sigmoid')(gap_gmp)
    
    spatial_attention = multiply([shortcut, gap_gmp])
    
    ## channel attention
    x1 = tf.keras.layers.Conv2D(dim, (1,1), strides=(1,1), 
                               padding="same", 
                               activation='relu')(gap)
    x1 = tf.keras.layers.Conv2D(dim, (1,1), strides=(1,1), 
                               padding="same", 
                               activation='sigmoid')(x1)
    
    channel_attention = multiply([shortcut, x1])
    
    
    attention = Concatenate(axis=3)([spatial_attention, channel_attention])
    x2 = tf.keras.layers.Conv2D(dim, (1,1), strides=(1,1), 
                               padding="same", 
                               activation=None)(attention)
    
    #input_project = tf.keras.layers.Conv2D(dim, (1,1), strides=(1,1), 
    #                           padding="same", 
    #                           activation=None)(inputs)
    
    out = Add()([inputs, x2])
    return out
    

def RRG(x,dim=16):
    '''Recursive Residual Group
    source: https://github.com/swz30/CycleISP'''
    x = attention_block(x,dim)
    return out

def flatten(x) :
    return tf.layers.flatten(x)

def hw_flatten(x) :
    return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])

def sagan_block(x, channels):
    f = Conv2D(channels, (1,1), activation=None, strides=(1,1), padding='same')(x)
    g = Conv2D(channels, (1,1), activation=None, strides=(1,1), padding='same')(x)
    h = Conv2D(channels, (1,1), activation=None, strides=(1,1), padding='same')(x)

    f = tf.transpose(f)
    att_map = f*g
    att_map = tf.keras.activations.softmax (att_map)
    fe = att_map * h
    fe = Conv2D(channels, (1,1), activation='sigmoid', strides=(1,1), padding='same')(fe)
    return fe

def sat(x, channels=3):
    f = Conv2D(channels, (7,7), activation='relu', strides=(1,1), padding='same')(x)
    f = Conv2D(channels, (5,5), activation='relu', strides=(1,1), padding='same')(f)
    f = Conv2D(channels, (3,3), activation='sigmoid', strides=(1,1), padding='same')(f)
    return x * f

def inv_block(x, channels=3):
    m = x
    m = Conv2D(channels, (1,1), activation =None, strides=(1,1), padding='same')(m)
    m = DepthwiseConv2D((3,3), activation=None, strides=(1,1), padding='same')(m)
    m = elu(m)
    m = Conv2D(channels, (1,1))(m)
    
    x = Conv2D(channels, (1,1), activation ='relu', strides=(1,1), padding='same')(x)
    y = Add()([m, x])
    return y
    
def baseblock(x, channels=32):
    #m = LayerNormalization()(x)
    m = x
    m = Conv2D(channels, (1,1), activation =None, strides=(1,1), padding='same')(m)
    m = DepthwiseConv2D((3,3), activation=None, strides=(1,1), padding='same')(m)
    m = elu(m)
    m = Conv2D(channels, (1,1))(m)
    y = Add()([m, x])
    #m = LayerNormalization()(m)
    m = Conv2D(channels, (1,1), activation= None, strides=(1,1), padding='same')(m)
    m = elu(m)
    m = Conv2D(channels, (1,1), activation= None, strides=(1,1), padding='same')(m)
    m = Add()([m, y])
    return m

In [7]:
def build_ours(input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS),learning_rate=0.001):

    encoder_dim = [16, 32, 64]
    encoder_fes = []
    decoder_dim = [32, 16]
    enc_dec_cat = [1, 0]
    
    inputs = tf.keras.layers.Input(input_shape)
    x = inputs
    
    for e in range(len(encoder_dim)):
        x = tf.keras.layers.Conv2D(encoder_dim[e], (3, 3), activation="relu", padding="same")(x)
        x = baseblock(x,encoder_dim[e])
        x = baseblock(x,encoder_dim[e])
        x = attention_block(x,encoder_dim[e])
        encoder_fes.append(x)
        print ('e', e, x.shape)
        if e != (len(encoder_dim)-1):
            x = tf.keras.layers.MaxPooling2D((2, 2))(x)

    for d in range(len(decoder_dim)):
        x = tf.keras.layers.Conv2D(decoder_dim[d], (3, 3), activation="relu", padding="same")(x)
        x = baseblock(x,decoder_dim[d])
        x = baseblock(x,decoder_dim[d])
        x = attention_block(x,decoder_dim[d])
        print ('d', d, x.shape)
        x = tf.keras.layers.UpSampling2D(size=(2,2),interpolation='bilinear')(x)
        #x = Conv2D(x.shape[-1], (3,3), activation =None, strides=(1,1), padding='same')(x)
        x = tf.keras.layers.Concatenate()([x, encoder_fes[enc_dec_cat[d]]])
    
    x = inv_block(x,3)
    x = sat(x)
    x = x + inputs
    
    model = tf.keras.models.Model(inputs=[inputs], outputs=[x])
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate), loss='mse')
    return model

In [8]:
tf.keras.backend.clear_session()

model = build_ours(input_shape=(3840 , 2160, 3 ))
print (model.count_params() / 1_000_000, 'M params.')
#model.summary()

flops = get_flops(model, batch_size=1)
print(f"FLOPS: {flops / 1e9} G")

MODEL_NAME = 'wours_4k.tflite'
keras2tflite(model, name=MODEL_NAME, fp16=True)

e 0 (None, 3840, 2160, 16)
e 1 (None, 1920, 1080, 32)
e 2 (None, 960, 540, 64)
d 0 (None, 960, 540, 32)
d 1 (None, 1920, 1080, 16)
0.133652 M params.

-max_depth                  10000
-min_bytes                  0
-min_peak_bytes             0
-min_residual_bytes         0
-min_output_bytes           0
-min_micros                 0
-min_accelerator_micros     0
-min_cpu_micros             0
-min_params                 0
-min_float_ops              1
-min_occurrence             0
-step                       -1
-order_by                   float_ops
-account_type_regexes       .*
-start_name_regexes         .*
-trim_name_regexes          
-show_name_regexes          .*
-hide_name_regexes          
-account_displayed_op_only  true
-select                     float_ops
-output                     stdout:


Doc:
scope: The nodes in the model graph are organized by their names, which is hierarchical like filesystem.
flops: Number of float operations. Note: Please read the implementation for 