In [568]:
import os
import torch
import json
import h5py
import numpy as np
from matplotlib.pyplot import imshow
from PIL import Image, ImageDraw, ImageFont
import cv2 

from matplotlib import pyplot as plt

In [569]:
import fontconfig

# find a font file
fonts = fontconfig.query(lang='en')
for i in range(1, len(fonts)):
    if fonts[i].fontformat == 'TrueType':
        absolute_path = fonts[i].file
        
        print(absolute_path)
        break

/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf


In [570]:
image_file = json.load(open('../datasets/vg/image_data.json'))
vocab_file = json.load(open('../datasets/vg/VG-SGG-dicts-with-attri.json'))
data_file = h5py.File('../datasets/vg/VG-SGG-with-attri.h5', 'r')
# remove invalid image
corrupted_ims = [1592, 1722, 4616, 4617]
tmp = []
for item in image_file:
    if int(item['image_id']) not in corrupted_ims:
        tmp.append(item)
image_file = tmp

In [571]:
# load detected results
path = '../checkpoints/iba0.02_s2_inv_prop0.03_power0.5_sum_v3-predcls/'
detected_origin_path = path + 'inference/VG_stanford_filtered_with_attribute_test/'
dir_fmap = path + '/rib/'

In [572]:
detected_origin_result = torch.load(detected_origin_path + 'eval_results.pytorch')
detected_info = json.load(open(detected_origin_path + 'visual_info.json'))

In [628]:
def load_fmap(fname): 
    fmap = np.load(fname)
    fmap = np.mean(fmap, axis=1)
    fmap = fmap * (fmap > 0) # ReLU 
    fmap = fmap - fmap.min()  
    fmap /= fmap.max()
    return fmap 
    
def resize_map(fmap, size): 
    return cv2.resize(fmap, size, cv2.INTER_AREA)

def load_mask(fname): 
    masks = np.load(fname)
    return masks 

def load_inds(fname): 
    rel_inds = np.load(fname)
    return rel_inds

def get_mask(masks, rel_inds, query_rel): 
    idx = rel_inds.tolist().index(query_rel)
    return masks[idx, 0, :, :], idx

def get_predicate(rels, query_rel): 
    idx = rels[:, 0:2].tolist().index(query_rel)
    return rels[idx, 2]

def get_union(box1, box2): 
    ubox = [0, 0, 0, 0]
    
    ubox[0] = min(box1[0], box2[0]).astype(int)
    ubox[1] = min(box1[1], box2[1]).astype(int)
    ubox[2] = max(box1[2], box2[2]).astype(int)
    ubox[3] = max(box1[3], box2[3]).astype(int)
    return ubox

In [629]:
def resize_boxes(img, boxes, im_scale=(1024.,)):
    new_boxes = np.copy(boxes)
    h, w = img.shape[0:2]
    if len(im_scale) == 2: 
        scale_w = im_scale[1] / w
        scale_h = im_scale[0] / h 
    else:
        if h > w: 
            scale_h = im_scale[0] / h 
            scale_w = scale_h 
        else:
            scale_w = im_scale[0] / w 
            scale_h = scale_w 
    #ipdb.set_trace()
    new_boxes[:, [0, 2]] = new_boxes[:, [0, 2]] / scale_w
    new_boxes[:, [1, 3]] = new_boxes[:, [1, 3]] / scale_h
    new_boxes = new_boxes.astype(np.int)
    return  new_boxes

In [630]:
# load fmap
def get_fmap(idx):
    # (1, 256, 232, 152)
    fmap = load_fmap(dir_fmap + '{}_fmap.npy'.format(idx))
    fmap = np.swapaxes(fmap, 0,2)
    # (90, 256, 7, 7)
    rib_fmap = load_fmap(dir_fmap + '{}_rib_fmap.npy'.format(idx))
    rib_fmap = np.swapaxes(rib_fmap, 0,2)
    # (90, 1, 15, 15)
    mask = load_mask(dir_fmap + '{}_mask.npy'.format(idx))
    mask = np.swapaxes(mask, 1,3)
    # (90, 2)
    rel_inds = load_inds(dir_fmap + '{}_inds.npy'.format(idx))
    
    return fmap, rib_fmap, mask, rel_inds

