In [1]:
import os
import numpy as np

from utils.register_dataset import register_vrd_dataset
from config import parse_args, get_vrd_cfg
from utils.trainer import CustomTrainer

### [Optional] Training Visual Backbone Module

In [2]:
#Train Dataset
cfg = get_vrd_cfg()
register_vrd_dataset('vrd')
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = CustomTrainer(cfg)

[32m[03/17 08:03:44 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
 

[32m[03/17 08:03:45 d2.data.build]: [0mUsing training sampler TrainingSampler


In [None]:
cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR, 'detectors')

In [None]:
trainer.resume_or_load(resume=False)
trainer.train()

In [None]:
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer

detectron = build_model(cfg)

In [None]:
b = detectron.backbone
b = None
detectron

In [None]:
detectron.roi_heads.box_pooler

In [None]:
from detectron2.modeling.poolers import ROIPooler
pooler = ROIPooler((7, 7), pooler_type='ROIAlignV2', scales=[1/4, 1/8, 1/16, 1/32], sampling_ratio=4)
pooler

In [None]:
#Detections Output
import random
from utils.register_dataset import register_vrd_dataset
from detectron2.utils.visualizer import Visualizer
from detectron2.data import DatasetCatalog, MetadataCatalog

register_vrd_dataset('vrd')
detectron_dataset = DatasetCatalog.get("vrd_train")
vrd_metadata = MetadataCatalog.get("vrd_train")

In [None]:
%matplotlib inline
import cv2
import matplotlib as mpl
from matplotlib import pyplot as plt
for d in random.sample(detectron_dataset, 3):
    print(d["file_name"])
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=vrd_metadata, scale=1)
    vis = visualizer.draw_dataset_dict(d)
    img = vis.get_image()
    plt.figure(dpi=1200)
    plt.imshow(img)
    plt.show()

In [None]:
import os
import json
from detectron2.structures import BoxMode
from config import ROOT_DIR, VRD_DATASET_PATH, VG200_DATASET_PATH

def load_dataset_annotations(dataset_name):
    """
        Load raw annotation dataset and convert it into standard detectron2 format
        Input:
            dataset_name: 'vrd/train' or 'vrd/val'
    """
    dataset_dicts = []
    set_name, set_type = dataset_name.split('/')
    if (set_name == 'vrd'):
        
        #VRD data path
        dataset_path = VRD_DATASET_PATH

        #load annotations file
        annotations = {}
        file_name = "new_annotations_{}.json".format(set_type)
        with open(os.path.join(dataset_path, file_name)) as annotations_file:
            annotations = json.load(annotations_file)
        
        #Map 'annotations' to the correct format
        img_file_dir = os.path.join(VRD_DATASET_PATH, '{}_images'.format(set_type))
        for key, val in annotations.items():
            record = {}

            #populating image fields
            img_file_path = os.path.join(img_file_dir, key)
            height, width = cv2.imread(img_file_path).shape[:2]

            record['file_name'] = img_file_path
            record['height'] = height
            record['width'] = width
            record['image_id'] = int(key.split('.')[0])

            #populating annotations from record
            objs = []
            visited = set()
            for anno in val:
                
                subj = anno['subject'] #ymin, ymax, xmin, xmax
                obj = anno['object']
                bbox_subj = subj['bbox']
                bbox_obj = obj['bbox']
                new_bbox_subj = [bbox_subj[2], bbox_subj[0], bbox_subj[3], bbox_subj[1]]
                new_bbox_obj = [bbox_obj[2], bbox_obj[0], bbox_obj[3], bbox_obj[1]]
                
                subj_dict = {
                    'bbox': new_bbox_subj,
                    'bbox_mode': BoxMode.XYXY_ABS,
                    'category_id': subj['category']
                }
                if (tuple(bbox_subj) not in visited):
                    objs.append(subj_dict)
                    visited.add(tuple(bbox_subj))

                obj_dict = {
                    'bbox': new_bbox_obj,
                    'bbox_mode': BoxMode.XYXY_ABS,
                    'category_id': obj['category'],
                }
                if (tuple(bbox_obj) not in visited):
                    objs.append(obj_dict)
                    visited.add(tuple(bbox_obj))

            record['annotations'] = objs
                         
            #add img to dataset dict
            dataset_dicts.append(record)
    
    return dataset_dicts
data = load_dataset_annotations('vrd/train')

### [Required] DataSet and DataLoader (Step 1)

In [2]:
import os
import cv2
import json
import torch
import random
import copy

import numpy as np
from torch._C import import_ir_module
from utils.boxes import boxes_union
from torch.utils.data import Dataset, DataLoader
import detectron2.data.transforms as T
from detectron2.data import detection_utils as utils
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import matplotlib.pyplot as plt

# from config import ROOT_DIR, VRD_DATASET_PATH, VG200_DATASET_PATH
VRD_DATASET_PATH = "../data/vrd/"

# custom transformation
detectron_transform = T.Resize((800, 800))

class VRDDataset(Dataset):
    def __init__(self, set_type="train", transform=detectron_transform):
        """
            Note:
                transform only applys on bounding boxes. The transformation to the image should be done by detectron2 dataloader.
        """
        annotations_path = os.path.join(
            VRD_DATASET_PATH, "new_annotations_{}.json".format(set_type)
        )
        self.images_dir = os.path.join(VRD_DATASET_PATH, "{}_images".format(set_type))
        self.transform = transform

        with open(annotations_path) as fp:
            raw_annotation = json.load(fp)
        self.annotations = list(raw_annotation.items())

        # check if the data is pre-generated
        roidb_chkpt = os.path.join(VRD_DATASET_PATH, "vrd_roidb_{}.json")
        if os.path.exists(roidb_chkpt):
            with open(roidb_chkpt) as fp:
                self.roidb = json.load(fp)
        else:
            roidb = []
            for img_name, annotations in raw_annotation.items():
                subj_bboxes = []
                obj_bboxes = []
                union_bboxes = []
                unique_objects = {}
                subj_classes = []
                obj_classes = []
                pred_classes = []
                for anno in annotations:
                    subj = anno["subject"]
                    obj = anno["object"]

                    subj["bbox"] = [
                        subj["bbox"][2],
                        subj["bbox"][0],
                        subj["bbox"][3],
                        subj["bbox"][1],
                    ]  # XYXY
                    obj["bbox"] = [
                        obj["bbox"][2],
                        obj["bbox"][0],
                        obj["bbox"][3],
                        obj["bbox"][1],
                    ]

                    union_bboxes.append(
                        boxes_union(np.array([subj["bbox"]]), np.array([obj["bbox"]]))[0]
                    )
                    subj_bboxes.append(subj["bbox"])
                    obj_bboxes.append(obj["bbox"])
                    unique_objects[(tuple(subj["bbox"]))] = {
                        "bbox": subj["bbox"],
                        "bbox_mode": 0,  # BoxMode.XYXY_ABS
                        "category_id": subj["category"],
                    }
                    unique_objects[(tuple(obj["bbox"]))] = {
                        "bbox": obj["bbox"],
                        "bbox_mode": 0,
                        "category_id": obj["category"],
                    }
                    subj_classes.append(subj["category"])
                    obj_classes.append(obj["category"])
                    pred_classes.append(anno["predicate"])
                                
                roidb.append(
                    {
                        # Detectron
                        "file_name": os.path.join(self.images_dir, img_name),
                        "image_id": int(img_name.split(".")[0]),
                        "annotations": list(unique_objects.values()),
                        # Relationships
                        "relationships": {
                            "subj_bboxes": subj_bboxes,
                            "obj_bboxes": obj_bboxes,
                            "union_bboxes": union_bboxes,
                            "subj_classes": subj_classes,
                            "pred_classes": pred_classes,
                            "obj_classes": obj_classes,
                        },
                    }
                )
            self.roidb = roidb
            
    def __len__(self):
        return len(self.roidb)

    def __getitem__(self, idx):
        item = self.roidb[idx]
        cv2.setNumThreads(0)
        image = cv2.imread(item["file_name"])
        
        #get transformation
        auginput = T.AugInput(image)
        transform = self.transform(auginput)
        auginput2 = T.AugInput(image)
        transform2 = self.transform(auginput2)
        relationships = item["relationships"]
        
        #update bboxes
        subj_bboxes = []
        obj_bboxes = []
        union_bboxes = []
        
        for subj_bbox, obj_bbox in zip(relationships['subj_bboxes'], relationships['obj_bboxes']):
            new_subj_box = transform.apply_box(subj_bbox)
            new_obj_box = transform2.apply_box(obj_bbox)
            new_union_box = boxes_union(copy.deepcopy(new_subj_box), copy.deepcopy(new_obj_box))[0]
            
            subj_bboxes.append(new_subj_box)
            obj_bboxes.append(new_obj_box)
            union_bboxes.append(new_union_box)
        relationships['subj_bboxes'] = subj_bboxes
        relationships['obj_bboxes'] = obj_bboxes
        relationships['union_bboxes'] = union_bboxes
        # add height and width
        height, width = image.shape[:2]
        item["height"] = height
        item["width"] = width
        
        return item


def get_object_classes(set_name):
    if set_name == "vrd":
        classes = []
        with open(os.path.join(VRD_DATASET_PATH, "objects.json")) as fp:
            classes = json.load(fp)
        return classes
    elif set_name == "vg":
        # to be implemented
        pass
    else:
        raise NotImplementedError


def get_predicate_classes(set_name):
    if set_name == "vrd":
        classes = []
        with open(os.path.join(VRD_DATASET_PATH, "predicates.json")) as fp:
            classes = json.load(fp)
        # add unknown predicate class for missing predicates
        # classes.insert(0, 'unknown')
        classes.append('unrelated')
        return classes
    elif set_name == "vg":
        # to be implemented
        pass
    else:
        raise NotImplementedError


def visualize_image_bboxes(image, instances, object_classes=None):
    """
        Inputs:
            image: image tensor
            instances: Instances object from detectron2
    """
    img = copy.deepcopy(image)
    instances_dict = instances[0].get_fields()
    bounding_boxes = [box.tolist() for box in instances_dict['pred_boxes']]
    labels = instances_dict['pred_classes']
    
    for bbox in bounding_boxes:
        img = cv2.rectangle(img, [int(coord) for coord in bbox[0:2]], [int(coord) for coord in bbox[2:4]], (255,255,255), 2)
    plt.imshow(img)
    plt.show()

def visualize_bboxes(dataset):
    pred_classes = get_predicate_classes("vrd")
    obj_classes = get_object_classes("vrd")

    for i in random.sample(range(len(dataset)), 60):
        image, cropped_img, anno = dataset[i]
        img = copy.deepcopy(image)
        subj = anno["subject"]
        obj = anno["object"]

        subject_class = obj_classes[subj["category"]]
        predicate_class = pred_classes[anno["predicate"]]
        object_class = obj_classes[obj["category"]]

        img = cv2.rectangle(img, subj["bbox"][0:2], subj["bbox"][2:4], (0, 0, 255), 2)
        img = cv2.rectangle(img, obj["bbox"][0:2], obj["bbox"][2:4], (255, 0, 0), 2)

        cv2.imshow(
            " ".join((subject_class, predicate_class, object_class)), cropped_img
        )
        cv2.waitKey(3000)

In [3]:
#Extract ROI Features
"""
Relevant Links:
    https://towardsdatascience.com/faster-rcnn-object-detection-f865e5ed7fc4#:~:text=Faster%20RCNN%20is%20an%20object,SSD%20(%20Single%20Shot%20Detector).
"""
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.structures import ImageList

from detectron2.modeling import build_backbone
from detectron2.modeling.poolers import ROIPooler
from transformers import BertTokenizer, BertModel

from config import get_vrd_cfg

def get_roi_features(images, box_lists, output_size=(14,14)):
    """
        Get image features from the backbone network
        Input:
            images: (ImageList.from_tensors) with dimension (B,C,W,H)
            box_lists: A list of N boxes
    """
    cfg = get_vrd_cfg()
    backbone = build_backbone(cfg)
    pooler = ROIPooler(output_size, pooler_type='ROIAlignV2', scales=[1/4, 1/8, 1/16, 1/32, 1/64], sampling_ratio=4)
    feature_maps = backbone(images)
    feature_maps = [feature_maps['p{}'.format(i)] for i in range(2,7)]
    regions_feature = pooler(feature_maps, box_lists)
    print(regions_feature.shape)

    return regions_feature

def generate_roi_features(dataset):
    pass
    # images = ImageList.from_tensors(tensors)  # preprocessed input tensor
    # model = build_model(cfg)
    # DetectionCheckpointer(model).load("output/model_final.pth")
    # model.eval()
    # features = model.backbone(images.tensor)
    # proposals, _ = model.proposal_generator(images, features)
    # instances, _ = model.roi_heads(images, features, proposals)
    # mask_features = [features[f] for f in model.roi_heads.in_features]
    # mask_features = model.roi_heads.mask_pooler(mask_features, [x.pred_boxes for x in instances])


In [4]:
#[Only Run once] Register dataset with detectron2 instead of using my own dataloader
from detectron2.data import DatasetCatalog, MetadataCatalog
from utils.annotations import get_object_classes

def get_vrd_dicts(dataset_name):
    set_name, set_type = dataset_name.split('/')
    dataset = None
    if (set_name == 'vrd'):
        dataset = VRDDataset(set_type=set_type) 
    return dataset

def register_vrd_dataset(set_name):
    """
        Register dataset and its metadata to the detectron2 engine
        Input:
            set_name: 'vrd' or vg200
    """
    thing_classes = get_object_classes(set_name)

    #register the annotations
    for d_type in ['train', 'val']:
        DatasetCatalog.register('_'.join((set_name, d_type)), lambda d_type=d_type: get_vrd_dicts('/'.join((set_name, d_type))))
        MetadataCatalog.get('_'.join((set_name, d_type))).set(thing_classes=thing_classes)
register_vrd_dataset('vrd')

### [Optional] Testing/Check whether ROI is updated with Image augmentation

In [None]:
#build dataloader
from detectron2.data import (
    DatasetCatalog, DatasetMapper,
    build_detection_train_loader
)
from config import get_vrd_cfg
import detectron2.data.transforms as T

cfg = get_vrd_cfg()
dataloader = build_detection_train_loader(cfg,
    mapper=DatasetMapper(cfg, is_train=True, augmentations=[
        T.Resize((800, 800))
    ])
)

In [None]:
data_iter = iter(dataloader)

In [None]:
del data_iter

In [None]:
train_features = next(data_iter)
train_features = train_features[0] #first element, dataloader with batch size 1

In [None]:
data = train_features
relationships = data['relationships']
subj_boxes = relationships['subj_bboxes']
union_boxes = relationships['union_bboxes']
obj_boxes = relationships['obj_bboxes']
subj_classes = relationships['subj_classes']
pred_classes = relationships['pred_classes']
obj_classes = relationships['obj_classes']

In [None]:
from utils.boxes import boxes_intersect

In [None]:
def is_line_overlap(box1, box2):
    """
        box1: (xmin1, xmax1)
        box2: (xmin2, xmax2)
    """
    xmin1, xmax1 = box1
    xmin2, xmax2 = box2
    return xmax1 >= xmin2 and xmax2 >= xmin1

def is_box_overlap(box1, box2):
    xmin1, ymin1, xmax1, ymax1 = box1
    xmin2, ymin2, xmax2, ymax2 = box2
    
    return is_line_overlap((xmin1, xmax1), (xmin2, xmax2)) and is_line_overlap((ymin1, ymax1), (ymin2, ymax2))

boxes_intersect(np.array([box[0] for box in subj_boxes]), np.array([box[0] for box in obj_boxes]))
# np.array([subj["bbox"]]), np.array([obj["bbox"]])

In [None]:
import matplotlib.pyplot as plt
import cv2
import numpy as np
from utils.annotations import get_object_classes

#thing classes
thing_classes = get_object_classes('vrd')

#Visualizing the trained_features
# train_features = data
img = train_features['image']
img = img.numpy()
img = copy.deepcopy(np.transpose(img, (1,2, 0)))

boxes = train_features['instances'].get_fields()['gt_boxes']
classes = train_features['instances'].get_fields()['gt_classes']
for cls, box in zip(classes, boxes):
    int_box = [int(i) for i in box]
    
plt.figure(dpi=800)
plt.imshow(img[:,:,[2,1,0]])
plt.show()

In [None]:
#Get Region of Interests of the ground truth bboxes of the image
images = torch.unsqueeze(train_features['image'], axis=0)
boxes = train_features['instances'].get_fields()['gt_boxes']
visual_features = get_roi_features(images.float(), box_lists=[boxes], output_size=(7, 7))

### [Required] BERT Modeling (extracting features from text) (Step 2)

In [5]:
from transformers import BertTokenizer, BertModel
from collections import defaultdict

def get_word_features(triples, model, tokenizer):
    """
    Args:
        triples: (Subj, Pred, Obj)
    Return:
        dict of [CLS, Subj, Pred, Obj, SEP] embeddings
    """
    results = {}
    
    # Load pre-trained model tokenizer (vocabulary)
    marked_text = "[CLS] " + " ".join(triples) + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
    
    # Save the token split to average them later on
    token_placements = defaultdict(list)
    triples_temp = list(triples)
    for i, tok in enumerate(tokenized_text):
        stip_tok = tok.replace('#', '')
        if (stip_tok in triples_temp[0]):
            token_placements['subj'].append(i)
            triples_temp[0] = triples_temp[0].replace(stip_tok, '')
        elif (stip_tok in triples_temp[1]):
            token_placements['pred'].append(i)
            triples_temp[1] = triples_temp[1].replace(stip_tok, '')
        elif (stip_tok in triples_temp[2]):
            token_placements['obj'].append(i)
            triples_temp[2] = triples_temp[2].replace(stip_tok, '')
        elif (not tok == '[CLS]' and not tok == '[SEP]'):
            print(tok, triples)

    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1] * len(tokenized_text)  # one sentence

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    # Put the model in "evaluation" mode, meaning feed-forward operation.
    model.eval()

    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)

        # Evaluating the model will return a different number of objects based on
        # how it's  configured in the `from_pretrained` call earlier. In this case,
        # becase we set `output_hidden_states = True`, the third item will be the
        # hidden states from all layers. See the documentation for more details:
        # https://huggingface.co/transformers/model_doc/bert.html#bertmodel
        hidden_states = outputs[2]

    token_embeddings = torch.stack(hidden_states, dim=0)
    token_embeddings.size()

    # remove dimension 1
    token_embeddings = torch.squeeze(token_embeddings, dim=1)
    token_embeddings = token_embeddings.permute(1, 0, 2)

    # get token embeddings (list of token embeddings)
    token_vecs_cat = []
    for token in token_embeddings:
        cat_vec = torch.cat((token[-1], token[-2], token[-3], token[-4]), dim=0)
        token_vecs_cat.append(cat_vec)
    results['CLS'] = token_vecs_cat[0]
    results['SEP'] = token_vecs_cat[-1]

    # average the token embeddings for word that are splitted to get word embeddings
    for key, val in token_placements.items():
        results[key] = token_vecs_cat[val[0]]
        for i in range(1, len(val)):
            results[key] += token_vecs_cat[val[i]]
        results[key] = results[key] / len(val)
        
    return results

