Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualizing Attention maps on input images #145

Closed
mineshmathew opened this issue Aug 6, 2019 · 5 comments
Closed

Visualizing Attention maps on input images #145

mineshmathew opened this issue Aug 6, 2019 · 5 comments

Comments

@mineshmathew
Copy link

mineshmathew commented Aug 6, 2019

❓ Questions and Help

I was wondering if there is any way to visualize the attention weights over the original inputs for vqa or captioning. I see such figures in papers.

Is there a script already available for this.

@apsdehal
Copy link
Contributor

apsdehal commented Aug 7, 2019

No, we don't have these scripts. But you can save attention weights from the model and use the functions below to plot attention weights on image:

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import skimage
import cv2

from PIL import Image

cmap = matplotlib.cm.get_cmap('jet')
cmap.set_bad(color="k", alpha=0.0)

def attention_bbox_interpolation(im, bboxes, att):
    softmax = att
    assert len(softmax) == len(bboxes)

    img_h, img_w = im.shape[:2]
    opacity = np.zeros((img_h, img_w), np.float32)
    for bbox, weight in zip(bboxes, softmax):
        x1, y1, x2, y2 = bbox
        opacity[int(y1):int(y2), int(x1):int(x2)] += weight
    opacity = np.minimum(opacity, 1)

    opacity = opacity[..., np.newaxis]
    
    vis_im = np.array(Image.fromarray(cmap(opacity, bytes=True), 'RGBA'))
    vis_im = vis_im.astype(im.dtype)
    vis_im = cv2.addWeighted(im, 0.7, vis_im, 0.5, 0)
    vis_im = vis_im.astype(im.dtype)
    
    return vis_im


def attention_grid_interpolation(im, att):
    softmax = np.reshape(att, (14, 14))
    opacity = skimage.transform.resize(softmax, im.shape[:2], order=3)
    opacity = opacity[..., np.newaxis]
    opacity = opacity*0.95+0.05

    vis_im = opacity*im + (1-opacity)*255
    vis_im = vis_im.astype(im.dtype)
    return vis_im

def visualize_pred(im_path, boxes, att_weights):
    im = cv2.imread(im_path)
    im = cv2.cvtColor(im, cv2.COLOR_RGB2RGBA)
    b,g,r,a = cv2.split(im)           # get b, g, r
    im = cv2.merge([r,g,b,a])

    M = min(len(boxes), len(att_weights))
    im_ocr_att = attention_bbox_interpolation(im, boxes[:M], att_weights[:M])
    plt.imshow(im_ocr_att)

Might have some errors though.

@CCYChongyanChen
Copy link

FYI..
An alternative to the function attention_grid_interpolation(im, att).
Attention_grid_interpolation generates white mask with different opacity.

def get_blend_map(self,img, att_map, blur=True, overlap=True):
        att_map -= att_map.min()
        if att_map.max() > 0:
            att_map /= att_map.max()
        att_map = skimage.transform.resize(att_map, (img.shape[:2]), order = 3)
        if blur:
            att_map = skimage.filters.gaussian(att_map, 0.02*max(img.shape[:2]))
            att_map -= att_map.min()
            att_map /= att_map.max()
        cmap = plt.get_cmap('jet')
        att_map_v = cmap(att_map)
        att_map_v = np.delete(att_map_v, 3, 2)
        plt.imshow(att_map_v)
        #plt.imshow(img,alpha=0.2)    
        plt.show()
        # cv2.imwrite("attentionmap.jpg",att_map_v*255)
        return att_map_v

@shamanthak-hegde
Copy link

shamanthak-hegde commented Jul 26, 2022

FYI..
An alternative to the function attention_grid_interpolation(im, att).
Attention_grid_interpolation generates white mask with different opacity.

Hi, I understood what visualize_pred and attention_bbox_interpolation does. But what's the use of get_blend_map or attention_grid_interpolation? And how do we use it?

@micdist
Copy link

micdist commented Oct 13, 2022

@apsdehal Hi mate thanks for sharing code. Just playing with it on DETR and i get size mismatch with "weights", what was the shape of weights you have developed that for? Does it expect me to provide weight projected onto image dimensions and than extract a bounding box size slice? Thanks

@batubb
Copy link

batubb commented Jul 16, 2024

So you always visualize the first attention head? You ignore the others?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants