In [7]:
from dataset import YouCookII
from dataset import YouCookIICollate
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from model import Model

import numpy as np
import torch
import matplotlib.pyplot as plt

import itertools
import torch
import einops
import torch.nn.functional as F

from transformers import LxmertModel, LxmertTokenizer
from torch import nn
from torch.nn.utils.rnn import pad_sequence

from model import *

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
ACTION = '[unused3]'

lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
lxmert_tokenizer.add_special_tokens({"additional_special_tokens": [ACTION]})
lxmert_tokenizer.encode([ACTION])

lxmert = LxmertModel.from_pretrained("unc-nlp/lxmert-base-uncased")
lxmert.to(device)

ACTION_TOKEN = lxmert_tokenizer.convert_tokens_to_ids(ACTION)

In [41]:
NUM_ACTIONS = 8
MAX_DETECTIONS=20
BATCH_SIZE = 1

DETECTION_EMBEDDING_SIZE = 2048
OUTPUT_EMBEDDING_SIZE = 768
NUM_FRAMES_PER_STEP=5
MAX_DETECTIONS=20
CANDIDATES = NUM_FRAMES_PER_STEP * MAX_DETECTIONS

dataset = YouCookII(NUM_ACTIONS, "/h/sagar/ece496-capstone/datasets/ycii")
collate = YouCookIICollate(MAX_DETECTIONS=MAX_DETECTIONS)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

In [11]:
model = Model(device, MAX_DETECTIONS=20)

In [None]:
_, boxes, features, steps, entities, entity_count, _, _ = next(iter(dataloader))

In [44]:
model.eval()

with torch.no_grad():

    for data in dataloader:
        _, boxes, features, steps, entities, entity_count, _, _ = data
        loss_data, VG, RR = model(BATCH_SIZE, NUM_ACTIONS, steps, features, boxes, entities, entity_count)
        
        loss = compute_loss_batched(loss_data)
        accuracy = compute_alignment_accuracy_batched(loss_data)
        
        print("Loss: {}, Accuracy: {}".format(float(loss), accuracy))

tensor(9.0554)
Loss: 1411.52197265625, Accuracy: 0.4489795918367347
tensor(-1.0769)
Loss: 1218.94677734375, Accuracy: 0.5523809523809524
tensor(4.6980)
Loss: 1309.8203125, Accuracy: 0.45918367346938777
tensor(1.7204)
Loss: 1305.0159912109375, Accuracy: 0.5142857142857142
tensor(14.7055)
Loss: 1380.8665771484375, Accuracy: 0.4857142857142857
tensor(7.7888)
Loss: 1084.072509765625, Accuracy: 0.5803571428571429
tensor(5.3244)
Loss: 1290.072998046875, Accuracy: 0.45918367346938777
tensor(-2.6532)
Loss: 1280.8931884765625, Accuracy: 0.5523809523809524
tensor(15.9109)
Loss: 1337.2987060546875, Accuracy: 0.43956043956043955
tensor(6.6066)
Loss: 1459.7196044921875, Accuracy: 0.36507936507936506
tensor(0.7305)
Loss: 1435.6927490234375, Accuracy: 0.48739495798319327
tensor(6.8205)
Loss: 1409.686767578125, Accuracy: 0.3673469387755102
tensor(14.0222)
Loss: 1453.2303466796875, Accuracy: 0.45714285714285713
tensor(6.5387)
Loss: 1272.58984375, Accuracy: 0.5
tensor(11.9544)
Loss: 1462.01123046875, Ac

In [69]:
get_alignment_loss(model, dataloader, BATCH_SIZE)

tensor([[76803.7812]])

In [67]:
def get_alignment_loss(model, dataloader, batch_size):
    loss = 0
    
    with torch.no_grad():
        for data in dataloader:
            _, boxes, features, steps, entities, entity_count, _, _ = data
            loss_data, VG, RR = model(BATCH_SIZE, NUM_ACTIONS, steps, features, boxes, entities, entity_count)
            
            loss = loss + compute_loss_batched(loss_data)
        
    return loss

In [63]:
def compute_alignment_accuracy_batched(loss_data):
    total = 0
    correct = 0
    
    alignment_scores, entity_count, BATCH_SIZE, NUM_ACTIONS, MAX_ENTITIES = loss_data
    
    for batch_idx in range(BATCH_SIZE):
        _alignment_scores = alignment_scores[batch_idx]
        _entity_count = entity_count[batch_idx]
        
        _total, _correct = compute_alignment_accuracy((_alignment_scores, _entity_count, NUM_ACTIONS, MAX_ENTITIES))
        
        total = total + _total
        correct = correct + _correct
            
    return (total, correct)

