In [None]:
import sys

import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

import cv2
from PIL import Image
from torchvision import transforms
from pipelines.video_action_recognition_config import get_cfg_defaults
from models.tuber_ava import build_model
from glob import glob
import json
import datasets.video_transforms as T
import random

def read_label_map(label_map_path):

    item_id = None
    item_name = None
    items = {}
    
    with open(label_map_path, "r") as file:
        for line in file:
            line.replace(" ", "")
            if line == "item{":
                pass
            elif line == "}":
                pass
            elif "id:" in line:
                item_id = int(line.split(":", 1)[1].strip())
            elif "name" in line:
                item_name = line.split(":", 1)[1].replace("'", "").strip()

            if item_id is not None and item_name is not None:
                items[item_id] = item_name
                item_id = None
                item_name = None
            items[81] = "happens"

    return items

items = read_label_map("../assets/ava_action_list_v2.1.pbtxt")

def make_transforms(image_set, cfg):
    normalize = T.Compose([
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    print("transform image crop: {}".format(cfg.CONFIG.DATA.IMG_SIZE))
    if image_set == 'train':
        return T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomSizeCrop_Custom(cfg.CONFIG.DATA.IMG_SIZE),
            T.ColorJitter(),
            normalize,
        ])

    if image_set == 'val':
        return T.Compose([
            T.Resize_Custom(cfg.CONFIG.DATA.IMG_SIZE),
            normalize,
        ])

    if image_set == 'visual':
        return T.Compose([
            T.Resize_Custom(cfg.CONFIG.DATA.IMG_SIZE),
            normalize,
        ])
    raise ValueError(f'unknown {image_set}')

cfg = get_cfg_defaults()
cfg.merge_from_file("./configuration/TubeR_CSN50_AVA21.yaml")
model, _, _ = build_model(cfg)

from models.dab_conv_trans import build_model
cfg2 = get_cfg_defaults()
cfg2.merge_from_file("./configuration/Dab_conv_trans_CSN152_AVA22.yaml")
current_model, _, _ = build_model(cfg2)

from models.dab_baseline import build_model
cfg3 = get_cfg_defaults()
cfg3.merge_from_file("./configuration/Dab_hier_CSN50_AVA22.yaml")
baseline_model, _, _ = build_model(cfg3)

og_checkpoint = torch.load("../pretrained_models/main/tuber.pth")
# my_checkpoint = torch.load("../pretrained_models/main/dab_hoper.pth")
my_checkpoint = torch.load("/mnt/video_nfs4/users/jinsung/results/tubelet-transformer/AVA_Tuber/Dab_conv_trans_CSN152_270-251/checkpoints/ckpt_epoch_06.pth")
baseline_checkpoint = torch.load("../pretrained_models/main/baseline.pth")

model_dict = model.state_dict()
curr_model_dict = current_model.state_dict()
baseline_model_dict = baseline_model.state_dict()


print("---------og------------")
pretrained_dict = {k[7:]: v for k, v in og_checkpoint['model'].items() if k[7:] in model_dict}
unused_dict = {k[:7]: v for k, v in og_checkpoint['model'].items() if not k[7:] in model_dict}
not_found_dict = {k: v for k, v in model_dict.items() if not "module."+k in og_checkpoint['model']}
print("# successfully loaded model layers:", len(pretrained_dict.keys()))
print("# unused model layers:", len(unused_dict.keys()))
print("# not found layers:", len(not_found_dict.keys()))
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

print("---------mine------------")
pretrained_dict = {k[7:]: v for k, v in my_checkpoint['model'].items() if k[7:] in curr_model_dict}
unused_dict = {k[7:]: v for k, v in my_checkpoint['model'].items() if not k[7:] in curr_model_dict}
not_found_dict = {k: v for k, v in curr_model_dict.items() if not "module"+k in my_checkpoint['model']}
print("# successfully loaded model layers:", len(pretrained_dict.keys()))
print("# unused model layers:", len(unused_dict.keys()))
print("# not found layers:", len(not_found_dict.keys()))
curr_model_dict.update(pretrained_dict)
current_model.load_state_dict(curr_model_dict)

print("---------baseline------------")
pretrained_dict = {k[7:]: v for k, v in baseline_checkpoint['model'].items() if k[7:] in baseline_model_dict}
unused_dict = {k[7:]: v for k, v in baseline_checkpoint['model'].items() if not k[7:] in baseline_model_dict}
not_found_dict = {k: v for k, v in baseline_model_dict.items() if not "module"+k in baseline_checkpoint['model']}
print("# successfully loaded model layers:", len(pretrained_dict.keys()))
print("# unused model layers:", len(unused_dict.keys()))
print("# not found layers:", len(not_found_dict.keys()))
baseline_model_dict.update(pretrained_dict)
baseline_model.load_state_dict(baseline_model_dict)

print("--------- all models are successfully loaded ------------")


transforms=make_transforms("val", cfg2)

# model.eval()
current_model.eval()
baseline_model.eval()
# sample_image1_path = "/mnt/tmp/frames/xeGWXqSvC-8/xeGWXqSvC-8_000360.jpg" #False
# sample_image2_path = "/mnt/tmp/frames/CMCPhm2L400/CMCPhm2L400_011200.jpg" #False
# sample_image3_path = "/mnt/tmp/frames/Gvp-cj3bmIY/Gvp-cj3bmIY_024750.jpg" #True

