In [None]:
"A client to call the model server and get predictions. Supports both gRPC and REST modes of communication"

In [None]:
import glob

import cv2
import json

import grpc
import numpy as np
import requests
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub

In [None]:
import matplotlib.image as mpimg
# This is needed to display the images.
%matplotlib inline
import matplotlib.pyplot as plt

def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(64, 20))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
# configs

IMG_DIR = "/Users/magarwal/office/satellite-road-extraction-service/road_extraction_outputs/test_results/images"
LABEL_DIR = "/Users/magarwal/office/satellite-road-extraction-service/road_extraction_outputs/test_results/labels"
OUTPUT_PATH = "/Users/magarwal/office/satellite-road-extraction-service/road_extraction_outputs/test_results"

TF_SERVING_HOST = "localhost"
TF_SERVING_GRPC_PORT = "8500"
TF_SERVING_REST_PORT = "8501"

MODEL_SIGNATURE_NAME = 'serving_default'

In [None]:
def get_tf_serving_client(mode='grpc'):
    """
    :param mode:
    :type mode:
    :return:
    :rtype:
    """
    if mode == 'grpc':
        return TFServingClient(TF_SERVING_HOST, TF_SERVING_GRPC_PORT, mode)
    else:
        return TFServingClient(TF_SERVING_HOST, TF_SERVING_REST_PORT, mode)

In [None]:
class TFServingClient:
    """
    TF Serving Client Object Encapsulation
    """

    def __init__(self, host: str, port: str, mode='grpc'):
        """
        :param host:
        :type host:
        :param port:
        :type port:
        :param mode:
        :type mode:
        """
        self.host = host
        self.port = port
        self.mode = mode
        if self.mode == 'grpc':
            self.stub = self._get_stub()

    def predict(self, image: np.ndarray, model_name: str, version: int = None, **kwargs) -> np.ndarray:
        """
        Takes image batch, model name, and optionally the model version and returns the parsed predictions.
        :param image: A batch of images
        :param model_name: Name of the model to call
        :param version: Version number of the model. Leave blank for latest.
        :return: Parsed response of the model
        """
        if self.mode == 'grpc':
            result = self._predict_grpc(image, model_name, version, **kwargs)
        else:
            result = self._predict_rest(image, model_name, version)
        return result

    def _predict_grpc(self, image: np.ndarray, model_name: str, version: int = None, **kwargs):
        """
        Connect to the model server over a gRPC channel and return the predictions. Prefer this over the REST API.
        :param image: A batch of images
        :param model_name: Name of the model to call
        :param version: Version number of the model. Leave blank for latest.
        :return: Parsed response of the model
        """
        request = predict_pb2.PredictRequest()
        request.model_spec.name = model_name
        request.model_spec.signature_name = MODEL_SIGNATURE_NAME
        if version is not None:
            request.model_spec.version_value = version

        if model_name == "road_segmentation":
            input_proto = tf.contrib.util.make_tensor_proto(image, dtype='float32', shape=[1, 512, 512, 3])
            request.inputs["input_1"].CopyFrom(input_proto)
            resp = self.stub.Predict(request, 300)
            self.channel.close()
            return self._parse_resp_road_segmentation(resp)

    def _predict_rest(self, image: np.ndarray, model_name: str, version: int = None) -> np.ndarray:
        """
        Get predictions using the model server's REST API.
        :param image: A batch of images
        :param model_name: Name of the model to call
        :param version: Version number of the model. Leave blank for latest.
        :return: Parsed response of the model
        """
        data = json.dumps({
            'signature_name': MODEL_SIGNATURE_NAME,
            'instances': image.tolist()
        })
        if version is None:
            url = f'http://{self.host}:{self.port}/v1/models/{model_name}:predict'
        else:
            url = f'http://{self.host}:{self.port}/v1/models/{model_name}/versions/{version}:predict'
        headers = {'Content-type': 'application/json'}
        resp = requests.post(url, data=data, headers=headers)
        return self._parse_resp(resp)

    def _get_stub(self) -> PredictionServiceStub:
        """
        If gRPC is chosen, this method establishes a channel and returns a prediction stub.
        :return: A Tensorflow_serving API prediction stub
        """
        self.channel = grpc.insecure_channel(f'{self.host}:{self.port}',
                                             options=[('grpc.max_receive_message_length', 4096 * 4096 * 5)])
        return PredictionServiceStub(self.channel)

    def _parse_resp_road_segmentation(self, resp):
        """
        Parse the response for grpc.
        :param resp:
        :return:
        """
        result = np.array(resp.outputs["sigmoid"].float_val).reshape((512, 512)).astype(np.float32)
        return result