In [6]:
# model = BertModel.from_pretrained(
#     "bert-base-uncased",
#     output_hidden_states=True,  # Whether the model returns all hidden-states.
# )
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


In [7]:
#object and predicate labels
from utils.annotations import get_object_classes, get_predicate_classes

object_classes = get_object_classes('vrd')
predicate_classes = get_predicate_classes('vrd')

In [8]:
import json
import os

triples_embeddings_path = '../generated/triples_embeddings.pt'
def get_triples_features(set_name='vrd'):
    triples_memo = {}
#     triples_memo = torch.load(triples_embeddings_path)
    if (os.path.exists(triples_embeddings_path)):
        triples_memo = torch.load(triples_embeddings_path)
        return triples_memo
        
    # initialize the model and tokenizer
    model = BertModel.from_pretrained(
        "bert-base-uncased",
        output_hidden_states=True,  # Whether the model returns all hidden-states.
    )
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
    # iterate through the all the triples, and extract the features
    dataset = get_vrd_dicts('vrd/train')
    for train_feat in dataset:
        rel = train_feat['relationships']
        for subj_ind, pred_ind, obj_ind in zip(rel['subj_classes'], rel['pred_classes'], rel['obj_classes']):
            for neg_obj in object_classes:
                print((subj_ind, pred_ind, obj_ind))
                #neg subj
                triples_text = (neg_obj, predicate_classes[pred_ind], object_classes[obj_ind])
                if ('-'.join(triples_text) in triples_memo):
                    continue
                word_feat = get_word_features(triples_text, model, tokenizer)
                triples_memo['-'.join(triples_text)] = word_feat
                
                #neg obj
                triples_text = (object_classes[subj_ind], predicate_classes[pred_ind], neg_obj)
                if ('-'.join(triples_text) in triples_memo):
                    continue
                word_feat = get_word_features(triples_text, model, tokenizer)
                triples_memo['-'.join(triples_text)] = word_feat
                
    
    try:
        torch.save(triples_memo, triples_embeddings_path)
    except:
        pass

    return triples_memo
token_embeddings = get_triples_features()

### [Optional] Visualize language triples similarities

In [None]:
from numpy.linalg import norm
def cosine_similarity(a, b):
    return (a @ b.T) / (norm(a)*norm(b))

def get_predicate_similarity_scores(test_triples_feats, token_name='pred'):
    """
        Input:
            test_triples_feat: triple feature to test
            token: pred, obj, subj, or CLS
        token
    """
    similarity_scores = {}
    for triples, features in triples_memo.items():
        try:
            similarity_scores[triples] = cosine_similarity(test_triples_feats[token_name], features[token_name])
        except:
            print(triples)
    return similarity_scores

In [None]:
test_triples = 'motorcycle-on-road'
test_triples_feats = triples_memo[test_triples]
similarity_scores = get_predicate_similarity_scores(test_triples_feats, token_name='obj')
dict(sorted(similarity_scores.items(), key=lambda item: item[1], reverse=True))

In [None]:
stacks = {
    'CLS': [],
    'SEP': [],
    'subj': [],
    'pred': [],
    'obj': [],
    'subj-obj': [],
    'subj+pred': [],
}
triples_label = []
for key, trip in triples_memo.items():
    for k, val in trip.items():
        stacks[k].append(val)
    stacks['subj-obj'].append(torch.cat((trip['subj'], trip['obj']), dim=0))
    stacks['subj+pred'].append((trip['subj'] + trip['pred'])/2)
    
    triples_label.append(key)

In [None]:
import matplotlib as mp

#color maping for plot
colors =  [ list(np.random.choice(range(256), size=3)/255) for _ in range(100)]
cmap = mp.colors.ListedColormap(colors, name='from_list')

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

#t-sne 2-dim predicate
tsne = TSNE(2, perplexity=30, verbose=1)
predicate_stack = np.array([item.numpy() for item in stacks['pred']])
tsne_proj = tsne.fit_transform(predicate_stack)

In [None]:
#subj-obj 2-dim t-sne
tsne = TSNE(2, perplexity=30, verbose=1)
subj_obj_stack = np.array([item.numpy() for item in stacks['subj-obj']])
tsne_proj = tsne.fit_transform(subj_obj_stack)

#plotting
fig, ax = plt.subplots(figsize=(16,16))
num_categories = 100
predicate_classes = get_predicate_classes('vrd')
pred_labels = np.array([item.split('-')[1] for item in triples_label])
for lab, pred in enumerate(predicate_classes):
    indices = pred_labels == pred
    ax.scatter(tsne_proj[indices,0],tsne_proj[indices,1], c=np.array(cmap(lab)).reshape(1,4), label = pred ,alpha=0.5)
ax.legend(fontsize='large', markerscale=2)
plt.show()

In [None]:
#(subj+pred)/2 2-dim t-sne
tsne = TSNE(2, perplexity=30, verbose=1)
subj_pred_stack = np.array([item.numpy() for item in stacks['subj+pred']])
tsne_proj = tsne.fit_transform(subj_pred_stack)

#plotting
fig, ax = plt.subplots(figsize=(16,16))
num_categories = 100
predicate_classes = get_predicate_classes('vrd')
pred_labels = np.array([item.split('-')[1] for item in triples_label])
for lab, pred in enumerate(predicate_classes):
    indices = pred_labels == pred
    ax.scatter(tsne_proj[indices,0],tsne_proj[indices,1], c=np.array(cmap(lab)).reshape(1,4), label = pred ,alpha=0.5)
ax.legend(fontsize='large', markerscale=2)
plt.show()

In [None]:
#obj 2-dim t-sne
tsne = TSNE(2, perplexity=30, verbose=1)
obj_stack = np.array([item.numpy() for item in stacks['obj']])
tsne_proj = tsne.fit_transform(obj_stack)

#plotting
fig, ax = plt.subplots(figsize=(16,16))
num_categories = 100
predicate_classes = get_predicate_classes('vrd')
pred_labels = np.array([item.split('-')[1] for item in triples_label])
for lab, pred in enumerate(predicate_classes):
    indices = pred_labels == pred
    ax.scatter(tsne_proj[indices,0],tsne_proj[indices,1], c=np.array(cmap(lab)).reshape(1,4), label = pred ,alpha=0.5)
ax.legend(fontsize='large', markerscale=2)
plt.show()

In [None]:
#word feature length
obj_stack[0].shape

In [None]:
#roi feature length
visual_features.reshape((6, 256*7*7)).shape

In [None]:
import random
random.random()

In [None]:
torch.save(model.token_embeddings, triples_embeddings_path)

### [Required] VRDTransR Model

In [13]:
import os
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import ModuleDict

from utils.annotations import get_object_classes, get_predicate_classes
from utils.boxes import boxes_union

from modeling.roi_features import get_roi_features
from detectron2.modeling import build_model
from detectron2.modeling import build_backbone
from detectron2.modeling.poolers import ROIPooler
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.structures.boxes import Boxes
import random

