<a href="https://colab.research.google.com/github/hnguyentt/GradCAM_and_GuidedGradCAM_tf2/blob/master/Visualization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Setup for Colab
import sys
from pathlib import Path
IN_COLAB = 'google.colab' in sys.modules

srcpath = Path("src")

if IN_COLAB and not srcpath.exists(): # if running in Colab --> download src if not exists
    !git clone https://github.com/hnguyentt/GradCAM_and_GuidedGradCAM_tf2
    !mv GradCAM_and_GuidedGradCAM_tf2/assets .
    !mv GradCAM_and_GuidedGradCAM_tf2/src .

In [3]:
#@title Click on the Run button to run
from IPython.display import display, Javascript, HTML, clear_output, IFrame
from ipywidgets import interact, interactive, fixed, interact_manual, AppLayout, GridspecLayout
import ipywidgets as widgets
from src.gradcam import GradCAM, overlay_gradCAM
from src.guidedBackprop import GuidedBackprop, deprocess_image
from src.utils import preprocess, predict, SAMPLE_DIR, array2bytes, DECODE, INV_MAP
from src.models import load_ResNet50PlusFC, load_VanilaResNet50
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import os, cv2
import matplotlib.pyplot as plt

# Header
header = widgets.HTML('<font color="#1f77b4" face="sans-serif"><center><h1>DEMO GradCAM and Guided GradCAM</h1></center></font>',
                      layout=widgets.Layout(height='auto'))
# Logo
logo = widgets.Image(
    value=open("./assets/illustrations/VietAIlogo.png", "rb").read(),
    format='png',
    width='auto',
    height='auto',
    align="center-align"
)

# Dropdowns
def on_change_im(change):
    if change['type'] == "change" and change["name"] == "value":
        img = cv2.imread(os.path.join(SAMPLE_DIR,change["new"]))
        im_arr = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        logo = widgets.Image(
            value=array2bytes(im_arr),
            format='png',
            width='auto',
            height='auto',
            align="center-align"
        )
        grid[5:13, 1:9] = widgets.HTML("")
        grid[5:13, 11:19] = logo
        grid[5:13, 21:29] = widgets.HTML("")
def on_change_model(change):
    if change['type'] == "change" and change["name"] == "value":
        chosen_model =  widgets.HTML("<center><p>Model %s loaded.<center>" % change["new"])
        grid[2,:8] = chosen_model
        
im_ls = ["--Select"] + [x for x in os.listdir("./assets/samples") if not x.startswith(".")]
im_ls.sort()
model_ls = ["--Select","VanilaResNet50", "ResNet50PlusFC"]
class_ls = ["--Select","Cat", "Dog"]
models = widgets.Dropdown(options=model_ls, description="Model",layout={'width':'auto'}, disabled=False)
imgs = widgets.Dropdown(options=im_ls, description="Image", layout={'width':'auto'}, disabled=False)
classes = widgets.Dropdown(options=class_ls, description="Class", layout={'width':'auto'}, disabled=False)
imgs.observe(on_change_im)
models.observe(on_change_model)
# Notes
# note = widgets.HTML("<p><b>Modes:</b><br>- VanilaResNet50 is retrained ResNet50 model on Dog vs. Cat dataset<br>"
#                     "- ResNet50PlusFC is the model with 2 FC layers added to ResNet50, retrained on Dog vs. Cat dataset"
#                     "<br><b>Classes: </b>If not specified, GradCAM & Guided GradCAM will be calculated based on the predicted class</p>")

# button
def create_expanded_button(description, button_style):
    return widgets.Button(description=description, button_style=button_style,
                          layout=widgets.Layout(height='auto', width='auto'))
pred_but = create_expanded_button("Show","info")

# Layouts
grid = GridspecLayout(20, 30, height='700px')
grid[0,:] = header
grid[1,:8] = models
grid[1,8:17] = imgs
grid[1,17:24] = classes
grid[1,25:] = pred_but
grid[5:13,11:19] = logo
# grid[17:,:] = note
display(grid)

def showCAMs(img, x, GradCAM, GuidedBP, chosen_class, upsample_size):
    cam3 = GradCAM.compute_heatmap(image=x, classIdx=chosen_class, upsample_size=upsample_size)
    gradcam = overlay_gradCAM(img, cam3)
    gradcam = cv2.cvtColor(gradcam, cv2.COLOR_BGR2RGB)
    # Guided backprop
    gb = GuidedBP.guided_backprop(x, upsample_size)
    gb_im = deprocess_image(gb)
    gb_im = cv2.cvtColor(gb_im, cv2.COLOR_BGR2RGB)
    # Guided GradCAM
    guided_gradcam = deprocess_image(gb*cam3)
    guided_gradcam = cv2.cvtColor(guided_gradcam, cv2.COLOR_BGR2RGB)
    
    # Display
    gc = widgets.Image(
            value=array2bytes(gradcam),
            format='png',
            width='auto',
            height='auto',
            align="center-align"
        )
    gbim = widgets.Image(
            value=array2bytes(gb_im),
            format='png',
            width='auto',
            height='auto',
            align="center-align"
        )
    ggc = widgets.Image(
            value=array2bytes(guided_gradcam),
            format='png',
            width='auto',
            height='auto',
            align="center-align"
        )
    grid[4, 1:9] = widgets.HTML('<center><b>GradCAM</b></center>')
    grid[4, 11:19] = widgets.HTML('<center><b>Guided Bacpropagation</b></center>')
    grid[4, 21:29] = widgets.HTML('<center><b>Guided GradCAM</b></center>')
    grid[5:13, 1:9] = gc
    grid[5:13, 11:19] = gbim
    grid[5:13, 21:29] = ggc
    
def check_button(sender):
#     if models.value == "VanilaResNet50"
    if models.value == "VanilaResNet50":
        model = load_VanilaResNet50()
        gradCAM = GradCAM(model=model, layerName="conv5_block3_out")
        guidedBP = GuidedBackprop(model=model,layerName="conv5_block3_out")
    elif models.value == "ResNet50PlusFC":
        model = load_ResNet50PlusFC()
        gradCAM = GradCAM(model=model, layerName="conv5_block3_out")
        guidedBP = GuidedBackprop(model=model, layerName="conv5_block3_out")
#     img = img_to_array(load_img(os.path.join(SAMPLE_DIR,imgs.value), target_size=(224,224)))
    img = cv2.imread(os.path.join(SAMPLE_DIR,imgs.value))
    upsample_size = (img.shape[1],img.shape[0])
    x = preprocess(imgs.value)
    pred, prob = predict(model,x)
    if classes.value == "--Select":
        classIdx = pred
    else:
        classIdx = INV_MAP[classes.value]
    
    grid[2,9:18] = widgets.HTML("<center><span>Predicted: <b>{}<b><span><center>".format(DECODE[pred]))
        
    showCAMs(img, x, gradCAM, guidedBP, classIdx, upsample_size)

pred_but.on_click(check_button)

GridspecLayout(children=(HTML(value='<font color="#1f77b4" face="sans-serif"><center><h1>DEMO GradCAM and Guid…

