In [1]:
import random
import json
import numpy as np
import pandas as pd

import skimage

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from torchvision.models._utils import IntermediateLayerGetter

from models.modeling.deeplab import *
from dataloader.talk2car import *

from PIL import Image
from skimage.transform import resize

from models.model import JointModel

from utils.im_processing import *
from utils.metrics import *

from collections import Counter
from nltk.corpus import stopwords

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
class Args:
    lr = 3e-4
    batch_size = 64
    num_workers = 4
    image_encoder = "deeplabv3_plus"
    num_layers = 1
    num_encoder_layers = 1
    dropout = 0.25
    skip_conn = False
    model_path = "./saved_model/talk2car/baseline_drop_0.25_bs_64_el_1_sl_40_bce_0.49785.pth"
    loss = "bce"
    dataroot = "/ssd_scratch/cvit/kanishk/"
    glove_path = "/ssd_scratch/cvit/kanishk/glove/"
    dataset = "talk2car"
    task = "talk2car"
    split = "val"
    seq_len = 40
    image_dim = 448
    mask_dim = 448
    mask_thresh = 0.3
    area_thresh = 0.4
    topk = 10
    metric = "pointing_game"

args = Args()

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_gpu = torch.cuda.device_count()
print(f'{device} being used with {n_gpu} GPUs!!')

cuda being used with 2 GPUs!!


In [5]:
print("Initializing dataset")

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()
resize = transforms.Resize((args.image_dim, args.image_dim))


val_dataset = Talk2Car(
    root=args.dataroot,
    split=args.split,
    transform=transforms.Compose([resize, to_tensor, normalize]),
    mask_transform=transforms.Compose([ResizeAnnotation(args.mask_dim)]),
    glove_path=args.glove_path
)

val_loader = DataLoader(
    val_dataset, shuffle=True, batch_size=1, num_workers=0, pin_memory=True
)

data_len = val_dataset.__len__()
print(f'Length of dataset: {data_len}')

Initializing dataset
Length of dataset: 1163