# '/home/nsml/assets/ava_{}_v21.json'
val_bbox_json = json.load(open(cfg2.CONFIG.DATA.ANNO_PATH.format("val")))
video_frame_bbox = val_bbox_json["video_frame_bbox"]

def sim_matrix(a, b, eps=1e-8):
    """
    a: hs x bs x dim
    b: nq x bs x dim
    output: bs x hs x nq
    """
    a, b = a.permute(1,0,2), b.permute(1,0,2)
    a_n, b_n = a.norm(dim=2)[:, :, None], b.norm(dim=2)[:, :, None]
    a_norm = a / torch.clamp(a_n, min=eps)
    b_norm = b / torch.clamp(b_n, min=eps)
    sim_mt = torch.bmm(a_norm, b_norm.transpose(1, 2))
    return sim_mt

def load_annotation(sample_id, video_frame_list): # (val 혹은 train의 key frame을 표시해놓은 list)

    num_classes = 80
    boxes, classes = [], []
    target = {}

    first_img = cv2.imread(video_frame_list[0])

    oh = first_img.shape[0]
    ow = first_img.shape[1]
    if oh <= ow:
        nh = 256
        nw = 256 * (ow / oh)
    else:
        nw = 256
        nh = 256 * (oh / ow)

    p_t = int(32 // 2)
    key_pos = p_t

    anno_entity = video_frame_bbox[sample_id]

    for i, bbox in enumerate(anno_entity["bboxes"]):
        label_tmp = np.zeros((num_classes, ))
        acts_p = anno_entity["acts"][i]
        for l in acts_p:
            label_tmp[l] = 1

        if np.sum(label_tmp) == 0: continue
        p_x = np.int_(bbox[0] * nw)
        p_y = np.int_(bbox[1] * nh)
        p_w = np.int_(bbox[2] * nw)
        p_h = np.int_(bbox[3] * nh)

        boxes.append([p_t, p_x, p_y, p_w, p_h])
        classes.append(label_tmp)

    boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 5)
    boxes[:, 1::3].clamp_(min=0, max=int(nw))
    boxes[:, 2::3].clamp_(min=0, max=nh)

    if boxes.shape[0]:
        raw_boxes = F.pad(boxes, (1, 0, 0, 0), value=0)
    else:
        raw_boxes = boxes
    classes = np.array(classes)
    classes = torch.as_tensor(classes, dtype=torch.float32).reshape(-1, num_classes)

    target["image_id"] = [str(sample_id).replace(",", "_"), key_pos]
    target['boxes'] = boxes
    target['raw_boxes'] = raw_boxes
    target["labels"] = classes
    target["orig_size"] = torch.as_tensor([int(nh), int(nw)])
    target["size"] = torch.as_tensor([int(nh), int(nw)])
    # self.index_cnt = self.index_cnt + 1

    return target

def loadvideo(start_img, vid, frame_key):
    frame_path = "/mnt/tmp/frames/{}"
    video_frame_path = frame_path.format(vid)
    video_frame_list = sorted(glob(video_frame_path + '/*.jpg'))

    if len(video_frame_list) == 0:
        print("path doesnt exist", video_frame_path)
        return [], []
    
    target = load_annotation(frame_key, video_frame_list)

    start_img = np.max(start_img, 0)
    end_img = start_img + 32 * 2
    indx_img = list(np.clip(range(start_img, end_img, 2), 0, len(video_frame_list) - 1))
    buffer = []
    for frame_idx in indx_img:
        tmp = Image.open(video_frame_list[frame_idx])
        tmp = tmp.resize((target['orig_size'][1], target['orig_size'][0]))
        buffer.append(tmp)

    return buffer, target



In [None]:
# TubeR's csn152, ava21 results
plt.figure()
%matplotlib inline

