In [8]:
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.core.framework import graph_pb2

from model_2d import hand_landmark_2d_model
from utils import get_pretrained_tflite_weights, set_pretrained_weights, define_fake_4_channels_graph, display_nodes

In [2]:
# Create 4-channels inputs layer with slice operation
fake_graph = define_fake_4_channels_graph(input_size=(256, 256, 4))

input_size = (256, 256, 3)
model = hand_landmark_2d_model(input_size)
model.summary()

Instructions for updating:
Colocations handled automatically by placer.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 16) 448         input_1[0][0]                    
__________________________________________________________________________________________________
p_re_lu (PReLU)                 (None, 128, 128, 16) 16          conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 8)  136         p_re_lu[0][0]                    
_____________________________________

In [3]:
# get pretrained TFLite weights
tflite_model_path = "./pretrained_models/hand_landmark.tflite"
pretrained_weights_dict, layer_names = get_pretrained_tflite_weights(tflite_model_path)

# set pretrained weights in defined model
set_pretrained_weights(model, pretrained_weights_dict, layer_names)
print("Set all pretrained weights")

Set all pretrained weights


In [18]:
# Replace input layer and concatenate graph def
sess = K.get_session()
output_names = [node.op.name for node in model.outputs]
frozen_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names)

## check nodes
display_nodes(frozen_def.node[0:30])

Instructions for updating:
Use tf.compat.v1.graph_util.convert_variables_to_constants
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
INFO:tensorflow:Froze 503 variables.
INFO:tensorflow:Converted 503 variables to const ops.
0 input_1 Placeholder
1 conv2d/kernel Const
2 conv2d/bias Const
3 conv2d/Conv2D/ReadVariableOp Identity
└─── 0 ─ conv2d/kernel
4 conv2d/Conv2D Conv2D
└─── 0 ─ input_1
└─── 1 ─ conv2d/Conv2D/ReadVariableOp
5 conv2d/BiasAdd/ReadVariableOp Identity
└─── 0 ─ conv2d/bias
6 conv2d/BiasAdd BiasAdd
└─── 0 ─ conv2d/Conv2D
└─── 1 ─ conv2d/BiasAdd/ReadVariableOp
7 p_re_lu/alpha Const
8 p_re_lu/Relu Relu
└─── 0 ─ conv2d/BiasAdd
9 p_re_lu/ReadVariableOp Identity
└─── 0 ─ p_re_lu/alpha
10 p_re_lu/Neg Neg
└─── 0 ─ p_re_lu/ReadVariableOp
11 p_re_lu/Neg_1 Neg
└─── 0 ─ conv2d/BiasAdd
12 p_re_lu/Relu_1 Relu
└─── 0 ─ p_re_lu/Neg_1
13 p_re_lu/mul Mul
└─── 0 ─ p_re_lu/Neg
└─── 1 ─ p_re_lu/Relu_1
14 p_re_lu/add Add
└─── 0 ─ p_re_lu/Relu
└─── 1 ─ p_re_lu/mul
15 co

In [19]:
display_nodes(fake_graph.node)

0 input_1 Placeholder
1 slicing_inputs/strided_slice/stack Const
2 slicing_inputs/strided_slice/stack_1 Const
3 slicing_inputs/strided_slice/stack_2 Const
4 slicing_inputs/strided_slice StridedSlice
└─── 0 ─ input_1
└─── 1 ─ slicing_inputs/strided_slice/stack
└─── 2 ─ slicing_inputs/strided_slice/stack_1
└─── 3 ─ slicing_inputs/strided_slice/stack_2


In [22]:
# change input name for fake 4-channels
frozen_def.node[4].input[0] = 'slicing_inputs/strided_slice'
# merge original graph with fake graph 
concat_nodes = fake_graph.node[0:] + frozen_def.node[1:]

## check again
display_nodes(concat_nodes[:30])

0 input_1 Placeholder
1 slicing_inputs/strided_slice/stack Const
2 slicing_inputs/strided_slice/stack_1 Const
3 slicing_inputs/strided_slice/stack_2 Const
4 slicing_inputs/strided_slice StridedSlice
└─── 0 ─ input_1
└─── 1 ─ slicing_inputs/strided_slice/stack
└─── 2 ─ slicing_inputs/strided_slice/stack_1
└─── 3 ─ slicing_inputs/strided_slice/stack_2
5 conv2d/kernel Const
6 conv2d/bias Const
7 conv2d/Conv2D/ReadVariableOp Identity
└─── 0 ─ conv2d/kernel
8 conv2d/Conv2D Conv2D
└─── 0 ─ slicing_inputs/strided_slice
└─── 1 ─ conv2d/Conv2D/ReadVariableOp
9 conv2d/BiasAdd/ReadVariableOp Identity
└─── 0 ─ conv2d/bias
10 conv2d/BiasAdd BiasAdd
└─── 0 ─ conv2d/Conv2D
└─── 1 ─ conv2d/BiasAdd/ReadVariableOp
11 p_re_lu/alpha Const
12 p_re_lu/Relu Relu
└─── 0 ─ conv2d/BiasAdd
13 p_re_lu/ReadVariableOp Identity
└─── 0 ─ p_re_lu/alpha
14 p_re_lu/Neg Neg
└─── 0 ─ p_re_lu/ReadVariableOp
15 p_re_lu/Neg_1 Neg
└─── 0 ─ conv2d/BiasAdd
16 p_re_lu/Relu_1 Relu
└─── 0 ─ p_re_lu/Neg_1
17 p_re_lu/mul Mul
└─── 0 

In [23]:
# Save new defined graph
out_path = "./pretrained_models/hand_landmark_4channels.pb"
concat_graph_def = graph_pb2.GraphDef()
concat_graph_def.node.extend(concat_nodes)

with tf.gfile.GFile(out_path, 'w') as f:
    f.write(concat_graph_def.SerializeToString())

In [32]:
# Convert TFLite model from frozen graph
outPath = "./pretrained_models/hand_landmark_4channels.tflite"
frozen_graph_file = "./pretrained_models/hand_landmark_4channels.pb"
input_names = ['input_1']
output_names = ['ld_21_2d/Reshape', 'output_handflag/Reshape']
converter = tf.lite.TFLiteConverter.from_frozen_graph(frozen_graph_file, input_names, output_names)

tflite_model = converter.convert()
open(outPath, "wb").write(tflite_model)

11403944