In [None]:
from modelhub import pretrained_sov_stg_s, ImageProcessor, ResultParser, draw_boxes
import torch.nn.functional as F
import numpy as np
import torch
import cv2

## Define some useful function tools

In [None]:
def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def getJetColorRGB(v, vmin, vmax):
    c = np.zeros((3))
    if (v < vmin):
        v = vmin
    if (v > vmax):
        v = vmax
    dv = vmax - vmin
    if (v < (vmin + 0.125 * dv)): 
        c[0] = 256 * (0.5 + (v * 4)) #B: 0.5 ~ 1
    elif (v < (vmin + 0.375 * dv)):
        c[0] = 255
        c[1] = 256 * (v - 0.125) * 4 #G: 0 ~ 1
    elif (v < (vmin + 0.625 * dv)):
        c[0] = 256 * (-4 * v + 2.5)  #B: 1 ~ 0
        c[1] = 255
        c[2] = 256 * (4 * (v - 0.375)) #R: 0 ~ 1
    elif (v < (vmin + 0.875 * dv)):
        c[1] = 256 * (-4 * v + 3.5)  #G: 1 ~ 0
        c[2] = 255
    else:
        c[2] = 256 * (-4 * v + 4.5) #R: 1 ~ 0.5                      
    return c

def getJetColorRB(v, vmin, vmax):
    c = np.zeros((3))
    if (v < vmin):
        v = vmin
    if (v > vmax):
        v = vmax
    # if (v < (vmin + 0.5 * dv)):
    #     c[0] = 256 * (1-(v-vmin)) #B: 0.5 ~ 1
    # else:
    #     c[2] = 256 * (v-vmin) #R: 1 ~ 0.5
    c[0] = 256 * (1-(v-vmin)) #B: 0.5 ~ 1
    c[2] = 256 * (v-vmin) #R: 1 ~ 0.5                      
    return c  


## Load sov-stg-s

In [None]:
checkpoint_path = 'params/sov-stg-s.pth'
model, PostPeocessor = pretrained_sov_stg_s(checkpoint_path, True)
model.cuda()
model.eval()

## Load an example image

In [None]:
iamge_path = 'data/hico_det/images/test2015/HICO_test2015_00000001.jpg'
img, img_size, orig_img = ImageProcessor(iamge_path, device='cuda')

# Register hooks and forward

In [None]:
hook_value_cross_attn, sampling_offsets, attention_weights, verb_boxes_list = [], [], [], []
def hook_cross_attn_input(model, inputs, output):
    N, Len_q, _ = inputs[0].shape
    hook_value_cross_attn.append([inputs[1], N, Len_q, model.n_heads, model.n_levels, model.n_points, inputs[3]])
hooks = [
        model.transformer.vDec.layers[-1].cross_attn.register_forward_hook(
            hook_cross_attn_input
        ),
        model.transformer.vDec.layers[-1].cross_attn.sampling_offsets.register_forward_hook(
            lambda self, input, output: sampling_offsets.append(output)
        ),
        model.transformer.vDec.layers[-1].cross_attn.attention_weights.register_forward_hook(
            lambda self, input, output: attention_weights.append(output)
        ),
        model.transformer.vDec.register_forward_hook(
            lambda self, input, output: verb_boxes_list.append(output[1][0])
        )
    ]

In [None]:
outputs = model([img])
results = PostPeocessor(outputs, img_size)

for hook in hooks:
    hook.remove()

In [None]:
# process hook value
reference_points, N_, Len_q, n_heads, n_levels, n_points, value_spatial_shapes = hook_value_cross_attn[0]
sampling_offsets_reshape = sampling_offsets[0].view(N_, Len_q, n_heads, n_levels, n_points, 2)
sampling_grids = reference_points[:, :, None, :, None, :2] \
                    + sampling_offsets_reshape / n_points * reference_points[:, :, None, :, None, 2:] * 0.5
attention_weights_reshape = attention_weights[0].view(N_, Len_q, n_heads, n_levels * n_points)

attention_weights_reshape = F.softmax(attention_weights_reshape, -1).view(N_, Len_q, n_heads, n_levels, n_points)