go_on = True
while go_on:
    gpu_num = random.randint(0,7)

    detection = './tuber_res/{}.txt'.format(gpu_num) #numbers are changeable
    gt = './tuber_res/GT_{}.txt'.format(gpu_num)

    baseline_detection = './baseline_res/{}.txt'.format(gpu_num)
    baseline_gt = './baseline_res/GT_{}.txt'.format(gpu_num)

    # my current model's results
    my_detection = './my_res/{}.txt'.format(gpu_num) #numbers are changeable
    my_gt = './my_res/GT_{}.txt'.format(gpu_num)
    # my_detection = './res/{}.txt'.format(gpu_num) #numbers are changeable
    # my_gt = './res/GT_{}.txt'.format(gpu_num)


    # what label am I interested in?
    # label = 54 # "stand", for example
    # LOI = [12, 14, 21, 23, 25, 29, 33, 35, 37, 40, 42, 44, 45, 46, 55, 56, 59, 60, 61, 62, 63, 64, 65, 68, 72, 75, 77]
    # label = random.sample(LOI, 1)[0]
    label = random.randint(1,80)
    
    # find a video key frame with desired label

    key_frame_candidates = []
    with open(my_gt) as f:
        for line in f.readlines():
            img_id = line.split(' ')[0]
            annotation = [int(float(n)) for n in line.split('[')[1].split(']')[0].split(',')]
            multi_hot_obj_label = annotation[6:]
            if multi_hot_obj_label[label-1] == 1:
                key_frame_candidates.append(img_id)
    
    # with open(baseline_gt) as f:
    #     for line in f.readlines():
    #         img_id = line.split(' ')[0]
    #         annotation = [int(float(n)) for n in line.split('[')[1].split(']')[0].split(',')]
    #         multi_hot_obj_label = annotation[6:]
    #         if multi_hot_obj_label[label-1] == 1:
    #             key_frame_candidates.append(img_id)

    # pick one of key_frame_candidates:
    if len(key_frame_candidates) == 0:
        print("no key frame found with following label: {}".format(items[label]), " try another gpu_num")

    ind = random.randint(0, len(key_frame_candidates)-1)
    key_frame = key_frame_candidates[ind]
    frame_second = key_frame.split("_")[-1]
    vid = "_".join(key_frame.split('_')[:-1])
    """
    # frame_key is one of "xeGWXqSvC-8,0911", "CMCPhm2L400,1274", "Gvp-cj3bmIY,1725", "Gvp-cj3bmIY_1675"
    frame_key = "Gvp-cj3bmIY,1675" 
    vid, frame_second = frame_key.split(',')
    """
    frame_key = ",".join([vid, frame_second])
    timef = int(frame_second) - 900
    start_img = np.max((timef * 30 - 32 // 2 * 2, 0))

    imgs, target = loadvideo(start_img, vid, frame_key)

    """
    start_img: start_img number, int
    vid: xeGWXqSvC-8, CMCPhm2L400, Gcp-cj3bmIY
    frame_key: 0911, 1274, 1725

    """
    orig_vid = imgs
    plt.imshow(orig_vid[16])
    plt.show()
    response = input()
    
    if response == "y":
        go_on = False

imgs, target = transforms(imgs, target)
ho,wo = imgs[0].shape[-2], imgs[0].shape[-1]
imgs = torch.stack(imgs, dim=0)
imgs = imgs.permute(1, 0, 2, 3)
    

# print(len(imgs), imgs[0].shape, target)

device = "cuda:0"
model = model.to(device)
imgs = imgs.to(device)

device2 = "cuda:0"
current_model = current_model.to(device2)
imgs2 = imgs.to(device2)

device3 = "cuda:1"
baseline_model = baseline_model.to(device3)
imgs3 = imgs.to(device3)

# print(attn_weights.shape)


In [None]:
# check out intermediate 

#tuber og model
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
enc_features = []
query_features = []
cls_enc_attn_weights, cls_dec_attn_weights = [], []
cls_enc_features = []
cls_query_features = []

#current model
curr_conv_features, curr_enc_attn_weights, curr_dec_attn_weights = [], [], []
curr_enc_features = []
curr_query_features = []
curr_key_features = []
curr_cls_enc_attn_weights, curr_cls_dec_attn_weights = [], []
curr_cls_enc_features = []
curr_cls_query_features = []
curr_actor_features = []
curr_global_features = []

#baseline model
baseline_conv_features, baseline_enc_attn_weights, baseline_dec_attn_weights = [], [], []
baseline_enc_features = []
baseline_query_features = []
baseline_cls_enc_attn_weights, baseline_cls_dec_attn_weights = [], []
baseline_cls_enc_features = []
baseline_cls_query_features = []
baseline_final_features = []


hooks = [
    model.backbone.body.layer4[-2].register_forward_hook(
        lambda self, input, output: conv_features.append(output)
    ),
    model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
        lambda self, input, output: enc_attn_weights.append(output[1])
    ),
    model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
        lambda self, input, output: enc_features.append(output[0])
    ),
    # model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
    #     lambda self, input, output: dec_attn_weights.append(output[1])
    # ),
    model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
        lambda self, input, output: query_features.append(output[0])
    ),
]

for i in range(6):
    hooks.append(
        model.transformer.decoder.layers[i].multihead_attn.register_forward_hook(
            lambda self, input, output: dec_attn_weights.append(output[1])
        )
    )

hooks2 = [
    model.encoder.layers[-1].self_attn_t.register_forward_hook(
        lambda self, input, output: cls_enc_attn_weights.append(output[1])
    ),
    model.encoder.layers[-1].self_attn_t.register_forward_hook(
        lambda self, input, output: cls_enc_features.append(output[0])
    ),
    model.cross_attn.register_forward_hook(
        lambda self, input, output: cls_dec_attn_weights.append(output[1])
    ),    
    model.cross_attn.register_forward_hook(
        lambda self, input, output: cls_query_features.append(output[0])
    ),    
]

cur_hooks = []
cur_hooks.append(current_model.backbone.body.layer4[-2].register_forward_hook(
        lambda self, input, output: curr_conv_features.append(output)
        ),
    )
# cur_hooks.append(
#     current_model.transformer.decoder.conv_blocks[-1].register_forward_hook(
#         lambda self, input, output: curr_query_features.append(output)
#     )
# )
cur_hooks.append(
    current_model.transformer.decoder.cross_attn.register_forward_hook(
        lambda self, input, output: curr_query_features.append(output[1])
    )
)
# cur_hooks.append(
#     current_model.transformer.decoder.q_proj.register_forward_hook(
#         lambda self, input, output: curr_query_features.append(output)
#     )
# )
# cur_hooks.append(
#     current_model.transformer.decoder.k_proj.register_forward_hook(
#         lambda self, input, output: curr_key_features.append(output)
#     )
# )
cur_hooks.append(
    current_model.transformer.encoder.register_forward_hook(
        lambda self, input, output: curr_global_features.append(output)
    )
)
cur_hooks.append(
    current_model.transformer.decoder.cls_norm.register_forward_hook(
        lambda self, input, output: curr_actor_features.append(output)
    )
)
# for i in range(6):
    # cur_hooks.append(
    #     current_model.transformer.decoder.cls_layers[i].cross_attn.register_forward_hook(
    #         lambda self, input, output: curr_cls_dec_attn_weights.append(output[1])
    #     )
    # )