In [631]:
def draw_single_box(pic, box, color='red', draw_info=None):
    draw = ImageDraw.Draw(pic)
    
    # get a font
    fnt = ImageFont.truetype(absolute_path, 14, encoding="unic")
    
    x1,y1,x2,y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
    draw.rectangle(((x1, y1), (x2, y2)), outline=color)
    if draw_info:
        draw.rectangle(((x1, y1), (x1+60, y1+15)), fill=color)
        info = draw_info
        draw.text((x1, y1), info, font=fnt)
        
def draw_union_box(pic, box, color='red', draw_info=None):
    draw = ImageDraw.Draw(pic)
    
    # get a font
    fnt = ImageFont.truetype(absolute_path, 14, encoding="unic")
    
    x1,y1,x2,y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
    
    draw.rectangle(((x1-1, y1-1), (x2+1, y2+1)), outline=color)
    if draw_info:
        draw.rectangle(((x1, y2-15), (x1+190, y2)), fill=color)
        info = draw_info
        draw.text((x1 + 5, y2-15), info, font=fnt)

def print_list(name, input_list):
    for i, item in enumerate(input_list):
        print(name + ' ' + str(i) + ': ' + str(item))
    
def draw_image(img_path, boxes, labels, gt_rels, pred_rels, pred_rel_score, pred_rel_label, print_img=True):
    pic = Image.open(img_path)
    num_obj = boxes.shape[0]
    for i in range(num_obj):
        info = labels[i]
        draw_single_box(pic, boxes[i], draw_info=info)
    if print_img:
        display(pic)
    if print_img:
        print('*' * 50)
        print_list('gt_boxes', labels)
        print('*' * 50)
        print_list('gt_rels', gt_rels)
        print('*' * 50)
    print_list('pred_rels', pred_rels[:20])
    print('*' * 50)
    
    return None

In [632]:
idx = 1

In [633]:
fmap, rib_fmap, mask, rel_inds = get_fmap(idx=idx)
rib_fmap[:,:,1].shape

(15, 15)

In [634]:
# get image info by index
def get_info_by_idx(idx, det_input, thres=0.5):
    groundtruth = det_input['groundtruths'][idx]
    prediction = det_input['predictions'][idx]
    # image path
    img_path = detected_info[idx]['img_file']
    # boxes
    boxes = groundtruth.bbox
    
    # object labels
    idx2label = vocab_file['idx_to_label']
    labels = ['{}-{}'.format(idx,idx2label[str(i)]) for idx, i in enumerate(groundtruth.get_field('labels').tolist())]
    pred_labels = ['{}-{}'.format(idx,idx2label[str(i)]) for idx, i in enumerate(prediction.get_field('pred_labels').tolist())]
    
    # groundtruth relation triplet
    idx2pred = vocab_file['idx_to_predicate']
    gt_rels = groundtruth.get_field('relation_tuple').tolist()
    gt_rels = [(labels[i[0]], idx2pred[str(i[2])], labels[i[1]]) for i in gt_rels]
    
    # prediction relation triplet
    pred_rel_pair = prediction.get_field('rel_pair_idxs').tolist()
    pred_rel_label = prediction.get_field('pred_rel_scores')
    pred_rel_label[:,0] = 0
    pred_rel_score, pred_rel_label = pred_rel_label.max(-1)
    
    #mask = pred_rel_score > thres
    #pred_rel_score = pred_rel_score[mask]
    #pred_rel_label = pred_rel_label[mask]
    pred_rels = [(pred_labels[i[0]], idx2pred[str(j)], pred_labels[i[1]]) for i, j in zip(pred_rel_pair, pred_rel_label.tolist())]
    return img_path, boxes, labels, gt_rels, pred_rels, pred_rel_score, pred_rel_label

In [635]:
pred_path = './pred_boxes/'
try:
    os.mkdir(pred_path)
except :
    None

In [636]:
img_path,boxes,labels,gt_rels,pred_rels,pred_rel_score,pred_rel_label= get_info_by_idx(idx=idx,det_input=detected_origin_result)

In [637]:
union_box

[133.78906, 35.64453, 465.8203, 349.1211]

