In [1]:
from maskrcnn_benchmark.data.datasets.visual_genome import load_info, load_graphs,load_image_filenames
from maskrcnn_benchmark.structures.bounding_box import BoxList
import numpy as np
from PIL import Image
import random
import torch
import cv2
from collections import defaultdict

BOX_SCALE = 1024

In [2]:
base_path="/data/sdc/SGG_data/"

dict_file=base_path+"VG/VG-SGG-dicts-with-attri.json"
roidb_file=base_path+"VG/VG-SGG-with-attri.h5"
image_file=base_path+"VG/image_data.json"
image_dir=base_path+"VG/VG_100K"

In [3]:
ind_to_classes,ind_to_predicates,ind_to_attributes=load_info(dict_file)

split_mask,gt_boxes,gt_classes,gt_attributes,relationships=load_graphs(roidb_file,split='train',num_im=-1,num_val_im=5000,filter_empty_rels=False,filter_non_overlap=False)
filenames,img_info=load_image_filenames(image_dir,image_file)

filenames=[filenames[i] for i in np.where(split_mask)[0]]
img_infos=[img_info[i] for i in np.where(split_mask)[0]]


In [5]:
def get_groundtruth(index, evaluation=False, flip_img=False):
    img_info = img_infos[index]
    w, h = img_info['width'], img_info['height']
    # important: recover original box from BOX_SCALE
    box = gt_boxes[index] / BOX_SCALE * max(w, h)
    box = torch.from_numpy(box).reshape(-1, 4)  # guard against no boxes
    if flip_img:
        new_xmin = w - box[:,2]
        new_xmax = w - box[:,0]
        box[:,0] = new_xmin
        box[:,2] = new_xmax
    target = BoxList(box, (w, h), 'xyxy') # xyxy

    target.add_field("labels", torch.from_numpy(gt_classes[index]))
    target.add_field("attributes", torch.from_numpy(gt_attributes[index]))

    relation = relationships[index].copy() # (num_rel, 3)

    old_size = relation.shape[0]
    all_rel_sets = defaultdict(list)
    for (o0, o1, r) in relation:
        all_rel_sets[(o0, o1)].append(r)
    relation = [(k[0], k[1], np.random.choice(v)) for k,v in all_rel_sets.items()]
    relation = np.array(relation, dtype=np.int32)
    
    # add relation to target
    num_box = len(target)
    relation_map = torch.zeros((num_box, num_box), dtype=torch.int64)
    for i in range(relation.shape[0]):
        if relation_map[int(relation[i,0]), int(relation[i,1])] > 0:
            if (random.random() > 0.5):
                relation_map[int(relation[i,0]), int(relation[i,1])] = int(relation[i,2])
        else:
            relation_map[int(relation[i,0]), int(relation[i,1])] = int(relation[i,2])
    target.add_field("relation", relation_map, is_triplet=True)

    target = target.clip_to_image(remove_empty=False)
    target.add_field("relation_tuple", torch.LongTensor(relation)) # for evaluation
    return target


In [54]:
wo_idx=ind_to_predicates.index('walking on')

include_wo_samples,others=[],[]
for idx in range(len(filenames)):
    target=get_groundtruth(idx)
    target_rels=target.get_field('relation_tuple')
    if len(target_rels)==0:
        continue
    
    print(target_rels[:,2],wo_idx in target_rels[:,2])

    if wo_idx in target_rels[:,2]:
        include_wo_samples.append(idx)
    else:
        others.append(idx)
    
    if len(include_wo_samples)>20:
        break

