In [None]:
"""
The structure of the code:

- First we need to create the dataset from the precomputed bounding boxes for vqa dataset.
- Then create a dataloader to load the bounding box pairs
- The the Evaluation part

"""

In [1]:
# All the imports go here
import numpy as np
import json
import base64
import copy
import glob
import time
import random
import matplotlib.pyplot as plt

import torch.nn as nn
import torch
import torch.utils.data as Data
from PIL import Image
from torchvision import transforms

from vis_rel.function.config import config, update_config
from vis_rel.modules.frcnn_classifier import Net

update_config('cfgs/vis_rel/frcnn.yaml')

In [None]:
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def plot_rectangles(pil_img, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for (xmin, ymin, xmax, ymax), c in zip(boxes, colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
    plt.axis('off')
    plt.show()

In [2]:
"""
The Dataset creation part
"""

sample_boxes = 20

def createBoundingBoxPairs(path, verbose = False):
    if verbose:
        start_time = time.time()
    image_id = int(path.split('/')[-1].replace('.json', ''))
    objects = json.load(open(path))
    im_info = (objects['image_w'], objects['image_h'])
    
    temp_bb_pairs = []
    
    boxes = np.frombuffer(base64.decodebytes(objects['boxes'].encode()), dtype=np.float32).reshape((objects['num_boxes'], -1))
    
    boxes = random.sample(list(boxes), min(len(boxes), sample_boxes))
    
    if verbose:
        # plot objects on images
        img = Image.open(glob.glob('data/coco/*/*' + str(image_id) + '.jpg')[0])
        img = img.convert('RGB')
        
        plot_rectangles(img, boxes)
    
    for i in range(len(boxes)):
        for j in range(len(boxes)):
            if i != j:
                temp_pair_dic = {
                    'subj_bbox': boxes[i],
                    'obj_bbox': boxes[j],
                    'union_bbox': np.array([min(boxes[i][0], boxes[j][0]), min(boxes[i][1], boxes[j][1]), max(boxes[i][2], boxes[j][2]), max(boxes[i][3], boxes[j][3])], dtype=np.float32),
                    'im_info': im_info,
                    'image_id': image_id
                }
                temp_bb_pairs.append(temp_pair_dic)
    if verbose:            
        end_time = time.time()
        print('Time taken in image id: {} is {}'.format(image_id, end_time-start_time))
                
    return temp_bb_pairs


def createBoundingBoxPairDataset():
    # both training and testing bb pairs
    bb_pairs = []
    
    objs_path_list = glob.glob('data/coco/vgbua_res101_precomputed/*faster_rcnn_genome/*.json')[:128]
    
    # now for each image create the bb pairs
    count = 0
    total = len(objs_path_list)
    for path in objs_path_list:
        bb_pairs.extend(createBoundingBoxPairs(path, verbose = False))
        count += 1
        print("\rImages done : {} / {}".format(count, total), end="  ")
        
    return bb_pairs
    

In [3]:
bb_pairs_dataset = createBoundingBoxPairDataset()
print("Length of bb pairs is : {}".format(len(bb_pairs_dataset)))

Images done : 128 / 128  Length of bb pairs is : 47692


In [4]:
import pickle
print("saving to pickle file")
pickle.dump(bb_pairs_dataset, open('data/bb_pairs_dataset_tiny.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
print("done")

saving to pickle file
done


In [None]:
# try on a single image
img = Image.open('data/coco/train2014/COCO_train2014_000000291797.jpg')



In [None]:
# some method to save the list of dictionaries containing numpy arrays
to

In [None]:
"""
The dataloader for the dataset
"""

import random

class DatasetLoader(Data.Dataset):
    
    def __init__(self, path, bb_pairs_dataset):
        
        self.image_dir = path
        
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])

        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ToTensor(),
            self.normalize,
        ])
        
        self.bb_pairs = bb_pairs_dataset[:1024]       
        self.data_size = len(self.bb_pairs)
        
    def __getitem__(self, idx):
        
        bb_pair = self.bb_pairs[idx]
        
        img = Image.open(glob.glob(self.image_dir + '/*/*' + str(bb_pair['image_id']) + '.jpg')[0])
        img = img.convert('RGB')
        img = self.transform(img)
        
        inputs = {}
        
        # convert inputs to tensors
        inputs['subj_bbox'] = torch.from_numpy(bb_pair['subj_bbox'])
        inputs['obj_bbox'] = torch.from_numpy(bb_pair['obj_bbox'])
        inputs['union_bbox'] = torch.from_numpy(bb_pair['union_bbox'])
        
        inputs['im_info'] = torch.tensor(bb_pair['im_info'])
        inputs['image'] = img
        
        # image_id
        image_ids = torch.tensor(bb_pair['image_id'])
        
        return inputs, image_ids
        
        
    def __len__(self):
        return self.data_size

In [None]:
"""
The evaluator
"""
# batch_size
batch_size = 16

# Load the dataset
val_dataset = DatasetLoader('data/coco', bb_pairs_dataset)

val_dataloader = Data.DataLoader(
                    val_dataset,
                    batch_size,
                    shuffle = False,
                    num_workers = 4,
                    pin_memory = True,
                    sampler = None,
                    drop_last = True
                )

print('Loaded the dataset')

# os env
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7'
# Load the net
model = Net(config)

# define a softmax obj
soft = nn.Softmax(-1)

# Load the state dict
path = 'output/output/vis_rel/ckpt_frcnn_train+val_low_lr_epoch7.pkl'
state_dict = torch.load(path, map_location=torch.device('cpu'))['state_dict']
new_state_dict = {k.replace('module.', ''):state_dict[k] for k in state_dict}
model.load_state_dict(new_state_dict)

model.eval()
print("here")
model.to(torch.device('cuda:4'))
print('not here')
model = nn.DataParallel(model, device_ids = ['cuda:4','cuda:5','cuda:6','cuda:7'])

# Load the idx to label relationship
rel_classes = json.load(open('data/relationship_classes.json'))
class_rel = {v:k for k, v in rel_classes.items()}
print(class_rel)

# define a threshold
threshold = 0.6

In [None]:
print("hello there")
import codecs
relationships = {}

with torch.no_grad():
    for step, (
            inputs,
            image_ids
        ) in enumerate(val_dataloader):

        print("step is : {}\n".format(step))

        for k, v in inputs.items():
            inputs[k] = v.cuda()

        feats, pred = model(inputs)

        # softmax over pred
        pred = soft(pred)

        for i in range(len(pred)):
            pred_ind = int(torch.argmax(pred[i]))
            pred_val = torch.max(pred[i])

            if pred_ind < 20 and pred_val > threshold:
                print("The prediction index is: {}".format(pred_ind))
                print("the max prediction value is {}".format(pred_val))
                print("The relationship is {}".format(str(class_rel[int(pred_ind)])))
                temp_rel = {
                    'predicate': str(class_rel[int(pred_ind)]),
                    'features': feats.cpu().detach().tolist(),
                    'subj_bbox': inputs['subj_bbox'][i].cpu().detach().tolist(),
                    'obj_bbox': inputs['obj_bbox'][i].cpu().detach().tolist()
                }

                if str(int(image_ids[i])) not in relationships:
                    relationships[str(int(image_ids[i]))] = []

                relationships[str(int(image_ids[i]))].append(temp_rel)

        print("\rProgress {}/{}".format(step, val_dataset.data_size/batch_size))


for k, v in realtionships:
    print("\rImage id : {}".format(str(k)), end=' ')
    pickle.dump(v, open('data/coco/vqa_relationships/' + str(k) + '.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)



In [None]:
print(relationships.keys())