In [1]:
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, utils
from torchvision.transforms import v2

from detection_tools.data.pascal_voc import CLASSES, tensorize_target

In [2]:
from detection_tools.ssd_utils import VGG16FeatureExtractor
feature_extractor = VGG16FeatureExtractor()

In [3]:
feature_extractor

VGG16FeatureExtractor(
  (features_conv4_3): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=

In [4]:
img = torch.randn(1, 3, 300, 300)
features = feature_extractor(img)
for feature in features:
    print(feature.shape)

torch.Size([1, 512, 38, 38])
torch.Size([1, 1024, 19, 19])
torch.Size([1, 512, 10, 10])
torch.Size([1, 256, 5, 5])
torch.Size([1, 256, 3, 3])
torch.Size([1, 256, 1, 1])


In [5]:
transform = v2.Compose([
    v2.ToImage(),
    v2.Resize((300, 300)),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.48235, 0.45882, 0.40784], std=[0.229, 0.224, 0.225]),
])
ds = datasets.VOCDetection(
    root="../data/Pascal_VOC_Detection",
    transform=transform,
    target_transform=tensorize_target
)

In [6]:
def collate_fn(batch):
    images = [batch[0] for batch in batch]
    targets = [batch[1] for batch in batch]
    images = torch.stack(images, dim=0)
    return images, targets

In [7]:
dl = torch.utils.data.DataLoader(ds, batch_size=4, collate_fn=collate_fn, shuffle=True)

In [None]:
images, targets = next(iter(dl))

In [9]:
fmaps = feature_extractor(images)

In [10]:
for fmap in fmaps:
    print(fmap.shape)

torch.Size([4, 512, 38, 38])
torch.Size([4, 1024, 19, 19])
torch.Size([4, 512, 10, 10])
torch.Size([4, 256, 5, 5])
torch.Size([4, 256, 3, 3])
torch.Size([4, 256, 1, 1])


In [11]:
from detection_tools.ssd_utils import AnchorGenerator
anchor_generator = AnchorGenerator(
    aspect_ratios=[
        [1.0, 2.0],
        [1.0, 2.0, 3.0],
        [1.0, 2.0, 3.0],
        [1.0, 2.0, 3.0],
        [1.0, 2.0],
        [1.0, 2.0]
    ]
)
anchors = anchor_generator.generate_anchors((300, 300), VGG16FeatureExtractor.MAP_SHAPES_300)

In [12]:
anchors = [anchors] * images.shape[0]

In [13]:
from detection_tools.ssd_utils import Matcher, SSDHead, SSDLoss
matcher = Matcher(iou_threshold=0.5)
matches = []
for target, anchor in zip(targets, anchors):
    match = matcher(target["boxes"], anchor)
    matches.append(match)
print(matches)

[tensor([-1, -1, -1,  ..., -1, -1, -1]), tensor([-1, -1, -1,  ..., -1, -1, -1]), tensor([-1, -1, -1,  ..., -1, -1, -1]), tensor([-1, -1, -1,  ...,  0,  0, -1])]


In [14]:
for match in matches:
    print(match[match >= 0].numel())

3
296
38
15


In [15]:
from detection_tools.ssd_utils import OffsetHandler
o_handler = OffsetHandler()

In [16]:
num_anchors_per_location_per_map = [len(wh_pair) for wh_pair in anchor_generator.height_width_pairs]
print(num_anchors_per_location_per_map)

[4, 6, 6, 6, 4, 4]


In [17]:
in_channels = [512, 1024, 512, 256, 256, 256]

In [18]:
head = SSDHead(num_classes=len(CLASSES)+1, num_anchors=num_anchors_per_location_per_map, channels=in_channels)

In [19]:
head_outputs = head(fmaps)
print(head_outputs["offsets"].shape)
print(head_outputs["cls_logits"].shape)

torch.Size([4, 8732, 4])
torch.Size([4, 8732, 21])


In [20]:
loss = SSDLoss(neg_pos_ratio=3, o_handler=o_handler)

In [21]:
loss_vals = loss(targets, head_outputs, anchors, matches, raw=True)
print(loss_vals)

{'reg_loss': tensor(0.8864, grad_fn=<DivBackward0>), 'cls_loss': tensor(20.4386, grad_fn=<DivBackward0>)}


In [22]:
from detection_tools.ssd_utils import SSDPredictor
predictor = SSDPredictor(o_handler=o_handler)


In [23]:
preds = predictor(head_outputs, anchors)

In [28]:
preds[0]["scores"]

tensor([0.8074, 0.8046, 0.7584, 0.7526, 0.7494, 0.7437, 0.7432, 0.7298, 0.7150,
        0.7113, 0.7089, 0.7060, 0.6956, 0.6882, 0.6851, 0.6810, 0.6775, 0.6764,
        0.6753, 0.6739, 0.6708, 0.6676, 0.6650, 0.6561, 0.6526, 0.6515, 0.6513,
        0.6463, 0.6438, 0.6397, 0.6385, 0.6381, 0.6366, 0.6334, 0.6309, 0.6308,
        0.6263, 0.6240, 0.6238, 0.6197, 0.6171, 0.6164, 0.6136, 0.6088, 0.6070,
        0.6064, 0.6057, 0.6051, 0.6037, 0.6035, 0.5950, 0.5940, 0.5933, 0.5911,
        0.5903, 0.5881, 0.5874, 0.5861, 0.5855, 0.5831, 0.5828, 0.5817, 0.5813,
        0.5809, 0.5755, 0.5731, 0.5724, 0.5723, 0.5708, 0.5707, 0.5672, 0.5661,
        0.5652, 0.5638, 0.5615, 0.5605, 0.5604, 0.5602, 0.5601, 0.5585, 0.5582,
        0.5548, 0.5520, 0.5478, 0.5445, 0.5424, 0.5423, 0.5416, 0.5390, 0.5381,
        0.5380, 0.5379, 0.5367, 0.5352, 0.5347, 0.5335, 0.5321, 0.5287, 0.5285,
        0.5279, 0.5244, 0.5238, 0.5224, 0.5224, 0.5215, 0.5214, 0.5200, 0.5200,
        0.5198, 0.5189, 0.5187, 0.5180, 