In [None]:
# Utils

def get_filename(file):
    idx1 = file.rfind("/")
    filename = file[idx1+1:]
    return filename

def read_img(filepath):
    image = cv2.imread(filepath, cv2.IMREAD_UNCHANGED)
    return image

def save_image(image, path):
    resized_img = cv2.resize(image,(1500,1500), interpolation=cv2.INTER_AREA)
    cv2.imwrite(path, resized_img)


In [None]:
# IMAGE OPERATIONS

def crop_img_label(img, mask):
    img_arr = []
    label_arr = []
    image = cv2.resize(img, (1536, 1536),interpolation=cv2.INTER_AREA)
    label = cv2.resize(mask, (1536, 1536),interpolation=cv2.INTER_AREA)
    for i in range(3):
        row_start, row_end = i*512, (i+1)*512
        for j in range(3):
            col_start, col_end = j*512, (j+1)*512
            crp_img = image[row_start:row_end, col_start:col_end]
            crp_mask = label[row_start:row_end, col_start:col_end]
            
            img_arr.append(crp_img)
            label_arr.append(crp_mask)
    
    return img_arr, label_arr

def merge_imgs(images):
    img_row1 = np.concatenate((images[0], images[1], images[2]), axis=1)
    img_row2 = np.concatenate((images[3], images[4], images[5]), axis=1)
    img_row3 = np.concatenate((images[6], images[7], images[8]), axis=1)
    final_img = np.concatenate((img_row1, img_row2, img_row3), axis=0)
    return final_img

def convert_images_dtype(images):
    """
    Convert image d-type from unit8 to float 0-1 range
    :param images:
    :return:
    """
    images_shape = images.shape
    if len(images_shape) > 4:
        ValueError("'image' must have either 3 or 4 dimensions, "
                   "received `{}`.".format(images_shape))

    def convert_img(img):
        shape = img.shape
        img = tf.image.convert_image_dtype(img, dtype=tf.float32)
        img = tf.cast(img, dtype=tf.float32)
        img.set_shape((shape[0], shape[1], shape[2]))
        return img

    if len(images_shape) == 4:
        return tf.map_fn(convert_img, images)

    return convert_img(images)


def smoothen_detection(mask, dilate_iter=2):
    """
    :param dilate_iter:
    :type dilate_iter:
    :param mask:
    :type mask:
    :return:
    :rtype:
    """
    kernel = np.ones((3, 3), np.uint8)
    dilated = cv2.dilate(mask, kernel, iterations=dilate_iter)
    eroded = cv2.erode(dilated, kernel, iterations=dilate_iter*2)
    dilated = cv2.dilate(eroded, kernel, iterations=dilate_iter)
    return cv2.medianBlur(dilated, 3)


def remove_small_connected_objects(img, min_size):
    """
    Function to remove small prediction patches from infered image
    :param img:
    :param min_size: minimum numbers of pixels in patch to be valid....
    :return:
    """
    nb_components, output, stats, _ = cv2.connectedComponentsWithStats(img, connectivity=8)
    sizes = stats[1:, -1]
    nb_components = nb_components - 1
    output_shape = output.shape
    converted_image = np.zeros((output_shape[0], output_shape[1]), np.uint8)

    for i in range(0, nb_components):
        if sizes[i] >= min_size:
            converted_image[output == i + 1] = 255

    return converted_image


def expand_dims(image_, axis=0):
    """
    :param image_:
    :type image_:
    :return:
    :rtype:
    """
    new_image = np.expand_dims(image_, axis=axis)
    return new_image

In [None]:
# Evaluation Metrics Utils