def compute_alignment_accuracy(loss_data):
    total = 0
    correct = 0
    
    alignment_scores, entity_count, NUM_ACTIONS, MAX_ENTITIES = loss_data
    
    # l: ENTITY_ACTION_ID
    # e: ENTITY_ID
    # m: CANDIDATE_ACTION_ID
    
    for m in range(NUM_ACTIONS):
        for e in range(entity_count[m]):
            for l in range(NUM_ACTIONS):
                if m == l:
                    continue
                    
                aligned = compute_alignment(alignment_scores[m, e, m, :], alignment_scores[m, e, l, :])
                
                if aligned:
                    correct = correct + 1
                    
                total = total + 1
                
    return total, correct
    
def compute_alignment(score_m, score_l):
    '''
        score_m: score between entity from STEP M and candidates from STEP M
        score_l: score between entity from STEP M and candidates from STEP L
        
    '''
    
    m_max = score_m.max()
    l_max = score_l.max()
    
    return (m_max > l_max)

In [57]:
def compute_loss_batched(loss_data, margin=10):
    loss = 0.
    
    alignment_scores, entity_count, BATCH_SIZE, NUM_ACTIONS, MAX_ENTITIES = loss_data
    
    for batch_idx in range(BATCH_SIZE):
        _alignment_scores = alignment_scores[batch_idx]
        _entity_count = entity_count[batch_idx]
        
        loss = loss + compute_loss((_alignment_scores, _entity_count, NUM_ACTIONS, MAX_ENTITIES), margin)
        
    return loss

def compute_loss(loss_data, margin):    
    alignment_scores, entity_count, NUM_ACTIONS, MAX_ENTITIES = loss_data
    
    # Recall the shape of the alignment scores tensor:
    # ENTITY_ACTION_ID * ENTITY * CANDIDATE_ACTION_ID * CANDIDATE.
    
    S = torch.zeros((NUM_ACTIONS, NUM_ACTIONS))
    zero = torch.zeros((1, 1))
    
    # l: ENTITY_ACTION_ID
    # m: CANDIDATE_ACTION_ID
    
    for l in range(NUM_ACTIONS):
        for m in range(NUM_ACTIONS):
            S[l][m] = compute_S(alignment_scores[m, :, l, :], entity_count[m])
                
    loss = 0.
    
    for l in range(NUM_ACTIONS):
        for m in range(NUM_ACTIONS):
            loss = loss + torch.max(S[l][m] - S[l][l] + margin, zero) + torch.max(S[m][l] - S[l][l] + margin, zero)
    
    return loss
        
def compute_S(scores, entity_count):
    '''
        scores: Alignment scores between entities from STEP M and candidates (boxes) from step STEP L.
    '''
    if entity_count == 0:
        return 0.
    
    # Remove padded dimension.
    scores = scores[:entity_count]
    
    S = scores.max(dim=-1)[0].sum()
    
    return S

# Testing

In [None]:
features = features.to(device)
boxes = boxes.to(device)

steps = remove_unused2(steps)

inputs = lxmert_tokenizer(
    steps,
    padding="longest",
    truncation=False,
    return_token_type_ids=True,
    return_attention_mask=True,
    add_special_tokens=True,
    return_tensors="pt"
)

inputs.input_ids = inputs.input_ids.to(device)
inputs.attention_mask = inputs.attention_mask.to(device)
inputs.token_type_ids = inputs.token_type_ids.to(device)

output = lxmert(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    visual_feats=features,
    visual_pos=boxes,
    token_type_ids=inputs.token_type_ids,
    return_dict=True,
    output_attentions=True
)

In [None]:
entity_idx = get_ent_inds(model, entities, steps)

linguistic_embeddings = get_entity_embeddings(output['language_output'], entity_idx)
vision_embeddings = output['vision_output']

In [None]:
split_sizes = torch.tensor(entity_count).flatten().tolist()
entity_embeddings = linguistic_embeddings.split(split_sizes)

In [None]:
NUM_ACTIONS = num_actions 

E = pad_sequence(entity_embeddings, batch_first=True)
E = E.reshape(-1, NUM_ACTIONS, E.shape[1], E.shape[2])

max_entities = E.shape[1]

In [None]:
V = vision_embeddings.reshape(BATCH_SIZE, NUM_ACTIONS, CANDIDATES, -1)

In [None]:
VG_scores = torch.einsum('bacs, baes -> baec', V, E)
VG_scores_max, VG_scores_index = VG_scores.max(dim=-1)

In [None]:
alignment_scores = torch.einsum('bqcs, bwes -> bweqc', V, E)