### Import Modules

In [None]:
#general modules
from os import listdir
from os.path import isfile, join
import time

#installed modules
import numpy as np
import tensorflow as tf
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

### Specifcy model and labels path

In [None]:
modelFullPath = '/Users/mrubashkin/Desktop/output_graph.pb'
labelsFullPath = '/Users/mrubashkin/Desktop/output_labels.txt'

In [None]:
def create_graph():
    """Creates a graph from saved GraphDef file and returns a saver."""
    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
        graph_def = tf.GraphDef()
        #this is the graph that we will assign to the run_inference_on_image function
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
        #graph_output=tf.import_graph_def(graph_def, name='')
        #graph_output=graph_def.ParseFromString(f.read())

def create_and_persist_graph():
    with tf.Session() as persisted_sess:
        # Load Graph
        with tf.gfile.FastGFile(modelFullPath,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            persisted_sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')
        return persisted_sess.graph
        
def run_inference_on_image():
    classification_scores={}
    answer = None

    if not tf.gfile.Exists(imagePath):
        tf.logging.fatal('File does not exist %s', imagePath)
        return answer

    image_data = tf.gfile.FastGFile(imagePath, 'rb').read()
    
    with tf.Session(target='', graph=persisted_result, config=None) as sess:
        softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
        predictions = sess.run(softmax_tensor,
                               {'DecodeJpeg/contents:0': image_data})
        predictions = np.squeeze(predictions)

        top_k = predictions.argsort()[-6:][::-1]  # Getting top 5 predictions
        f = open(labelsFullPath, 'rb')
        lines = f.readlines()
        labels = [str(w).replace("\n", "") for w in lines]
        for node_id in top_k:
            human_string = labels[node_id]
            score = predictions[node_id]
            #store output as dictionary
            classification_scores[human_string]=score
            #print('%s (score = %.5f)' % (human_string, score))

        return classification_scores

### Creates graph from saved GraphDef.

In [None]:

persisted_result=create_and_persist_graph()

In [None]:
image_dir='/Users/mrubashkin/Desktop/trains_to_classify/'
image_files = [f for f in listdir(image_dir) if isfile(join(image_dir, f))]

In [None]:
output=[]
print 'Classifying image in path: ' + image_dir
for image in image_files:
    imagePath = image_dir + image
    
    #Classify image, and record time to classify
    start_time=time.time()
    classification_scores=run_inference_on_image()
    print classification_scores
    output.append(classification_scores)
    print 'Classification of %s took %s seconds \n'%(image,time.time()-start_time)

In [None]:
#create arrays for graphing
trucks=[]
cars=[]
empty_road_rails=[]
caltrain=[]
freight_train=[]
light_rail=[]

for data_point in output:
    trucks.append(data_point['trucks'])
    cars.append(data_point['cars'])
    caltrain.append(data_point['caltrain'])
    freight_train.append(data_point['freight train'])
    light_rail.append(data_point['light rail'])
    empty_road_rails.append(data_point['empty road rails'])

In [None]:
fig=plt.figure(num=None, figsize=(10, 5), dpi=80, facecolor='w', edgecolor='k')

plt.rcParams.update({'font.size': 14})
plt.xlim([0,215])
plt.xlabel('Frame Number', fontsize=18)
plt.ylabel('Probability', fontsize=18)
plt.gca().set_color_cycle(['orange', 'gray', 'red', 'green','blue','black'])

plt.plot(trucks,linewidth=1.5)
plt.plot(cars,linewidth=1.5)
plt.plot(caltrain,linewidth=1.5)
plt.plot(freight_train,linewidth=1.5)
plt.plot(light_rail,linewidth=1.5)
plt.plot(empty_road_rails,linewidth=1.5)

plt.legend(['Trucks', 'Cars', 'Caltrain', 'Freight Train', 'Light Rail',\
            'No Vehicle'], bbox_to_anchor=(1.4, 1.025))

fig.show()
fig.savefig('/Users/mrubashkin/Desktop/test.png',dpi=fig.dpi, bbox_inches='tight')

In [None]:
#Plot image using Pillow and matplotlib
img = Image.open(imagePath)
img.thumbnail((64, 64), Image.ANTIALIAS) # resizes image in-place
imgplot = plt.imshow(img)
plt.figure()
plt.show(mpimg.imread(imagePath))