def calculate_precision(gt, pred):
#   True positives
    tp = np.logical_and(gt, pred)
#   True positive plus false positive
    tp_fn = pred 
    if np.sum(tp_fn) == 0:
        return None
    recall = np.sum(tp) / np.sum(tp_fn)
    return recall


def calculate_recall(gt, pred):
#   True positives
    tp = np.logical_and(gt, pred)
#   TP + FN
    tp_fn = gt 
    if np.sum(tp_fn) == 0:
        return None
    recall = np.sum(tp) / np.sum(tp_fn)
    return recall


def calculate_iou(gt, pred):
    intersection = np.logical_and(gt, pred)
    union = np.logical_or(gt, pred)
    if np.sum(union) == 0:
        # return none for black input images
        return None
    iou_score = np.sum(intersection) / np.sum(union)
    return iou_score


def calculate_F1(precision, recall):
    product = precision * recall
    summation = precision + recall
    if summation == 0:
        return None
    f1_score = (2 * product)/summation
    return f1_score


def cal_mean(acc_list):
    not_none_values = [val for val in acc_list if val is not None]
    mean_val = np.sum(not_none_values)/len(not_none_values)
    return mean_val

In [None]:
def infer_and_evaluate(img, mask, file_name):
    """
    Run inference on input image and evaluate accuracy parameters
    """
    imgs, labels = crop_img_label(img, mask)
    
    pred_arr = []
    
    for idx, crp in enumerate(imgs):
        label = labels[idx]
        inp_img = np.multiply(crp, 1 / 255.0)
        batched_img = expand_dims(inp_img)
        
        pred = get_tf_serving_client().predict(batched_img, "road_segmentation")
        pred = np.reshape(pred, (512, 512))
        
        final_img = np.zeros((512, 512), np.uint8)
        thresh_indices = pred[:, :] > 0.5
        final_img[thresh_indices] = 255
        final_img = remove_small_connected_objects(final_img, 40)
        pred_arr.append(final_img)
        
    merged_pred = merge_imgs(pred_arr)
    merged_pred = cv2.resize(merged_pred,(1500,1500),interpolation=cv2.INTER_AREA)
    
    ## uncomment below lines to visualise inference data
    visualize(
            img = img,
            pred = merged_pred,
            mask = mask
        )
    
#     save_image(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), f"{OUTPUT_PATH}/images/{file_name}")
#     save_image(merged_pred, f"{OUTPUT_PATH}/pred/{file_name}")
#     save_image(mask, f"{OUTPUT_PATH}/labels/{file_name}")
    
    # converting gt, pred into one hot encoding. 
    gt = np.where(mask > 0, 1, 0).astype('uint8')
    pred = np.where(merged_pred > 0, 1, 0).astype('uint8')
    
    iou = calculate_iou(gt, pred)
    recall = calculate_recall(gt, pred)
    precision = calculate_precision(gt, pred)
    f1_score = calculate_F1(precision, recall)
    
    print(f"for {file_name} - iou: {iou}, recall: {recall}, precision: {precision}, f1_score: {f1_score}")
    
    return iou, recall, precision, f1_score

In [None]:
def main():
    images = glob.glob(f"{IMG_DIR}/*.png")
    iou_list, recall_list, precision_list, f1_list = [], [], [], []
    for img_path in images:
        file_name = get_filename(img_path)
        label_path = f"{LABEL_DIR}/{file_name}"
        
        bgr_img = read_img(img_path)
        rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
        label = read_img(label_path)
        
        iou, recall, precision, f1_score = infer_and_evaluate(rgb_img, label, file_name)
        iou_list.append(iou)
        recall_list.append(recall)
        precision_list.append(precision)
        f1_list.append(f1_score)
        
        
    mean_iou = cal_mean(iou_list)
    mean_recall = cal_mean(recall_list)
    mean_precision = cal_mean(precision_list)
    mean_f1_score = cal_mean(f1_list)
    print("*******Results*******")
    print(f"Mean - iou: {mean_iou}, recall: {mean_recall}, precision: {mean_precision}, f1_score: {mean_f1_score}")

In [None]:
if __name__ == '__main__':
    main()