baseline_hooks = []
baseline_hooks.append(baseline_model.backbone.body.layer4[-2].register_forward_hook(
        lambda self, input, output: baseline_conv_features.append(output)
        ),
    )
baseline_hooks.append(baseline_model.transformer.encoder.layers[-1].norm2_t.register_forward_hook(
        lambda self, input, output: baseline_enc_features.append(output[16])
        ),
    )
baseline_hooks.append(baseline_model.transformer.decoder.layers[-1].norm2.register_forward_hook(
        lambda self, input, output: baseline_final_features.append(output)
        ),
    )
for i in range(6):
    baseline_hooks.append(
        baseline_model.transformer.decoder.layers[i].cross_attn.register_forward_hook(
            lambda self, input, output: baseline_dec_attn_weights.append(output[1])
        )
    )
    # baseline_hooks.append(
    #     baseline_model.transformer.decoder.cls_layers[i].cross_attn.register_forward_hook(
    #         lambda self, input, output: baseline_cls_dec_attn_weights.append(output)
    #     )
    # )

outputs = model(imgs.unsqueeze(0))
my_outputs = current_model(imgs2.unsqueeze(0))
baseline_outputs = baseline_model(imgs3.unsqueeze(0))

for hook in hooks:
    hook.remove()
for hook in hooks2:
    hook.remove()

for hook in cur_hooks:
    hook.remove()

for hook in baseline_hooks:
    hook.remove()

conv_features = conv_features[0]
enc_attn_weights = enc_attn_weights[0]
dec_attn_weights = dec_attn_weights
enc_features = enc_features[0]
query_features = query_features[0]
# cls_enc_attn_weights = cls_enc_attn_weights[0]
# cls_dec_attn_weights = cls_dec_attn_weights[0]
# cls_enc_features = cls_enc_features[0]
# cls_query_features = cls_query_features[0]

curr_conv_features = curr_conv_features[0]
curr_query_features = curr_query_features
# curr_key_features = curr_key_features
curr_key_features = current_model.transformer.decoder.class_queries
# curr_dec_attn_weights = curr_dec_attn_weights
# curr_cls_dec_attn_weights = curr_cls_dec_attn_weights
# curr_cls_enc_features = curr_cls_enc_features[0]
# curr_cls_query_features = curr_cls_query_features
# curr_cls_dec_attn_weights = curr_cls_dec_attn_weights #/ len(curr_cls_dec_attn_weights)
# curr_dec_attn_weights = curr_dec_attn_weights #/ len(curr_dec_attn_weights)

baseline_conv_features = baseline_conv_features[0]
# baseline_dec_attn_weights = baseline_dec_attn_weights
baseline_enc_features = baseline_enc_features[0]
baseline_final_features = baseline_final_features[0]

In [None]:
orig_vid[16]

## Test to see if matplotlib works well

In [None]:
h, w = curr_conv_features.shape[-2:]
# attn = (curr_query_features*curr_key_features).sum(dim=1).flatten(1).softmax(dim=1).reshape(-1, 1, h, w)

In [None]:
curr_query_features[0].shape
# current_model.transformer.decoder.q_proj.weight.squeeze().shape

In [None]:
# im=misc.imread("photosAfterAverage/exampleAfterAverage1.jpg")
%matplotlib inline
h, w = curr_conv_features.shape[-2:]
plt.figure()
plt.imshow(orig_vid[16])
attn = curr_query_features[-1].view(15, 80, h, w)
# attn = (query[:, None].expand(-1, 80, -1, -1, -1) * curr_key_features_).sum(dim=2).flatten(2).softmax(dim=2).reshape(15, -1, h, w)[:, :, None]
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
plt.imshow(attn[6, 53].detach().cpu().view(h, w), cmap='seismic', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))
# plt.imshow(curr_query_features[0][0, 2].detach().cpu().view(h, w), cmap='copper', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))
# plt.imshow(cls_dec_attn_weights[0, 2, :].detach().cpu().view(h, w), cmap='copper', interpolation='bicubic', alpha=.7, extent=(xmin, xmax, ymin, ymax))

# plt.imshow(test.detach().cpu().view(h, w), cmap='copper', interpolation='nearest', alpha=.8, extent=(xmin, xmax, ymin, ymax))
plt.show()

In [None]:
attn.shape

In [None]:
h, w = curr_conv_features.shape[-2:]
attn = []
for i in range(6):
    attn_ = curr_query_features[i].reshape(15,80,h,w)
    attn.append(attn_)

In [None]:
plt.figure()
plt.imshow(orig_vid[16])

xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
baseline_enc_features
plt.imshow(baseline_enc_features.norm(dim=1).detach().cpu().view(h, w), cmap='seismic', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))
# plt.imshow(cls_dec_attn_weights[0, 2, :].detach().cpu().view(h, w), cmap='copper', interpolation='bicubic', alpha=.7, extent=(xmin, xmax, ymin, ymax))

# plt.imshow(test.detach().cpu().view(h, w), cmap='copper', interpolation='nearest', alpha=.8, extent=(xmin, xmax, ymin, ymax))
plt.show()

## load annotation files
#### TODO: load it with listfiles()

In [None]:
# frame_key is one of "xeGWXqSvC-8,0911", "CMCPhm2L400,1274", "Gvp-cj3bmIY,1725" 