In [None]:
# convert sampling location to image
sampling_grid_rescale_all_lvl = []
attention_weights_all_lvl = []

attention_weights_all_head = []
sampling_grid_rescale_all_head = []

for head_i in range(n_heads):
    # N_, Lq_, lvl, P_, 2 -> N_, Lq_, lvl*P_, 2
    sampling_grid_head_ = sampling_grids[:, :, head_i].flatten(2, 3)
    # N_, Lq_, lvl, P_ -> N_, Lq_, lvl*P_
    attention_weights_head_ = attention_weights_reshape[:, :, head_i].flatten(2, 3)
    attention_weights_all_head.append(attention_weights_head_) # N_, Lq_, lvl*P_

    sampling_grid_head_rescale = sampling_grid_head_ * torch.tensor([img_size[:, 1], img_size[:, 0]], dtype=torch.float32).to(sampling_grid_head_.device)
    sampling_grid_rescale_all_head.append(sampling_grid_head_rescale) # N_, Lq_, lvl*P_, 2

for lid_, (H_, W_) in enumerate(value_spatial_shapes):
    # N_, Lq_, M_, P_, 2 -> N_, Lq_, M_*P_, 2
    sampling_grid_l_ = sampling_grids[:, :, :, lid_].flatten(2, 3)
    # N_, Lq_, M_, P_ -> N_, Lq_, M_*P_
    attention_weights_l_ = attention_weights_reshape[:, :, :, lid_].flatten(2, 3)
    attention_weights_all_lvl.append(attention_weights_l_) # N_, Lq_, M_*P_

    sampling_grid_l_rescale = sampling_grid_l_ * torch.tensor([img_size[:, 1], img_size[:, 0]], dtype=torch.float32).to(sampling_grid_l_.device)
    sampling_grid_rescale_all_lvl.append(sampling_grid_l_rescale) # N_, Lq_, M_*P_, 2

In [None]:
out_obj_logits = outputs['pred_obj_logits'] # N_, Lq_, 80
out_verb_logits = outputs['pred_verb_logits'] # N_, Lq_, 117

out_obj_class = out_obj_logits.softmax(-1)[0].max(-1).indices # Lq_
obj_scores = out_obj_logits.softmax(-1)[0].max(-1).values # Lq_
verb_scores = out_verb_logits.sigmoid()[0] # Lq_, 117

index = 0
for verb_score in verb_scores: 
    verb_score_max = verb_score.cpu().detach().numpy().max()
    obj_scores[index] *= verb_score_max
    index += 1
thres = np.sort(obj_scores.detach().cpu().numpy())[::-1][1]
keep = obj_scores > thres
keep_num =torch.nonzero(keep).item()


In [None]:
out_sub_boxes = outputs['pred_sub_boxes'] # N_, Lq_, 4
out_obj_boxes = outputs['pred_obj_boxes'] # N_, Lq_, 4
out_verb_boxes = verb_boxes_list[0] # N_, Lq_, 4
im_size = (img_size[:, 1], img_size[:, 0])

sub_box_priors = model.refpoint_sub_embed.weight.sigmoid()  # Lq_, 4

obj_box_priors = model.refpoint_obj_embed.weight.sigmoid()

sub_box_priors = rescale_bboxes(sub_box_priors[keep].cpu(), im_size)
obj_box_priors = rescale_bboxes(obj_box_priors[keep].cpu(), im_size)

sub_boxes = rescale_bboxes(out_sub_boxes[0, keep].cpu(), im_size)
obj_boxes = rescale_bboxes(out_obj_boxes[0, keep].cpu(), im_size)
verb_boxes = rescale_bboxes(out_verb_boxes[0, keep].cpu(), im_size)

## display result

Now let's visualize them

In [None]:
OBJSUB_BOX_WIDTH = 7
POINT_SIZE = 9
POINT_ALPHA = 0.8
VERB_BOX_WIDTH = 6
RGB_POINT = False
output_dir = './'
# get the feature map shape

img_cv = cv2.imread(iamge_path)
height, width, _ = img_cv.shape

