In [None]:
import numpy as np
from sklearn.metrics import jaccard_score
from PIL import Image

gt_color_to_label = {
    (128,   0,   0) : 0,  # building
    (128, 128, 128) : 1,  # sky
    (128,  64, 128) : 2,  # road
    (128, 128,   0) : 3,  # vegetation
    (  0,   0, 192) : 4,  # sidewalk
    ( 64,   0, 128) : 5,  # car
    ( 64,  64,   0) : 6,  # pedestrain
    (  0, 128, 192) : 7,  # cyclist
    (192, 128, 128) : 8,  # signate
    ( 64,  64, 128) : 9,  # fence
    (192, 192, 128) : 10, # pole
    (  0,   0,   0) : 255 # invalid
}

img_gt_dir = '/home/ganlu/bkism_ws/src/BKISemanticMapping/data/data_kitti_15/kitti_15/'
img_pred_dir = '/home/ganlu/bkism_ws/src/BKISemanticMapping/data/data_kitti_15/reproj_img/'
evaluation_list = '/home/ganlu/bkism_ws/src/BKISemanticMapping/data/data_kitti_15/evaluatioList.txt'
evaluation_list = np.loadtxt(evaluation_list)

img_pred_all = []
img_gt_all = []
for img_id in evaluation_list:
    img_id = np.array2string(img_id, formatter={'float_kind':lambda x: "%06i" % x})
    print(img_id)
    
    # Read images
    img_gt_color = np.array(Image.open(img_gt_dir + img_id + '.png'))
    img_pred = np.array(Image.open(img_pred_dir + img_id + '_bw.png'))

    # Convert rgb to label
    rows, cols, _ = img_gt_color.shape
    img_gt = np.zeros((rows, cols))
    for i in range(rows):
        for j in range(cols):
            img_gt[i, j] = gt_color_to_label[tuple(img_gt_color[i, j, :])]
    
    img_gt_all.append(img_gt)
    img_pred_all.append(img_pred)
    
# Flatten
img_gt_all = np.array(img_gt_all).flatten()
img_pred_all = np.array(img_pred_all).flatten()

# Ignore sky and invalid labels in gt
img_pred_all = img_pred_all[img_gt_all != 1]
img_gt_all = img_gt_all[img_gt_all != 1]
img_pred_all = img_pred_all[img_gt_all != 255]
img_gt_all = img_gt_all[img_gt_all != 255]

# Ignore invalid labels in pred
#img_gt_all = img_gt_all[img_pred_all != 255]
#img_pred_all = img_pred_all[img_pred_all != 255]


# IoU for each class
print( np.unique(np.concatenate((img_gt_all, img_pred_all), axis=0)) )
print( jaccard_score(img_gt_all, img_pred_all, average=None) )