In [1]:
from kitti_detection import config
from kitti_detection.dataset import DataSample, class_names, load_train_val_test_dataset
from kitti_detection.utils import display_samples_h

import torch
import torchvision
from torch import nn, optim, Tensor
from torch.nested import nested_tensor
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torchvision.tv_tensors import BoundingBoxes

In [2]:
transforms = v2.Compose([
    v2.RandomCrop(size=(370, 370)),
    v2.SanitizeBoundingBoxes(),
    v2.ToDtype(torch.float32),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [3]:
def p(t):
    print(t)
    print(t.size())
    print()

In [4]:
class DETR(nn.Module):

    def __init__(self, dim_embed=256):
        self.backbone = self._backbone()
        self.conv = nn.Conv2d(512, 256, kernel_size=1)

        self.register_buffer('pos_embedding', create_pos_encoding(12, 12)) # (12, 12, 256)
        self.register_buffer('query_pos_embedding', get_1d_pos_encoding(20, dim_embed))

        self.transformer = nn.Transformer(dim_embed, nhead=8, num_encoder_layers=4, num_decoder_layers=4, batch_first=True)

        self.linear_class = nn.Linear(dim_embed, n_classes + 1)
        self.linear_bbox = nn.Linear(dim_embed, 4)
        super().__init__()

    def _backbone(self) -> nn.Module:
        backbone = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
        del backbone.fc
        del backbone.avgpool

        def _forward(bb: torchvision.models.ResNet, x):
            x = bb.conv1(x)
            x = bb.bn1(x)
            x = bb.relu(x)
            x = bb.maxpool(x)

            x = bb.layer1(x)
            x = bb.layer2(x)
            x = bb.layer3(x)
            x = bb.layer4(x)
            return x

        backbone.forward = lambda x: _forward(backbone, x)

    def forward(self, input):
        x = self.backbone(input)
        print(x.size())
        return x



In [5]:
train_dataset, val_dataset, test_dataset = load_train_val_test_dataset()

train_dataset.transform = transforms

In [6]:
n_classes = len(class_names)

In [7]:
data_loader = DataLoader(train_dataset, batch_size=None, shuffle=True)
#display_samples_h([next(iter(data_loader)) for _ in range(4)])

In [8]:
next(iter(data_loader))[0].size()

torch.Size([3, 370, 370])

In [9]:
input_batch = torch.stack([next(iter(data_loader))[0] for _ in range(4)])

In [10]:
print(input_batch.size())

torch.Size([4, 3, 370, 370])


In [11]:
dim_embed = 256

In [32]:
def get_1d_pos_encoding(l, dim):
    return torch.cat([
        torch.stack([
            torch.linspace(0, 10000**(2*i/dim), steps=l).sin(),
            torch.linspace(0, 10000**(2*i/dim), steps=l).cos()
        ], dim=1)
        for i in range(dim // 2)
    ], dim=1)

def create_pos_encoding(h, w, dim):
    col_embed = get_1d_pos_encoding(w, dim // 2).repeat(h, 1, 1)
    row_embed = get_1d_pos_encoding(h, dim // 2).unsqueeze(1).repeat(1, w, 1)
    
    return torch.cat((col_embed, row_embed), dim=-1)

In [15]:
query_pos_embedding.size()

torch.Size([20, 256])

In [30]:
x = backbone(input_batch) # (4, 512, 12, 12)
x = conv(x) # (4, 256, 12, 12)

x = x.permute(0, 2, 3, 1) # (4, 12, 12, 256)
x = x + pos_embedding
x = x.flatten(1, 2) # (4, 144, 256)

q = query_pos_embedding.repeat(4, 1, 1) # (4, 20, 256)

q = transformer(x, q)

pred_logits = linear_class(q)
pred_bboxes = torch.sigmoid(linear_bbox(q))

In [31]:
p(pred_bboxes)

tensor([[[0.2989, 0.3752, 0.6840, 0.2632],
         [0.3852, 0.2919, 0.7943, 0.4567],
         [0.3604, 0.3874, 0.8037, 0.4133],
         [0.2853, 0.3429, 0.7407, 0.5648],
         [0.3136, 0.4665, 0.6625, 0.3559],
         [0.4037, 0.4302, 0.8041, 0.3987],
         [0.2630, 0.4818, 0.7175, 0.2959],
         [0.2425, 0.4006, 0.7613, 0.3108],
         [0.2651, 0.3428, 0.7743, 0.3723],
         [0.4702, 0.4934, 0.7538, 0.3022],
         [0.3149, 0.3545, 0.6274, 0.4723],
         [0.4575, 0.3948, 0.6069, 0.4462],
         [0.4119, 0.5753, 0.6982, 0.4272],
         [0.3378, 0.4335, 0.6368, 0.3993],
         [0.3655, 0.3129, 0.7284, 0.4467],
         [0.3443, 0.4044, 0.8465, 0.3359],
         [0.3950, 0.4006, 0.7266, 0.4710],
         [0.2930, 0.2794, 0.7905, 0.4170],
         [0.2964, 0.4281, 0.5667, 0.4747],
         [0.3113, 0.3822, 0.7032, 0.4106]],

        [[0.3520, 0.3492, 0.6631, 0.3911],
         [0.2674, 0.3208, 0.8503, 0.3617],
         [0.3668, 0.2737, 0.7978, 0.2949],
         

In [21]:
print(n_classes)

9


In [None]:
print(n_)