In [1]:
import matplotlib.pyplot as plt
import numpy as np

from mnist import MNIST
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Conv2D, MaxPool2D, GlobalAveragePooling2D, Dense, Flatten, BatchNormalization, ReLU, Reshape, Lambda
import tensorflow.keras.backend as K

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
# open data
mndata = MNIST('./python-mnist/data/')
images, labels = mndata.load_training()

# convert to numpy
images = np.array(images)
labels = np.array(labels)

# transform
images = images.reshape(-1,28,28)#.astype(float)/255
b = np.zeros((labels.size, labels.max()+1))
b[np.arange(labels.size),labels] = 1
labels = b

# check
images.shape, labels.shape

((60000, 28, 28), (60000, 10))

In [3]:
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.1, random_state=42)

# Model

In [4]:
model = Sequential()

model.add(Reshape((28,28,1), input_shape=(28,28)))
model.add(Lambda(lambda x: K.cast(x, "float32")/255))
model.add(Conv2D(2,kernel_size=(2,2)))
#model.add(BatchNormalization())
model.add(ReLU())
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Conv2D(4,kernel_size=(2,2)))
#model.add(BatchNormalization())
model.add(ReLU())
model.add(MaxPool2D())
model.add(Flatten())
model.add(Dense(4))
#model.add(BatchNormalization())
model.add(ReLU())
model.add(Dense(10, activation='sigmoid'))

opt = Adam(learning_rate=0.0001)
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [5]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
lambda (Lambda)              (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 27, 27, 2)         10        
_________________________________________________________________
re_lu (ReLU)                 (None, 27, 27, 2)         0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 2)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 12, 12, 4)         36        
_________________________________________________________________
re_lu_1 (ReLU)               (None, 12, 12, 4)         0

# Training

In [None]:
model.fit(X_train,y_train, epochs=8, batch_size=16, validation_data=(X_test,y_test))

Train on 54000 samples, validate on 6000 samples
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8

## Original acc

In [None]:
preds = model.predict(X_test)

In [None]:
res = np.argmax(preds,axis=1) == np.argmax(y_test,axis=1)
np.sum(res)/len(res)

# Save

In [None]:
model.save("model.h5")

# Convert to TF format

In [None]:
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph

In [None]:
k_inputs = [out.op.name for out in model.inputs]
k_outputs = [out.op.name for out in model.outputs]
k_inputs, k_outputs

In [None]:
frozen_graph = freeze_session(K.get_session(), output_names=k_outputs)

In [None]:
tf.train.write_graph(frozen_graph, "tf_model", "my_model.pb", as_text=False)

# TODO: To TRT

In [None]:
graph = tf.Graph()
sess = tf.InteractiveSession(graph = graph)

with tf.gfile.GFile("./tf_model/my_model.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

In [None]:
from tensorflow.python.compiler.tensorrt import trt_convert as trt

In [None]:
converter = trt.TrtGraphConverter(
            input_graph_def=graph_def,
            nodes_blacklist=k_outputs,
            max_workspace_size_bytes=1 << 32,
            precision_mode='INT8',
            minimum_segment_size=2,
            is_dynamic_op=True,
            maximum_cached_engines=100)
trt_frozen_graph = converter.convert()

In [None]:
type(trt_frozen_graph)