class RelTransR(nn.Module):
    def __init__(self, cfg, pooling_size=(7,7), training=True):
        super().__init__()
        self.cfg = cfg
        
        # Object and Predicate Classes
        self.object_classes = get_object_classes('vrd')
        self.predicate_classes = get_predicate_classes('vrd')

        # Embeddings dimensions
        self.visual_feature_dim = 256*pooling_size[0]*pooling_size[1]
        self.visual_hidden_dim = 128*pooling_size[0]*pooling_size[1]
        self.word_feature_dim = 3072
        self.trans_feature_dim = 256
        self.emb_feature_dim = 64
        
        # Spatial Module
        self.spatial_feature_dim = 22
        self.spatial_hidden_dim = 64
        self.fc_spatial = torch.nn.Sequential(
            torch.nn.Linear(self.spatial_feature_dim, self.spatial_hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.spatial_hidden_dim, self.spatial_hidden_dim),
        )
        
        # Visual Modal
        self.detectron = build_model(cfg)
        if (training):
            self._load_detectron_chkpoints(cfg)

        # Seperate for predicate
        self.backbone = copy.deepcopy(self.detectron.backbone)
        self.pooler = copy.deepcopy(self.detectron.roi_heads.box_pooler)

        # Language Modal
        self.bert_model = BertModel.from_pretrained(
            "bert-base-uncased",
            # Whether the model returns all hidden-states.
            output_hidden_states=True,
        )
        self.bert_model.to('cuda')
        self.bert_model.eval()
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        
        # Pre-trained token embeddings (static without changes for now)
        self.triples_embeddings_path = '../generated/triples_embeddings.pt'

        # Fully connect language
        self.fc_word = ModuleDict({
            'subj': torch.nn.Sequential(
                torch.nn.Linear(self.word_feature_dim, self.trans_feature_dim),
                #torch.nn.BatchNorm1d(self.trans_feature_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.trans_feature_dim, self.emb_feature_dim),
            ),
            'pred': torch.nn.Sequential(
                torch.nn.Linear(self.word_feature_dim, self.trans_feature_dim),
                #torch.nn.BatchNorm1d(self.trans_feature_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.trans_feature_dim, self.emb_feature_dim),
            ),
            'obj': torch.nn.Sequential(
                torch.nn.Linear(self.word_feature_dim, self.trans_feature_dim),
                #torch.nn.BatchNorm1d(self.trans_feature_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.trans_feature_dim, self.emb_feature_dim),
            ),
        })
        
        # Fully connect roi
        self.fc_rois = ModuleDict({
            'subj': torch.nn.Sequential(
                torch.nn.Linear(self.visual_feature_dim, self.visual_hidden_dim),
                #torch.nn.BatchNorm1d(self.visual_hidden_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.visual_hidden_dim, self.trans_feature_dim),
                #torch.nn.BatchNorm1d(self.trans_feature_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.trans_feature_dim, self.trans_feature_dim),
            ),
            'pred': torch.nn.Sequential(
                torch.nn.Linear(self.visual_feature_dim + self.spatial_hidden_dim, self.visual_hidden_dim),
                #torch.nn.BatchNorm1d(self.visual_hidden_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.visual_hidden_dim, self.trans_feature_dim),
                #torch.nn.BatchNorm1d(self.trans_feature_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.trans_feature_dim, self.trans_feature_dim),
            ),
            'obj': torch.nn.Sequential(
                torch.nn.Linear(self.visual_feature_dim, self.visual_hidden_dim),
                #torch.nn.BatchNorm1d(self.visual_hidden_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.visual_hidden_dim, self.trans_feature_dim),
                #torch.nn.BatchNorm1d(self.trans_feature_dim),
                torch.nn.LeakyReLU(0.1),
                torch.nn.Linear(self.trans_feature_dim, self.trans_feature_dim),
            ),
        })
        
        self.fc_rois2 = ModuleDict({
            'subj': torch.nn.Sequential(
                torch.nn.Linear(self.trans_feature_dim, self.emb_feature_dim),
                #torch.nn.BatchNorm1d(self.emb_feature_dim),
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.emb_feature_dim, self.emb_feature_dim),
            ),
            'pred': torch.nn.Sequential(
                torch.nn.Linear(self.trans_feature_dim, self.emb_feature_dim),
                #torch.nn.BatchNorm1d(self.emb_feature_dim),
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.emb_feature_dim, self.emb_feature_dim),
            ),
            'obj': torch.nn.Sequential(
                torch.nn.Linear(self.trans_feature_dim, self.emb_feature_dim),
                #torch.nn.BatchNorm1d(self.emb_feature_dim),
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.emb_feature_dim, self.emb_feature_dim),
            ),
        })
        
        # Triplet Loss (Cosine Distance)
        self.triplet_loss = ModuleDict({
            'subj': nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y), margin=0.2),
            'pred': nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y), margin=0.2),
            'obj': nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y), margin=0.2),
        })

    def _freeze_parameters(self, cfg):
        freeze_detectron = True
        if freeze_detectron:
            for param in self.detectron.parameters():
                param.requires_grad = False

            for param in self.backbone.parameters():
                param.requires_grad = False
                    
    def _load_detectron_chkpoints(self, cfg):
        """
            Extension of __init__ for modules
        """
        # Load Detectron2 Pre-Trained Weights
        if cfg.VRD_RESNETS101_PRETRAINED_WEIGHTS is not None:
            DetectionCheckpointer(self.detectron).load(
                os.path.join(cfg.OUTPUT_DIR, 'model_final.pth')
            )
    
    def _load_words_chkpoints(self, cfg):
        """
            Extension of __init__ for modules
        """
        return get_triples_features(cfg.DATASETS.TRAIN[0].split('_')[0])

    
    def _save_words_chkpoints(self, cfg):
        torch.save(self.token_embeddings, self.triples_embeddings_path)
    
    def _get_bert_features(self, triples):
        """
        Args:
            triples: (Subj, Pred, Obj)
        Return:
            dict of [CLS, Subj, Pred, Obj, SEP] embeddings
        """
        results = {}

        # Load pre-trained model tokenizer (vocabulary)
        marked_text = "[CLS] " + " ".join(triples) + " [SEP]"
        tokenized_text = self.tokenizer.tokenize(marked_text)

        # Save the token split to average them later on
        token_placements = defaultdict(list)
        triples_temp = list(triples)
        for i, tok in enumerate(tokenized_text):
            stip_tok = tok.replace('#', '')
            if (stip_tok in triples_temp[0]):
                token_placements['subj'].append(i)
                triples_temp[0] = triples_temp[0].replace(stip_tok, '')
            elif (stip_tok in triples_temp[1]):
                token_placements['pred'].append(i)
                triples_temp[1] = triples_temp[1].replace(stip_tok, '')
            elif (stip_tok in triples_temp[2]):
                token_placements['obj'].append(i)
                triples_temp[2] = triples_temp[2].replace(stip_tok, '')
            elif (not tok == '[CLS]' and not tok == '[SEP]'):
                print(tok, triples)

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        segments_ids = [1] * len(tokenized_text)  # one sentence

        # Convert inputs to PyTorch tensors
        tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')
        segments_tensors = torch.tensor([segments_ids]).to('cuda')

        # Put the model in "evaluation" mode, meaning feed-forward operation.
        with torch.no_grad():
            outputs = self.bert_model(tokens_tensor, segments_tensors)

            # Evaluating the model will return a different number of objects based on
            # how it's  configured in the `from_pretrained` call earlier. In this case,
            # becase we set `output_hidden_states = True`, the third item will be the
            # hidden states from all layers. See the documentation for more details:
            # https://huggingface.co/transformers/model_doc/bert.html#bertmodel
            hidden_states = outputs[2]

        token_embeddings = torch.stack(hidden_states, dim=0)
        token_embeddings.size()

        # remove dimension 1
        token_embeddings = torch.squeeze(token_embeddings, dim=1)
        token_embeddings = token_embeddings.permute(1, 0, 2)

        # get token embeddings (list of token embeddings)
        token_vecs_cat = []
        for token in token_embeddings:
            cat_vec = torch.cat((token[-1], token[-2], token[-3], token[-4]), dim=0)
            token_vecs_cat.append(cat_vec)
        results['CLS'] = token_vecs_cat[0]
        results['SEP'] = token_vecs_cat[-1]

        # average the token embeddings for word that are splitted to get word embeddings
        for key, val in token_placements.items():
            results[key] = token_vecs_cat[val[0]]
            for i in range(1, len(val)):
                results[key] += token_vecs_cat[val[i]]
            results[key] = results[key] / len(val)

        return results
    
    def _get_word_features(self, triples):
        """
            Args:
                triples: ([(subj, pred, obj)]) list of triple
            Return:
                resulting embeddings for subjs, preds, and objs
        """
        subj_embeddings = []
        pred_embeddings = []
        obj_embeddings = []
        
        for subj, pred, obj in triples:
            subj_cls_label = self.object_classes[subj]
            pred_cls_label = self.predicate_classes[pred]
            obj_cls_label = self.object_classes[obj]
            triples_text = '-'.join((subj_cls_label, pred_cls_label, obj_cls_label))
            query_embeddings = None
            if (triples_text not in token_embeddings):
                query_embeddings = self._get_bert_features((subj_cls_label, pred_cls_label, obj_cls_label))
                token_embeddings[triples_text] = query_embeddings
            else:
                query_embeddings = token_embeddings[triples_text]
            subj_embeddings.append(query_embeddings['subj'].to('cuda'))
            pred_embeddings.append(query_embeddings['pred'].to('cuda'))
            obj_embeddings.append(query_embeddings['obj'].to('cuda'))
        
        subj_embeddings = torch.stack(subj_embeddings).to('cuda')
        pred_embeddings = torch.stack(pred_embeddings).to('cuda')
        obj_embeddings = torch.stack(obj_embeddings).to('cuda')
        
        return subj_embeddings, pred_embeddings, obj_embeddings
    
    def _get_word_predicate_features(self, subj, obj):
        """
            Input:
                subj: subject index
                obj: object index
            Output:
                A stack pred_embeddings for the above two subj and obj
        """
        subj_cls_label = self.object_classes[subj]
        obj_cls_label = self.object_classes[obj]

        #predicate label
        pred_embeddings = []
        subj_embeddings = []
        obj_embeddings = []
        for pred, pred_cls_label in enumerate(self.predicate_classes):
            triples_text = '-'.join((subj_cls_label, pred_cls_label, obj_cls_label))
            query_embeddings = None
            if (triples_text not in token_embeddings):
                query_embeddings = self._get_bert_features((subj_cls_label, pred_cls_label, obj_cls_label))
                token_embeddings[triples_text] = query_embeddings
            else:
                query_embeddings = token_embeddings[triples_text]
            subj_embeddings.append(query_embeddings['subj'].to('cuda'))
            pred_embeddings.append(query_embeddings['pred'].to('cuda'))
            obj_embeddings.append(query_embeddings['obj'].to('cuda'))
            
        subj_embeddings = torch.stack(subj_embeddings).to('cuda')
        pred_embeddings = torch.stack(pred_embeddings).to('cuda')
        obj_embeddings = torch.stack(obj_embeddings).to('cuda')

        return subj_embeddings, pred_embeddings, obj_embeddings
    
    def get_instances_prediction(self, data):
        """
            This function gets the predicted instances from the object detector, 
            and is only relevant to relationship detection evaluation.
            
            Input:
                data - data format for detectron2
            Output:
                instances - Instances object by the detectron2 that contain predicted instances
        """
        instances = []
        
        images =  torch.unsqueeze(data['image'], axis=0).cuda().float()
        images_list = ImageList(images, [(800, 800)])
        self.detectron.eval()
        features = self.detectron.backbone(images)
        proposals, _ = self.detectron.proposal_generator(images_list, features)
        instances, _ = self.detectron.roi_heads(images, features, proposals)
        
        return instances
    
    def enumerate_relationships_from_instances(self, instances):
        """
            This function convert the format of instancesinto all possible combinations
            of relationships in the detectron2 format
            
            Input:
                instances - Instances object by detectron2
            Output:
                relationships - a dictionary of relationships in the vrdtransr input format
        """
        enumerated_relationships = {
            'subj_bboxes': [],
            'obj_bboxes': [],
            'union_bboxes': [],
            'subj_classes': [],
            'obj_classes': [],
            'subj_scores': [],
            'obj_scores': []
        }
        
        # TO DO: implement to support batch images instead of a single image
        instance_dict = instances[0].get_fields()
        
        pred_bboxes = instance_dict['pred_boxes']
        pred_cls = instance_dict['pred_classes']
        pred_scores = instance_dict['scores']
        
        for i in range(len(pred_cls)):
            for j in range(i + 1, len(pred_cls)):
                obj1_bbox = pred_bboxes[i].tensor.cpu().detach().numpy().astype(int)
                obj2_bbox = pred_bboxes[j].tensor.cpu().detach().numpy().astype(int)
                union_bbox = boxes_union(copy.deepcopy(obj1_bbox), copy.deepcopy(obj2_bbox))[0]
                
                obj1_label = pred_cls[i]
                obj2_label = pred_cls[j]
                
                obj1_score = pred_scores[i]
                obj2_score = pred_scores[j]
                
                enumerated_relationships['subj_bboxes'].append(obj1_bbox)
                enumerated_relationships['obj_bboxes'].append(obj2_bbox)
                enumerated_relationships['union_bboxes'].append(union_bbox)
                enumerated_relationships['subj_classes'].append(obj1_label)
                enumerated_relationships['obj_classes'].append(obj2_label)
                enumerated_relationships['subj_scores'].append(obj1_score)
                enumerated_relationships['obj_scores'].append(obj2_score)
                
                
                enumerated_relationships['subj_bboxes'].append(obj2_bbox)
                enumerated_relationships['obj_bboxes'].append(obj1_bbox)
                enumerated_relationships['union_bboxes'].append(union_bbox)
                enumerated_relationships['subj_classes'].append(obj2_label)
                enumerated_relationships['obj_classes'].append(obj1_label)
                enumerated_relationships['subj_scores'].append(obj2_score)
                enumerated_relationships['obj_scores'].append(obj1_score)
                
        return enumerated_relationships
    
    def get_predicted_relationships(self, data):
        """
            This function gets the predicted relationships from the object detector, 
            and is only relevant to relationship detection evaluation.
            
            Input:
                data - data format for detectron2
            Output:
                relationships - Instances object by the detectron2 that contain predicted instances
        """
        #get predicted objects in the given image
        instances = self.get_instances_prediction(data)
        
        #enumerate the relationships with the predicted instances (bounding boxes and labels)
        relationships = self.enumerate_relationships_from_instances(instances)
        
        return relationships

    
    def get_predicate_distances(self, data, is_rel_eval=False):
        """
            Predict model's prediction based on the given data.
            Return the prediction predicate, visual relationship, and phrase (to be implemented)
            Input:
                data: vrdtranse input format
                is_rel_eval: boolean whether the evaluation is predicate detection or relationship detection
        """
        relationships = data["relationships"].copy()
        all_predicate_distances = [] # for each (subj, obj) pair, we get a set of distances
        all_subject_distances = [] #distance between subject visual and language
        all_object_distances = []
        all_predicate_subtract_distances = []
        all_transe_visual_feature = []
        all_transe_language_feature = []
        all_transe_distance_feature = []
        
        #forward features for gt_visual and gt_text
        fc_features = self.forward(data, None, get_fc_features=True)
        
        rel_cnt = 0
        with torch.no_grad():
            for subj, obj in zip(relationships['subj_classes'], relationships['obj_classes']):
                #compute all potential predicate embeddings for the (subj, obj) pair
                all_subj_embeddings, all_predicate_embeddings, all_object_embeddings = self._get_word_predicate_features(subj=subj, obj=obj)

                #languge features
                fc_all_subject_embeddings = self.fc_word["subj"](all_subj_embeddings)
                fc_all_predicate_embeddings = self.fc_word["pred"](all_predicate_embeddings)
                fc_all_object_embeddings = self.fc_word["obj"](all_object_embeddings)
                fc_pred_transe_language_feature = fc_all_subject_embeddings + fc_all_predicate_embeddings - fc_all_object_embeddings

                #visual features
                fc_pred_visual_feature = fc_features["visual"]["pred"][rel_cnt,:]
                fc_pred_subtract_visual_feature = fc_features["visual"]["obj"][rel_cnt,:] - fc_features["visual"]["subj"][rel_cnt,:]
                fc_pred_transe_visual_feature = fc_features["visual"]["subj"][rel_cnt,:] + fc_features["visual"]["pred"][rel_cnt,:] - fc_features["visual"]["obj"][rel_cnt,:]
                fc_subj_visual_feature = fc_features["visual"]["subj"][rel_cnt,:]
                fc_obj_visual_feature = fc_features["visual"]["obj"][rel_cnt,:]
                
                #compute distance between the fc_features["visual"]["pred"] and fc_predicate_embeddings to get top n
                pdist = lambda x, y: 1.0 - F.cosine_similarity(x, y)
                distance = []
                distance_subject = []
                distance_object = []
                distance_subtract = []
                distance_transe = []
                
                for subj_emb, pred_emb, obj_emb, pred_transe_emb in zip(fc_all_subject_embeddings, fc_all_predicate_embeddings, fc_all_object_embeddings, fc_pred_transe_language_feature):
                    distance.append(pdist(torch.unsqueeze(fc_pred_visual_feature, dim=0), torch.unsqueeze(pred_emb, dim=0)))
                    distance_subtract.append(pdist(torch.unsqueeze(fc_pred_subtract_visual_feature, dim=0), torch.unsqueeze(obj_emb - subj_emb, dim=0)))
                    distance_transe.append(pdist(torch.unsqueeze(fc_pred_transe_visual_feature, dim=0), torch.unsqueeze(pred_transe_emb, dim=0)))
                    distance_subject.append(pdist(torch.unsqueeze(fc_subj_visual_feature, dim=0), torch.unsqueeze(subj_emb, dim=0)))
                    distance_object.append(pdist(torch.unsqueeze(fc_obj_visual_feature, dim=0), torch.unsqueeze(obj_emb, dim=0)))
                     
                    
                # add set of distances to the given relationship
                all_predicate_distances.append(distance)
                all_subject_distances.append(distance_subject)
                all_object_distances.append(distance_object)
                all_predicate_subtract_distances.append(distance_subtract)
                all_transe_visual_feature.append(fc_pred_transe_visual_feature)
                all_transe_language_feature.append(fc_pred_transe_language_feature)
                all_transe_distance_feature.append(distance_transe)
                rel_cnt += 1
        
        if (is_rel_eval):
            return all_predicate_distances, all_subject_distances, all_object_distances, all_predicate_subtract_distances, all_transe_distance_feature
                
        return all_predicate_distances, all_predicate_subtract_distances, all_transe_distance_feature
    
    def get_triples_distance(self, data):
        """
            POINTER 3 [currently redundant]
            Predicts the distance between subjects, predicates, objects, and subtracted distance between subject and object
            
            Only relevant for visual relationship detection task (evaluation)
        """
        relationships = data["relationships"].copy()
        all_predicate_distances = []
        all_subject_distances = []
        all_object_distances = []
        all_predicate_subtract_distances = []
        
        rel_cnt = 0
        with torch.no_grad():
            for subj, obj in zip(relationships['subj_classes'], relationships['obj_classes']):
                #compute all potential predicate embeddings for the (subj, obj) pair
                all_subj_embeddings, all_predicate_embeddings, all_object_embeddings = self._get_word_predicate_features(subj=subj, obj=obj)

                #languge features
                fc_all_subject_embeddings = self.fc_word["subj"](all_subj_embeddings)
                fc_all_predicate_embeddings = self.fc_word["pred"](all_predicate_embeddings)
                fc_all_object_embeddings = self.fc_word["obj"](all_object_embeddings)
                fc_pred_transe_language_feature = fc_all_subject_embeddings - fc_all_object_embeddings

                #visual features
                fc_pred_visual_feature = fc_features["visual"]["pred"][rel_cnt,:]
                fc_pred_subtract_visual_feature = fc_features["visual"]["obj"][rel_cnt,:] - fc_features["visual"]["subj"][rel_cnt,:]
                fc_pred_transe_visual_feature = fc_features["visual"]["subj"][rel_cnt,:] + fc_features["visual"]["pred"][rel_cnt,:] - fc_features["visual"]["obj"][rel_cnt,:]
                
                #compute distance between the fc_features["visual"]["pred"] and fc_predicate_embeddings to get top n
                pdist = lambda x, y: 1.0 - F.cosine_similarity(x, y)
                distance = []
                distance_subtract = []
                distance_transe = []
                
                for subj_emb, pred_emb, obj_emb, pred_transe_emb in zip(fc_all_subject_embeddings, fc_all_predicate_embeddings, fc_all_object_embeddings, fc_pred_transe_language_feature):
                    distance.append(pdist(torch.unsqueeze(fc_pred_visual_feature, dim=0), torch.unsqueeze(pred_emb, dim=0)))
                    distance_subtract.append(pdist(torch.unsqueeze(fc_pred_subtract_visual_feature, dim=0), torch.unsqueeze(obj_emb - subj_emb, dim=0)))
                    distance_transe.append(pdist(torch.unsqueeze(fc_pred_transe_visual_feature, dim=0), torch.unsqueeze(pred_transe_emb, dim=0)))

                # add set of distances to the given relationship
                all_predicate_distances.append(distance)
                all_predicate_subtract_distances.append(distance_subtract)
                all_transe_visual_feature.append(fc_pred_transe_visual_feature)
                all_transe_language_feature.append(fc_pred_transe_language_feature)
                all_transe_distance_feature.append(distance_transe)
                rel_cnt += 1
        
        fc_features = self.forward(data, None, get_fc_features=True)

        
        return all_predicate_distances, all_subject_distances, all_object_distances, all_predicate_subtract_distances
    
    def _get_word_object_features(self, subj, pred):
        """
            Input:
                subj: subject index
                pred: predicate index
            Output:
                A stack pred_embeddings for the above two subj and obj
        """
        subj_cls_label = self.object_classes[subj]
        pred_cls_label = self.predicate_classes[pred]

        #predicate label
        pred_embeddings = []
        subj_embeddings = []
        obj_embeddings = []
        for obj, obj_cls_label in enumerate(self.object_classes):
            triples_text = '-'.join((subj_cls_label, pred_cls_label, obj_cls_label))
            query_embeddings = None
            if (triples_text not in token_embeddings):
                query_embeddings = self._get_bert_features((subj_cls_label, pred_cls_label, obj_cls_label))
                token_embeddings[triples_text] = query_embeddings
            else:
                query_embeddings = token_embeddings[triples_text]
            subj_embeddings.append(query_embeddings['subj'].to('cuda'))
            pred_embeddings.append(query_embeddings['pred'].to('cuda'))
            obj_embeddings.append(query_embeddings['obj'].to('cuda'))
            
        subj_embeddings = torch.stack(subj_embeddings).to('cuda')
        pred_embeddings = torch.stack(pred_embeddings).to('cuda')
        obj_embeddings = torch.stack(obj_embeddings).to('cuda')

        return subj_embeddings, pred_embeddings, obj_embeddings
    
    def get_object_distances(self, data):
        """
            Predict model's prediction based on the given data.
            Return the prediction predicate, phrase (to be implemented), and visual relationship (to be implemented)
        """
        relationships = data["relationships"]
        all_object_distances = [] # for each (subj, pred) pair, we get a set of distances
        all_object_add_distances = [] 

        #forward features for gt_visual and gt_text
        fc_features = self.forward(data, None, get_fc_features=True, obfuscate_object=True)

        rel_cnt = 0
        with torch.no_grad():
            for subj, pred in zip(relationships['subj_classes'], relationships['pred_classes']):
                #compute all potential predicate embeddings for the (subj, obj) pair
                all_subj_embeddings, all_predicate_embeddings, all_object_embeddings = self._get_word_object_features(subj=subj, pred=pred)

                #languge features
                fc_all_subject_embeddings = self.fc_word["subj"](all_subj_embeddings)
                fc_all_predicate_embeddings = self.fc_word["pred"](all_predicate_embeddings)
                fc_all_object_embeddings = self.fc_word["obj"](all_object_embeddings)

                #visual features
                fc_obj_add_visual_feature = fc_features["visual"]["subj"][rel_cnt,:] + fc_features["visual"]["pred"][rel_cnt,:]

                #compute distance between the fc_features["visual"]["pred"] and fc_predicate_embeddings to get top n
                pdist = lambda x, y: 1.0 - F.cosine_similarity(x, y)
                distance_add = []
                distance = []
                
                for subj_emb, pred_emb, obj_emb in zip(fc_all_subject_embeddings, fc_all_predicate_embeddings, fc_all_object_embeddings):
                    distance.append(pdist(torch.unsqueeze(fc_obj_add_visual_feature, dim=0), torch.unsqueeze(obj_emb, dim=0)))
                    distance_add.append(pdist(torch.unsqueeze(fc_obj_add_visual_feature, dim=0), torch.unsqueeze(subj_emb + pred_emb, dim=0)))
                    
                # add set of distances to the given relationship
                all_object_distances.append(distance)
                all_object_add_distances.append(distance_add)
                rel_cnt += 1

        return all_object_distances, all_object_add_distances                
    
    def _get_prediced_bboxes(self, data):
        
        data['height'] = 800
        data['width'] = 800

        self.detectron.eval()
        with torch.no_grad():
            outputs = self.detectron([data])
        
        return outputs
    
    def _get_roi_features(self, images, box_lists):
        """
            Get image features from the backbone network
            Input:
                images: (ImageList.from_tensors) with dimension (C,W,H)
                box_lists: A list of N boxes
            Return:
                features:[N, 7*7*256]
        """
        N = len(box_lists[0])
        
        cfg = self.cfg
        feature_maps = self.backbone(images)
        feature_maps = [feature_maps['p{}'.format(i)] for i in range(2,6)]
        regions_feature = self.pooler(feature_maps, box_lists)
        return regions_feature.reshape((N, self.visual_feature_dim))

    def get_unrelated_instance(self, bbox, cls, gt_tuple_boxes, gt_classes, memo, other_memo=None):
        negative_example = {}

        tuple_bbox = tuple(bbox)
        for i, neg_bbox in enumerate(gt_tuple_boxes):
            if (other_memo is not None and neg_bbox not in other_memo and neg_bbox not in memo):
                negative_example = {
                    'bbox': torch.from_numpy(np.asarray(neg_bbox)).float().cuda(), #convert to tensor float
                    'cls': gt_classes[i]
                }
                return negative_example
            elif (neg_bbox != tuple_bbox and neg_bbox not in memo):
                negative_example = {
                    'bbox': torch.from_numpy(np.asarray(neg_bbox)).float().cuda(), #convert to tensor float
                    'cls': gt_classes[i]
                }
                return negative_example

        return negative_example
    
    def generate_negative_examples(self, data, K=3):
        """
            for each triple relation in data, generate K negative examples

            return: [{
                'subj_bboxes': Boxes(tensor[[X,Y,X2,Y2],...])),
                'union_bboxes': Boxes(tensor(([[X,Y,X2,Y2],...])),
                'obj_bboxes': Boxes(tensor([[X,Y,X2,Y2],...])),
                'subj_classes': [cls_subj,...],
                'pred_classes': [cls_pred,...],
                'obj_classes': [cls_obj,...]
            }]
        """
        boxes = data['instances'].get_fields()['gt_boxes']
        gt_tuple_boxes = [tuple([ele.item() for ele in box]) for box in boxes] #convert ground truth boxes into tuples

        classes = data['instances'].get_fields()['gt_classes']
        gt_classes = [int(item) for item in classes]

        # shuffle to random select first K
        zip_gt_data = list(zip(gt_tuple_boxes, gt_classes))
        random.shuffle(zip_gt_data)
        gt_tuple_boxes, gt_classes = zip(*zip_gt_data)
        
        relationships = data['relationships']
        subj_boxes = relationships['subj_bboxes']
        union_boxes = relationships['union_bboxes']
        obj_boxes = relationships['obj_bboxes']
        subj_classes = relationships['subj_classes']
        pred_classes = relationships['pred_classes']
        obj_classes = relationships['obj_classes']
        
        #generate K negative examples
        neg_examples = []
        memo_subj = set()
        memo_obj = set()
        existed_predicates = dict(zip([tuple(item) for item in data['relationships']['union_bboxes']], data['relationships']['pred_classes']))
        
        # negative within the image
        for i in range(min(len(gt_tuple_boxes)-1, K)):
            neg_ex = defaultdict(list)
            if (len(memo_subj) == len(gt_tuple_boxes)-1 or len(memo_obj) == len(gt_tuple_boxes)-1):
                break
            
            try:
                for j in range(len(subj_boxes)): #iterate through the relationships in the image
                    #subj
                    subj_box = subj_boxes[j]
                    subj_cls = subj_classes[j]

                    #ISSUE: in the case where the number of object in the image is actually smaller than K, it is kind useless
                    unrelated_subj_instance = self.get_unrelated_instance(subj_box[0], subj_cls, gt_tuple_boxes, gt_classes, memo=memo_subj)
                    neg_ex['subj_bboxes'].append(unrelated_subj_instance['bbox'])
                    neg_ex['subj_classes'].append(unrelated_subj_instance['cls'])

                    #obj
                    obj_box = obj_boxes[j]
                    obj_cls = obj_classes[j]
                    other_memo = set()
                    other_memo.add(tuple(subj_box[0]))
                    unrelated_obj_instance = self.get_unrelated_instance(obj_box[0], obj_cls, gt_tuple_boxes, gt_classes, memo=memo_obj, other_memo=other_memo)
                    neg_ex['obj_bboxes'].append(unrelated_obj_instance['bbox'])
                    neg_ex['obj_classes'].append(unrelated_obj_instance['cls'])
                    
                    #pred
                    new_union_box = boxes_union(copy.deepcopy(unrelated_subj_instance['bbox'].reshape(1,4).to('cpu')),
                                                copy.deepcopy(unrelated_obj_instance['bbox'].reshape(1,4).to('cpu')))[0]
                    new_predicate_class = len(self.predicate_classes) - 1
                    if (tuple(new_union_box) in existed_predicates):
                        new_predicate_class = existed_predicates[tuple(new_union_box)]
                    neg_ex['union_bboxes'].append(torch.from_numpy(np.asarray(new_union_box)).float().cuda())
                    neg_ex['pred_classes'].append(new_predicate_class)

                for j in range(len(subj_boxes)):
                    memo_subj.add(tuple(neg_ex['subj_bboxes'][j].tolist()))
                    memo_obj.add(tuple(neg_ex['obj_bboxes'][j].tolist()))
            except:
                break

            #stack the bounding boxes
            neg_ex['subj_bboxes'] = Boxes(torch.stack(neg_ex['subj_bboxes']))
            neg_ex['obj_bboxes'] = Boxes(torch.stack(neg_ex['obj_bboxes']))
            neg_ex['union_bboxes'] = Boxes(torch.stack(neg_ex['union_bboxes']))
            
            #append to memory
            neg_examples.append(neg_ex)
            
        return neg_examples

    def get_spatial_features(self, relationships, is_negative=False):
        """
        Args:
            data: see data definition in forward function
        Return:
            spatial_features: a tensor of spatial features containing coordinates of the bounding box
        """
        def spatial_delta(entity1, entity2):
            """
                entity1, entity2: [X,Y,X,Y]
            """
            
            width1, height1 = entity1[2] - entity1[0], entity1[3] - entity1[1]
            width2, height2 = entity2[2] - entity2[0], entity2[3] - entity2[1]
            
            delta_feat = [
                (entity1[0] - entity2[0])/width2,
                (entity1[1] - entity2[1])/height2,
                np.log(width1/width2),
                np.log(height1/height2),
            ]
            return delta_feat
        
        def spatial_coordinates(entity):
            """
                entity: [X,Y,X,Y]
            """
            width, height = entity[2] - entity[0], entity[3] - entity[1]
            coordinate_feat = [
                entity[0]/800,
                entity[1]/800,
                entity[2]/800,
                entity[3]/800,
                width*height/800*800
            ]
            return coordinate_feat
        
        spatial_features = []
        #iterate through every relationship pair and construct an array of spatial feature    
        for subj_bbox, obj_bbox, union_bbox in zip(relationships['subj_bboxes'], relationships['obj_bboxes'], relationships['union_bboxes']):
            feat = []
            #XYXY
            if (is_negative):
                feat.extend(spatial_delta(subj_bbox.cpu(), obj_bbox.cpu()))
                feat.extend(spatial_delta(subj_bbox.cpu(), union_bbox.cpu()))
                feat.extend(spatial_delta(union_bbox.cpu(), obj_bbox.cpu()))
                feat.extend(spatial_coordinates(subj_bbox.cpu()))
                feat.extend(spatial_coordinates(obj_bbox.cpu()))            
            else:
                feat.extend(spatial_delta(subj_bbox[0], obj_bbox[0]))
                feat.extend(spatial_delta(subj_bbox[0], union_bbox))
                feat.extend(spatial_delta(union_bbox, obj_bbox[0]))
                feat.extend(spatial_coordinates(subj_bbox[0]))
                feat.extend(spatial_coordinates(obj_bbox[0]))
            
            spatial_features.append(torch.from_numpy(np.asarray(feat)).float().cuda())
        
        return torch.stack(spatial_features)
    
    def eval_phrase_detection(self, data):
        #add batch dim to image
        image = torch.unsqueeze(data['image'], axis=0).cuda()
        
        #get prediction bounding boxes
        output = self._get_prediced_bboxes(data)
        bboxes = output[0]['instances'].get_fields()['pred_boxes']
        bboxes_features = self._get_roi_features(image.float(), box_lists=[bboxes])
        
        bboxes_classes = output[0]['instances'].get_fields()['pred_classes']
        conf_score = output[0]['instances'].get_fields()['scores']
        #iterate through every pairs and compute score
        return output, bboxes_features
    
    def forward(self, data, negative_examples, get_fc_features=False, obfuscate_object=False, **kwargs):
        """
        Args:
            data: {
                    #Detectron
                    'file_name': os.path.join(self.images_dir, img_name),
                    'image_id': int(img_name.split('.')[0]),
                    'annotations': list(unique_objects.values()),

                    #Relationships
                    'relationships': {
                        'subj_bboxes': subj_bboxes,
                        'obj_bboxes': obj_bboxes,
                        'union_bboxes': union_bboxes,

                        'subj_classes': subj_classes,
                        'pred_classes': pred_classes,
                        'obj_classes': obj_classes,
                    }
                }
            context: #any external data (not implemented as of current)
        """
        image = torch.unsqueeze(data['image'], axis=0).cuda()

        relationships = data['relationships']
        
        subj_bboxes = Boxes([bbox[0] for bbox in relationships['subj_bboxes']]).to('cuda')
        if (obfuscate_object):
            union_bboxes = subj_bboxes
        else:
            union_bboxes = Boxes([bbox for bbox in relationships['union_bboxes']]).to('cuda')
        obj_bboxes = Boxes([bbox[0] for bbox in relationships['obj_bboxes']]).to('cuda')

        
        #ground_truth features
        gt_features = {"visual": {}, "textual": {}}

        #fully connected features
        fc_features = {"visual": {}, "textual": {}}


        #spatial features
        gt_spatial_features = self.fc_spatial(self.get_spatial_features(relationships))
        
        #extract visual features from backbone and ROIPool for n relations in the image
        gt_features["visual"]["subj"] = self._get_roi_features(image.float(), box_lists=[subj_bboxes])
        gt_features["visual"]["pred"] = self._get_roi_features(image.float(), box_lists=[union_bboxes])
        gt_features["visual"]["pred"] = torch.cat((gt_features["visual"]["pred"], gt_spatial_features), 1)
        gt_features["visual"]["obj"] = self._get_roi_features(image.float(), box_lists=[obj_bboxes])
        
        #fc visual (rois1 and rois2)
        fc_features["visual"]["subj"] = self.fc_rois["subj"](gt_features["visual"]["subj"])
        fc_features["visual"]["pred"] = self.fc_rois["pred"](gt_features["visual"]["pred"])
        fc_features["visual"]["obj"] = self.fc_rois["obj"](gt_features["visual"]["obj"])

        fc_features["visual"]["subj"] = self.fc_rois2["subj"](fc_features["visual"]["subj"])
        fc_features["visual"]["pred"] = self.fc_rois2["pred"](fc_features["visual"]["pred"])
        fc_features["visual"]["obj"] = self.fc_rois2["obj"](fc_features["visual"]["obj"])

        
        if (get_fc_features):
            return fc_features
    
        #extract word embeddings for n examples in the image
        word_embeddings = self._get_word_features(list(zip(relationships['subj_classes'],
                                                           relationships['pred_classes'],
                                                           relationships['obj_classes'])))
        gt_features["textual"]["subj"] = word_embeddings[0]
        gt_features["textual"]["pred"] = word_embeddings[1]
        gt_features["textual"]["obj"] = word_embeddings[2]


        #fc word
        fc_features["textual"]["subj"] = self.fc_word["subj"](gt_features["textual"]["subj"])
        fc_features["textual"]["pred"] = self.fc_word["pred"](gt_features["textual"]["pred"])
        fc_features["textual"]["obj"] = self.fc_word["obj"](gt_features["textual"]["obj"])

        #Visual and Language Consistency losses triplet_loss(anchor, positive, negative)
        triplet_losses = {"subj": None, "pred": None, "obj": None} #, "transr": None}
        
        #NEGATIVE EXAMPLES
        for neg_relationships in negative_examples:

            neg_subj_boxes = neg_relationships['subj_bboxes']
            neg_union_boxes = neg_relationships['union_bboxes']
            neg_obj_boxes = neg_relationships['obj_bboxes']

            #dictionary to store gt_features
            neg_features = {"visual": {}, "textual": {}}

            #spatial features
            neg_spatial_features = self.fc_spatial(self.get_spatial_features(neg_relationships, is_negative=True))

            #extract visual features from backbone and ROIPool for n relations in the image
            neg_features["visual"]["subj"] = self._get_roi_features(image.float(), box_lists=[neg_subj_boxes])
            neg_features["visual"]["pred"] = self._get_roi_features(image.float(), box_lists=[neg_union_boxes])
            neg_features["visual"]["pred"] = torch.cat((neg_features["visual"]["pred"], neg_spatial_features), 1)
            neg_features["visual"]["obj"] = self._get_roi_features(image.float(), box_lists=[neg_obj_boxes])

            #extract word embeddings for n examples in the image
            neg_word_embeddings = self._get_word_features(list(zip(neg_relationships['subj_classes'],
                                                               neg_relationships['pred_classes'],
                                                               neg_relationships['obj_classes'])))
            neg_features["textual"]["subj"] = neg_word_embeddings[0]
            neg_features["textual"]["pred"] = neg_word_embeddings[1]
            neg_features["textual"]["obj"] = neg_word_embeddings[2]

            #neg fc visual
            neg_fc_features = {"visual": {}, "textual": {}}
            neg_fc_features["visual"]["subj"] = self.fc_rois["subj"](neg_features["visual"]["subj"])
            neg_fc_features["visual"]["pred"] = self.fc_rois["pred"](neg_features["visual"]["pred"])
            neg_fc_features["visual"]["obj"] = self.fc_rois["obj"](neg_features["visual"]["obj"])

            neg_fc_features["visual"]["subj"] = self.fc_rois2["subj"](neg_fc_features["visual"]["subj"])
            neg_fc_features["visual"]["pred"] = self.fc_rois2["pred"](neg_fc_features["visual"]["pred"])
            neg_fc_features["visual"]["obj"] = self.fc_rois2["obj"](neg_fc_features["visual"]["obj"])
            
            #neg fc word
            neg_fc_features["textual"]["subj"] = self.fc_word["subj"](neg_features["textual"]["subj"])
            neg_fc_features["textual"]["pred"] = self.fc_word["pred"](neg_features["textual"]["pred"])
            neg_fc_features["textual"]["obj"] = self.fc_word["obj"](neg_features["textual"]["obj"])

            if (triplet_losses["subj"] is None):
                triplet_losses["subj"] = self.triplet_loss["subj"](fc_features["visual"]["subj"], fc_features["textual"]["subj"], neg_fc_features["visual"]["subj"]) + self.triplet_loss["subj"](fc_features["textual"]["subj"], fc_features["visual"]["subj"], neg_fc_features["textual"]["subj"])
                triplet_losses["pred"] = self.triplet_loss["pred"](fc_features["visual"]["pred"], fc_features["textual"]["pred"], neg_fc_features["visual"]["pred"]) + self.triplet_loss["pred"](fc_features["textual"]["pred"], fc_features["visual"]["pred"], neg_fc_features["textual"]["pred"])
                triplet_losses["obj"] = self.triplet_loss["obj"](fc_features["visual"]["obj"], fc_features["textual"]["obj"], neg_fc_features["visual"]["obj"]) + self.triplet_loss["obj"](fc_features["textual"]["obj"],fc_features["visual"]["obj"],neg_fc_features["textual"]["obj"])
                
                triplet_losses["transr"] = self.triplet_loss["pred"](fc_features["visual"]["subj"] + fc_features["visual"]["pred"], fc_features["visual"]["obj"], neg_fc_features["visual"]["subj"] + neg_fc_features["visual"]["pred"])
                triplet_losses["transr"] += self.triplet_loss["pred"](fc_features["textual"]["subj"] + fc_features["textual"]["pred"], fc_features["textual"]["obj"], neg_fc_features["textual"]["subj"] + neg_fc_features["textual"]["pred"])
                triplet_losses["transr"] += self.triplet_loss["pred"](fc_features["visual"]["subj"] + fc_features["visual"]["pred"], fc_features["visual"]["obj"], neg_fc_features["visual"]["obj"])
                triplet_losses["transr"] += self.triplet_loss["pred"](fc_features["textual"]["subj"] + fc_features["textual"]["pred"], fc_features["textual"]["obj"], neg_fc_features["textual"]["obj"])

            else:
                triplet_losses["subj"] += self.triplet_loss["subj"](fc_features["visual"]["subj"], fc_features["textual"]["subj"], neg_fc_features["visual"]["subj"]) + self.triplet_loss["subj"](fc_features["textual"]["subj"],fc_features["visual"]["subj"], neg_fc_features["textual"]["subj"])
                triplet_losses["pred"] += self.triplet_loss["pred"](fc_features["visual"]["pred"], fc_features["textual"]["pred"], neg_fc_features["visual"]["pred"]) + self.triplet_loss["pred"](fc_features["textual"]["pred"],fc_features["visual"]["pred"],neg_fc_features["textual"]["pred"])
                triplet_losses["obj"] += self.triplet_loss["obj"](fc_features["visual"]["obj"], fc_features["textual"]["obj"], neg_fc_features["visual"]["obj"]) + self.triplet_loss["obj"](fc_features["textual"]["obj"],fc_features["visual"]["obj"],neg_fc_features["textual"]["obj"])

                triplet_losses["transr"] += self.triplet_loss["pred"](fc_features["visual"]["subj"] + fc_features["visual"]["pred"], fc_features["visual"]["obj"], neg_fc_features["visual"]["subj"] + neg_fc_features["visual"]["pred"])
                triplet_losses["transr"] += self.triplet_loss["pred"](fc_features["textual"]["subj"] + fc_features["textual"]["pred"], fc_features["textual"]["obj"], neg_fc_features["textual"]["subj"] + neg_fc_features["textual"]["pred"])
                triplet_losses["transr"] += self.triplet_loss["pred"](fc_features["visual"]["subj"] + fc_features["visual"]["pred"], fc_features["visual"]["obj"], neg_fc_features["visual"]["obj"])
                triplet_losses["transr"] += self.triplet_loss["pred"](fc_features["textual"]["subj"] + fc_features["textual"]["pred"], fc_features["textual"]["obj"], neg_fc_features["textual"]["obj"])

        #get triples
        language_negatives = []
        for subj, pred, obj in zip(relationships['subj_classes'], relationships['pred_classes'], relationships['obj_classes']):
            subj_neg_classes = copy.deepcopy(self.object_classes)
            subj_neg_classes.remove(self.object_classes[subj])
            obj_neg_classes = copy.deepcopy(self.object_classes)
            obj_neg_classes.remove(self.object_classes[obj])
            
            negative_triples = [(self.object_classes.index(neg), pred, obj) for neg in subj_neg_classes]
            negative_triples.extend([(subj, pred, self.object_classes.index(neg)) for neg in obj_neg_classes])
            
            language_negatives.append(negative_triples)
 
        n_neg_lang = 50
        for i in random.sample(range(len(language_negatives[0])), n_neg_lang):
            neg_example = [item[i] for item in language_negatives]
            
            neg_word_embeddings = self._get_word_features(neg_example)
            neg_features_subj = neg_word_embeddings[0]
            neg_features_pred = neg_word_embeddings[1]
            neg_features_obj = neg_word_embeddings[2]
            
            neg_fc_features_subj = self.fc_word["subj"](neg_features_subj)
            neg_fc_features_pred = self.fc_word["pred"](neg_features_pred)
            neg_fc_features_obj = self.fc_word["obj"](neg_features_obj)

            triplet_losses["subj"] = self.triplet_loss["subj"](fc_features["textual"]["subj"], fc_features["textual"]["subj"], neg_fc_features_subj) + self.triplet_loss["subj"](fc_features["textual"]["subj"], fc_features["visual"]["subj"], neg_fc_features_subj)
            triplet_losses["pred"] = self.triplet_loss["pred"](fc_features["textual"]["pred"], fc_features["textual"]["pred"], neg_fc_features_pred) + self.triplet_loss["pred"](fc_features["textual"]["pred"], fc_features["visual"]["pred"], neg_fc_features_pred)
            triplet_losses["obj"] = self.triplet_loss["obj"](fc_features["textual"]["obj"], fc_features["textual"]["obj"], neg_fc_features_obj) + self.triplet_loss["obj"](fc_features["textual"]["obj"],fc_features["visual"]["obj"], neg_fc_features_obj)            
        #divide by the number of training examples
        triplet_losses["subj"] = triplet_losses["subj"] / (len(negative_examples) + n_neg_lang)
        triplet_losses["pred"] = triplet_losses["pred"] / (len(negative_examples) + n_neg_lang)
        triplet_losses["obj"] = triplet_losses["obj"] / (len(negative_examples) + n_neg_lang)      
        triplet_losses["transr"] =  triplet_losses["transr"] / len(negative_examples)
        return triplet_losses

