In [1]:
DATA_DIR = '../datasets/vis_2'

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from mpl_toolkits.axes_grid1 import make_axes_locatable
import os
from PIL import Image 
import numpy as np
import os
from tqdm import tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import shutil
import time
import pickle
import tqdm
from csv import reader
from functools import reduce
import math
from time import sleep
from pathlib import Path
import csv
import itertools
import random
import gc
import statistics
import pandas as pd
import logging
import functools
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import requests
from matplotlib.widgets import Button
from textwrap import wrap
import matplotlib.gridspec as gridspec
import ipywidgets as widgets
from IPython.display import display
from ipywidgets import widgets, interactive

In [3]:
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

In [4]:
input_img = Image.open(f'{DATA_DIR}/input_img.png')
with open(f'{DATA_DIR}/grad_cam_dict.pickle', 'rb') as f:
    grad_cam_dict = pickle.load(f)

In [5]:
color_mapping_before_conversion = {
     0 : (  0,  0,  0),
     1 : (  0,  0,  0),
     2 : (  0,  0,  0),
     3 : (  0,  0,  0),
     4 : (  0,  0,  0),
     5 : (111, 74,  0),
     6 : ( 81,  0, 81),
     7 : (128, 64,128),
     8 : (244, 35,232),
     9 : (250,170,160),
    10 : (230,150,140),
    11 : ( 70, 70, 70),
    12 : (102,102,156),
    13 : (190,153,153),
    14 : (180,165,180),
    15 : (150,100,100),
    16 : (150,120, 90),
    17 : (153,153,153),
    18 : (153,153,153),
    19 : (250,170, 30),
    20 : (220,220,  0),
    21 : (107,142, 35),
    22 : (152,251,152),
    23 : ( 70,130,180),
    24 : (220, 20, 60),
    25 : (255,  0,  0),
    26 : (  0,  0,142),
    27 : (  0,  0, 70),
    28 : (  0, 60,100),
    29 : (  0,  0, 90),
    30 : (  0,  0,110),
    31 : (  0, 80,100),
    32 : (  0,  0,230),
    33 : (119, 11, 32),
    -1 : (  0,  0,142)
}

In [6]:
color_mapping_after_conversion = {
     255 : (  0,  0,  0),
     0 : (128, 64,128),
     1 : (244, 35,232),
     2 : ( 70, 70, 70),
     3 : (102,102,156),
     4 : (190,153,153),
     5 : (153,153,153),
     6 : (250,170, 30),
     7 : (220,220,  0),
     8 : (107,142, 35),
     9 : (152,251,152),
    10 : ( 70,130,180),
    11 : (220, 20, 60),
    12 : (255,  0,  0),
    13 : (  0,  0,142),
    14 : (  0,  0, 70),
    15 : (  0, 60,100),
    16 : (  0, 80,100),
    17 : (  0,  0,230),
    18 : (119, 11, 32)
}

In [7]:
id2class_after_conversion = {
     255 : "unlabeled",
     0 : "road",
     1 : "sidewalk",
     2 : "building",
     3 : "wall",
     4 : "fence",
     5 : "pole",
     6 : "tr. light",
     7 : "tr. sign",
     8 : "veg.",
     9 : "terrain",
    10 : "sky",
    11 : "person",
    12 : "rider",
    13 : "car",
    14 : "truck",
    15 : "bus",
    16 : "train",
    17 : "m. cycle",
    18 : "bicycle"
}

In [8]:
class2id_after_conversion = dict(map(reversed, id2class_after_conversion.items()))

In [9]:
def get_next_class_label(current_class_label = 'road'):
    current_class_id = class2id_after_conversion[current_class_label]
    current_class_id = (current_class_id+1)%NUM_CLASSES
    new_class_label = id2class_after_conversion[current_class_id]
    return new_class_label

In [10]:
label_mask = class2id_after_conversion.keys()
labels = list(set(label_mask))
labels.remove("unlabeled")
cmaps = ['turbo', 'Paired', 'nipy_spectral', 'viridis', 'rainbow']

In [11]:
def plot_cam_image(class_label, input_img, grad_cam, axs, ax, fig, cmap='turbo'):
    
    axs.imshow(input_img, alpha=1.0)
    bar2show = axs.imshow(grad_cam, alpha=0.5, cmap=cmap)
    
    divider = make_axes_locatable(axs)
    cax = divider.new_horizontal(size="5%", pad=0.1)
    fig.add_axes(cax)
    
    cbar = plt.colorbar(bar2show, cax=cax)
    cbar.set_label(label=f'GRAD-CAM Weights ({cmap})', fontsize = 8, color='black', rotation=90)
    
    display_string = "Current Object: " + class_label
    padding = " "*(30-len(display_string))
    display_string = padding + display_string + padding
    fig.text(0.38, 0.85, display_string, fontsize=10, bbox={'facecolor': 'blue', 'alpha': 0.5})

In [12]:
def cam_vis(class_label='road', cmap='turbo'):
    global input_img
    
    fig,ax=plt.subplots(ncols=2, nrows=1, figsize=(9,3))
    fig.patch.set_facecolor('skyblue')
    fig.patch.set_alpha(0.6)
    fig.dpi = 150
    
    ax[0].imshow(input_img, extent=[0, 1024, 0, 470])
    ax[0].set_xlabel('Input Image')
    ax[1].set_xlabel(f'Grad-CAM for: {class_label}')
    ax[0].set_xticks([])
    ax[1].set_xticks([])
    ax[0].set_yticks([])
    ax[1].set_yticks([])
    
    current_class_label = class_label
    plot_cam_image(current_class_label, input_img, grad_cam_dict[class_label], 
                   ax[1], ax, fig, cmap=cmap)
    
    # Hover Code
    annot = ax[1].annotate("", xy=(0,0), xytext=(5,5), textcoords="offset points", color='white')
    annot.set_visible(False)
    
    def on_hover(event):
        if event.inaxes == ax[1]:
            annot.xy = (event.xdata, event.ydata)
            x, y = event.xdata, event.ydata
            x, y = int(x), int(y)
            x, y = x, y-512
            annot.set_text(f"{grad_cam_dict[class_label][y][x]:.2f}")
            annot.set_visible(True)
        else:
            annot.set_visible(False)

    cid = fig.canvas.mpl_connect('motion_notify_event', on_hover)
    
    plt.show()

### Categorical cmaps:
1. Paired
2. nipy_spectral
### Sequential cmaps:
1. viridis
2. turbo

### A BAD BAD cmap:
* rainbow

### Suggested one: turbo

In [13]:
%matplotlib widget

class_dropdown = widgets.Dropdown(options=labels, 
                            value='road', 
                            description='Object:')
cmap_dropdown = widgets.Dropdown(options=cmaps, 
                            value='turbo', 
                            description='Colormap:')

interactive(cam_vis, class_label=class_dropdown, cmap=cmap_dropdown)

interactive(children=(Dropdown(description='Object:', options=('road', 'bus', 'terrain', 'tr. sign', 'm. cycle…