In [22]:
from kitti_detection import config
from kitti_detection.dataset import DataSample, class_names, load_train_val_test_dataset
from kitti_detection.utils import display_samples_h, generalized_box_iou, box_cxcywh_to_xyxy

import torch
import torchvision
from torch import nn, optim, Tensor
from torch.nested import nested_tensor
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torchvision.tv_tensors import BoundingBoxes

from scipy.optimize import linear_sum_assignment

In [23]:
transforms = v2.Compose([
    v2.RandomCrop(size=(370, 370)),
    v2.SanitizeBoundingBoxes(),
    v2.ToDtype(torch.float32),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [24]:
train_dataset, val_dataset, test_dataset = load_train_val_test_dataset()

train_dataset.transform = transforms

In [25]:
n_classes = len(class_names)

In [26]:
def p(t):
    print(t)
    print(t.size())
    print()

In [27]:
def get_1d_pos_encoding(l, dim):
    return torch.cat([
        torch.stack([
            torch.linspace(0, 10000**(2*i/dim), steps=l).sin(),
            torch.linspace(0, 10000**(2*i/dim), steps=l).cos()
        ], dim=1)
        for i in range(dim // 2)
    ], dim=1)

def create_pos_encoding(h, w, dim):
    col_embed = get_1d_pos_encoding(w, dim // 2).repeat(h, 1, 1)
    row_embed = get_1d_pos_encoding(h, dim // 2).unsqueeze(1).repeat(1, w, 1)
    
    return torch.cat((col_embed, row_embed), dim=-1)

In [28]:
class DETR(nn.Module):

    def __init__(self, dim_embed=256):
        super().__init__()
        self.backbone = self._backbone()
        self.conv = nn.Conv2d(512, dim_embed, kernel_size=1)

        self.register_buffer('pos_embedding', create_pos_encoding(12, 12, dim_embed)) # (12, 12, 256)
        self.register_buffer('query_pos_embedding', get_1d_pos_encoding(20, dim_embed))

        self.transformer = nn.Transformer(dim_embed, nhead=8, num_encoder_layers=4, num_decoder_layers=4, batch_first=True)

        self.linear_class = nn.Linear(dim_embed, n_classes + 1)
        self.linear_bbox = nn.Linear(dim_embed, 4)

    def _backbone(self) -> nn.Module:
        backbone = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
        del backbone.fc
        del backbone.avgpool

        def _forward(bb: torchvision.models.ResNet, x):
            x = bb.conv1(x)
            x = bb.bn1(x)
            x = bb.relu(x)
            x = bb.maxpool(x)

            x = bb.layer1(x)
            x = bb.layer2(x)
            x = bb.layer3(x)
            x = bb.layer4(x)
            return x

        backbone.forward = lambda x: _forward(backbone, x)
        # TODO BatchNorm freeze
        return backbone

    def forward(self, input):
        x = self.backbone(input) # (4, 512, 12, 12)
        x = self.conv(x) # (4, 256, 12, 12)

        x = x.permute(0, 2, 3, 1) # (4, 12, 12, 256)
        x = x + self.pos_embedding
        x = x.flatten(1, 2) # (4, 144, 256)

        q = self.query_pos_embedding.repeat(4, 1, 1) # (4, 20, 256)
        q = self.transformer(x, q)

        return {
            'pred_logits': self.linear_class(q),
            'pred_boxes': torch.sigmoid(self.linear_bbox(q))
        }


In [29]:
def _collate(samples):
    imgs = tuple( img for img, _ in samples )
    targets = tuple( target for _, target in samples )
    return torch.stack(imgs), tuple(targets)

In [30]:
data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=_collate)

In [31]:
imgs, targets = next(iter(data_loader))

In [32]:
model = DETR()

In [33]:
@torch.no_grad
def match_indices(outputs: torch.Tensor, targets: torch.Tensor, cost_bbox_factor=0.3, cost_class_factor=0.3, cost_giou_factor=0.3):
    """ Performs the matching

    Params:
        outputs: This is a dict that contains at least these entries:
                "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates

        targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                        objects in the target) containing the class labels
                "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates

    Returns:
        A list of size batch_size, containing tuples of (index_i, index_j) where:
            - index_i is the indices of the selected predictions (in order)
            - index_j is the indices of the corresponding selected targets (in order)
        For each batch element, it holds:
            len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
    """
    bs, num_queries = outputs["pred_logits"].shape[:2]

    # We flatten to compute the cost matrices in a batch
    out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
    out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

    # Also concat the target labels and boxes
    tgt_ids = torch.cat([v["labels"] for v in targets])
    tgt_bbox = torch.cat([v["boxes"] for v in targets])

    # Compute the classification cost. Contrary to the loss, we don't use the NLL,
    # but approximate it in 1 - proba[target class].
    # The 1 is a constant that doesn't change the matching, it can be ommitted.
    cost_class = -out_prob[:, tgt_ids]

    # Compute the L1 cost between boxes
    cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

    # Compute the giou cost betwen boxes
    cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))

    # Final cost matrix
    C = cost_bbox_factor * cost_bbox + cost_class_factor * cost_class + cost_giou_factor * cost_giou
    C = C.view(bs, num_queries, -1).cpu()

    sizes = [len(v["boxes"]) for v in targets]
    indices = [linear_sum_assignment(c[i].detach()) for i, c in enumerate(C.split(sizes, -1))]
    return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

In [34]:
output = model(imgs)
#match_indices(output, targets)


In [35]:
print(output["pred_logits"].shape)

torch.Size([4, 20, 10])


In [36]:
indices = match_indices(output, targets)

In [37]:
def _get_src_permutation_idx(indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx


_get_src_permutation_idx(indices)

(tensor([0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3]),
 tensor([ 2, 15,  0,  2,  5,  0, 16,  0,  1,  9, 15, 18]))

In [38]:
class HungarianLoss(nn.Module):
    def loss_labels(self, outputs, targets, indices):
        pred_logits = outputs['pred_logits']

        

In [39]:
num_classes = 9
empty_weight = torch.ones(num_classes + 1)
eofs_coef = 0.1
empty_weight[-1] = eofs_coef

In [40]:
empty_weight

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        0.1000])

In [41]:
import torch.nn.functional as F

def loss_labels(outputs, targets, indices):
    """Classification loss (NLL)
    targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
    """
    assert 'pred_logits' in outputs
    src_logits = outputs['pred_logits']

    idx = _get_src_permutation_idx(indices)
    target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
    target_classes = torch.full(src_logits.shape[:2], num_classes,
                                dtype=torch.int64, device=src_logits.device)
    target_classes[idx] = target_classes_o

    loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, weight=empty_weight)
    losses = {'loss_ce': loss_ce}
    return losses

In [43]:
loss_labels(output, targets, indices)