In [26]:
import numpy as np
import torch

In [129]:
y = 0.5

# 5 steps with each having 2 entities.
E = torch.rand(5, 2)
V = torch.rand(5, 2)
R = torch.round(torch.rand(5, 5))

loss_RAVG(y, R, E, V)

tensor(1.4185)

In [132]:
# TODO: modify function to work with batches and non-scalar embeddings.

def loss_RAVG(y, R, E, V):
    '''
    Custom loss function for reference-aware visual grounding. This function must
    use PyTorch functions for any parameters that are to be trained. Some of the
    embeddings have been modified to match the VisualBERT model.
    
    The outputs of the VG extension are the alignment scores between the bounding
    box proposals and entities for that particular step. The size of this matrix
    is E x P, where E is the number of entities and P is the number of proposals.
    
    Source: http://vision.stanford.edu/pdf/huang-buch-2018cvpr.
    
    Input:
        y: alignment penalty (hyperparameter)
    
        R:
            references
            -----------------------
            The size of the matrix is M x M. M is the number of steps.
            
            R_lj contains a 1 if there is any backward reference between a_l and a_j.
            
        V:
            visual embeddings
            -----------------------
            The size of the matrix is M x J. M is the number of steps and J is the maximum
            number of entities across all of the steps.
            
            V_mj contains the visual embedding that corresponds to e_mj entity. This is the
            b_mj bounding box (bounding box with highest alignment score). Since each the
            steps may have a different number of entities, there may be some padding of 1 
            required.
            
            Note: for now, we assume that the embedding is a scalar.
            
        E:
            reference aware entity embeddings
            -----------------------
            The size of the matrix is M x J. E_mj contains the reference aware entity embedding
            of e_mj. Since each the steps may have a different number of entities, there may be
            some padding of -inf required.
            
            Note: for now, we assume that the embedding is a scalar.
    '''
    
    M = E.shape[0]
    J = E.shape[1]

    # Compute the outer product between E and V. This essentially calculates
    # all possible alignment scores across the entire video. scores_lmjk is
    # the alignment score of e_mj and b_lk (same as equation). The size of
    # the matrix is M x J x M x J.
    #
    # Ref: stackoverflow.com/questions/24839481/python-matrix-outer-product
    #
    # Note that the alignment score for the padding matrices will be -inf due
    # to the way they're configured (1 * -inf).
    scores = torch.einsum('mj, lk -> lmjk', E, V)

    # The best alignment score between e_mj and b_lk (over all k).
    max_k_align = scores.max(3)[0]

    # Sum all of the best alignment scores from e_mj (over all j).
    S_lm = max_k_align.sum(2)

    # Find the transposed version.
    S_ml = S_lm.transpose(0, 1)

    # Self alignment scores.
    S_ll = S_lm[0].repeat(M, 1)

    # Compute reference based penalty. 1 if none of the entities in e_m refer
    # to a_l, constant (hyperparameter) otherwise. We can use R since it has
    # the mappings between each of the actions. This is a M x M matrix.
    Y_lm = R * y
    Y_ml = Y_lm.transpose(0, 1)

    # Zero matrix.
    zero = torch.zeros(M, M)

    # S_ll needs to have the rows filled with diagonal values. Note that unsqueeze(1)
    # for a vector is the same as transposing it.
    S_ll = S_lm.diagonal().unsqueeze(1).repeat(1, M)

    # Vectorization magic.
    loss = (Y_lm * torch.max(zero, S_lm - S_ll) + Y_ml * torch.max(zero, S_ml - S_ll)).sum()
    
    return loss

