-
Notifications
You must be signed in to change notification settings - Fork 401
/
ssd.py
153 lines (125 loc) · 6.19 KB
/
ssd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Single Shot Object Detection model with pretrained ResNet34 backbone extending :class:`.ComposerModel`."""
import os
import tempfile
from typing import Any, Sequence, Tuple, Union
import numpy as np
import requests
from torch import Tensor
from torchmetrics import Metric, MetricCollection
from composer.models.base import ComposerModel
from composer.models.ssd.base_model import Loss
from composer.models.ssd.ssd300 import SSD300
from composer.models.ssd.utils import Encoder, SSDTransformer, dboxes300_coco
from composer.utils.import_helpers import MissingConditionalImportError
__all__ = ['SSD']
class SSD(ComposerModel):
"""Single Shot Object detection Model with pretrained ResNet34 backbone extending :class:`.ComposerModel`.
Args:
input_size (int, optional): input image size. Default: ``300``.
num_classes (int, optional): The number of classes to detect. Default: ``80``.
overlap_threshold (float, optional): Minimum IOU threshold for NMS. Default: ``0.5``.
nms_max_detections (int, optional): Max number of boxes after NMS. Default: ``200``.
data (str, optional): path to coco dataset. Default: ``"/localdisk/coco"``.
"""
def __init__(self, input_size: int, overlap_threshold: float, nms_max_detections: int, num_classes: int, data: str):
super().__init__()
self.input_size = input_size
self.overlap_threshold = overlap_threshold
self.nms_max_detections = nms_max_detections
self.num_classes = num_classes
url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
with tempfile.TemporaryDirectory() as tempdir:
with requests.get(url, stream=True) as r:
r.raise_for_status()
pretrained_backbone = os.path.join(tempdir, 'weights.pth')
with open(pretrained_backbone, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
self.module = SSD300(self.num_classes, model_path=pretrained_backbone)
dboxes = dboxes300_coco()
self.loss_func = Loss(dboxes)
self.encoder = Encoder(dboxes)
self.data = data
self.MAP = coco_map(self.data)
val_annotate = os.path.join(self.data, 'annotations/instances_val2017.json')
val_coco_root = os.path.join(self.data, 'val2017')
input_size = self.input_size
val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)
from composer.datasets.coco import COCODetection
self.val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
def loss(self, outputs: Any, batch: Any) -> Union[Tensor, Sequence[Tensor]]:
(_, _, _, bbox, label) = batch #type: ignore
if not isinstance(bbox, Tensor):
raise TypeError('bbox must be a singular tensor')
trans_bbox = bbox.transpose(1, 2).contiguous()
ploc, plabel = outputs
gloc, glabel = trans_bbox, label
loss = self.loss_func(ploc, plabel, gloc, glabel)
return loss
def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]:
return self.MAP
def forward(self, batch: Any) -> Tensor:
(img, _, _, _, _) = batch #type: ignore
ploc, plabel = self.module(img)
return ploc, plabel #type: ignore
def validate(self, batch: Any) -> Tuple[Any, Any]:
inv_map = {v: k for k, v in self.val_coco.label_map.items()}
ret = []
overlap_threshold = self.overlap_threshold
nms_max_detections = self.nms_max_detections
(img, img_id, img_size, _, _) = batch #type: ignore
ploc, plabel = self.module(img)
results = []
try:
results = self.encoder.decode_batch(ploc,
plabel,
overlap_threshold,
nms_max_detections,
nms_valid_thresh=0.05)
except:
print('No object detected')
(htot, wtot) = [d.cpu().numpy() for d in img_size] #type: ignore
img_id = img_id.cpu().numpy() #type: ignore
if len(results) > 0:
# Iterate over batch elements
for img_id_, wtot_, htot_, result in zip(img_id, wtot, htot, results):
loc, label, prob = [r.cpu().numpy() for r in result] #type: ignore
# Iterate over image detections
for loc_, label_, prob_ in zip(loc, label, prob):
ret.append([img_id_, loc_[0]*wtot_, \
loc_[1]*htot_,
(loc_[2] - loc_[0])*wtot_,
(loc_[3] - loc_[1])*htot_,
prob_,
inv_map[label_]])
return ret, ret
class coco_map(Metric):
def __init__(self, data):
super().__init__()
try:
from pycocotools.coco import COCO
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='coco',
conda_channel='conda-forge',
conda_package='pycocotools') from e
self.add_state('predictions', default=[])
val_annotate = os.path.join(data, 'annotations/instances_val2017.json')
self.cocogt = COCO(annotation_file=val_annotate)
def update(self, pred, target):
self.predictions.append(pred) #type: ignore
np.squeeze(self.predictions) #type: ignore
def compute(self):
try:
from pycocotools.cocoeval import COCOeval
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='coco',
conda_channel='conda-forge',
conda_package='pycocotools') from e
cocoDt = self.cocogt.loadRes(np.array(self.predictions))
E = COCOeval(self.cocogt, cocoDt, iouType='bbox')
E.evaluate()
E.accumulate()
E.summarize()
return E.stats[0]