In [4]:
import torch
import torch.nn as nn
import numpy as np
import os

# ORG-Module

Object Relational Graph is a module that learns to describe an object based on its relationship with others in a video. The algorithm consists many steps and stated in the following order:

1. Apply pretrained object detector to capture severall class-agnostic proposal.
2. The object features is captured on each keyframes.
3. The object features then stored in R, where i is the i-th keyframes, and k is the k-th object.
4. The number of objects extracted from each frames are five objects.
5. The R variable consist of 5 independent object features.
6. Define Object Set R K x d, where K is the number of object nodes, and d is the dimension features.
7. Define A, where A is a relation coefficient matrix between K nodes.
8. Before feeding to A, the R variable is feed to **Fully connected layer** with bias resulting in R'.
9. Then A is the product of fully connected layer between R' and R'T
10. After that, the product is activated using softmax function and named A^
11. Apply the GCN function, R^ = A^ . R . Wr, Where Wr is learnable parameter
12. R^ is the enhanced object features with interaction message between objects

# Develop Side

In [16]:
# the object feats has the dimension of Frames x Objs x features
# with batch dimension it becomes 4-D tensor

feat_dims = 512
k_object = 5

# this means the object is the second object
# of the first frame

r_obj_feats = torch.rand(k_objects, feat_dims)

In [45]:
# based on ORG paper A is equal to:
# φ(R) . transpose(ψ(R))
# where : ...
# φ(R) = R . Wi + bi
# ψ(R) = R . wj + bj

in_features = feat_dims
out_features = feat_dims

sigma_r = nn.Linear(in_features, out_features)
psi_r = nn.Linear(in_features, out_features)
a_softmax = nn.Softmax(dim=1)

w_r = nn.Linear(in_features, out_features, bias=False)

In [46]:
sigma_r_out = sigma_r(r_obj_feats)
psi_r_out = psi_r(r_obj_feats)

In [47]:
a_coeff_mat = torch.matmul(sigma_r_out, torch.t(psi_r_out))

In [48]:
a_hat = a_softmax(a_coeff_mat)

In [49]:
a_hat_mul_r = torch.matmul(a_hat, r_obj_feats)

In [50]:
output = w_r(a_hat_mul_r)

In [53]:
output

tensor([[-0.2951,  0.0876, -0.1391,  ...,  0.0284, -0.2573, -0.5024],
        [-0.2546,  0.1451, -0.1221,  ..., -0.0057, -0.2320, -0.4991],
        [-0.3415,  0.0536, -0.1762,  ...,  0.0159, -0.2983, -0.4954],
        [-0.2284,  0.1758, -0.1091,  ..., -0.0296, -0.2088, -0.4882],
        [-0.2611,  0.1293, -0.1218,  ...,  0.0073, -0.2321, -0.4986]],
       grad_fn=<MmBackward0>)

# Class Side (Alpha)

In [55]:
class ORG(nn.Module):
    
    def __init__(self, feat_dims):
        super(ORG, self).__init__()
        '''
        Object Relational Graph (ORG) is a module that learns 
        to describe an object based on its relationship 
        with others in a video.
        
        Arguments:
            feat_size : The object feature size that obtained from
                        the last fully-connected layer of the backbone
                        of Faster R-CNN, this case is 512
        '''
        
        sigma_r = nn.Linear(feat_dims, feat_dims)
        psi_r = nn.Linear(feat_dims, feat_dims)
        
        a_softmax = nn.Softmax(dim=1)
        
        w_r = nn.Linear(feat_dims, feat_dims, bias=False)
        
    def forward(self, r_obj_feat):
        sigma_r_out = sigma_r(r_obj_feats)
        psi_r_out = psi_r(r_obj_feats)
        
        a_coeff_mat = torch.matmul(sigma_r_out, torch.t(psi_r_out))
        a_hat = a_softmax(a_coeff_mat)
        
        a_hat_mul_r = torch.matmul(a_hat, r_obj_feats)
        output = w_r(a_hat_mul_r)
        
        return output

In [57]:
org_module = ORG(feat_dims)

In [60]:
r_hat = org_module(r_obj_feats)
r_hat

tensor([[-0.2951,  0.0876, -0.1391,  ...,  0.0284, -0.2573, -0.5024],
        [-0.2546,  0.1451, -0.1221,  ..., -0.0057, -0.2320, -0.4991],
        [-0.3415,  0.0536, -0.1762,  ...,  0.0159, -0.2983, -0.4954],
        [-0.2284,  0.1758, -0.1091,  ..., -0.0296, -0.2088, -0.4882],
        [-0.2611,  0.1293, -0.1218,  ...,  0.0073, -0.2321, -0.4986]],
       grad_fn=<MmBackward0>)

# In Practice Using Faster R-CNN Object Features (Beta)