In [1]:
import os
import numpy as np
import h5py
import cv2
import json
from PIL import Image

import torch
import torch.nn.functional as F
import torchvision
from torch import nn

import spatial
from model.bgnet_model import AGRNN

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
faster_rcnn = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, rpn_post_nms_top_n_test=200, \
                                                                 box_batch_size_per_image=128, box_score_thresh=0.1, box_nms_thresh=0.3)
faster_rcnn.cuda()
faster_rcnn.eval()

node_num = []
features = None
spatial_feat = None
word2vec_emb = None
roi_labels = None
bg = None

checkpoint = torch.load("checkpoints/run_bg_final_final/v8/epoch_train/checkpoint_21_epoch.pth", map_location=device)
model = AGRNN(feat_type=checkpoint['feat_type'], bias=checkpoint['bias'], bn=checkpoint['bn'], dropout=checkpoint['dropout'], multi_attn=checkpoint['multi_head'], layer=checkpoint['layers'], diff_edge=checkpoint['diff_edge']) #2 )
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()

img = cv2.imread("HICO_train2015_00000019.jpg")
# img = cv2.imread("HICO_test2015_00000016.jpg")
word2vec = h5py.File("datasets/processed/hico/hico_word2vec.hdf5", 'r')
coco_dict = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

In [3]:
outputs = []
def hook(module, input, output):
    outputs.append(output)
faster_rcnn.roi_heads.box_head.fc7.register_forward_hook(hook)

img_norm = img / 255
frcnn_img_tensor = torch.from_numpy(img_norm)
frcnn_img_tensor = frcnn_img_tensor.permute([2,0,1]).float().to(device) # chw format
rcnn_input = [frcnn_img_tensor]

out = faster_rcnn(rcnn_input)[0]
features = outputs[0]
node_num.append(len(features))

In [4]:
faster_rcnn

FasterRCNN(
  (transform): GeneralizedRCNNTransform()
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d()
      (relu): ReLU(inplace)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d()
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d()
          (relu): ReLU(inplace)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d()
          )
        )
  

In [5]:
out

{'boxes': tensor([[227.7711, 229.1417, 459.6313, 377.8958],
         [263.7496, 128.2099, 373.9128, 357.7357],
         [200.9035, 181.5686, 248.2359, 300.0794],
         [302.1853, 141.7461, 365.0817, 212.2273],
         [471.0718,  97.3743, 564.3862, 237.3736]], device='cuda:0',
        grad_fn=<StackBackward>),
 'labels': tensor([ 2,  1,  1, 27,  1], device='cuda:0'),
 'scores': tensor([0.9994, 0.9994, 0.9043, 0.7061, 0.1448], device='cuda:0',
        grad_fn=<IndexBackward>)}

In [8]:
bboxes = None
for i in range(len(out['scores'])):
    if out['scores'][i] < 0.7:
        bboxes = out['boxes'][:i].detach().cpu()
        roi_labels = [out['labels'][:i].detach().cpu().numpy()]
img_wh = [img.shape[1], img.shape[0]]
spatial_feat = spatial.calculate_spatial_feats(bboxes, img_wh)
spatial_feat = torch.Tensor(spatial_feat).to(device)

In [9]:
word2vec_emb = np.empty((0,300))
for id in roi_labels[0]:
    vec = word2vec[coco_dict[id]]
    word2vec_emb = np.vstack((word2vec_emb, vec))
word2vec_emb = torch.Tensor(word2vec_emb).to(device)

In [10]:
mask = np.ones_like(img) * 255
for bbox in bboxes:
    bbox = bbox.detach().cpu().numpy()
    cv2.rectangle(mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 0), thickness=-1)
background_img = cv2.bitwise_and(img, mask, mask=None)
background_img_resize = cv2.resize(background_img, (64, 64), interpolation=cv2.INTER_AREA)
background_img_tensor = torch.from_numpy(background_img)
res_background_input = background_img_tensor.unsqueeze(0)
res_background_input = res_background_input.permute([0,3,1,2]).float().to(device)

In [11]:
with torch.no_grad():
    node_num = [len(roi_labels[0])]
    features = features[:node_num[0]]
    print(type(node_num))
    print(type(features))
    print(type(spatial_feat))
    print(type(word2vec_emb))
    print(type(roi_labels))
    model_preds = model(node_num, features, spatial_feat, word2vec_emb, roi_labels, bg=res_background_input, validation=True)
model_preds.shape

<class 'list'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'list'>




torch.Size([3, 117])

In [12]:
model_preds

tensor([[-11.3312, -10.6611, -10.1204,  -7.5408,  -6.8844, -13.5915, -10.2435,
         -13.0439,  -4.1131,  -8.4764, -13.6136, -11.5625, -13.9360, -11.9764,
         -12.5769,  -7.6644,  -9.8749, -11.4811, -10.2135, -10.7590,  -9.8430,
          -6.3988, -13.3266,  -6.8163,  -4.4017, -11.3143,  -9.3741, -11.8405,
          -7.7789, -12.5106,  -6.9803, -12.9873,  -8.0797, -12.4345, -12.5315,
          -9.0891,  -2.7476, -11.7124, -11.3940,  -6.9715, -13.2066,  -5.4328,
         -12.8884,  -4.7522,  -8.3489,  -9.4233, -12.9479,  -9.7786, -12.5447,
          -6.5293, -11.9104, -13.1376,  -8.5792, -13.0538, -11.4977, -12.2834,
         -12.7899,  -3.7650,  -7.2723, -13.1795, -12.5887, -12.5138, -11.8188,
         -12.4179, -12.6438,  -8.2007, -12.0279,  -8.4913, -12.2052, -11.6880,
          -8.2221,  -9.9318,  -6.8128,  -5.6316, -12.5346,  -9.6821,  -2.7512,
          -6.3312,  -6.7688,  -8.4502, -10.5278, -10.5099, -11.4391, -10.8101,
         -13.0536, -10.8512,  -4.1419,  -3.0312, -13