In [6]:
return_layers = {"layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}

model = DeepLab(num_classes=21, backbone="resnet", output_stride=16)
model.load_state_dict(torch.load("./models/deeplab-resnet.pth.tar")["state_dict"])

image_encoder = IntermediateLayerGetter(model.backbone, return_layers)

for param in image_encoder.parameters():
    param.requires_grad_(False)

In [7]:
in_channels = 2048
out_channels = 512
stride = 2

joint_model = JointModel(
    in_channels=in_channels,
    out_channels=out_channels,
    stride=stride,
    num_layers=args.num_layers,
    num_encoder_layers=args.num_encoder_layers,
    dropout=args.dropout,
    skip_conn=args.skip_conn,
    mask_dim=args.mask_dim,
)

if n_gpu > 1:
    image_encoder = nn.DataParallel(image_encoder)
    joint_model = nn.DataParallel(joint_model)

state_dict = torch.load(args.model_path)
if "state_dict" in state_dict:
    state_dict = state_dict["state_dict"]
joint_model.load_state_dict(state_dict) 

joint_model.to(device)
image_encoder.to(device)

image_encoder.eval();
joint_model.eval();

In [8]:
def compute_mask_IOU(masks, target, thresh=0.3):
    assert(target.shape[-2:] == masks.shape[-2:])
    temp = ((masks>thresh) * target)
    intersection = temp.sum()
    union = (((masks>thresh) + target) - temp).sum()
    return intersection, union

def meanIOU(m, gt, t):
    temp = ((m > t)*gt)
    inter = temp.sum()
    union = ((m > t) + gt - temp).sum()
    return inter/union

In [9]:
def intersection(command, key_words):
    count = 0
    words = []
    for key in key_words:
        if key in command:
            words.append(key)
            count += 1
    return count, words

### Evaluation Metric

In [10]:
args.metric = "recall_at_k"
args.topk = 5000

args.mask_thresh = 0.3

image_encoder.eval()
joint_model.eval()

print(f'{args.metric} for {args.split} split')

exp_params = []

if args.metric == "intersection_at_t":
    exp_params = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
elif args.metric == "recall_at_k":
    exp_params = [1, 5, 10, 50, 100, 500, 1000, 5000]
else:
    exp_params = [0]

for param in exp_params:
    total_inter = 0
    total_union = 0

    total_accuracy = 0

    mean_IOU = 0
    total_accuracy = 0

    data_len = len(val_loader)

    n_iter = 0

    base_correct = 0

    for step, batch in enumerate(val_loader):

        img = batch["image"].cuda(non_blocking=True)

        phrase = batch["phrase"].cuda(non_blocking=True)
        phrase_mask = batch["phrase_mask"].cuda(non_blocking=True)
        index = batch["index"]

        gt_mask = batch["seg_mask"]
        gt_mask = gt_mask.squeeze(dim=1)

        _, h, w = gt_mask.shape

        if gt_mask[0, w//2, h//2] > 0:
            base_correct += 1

        batch_size = img.shape[0]
        img_mask = torch.ones(batch_size, 14 * 14, dtype=torch.int64).cuda(
            non_blocking=True
        )

        orig_phrase = batch["orig_phrase"][0]
        phrase_len = len(orig_phrase.split())

        with torch.no_grad():
            img = image_encoder(img)

        output_mask = joint_model(img, phrase, img_mask, phrase_mask)
        output_mask = output_mask.detach().cpu()

        inter, union = compute_batch_IOU(output_mask, gt_mask, args.mask_thresh)

        total_inter += inter.sum().item()
        total_union += union.sum().item()

        accuracy = 0
        if args.metric == "pointing_game":
            accuracy += pointing_game(output_mask, gt_mask)
        elif args.metric == "intersection_at_t":
            accuracy += intersection_at_t(output_mask, gt_mask, args.mask_thresh, param)
        elif args.metric == "recall_at_k":
            accuracy += recall_at_k(output_mask, gt_mask, param)
        elif args.metric == "dice_score":
            accuracy += dice_score(output_mask, gt_mask, args.mask_thresh)
        total_accuracy += accuracy

        score = 0 if union.item() == 0 else inter.item() / union.item()

        mean_IOU += score

        total_score = total_inter / total_union

    overall_IOU = total_inter / total_union
    mean_IOU = mean_IOU / data_len
    final_accuracy = total_accuracy / data_len

    if args.metric == "pointing_game":
        center_accuracy = base_correct / data_len
        print(f'Center Accuracy: {center_accuracy}')

    if args.metric == "intersection_at_t":
        print(f"Area_Thresh={param}: Accuracy:{final_accuracy}, Overall_IOU: {overall_IOU}, Mean_IOU: {mean_IOU}")
    elif args.metric =="recall_at_k":
        print(f"K={param}: Accuracy:{final_accuracy}, Overall_IOU: {overall_IOU}, Mean_IOU: {mean_IOU}")
    else:
        print(f"Pointing Game Accuracy:{final_accuracy}, Overall_IOU: {overall_IOU}, Mean_IOU: {mean_IOU}")

recall_at_k for val split
K=1: Accuracy:0.49785038693035255, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878786
K=5: Accuracy:0.5184866723989682, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878806
K=10: Accuracy:0.5270851246775581, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878817
K=50: Accuracy:0.5528804815133276, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878806
K=100: Accuracy:0.5692175408426483, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878795
K=500: Accuracy:0.6448839208942391, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878797
K=1000: Accuracy:0.6964746345657782, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878803
K=5000: Accuracy:0.8632846087704213, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.1702530485187879


### Command Length Analysis 

In [11]:
args.metric = "pointing_game"
metric = args.metric

args.topk = 200

args.mask_thresh = 0.3

image_encoder.eval()
joint_model.eval()

total_inter = 0
total_union = 0

total_accuracy = 0

result_map = {
    '0-10': {},
    '10-20': {},
    '20-': {},
}

area_thresh = args.metric

mean_IOU = 0
total_accuracy = 0

data_len = len(val_loader)

n_iter = 0

base_correct = 0

for step, batch in enumerate(val_loader):

    img = batch["image"].cuda(non_blocking=True)

    phrase = batch["phrase"].cuda(non_blocking=True)
    phrase_mask = batch["phrase_mask"].cuda(non_blocking=True)
    index = batch["index"]

    gt_mask = batch["seg_mask"]
    gt_mask = gt_mask.squeeze(dim=1)
    
    _, h, w = gt_mask.shape
    
    if gt_mask[0, w//2, h//2] > 0:
        base_correct += 1

    batch_size = img.shape[0]
    img_mask = torch.ones(batch_size, 14 * 14, dtype=torch.int64).cuda(
        non_blocking=True
    )
    
    orig_phrase = batch["orig_phrase"][0]
    phrase_len = len(orig_phrase.split())
    
    with torch.no_grad():
        img = image_encoder(img)

    output_mask = joint_model(img, phrase, img_mask, phrase_mask)
    output_mask = output_mask.detach().cpu()
    
    # count, inter_words = intersection(orig_phrase)

    inter, union = compute_batch_IOU(output_mask, gt_mask, args.mask_thresh)

    total_inter += inter.sum().item()
    total_union += union.sum().item()

    accuracy = 0
    if args.metric == "pointing_game":
        accuracy += pointing_game(output_mask, gt_mask)
    elif args.metric == "recall_at_k":
        accuracy += recall_at_k(output_mask, gt_mask, args.topk)
    elif args.metric == "dice_score":
        accuracy += dice_score(output_mask, gt_mask, args.mask_thresh)
    total_accuracy += accuracy

    if phrase_len < 10:
        if metric not in result_map['0-10']:
            result_map['0-10'][metric] = []
        result_map['0-10'][metric].append(accuracy)
    elif phrase_len < 20:
        if metric not in result_map['10-20']:
            result_map['10-20'][metric] = []
        result_map['10-20'][metric].append(accuracy)
    else:
        if metric not in result_map['20-']:
            result_map['20-'][metric] = []
        result_map['20-'][metric].append(accuracy)
        
    score = 0 if union.item() == 0 else inter.item() / union.item()

    mean_IOU += score

    total_score = total_inter / total_union

overall_IOU = total_inter / total_union
mean_IOU = mean_IOU / data_len
final_accuracy = total_accuracy / data_len

if args.metric == "pointing_game":
    center_accuracy = base_correct / data_len
    print(f'Center Accuracy: {center_accuracy}')

print(f'{args.metric} for {args.split} split')
print(f"Accuracy:{final_accuracy}, Overall_IOU: {overall_IOU}, Mean_IOU: {mean_IOU}, Total: {n_iter}, ")

Center Accuracy: 0.0025795356835769563
pointing_game for val split
Accuracy:0.49785038693035255, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878797, Total: 0, 


In [12]:
print(f'Results at mask_thresh: {args.mask_thresh}')

refined_map = {}
for key1 in result_map:
    refined_map[key1] = {}
    total_len = 0
    for key2 in result_map[key1]:
        total_len += len(result_map[key1][key2])
        accuracy = torch.tensor(result_map[key1][key2])
        refined_map[key1][key2] = accuracy.mean().item()
    refined_map[key1]['count'] = total_len // len(result_map[key1])

df_result = pd.DataFrame.from_dict(refined_map,orient='index')
df_result.index.name = 'phrase_len'

df_result

Results at mask_thresh: 0.3


Unnamed: 0_level_0,pointing_game,count
phrase_len,Unnamed: 1_level_1,Unnamed: 2_level_1
0-10,0.520958,501
10-20,0.48552,587
20-,0.44,75


### Command Type Analysis

In [13]:
action_words = ["stop", "slow", "wait", "park", "speed", "change", "turn", "follow", "u-turn", "straight", "pull", "switch"]

In [14]:
args.metric = "pointing_game"
metric = args.metric

args.topk = 200

args.mask_thresh = 0.3

image_encoder.eval()
joint_model.eval()

total_inter = 0
total_union = 0

total_accuracy = 0

result_map = {key: {} for key in action_words}

mean_IOU = 0
total_accuracy = 0

data_len = len(val_loader)

n_iter = 0

base_correct = 0

for step, batch in enumerate(val_loader):

    img = batch["image"].cuda(non_blocking=True)

    phrase = batch["phrase"].cuda(non_blocking=True)
    phrase_mask = batch["phrase_mask"].cuda(non_blocking=True)
    index = batch["index"]
    
    gt_mask = batch["seg_mask"]
    gt_mask = gt_mask.squeeze(dim=1)
    
    _, h, w = gt_mask.shape
    
    if gt_mask[0, w//2, h//2] > 0:
        base_correct += 1

    batch_size = img.shape[0]
    img_mask = torch.ones(batch_size, 14 * 14, dtype=torch.int64).cuda(
        non_blocking=True
    )
    
    orig_phrase = batch["orig_phrase"][0]
    phrase_len = len(orig_phrase.split())
    
    with torch.no_grad():
        img = image_encoder(img)

    output_mask = joint_model(img, phrase, img_mask, phrase_mask)
    output_mask = output_mask.detach().cpu()
    
    count, inter_words = intersection(orig_phrase, action_words)

    inter, union = compute_batch_IOU(output_mask, gt_mask, args.mask_thresh)

    total_inter += inter.sum().item()
    total_union += union.sum().item()

    accuracy = 0
    if args.metric == "pointing_game":
        accuracy += pointing_game(output_mask, gt_mask)
    elif args.metric == "intersection_at_t":
        accuracy += intersection_at_t(output_mask, gt_mask, args.mask_thresh, args.area_thresh)
    elif args.metric == "recall_at_k":
        accuracy += recall_at_k(output_mask, gt_mask, args.topk)
    elif args.metric == "dice_score":
        accuracy += dice_score(output_mask, gt_mask, args.mask_thresh)
    total_accuracy += accuracy

    for word in action_words:
        if word in inter_words:
            if metric not in result_map[word]:
                result_map[word][metric] = []
            result_map[word][metric].append(accuracy)

    score = 0 if union.item() == 0 else inter.item() / union.item()

    mean_IOU += score

    total_score = total_inter / total_union

overall_IOU = total_inter / total_union
mean_IOU = mean_IOU / data_len
final_accuracy = total_accuracy / data_len

print(f'{args.metric} for {args.split} split')

if args.metric == "pointing_game":
    center_accuracy = base_correct / data_len
    print(f'Center Accuracy: {center_accuracy}')

print(f"Accuracy:{final_accuracy}, Overall_IOU: {overall_IOU}, Mean_IOU: {mean_IOU}, Total: {n_iter}, ")

pointing_game for val split
Center Accuracy: 0.0025795356835769563
Accuracy:0.49785038693035255, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.170253048518788, Total: 0, 


In [15]:
print(f'Results at mask_thresh: {args.mask_thresh}')

refined_map = {}
for key1 in result_map:
    refined_map[key1] = {}
    total_len = 0
    for key2 in result_map[key1]:
        total_len += len(result_map[key1][key2])
        accuracy = torch.tensor(result_map[key1][key2])
        refined_map[key1][key2] = accuracy.mean().item()
    if len(result_map[key1]) > 0:
        refined_map[key1]['count'] = total_len // len(result_map[key1])

df_result = pd.DataFrame.from_dict(refined_map,orient='index')
df_result.index.name = 'action'

df_result

Results at mask_thresh: 0.3


Unnamed: 0_level_0,pointing_game,count
action,Unnamed: 1_level_1,Unnamed: 2_level_1
stop,0.510067,149
slow,0.565789,76
wait,0.477273,44
park,0.438127,299
speed,0.538462,13
change,0.416667,24
turn,0.506073,247
follow,0.488889,135
u-turn,0.333333,18
straight,0.384615,39


### Temporal Vs Non-Temporal Words

In [16]:
temporal_words = ["once", "then", "when", "sometime", "minutes", "possible", "safe", "wait", "while", "check", "before", "until", "after", "soon", "but", "slow", "follow", "u-turn", "you turn",]
action_words = ["stop", "slow", "wait", "park", "speed", "change", "turn", "follow", "u-turn", "straight", "pull", "switch", "pick"]

In [17]:
args.metric = "pointing_game"
metric = args.metric

args.topk = 200

args.mask_thresh = 0.3
args.area_thresh = 0.4

image_encoder.eval()
joint_model.eval()

total_inter = 0
total_union = 0

total_accuracy = 0

area_thresh = args.metric

result_map = {
    'non-temporal': {},
    'temporal': {},
}

mean_IOU = 0
total_accuracy = 0

data_len = len(val_loader)

n_iter = 0

for step, batch in enumerate(val_loader):

    img = batch["image"].cuda(non_blocking=True)

    phrase = batch["phrase"].cuda(non_blocking=True)
    phrase_mask = batch["phrase_mask"].cuda(non_blocking=True)
    index = batch["index"]

    gt_mask = batch["seg_mask"]
    gt_mask = gt_mask.squeeze(dim=1)
    
    _, h, w = gt_mask.shape
    
    if gt_mask[0, w//2, h//2] > 0:
        base_correct += 1

    batch_size = img.shape[0]
    img_mask = torch.ones(batch_size, 14 * 14, dtype=torch.int64).cuda(
        non_blocking=True
    )
    
    orig_phrase = batch["orig_phrase"][0]
    phrase_len = len(orig_phrase.split())
    
    with torch.no_grad():
        img = image_encoder(img)

    output_mask = joint_model(img, phrase, img_mask, phrase_mask)
    output_mask = output_mask.detach().cpu()
    
    count_temp, inter_words = intersection(orig_phrase, temporal_words)
    count_action, inter_words = intersection(orig_phrase, action_words)

    inter, union = compute_batch_IOU(output_mask, gt_mask, args.mask_thresh)

    total_inter += inter.sum().item()
    total_union += union.sum().item()

    accuracy = 0
    if args.metric == "pointing_game":
        accuracy += pointing_game(output_mask, gt_mask)
    elif args.metric == "intersection_at_t":
        accuracy += intersection_at_t(output_mask, gt_mask, args.mask_thresh, area_thresh)
    elif args.metric == "recall_at_k":
        accuracy += recall_at_k(output_mask, gt_mask, args.topk)
    elif args.metric == "dice_score":
        accuracy += dice_score(output_mask, gt_mask, args.mask_thresh)
    total_accuracy += accuracy

    if count_temp > 0 or count_action > 1:
        if area_thresh not in result_map['temporal']:
            result_map['temporal'][area_thresh] = []
        result_map['temporal'][area_thresh].append(accuracy)
    else:
        if area_thresh not in result_map['non-temporal']:
            result_map['non-temporal'][area_thresh] = []
        result_map['non-temporal'][area_thresh].append(accuracy)

    score = 0 if union.item() == 0 else inter.item() / union.item()

    mean_IOU += score

    total_score = total_inter / total_union

overall_IOU = total_inter / total_union
mean_IOU = mean_IOU / data_len
final_accuracy = total_accuracy / data_len

print(f'{args.metric} for {args.split} split')

if args.metric == "pointing_game":
    center_accuracy = base_correct / data_len
    print(f'Center Accuracy: {center_accuracy}')

print(f"Accuracy:{final_accuracy}, Overall_IOU: {overall_IOU}, Mean_IOU: {mean_IOU}, Total: {n_iter}, ")

pointing_game for val split
Center Accuracy: 0.005159071367153913
Accuracy:0.49785038693035255, Overall_IOU: 0.19882105059992924, Mean_IOU: 0.17025304851878814, Total: 0, 


In [18]:
print(f'Results at mask_thresh: {args.mask_thresh}')

refined_map = {}
for key1 in result_map:
    refined_map[key1] = {}
    total_len = 0
    for key2 in result_map[key1]:
        total_len += len(result_map[key1][key2])
        accuracy = torch.tensor(result_map[key1][key2])
        refined_map[key1][key2] = accuracy.mean().item()
    refined_map[key1]['count'] = total_len // len(result_map[key1])

df_result = pd.DataFrame.from_dict(refined_map,orient='index')
df_result.index.name = 'command_type'

df_result

Results at mask_thresh: 0.3


Unnamed: 0_level_0,pointing_game,count
command_type,Unnamed: 1_level_1,Unnamed: 2_level_1
non-temporal,0.493789,644
temporal,0.50289,519