tensor([20, 20]) False
tensor([31, 20]) False
tensor([49, 47, 48, 49, 49]) False
tensor([31]) False
tensor([20, 20, 30, 30]) False
tensor([29, 20]) False
tensor([48, 38, 20, 38, 50, 22, 22, 50]) False
tensor([21, 30]) False
tensor([20, 20, 20,  8,  8]) False
tensor([31, 31]) False
tensor([30, 30]) False
tensor([48, 20,  8, 49]) False
tensor([48, 31, 21, 31, 20, 48]) False
tensor([31]) False
tensor([30]) False
tensor([31]) False
tensor([37, 22, 47, 48]) False
tensor([31, 31, 48, 48, 48, 20, 38, 48, 48]) False
tensor([29, 29]) False
tensor([31]) False
tensor([20, 31, 31, 31, 31, 48, 48, 48, 43, 20, 31, 30, 38, 38, 48, 48]) False
tensor([50]) False
tensor([21,  8, 48, 48]) False
tensor([40, 31, 20, 20, 31, 31, 49, 20, 49, 49, 20, 48, 33]) False
tensor([23, 30]) False
tensor([31, 31]) False
tensor([ 9, 20, 30]) False
tensor([48, 30, 31]) False
tensor([31]) False
tensor([20, 20]) False
tensor([22, 10, 10, 22, 22, 22, 22, 31, 50, 20]) False
tensor([50, 50]) False
tensor([50, 31, 20,  8, 21, 

In [62]:
import os,glob

wo_sample_idx=random.sample(include_wo_samples,2)
other_sample_idx=random.choice(others)

for sample_id in wo_sample_idx+[other_sample_idx]:
    img=cv2.imread(filenames[sample_id])
    target=get_groundtruth(sample_id)
    
    os.makedirs(str(sample_id),exist_ok=True)
    
    for box,cls in zip(gt_boxes[sample_id],gt_classes[sample_id]):
        obj_img=img[int(box[1]):int(box[3]),int(box[0]):int(box[2]),...]
        if os.path.exists(f'{ind_to_classes[cls]}.png'):
            exist_num=len(glob.glob(f'{sample_id}/{ind_to_classes[cls]}*.png'))
            cv2.imwrite(f'{sample_id}/{ind_to_classes[cls]}-{exist_num}.png',obj_img)
        else:
            cv2.imwrite(f'{sample_id}/{ind_to_classes[cls]}.png',obj_img)

    cv2.imwrite(f'{sample_id}/ori_img.png',img)

In [68]:
sample_ids=[360,432,535]
for sample_id in sample_ids:
    target=get_groundtruth(sample_id)
    rel_tuple=target.get_field('relation_tuple')
    
    
    for rel_t in rel_tuple:
        rel_t=rel_t.tolist()
        write_line=f'{ind_to_classes[gt_classes[sample_id][rel_t[0]]]}-{ind_to_predicates[rel_t[2]]}-{ind_to_classes[gt_classes[sample_id][rel_t[1]]]}'
        with open(f'{sample_id}/relation.txt','a') as rel_txt:
            rel_txt.write(write_line)
            rel_txt.write("\n")
            

In [74]:
import copy
sample_ids=[360,432,535]
for sample_id in sample_ids:
    img=cv2.imread(filenames[sample_id])
    target=get_groundtruth(sample_id)
    rel_tuple=target.get_field('relation_tuple')
    
    for rel_t in rel_tuple:
        rel_t=rel_t.tolist()
        head_box,head_cls=gt_boxes[sample_id][rel_t[0]],ind_to_classes[gt_classes[sample_id][rel_t[0]]]
        tail_box,tail_cls=gt_boxes[sample_id][rel_t[1]],ind_to_classes[gt_classes[sample_id][rel_t[1]]]
        predicate=ind_to_predicates[rel_t[2]]
        
        new_img=copy.deepcopy(img)
        cv2.rectangle(new_img,(int(head_box[0]),int(head_box[1])),(int(head_box[2]),int(head_box[3])),(random.randint(0,255),random.randint(0,255),random.randint(0,255)),1)
        cv2.rectangle(new_img,(int(tail_box[0]),int(tail_box[1])),(int(tail_box[2]),int(tail_box[3])),(random.randint(0,255),random.randint(0,255),random.randint(0,255)),1)
                
        head_box,tail_box=torch.from_numpy(head_box),torch.from_numpy(tail_box)
        union_box=torch.cat((torch.min(head_box[:2],tail_box[:2]),torch.max(head_box[2:],tail_box[2:])),dim=-1)
        
        rel_tri_img=new_img[int(union_box[1]):int(union_box[3]),int(union_box[0]):int(union_box[2]),...]
        cv2.imwrite(f'{sample_id}/{head_cls}-{predicate}-{tail_cls}.png',rel_tri_img)
        


: 