imgs_obj = []
# for idx, (sx1, sy1, sx2, sy2), (ox1, oy1, ox2, oy2), (vx1, vy1, vx2, vy2),(vxsM1, vysM1, vxsM2, vysM2),(vxM1, vyM1, vxM2, vyM2) in zip(keep.nonzero(), sub_boxes, obj_boxes, verb_boxes, verb_boxes_sMBR, verb_boxes_MBR):
for idx, (sx1, sy1, sx2, sy2), (ox1, oy1, ox2, oy2), (vx1, vy1, vx2, vy2) in zip(keep.nonzero(), sub_boxes, obj_boxes, verb_boxes):
    img_copy = img_cv.copy()

    # resize img to 2 times larger
    img_copy = cv2.resize(img_copy, (width*2, height*2))
    
    imgs_obj.append(img_copy)


# save anchor box priors
img_prior = imgs_obj[0].copy()
for (sx1_prior, sy1_prior, sx2_prior, sy2_prior), (ox1_prior, oy1_prior, ox2_prior, oy2_prior) in zip(sub_box_priors, obj_box_priors):
    cv2.rectangle(img_prior, (int(sx1_prior)*2, int(sy1_prior)*2), (int(sx2_prior)*2, int(sy2_prior)*2), (0,220,0), OBJSUB_BOX_WIDTH)
    cv2.rectangle(img_prior, (int(ox1_prior)*2, int(oy1_prior)*2), (int(ox2_prior)*2, int(oy2_prior)*2), (0,0,220), OBJSUB_BOX_WIDTH)
cv2.imwrite('{}/{}_prior.jpg'.format(output_dir, iamge_path.split('/')[-1][:-4]), img_prior)


img_i2 = imgs_obj[0].copy()
img_i2_base = imgs_obj[0].copy()
point_num = n_levels * n_points
point_all = 0
for head_idx in range(n_heads):
    sampling_grid_head_i = sampling_grid_rescale_all_head[head_idx][0, keep].cpu()[0]
    attention_weights_head_i = attention_weights_all_head[head_idx][0, keep].cpu()[0]

    # change the order of the attention weights from low to high
    order = torch.argsort(attention_weights_head_i)
    attention_weights_head_i = attention_weights_head_i[order]
    sampling_grid_head_i = sampling_grid_head_i[order]

    attention_weights_head_min = attention_weights_head_i.min()
    attention_weights_head_max = attention_weights_head_i.max()
    gap = attention_weights_head_max - attention_weights_head_min
    for p_idx in range(point_num):
        sample_grid_head_i, attn_weight_head_i = sampling_grid_head_i[p_idx], attention_weights_head_i[p_idx]
        attn_weight_head_i = attn_weight_head_i - attention_weights_head_min
        x, y = sample_grid_head_i
        if RGB_POINT:
            color = getJetColorRGB(attn_weight_head_i/gap, 0, 1)
        else:
            color = getJetColorRB(attn_weight_head_i/gap, 0, 1)
        if color[0]< 256*0.6:
            cv2.circle(img_i2, (int(x)*2,int(y)*2), POINT_SIZE, color, -1)
            point_all = point_all + 1
        else:
            cv2.circle(img_i2, (int(x)*2,int(y)*2), POINT_SIZE, np.array([114.0,114.0,114.0]), -1)

    
result = cv2.addWeighted(img_i2, POINT_ALPHA, img_i2_base, 1-POINT_ALPHA, 0)
for idx, (sx1, sy1, sx2, sy2), (ox1, oy1, ox2, oy2), (vx1, vy1, vx2, vy2) in zip(keep.nonzero(), sub_boxes, obj_boxes, verb_boxes):
    cv2.rectangle(result, (int(sx1)*2,int(sy1)*2), (int(sx2)*2,int(sy2)*2), (0,220,0), OBJSUB_BOX_WIDTH)
    cv2.rectangle(result, (int(ox1)*2,int(oy1)*2), (int(ox2)*2,int(oy2)*2), (0,0,220), OBJSUB_BOX_WIDTH)
    cv2.rectangle(result, (int(vx1)*2,int(vy1)*2), (int(vx2)*2,int(vy2)*2), (160,48,112), VERB_BOX_WIDTH)
cv2.imwrite('{}/{}_attn_all.jpg'.format(output_dir, iamge_path.split('/')[-1][:-4]), result)