In [None]:
for idx in range(200, 400):
    
    # get predictions
    img_path,boxes,labels,gt_rels,pred_rels,pred_rel_score,pred_rel_label= get_info_by_idx(
        idx=idx,det_input=detected_origin_result)

    # get image name
    img_name = img_path.split('/')[-1].split('.')[0]

    # get Image
    img = cv2.imread(img_path)
    
    np_boxes = boxes.cpu().numpy()

    for rel_idx in range(len(gt_rels)):
        
        gt_subj_idx = int(gt_rels[rel_idx][0].split('-')[0])
        gt_obj_idx = int(gt_rels[rel_idx][2].split('-')[0])

        gt_subj_label = gt_rels[rel_idx][0].split('-')[1]
        gt_obj_label = gt_rels[rel_idx][2].split('-')[1]
        gt_rel_label = gt_rels[rel_idx][1]

        gt_triplet = '< ' + gt_subj_label + ', ' + gt_rel_label + ', ' + gt_obj_label + ' >' 

        for rel_jdx in range(len(pred_rels)):

            # sample name
            img_rel_name = pred_path + img_name + '_{}.pdf'.format(str(rel_jdx))
            img_jdx = img.copy()
            
            pred_subj_idx = int(pred_rels[rel_jdx][0].split('-')[0])
            pred_obj_idx = int(pred_rels[rel_jdx][2].split('-')[0])

            pred_subj_label = pred_rels[rel_jdx][0].split('-')[1]
            pred_obj_label = pred_rels[rel_jdx][2].split('-')[1]
            pred_rel_label = pred_rels[rel_jdx][1]

            # get union boxes 
            subj_box = np_boxes[pred_subj_idx]
            obj_box = np_boxes[pred_obj_idx]

            # [96.67969, 123.53516, 498.53516, 306.15234]
            union_box = get_union(subj_box, obj_box)

            # pred triplet
            pred_triplet = '< ' + pred_subj_label + ', ' + pred_rel_label + ', ' + pred_obj_label + ' >'


            # get heatmaps
            fmap, rib_fmap, mask, rel_inds = get_fmap(idx=idx)

            if False:
                fmap = resize_map(fmap[0], (img.shape[1], img.shape[0]))
            else:
                fmap_w = union_box[2] - union_box[0]
                fmap_h = union_box[3] - union_box[1]
                fmap = resize_map(rib_fmap[:,:,rel_jdx], (fmap_w, fmap_h))

            if False:
                heatmap = fmap[union_box[1]:union_box[3]+1, union_box[0]:union_box[2]+1]
            else:
                heatmap = fmap

            heatmap = np.array(heatmap * 255, dtype = np.uint8)
            heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_WINTER)

            overlay = 0.5
            img_jdx[union_box[1]:union_box[3], union_box[0]:union_box[2], :] += np.uint8(overlay * heatmap)

            if gt_subj_idx == pred_subj_idx and gt_obj_idx == pred_obj_idx:

                # You may need to convert the color.
                img_jdx = cv2.cvtColor(img_jdx, cv2.COLOR_BGR2RGB)
                im_pil = Image.fromarray(img_jdx)

                draw_single_box(im_pil, subj_box, draw_info=pred_subj_label, color='green')
                draw_single_box(im_pil, obj_box, draw_info=pred_obj_label, color='magenta')
                draw_union_box(im_pil, union_box, draw_info=pred_triplet, color='red')


                #display(im_pil)
                print(img_rel_name)
                #display(im_pil)
                plt.imshow(im_pil)
                plt.axis('off')
                plt.savefig(img_rel_name, dpi=200, bbox_inches='tight', pad_inches=0.1)
            
            

