# RoadSafety

## Importing Modules

Importing libraries installed using PyPI and functions present in scripts created in for this project.

In [2]:
import tensorflow as tf
import numpy as np
from PIL import Image

## Executing Pre-Trained TensorFlow Model

In [11]:
def load_graph(frozen_graph_filename):
    
    """
    Args:
        frozen_graph_filename (str): Full path to the .pb file.
    """
    # We load the protobuf file from the disk and parse it to retrieve the unserialized graph_def
    with tf.io.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and returns it
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name="prefix")
        return graph


def segment(graph, image_file):
    """
    Does the segmentation on the given image.
    Args:
        graph (Tensorflow Graph)
        image_file (str): Full path to your image
    Returns:
        segmentation_mask (np.array): The segmentation mask of the image.
    """
    # We access the input and output nodes
    x = graph.get_tensor_by_name('prefix/ImageTensor:0')
    y = graph.get_tensor_by_name('prefix/SemanticPredictions:0')

    # We launch a Session
    with tf.compat.v1.Session(graph=graph) as sess:
        image = Image.open(image_file)
        image = image.resize((10, 10)) # 299
        image_array = np.array(image)
        image_array = np.expand_dims(image_array, axis=0)

        # Note: we don't nee to initialize/restore anything
        # There is no Variables in this graph, only hardcoded constants
        pred = sess.run(y, feed_dict={x: image_array})

        pred = pred.squeeze()
        
        print(pred)

    return pred

def get_n_rgb_colors(n):
    """
    n = 30 (number of classes)
    Get n evenly spaced RGB colors.
    Returns:
        rgb_colors (list): List of RGB colors.
    """
    max_value = 16581375  # 255**3
    interval = int(max_value / n)
    colors = [hex(I)[2:].zfill(6) for I in range(0, max_value, interval)]

    rgb_colors = [(int(i[:2], 16), int(i[2:4], 16), int(i[4:], 16)) for i in colors]
    
    print(rgb_colors)

    return rgb_colors

def parse_pred(pred, n_classes):
    """
    Parses a prediction and returns the prediction as a PIL.Image.
    Args:
        pred (np.array)
    Returns:
        parsed_pred (PIL.Image): Parsed prediction that we can view as an image.
    """
    uni = np.unique(pred)

    empty = np.empty((pred.shape[1], pred.shape[0], 3))

    colors = get_n_rgb_colors(n_classes)

    for i, u in enumerate(uni):
        idx = np.transpose(np.nonzero(pred == u))
        c = colors[u]
        empty[idx[:, 0], idx[:, 1]] = [c[0], c[1], c[2]]

    parsed_pred = np.array(empty, dtype=np.uint8)
    parsed_pred = Image.fromarray(parsed_pred)

    return parsed_pred


if __name__ == '__main__':
    N_CLASSES = 30
    MODEL_FILE = '/Users/luisrita/PycharmProjects/RoadSafety/cityscapes/frozen_inference_graph.pb'
    IMAGE_FILE = '/Users/luisrita/PycharmProjects/RoadSafety/img/london.png'

    graph = load_graph(MODEL_FILE)
    prediction = segment(graph, IMAGE_FILE)
    segmented_image = parse_pred(prediction, N_CLASSES)

    segmented_image.show()

[[2 2 2 7 7 7 7 7 7 7]
 [2 2 2 7 7 7 7 7 7 7]
 [8 7 7 7 7 7 7 7 7 7]
 [7 7 7 2 2 2 7 7 7 7]
 [7 7 7 2 2 8 8 8 8 8]
 [7 7 2 2 2 2 2 8 7 7]
 [7 7 2 2 2 2 2 7 7 7]
 [7 2 2 2 2 2 2 7 7 7]
 [2 2 2 2 2 2 2 7 7 7]
 [2 2 2 2 2 2 2 2 7 2]]
[(0, 0, 0), (8, 111, 8), (16, 222, 16), (25, 77, 24), (33, 188, 32), (42, 43, 40), (50, 154, 48), (59, 9, 56), (67, 120, 64), (75, 231, 72), (84, 86, 80), (92, 197, 88), (101, 52, 96), (109, 163, 104), (118, 18, 112), (126, 129, 120), (134, 240, 128), (143, 95, 136), (151, 206, 144), (160, 61, 152), (168, 172, 160), (177, 27, 168), (185, 138, 176), (193, 249, 184), (202, 104, 192), (210, 215, 200), (219, 70, 208), (227, 181, 216), (236, 36, 224), (244, 147, 232), (253, 2, 240)]
