<img src="./img/hpe_logo.png" alt="HPE Logo" width="300">

<h1>Request Prediction from KServe InferenceService</h1>

<h5>Date: 07/26/23</h5>
<h5>Version: 1.0</h5>
<h5>Author(s): andrew.mendez@hpe.com</h5>



<img src="./img/platform_step0.png" alt="Enterprise Machine Learning platform architecture" width="850">

<h3>Import modules and define functions</h3>
The cell below imports all modules and libraries required to run the demo.

In [1]:
# !pip install -q ipywidgets

In [14]:
# imports
import sys
import glob
import base64
import json
import requests
import matplotlib.pyplot as plt

from skimage import io
from PIL import Image, ImageDraw
from ipywidgets import interact, interactive
import ipywidgets as widgets
import io

from tqdm import tqdm
from multiprocessing import Pool

<h3>Step 1: Setting up connection details to KServe and define image directory</h3>

In [15]:
# Set direcotry ./img/pred for images and get files with .jpg extention
images = r"./../e2e_blogposts/ngc_blog/xview_dataset/train_images_rgb_no_neg_filt_32/train_images_640_02_filt_32/*.jpg"
endpoint_name='detection-deploy'
model_name='sat-detection'
ingress_host = "{}.models.mlds-kserve.us.rdlabs.hpecorp.net".format(endpoint_name)
ingress_port = "80"
service_hostname = "{}.models.mlds-kserve.us.rdlabs.hpecorp.net".format(endpoint_name)
print(ingress_host)

detection-deploy.models.mlds-kserve.us.rdlabs.hpecorp.net


In [40]:
import torchvision
import torch
global is3classes
is3classes=False

def run_nms(pred_d):
    # check if pred_d has two classes or three classes
    d_cl_ids = {}
    global is3classes
    n_classes = len(set([list(pred.keys())[0] for pred in pred_d]))# get max id
    # print("n_classes: ",n_classes)
    if n_classes == 3:
        cat2id = {'Fixed-wing Aircraft': 1, 'Cargo Plane':2,'Small Aircraft':3}
        is3classes=True
    elif is3classes==True and n_classes !=3:
        cat2id = {'Fixed-wing Aircraft': 1, 'Cargo Plane':2,'Small Aircraft':3}
    else:
        cat2id = {'Fixed-wing Aircraft': 1, 'Cargo Plane':2}
    id2cat = {v:k for k,v in cat2id.items()}
    bboxes = []
    scores = []
    classes = []
    for pred in pred_d:
        cl_id = list(pred.keys())[0]
        classes.append(cat2id[cl_id])
        scores.append(pred['score'])
        bboxes.append(pred[cl_id])
    classes = torch.LongTensor(classes)
    bboxes = torch.FloatTensor(bboxes)
    scores = torch.FloatTensor(scores)
    # print(classes)
    idxs = torchvision.ops.batched_nms(bboxes, scores, classes, iou_threshold=0.2)
    # print(idxs)
    final_classes = classes[idxs].tolist()
    final_bboxes = bboxes[idxs].tolist()
    final_scores = scores[idxs].tolist()
    # print(final_classes)

    final_d = []
    for cl,bbox,s in zip(final_classes,final_bboxes,final_scores):

        #{'Cargo Plane': [226.26126098632812, 554.0189208984375, 307.98333740234375, 623.9032592773438], 'score': 0.30342113971710205}
        # print(id2cat[cl])
        final_d.append({id2cat[cl]:bbox, 'score':s })
        # break
    # print(final_d)
    return final_d
def plot_pred(im,pred_d,thres=0.15):
    '''
    '''
    draw = ImageDraw.Draw(im)
    try:
        for pred in pred_d['predictions'][0]:
            assert len(list(pred.keys())) == 2
            cl_name = list(pred.keys())[0]
            bboxes = pred[cl_name]
            if pred['score'] > thres:
                draw.rectangle([bboxes[0],bboxes[1],bboxes[2],bboxes[3]],outline=(255,0,0),fill=None,width=1)
                draw.text([bboxes[0],bboxes[1]-10],"{} :{:.2f}".format(cl_name,pred['score']),fill=(250,0,0))
        plt.figure(figsize=(8,8))
    except Exception as e:
        print(e)
        pass
    plt.imshow(im)
    plt.show()
    return im

def predict(args):
    '''
    Function to base64encode image and send to API
    '''
    image= Image.open(args[0])
    with io.BytesIO() as buffer:
        image.save(buffer, format='JPEG')  # You can replace 'JPEG' with other formats like 'PNG' if needed
        image_bytes = buffer.getvalue()
    image_64_encode = base64.b64encode(image_bytes)
    bytes_array = image_64_encode.decode("utf-8")
    
    # Format the request in json
    request = {
      "instances":[
        {
          "data": bytes_array
        }
      ]
    }
    ingress_host = args[1]
    ingress_port = args[2]
    model_name = args[3]
    service_hostname = args[4]
    # Create request for Prediction (header, URL, payload)
    url = str("http://") + str(ingress_host) + ":" + str(ingress_port) + "/v1/models/" + str(model_name) + ":predict"
    headers = {'Host': service_hostname}
    payload = json.dumps(request)
    # print(request)
    response = requests.post(url, data=payload, headers=headers)
    res = response.json()
    # print("Running NMS...")
    final_res = run_nms(res['predictions'][0])
    # print("Done!")
    res['predictions'][0] = final_res
    # print(len(res['predictions'][0]))
    return res

def visualize(idx,thres=0.15):
    '''
    Visualize predicted results from resps
    '''
    print(idx,thres)
    output = resps[idx]
    im = Image.open(imgs[idx])
    plot_pred(im,output,thres)
    
