In [None]:
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
LAYER_NAME = 'block5_conv3'  # VGG16
model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True)

grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(LAYER_NAME).output, model.output])

In [None]:
grad_model.layers[-1].activation = tf.keras.activations.softmax

# Grad CAM

In [None]:
def grad_cam(img_path, cls_idx=-1, colormap_type=cv2.COLORMAP_JET, show_max_cam=False, img_size=224, softmax=True, omit_neg=True):
    img_bgr = cv2.imread(img_path)
    if img_bgr is None:
        print("failed to load image")
        return [None]*5
    else:
        img_bgr_resize = cv2.resize(img_bgr, (img_size, img_size))
        img_rgb = cv2.cvtColor(img_bgr_resize, cv2.COLOR_BGR2RGB)
        #img_rgb = cv2.resize(img_rgb, (img_size, img_size))
        img_norm = ((img_rgb-127.5)/127.5).astype(np.float32)
        
    with tf.GradientTape() as tape:
        print(img_norm.shape)
        conv_outputs, predictions = grad_model(np.array([img_norm]))
        print(conv_outputs.shape, predictions.shape)
        if cls_idx == -1:
            cls_idx = np.argmax(predictions[0]) 
        preds = predictions[:, cls_idx]

    #conv_outputs = model.backbone.id_block5_2.output
    output = conv_outputs[0]
    grads = tape.gradient(preds, conv_outputs)[0]

    weights = tf.reduce_mean(grads, axis=(0, 1))

    cam = np.zeros(output.shape[0: 2], dtype = np.float32)

    for i, w in enumerate(weights):
        cam += w * output[:, :, i]

    if show_max_cam:
        print(np.max(cam), np.min(cam))
    #print(cam.shape)
    
    if omit_neg:
        cam = np.maximum(cam, 0) / np.max(cam)
    else:
        cam = (cam.numpy()-np.min(cam)) / (np.max(cam)-np.min(cam))
    cam = cv2.resize(cam, (img_size, img_size))

    cam = np.uint8(255*cam)
    if colormap_type is None:
        cam = cv2.cvtColor(cam, cv2.COLOR_GRAY2BGR)
    else:
        cam = cv2.applyColorMap(cam, colormap_type)
    
    heatmap = cv2.addWeighted(img_bgr_resize.astype('uint8'), 0.5, cam, 1, 0)
    
    if softmax:
        preds = predictions[0]
    else:
        preds = tf.nn.softmax(predictions[0])
    argmax = np.argmax(preds)
    
    h, w, c = img_bgr_resize.shape
    img_final = np.ones([h, w*2, c])
    img_final[0:h, 0:w] = img_bgr_resize
    img_final[0:h, w:2*w] = heatmap
    img_final = img_final.astype(np.uint8)

    return img_bgr, heatmap, img_final, argmax, preds


## Grad CAM++

In [None]:
def grad_cam_pp(img_path, cls_idx=-1, colormap_type=cv2.COLORMAP_JET, show_max_cam=False, img_size=224, softmax=True, omit_neg=True):
    img_bgr = cv2.imread(img_path)
    if img_bgr is None:
        print("failed to load image")
        return [None]*5
    else:
        img_bgr_resize = cv2.resize(img_bgr, (img_size, img_size))
        img_rgb = cv2.cvtColor(img_bgr_resize, cv2.COLOR_BGR2RGB)
        #img_rgb = cv2.resize(img_rgb, (img_size, img_size))
        img_norm = ((img_rgb-127.5)/127.5).astype(np.float32)
        #img_norm = (img_rgb/255).astype(np.float32)
        
        #img_bgr_resize = cv2.resize(img_bgr, (img_size, img_size))

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(np.array([img_norm]))
#         if softmax:
#             predictions = tf.nn.softmax(predictions)
        if cls_idx == -1:
            cls_idx = np.argmax(predictions[0]) 
        preds = predictions[:, cls_idx]

    #conv_outputs = model.backbone.id_block5_2.output
    conv_output = conv_outputs[0]
    grads = tape.gradient(preds, conv_outputs)[0]

    score = tf.exp(preds)
    
    first_derivative = score * grads
    second_derivative = first_derivative * grads
    third_derivative = second_derivative * grads

    
    global_sum = np.sum(tf.reshape(conv_output, (-1, conv_output.shape[2])), axis=0)  # [2048]
 
    alpha_denom = second_derivative*2 + third_derivative * tf.reshape(global_sum, (1, 1, conv_output.shape[2])) + 1e-5
    
    alphas = conv_output / alpha_denom  # [10, 10, 2048]
    
    weights = np.maximum(first_derivative, 0)
    
    ## normalize the alphas for each feature map
    alphas_thresh = np.where(weights, alphas, 0)  # threhold the alphas by weights
    alphas_sum = np.sum(alphas_thresh, axis=(0, 1))  # sum the alphas over the feature map, [2048]
    alphas_norm = np.where(alphas_sum!=0, alphas_sum, np.ones(alphas_sum.shape))  
    
    alphas /= alphas_norm

    
    weights = np.maximum(first_derivative, 0) * alphas
    
    weights = np.sum(weights, axis=(0, 1))  # [2048]
    #print(weights, 'aaa')
    cam = np.zeros(conv_output.shape[0: 2], dtype = np.float32)
    for i, w in enumerate(weights):
        cam += w * conv_output[:, :, i]
    
    
    if show_max_cam:
        print(np.max(cam), np.min(cam))
    #print(cam.shape)
    
    if omit_neg:
        cam = np.maximum(cam, 0) / np.max(cam)
    else:
        cam = (cam.numpy()-np.min(cam)) / (np.max(cam)-np.min(cam))
    cam = cv2.resize(cam, (img_size, img_size))

    cam = np.uint8(255*cam)
    if colormap_type is None:
        cam = cv2.cvtColor(cam, cv2.COLOR_GRAY2BGR)
    else:
        cam = cv2.applyColorMap(cam, colormap_type)
    
    heatmap = cv2.addWeighted(img_bgr_resize.astype('uint8'), 0.5, cam, 1, 0)
    
    if softmax:
        preds = predictions[0]
    else:
        preds = tf.nn.softmax(predictions[0])
    argmax = np.argmax(preds)
    
    h, w, c = img_bgr_resize.shape
    img_final = np.ones([h, w*2, c])
    img_final[0:h, 0:w] = img_bgr_resize
    img_final[0:h, w:2*w] = heatmap
    img_final = img_final.astype(np.uint8)

    return img_bgr, heatmap, img_final, argmax, preds


In [None]:
#img_path = "cat.jpg"
img_path = "bear.jpg"
#img_path = "elephant.jpg"
#img_path = "goldfish.jpg"

fig, ax = plt.subplots(nrows=2, figsize=(10, 7), sharex=True, gridspec_kw={'hspace': 0.05})

#grad_cam_pp = grad_cam
cls_idx = -1
src_img, heatmap, img_final, argmax, scores = grad_cam(img_path, cls_idx, cv2.COLORMAP_JET, show_max_cam=True, img_size=224, omit_neg=True)
src_img2, heatmap2, img_final2, argmax2, scores2 = grad_cam_pp(img_path, cls_idx, cv2.COLORMAP_JET, show_max_cam=True, img_size=224, omit_neg=True)
img_final = cv2.cvtColor(img_final, cv2.COLOR_BGR2RGB)
img_final2 = cv2.cvtColor(img_final2, cv2.COLOR_BGR2RGB)
print(argmax)
#plt.imshow(img_final)
ax[0].imshow(img_final)
ax[1].imshow(img_final2)