# TensorFlow-TensorRT Sample
This notebook shows simple process of model optimization from TensorFlow to TensorRT.

## Import dependent packages

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.ops import data_flow_ops
import tensorflow.contrib.tensorrt as trt

import numpy as np
import time
from tensorflow.python.platform import gfile
from tensorflow.python.client import timeline
import argparse, sys, itertools,datetime
import json
tf.logging.set_verbosity(tf.logging.INFO)

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" #selects a specific device

  return f(*args, **kwds)


Instructions for updating:
Use the retry module or similar alternatives.


## TensorRT integration options
**Output Layer**'s name can be obtained from the model freezing code.

In [2]:
config = {
    # models
    "frozen_model_file": "./frozen_models/resnetV150_frozen.pb",
    "output_layer": "resnet_v1_50/predictions/Reshape_1",
    
    # Parameters
    "FP32": True,
    "FP16": True,
    "INT8": True,
    "native": True,
    "num_loops": 20,
    "topN": 10,
    "batch_size": 128,
    "dump_diff": True,
    "with_timeline": True,
    "workspace_size": 1<<10,
    "update_graphdef": True
}

In [3]:
def read_tensor_from_image_file(file_name, input_height=224, input_width=224,
                                input_mean=0, input_std=255):
  """ Read a jpg image file and return a tensor """
  input_name = "file_reader"
  output_name = "normalized"
  file_reader = tf.read_file(file_name, input_name)
  image_reader = tf.image.decode_png(file_reader, channels = 3,
                                       name='jpg_reader')
  float_caster = tf.cast(image_reader, tf.float32)
  dims_expander = tf.expand_dims(float_caster, 0);
  resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
  normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
  sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.50)))
  result = sess.run([normalized,tf.transpose(normalized,perm=(0,3,1,2))])
  del sess

  return result