In [14]:
import torch

def load_checkpoint(model, chkpoint_path, optimizer=None):
    chkpoint = torch.load(chkpoint_path)
    
    # incompatibility fixes
    
    model.load_state_dict(chkpoint['model'])
    if (optimizer is not None):
        optimizer.load_state_dict(chkpoint['optimizer'])
    return chkpoint['it'], chkpoint['epoch'], chkpoint['losses']

### [Optional] Test Instance Detector (Object detector)

In [11]:
from detectron2.data import (
    DatasetCatalog, DatasetMapper,
    build_detection_train_loader,
    build_detection_test_loader,
)
from config import get_vrd_cfg
import detectron2.data.transforms as T

cfg = get_vrd_cfg()
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .6
cfg.DATASETS.TEST = ("vrd_val", )

#test dataset
test_dataset = DatasetCatalog.get("vrd_val")
test_dataloader = build_detection_test_loader(dataset=test_dataset,
    mapper=DatasetMapper(cfg, is_train=True, augmentations=[
        T.Resize((800, 800))
    ])
)


#train dataset
train_dataset = DatasetCatalog.get("vrd_train")
train_dataloader = build_detection_test_loader(dataset=train_dataset,
    mapper=DatasetMapper(cfg, is_train=True, augmentations=[
        T.Resize((800, 800))
    ])
)