detection = './tuber_res/{}.txt' #numbers are changeable
gt = './tuber_res/GT_{}.txt'

query_logits = []
# with open(detection) as f:
#     for line in f.readlines():
#         img_id = line.split(' ')[0]
#         if key_frame != img_id:
#             continue
#         else:
#             annotation = [float(n) for n in line.split('[')[1].split(']')[0].split(',')]
#             # multi_hot_obj_label = [int(n) for n in annotation[6:]]
#             query_logits.append(annotation)

for i in range(8):
    detection_ = detection.format(i)
    with open(detection_) as f:
        for line in f.readlines():
            img_id = line.split(' ')[0]
            if key_frame != img_id:
                continue
            else:
                annotation = [float(n) for n in line.split('[')[1].split(']')[0].split(',')]
                # multi_hot_obj_label = [int(n) for n in annotation[6:]]
                query_logits.append(annotation)

anno_dict = {}
for j in range(8):
    gt_ = gt.format(j)
    with open(gt_) as f:
        for line in f.readlines():
            img_id = line.split(' ')[0]
            if img_id != key_frame:
                continue
            else:
                annotation = [int(float(n)) for n in line.split('[')[1].split(']')[0].split(',')]
                gt_coord = annotation[2:6]
                # gtxmin, gtymin, gtxmax, gtymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                gt_multi_hot_label = annotation[6:]
                gt_cat = [items[i+1] for i, e in enumerate(gt_multi_hot_label) if e]
                if img_id not in anno_dict.keys():
                    anno_dict[img_id] = {
                        "obj": [gt_cat],
                        "coord": [gt_coord]
                    }           
                else:
                    anno_dict[img_id]["obj"].append(gt_cat)
                    anno_dict[img_id]["coord"].append(gt_coord) 

my_detection = './my_res/{}.txt'
my_gt = './my_res/GT_{}.txt'

# my_detection = './res/{}.txt'
# my_gt = './res/GT_{}.txt'


my_query_logits = []

for i in range(8):
    my_detection_ = my_detection.format(i)
    with open(my_detection_) as f:
        for line in f.readlines():
            img_id = line.split(' ')[0]
            if key_frame != img_id:
                continue
            else:
                annotation = [float(n) for n in line.split('[')[1].split(']')[0].split(',')]
                # multi_hot_obj_label = [int(n) for n in annotation[6:]]
                my_query_logits.append(annotation)

my_anno_dict = {}
for j in range(8):
    my_gt_ = my_gt.format(j)
    with open(my_gt_) as f:
        for line in f.readlines():
            img_id = line.split(' ')[0]
            if img_id != key_frame:
                continue
            else:
                annotation = [int(float(n)) for n in line.split('[')[1].split(']')[0].split(',')]
                gt_coord = annotation[2:6]
                # gtxmin, gtymin, gtxmax, gtymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                gt_multi_hot_label = annotation[6:]
                gt_cat = [items[i+1] for i, e in enumerate(gt_multi_hot_label) if e]
                if img_id not in my_anno_dict.keys():
                    my_anno_dict[img_id] = {
                        "obj": [gt_cat],
                        "coord": [gt_coord]
                    }           
                else:
                    my_anno_dict[img_id]["obj"].append(gt_cat)
                    my_anno_dict[img_id]["coord"].append(gt_coord) 


baseline_detection = './baseline_res/{}.txt'
baseline_gt = './baseline_res/GT_{}.txt'

baseline_query_logits = []
for i in range(8):
    baseline_detection_ = baseline_detection.format(i)
    with open(baseline_detection_) as f:
        for line in f.readlines():
            img_id = line.split(' ')[0]
            if key_frame != img_id:
                continue
            else:
                annotation = [float(n) for n in line.split('[')[1].split(']')[0].split(',')]
                # multi_hot_obj_label = [int(n) for n in annotation[6:]]
                baseline_query_logits.append(annotation)

baseline_anno_dict = {}
for i in range(8):
    baseline_gt_ = baseline_gt.format(i)
    with open(baseline_gt_) as f:
        for line in f.readlines():
            img_id = line.split(' ')[0]
            if img_id != key_frame:
                continue
            else:
                annotation = [int(float(n)) for n in line.split('[')[1].split(']')[0].split(',')]
                gt_coord = annotation[2:6]
                # gtxmin, gtymin, gtxmax, gtymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                gt_multi_hot_label = annotation[6:]
                gt_cat = [items[i+1] for i, e in enumerate(gt_multi_hot_label) if e]
                if img_id not in baseline_anno_dict.keys():
                    baseline_anno_dict[img_id] = {
                        "obj": [gt_cat],
                        "coord": [gt_coord]
                    }           
                else:
                    baseline_anno_dict[img_id]["obj"].append(gt_cat)
                    baseline_anno_dict[img_id]["coord"].append(gt_coord) 

In [None]:
import visualization_utils_custom as vis_utils
img_show = orig_vid[16]
for _, (obj, gt_coord) in enumerate(zip(my_anno_dict[key_frame]["obj"], my_anno_dict[key_frame]["coord"])):
    gt_cat = str(obj)
    gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
    vis_utils.draw_bounding_box_on_image(
        img_show, gt_ymin, gt_xmin, gt_ymax, gt_xmax,
        color = 'Green',
        display_str_list=[gt_cat],
        use_normalized_coordinates=False,
        margin2=20
    )
    print(gt_cat)
img_show

### Visualization of decoder query attention (of the second transformer)
(Simpler version)

