In [None]:
import matplotlib
matplotlib.use('Agg')
from responsibleai_vision.common.constants import ImageColumns
import os
import torchvision.datasets as dset
import torchvision
import torch
import pandas as pd 
from zipfile import ZipFile
try:
    from urllib import urlretrieve
except ImportError:
    from urllib.request import urlretrieve
from PIL import Image
from torchvision import transforms as T
from vision_explanation_methods.error_labeling.error_labeling import (
    ErrorLabeling, ErrorLabelType)
from ml_wrappers.model.image_model_wrapper import WrappedObjectDetectionModel
from IPython.display import display
from PIL import Image, ImageDraw
import copy
from PIL import Image, ImageDraw, ImageFont

In [None]:

def load_mscoco_object_detection_dataset_labels():

    src_images = "./dataCoco/images/"
    base_image = './dataCoco/'

    # Path to the annotations
    annotations_folder = os.path.join(base_image, "annotations")
    
    path2data = './dataCoco/val2017/'
    path2json = './dataCoco/annotations/instances_val2017.json'
    coco_val = dset.CocoDetection(root=path2data, annFile=path2json)

    image_labels = []
    image_ids = []
    for x in range(4,10):
        img, target = coco_val[x]
        img_label = []
        image_ids.append(str(target[0]['image_id']))
        for i in range(0, len(target)):
            bbox = target[i]['bbox']
            label = target[i]['category_id']
            isCrowd = target[i]['iscrowd']
            
            img_label.append([label, float(bbox[0]), float(bbox[1]),float(bbox[0])+float(bbox[2]),float(bbox[1])+float(bbox[3]),int(isCrowd) ]) 

        image_labels.append(img_label)
    
    return image_ids, image_labels

def load_mscoco_object_detection_dataset():
    # create data folder if it doesnt exist.
    os.makedirs("dataCoco", exist_ok=True)

    # download data
    download_url = ("http://images.cocodataset.org/zips/val2017.zip")
    data_file = "./val2017.zip"
    urlretrieve(download_url, filename=data_file)

    download_url1 = ("http://images.cocodataset.org/annotations/annotations_trainval2017.zip")
    data_file1 = "./annotations_trainval2017.zip"
    urlretrieve(download_url1, filename=data_file1)

    # extract files
    with ZipFile(data_file, "r") as zip:
        print("extracting files...")
        zip.extractall(path="./dataCoco")
        print("done")
    # delete zip file
    os.remove(data_file)

    with ZipFile(data_file1, "r") as zip:
        print("extracting files...")
        zip.extractall(path="./dataCoco")
        print("done")
    # delete zip file
    os.remove(data_file1)
        
        
    ids, labels = load_mscoco_object_detection_dataset_labels()
    
    # get all file names into a pandas dataframe with the labels
    data = pd.DataFrame(columns=[ImageColumns.IMAGE.value,
                                    ImageColumns.LABEL.value])
    for i in range(0, len(ids)):
        image_path = "./dataCoco/val2017/" + '%012d'%(int(ids[i]))+'.jpg'
        data = data.append({ImageColumns.IMAGE.value: image_path,
                            ImageColumns.LABEL.value: labels[i]}, # folder
                            ignore_index=True)
    
    
    return data

In [None]:
def get_instance_segmentation_model():
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    return model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = get_instance_segmentation_model()

model.to(device)

In [None]:
data = load_mscoco_object_detection_dataset()

class_names = np.array(["person", "bicycle", "car", "motorcycle",
            "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant",
            "unknown", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
            "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "unknown", "backpack",
            "umbrella", "unknown", "unknown", "handbag", "tie", "suitcase", "frisbee", "skis",
            "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
            "surfboard", "tennis racket", "bottle", "unknown", "wine glass", "cup", "fork", "knife",
            "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog",
            "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "unknown", "dining table",
            "unknown", "unknown", "toilet", "unknown", "tv", "laptop", "mouse", "remote", "keyboard",
            "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "unknown",
            "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush" ])

In [None]:
def highlight_cells(value):
    color = 'background-color: yellow'  # Set the desired highlight color

    if value == ErrorLabelType.MATCH:
        return color
    else:
        return ''

In [None]:
def remove_first_last(array):
    result = []
    # for sublist in array:
    if len(array) > 2:
        result.append(array[1:-1])
    return result

def draw_detections(base_img, detections, gt):
    img_copy = copy.deepcopy(base_img)
    im = ImageDraw.Draw(img_copy)
    for idx, detection in enumerate(detections):
        d = detection[1:-1]
        im.rectangle(((d[0], d[1]), (d[2], d[3])), outline="green", width=2)
        im.text((d[0], d[1]), text="p"+str(idx), fill="green", stroke_width=4, stroke_fill="white", font=ImageFont.truetype("arial.ttf", 12))
    for idx, detection in enumerate(gt):
        d = detection[1:-1]
        im.rectangle(((d[0], d[1]), (d[2], d[3])), outline=(255, 0, 255), width=2)
        im.text((d[2]-20, d[1]), text="gt"+str(idx), fill=(255, 0, 255), stroke_width=4, stroke_fill="white", font=ImageFont.truetype("arial.ttf", 12))
    return img_copy

detection_model = WrappedObjectDetectionModel(model=model, number_of_classes = 90)
lst=[]
lst_labels = []
for row in data.itertuples():
    image = Image.open(row.image)
    gt_label = (row.label)
    img_tensor = T.ToTensor()(image).to(device).unsqueeze(0)
    pred_y = detection_model.predict(img_tensor)
    mng = ErrorLabeling('object_detection',
                  pred_y[0],
                  gt_label,
                  .7)
    mng.compute()
    img = draw_detections(image, pred_y[0], gt_label)
    display(img)
    rows = [f"gt{i}" for i in range(len(mng._match_matrix))]
    cols = [f"p{i}" for i in range(len(mng._match_matrix[0]))]
    df = pd.DataFrame(mng._match_matrix, columns=cols)
    df.index = (rows)
    highlighted_df = df.style.applymap(highlight_cells)
    display(highlighted_df)
        