In [134]:
def train(model, train_loader, valid_loader, epochs=25, lr=0.001, batch_size=10, y=0.5):
    '''
    Training loop for the model (VisualBERT + extensions).
    '''
    
    optimizer = torch.optim.adam(model.parameters, lr=lr)
    
    # Ouput losses.
    train_loss = np.zeros(epochs)
    valid_loss = np.zeros(epochs)
    
    # Output accuracies.
    train_accuracy = np.zeroes(epochs)
    valid_accuracy = np.zeroes(epochs)
    
    for epoch in range(epochs):
        for videos, transcripts in iter(train_loader):
            # Note that there is no target since supervision is 
            # provided only through the alignment.
            
            # Zero out any gradients.
            optimizer.zero_grad()
            
            # Run inference (forward pass).
            outputs = model(videos, transcripts)
            
            # TODO: Generate R, E, V.
            m_loss_RAVG = loss_RAVG(R, E, V)
            
            # TODO: Loss for GARR.
            m_loss_GARR = 0
            
            loss = m_loss_RAVG + m_loss_GARR
            
            # Backpropagation (backward pass).
            loss.backward()
            
            # Update parameters.
            optimizer.step()
            
        # TODO: save loss and accuracy at each epoch, plot (and checkpoint).

$$P\left(d_{i j}=(l, k) \mid E, A, B, R\right)=\operatorname{sigmoid}\left(\psi\left(b_{l k}\right)^{T} \phi_{e}^{R}\left(e_{i j}\right)\right)$$

$$\phi_{e}^{R}\left(e_{i j}\right)=w o r d Embd\left(e_{i j}\right)+\phi_{a}^{R}\left(a_{o}\right)$$

\begin{aligned}
\max _{D_{l}} P\left(D_{l} \mid \bar{G}_{l}, B_{l}\right) &>\max _{D_{l}} P\left(D_{l} \mid \bar{G}_{l}, B_{m}\right) \\
\max _{D_{l}} P\left(D_{l} \mid \bar{G}_{l}, B_{l}\right) &>\max _{D_{n}} P\left(D_{n} \mid \bar{G}_{n}, B_{l}\right)
\end{aligned}

$$S_{l m}^{R}=\sum_{j} \max _{k}\left\langle\phi_{e}^{R}\left(e_{m j}\right), \psi_{b}\left(b_{l k}\right)\right\rangle$$

\begin{aligned}
\mathcal{L}_{R A-M I L}=\sum_{l}[& \sum_{m} \gamma_{l m} \cdot \max \left(0, S_{l m}^{R}-S_{l l}^{R}+\Delta\right) \\
&\left.+\sum_{m} \gamma_{m l} \cdot \max \left(0, S_{m l}^{R}-S_{l l}^{R}+\Delta\right)\right]
\end{aligned}

Loss function terminology:
    
        l, m: step indices
        j, k: entity indices
        R: reference resolution edges

        a_m: action step m
        e_mj: j'th entity in step m
        b_lk: bounding box for the k'th entity in step l

        a(R, m, j): action referred to by the entity e_mj

        ψ(l, k):
            visual embedding of bounding box
            -----------------------
            ψ = VisualBERT_embedding(b_lk)

        φA(m):
            action embedding
            -----------------------
            φA = avg_j(VisualBERT_embedding(e_mj))

        φE(R, m, j):
            reference-aware entity embedding
            -----------------------
            φE = word_embedding(e_mj) + φA(a(R, m, j))

        γ(l, m):
            reference-based penalty
            -----------------------
            γ = 1     : if none of the entities in step m (a_m) have a reference to step l (a_l)
            0 < γ < 1 : if atleast one entity in step m (a_m) has a reference to step l (a_l)

        score(R, m, j, l, k):
            alignment score between entity (e_mj) and bounding box (b_lk)
            -----------------------
            score = φE(R, m, j) · ψ(l, k)

        S(R, l, m):
            alignment score between steps (a_l and a_m)
            -----------------------
            S = sum_j(max_k(φE(R, m, j), ψ(l, k)))

        Loss = sum_l
               (
                    sum_m [   γ(l, m) * max(0, S(R, l, m) - S(R, l, l))   ] 
                    sum_m [   γ(m, l) * max(0, S(R, m, l) - S(R, l, l))   ]
               )