In [None]:
import copy
import visualization_utils_custom as vis_utils
import matplotlib.animation as animation
import matplotlib.image as mpimg
import imageio.v2 as imageio
from IPython import display


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)

%matplotlib inline
# curr_weights = (curr_cls_dec_attn_weights + curr_dec_attn_weights)[0]
# curr_weights = (curr_dec_attn_weights)[0]
curr_weights = attn
class_of_interest = 73

h2, w2 = curr_conv_features.shape[-2:]
# curr_weights = curr_weights.transpose(0,1).reshape(1, 15, 32, h2, w2)
# curr_weights = curr_weights.view(1,15,32,h2,w2)
fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(h,w//3+1))
ims = []
t = 16

target_img = orig_vid[t]
fig.suptitle("current model (time {})".format(t) + " class: {}".format(items[class_of_interest+1]))

for i, ax_i in enumerate(axs):
    for j in range(5):
        tgt_img = copy.deepcopy(target_img)
        # coord = my_query_logits[i*5+j][:4]
        coord = box_cxcywh_to_xyxy(my_outputs["pred_boxes"][0, i*5+j])
        # logits = my_query_logits[i*5+j][4:-1]
        logits = my_outputs["pred_logits"][0, i*5+j].sigmoid() * my_outputs["pred_logits_b"].softmax(-1)[0, i*5+j, 1:2]
        bxmin, bymin, bxmax, bymax = coord[0], coord[1], coord[2], coord[3]
        cat = [items[i+1] for i, e in enumerate([k>0.5 for k in logits]) if e]
        ax_i[j].set_title("query {}".format(i*5+j))
        if len(cat) == 0:                
            pass
        elif t == 16:
            # for _, (obj, gt_coord) in enumerate(zip(my_anno_dict[key_frame]["obj"], my_anno_dict[key_frame]["coord"])):
            #     gt_cat = str(obj)
            #     gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
            #     vis_utils.draw_bounding_box_on_image(
            #         tgt_img, gt_ymin, gt_xmin, gt_ymax, gt_xmax,
            #         color = 'Green',
            #         display_str_list=[gt_cat],
            #         use_normalized_coordinates=False,
            #         margin2=20
            #     )
            #     print(gt_cat)   
            vis_utils.draw_bounding_box_on_image(
                tgt_img, bymin, bxmin, bymax, bxmax,
                    color = 'Yellow',
                    display_str_list=cat,
                    # use_normalized_coordinates=False,
                    use_normalized_coordinates=True,
                    margin2=30
                )
            
            print(cat)
                
        else:
            pass
        ax_i[j].imshow(tgt_img)
        ax_i[j].imshow(attn[-1][i*5+j, class_of_interest].detach().cpu().view(h, w), cmap='seismic', interpolation='bicubic', alpha=.5, extent=(xmin, xmax, ymin, ymax))



In [None]:
import copy
import visualization_utils_custom as vis_utils
import matplotlib.animation as animation
import matplotlib.image as mpimg
import imageio.v2 as imageio
from IPython import display

%matplotlib inline
baseline_weights = baseline_dec_attn_weights[-1]
h2, w2 = baseline_conv_features.shape[-2:]
baseline_weights = baseline_weights.reshape(1, 15, 1, h2, w2)
baseline_weights = baseline_weights.view(1,15,1,h2,w2)
fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(h,w//3+1))
ims = []
t = 16

target_img = orig_vid[t]
fig.suptitle("baseline (time {})".format(t))

for i, ax_i in enumerate(axs):
    for j in range(5):
        tgt_img = copy.deepcopy(target_img)
        coord = baseline_query_logits[i*5+j][:4]
        logits = baseline_query_logits[i*5+j][4:-1]
        bxmin, bymin, bxmax, bymax = coord[0], coord[1], coord[2], coord[3]
        cat = [items[i+1] for i, e in enumerate([k>0.4 for k in logits]) if e]
        ax_i[j].set_title("query {}".format(i*5+j))
        if len(cat) == 0:                
            continue
        elif t == 16:
            for _, (obj, gt_coord) in enumerate(zip(baseline_anno_dict[key_frame]["obj"], baseline_anno_dict[key_frame]["coord"])):
                gt_cat = str(obj)
                gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                vis_utils.draw_bounding_box_on_image(
                    tgt_img, gt_ymin, gt_xmin, gt_ymax, gt_xmax,
                    color = 'Green',
                    display_str_list=[gt_cat],
                    use_normalized_coordinates=False,
                    margin2=20
                )   
            vis_utils.draw_bounding_box_on_image(
                tgt_img, bymin, bxmin, bymax, bxmax,
                    color = 'Yellow',
                    display_str_list=cat,
                    use_normalized_coordinates=False,
                    margin2=30
                )            
        else:
            pass
        ax_i[j].imshow(tgt_img)
        ax_i[j].imshow(baseline_weights[0, i*5+j, 0].detach().cpu().view(h, w), cmap='seismic', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))


In [None]:
# my_outputs["pred_logits_b"].softmax(-1)[:, :, 1:2] # 1, 15, 1
my_outputs["pred_logits"][0, :, 53].sigmoid()
# my_outputs["pred_boxes"]

In [None]:
import copy
import visualization_utils_custom as vis_utils
%matplotlib inline

fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(h,w//3+1))
target_img = orig_vid[16]
for i, ax_i in enumerate(axs):
    for j in range(5):
        tgt_img = copy.deepcopy(target_img)
        coord = query_logits[i*5+j][:4]
        logits = query_logits[i*5+j][4:-1]
        bxmin, bymin, bxmax, bymax = coord[0], coord[1], coord[2], coord[3]
        cat = [items[i+1] for i, e in enumerate([k>0.7 for k in logits]) if e]
        ax_i[j].set_title("query {}".format(i*5+j))
        if len(cat) == 0:
            continue
        else:
            for _, (obj, gt_coord) in enumerate(zip(anno_dict[key_frame]["obj"], anno_dict[key_frame]["coord"])):
                gt_cat = str(obj)
                gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                vis_utils.draw_bounding_box_on_image(
                    tgt_img, gt_ymin, gt_xmin, gt_ymax, gt_xmax,
                    color = 'Green',
                    display_str_list=[gt_cat],
                    use_normalized_coordinates=False,
                    margin2=20
                )               
            vis_utils.draw_bounding_box_on_image(
                tgt_img, bymin, bxmin, bymax, bxmax,
                    color = 'Yellow',
                    display_str_list=cat,
                    use_normalized_coordinates=False,
                    margin2=30
                )
        ax_i[j].imshow(tgt_img)
        ax_i[j].imshow(dec_attn_weights[-1][0, i*5+j].detach().cpu().view(h, w), cmap='seismic', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))
        ax_i[j].set_title("query {}".format(i*5+j))

In [None]:
import copy
import visualization_utils_custom as vis_utils
import matplotlib.animation as animation
import matplotlib.image as mpimg
import imageio.v2 as imageio
from IPython import display

%matplotlib inline

def plot_along_(l):
    baseline_weights = baseline_dec_attn_weights[5-l]
    h2, w2 = baseline_conv_features.shape[-2:]
    baseline_weights = baseline_weights.reshape(1, 15, 1, h2, w2)
    baseline_weights = baseline_weights.view(1,15,1,h2,w2)
    fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(h,w//3+1))
    t = 16
    
    target_img = orig_vid[t]
    fig.suptitle("baseline (layer {})".format(l))

    for i, ax_i in enumerate(axs):
        for j in range(5):
            tgt_img = copy.deepcopy(target_img)
            coord = baseline_query_logits[i*5+j][:4]
            logits = baseline_query_logits[i*5+j][4:-1]
            bxmin, bymin, bxmax, bymax = coord[0], coord[1], coord[2], coord[3]
            cat = [items[i+1] for i, e in enumerate([k>0.4 for k in logits]) if e]
            ax_i[j].set_title("query {}".format(i*5+j))
            if len(cat) == 0:                
                continue
            elif t == 16:
                # for _, (obj, gt_coord) in enumerate(zip(baseline_anno_dict[key_frame]["obj"], baseline_anno_dict[key_frame]["coord"])):
                #     gt_cat = str(obj)
                #     gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                #     vis_utils.draw_bounding_box_on_image(
                #         tgt_img, gt_ymin, gt_xmin, gt_ymax, gt_xmax,
                #         color = 'Green',
                #         display_str_list=[gt_cat],
                #         use_normalized_coordinates=False,
                #         margin2=20
                #     )   
                vis_utils.draw_bounding_box_on_image(
                    tgt_img, bymin, bxmin, bymax, bxmax,
                        color = 'Yellow',
                        display_str_list=cat,
                        use_normalized_coordinates=False,
                        margin2=30
                    )            
            else:
                pass
            ax_i[j].imshow(tgt_img)
            ax_i[j].imshow(baseline_weights[0, i*5+j, 0].detach().cpu().view(h, w), cmap='seismic', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))
    fig.savefig('./temp.png')
    img = imageio.imread('./temp.png')
    return img

ims = []
for l in range(6):
    ims.append(plot_along_(l))

imageio.mimsave('./temp.gif', ims, duration=0.3)

display.Image("./temp.gif")

In [None]:
import copy
import visualization_utils_custom as vis_utils
import matplotlib.animation as animation
import matplotlib.image as mpimg
import imageio.v2 as imageio
from IPython import display

%matplotlib inline

def plot_along_(l):
    # curr_weights = (curr_cls_dec_attn_weights + curr_dec_attn_weights)[5-l]
    # curr_weights = curr_cls_dec_attn_weights[5-l]
    curr_weights = attn[l] # 15, 1, 16, 25
    h2, w2 = curr_conv_features.shape[-2:]
    # curr_weights = curr_weights.transpose(0,1).reshape(1, 15, 32, h2, w2)
    # curr_weights = curr_weights.view(1,15,32,h2,w2)
    fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(h,w//3+1))
    # t = 16
    class_of_interest = 11
    target_img = orig_vid[t]
    fig.suptitle("current model (layer {}), label: {}".format(l, items[class_of_interest+1]))

    for i, ax_i in enumerate(axs):
        for j in range(5):
            tgt_img = copy.deepcopy(target_img)
            coord = my_query_logits[i*5+j][:4]
            logits = my_query_logits[i*5+j][4:-1]
            bxmin, bymin, bxmax, bymax = coord[0], coord[1], coord[2], coord[3]
            cat = [items[i+1] for i, e in enumerate([k>0.5 for k in logits]) if e]
            ax_i[j].set_title("query {}".format(i*5+j))
            if len(cat) == 0:                
                continue
            elif t == 16:
                # for _, (obj, gt_coord) in enumerate(zip(my_anno_dict[key_frame]["obj"], my_anno_dict[key_frame]["coord"])):
                #     gt_cat = str(obj)
                #     gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                #     vis_utils.draw_bounding_box_on_image(
                #         tgt_img, gt_ymin, gt_xmin, gt_ymax, gt_xmax,
                #         color = 'Green',
                #         display_str_list=[gt_cat],
                #         use_normalized_coordinates=False,
                #         margin2=20
                #     )   
                vis_utils.draw_bounding_box_on_image(
                    tgt_img, bymin, bxmin, bymax, bxmax,
                        color = 'Yellow',
                        display_str_list=cat,
                        use_normalized_coordinates=False,
                        margin2=30
                    )            
            else:
                pass
            ax_i[j].imshow(tgt_img)
            ax_i[j].imshow(curr_weights[i*5+j, class_of_interest].detach().cpu().view(h, w), cmap='seismic', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))
    fig.savefig('./temp.png')
    img = imageio.imread('./temp.png')
    return img

ims = []
for l in range(6):
    ims.append(plot_along_(l))

imageio.mimsave('./temp.gif', ims, duration=0.3)

display.Image("./temp.gif")

In [None]:
import copy
import visualization_utils_custom as vis_utils

%matplotlib inline
# classification decoder weight plotting along time axis
fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(h,w//3+1))
ims = []
target_img = orig_vid[16]
t = 2
fig.suptitle("time {}".format(t))
dec_attn_weights_ = sum(dec_attn_weights)

for i, ax_i in enumerate(axs):
    for j in range(5):

        tgt_img = copy.deepcopy(target_img)
        coord = query_logits[i*5+j][:4]
        logits = query_logits[i*5+j][4:-1]
        bxmin, bymin, bxmax, bymax = coord[0], coord[1], coord[2], coord[3]
        cat = [items[i+1] for i, e in enumerate([k>0.4 for k in logits]) if e]
        if len(cat) != 0:
            for _, (obj, gt_coord) in enumerate(zip(anno_dict[key_frame]["obj"], anno_dict[key_frame]["coord"])):
                gt_cat = str(obj)
                gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                vis_utils.draw_bounding_box_on_image(
                    tgt_img, gt_ymin, gt_xmin, gt_ymax, gt_xmax,
                    color = 'Green',
                    display_str_list=[gt_cat],
                    use_normalized_coordinates=False,
                    margin2=20
                )
            vis_utils.draw_bounding_box_on_image(
                tgt_img, bymin, bxmin, bymax, bxmax,
                    color = 'Yellow',
                    display_str_list=cat,
                    use_normalized_coordinates=False,
                    margin2=30
                )

            ax_i[j].imshow(tgt_img)
            ax_i[j].imshow(dec_attn_weights_[0, i*5+j].detach().cpu().view(h, w), cmap='copper', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))
        ax_i[j].set_title("query {}".format(i*5+j))

In [None]:
import copy
import visualization_utils_custom as vis_utils
import matplotlib.animation as animation
import matplotlib.image as mpimg
import imageio.v2 as imageio
from IPython import display

# %matplotlib inline

def plot_along_(l):
    dec_attn_weights_ = dec_attn_weights[5-l]
    h2, w2 = conv_features.shape[-2:]
    dec_attn_weights_ = dec_attn_weights_.reshape(1, 15, 1, h2, w2)
    dec_attn_weights_ = dec_attn_weights_.view(1,15,1,h2,w2)
    fig, axs = plt.subplots(ncols=5, nrows=3, figsize=(h,w//3+1))
    t = 16

    target_img = orig_vid[t]
    fig.suptitle("tuber (layer {})".format(l))

    for i, ax_i in enumerate(axs):
        for j in range(5):
            tgt_img = copy.deepcopy(target_img)
            coord = query_logits[i*5+j][:4]
            logits = query_logits[i*5+j][4:-1]
            bxmin, bymin, bxmax, bymax = coord[0], coord[1], coord[2], coord[3]
            cat = [items[i+1] for i, e in enumerate([k>0.4 for k in logits]) if e]
            ax_i[j].set_title("query {}".format(i*5+j))
            if len(cat) == 0:                
                continue
            elif t == 16:
                for _, (obj, gt_coord) in enumerate(zip(anno_dict[key_frame]["obj"], anno_dict[key_frame]["coord"])):
                    gt_cat = str(obj)
                    gt_xmin, gt_ymin, gt_xmax, gt_ymax = gt_coord[0], gt_coord[1], gt_coord[2], gt_coord[3]
                    vis_utils.draw_bounding_box_on_image(
                        tgt_img, gt_ymin, gt_xmin, gt_ymax, gt_xmax,
                        color = 'Green',
                        display_str_list=[gt_cat],
                        use_normalized_coordinates=False,
                        margin2=20
                    )   
                vis_utils.draw_bounding_box_on_image(
                    tgt_img, bymin, bxmin, bymax, bxmax,
                        color = 'Yellow',
                        display_str_list=cat,
                        use_normalized_coordinates=False,
                        margin2=30
                    )            
            else:
                pass
            ax_i[j].imshow(tgt_img)
            ax_i[j].imshow(dec_attn_weights_[0, i*5+j].detach().cpu().view(h, w), cmap='copper', interpolation='bicubic', alpha=.6, extent=(xmin, xmax, ymin, ymax))
    fig.savefig('./temp.png')
    img = imageio.imread('./temp.png')
    return img

ims = []
for l in range(6):
    ims.append(plot_along_(l))

imageio.mimsave('./temp.gif', ims, duration=0.3)

display.Image("./temp.gif")