# Visual Relationship Detection

In this tutorial, we focus on the task of classifying visual relationships between objects in an image. For any given image, there might be many such relationships, defined formally as a `subject <predictate> object` (e.g. `person <riding> bike`). As an example, in the relationship `man riding bicycle`), "man" and "bicycle" are the subject and object, respectively, and "riding" is the relationship predicate.

![Visual Relationships](https://cs.stanford.edu/people/ranjaykrishna/vrd/dataset.png)

In the examples of the relationships shown above, the red box represents the _subject_ while the green box represents the _object_. The _predicate_ (e.g. kick) denotes what relationship connects the subject and the object.

For the purpose of this tutorial, we operate over the [Visual Relationship Detection (VRD) dataset](https://cs.stanford.edu/people/ranjaykrishna/vrd/) and focus on action relationships. We define our classification task as **identifying which of three relationships holds between the objects represented by a pair of bounding boxes.**

In [None]:
import os
import pandas as pd

if os.path.basename(os.getcwd()) == "snorkel-tutorials":
    os.chdir("visual_relation")

### 1. Load Dataset
We load the VRD dataset and filter images with at least one action predicate in it, since these are more difficult to classify than geometric relationships like `above` or `next to`. We load the train, valid, and test sets as Pandas `DataFrame` objects with the following fields:
- `label`: The relationship between the objects. 0: `RIDE`, 1: `CARRY`, 2: `OTHER` action predicates
- `object_bbox`: coordinates of the bounding box for the object `[ymin, ymax, xmin, xmax]`
- `object_category`: category of the object
- `source_img`: filename for the corresponding image the relationship is in
- `subject_bbox`: coordinates of the bounding box for the object `[ymin, ymax, xmin, xmax]`
- `subject_category`: category of the subject

If you are running this notebook for the first time, it will take ~15 mins to download all the required sample data.

The sampled version of the dataset **uses the same 26 data points across the train, dev, and test sets.
This setting is meant to demonstrate quickly how Snorkel works with this task, not to demonstrate performance.**

In [None]:
from utils import load_vrd_data
# changed IMAGES_URL in download_full_data.sh to "http://imagenet.stanford.edu/internal/jcjohns/scene_graphs/sg_dataset.zip"
# setting sample=False will take ~3 hours to run (downloads full VRD dataset)
sample = False
is_test = os.environ.get("TRAVIS") == "true" or os.environ.get("IS_TEST") == "true"
df_train, df_valid, df_test = load_vrd_data(sample, is_test)

print("Train Relationships: ", len(df_train))
print("Dev Relationships: ", len(df_valid))
print("Test Relationships: ", len(df_test))

df_train.head()

Note that the training `DataFrame` will have a labels field with all -1s. This denotes the lack of labels for that particular dataset. In this tutorial, we will assign probabilistic labels to the training set by writing labeling functions over attributes of the subject and objects!

## 2. Writing Labeling Functions
We now write labeling functions to detect what relationship exists between pairs of bounding boxes. To do so, we can encode various intuitions into the labeling functions:
* _Categorical_ intution: knowledge about the categories of subjects and objects usually involved in these relationships (e.g., `person` is usually the subject for predicates like `ride` and `carry`)
* _Spatial_ intuition: knowledge about the relative positions of the subject and objects (e.g., subject is usually higher than the object for the predicate `ride`)

In [None]:
RIDE = 0
CARRY = 1
OTHER = 2
ABSTAIN = -1

We begin with labeling functions that encode categorical intuition: we use knowledge about common subject-object category pairs that are common for `RIDE` and `CARRY` and also knowledge about what subjects or objects are unlikely to be involved in the two relationships.

In [None]:
from snorkel.labeling import labeling_function

# Category-based LFs
@labeling_function()
def lf_ride_object(x):
    if x.subject_category == "person":
        if x.object_category in [
            "bike",
            "snowboard",
            "motorcycle",
            "horse",
            "bus",
            "truck",
            "elephant",
        ]:
            return RIDE
    return ABSTAIN


@labeling_function()
def lf_carry_object(x):
    if x.subject_category == "person":
        if x.object_category in ["bag", "surfboard", "skis"]:
            return CARRY
    return ABSTAIN


@labeling_function()
def lf_carry_subject(x):
    if x.object_category == "person":
        if x.subject_category in ["chair", "bike", "snowboard", "motorcycle", "horse"]:
            return CARRY
    return ABSTAIN


@labeling_function()
def lf_not_person(x):
    if x.subject_category != "person":
        return OTHER
    return ABSTAIN

We now encode our spatial intuition, which includes measuring the distance between the bounding boxes and comparing their relative areas.

In [None]:
YMIN = 0
YMAX = 1
XMIN = 2
XMAX = 3

In [None]:
import numpy as np

# Distance-based LFs
@labeling_function()
def lf_ydist(x):
    if x.subject_bbox[XMAX] < x.object_bbox[XMAX]:
        return OTHER
    return ABSTAIN


@labeling_function()
def lf_dist(x):
    if np.linalg.norm(np.array(x.subject_bbox) - np.array(x.object_bbox)) <= 1000:
        return OTHER
    return ABSTAIN


def area(bbox):
    return (bbox[YMAX] - bbox[YMIN]) * (bbox[XMAX] - bbox[XMIN])


# Size-based LF
@labeling_function()
def lf_area(x):
    if area(x.subject_bbox) / area(x.object_bbox) <= 0.5:
        return OTHER
    return ABSTAIN

Note that the labeling functions have varying empirical accuracies and coverages. Due to class imbalance in our chosen relationships, labeling functions that label the `OTHER` class have higher coverage than labeling functions for `RIDE` or `CARRY`. This reflects the distribution of classes in the dataset as well.

In [None]:
from snorkel.labeling import PandasLFApplier

lfs = [
    lf_ride_object,
    lf_carry_object,
    lf_carry_subject,
    lf_not_person,
    lf_ydist,
    lf_dist,
    lf_area,
]

applier = PandasLFApplier(lfs)
L_train = applier.apply(df_train)
L_valid = applier.apply(df_valid)

In [None]:
from snorkel.labeling import LFAnalysis

Y_valid = df_valid.label.values
LFAnalysis(L_valid, lfs).lf_summary(Y_valid)

## 2b. Import dependencies and informed label_model

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')
from Our_Monitors.CD_Monitor import CDM, Informed_LabelModel
from Our_Monitors.CDGA_Monitor import CDGAM
from Our_Monitors.New_Monitor import NM
from Our_Monitors.utils import ModVarma_InCov

L_dev = L_valid;# print(L_dev)
Y_dev = Y_valid

from dependency_model.varma_deps_functions import get_varma_edges, get_varma_with_gold_edges

#### Toggle all warnings js code

In [None]:
%%javascript
(function(on) {
const e=$( "<a>Setup failed</a>" );
const ns="js_jupyter_suppress_warnings";
var cssrules=$("#"+ns);
if(!cssrules.length) cssrules = $("<style id='"+ns+"' type='text/css'>div.output_stderr { } </style>").appendTo("head");
e.click(function() {
    var s='Showing';  
    cssrules.empty()
    if(on) {
        s='Hiding';
        cssrules.append("div.output_stderr, div[data-mime-type*='.stderr'] { display:none; }");
    }
    e.text(s+' warnings (click to toggle)');
    on=!on;
}).click();
$(element).append(e);
})(true);

### Validation analysis of lable_model for different dependenices and epochs

In [None]:
L_test = applier.apply(df_test)
Y_test = df_test.label.values

In [None]:
%%time
from snorkel.analysis import metric_score
from snorkel.utils import probs_to_preds

#epochs_list = [50, 100, 500, 1000, 5000, 6000][::-1]
#sig_list = [0.01, 0.05, 0.1, 0.5, 0.6, 0.75, 0.9]
#sig_crazy_list = [0.75, 0.9]
#thresh_list = [0.1, 0.25, 0.5, 1, 1.5, 2]
#lr_list = [0.01, 0.02, 0.05, 0.07, 0.1, 0.2]

epochs_list = [50, 100, 500, 1000, 5000][::-1]
sig_list = [0.01, 0.05, 0.1, 0.5, 0.75, 0.9]
sig_crazy_list = [0.75, 0.9]
thresh_list = [0.1, 0.5, 1, 1.5]
lr_list = [0.01, 0.025, 0.05, 0.1]

deps_names = ['NM_NP', 'CDGAM', 'Varma_Gold', 'Varma', 'Empty'] # !!!!!!!!!! modify this !!!!!!!!!!!

info = { 'CDM': {'deps_params': sig_list, 'snorkel_eps': epochs_list, 'snorkel_lr':lr_list}, 
            'NM': {'deps_params': sig_list, 'snorkel_eps': epochs_list, 'snorkel_lr':lr_list}, 
            'NM_NP': {'deps_params': sig_list, 'snorkel_eps': epochs_list, 'snorkel_lr':lr_list}, 
            'CDGAM': {'deps_params': sig_list, 'snorkel_eps': epochs_list, 'snorkel_lr':lr_list}, 
            'Mod_Varma': {'deps_params': thresh_list, 'snorkel_eps': epochs_list, 'snorkel_lr':lr_list}, 
            'Varma': {'deps_params': thresh_list, 'snorkel_eps': epochs_list, 'snorkel_lr':lr_list}, 
            'Varma_Gold': {'deps_params': thresh_list, 'snorkel_eps': epochs_list, 'snorkel_lr':lr_list},
            'Empty': {'deps_params': [-1], 'snorkel_eps': epochs_list, 'snorkel_lr':lr_list}}

f1_based_store = {key: {'param': -1, 'n_eps': -1, 'lr': -1, 'f1': -1} for key in deps_names}

def overall_deps_fn(deps_name, param):
    if deps_name == 'CDM':
        deps = CDM(L_dev, Y_dev, k=3, sig=param, policy = 'new', verbose=False, return_more_info = False)
    elif deps_name == 'CDGAM':
        deps = CDGAM(L_dev, k=3, sig=param, policy = 'new', verbose = False, return_more_info = False)
    elif deps_name == 'NM':
        deps = NM(L_dev, Y_dev, k=3, sig=param, policy = 'old', verbose=False, return_more_info = False)
    elif deps_name == 'NM_NP':
        deps = NM(L_dev, Y_dev, k=3, sig=param, policy = 'new', verbose=False, return_more_info = False)
    elif deps_name == 'Mod_Varma':
        deps = ModVarma_InCov(L_dev, Y_dev, thresh=param)
    elif deps_name == 'Varma':
        deps = get_varma_edges(L_dev, thresh=param)
    elif deps_name == 'Varma_Gold':
        deps = get_varma_with_gold_edges(L_dev, Y_dev, thresh=param)
    elif deps_name == 'Empty':
        deps = []
    return deps

ct=0
total = sum([len(info[deps_name]['deps_params']) for deps_name in deps_names]) * len(epochs_list) * len(lr_list)
for deps_name in deps_names:
    for param in info[deps_name]['deps_params']:
        
        deps = overall_deps_fn(deps_name, param)
        
        for n_eps in info[deps_name]['snorkel_eps']:
            for lr in info[deps_name]['snorkel_lr']:
                
                label_model = Informed_LabelModel(edges = deps, cardinality=3, verbose=True)
                label_model.fit(L_train, seed=12345, lr=lr, log_freq=n_eps/10, n_epochs=n_eps)
                
                probs_dev = label_model.predict_proba(L_dev)
                preds_dev = probs_to_preds(probs_dev)
                f1 = label_model.score(L_dev, Y_dev)['accuracy']
                
                if f1>f1_based_store[deps_name]['f1']:
                    print(deps_name, param, deps, n_eps, lr, " | f1: ", f1)
                    f1_based_store[deps_name] = {'param': param, 'n_eps': n_eps, 'lr': lr, 'f1': f1}
                
                ct +=1
                print(ct, " / ", total)


In [None]:
%%time
# test using params from validation
f1_based_test_store = {key: {'f1': -1} for key in deps_names}

for deps_name in deps_names:
    param = f1_based_store[deps_name]['param']
    n_eps = f1_based_store[deps_name]['n_eps']
    lr = f1_based_store[deps_name]['lr']
    
    deps = overall_deps_fn(deps_name, param)
    
    label_model = Informed_LabelModel(edges = deps, cardinality=3, verbose=True)
    label_model.fit(L_train, seed=12345, lr=lr, log_freq=n_eps/10, n_epochs=n_eps)
    
    probs_test = label_model.predict_proba(L_test)
    preds_test = probs_to_preds(probs_test)
    f1_test = label_model.score(L_test, Y_test)['accuracy']

    f1_based_test_store[deps_name] = {'f1': f1_test}

In [None]:
print(pd.DataFrame(f1_based_store))
print(pd.DataFrame(f1_based_test_store))