In [1]:
import os
import copy

import torch
import numpy as np

from e2edet.utils.configuration import load_yaml
from e2edet.utils.general import get_root
from e2edet.utils.box_ops import box_cxcywh_to_xyxy
from e2edet.model import build_model
from e2edet.module import build_matcher
from e2edet.dataset import build_dataset
from e2edet.dataset.coco import ConvertCocoPolysToMask
from e2edet.dataset.helper.collate_fn import collate2d

  from .autonotebook import tqdm as notebook_tqdm
  def points_in_convex_polygon_jit(points, polygon, clockwise=True):


In [2]:
os.environ["E2E_DATASETS"] = "/media/scratch1/duykien/data/coco"

In [3]:
log_dir = os.path.join(get_root(), "..", "save/COCO-InstanceSegmentation/boxer2d_vit")
model_path = "boxer2d_vit_b_w16_4g_5x_ss_lsj_final/boxer2d_vit_final.pth"
config_path = "boxer2d_vit_b_w16_4g_5x_ss_lsj_final/config.yaml"

In [4]:
def _get_src_permutation_idx(indices, num_references=4):
    # permute predictions following indices
    batch_idx = torch.cat(
        [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
    )  # [batch_size * num_target_boxes]
    src_idx = torch.cat(
        [src for (src, _) in indices]
    )  # [batch_size * num_target_boxes]
    return batch_idx, torch.div(src_idx, num_references, rounding_mode='floor')

In [5]:
idx_to_refsize = [1, 2, 4, 8]

def boxes_to_labels(boxes):
    # area_range = [[0 ** 2, 32 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
    # area_label = ["small", "medium", "large"]
    assert (boxes[..., 2:] >= boxes[..., :2]).all().item()
    areas = torch.prod(boxes[..., 2:] - boxes[..., :2], dim=-1)
    labels = []

    for area in areas:
        area = area.item()

        if area < (32 ** 2):
            labels.append(0) # small
        elif area < (96 ** 2):
            labels.append(1) # medium
        else:
            labels.append(2) # large

    return torch.tensor(labels, dtype=torch.int64)

In [6]:
class SimPLRDemo:
    def __init__(self, root_path, model_path, config_path, current_device=torch.device("cuda")):
        model_path = os.path.join(root_path, model_path)
        config_path = os.path.join(root_path, config_path)
        self.current_device = current_device
        print("Loading model from", model_path)
        
        self.config = load_yaml(config_path)
        self._init_processors()

        self.model = self._build_simplr(model_path)
        self.matcher = build_matcher(self.config.loss.params.matcher)

    def _init_processors(self):
        task = self.config.task
        task_config = getattr(self.config.dataset_config, task)
        
        self.prepare = ConvertCocoPolysToMask(task_config["use_mask"])
        self.dataset = build_dataset(self.config, "val", self.current_device)

    def _build_simplr(self, model_path):
        num_classes = self.dataset.get_answer_size()
        other_args = {"num_classes": num_classes}
        
        model = build_model(self.config, **other_args)
        
        ext = model_path.split(".")[-1]
        state_dict = torch.load(model_path, map_location="cpu")
        if ext == "ckpt":
            state_dict = state_dict["model"]
            
        if list(state_dict.keys())[0].startswith('module') and not hasattr(model, 'module'):
            state_dict = self._multi_gpu_state_to_single(state_dict)
        
        print("Loading model:", model.load_state_dict(state_dict))
        model.to(self.current_device)
        model.eval()
        
        return model

    def _multi_gpu_state_to_single(self, state_dict):
        new_sd = {}
        for k, v in state_dict.items():
            if not k.startswith('module.'):
                raise TypeError("Not a multiple GPU state of dict")
            k1 = k[7:]
            new_sd[k1] = v
        return new_sd

    @torch.no_grad()
    def predict(self, idx=0):
        sample, target = self.dataset[idx]
        
        batch = collate2d([(sample, target)])

        sample, target = self.dataset.prepare_batch(batch)
        
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            outputs = self.model(sample, target)

        enc_output = outputs["enc_outputs"][0]
        
        bin_target = copy.deepcopy(target)
        for bt in bin_target:
            bt["labels"] = torch.zeros_like(bt["labels"])

        indices = self.matcher(enc_output, bin_target)

        attn = self.model.transformer.encoder.layers[5].self_attn.attn.float()
        boxes = self.model.transformer.encoder.layers[5].self_attn.boxes.float()
        vit_attn = self.model.backbone.net.blocks[11].attn.attn.float()

        src_idx = _get_src_permutation_idx(indices)
        boxes_target = torch.cat([t["boxes"][i] for t, (_, i) in zip(target, indices)], dim=0)
        size = target[0]["orig_size"]

        boxes_target = box_cxcywh_to_xyxy(boxes_target * size[[1, 0, 1, 0]])
        labels = boxes_to_labels(boxes_target)
        attn = attn[src_idx]
        boxes = box_cxcywh_to_xyxy(boxes[src_idx] * size[[1, 0, 1, 0]])

        return attn, boxes, labels, vit_attn, src_idx

In [7]:
demo = SimPLRDemo(log_dir, model_path, config_path, torch.device("cuda"))

Loading model from /media/deepstorage01/home2/duykien/test/3D-ObjectDect/e2edet/../save/COCO-InstanceSegmentation/boxer2d_vit/boxer2d_vit_b_w16_4g_5x_ss_lsj_final/boxer2d_vit_final.pth
loading annotations into memory...
Done (t=0.72s)
creating index...
index created!
loss_mode: focal
Loaded pretrained mae_base_patch16: _IncompatibleKeys(missing_keys=['blocks.0.attn.rel_pos_h', 'blocks.0.attn.rel_pos_w', 'blocks.1.attn.rel_pos_h', 'blocks.1.attn.rel_pos_w', 'blocks.2.attn.rel_pos_h', 'blocks.2.attn.rel_pos_w', 'blocks.3.attn.rel_pos_h', 'blocks.3.attn.rel_pos_w', 'blocks.4.attn.rel_pos_h', 'blocks.4.attn.rel_pos_w', 'blocks.5.attn.rel_pos_h', 'blocks.5.attn.rel_pos_w', 'blocks.6.attn.rel_pos_h', 'blocks.6.attn.rel_pos_w', 'blocks.7.attn.rel_pos_h', 'blocks.7.attn.rel_pos_w', 'blocks.8.attn.rel_pos_h', 'blocks.8.attn.rel_pos_w', 'blocks.9.attn.rel_pos_h', 'blocks.9.attn.rel_pos_w', 'blocks.10.attn.rel_pos_h', 'blocks.10.attn.rel_pos_w', 'blocks.11.attn.rel_pos_h', 'blocks.11.attn.rel_pos

In [8]:
idx = 0

attn, boxes, labels, vit_attn, src_idx = demo.predict(idx)
print(attn.shape)
print(labels.shape)

torch.Size([20, 12, 4, 4])
torch.Size([20])


In [9]:
print(labels)
print(vit_attn.shape)

tensor([0, 1, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 1, 2, 1, 1, 1, 0, 0, 1])
torch.Size([1, 12, 4096, 4096])


In [10]:
stats = {}

ids = list(torch.load("ids.pth"))

y, x = torch.meshgrid((torch.arange(64) + 0.5), (torch.arange(64) + 0.5), indexing="ij")
coords = torch.stack([x, y], dim=-1).flatten(0, 1).float()

rel_dist = torch.cdist(coords.unsqueeze(0), coords.unsqueeze(0)).cuda()

for idx in ids:
    attn, boxes, labels, vit_attn, src_idx = demo.predict(idx)

    batch_idx, len_idx = src_idx
    h_idx, w_idx = (len_idx // 128) // 2, (len_idx % 128) // 2
    len_idx = h_idx * 64 + w_idx
    vit_attn = (vit_attn[:, :, len_idx] * rel_dist[:, len_idx].unsqueeze(0)).sum(-1).squeeze(0).transpose(0, 1)

    stats[idx] = (attn.cpu(), boxes.cpu(), labels, vit_attn.cpu())


  h_idx, w_idx = (len_idx // 128) // 2, (len_idx % 128) // 2


# Analysis

In [11]:
ids = list(stats.keys())

In [12]:
large = torch.cat([stats[idx][0][stats[idx][2] == 2] for idx in ids], dim=0)
medium = torch.cat([stats[idx][0][stats[idx][2] == 1] for idx in ids], dim=0)
small = torch.cat([stats[idx][0][stats[idx][2] == 0] for idx in ids], dim=0)

In [13]:
print(large.shape[0])
print(medium.shape[0])
print(small.shape[0])

229
217
202


In [14]:
from collections import Counter

large_stats = Counter(large.sum(-1).max(dim=-1)[1].flatten().tolist())
print("large:", large_stats)

medium_stats = Counter(medium.sum(-1).max(dim=-1)[1].flatten().tolist())
print("medium:", medium_stats)

small_stats = Counter(small.sum(-1).max(dim=-1)[1].flatten().tolist())
print("small:", small_stats)

large: Counter({3: 965, 1: 651, 0: 570, 2: 562})
medium: Counter({0: 865, 1: 842, 2: 827, 3: 70})
small: Counter({0: 1375, 1: 884, 3: 84, 2: 81})


In [16]:
vit_large = torch.cat([stats[idx][3][stats[idx][2] == 2] for idx in ids], dim=0)
vit_medium = torch.cat([stats[idx][3][stats[idx][2] == 1] for idx in ids], dim=0)
vit_small = torch.cat([stats[idx][3][stats[idx][2] == 0] for idx in ids], dim=0)

In [24]:
print("vit_large:", vit_large.mean() * 16, vit_large.std(-1).mean())
print("vit_medium:", vit_medium.mean() * 16, vit_medium.std(-1).mean())
print("vit_small:", vit_small.mean() * 16, vit_small.std(-1).mean())

vit_large: tensor(253.3443) tensor(6.1593)
vit_medium: tensor(209.3601) tensor(7.1412)
vit_small: tensor(201.6862) tensor(6.8789)