./pred_boxes/2343494_10.pdf
./pred_boxes/2343494_28.pdf
./pred_boxes/2343494_40.pdf
./pred_boxes/2343494_8.pdf
./pred_boxes/2343494_77.pdf
./pred_boxes/2343494_6.pdf
./pred_boxes/2343494_1.pdf
./pred_boxes/2343494_0.pdf
./pred_boxes/2343494_2.pdf
./pred_boxes/2343494_3.pdf
./pred_boxes/2343494_4.pdf
./pred_boxes/2343494_194.pdf
./pred_boxes/2343494_77.pdf
./pred_boxes/2343494_46.pdf
./pred_boxes/2343494_135.pdf
./pred_boxes/2343494_70.pdf
./pred_boxes/2343494_36.pdf
./pred_boxes/2343494_9.pdf
./pred_boxes/2343493_13.pdf
./pred_boxes/2343493_11.pdf
./pred_boxes/2343493_6.pdf
./pred_boxes/2343493_2.pdf
./pred_boxes/2343493_5.pdf
./pred_boxes/2343493_5.pdf
./pred_boxes/2343493_6.pdf
./pred_boxes/2343493_54.pdf
./pred_boxes/2343493_5.pdf
./pred_boxes/2343493_0.pdf
./pred_boxes/2343492_22.pdf
./pred_boxes/2343492_48.pdf
./pred_boxes/2343492_55.pdf
./pred_boxes/2343492_28.pdf
./pred_boxes/2343492_30.pdf
./pred_boxes/2343492_74.pdf
./pred_boxes/2343492_13.pdf
./pred_boxes/2343492_22.pdf
./pre

./pred_boxes/2343455_2.pdf
./pred_boxes/2343455_10.pdf
./pred_boxes/2343455_0.pdf
./pred_boxes/2343455_6.pdf
./pred_boxes/2343455_42.pdf
./pred_boxes/2343454_8.pdf
./pred_boxes/2343454_0.pdf
./pred_boxes/2343453_40.pdf
./pred_boxes/2343453_16.pdf
./pred_boxes/2343453_7.pdf
./pred_boxes/2343453_5.pdf
./pred_boxes/2343453_57.pdf
./pred_boxes/2343453_40.pdf
./pred_boxes/2343453_6.pdf
./pred_boxes/2343453_26.pdf
./pred_boxes/2343453_40.pdf
./pred_boxes/2343453_1.pdf
./pred_boxes/2343453_34.pdf
./pred_boxes/2343453_6.pdf
./pred_boxes/2343452_12.pdf
./pred_boxes/2343452_15.pdf
./pred_boxes/2343452_7.pdf
./pred_boxes/2343452_2.pdf
./pred_boxes/2343452_12.pdf
./pred_boxes/2343452_9.pdf
./pred_boxes/2343452_0.pdf
./pred_boxes/2343452_8.pdf
./pred_boxes/2343451_9.pdf
./pred_boxes/2343451_5.pdf
./pred_boxes/2343451_2.pdf
./pred_boxes/2343451_65.pdf
./pred_boxes/2343451_6.pdf
./pred_boxes/2343451_12.pdf
./pred_boxes/2343451_9.pdf
./pred_boxes/2343451_5.pdf
./pred_boxes/2343451_1.pdf
./pred_boxes/2

./pred_boxes/2343415_13.pdf
./pred_boxes/2343415_12.pdf
./pred_boxes/2343415_4.pdf
./pred_boxes/2343415_44.pdf
./pred_boxes/2343415_1.pdf
./pred_boxes/2343415_4.pdf
./pred_boxes/2343415_8.pdf
./pred_boxes/2343415_13.pdf
./pred_boxes/2343415_6.pdf
./pred_boxes/2343415_8.pdf
./pred_boxes/2343415_4.pdf
./pred_boxes/2343415_0.pdf
./pred_boxes/2343415_0.pdf
./pred_boxes/2343415_5.pdf
./pred_boxes/2343414_1.pdf
./pred_boxes/2343414_5.pdf
./pred_boxes/2343414_5.pdf
./pred_boxes/2343414_11.pdf
./pred_boxes/2343414_35.pdf
./pred_boxes/2343414_5.pdf
./pred_boxes/2343414_24.pdf
./pred_boxes/2343414_4.pdf
./pred_boxes/2343414_55.pdf
./pred_boxes/2343414_100.pdf
./pred_boxes/2343414_22.pdf
./pred_boxes/2343412_0.pdf
./pred_boxes/2343412_2.pdf
./pred_boxes/2343412_3.pdf
./pred_boxes/2343412_12.pdf
./pred_boxes/2343412_12.pdf
./pred_boxes/2343412_1.pdf