In [4]:
def getSimpleGraphDef():
  """Create a simple graph and return its graph_def"""
  if gfile.Exists("origgraph"):
    gfile.DeleteRecursively("origgraph")
  g = tf.Graph()
  with g.as_default():
    A = tf.placeholder(dtype=tf.float32, shape=(None, 224, 224, 3), name="input")
    e = tf.constant(
        [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.],[1.,1.,1.,1.,1.,1.]]]],
        name="weights",
        dtype=tf.float32)
    conv = tf.nn.conv2d(
        input=A, filter=e, strides=[1, 1, 1, 1],dilations=[1,1,1,1], padding="SAME", name="conv")
    b = tf.constant([4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32)
    t = tf.nn.bias_add(conv, b, name="biasAdd")
    relu = tf.nn.relu(t, "relu")
    idty = tf.identity(relu, "ID")
    v = tf.nn.max_pool(
        idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
    out = tf.squeeze(v, name=config["output_layer"])
    writer = tf.summary.FileWriter("origgraph", g)
    writer.close()
    
  return g.as_graph_def()

In [6]:
def getResnet50():
  with gfile.FastGFile(config["frozen_model_file"], 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  return graph_def

In [5]:
def updateGraphDef(fileName):
  with gfile.FastGFile(fileName,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  tf.reset_default_graph()
  g=tf.Graph()
  with g.as_default():
    tf.import_graph_def(graph_def,name="")
    with gfile.FastGFile(fileName,'wb') as f:
      f.write(g.as_graph_def().SerializeToString())

In [7]:
def printStats(graphName,timings,batch_size):
  if timings is None:
    return
  times=np.array(timings)
  speeds=batch_size / times
  avgTime=np.mean(timings)
  avgSpeed=batch_size/avgTime
  stdTime=np.std(timings)
  stdSpeed=np.std(speeds)
  print("images/s : %.1f +/- %.1f, s/batch: %.5f +/- %.5f"%(avgSpeed,stdSpeed,avgTime,stdTime))
  print("RES, %s, %s, %.2f, %.2f, %.5f, %.5f"%(graphName,batch_size,avgSpeed,stdSpeed,avgTime,stdTime))

In [8]:
def getFP32(batch_size=128,workspace_size=1<<30):
  trt_graph = trt.create_inference_graph(getResnet50(), [ config["output_layer"] ],
                                         max_batch_size=batch_size,
                                         max_workspace_size_bytes=workspace_size,
                                         precision_mode="FP32")  # Get optimized graph
  with gfile.FastGFile("resnetV150_TRTFP32.pb",'wb') as f:
    f.write(trt_graph.SerializeToString())
  return trt_graph

In [9]:
def getFP16(batch_size=128,workspace_size=1<<30):
  trt_graph = trt.create_inference_graph(getResnet50(), [ config["output_layer"] ],
                                         max_batch_size=batch_size,
                                         max_workspace_size_bytes=workspace_size,
                                         precision_mode="FP16")  # Get optimized graph
  with gfile.FastGFile("resnetV150_TRTFP16.pb",'wb') as f:
    f.write(trt_graph.SerializeToString())
  return trt_graph

Making INT8 calibration using create_inference_graph. The output is a frozen ready for calibration

In [10]:
def getINT8CalibGraph(batch_size=128,workspace_size=1<<30):
  trt_graph = trt.create_inference_graph(getResnet50(), [ config["output_layer"] ],
                                         max_batch_size=batch_size,
                                         max_workspace_size_bytes=workspace_size,
                                         precision_mode="INT8")  # calibration
  with gfile.FastGFile("resnetV150_TRTINT8Calib.pb",'wb') as f:
    f.write(trt_graph.SerializeToString())
  return trt_graph

And, we do calibration with trt calibration graph.

In [11]:
def getINT8InferenceGraph(calibGraph):
  trt_graph=trt.calib_graph_to_infer_graph(calibGraph)
  with gfile.FastGFile("resnetV150_TRTINT8.pb",'wb') as f:
    f.write(trt_graph.SerializeToString())
  return trt_graph

In [12]:
def timeGraph(gdef,batch_size=128,num_loops=100,dummy_input=None,timelineName=None):
  tf.logging.info("Starting execution")
  gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
  tf.reset_default_graph()
  g = tf.Graph()
  if dummy_input is None:
    dummy_input = np.random.random_sample((batch_size,224,224,3))
  outlist=[]
  with g.as_default():
    inc=tf.constant(dummy_input, dtype=tf.float32)
    dataset=tf.data.Dataset.from_tensors(inc)
    dataset=dataset.repeat()
    iterator=dataset.make_one_shot_iterator()
    next_element=iterator.get_next()
    out = tf.import_graph_def(
      graph_def=gdef,
      input_map={"input":next_element},
      return_elements=[ config["output_layer"] ]
    )
    out = out[0].outputs[0]
    outlist.append(out)
    
  timings=[]

  with tf.Session(graph=g,config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    tf.logging.info("Starting Warmup cycle")
    def mergeTraceStr(mdarr):
      tl=timeline.Timeline(mdarr[0][0].step_stats)
      ctf=tl.generate_chrome_trace_format()
      Gtf=json.loads(ctf)
      deltat=mdarr[0][1][1]
      for md in mdarr[1:]:
        tl=timeline.Timeline(md[0].step_stats)
        ctf=tl.generate_chrome_trace_format()
        tmp=json.loads(ctf)
        deltat=0
        Gtf["traceEvents"].extend(tmp["traceEvents"])
        deltat=md[1][1]
        
      return json.dumps(Gtf,indent=2)
    rmArr=[[tf.RunMetadata(),0] for x in range(20)]
    if timelineName:
      if gfile.Exists(timelineName):
        gfile.Remove(timelineName)
      ttot=int(0)
      tend=time.time()
      for i in range(20):
        tstart=time.time()
        valt = sess.run(outlist,options=run_options,run_metadata=rmArr[i][0])
        tend=time.time()
        rmArr[i][1]=(int(tstart*1.e6),int(tend*1.e6))
      with gfile.FastGFile(timelineName,"a") as tlf:
        tlf.write(mergeTraceStr(rmArr))
    else:
      for i in range(20):
        valt = sess.run(outlist)
    tf.logging.info("Warmup done. Starting real timing")
    num_iters=50
    for i in range(num_loops):
      tstart=time.time()
      for k in range(num_iters):
        val = sess.run(outlist)
      timings.append((time.time()-tstart)/float(num_iters))
      print("iter ",i," ",timings[-1])
    comp=sess.run(tf.reduce_all(tf.equal(val[0],valt[0])))
    print("Comparison=",comp)
    sess.close()
    tf.logging.info("Timing loop done!")
    return timings,comp,val[0],None

In [13]:
def score(nat,trt,topN=5):
  ind=np.argsort(nat)[:,-topN:]
  tind=np.argsort(trt)[:,-topN:]
  return np.array_equal(ind,tind),howClose(nat,trt,topN)

In [14]:
def topX(arr,X):
  ind=np.argsort(arr)[:,-X:][:,::-1]
  return arr[np.arange(np.shape(arr)[0])[:,np.newaxis],ind],ind

In [15]:
def howClose(arr1,arr2,X):
  val1,ind1=topX(arr1,X)
  val2,ind2=topX(arr2,X)
  ssum=0.
  for i in range(X):
    in1=ind1[0]
    in2=ind2[0]
    if(in1[i]==in2[i]):
      ssum+=1
    else:
      pos=np.where(in2==in1[i])
      pos=pos[0]
      if pos.shape[0]:
        if np.abs(pos[0]-i)<2:
          ssum+=0.5
  return ssum/X

In [16]:
def getLabels(labels,ids):
  return [labels[str(x+1)] for x in ids]

In [None]:
valnative=None
valfp32=None
valfp16=None
valint8=None
res=[None,None,None,None]

print("Starting at",datetime.datetime.now())

if config["update_graphdef"]:
    updateGraphDef(config["frozen_model_file"])
dummy_input = np.random.random_sample((config["batch_size"],224,224,3))
with open("labellist.json","r") as lf:
    labels=json.load(lf)
imageName="grace_hopper.jpg"
t = read_tensor_from_image_file(imageName,
                              input_height=224,
                              input_width=224,
                              input_mean=0,
                              input_std=1.0)
tshape=list(t[0].shape)
tshape[0]=config["batch_size"]
tnhwcbatch=np.tile(t[0],(config["batch_size"],1,1,1))
dummy_input=tnhwcbatch
wsize=config["workspace_size"]<<20
timelineName=None
if config["native"]:
    if config["with_timeline"]: timelineName="NativeTimeline.json"
    timings,comp,valnative,mdstats=timeGraph(getResnet50(),config["batch_size"],
                                 config["num_loops"],dummy_input,timelineName)
    printStats("Native",timings,config["batch_size"])
    printStats("NativeRS",mdstats,config["batch_size"])
    
if config["FP32"]:
    if config["with_timeline"]: timelineName="FP32Timeline.json"
    timings,comp,valfp32,mdstats=timeGraph(getFP32(config["batch_size"],wsize),config["batch_size"],config["num_loops"],
                               dummy_input,timelineName)
    printStats("TRT-FP32",timings,config["batch_size"])
    printStats("TRT-FP32RS",mdstats,config["batch_size"])
    
if config["FP16"]:
    k=0
    if config["with_timeline"]: timelineName="FP16Timeline.json"
    timings,comp,valfp16,mdstats=timeGraph(getFP16(config["batch_size"],wsize),config["batch_size"],
                                   config["num_loops"],dummy_input,timelineName)
    printStats("TRT-FP16",timings,config["batch_size"])
    printStats("TRT-FP16RS",mdstats,config["batch_size"])
    
if config["INT8"]:
    calibGraph=getINT8CalibGraph(config["batch_size"],wsize)
    print("Running Calibration")
    timings,comp,_,mdstats=timeGraph(calibGraph,config["batch_size"],1,dummy_input)
    print("Creating inference graph")
    int8Graph=getINT8InferenceGraph(calibGraph)
    del calibGraph
    if config["with_timeline"]: timelineName="INT8Timeline.json"
    timings,comp,valint8,mdstats=timeGraph(int8Graph,config["batch_size"],
                                   config["num_loops"],dummy_input,timelineName)
    printStats("TRT-INT8",timings,config["batch_size"])
    printStats("TRT-INT8RS",mdstats,config["batch_size"])
vals=[valnative,valfp32,valfp16,valint8]
enabled=[(config["native"],"native",valnative),
       (config["FP32"],"FP32",valfp32),
       (config["FP16"],"FP16",valfp16),
       (config["INT8"],"INT8",valint8)]
print("Done timing",datetime.datetime.now())
for i in enabled:
    if i[0]:
        print(i[1],getLabels(labels,topX(i[2],config["topN"])[1][0]))

Starting at 2018-07-30 04:11:28.817866
INFO:tensorflow:Starting execution
INFO:tensorflow:Starting Warmup cycle
INFO:tensorflow:Warmup done. Starting real timing
iter  0   0.10343759059906006
iter  1   0.10354409217834473
iter  2   0.10343996047973633
iter  3   0.10382778644561767
iter  4   0.1024544334411621
iter  5   0.10323407173156739
iter  6   0.10393102169036865
iter  7   0.10415175437927246
iter  8   0.10453024864196778
iter  9   0.103573637008667
iter  10   0.10539804458618164
iter  11   0.10416852951049804
iter  12   0.1039428472518921
iter  13   0.10392391681671143
iter  14   0.1047536039352417
iter  15   0.10394089698791503
iter  16   0.10423333168029786
iter  17   0.1036576795578003
iter  18   0.10333051681518554
iter  19   0.10330786228179932
Comparison= True
INFO:tensorflow:Timing loop done!
images/s : 1232.7 +/- 7.2, s/batch: 0.10384 +/- 0.00061
RES, Native, 128, 1232.68, 7.24, 0.10384, 0.00061
INFO:tensorflow:Starting execution
INFO:tensorflow:Starting Warmup cycle
INFO