def run_apply_async_multiprocessing(func, argument_list, num_processes):
    '''
    Use multiprocessing.apply_async to send simultaneous requests
    '''
    pool = Pool(processes=num_processes)

    jobs = [pool.apply_async(func=func, args=(*argument,)) if isinstance(argument, tuple) else pool.apply_async(func=func, args=(argument,)) for argument in argument_list]
    pool.close()
    result_list_tqdm = []
    for job in tqdm(jobs):
        result_list_tqdm.append(job.get())

    return result_list_tqdm

In [17]:
imgs = [img for img in glob.glob(images, recursive=True)]

In [41]:
predict([imgs[0],ingress_host,ingress_port,model_name,service_hostname])

n_classes:  3


{'predictions': [[{'Fixed-wing Aircraft': [230.6951904296875,
     562.4132080078125,
     291.6294860839844,
     633.3889770507812],
    'score': 0.4123193025588989},
   {'Cargo Plane': [607.5602416992188,
     536.1907958984375,
     631.2140502929688,
     550.3070678710938],
    'score': 0.40741395950317383},
   {'Small Aircraft': [319.1665954589844,
     312.05633544921875,
     347.3628234863281,
     427.73974609375],
    'score': 0.4052237868309021},
   {'Small Aircraft': [285.889404296875,
     358.40582275390625,
     354.6211853027344,
     486.673583984375],
    'score': 0.3940927982330322},
   {'Small Aircraft': [480.6795349121094,
     120.5506820678711,
     586.2841186523438,
     483.2882995605469],
    'score': 0.39343053102493286},
   {'Small Aircraft': [196.9119110107422,
     335.21435546875,
     555.2933959960938,
     640.0],
    'score': 0.3909900188446045},
   {'Small Aircraft': [408.0142517089844,
     148.9851531982422,
     508.7037048339844,
     482.8180

<h3>Step 2: Request prediction from KServe InferenceService and display results</h3>

In [42]:
is3classes=False
resps = run_apply_async_multiprocessing(predict,[[imgs[i],ingress_host,ingress_port,model_name,service_hostname] for i in range(len(imgs))],num_processes=4)

  0%|          | 0/33 [00:00<?, ?it/s]

n_classes:  3
n_classes:  3
n_classes:  3
n_classes:  3


  3%|▎         | 1/33 [00:08<04:22,  8.19s/it]

n_classes:  2


 15%|█▌        | 5/33 [00:10<00:46,  1.67s/it]

n_classes:  3


 18%|█▊        | 6/33 [00:12<00:47,  1.77s/it]

n_classes:  3


 21%|██        | 7/33 [00:14<00:48,  1.85s/it]

n_classes:  3


 24%|██▍       | 8/33 [00:16<00:47,  1.92s/it]

n_classes:  3


 27%|██▋       | 9/33 [00:18<00:46,  1.94s/it]

n_classes:  3


 30%|███       | 10/33 [00:20<00:45,  1.98s/it]

n_classes:  3


 33%|███▎      | 11/33 [00:22<00:44,  2.02s/it]

n_classes:  2


 36%|███▋      | 12/33 [00:24<00:42,  2.05s/it]

n_classes:  3


 39%|███▉      | 13/33 [00:27<00:42,  2.12s/it]

n_classes:  3


 42%|████▏     | 14/33 [00:29<00:40,  2.11s/it]

n_classes:  2


 45%|████▌     | 15/33 [00:31<00:37,  2.11s/it]

n_classes:  3


 48%|████▊     | 16/33 [00:33<00:35,  2.08s/it]

n_classes:  3


 52%|█████▏    | 17/33 [00:35<00:33,  2.08s/it]

n_classes:  3


 55%|█████▍    | 18/33 [00:37<00:31,  2.09s/it]

n_classes:  3


 58%|█████▊    | 19/33 [00:39<00:29,  2.12s/it]

n_classes:  3


 61%|██████    | 20/33 [00:41<00:27,  2.12s/it]

n_classes:  3


 64%|██████▎   | 21/33 [00:43<00:25,  2.09s/it]

n_classes:  3


 67%|██████▋   | 22/33 [00:45<00:22,  2.08s/it]

n_classes:  1


 70%|██████▉   | 23/33 [00:48<00:20,  2.09s/it]

n_classes:  3


 73%|███████▎  | 24/33 [00:50<00:18,  2.09s/it]

n_classes:  3


 76%|███████▌  | 25/33 [00:52<00:16,  2.12s/it]

n_classes:  3


 79%|███████▉  | 26/33 [00:54<00:15,  2.15s/it]

n_classes:  3


 82%|████████▏ | 27/33 [00:56<00:12,  2.13s/it]

n_classes:  2


 85%|████████▍ | 28/33 [00:58<00:10,  2.13s/it]

n_classes:  3


 88%|████████▊ | 29/33 [01:01<00:08,  2.17s/it]

n_classes:  3


 91%|█████████ | 30/33 [01:03<00:06,  2.16s/it]

n_classes:  2


 94%|█████████▍| 31/33 [01:05<00:04,  2.14s/it]

n_classes:  3


 97%|█████████▋| 32/33 [01:07<00:02,  2.12s/it]

n_classes:  3


100%|██████████| 33/33 [01:09<00:00,  2.10s/it]


In [43]:
interact(visualize, idx=widgets.IntSlider(min=0, max=len(resps), step=1, value=0),thres =widgets.FloatSlider(min=0, max=1.0, step=0.1, value=0.00));

interactive(children=(IntSlider(value=0, description='idx', max=33), FloatSlider(value=0.0, description='thres…