In [12]:
model = RelTransR(cfg)
device = torch.device("cuda")
model.to(device)

#Run only once
model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


RelTransR(
  (fc_spatial): Sequential(
    (0): Linear(in_features=22, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  )
  (detectron): GeneralizedRCNN(
    (backbone): FPN(
      (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
      (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
      (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (top_block): LastLevelMaxPool()
      (bottom_up): ResNet(
        (stem): BasicStem(
          (conv1): Con

In [15]:
chkpoint_path = '../checkpoint/'
model_name = 'vltranse_langcon_model_8000.pt'
chkpoint_full_path = os.path.join(chkpoint_path, model_name)
it, start_epoch, losses = load_checkpoint(model, chkpoint_full_path)

In [16]:
from collections import defaultdict
from detectron2.evaluation import COCOEvaluator, DatasetEvaluator, DatasetEvaluators, inference_on_dataset

object_classes = model.object_classes

class Counter(DatasetEvaluator):
    def reset(self):
        self.count = 0
        self.gt_count = 0
        self.object_areas = defaultdict(list)
        
    def process(self, inputs, outputs):
        for output in outputs:
            self.count += len(output["instances"])
        
        for inp in inputs:
            self.gt_count += len(inp["instances"])

            annotations = inp['instances'].get_fields()    
            for box, cls in zip(annotations['gt_boxes'], annotations['gt_classes']):
                object_width = box[2] - box[0] + 1
                object_height = box[3] - box[1] + 1
                object_area = object_width * object_height
                self.object_areas['area-{}'.format(object_classes[cls])].append(object_area)
                
    def evaluate(self):
        # save self.count somewhere, or print it, or return it.
        resulting_dict = {"count": self.count, "gt_count": self.gt_count, "object_average_area": {}}
        
        for key, areas in self.object_areas.items():
            resulting_dict["object_average_area"][key] = sum(areas) / len(areas)
        
        return resulting_dict
    
def beatify_detectron2_results(eval_results):
    """
        Beautify the results output by detectron2
    """
    object_areas = eval_results['object_average_area']
    object_area_ap = {}
    
    for eval_method, eval_result in eval_results.items():
        if (eval_method == 'count'):
            print("Total Objects Detected:", eval_result)
        elif (eval_method == 'gt_count'):
            print("Total Labeled Objects:", eval_result)
        elif (eval_method == 'object_average_area'):
            continue
        else:
            print("Evaluation results for {}".format(eval_method))

            resulting_string = ""
            for i, (key, res) in enumerate(eval_result.items()):
                resulting_string = "".join((resulting_string, "|   {:>16}\t->\t{:5.2f}".format(key, res)))
                if ((i + 1) <= 6):
                    resulting_string = "".join((resulting_string, "   |"))
                if ((i + 1) == 6):
                    resulting_string = "".join((resulting_string, "\nEvaluation results by object category\n"))
                elif ((i + 1) > 6):
                    object_cls = key.split('-')[1]
                    area_key = 'area-{}'.format(object_cls)
                    object_area = object_areas[area_key]
                    
                    resulting_string = "".join((resulting_string, "( {:5.2f} area )\t|".format(object_area)))
                    object_area_ap[object_cls] = (res, object_area)
                    
                if ((i + 1) % 2 == 0):
                    resulting_string = "".join((resulting_string, "\n"))
            print(resulting_string)
            
    return object_area_ap

##### Object Detector Evaluation for Test Dataset

In [None]:
eval_results = inference_on_dataset(
    model.detectron,
    test_dataloader,
    DatasetEvaluators([COCOEvaluator('vrd_val', output_dir="../generated/coco_evaluations_val"), Counter()]))


In [None]:
object_area_ap = beatify_detectron2_results(eval_results)

In [None]:
import seaborn as sns
import pandas as pd

test_object_data = [ (key, item[0], item[1]) for key, item in object_area_ap.items()]

area_ap_df = pd.DataFrame(pd.DataFrame(list(test_object_data),
               columns =['Object', 'AP', 'Area']))

sns.scatterplot(data=area_ap_df, x="Area", y="AP")


In [None]:
rank_ap = np.argsort([item[1] for item in test_object_data])
ranked_object = np.array([item[0] for item in test_object_data])[rank_ap]
print("Worst Performing Objects:", ranked_object[:20])

##### Object detector performance for Train Dataset

In [None]:
eval_results = inference_on_dataset(
    model.detectron,
    train_dataloader,
    DatasetEvaluators([COCOEvaluator('vrd_train', output_dir="../generated/coco_evaluations_train/"), Counter()]))

In [None]:
object_area_ap = beatify_detectron2_results(eval_results)

In [None]:
import seaborn as sns
import pandas as pd

train_object_data = [ (key, item[0], item[1]) for key, item in object_area_ap.items()]

area_ap_df = pd.DataFrame(pd.DataFrame(list(train_object_data),
               columns =['Object', 'AP', 'Area']))

sns.scatterplot(data=area_ap_df, x="Area", y="AP")


In [None]:
rank_ap = np.argsort([item[1] for item in train_object_data])
ranked_object = np.array([item[0] for item in train_object_data])[rank_ap]
print("Worst Performing Objects:", ranked_object[:20])

### [Optional] Test

In [16]:
from detectron2.data import (
    DatasetCatalog, DatasetMapper,
    build_detection_train_loader,
    build_detection_test_loader,    
)
from config import get_vrd_cfg
import detectron2.data.transforms as T

cfg = get_vrd_cfg()

# DEPRECARTED: Old Dataloader Code
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .6
# cfg.DATASETS.TRAIN = ("vrd_val", )
# test_dataloader = build_detection_train_loader(cfg,
#     mapper=DatasetMapper(cfg, is_train=True, augmentations=[
#         T.Resize((800, 800))
#     ])
# )

cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .6

In [None]:
model = RelTransR(cfg)
device = torch.device("cuda")
model.to(device)

#Run only once
model.eval()

#Load Model
chkpoint_path = '../checkpoint/'
model_name = 'vrd2_model_transr_23000.pt'
chkpoint_full_path = os.path.join(chkpoint_path, model_name)
it, start_epoch, losses = load_checkpoint(model, chkpoint_full_path)

In [27]:
# Get triples that are in the training set
import os, json

trained_triples_path = '../generated/trained_triples.json'
trained_triples = {}
if (os.path.exists(trained_triples_path)):
    with open(trained_triples_path, 'r') as file:
        trained_triples = json.load(file)
else:
    iter_dataloader = iter(train_dataloader)
    n_iters = len(train_dataloader.dataset.dataset)
    for i in range(n_iters):
        print(i)
        data = next(iter_dataloader)[0]
        relationships = data['relationships']
        for j in range(len(relationships['subj_classes'])):
            subj_cls = model.object_classes[relationships['subj_classes'][j]]
            pred_cls = model.predicate_classes[relationships['pred_classes'][j]]
            obj_cls = model.object_classes[relationships['obj_classes'][j]]
            trained_triples['{}-{}-{}'.format(subj_cls, pred_cls, obj_cls)] = 1

    with open(trained_triples_path, 'w') as file:
        file.write(json.dumps(trained_triples))

#### EVAL 1/ RELATIONSHIP PREDICTION TASK (NO GROUND TRUTH LABELS)

##### Test Data Evaluation

In [22]:
#!/usr/bin/env python
# coding: utf-8


import os
import json
import torch
import wandb
import time


from torch.optim.lr_scheduler import StepLR
from utils.register_dataset import register_vrd_dataset
from config import GENERATED_DIR, get_vrd_cfg, CHECKPOINT_DIR

from modeling.word_features import get_triples_features, get_trained_triples_memo

from utils.annotations import get_object_classes, get_predicate_classes
from utils.trainer import load_checkpoint
from utils.eval_helpers import eval_dataset, eval_config

from detectron2.data import (
    DatasetCatalog,
    DatasetMapper,
    build_detection_train_loader,
    build_detection_test_loader,
)
import detectron2.data.transforms as T


def load_model(cfg, checkpoint_model_name):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = RelTransR(cfg)
    model.to(device)

    chkpoint_full_path = os.path.join(CHECKPOINT_DIR, checkpoint_model_name)
    _, _, _ = load_checkpoint(model, chkpoint_full_path)

    return model


def test_model(checkpoint_model_name, dataset_name="vrd_val"):
    cfg = get_vrd_cfg()
    cfg.DATASETS.TEST = ("dataset_name",)
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.4

    # [Only Run once] Register dataset with detectron2 instead of using my own dataloader
#     register_vrd_dataset(dataset_name.split("_")[0])

    # Load model
    model = load_model(cfg, checkpoint_model_name)
    # Set model to evaluation mode
    model.eval()

    # Parallel
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # For zeroshot
    trained_triples = get_trained_triples_memo()

    # Dataset and DataLoader
    # test_dataset = DatasetCatalog.get(dataset_name)
    # test_dataloader = build_detection_test_loader(
    #     dataset=test_dataset,
    #     mapper=DatasetMapper(cfg, is_train=True, augmentations=[T.Resize((800, 800))]),
    # )

    # Iterate and Test
    recall_results = {}
    i = 0
    for is_zeroshot in [False, True]:
        for nre in [50, 100]:
            for conf in eval_config:
                if i == 0:
                    i += 1
                    continue
                test_dataset = DatasetCatalog.get(dataset_name)
                test_dataloader = build_detection_test_loader(
                    dataset=test_dataset,
                    mapper=DatasetMapper(
                        cfg, is_train=True, augmentations=[T.Resize((800, 800))]
                    ),
                )

                if is_zeroshot:
                    print(("is_zeroshot", nre, conf))
                    recall_results[("is_zeroshot", nre, conf)] = eval_dataset(
                        test_dataloader,
                        model,
                        nre=nre,
                        config=conf,
                        trained_triples=trained_triples,
                    )
                else:
                    print(("not_zeroshot", nre, conf))
                    recall_results[("not_zeroshot", nre, conf)] = eval_dataset(
                        test_dataloader,
                        model,
                        nre=nre,
                        config=conf,
                        trained_triples=None,
                    )

                test_dataloader = None
                del test_dataloader
                del test_dataset
                del model

                model = load_model(cfg, checkpoint_model_name)
                # Set model to evaluation mode
                model.eval()

    recall_results_path = f"{GENERATED_DIR}/{checkpoint_model_name.split('.')[0]}.json"

    with open(recall_results_path, "w") as file:
        file.write(json.dumps(recall_results))


test_model("vltranse_langcon_model_8000.pt")


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


('not_zeroshot', 50, 'triple_subtract_dist')


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  result = getattr(asarray(obj), method)(*args, **kwds)
  result = getattr(asarray(obj), method)(*args, **kwds)
  sorted_val = np.array(val)[sorted_by_distance]
  sorted_val = np.array(val)[sorted_by_distance]


|----------------------Iter 20------------------------|
| TOP 50 |               Recall  3.54               |
|----------------------Iter 40------------------------|
| TOP 50 |               Recall  4.35               |
|----------------------Iter 60------------------------|
| TOP 50 |               Recall  6.10               |
|----------------------Iter 100------------------------|
| TOP 50 |               Recall  4.50               |
|----------------------Iter 140------------------------|
| TOP 50 |               Recall  4.97               |
|----------------------Iter 160------------------------|
| TOP 50 |               Recall  5.00               |
|----------------------Iter 180------------------------|
| TOP 50 |               Recall  5.05               |
|----------------------Iter 200------------------------|
| TOP 50 |               Recall  5.55               |
|----------------------Iter 220------------------------|
| TOP 50 |               Recall  5.59               |
|---

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


('not_zeroshot', 50, 'triple_transe_dist')
|----------------------Iter 20------------------------|
| TOP 50 |               Recall  3.54               |
|----------------------Iter 40------------------------|
| TOP 50 |               Recall  4.78               |
|----------------------Iter 60------------------------|
| TOP 50 |               Recall  6.98               |
|----------------------Iter 100------------------------|
| TOP 50 |               Recall  5.11               |
|----------------------Iter 140------------------------|
| TOP 50 |               Recall  5.28               |
|----------------------Iter 160------------------------|
| TOP 50 |               Recall  5.45               |
|----------------------Iter 180------------------------|
| TOP 50 |               Recall  5.36               |
|----------------------Iter 200------------------------|
| TOP 50 |               Recall  5.76               |
|----------------------Iter 220------------------------|
| TOP 50 |     

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


('not_zeroshot', 50, 'pred_dist')
|----------------------Iter 20------------------------|
| TOP 50 |               Recall  6.19               |
|----------------------Iter 40------------------------|
| TOP 50 |               Recall  5.65               |
|----------------------Iter 60------------------------|
| TOP 50 |               Recall  7.27               |
|----------------------Iter 100------------------------|
| TOP 50 |               Recall  5.26               |
|----------------------Iter 140------------------------|
| TOP 50 |               Recall  4.76               |
|----------------------Iter 160------------------------|
| TOP 50 |               Recall  5.09               |
|----------------------Iter 180------------------------|
| TOP 50 |               Recall  4.97               |
|----------------------Iter 200------------------------|
| TOP 50 |               Recall  5.62               |
|----------------------Iter 220------------------------|
| TOP 50 |              

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


('not_zeroshot', 50, 'pred_subtract_dist')
|----------------------Iter 20------------------------|
| TOP 50 |               Recall  7.08               |
|----------------------Iter 40------------------------|
| TOP 50 |               Recall  6.52               |
|----------------------Iter 60------------------------|
| TOP 50 |               Recall  7.27               |
|----------------------Iter 100------------------------|
| TOP 50 |               Recall  4.95               |
|----------------------Iter 140------------------------|
| TOP 50 |               Recall  5.07               |
|----------------------Iter 160------------------------|
| TOP 50 |               Recall  5.36               |
|----------------------Iter 180------------------------|
| TOP 50 |               Recall  5.44               |
|----------------------Iter 200------------------------|
| TOP 50 |               Recall  6.20               |
|----------------------Iter 220------------------------|
| TOP 50 |     

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


('not_zeroshot', 50, 'pred_transe_dist')
|----------------------Iter 20------------------------|
| TOP 50 |               Recall  4.42               |
|----------------------Iter 40------------------------|
| TOP 50 |               Recall  5.22               |
|----------------------Iter 60------------------------|
| TOP 50 |               Recall  6.98               |
|----------------------Iter 100------------------------|
| TOP 50 |               Recall  5.11               |
|----------------------Iter 140------------------------|
| TOP 50 |               Recall  4.87               |


KeyboardInterrupt: 

In [20]:
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .4
cfg.DATASETS.TRAIN = ("vrd_val", )

test_dataloader = build_detection_train_loader(cfg,
    mapper=DatasetMapper(cfg, is_train=True, augmentations=[
        T.Resize((800, 800))
    ])
)
test_data_iter = iter(test_dataloader)

[32m[03/17 09:00:10 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [Resize(shape=(800, 800))]
[32m[03/17 09:00:10 d2.data.build]: [0mUsing training sampler TrainingSampler


In [None]:
recall_results_path = '../generated/results_recall_vrd2_model_transr_23000.json'

with open(recall_results_path, 'w') as file:
    file.write(json.dumps(recall_results))

##### Train Data Evaluation

In [17]:
#train dataset
cfg.DATASETS.TEST = ("vrd_train", )

train_dataset = DatasetCatalog.get("vrd_train")
train_dataloader = build_detection_test_loader(dataset=train_dataset,
    mapper=DatasetMapper(cfg, is_train=True, augmentations=[
        T.Resize((800, 800))
    ])
)

[32m[03/17 08:05:21 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [Resize(shape=(800, 800))]


In [18]:
cnt = {}
train_dataset[0]


{'file_name': '/home/thanhtran/dev/vrdtransr/data/vrd/train_images/000000002658.jpg',
 'image_id': 2658,
 'annotations': [{'bbox': [306, 94, 590, 175],
   'bbox_mode': 0,
   'category_id': 22},
  {'bbox': [324, 336, 458, 489], 'bbox_mode': 0, 'category_id': 5},
  {'bbox': [539, 265, 680, 397], 'bbox_mode': 0, 'category_id': 7},
  {'bbox': [425, 273, 556, 394], 'bbox_mode': 0, 'category_id': 7},
  {'bbox': [302, 270, 420, 396], 'bbox_mode': 0, 'category_id': 7},
  {'bbox': [159, 272, 293, 401], 'bbox_mode': 0, 'category_id': 7},
  {'bbox': [47, 271, 156, 407], 'bbox_mode': 0, 'category_id': 7}],
 'relationships': {'subj_bboxes': [array([[306, 141, 590, 262]]),
   array([[306, 141, 590, 262]]),
   array([[539, 397, 680, 595]]),
   array([[539, 397, 680, 595]]),
   array([[425, 409, 556, 591]]),
   array([[425, 409, 556, 591]]),
   array([[302, 405, 420, 594]]),
   array([[159, 408, 293, 601]]),
   array([[ 47, 406, 156, 610]]),
   array([[324, 504, 458, 733]])],
  'obj_bboxes': [array([[

In [None]:
recall_results = {}
for is_zeroshot in [False, True]:
    for nre in [50, 100]:
        for conf in eval_config:
            train_dataloader = build_detection_test_loader(dataset=train_dataset,
                mapper=DatasetMapper(cfg, is_train=True, augmentations=[
                    T.Resize((800, 800))
                ])
            )
            with torch.no_grad():
                if (is_zeroshot):
                    print(('is_zeroshot', nre, conf))
                    recall_results[('is_zeroshot', nre, conf)] = eval_dataset(train_dataloader, model, nre=50, config=conf, trained_triples=trained_triples)
                else:
                    print(('not_zeroshot', nre, conf))
                    recall_results[('not_zeroshot',nre, conf)] = eval_dataset(train_dataloader, model, nre=50, config=conf, trained_triples=None)

            train_data_iter = None
            del train_dataloader
            

#### EVAL 2/ PREDICATE PREDICTION TASK (GROUND TRUTH BOUNDING BOXES AND LABELS PROVIDED)

In [23]:
import numpy as np
#Evalutation of the models performance on detective visual predicate
#default
total_true_positive = 0
total_true_positive_5 = 0
total_relationships = 0

total_true_positive_zeroshot = 0
total_true_positive_5_zeroshot = 0
total_relationships_zeroshot = 0

#sub
total_true_positive_sub = 0
total_true_positive_5_sub = 0

total_true_positive_sub_zeroshot = 0
total_true_positive_5_sub_zeroshot = 0

#mul
total_true_positive_mul = 0
total_true_positive_5_mul = 0

total_true_positive_mul_zeroshot = 0
total_true_positive_5_mul_zeroshot = 0

#transe
total_true_positive_transe = 0
total_true_positive_5_transe = 0

total_true_positive_transe_mul = 0
total_true_positive_5_transe_mul = 0

n_examples = len(test_dataloader.dataset.dataset.dataset)

for i in range(n_examples):
    data = next(test_data_iter)[0]
    relationships = data["relationships"]
    
    if (len(relationships['subj_bboxes']) == 0):
        #no relationship annotations for the given image
        continue
    
    with torch.no_grad():
        predicate_distances, predicate_subtract_distances, transe_distances = model.get_predicate_distances(data)
        
    #Verbose
    if (i % 20 == 0 and i > 0):
        recall_zeroshot_1 = 0
        recall_zeroshot_5 = 0
        if (total_relationships_zeroshot > 0):
            recall_zeroshot_1 = total_true_positive_zeroshot / total_relationships_zeroshot * 100
            recall_zeroshot_5 = total_true_positive_5_zeroshot / total_relationships_zeroshot * 100

        recall_sub_1 = total_true_positive_sub / total_relationships * 100
        recall_sub_5 = total_true_positive_5_sub / total_relationships * 100
        recall_zeroshot_sub_1 = 0
        recall_zeroshot_sub_5 = 0
        if (total_relationships_zeroshot > 0):
            recall_zeroshot_sub_1 = total_true_positive_sub_zeroshot / total_relationships_zeroshot * 100
            recall_zeroshot_sub_5 = total_true_positive_5_sub_zeroshot / total_relationships_zeroshot * 100        
        
        recall_mul_1 = total_true_positive_mul / total_relationships * 100
        recall_mul_5 = total_true_positive_5_mul / total_relationships * 100
        recall_zeroshot_mul_1 = 0
        recall_zeroshot_mul_5 = 0
        if (total_relationships_zeroshot > 0):
            recall_zeroshot_mul_1 = total_true_positive_mul_zeroshot / total_relationships_zeroshot * 100
            recall_zeroshot_mul_5 = total_true_positive_5_mul_zeroshot / total_relationships_zeroshot * 100        
        
        recall_transe_1 = total_true_positive_transe / total_relationships * 100
        recall_transe_5 = total_true_positive_5_transe / total_relationships * 100
        
        recall_transe_mul_1 = total_true_positive_transe_mul / total_relationships * 100
        recall_transe_mul_5 = total_true_positive_5_transe_mul / total_relationships * 100
        
        recall_1 =  total_true_positive / total_relationships * 100
        recall_5 = total_true_positive_5 / total_relationships * 100
        print("|----------------------Iter {}------------------------|".format(i))
        print("| Recall top 1 {:5.2f} | Recall zeroshot top 1 {:5.2f} |".format(recall_1, recall_zeroshot_1))
        print("| Recall top 5 {:5.2f} | Recall zeroshot top 5 {:5.2f} |".format(recall_5, recall_zeroshot_5))
        print("| Recall sub top 1 {:5.2f} | Recall zeroshot sub top 1 {:5.2f} |".format(recall_sub_1, recall_zeroshot_sub_1))
        print("| Recall sub top 5 {:5.2f} | Recall zeroshot sub top 5 {:5.2f} |".format(recall_sub_5, recall_zeroshot_sub_5))
        print("| Recall mul top 1 {:5.2f} | Recall zeroshot mul top 1 {:5.2f} |".format(recall_mul_1, recall_zeroshot_mul_1))
        print("| Recall mul top 5 {:5.2f} | Recall zeroshot mul top 5 {:5.2f} |".format(recall_mul_5, recall_zeroshot_mul_5))        
        print("| Recall transe 1 {:5.2f} | Recall transe mul 1 {:5.2f} |".format(recall_transe_1, recall_transe_mul_1))
        print("| Recall transe 5 {:5.2f} | Recall transe mul 5 {:5.2f} |".format(recall_transe_5, recall_transe_mul_5))

    #select top 5 to calculate recall
    
    for j, pred_distance in enumerate(predicate_distances):
        sorted_pred_distance = np.argsort([item.cpu() for item in pred_distance])
        
        top_5_pred = sorted_pred_distance[:5]
        top_1_pred = sorted_pred_distance[0]
        
        #subtract
        sorted_pred_subtract_distance = np.argsort([item.cpu() for item in predicate_subtract_distances[j]])
        top_5_pred_sub = sorted_pred_subtract_distance[:5]
        top_1_pred_sub = sorted_pred_subtract_distance[0]
        
        #transe
        sorted_transe_distances = np.argsort([item.cpu() for item in transe_distances[j]])
        top_5_vtranse = sorted_transe_distances[:5]
        top_1_vtranse = sorted_transe_distances[0]

        #pred * transe
        pred_transe_multiply_distances = [ item1.cpu() * item2.cpu() for item1,item2 in zip(pred_distance, transe_distances[j])]
        top_5_pred_transe = np.argsort(pred_transe_multiply_distances[j].cpu())[:5]
        top_1_pred_transe = np.argsort(pred_transe_multiply_distances[j].cpu())[0]
        
        #pred * sub
        pred_multiply_distances = [ item1.cpu() * item2.cpu() for item1,item2 in zip(pred_distance, predicate_subtract_distances[j])]
        top_5_pred_mul = np.argsort(pred_multiply_distances)[:5]
        top_1_pred_mul = np.argsort(pred_multiply_distances)[0]

        gt_subj = model.object_classes[relationships['subj_classes'][j]]
        gt_pred = model.predicate_classes[relationships['pred_classes'][j]]
        gt_obj = model.object_classes[relationships['obj_classes'][j]]
            
        #compute true positive
        if (top_1_pred == relationships['pred_classes'][j]):
            total_true_positive += 1
        if (relationships['pred_classes'][j] in top_5_pred):
            total_true_positive_5 += 1
        
        if ('{}-{}-{}'.format(gt_subj, gt_pred, gt_obj) not in trained_triples):
            
            #default pred
            if (top_1_pred == relationships['pred_classes'][j]):
                total_true_positive_zeroshot += 1
            if (relationships['pred_classes'][j] in top_5_pred):
                total_true_positive_5_zeroshot += 1
                
            #sub
            if (top_1_pred_sub == relationships['pred_classes'][j]):
                total_true_positive_sub_zeroshot += 1
            if (relationships['pred_classes'][j] in top_5_pred_sub):
                total_true_positive_5_sub_zeroshot += 1
                
            #mul (pred*sub)
            if (top_1_pred_mul == relationships['pred_classes'][j]):
                total_true_positive_mul_zeroshot += 1
            if (relationships['pred_classes'][j] in top_5_pred_mul):
                total_true_positive_5_mul_zeroshot += 1
                
            total_relationships_zeroshot += 1
        
        #sub
        if (top_1_pred_sub == relationships['pred_classes'][j]):
            total_true_positive_sub += 1
        if (relationships['pred_classes'][j] in top_5_pred_sub):
            total_true_positive_5_sub += 1
        
        #pred * sub
        if (top_1_pred_mul == relationships['pred_classes'][j]):
            total_true_positive_mul += 1
        if (relationships['pred_classes'][j] in top_5_pred_mul):
            total_true_positive_5_mul += 1

        #transe
        if (top_1_vtranse == relationships['pred_classes'][j]):
            total_true_positive_transe += 1
        if (relationships['pred_classes'][j] in top_5_vtranse):
            total_true_positive_5_transe += 1
        
        #pred * sub
        if (top_1_pred_transe == relationships['pred_classes'][j]):
            total_true_positive_transe_mul += 1
        if (relationships['pred_classes'][j] in top_5_pred_transe):
            total_true_positive_5_transe_mul += 1
            
    total_relationships += len(predicate_distances)

print("Final recall top 1: ", total_true_positive / total_relationships * 100)
print("Final recall top 5: ", total_true_positive_5 / total_relationships * 100)


NameError: name 'model' is not defined

In [None]:
recall_zeroshot_1 = 0
recall_zeroshot_5 = 0
if (total_relationships_zeroshot > 0):
    recall_zeroshot_1 = total_true_positive_zeroshot / total_relationships_zeroshot * 100
    recall_zeroshot_5 = total_true_positive_5_zeroshot / total_relationships_zeroshot * 100

recall_sub_1 = total_true_positive_sub / total_relationships * 100
recall_sub_5 = total_true_positive_5_sub / total_relationships * 100
recall_zeroshot_sub_1 = 0
recall_zeroshot_sub_5 = 0
if (total_relationships_zeroshot > 0):
    recall_zeroshot_sub_1 = total_true_positive_sub_zeroshot / total_relationships_zeroshot * 100
    recall_zeroshot_sub_5 = total_true_positive_5_sub_zeroshot / total_relationships_zeroshot * 100        

recall_mul_1 = total_true_positive_mul / total_relationships * 100
recall_mul_5 = total_true_positive_5_mul / total_relationships * 100
recall_zeroshot_mul_1 = 0
recall_zeroshot_mul_5 = 0
if (total_relationships_zeroshot > 0):
    recall_transe_mul_1 = total_true_positive_transe_mul / total_relationships * 100
    recall_transe_mul_5 = total_true_positive_5_transe_mul / total_relationships * 100

recall_1 =  total_true_positive / total_relationships * 100
recall_5 = total_true_positive_5 / total_relationships * 100
print("|----------------------Iter {}------------------------|".format(i))
print("| Recall top 1 {:5.2f} | Recall zeroshot top 1 {:5.2f} |".format(recall_1, recall_zeroshot_1))
print("| Recall top 5 {:5.2f} | Recall zeroshot top 5 {:5.2f} |".format(recall_5, recall_zeroshot_5))
print("| Recall sub top 1 {:5.2f} | Recall zeroshot sub top 1 {:5.2f} |".format(recall_sub_1, recall_zeroshot_sub_1))
print("| Recall sub top 5 {:5.2f} | Recall zeroshot sub top 5 {:5.2f} |".format(recall_sub_5, recall_zeroshot_sub_5))
print("| Recall mul top 1 {:5.2f} | Recall zeroshot mul top 1 {:5.2f} |".format(recall_mul_1, recall_zeroshot_mul_1))
print("| Recall mul top 5 {:5.2f} | Recall zeroshot mul top 5 {:5.2f} |".format(recall_mul_5, recall_zeroshot_mul_5))        
print("| Recall transe 1 {:5.2f} | Recall transe mul 1 {:5.2f} |".format(recall_transe_1, recall_transe_mul_1))
print("| Recall transe 5 {:5.2f} | Recall transe mul 5 {:5.2f} |".format(recall_transe_5, recall_transe_mul_5))
recall_zeroshot_mul_1 = total_true_positive_mul_zeroshot / total_relationships_zeroshot * 100
recall_zeroshot_mul_5 = total_true_positive_5_mul_zeroshot / total_relationships_zeroshot * 100        

recall_transe_1 = total_true_positive_transe / total_relationships * 100
recall_transe_5 = total_true_positive_5_transe / total_relationships * 100

recall_transe_mul_1 = total_true_positive_transe_mul / total_relationships * 100
recall_transe_mul_5 = total_true_positive_5_transe_mul / total_relationships * 100

recall_1 =  total_true_positive / total_relationships * 100
recall_5 = total_true_positive_5 / total_relationships * 100
print("|----------------------Iter {}------------------------|".format(i))
print("| Recall top 1 {:5.2f} | Recall zeroshot top 1 {:5.2f} |".format(recall_1, recall_zeroshot_1))
print("| Recall top 5 {:5.2f} | Recall zeroshot top 5 {:5.2f} |".format(recall_5, recall_zeroshot_5))
print("| Recall sub top 1 {:5.2f} | Recall zeroshot sub top 1 {:5.2f} |".format(recall_sub_1, recall_zeroshot_sub_1))
print("| Recall sub top 5 {:5.2f} | Recall zeroshot sub top 5 {:5.2f} |".format(recall_sub_5, recall_zeroshot_sub_5))
print("| Recall mul top 1 {:5.2f} | Recall zeroshot mul top 1 {:5.2f} |".format(recall_mul_1, recall_zeroshot_mul_1))
print("| Recall mul top 5 {:5.2f} | Recall zeroshot mul top 5 {:5.2f} |".format(recall_mul_5, recall_zeroshot_mul_5))        
print("| Recall transe 1 {:5.2f} | Recall transe mul 1 {:5.2f} |".format(recall_transe_1, recall_transe_mul_1))
print("| Recall transe 5 {:5.2f} | Recall transe mul 5 {:5.2f} |".format(recall_transe_5, recall_transe_mul_5))

In [None]:
#Evalutation of the models performance on detective visual object (no TransE)
import numpy as np
#total
total_relationships = 0
total_relationships_zeroshot = 0

#default
total_true_positive = 0
total_true_positive_5 = 0
total_true_positive_10 = 0

total_true_positive_zeroshot = 0
total_true_positive_5_zeroshot = 0
total_true_positive_10_zeroshot = 0

#add
total_true_positive_add = 0
total_true_positive_5_add = 0
total_true_positive_10_add = 0

total_true_positive_add_zeroshot = 0
total_true_positive_5_add_zeroshot = 0
total_true_positive_10_add_zeroshot = 0

#mul
total_true_positive_mul = 0
total_true_positive_5_mul = 0
total_true_positive_10_mul = 0

total_true_positive_mul_zeroshot = 0
total_true_positive_5_mul_zeroshot = 0
total_true_positive_10_mul_zeroshot = 0

for i in range(n_examples):
    data = next(test_data_iter)[0]
    relationships = data["relationships"]
    with torch.no_grad():
        object_distances, object_add_distances = model.get_object_distances(data)
        
    #Verbose
    if (i % 20 == 0 and i > 0):
        recall_1 =  total_true_positive / total_relationships * 100
        recall_5 = total_true_positive_5 / total_relationships * 100
        recall_10 = total_true_positive_10 / total_relationships * 100
        recall_zeroshot_1 = 0
        recall_zeroshot_5 = 0
        recall_zeroshot_10 = 0
        if (total_relationships_zeroshot > 0):
            recall_zeroshot_1 = total_true_positive_zeroshot / total_relationships_zeroshot * 100
            recall_zeroshot_5 = total_true_positive_5_zeroshot / total_relationships_zeroshot * 100
            recall_zeroshot_10 = total_true_positive_10_zeroshot / total_relationships_zeroshot * 100

        recall_add_1 = total_true_positive_add / total_relationships * 100
        recall_add_5 = total_true_positive_5_add / total_relationships * 100
        recall_add_10 = total_true_positive_10_add / total_relationships * 100
        recall_zeroshot_add_1 = 0
        recall_zeroshot_add_5 = 0
        recall_zeroshot_add_10 = 0        
        if (total_relationships_zeroshot > 0):
            recall_zeroshot_add_1 = total_true_positive_add_zeroshot / total_relationships_zeroshot * 100
            recall_zeroshot_add_5 = total_true_positive_5_add_zeroshot / total_relationships_zeroshot * 100
            recall_zeroshot_add_10 = total_true_positive_10_add_zeroshot / total_relationships_zeroshot * 100
        
        recall_mul_1 = total_true_positive_mul / total_relationships * 100
        recall_mul_5 = total_true_positive_5_mul / total_relationships * 100
        recall_mul_10 = total_true_positive_10_mul / total_relationships * 100
        recall_zeroshot_mul_1 = 0
        recall_zeroshot_mul_5 = 0
        recall_zeroshot_mul_10 = 0
        if (total_relationships_zeroshot > 0):
            recall_zeroshot_mul_1 = total_true_positive_mul_zeroshot / total_relationships_zeroshot * 100
            recall_zeroshot_mul_5 = total_true_positive_5_mul_zeroshot / total_relationships_zeroshot * 100        
            recall_zeroshot_mul_10 = total_true_positive_10_mul_zeroshot / total_relationships_zeroshot * 100        
        
        print("|----------------------Iter {}------------------------|".format(i))
        print("| Recall top 1 {:5.2f} | Recall zeroshot top 1 {:5.2f} |".format(recall_1, recall_zeroshot_1))
        print("| Recall top 5 {:5.2f} | Recall zeroshot top 5 {:5.2f} |".format(recall_5, recall_zeroshot_5))
        print("| Recall top 10 {:5.2f} | Recall zeroshot top 10 {:5.2f} |".format(recall_10, recall_zeroshot_10))
        print("| Recall add top 1 {:5.2f} | Recall zeroshot sub top 1 {:5.2f} |".format(recall_add_1, recall_zeroshot_add_1))
        print("| Recall add top 5 {:5.2f} | Recall zeroshot sub top 5 {:5.2f} |".format(recall_add_5, recall_zeroshot_add_5))
        print("| Recall add top 10 {:5.2f} | Recall zeroshot sub top 10 {:5.2f} |".format(recall_add_10, recall_zeroshot_add_10))
        print("| Recall mul top 1 {:5.2f} | Recall zeroshot mul top 1 {:5.2f} |".format(recall_mul_1, recall_zeroshot_mul_1))
        print("| Recall mul top 5 {:5.2f} | Recall zeroshot mul top 5 {:5.2f} |".format(recall_mul_5, recall_zeroshot_mul_5))        
        print("| Recall mul top 10 {:5.2f} | Recall zeroshot mul top 10 {:5.2f} |".format(recall_mul_10, recall_zeroshot_mul_10))        

    #select top 5 to calculate recall
    for j, obj_distance in enumerate(object_distances):
        top_1_obj = np.argsort(obj_distance)[0]
        top_5_obj = np.argsort(obj_distance)[:5]
        top_10_obj = np.argsort(obj_distance)[:10]
        
        #add
        top_1_obj_add = np.argsort(object_add_distances[j])[0]
        top_5_obj_add = np.argsort(object_add_distances[j])[:5]
        top_10_obj_add = np.argsort(object_add_distances[j])[:10]
        
        #pred * add
        obj_multiply_distances = [ item1 * item2 for item1,item2 in zip(obj_distance, object_add_distances[j])]
        top_1_obj_mul = np.argsort(obj_multiply_distances)[0]
        top_5_obj_mul = np.argsort(obj_multiply_distances)[:5]
        top_10_obj_mul = np.argsort(obj_multiply_distances)[:10]

        gt_subj = model.object_classes[relationships['subj_classes'][j]]
        gt_pred = model.predicate_classes[relationships['pred_classes'][j]]
        gt_obj = model.object_classes[relationships['obj_classes'][j]]
            
        #compute true positive        
        if ('{}-{}-{}'.format(gt_subj, gt_pred, gt_obj) not in trained_triples):
            
            #default pred
            if (top_1_obj == relationships['obj_classes'][j]):
                total_true_positive_zeroshot += 1
            if (relationships['obj_classes'][j] in top_5_obj):
                total_true_positive_5_zeroshot += 1
            if (relationships['obj_classes'][j] in top_10_obj):
                total_true_positive_10_zeroshot += 1

            #sub
            if (top_1_obj_add == relationships['obj_classes'][j]):
                total_true_positive_add_zeroshot += 1
            if (relationships['obj_classes'][j] in top_5_obj_add):
                total_true_positive_5_add_zeroshot += 1
            if (relationships['obj_classes'][j] in top_10_obj_add):
                total_true_positive_10_add_zeroshot += 1
                
            #mul (pred*sub)
            if (top_1_obj_mul == relationships['obj_classes'][j]):
                total_true_positive_mul_zeroshot += 1
            if (relationships['obj_classes'][j] in top_5_obj_mul):
                total_true_positive_5_mul_zeroshot += 1
            if (relationships['obj_classes'][j] in top_10_obj_mul):
                total_true_positive_10_mul_zeroshot += 1
                
            total_relationships_zeroshot += 1
        
        #normal
        if (top_1_obj == relationships['obj_classes'][j]):
            total_true_positive += 1
        if (relationships['obj_classes'][j] in top_5_obj):
            total_true_positive_5 += 1
        if (relationships['obj_classes'][j] in top_10_obj):
            total_true_positive_10 += 1
            
        #add
        if (top_1_obj_add == relationships['obj_classes'][j]):
            total_true_positive_add += 1
        if (relationships['obj_classes'][j] in top_5_obj_add):
            total_true_positive_5_add += 1
        if (relationships['obj_classes'][j] in top_10_obj_add):
            total_true_positive_10_add += 1
        
        #pred * add
        if (top_1_obj_mul == relationships['obj_classes'][j]):
            total_true_positive_mul += 1
        if (relationships['obj_classes'][j] in top_5_obj_mul):
            total_true_positive_5_mul += 1
        if (relationships['obj_classes'][j] in top_10_obj_mul):
            total_true_positive_10_mul += 1
    total_relationships += len(object_distances)

print("Final recall top 1: ", total_true_positive / total_relationships * 100)
print("Final recall top 5: ", total_true_positive_5 / total_relationships * 100)

In [None]:
recall_1 =  total_true_positive / total_relationships * 100
recall_5 = total_true_positive_5 / total_relationships * 100
recall_10 = total_true_positive_10 / total_relationships * 100
recall_zeroshot_1 = 0
recall_zeroshot_5 = 0
recall_zeroshot_10 = 0
if (total_relationships_zeroshot > 0):
    recall_zeroshot_1 = total_true_positive_zeroshot / total_relationships_zeroshot * 100
    recall_zeroshot_5 = total_true_positive_5_zeroshot / total_relationships_zeroshot * 100
    recall_zeroshot_10 = total_true_positive_10_zeroshot / total_relationships_zeroshot * 100

recall_add_1 = total_true_positive_add / total_relationships * 100
recall_add_5 = total_true_positive_5_add / total_relationships * 100
recall_add_10 = total_true_positive_10_add / total_relationships * 100
recall_zeroshot_add_1 = 0
recall_zeroshot_add_5 = 0
recall_zeroshot_add_10 = 0        
if (total_relationships_zeroshot > 0):
    recall_zeroshot_add_1 = total_true_positive_add_zeroshot / total_relationships_zeroshot * 100
    recall_zeroshot_add_5 = total_true_positive_5_add_zeroshot / total_relationships_zeroshot * 100
    recall_zeroshot_add_10 = total_true_positive_10_add_zeroshot / total_relationships_zeroshot * 100

recall_mul_1 = total_true_positive_mul / total_relationships * 100
recall_mul_5 = total_true_positive_5_mul / total_relationships * 100
recall_mul_10 = total_true_positive_10_mul / total_relationships * 100
recall_zeroshot_mul_1 = 0
recall_zeroshot_mul_5 = 0
recall_zeroshot_mul_10 = 0
if (total_relationships_zeroshot > 0):
    recall_zeroshot_mul_1 = total_true_positive_mul_zeroshot / total_relationships_zeroshot * 100
    recall_zeroshot_mul_5 = total_true_positive_5_mul_zeroshot / total_relationships_zeroshot * 100        
    recall_zeroshot_mul_10 = total_true_positive_10_mul_zeroshot / total_relationships_zeroshot * 100        

print("|----------------------Iter {}------------------------|".format(i))
print("| Recall top 1 {:5.2f} | Recall zeroshot top 1 {:5.2f} |".format(recall_1, recall_zeroshot_1))
print("| Recall top 5 {:5.2f} | Recall zeroshot top 5 {:5.2f} |".format(recall_5, recall_zeroshot_5))
print("| Recall top 10 {:5.2f} | Recall zeroshot top 10 {:5.2f} |".format(recall_10, recall_zeroshot_10))
print("| Recall add top 1 {:5.2f} | Recall zeroshot sub top 1 {:5.2f} |".format(recall_add_1, recall_zeroshot_add_1))
print("| Recall add top 5 {:5.2f} | Recall zeroshot sub top 5 {:5.2f} |".format(recall_add_5, recall_zeroshot_add_5))
print("| Recall add top 10 {:5.2f} | Recall zeroshot sub top 10 {:5.2f} |".format(recall_add_10, recall_zeroshot_add_10))
print("| Recall mul top 1 {:5.2f} | Recall zeroshot mul top 1 {:5.2f} |".format(recall_mul_1, recall_zeroshot_mul_1))
print("| Recall mul top 5 {:5.2f} | Recall zeroshot mul top 5 {:5.2f} |".format(recall_mul_5, recall_zeroshot_mul_5))        
print("| Recall mul top 10 {:5.2f} | Recall zeroshot mul top 10 {:5.2f} |".format(recall_mul_10, recall_zeroshot_mul_10))        


In [None]:
data_iter = iter(train_dataloader)
data = next(data_iter)[0]
# model = RelTransR(cfg)
# negative_examples = model.generate_negative_examples(data)
# results = model(data,negative_examples)

# Training Scripts

In [11]:
from detectron2.data import (
    DatasetCatalog, DatasetMapper,
    build_detection_train_loader
)
from config import get_vrd_cfg
import detectron2.data.transforms as T

cfg = get_vrd_cfg()
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = .7
batch_size = 1

# Dataloaders
if ('test_dataloader' in vars()):
    del test_dataloader

train_dataloader = build_detection_train_loader(cfg,
    mapper=DatasetMapper(cfg, is_train=True, augmentations=[
        T.Resize((800, 800))
    ])
)


In [12]:
train_dataset = DatasetCatalog.get("vrd_train")
train_dataloader = build_detection_train_loader(cfg,dataset=train_dataset,
    mapper=DatasetMapper(cfg, is_train=True, augmentations=[
        T.Resize((800, 800))
    ])
)

In [13]:
import torch
import wandb
from torch.optim.lr_scheduler import StepLR
torch.manual_seed(0)

#params
n_datapoints = len(train_dataloader.dataset.dataset.dataset)
# n_datapoints = 3780
n_iters = cfg.SOLVER.MAX_ITER
num_epochs = int(n_iters / n_datapoints)
chkpoint_it = n_datapoints #create a checkpoint every 1000 iterations

#model
model = RelTransR(cfg)
device = torch.device("cuda")
model.to(device)

#scheduler
learning_rate = 0.001
gamma = cfg.SOLVER.GAMMA
momentum = 0.9
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
scheduler = StepLR(optimizer, step_size=4, gamma=0.1)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
# model = RelTransR(cfg)
# device = torch.device("cuda")
# model.to(device)
# wandb.watch(model, log="all", log_freq=20)

In [15]:
#wandb
log_interval = 20
wandb.init(project="vrdtransr-project-vltranse-lang-con-2", entity="herobaby71")
wandb.config = {
    "seed": 0,
    "learning_rate": learning_rate,
    "gamma": 0.1,
    "momentum": 0.9,
    "epochs": num_epochs,
    "n_iters": n_iters,
    "batch_size": 1
}
wandb.watch(model, log="all", log_freq=20)

[34m[1mwandb[0m: Currently logged in as: [33mherobaby71[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[]

In [16]:
def train_log(loss, lr, it, epoch, loss_subj, loss_pred, loss_obj, loss_transr):
    # Where the magic happens
    wandb.log({"lr":lr, "epoch": epoch, "loss": loss, "loss_subj": loss_subj, "loss_pred": loss_pred, "loss_obj": loss_obj, "loss_transr":loss_transr}, step=it)

In [17]:
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.free: %fGB"%(f/1024/1024/1024))

torch.cuda.memory_allocated: 2.142324GB
torch.cuda.memory_reserved: 2.193359GB
torch.cuda.free: 0.051035GB


In [18]:
import time
import os

#checkpoint every 2000 steps
chkpoint_it = 2000

it = 0
initial_it = 0 # checkpoint intial iteration to resume training
start_time = time.time()
losses = []

#load checkpoint
load_chkpoint = False
if (load_chkpoint):
    chkpoint_path = '../checkpoint/'
    model_name = 'vrd2_model_18000.pt'
    chkpoint_full_path = os.path.join(chkpoint_path, model_name)
    it, start_epoch, losses = load_checkpoint(model, chkpoint_full_path, optimizer=optimizer)
    initial_it = it

In [19]:
import os, json

trained_triples_path = '../generated/trained_triples.json'
trained_triples = {}
if (os.path.exists(trained_triples_path)):
    with open(trained_triples_path) as file:
        trained_triples = json.load(file)
else:
    iter_dataloader = iter(train_dataloader)
    n_iters = len(train_dataloader.dataset.dataset)
    for i in range(n_iters):
        print(i)
        data = next(iter_dataloader)[0]
        relationships = data['relationships']
        for j in range(len(relationships['subj_classes'])):
            subj_cls = model.object_classes[relationships['subj_classes'][j]]
            pred_cls = model.predicate_classes[relationships['pred_classes'][j]]
            obj_cls = model.object_classes[relationships['obj_classes'][j]]
            trained_triples['{}-{}-{}'.format(subj_cls, pred_cls, obj_cls)] = 1

    with open(trained_triples_path, 'w') as file:
        file.write(json.dumps(trained_triples))

In [24]:
iter_dataloader = iter(train_dataloader)
interval_cnt = 0
chkpoint_path = '../checkpoint'

#Losses
total_loss = 0
subj_loss = 0
obj_loss = 0
pred_loss = 0
transr_loss = 0

for i in range(n_iters):    
    #continue training from the previous checkpoint
    if (i < initial_it % n_datapoints):
        continue
        
    #iterator
    try:
        data = next(iter_dataloader)[0]
    except StopIteration:
        print("iterator has reach its end at iteration {}. Initializing a new iterator.".format(str(it)))
        iter_dataloader = iter(train_dataloader)
        data = next(iter_dataloader)[0]


    if (len(data['relationships']['subj_bboxes']) <= 1):
        #image has only one relationship, cannot train
        print("an image has been removed for this batch")
        continue
        
    #other exclusion due to bad label
    if ('1841.jpg' in data["file_name"]):
        print("this image has bad label and has been removed.")
        continue

    optimizer.zero_grad()

    #forward passes
    negative_examples = {}
    negative_examples = model.generate_negative_examples(data)
#     print("number of objects:", len( data['instances'].get_fields()['gt_boxes']))
#     print("number of neg examples:", len(negative_examples))
    triplet_losses = model(data, negative_examples)
    
    #compute gradient backward
    final_loss = triplet_losses['obj'] + triplet_losses['pred'] + triplet_losses['subj'] + triplet_losses['transr']
    final_loss.backward()
    optimizer.step()

    #total loss
    total_loss += final_loss.item()
    subj_loss += triplet_losses['subj'].item()
    pred_loss += triplet_losses['pred'].item()
    obj_loss += triplet_losses['obj'].item()
    transr_loss += triplet_losses['transr'].item()
    
    interval_cnt += 1
    if (it > initial_it and it % log_interval == 0 and it > 0):
        current_loss = total_loss / interval_cnt
        losses.append(current_loss)
        elapsed = time.time() - start_time
        epoch = it / n_datapoints
        print('| it {} | epoch {} | lr {} | ms/batch {:5.2f} | loss {:5.2f}'.format(
            it, int(epoch), scheduler.get_last_lr()[0], elapsed * 1000 / log_interval, current_loss))
        train_log(current_loss, scheduler.get_last_lr()[0], it, int(epoch),
                  loss_subj=subj_loss/interval_cnt, loss_pred=pred_loss/interval_cnt,
                  loss_obj=obj_loss/interval_cnt, loss_transr=transr_loss/interval_cnt)
        total_loss = 0
        subj_loss = 0
        pred_loss = 0
        obj_loss = 0
        transr_loss = 0
        interval_cnt = 0
        start_time = time.time()

    if (it > initial_it and it % chkpoint_it == 0 and it > 0):
        chkpnt = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "it": it,
            "losses": losses
        }
        torch.save(chkpnt, os.path.join(chkpoint_path, 'vltranse_langcon_model_{}.pt'.format(str(it))))
        
    #increment total count
    it = it + 1

| it 2000 | epoch 0 | lr 0.001 | ms/batch 47845.62 | loss  0.20
an image has been removed for this batch
| it 2020 | epoch 0 | lr 0.001 | ms/batch 1189.71 | loss  0.22
an image has been removed for this batch
| it 2040 | epoch 0 | lr 0.001 | ms/batch 1061.65 | loss  0.19
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 2060 | epoch 0 | lr 0.001 | ms/batch 1023.86 | loss  0.20
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 2080 | epoch 0 | lr 0.001 | ms/batch 1057.20 | loss  0.18
an image has been removed for this batch
an image has been removed for this batch
| it 2100 | epoch 0 | lr 0.001 | ms/batch 1030.11 | loss  0.17
an image has been removed for this batch
| it 2120 | epoch 0 | lr 0.001 | ms/batch 1053.40 | loss  0.22
an image has been removed for this batch
an image has been removed

an image has been removed for this batch
| it 3100 | epoch 0 | lr 0.001 | ms/batch 1041.84 | loss  0.15
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 3120 | epoch 0 | lr 0.001 | ms/batch 1030.97 | loss  0.21
an image has been removed for this batch
| it 3140 | epoch 0 | lr 0.001 | ms/batch 1068.13 | loss  0.18
an image has been removed for this batch
| it 3160 | epoch 0 | lr 0.001 | ms/batch 1078.71 | loss  0.17
an image has been removed for this batch
an image has been removed for this batch
| it 3180 | epoch 0 | lr 0.001 | ms/batch 1049.23 | loss  0.13
an image has been removed for this batch
| it 3200 | epoch 0 | lr 0.001 | ms/batch 1036.07 | loss  0.18
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 3220 | epoch 0 | lr 0.001 | ms/batch 1059.68 | loss  0.20
| it 3240 | epoch 0 | lr 0.001 | ms/batch 1061.62 | loss  0.17
an i

| it 4140 | epoch 1 | lr 0.001 | ms/batch 1025.75 | loss  0.05
an image has been removed for this batch
an image has been removed for this batch
| it 4160 | epoch 1 | lr 0.001 | ms/batch 1080.34 | loss  0.06
an image has been removed for this batch
| it 4180 | epoch 1 | lr 0.001 | ms/batch 1092.91 | loss  0.05
an image has been removed for this batch
an image has been removed for this batch
| it 4200 | epoch 1 | lr 0.001 | ms/batch 1075.48 | loss  0.08
| it 4220 | epoch 1 | lr 0.001 | ms/batch 1020.23 | loss  0.04
an image has been removed for this batch
| it 4240 | epoch 1 | lr 0.001 | ms/batch 1060.29 | loss  0.04
an image has been removed for this batch
an image has been removed for this batch
| it 4260 | epoch 1 | lr 0.001 | ms/batch 1007.01 | loss  0.03
an image has been removed for this batch
an image has been removed for this batch
| it 4280 | epoch 1 | lr 0.001 | ms/batch 1091.82 | loss  0.04
| it 4300 | epoch 1 | lr 0.001 | ms/batch 1051.16 | loss  0.05
an image has been remov

an image has been removed for this batch
| it 5220 | epoch 1 | lr 0.001 | ms/batch 1092.63 | loss  0.05
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 5240 | epoch 1 | lr 0.001 | ms/batch 1068.93 | loss  0.03
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 5260 | epoch 1 | lr 0.001 | ms/batch 1076.47 | loss  0.03
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 5280 | epoch 1 | lr 0.001 | ms/batch 1065.89 | loss  0.03
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 5300 | epoch 1 | lr 0.001 | ms/batch 1023.64 | loss  0.03
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for

an image has been removed for this batch
an image has been removed for this batch
| it 6180 | epoch 1 | lr 0.001 | ms/batch 1059.91 | loss  0.16
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 6200 | epoch 1 | lr 0.001 | ms/batch 1073.77 | loss  0.12
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 6220 | epoch 1 | lr 0.001 | ms/batch 1067.72 | loss  0.09
an image has been removed for this batch
an image has been removed for this batch
| it 6240 | epoch 1 | lr 0.001 | ms/batch 1074.21 | loss  0.10
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 6260 | epoch 1 | lr 0.001 | ms/batch 1105.49 | loss  0.05
| it 6280 | epoch 1 | lr 0.001 | ms/batch 1016.32 | loss  0.03
| it 6300 | epoch 1 | lr 0.001 | ms/batch 1059.0

an image has been removed for this batch
an image has been removed for this batch
| it 7180 | epoch 1 | lr 0.001 | ms/batch 1049.10 | loss  0.04
an image has been removed for this batch
| it 7200 | epoch 1 | lr 0.001 | ms/batch 1065.56 | loss  0.03
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 7220 | epoch 1 | lr 0.001 | ms/batch 1072.06 | loss  0.04
an image has been removed for this batch
| it 7240 | epoch 1 | lr 0.001 | ms/batch 1069.19 | loss  0.02
an image has been removed for this batch
an image has been removed for this batch
an image has been removed for this batch
| it 7260 | epoch 1 | lr 0.001 | ms/batch 1083.35 | loss  0.03
an image has been removed for this batch
| it 7280 | epoch 1 | lr 0.001 | ms/batch 1030.33 | loss  0.02
an imag

KeyboardInterrupt: 

In [23]:
it

2000

### [Optional] Test Eval Model Performance

In [None]:
from detectron2.data import (
    DatasetCatalog, DatasetMapper,
    build_detection_train_loader
)

cfg = get_vrd_cfg()

#model
# model = RelTransR(cfg)
# device = torch.device("cuda")
# model.to(device)

#test dataloader
cfg.DATASETS.TRAIN = ("vrd_val", )
test_dataloader = build_detection_train_loader(cfg,
    mapper=DatasetMapper(cfg, is_train=True, augmentations=[
        T.Resize((800, 800))
    ])
)

In [None]:
global_loss = 0
avg_loss = 0
cnt = 0

total_loss = 0
pred_loss = 0
interval_cnt = 0

model.eval()
test_data_iter = iter(test_dataloader)
broken_image = []

with torch.no_grad():
    for i in range(2000):
        data = next(test_data_iter)[0]

        #If there is only one relationship (not suitable to be evaluate with the given loss)
        if (len(data['relationships']['subj_bboxes']) == 1):
            #image has only one relationship, cannot train
            print("an image has been removed for this batch")
            broken_image.append(data)
            continue

        # get negative examples and compute losses
        negative_examples = model.generate_negative_examples(data)
        triplet_losses = model(data, negative_examples)
        final_loss = triplet_losses['obj'] + triplet_losses['subj'] + triplet_losses['pred']

        #output interval loss
        total_loss += final_loss.item()
        pred_loss += triplet_losses['subj'].item()
        interval_cnt += 1
        if (i > 0 and i % 20 == 0):
            current_loss = total_loss / interval_cnt
            pred_current_loss = pred_loss / interval_cnt
            print('| iter: {} | loss {:5.2f} | pred_loss {:5.2f}'.format(i, current_loss, pred_current_loss))
            total_loss = 0
            pred_loss = 0
            interval_cnt = 0

        #update global loss
        global_loss += final_loss.item()
        cnt += 1
avg_loss = total_loss / cnt

In [None]:
global_loss/cnt

In [None]:
c

In [None]:
import torch
import torch.nn.functional as F

x = torch.tensor([[1,2,3], [2,3,4]])
y = torch.tensor([[1,2,3.2]])

1- F.cosine_similarity(x, y)

In [None]:

x = torch.tensor([[1,2,3], [2,3,4]])
y = torch.tensor([[1.5,2.25], [2.1,3.2]])

In [None]:
x.shape

In [None]:
torch.cat((x, y), 1)