diff --git a/CHANGELOG.md b/CHANGELOG.md index 57ea3bb647..6cdd4d442f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- YOLO-X (new) detection example and refactoring ([#1324](https://github.com/catalyst-team/catalyst/pull/1324)) ### Changed diff --git a/examples/detection/README.md b/examples/detection/README.md index 5246410aba..bdf4cedd79 100644 --- a/examples/detection/README.md +++ b/examples/detection/README.md @@ -25,28 +25,44 @@ python3 to_coco.py ## Single Shot Detector (SSD) -For training SSD it is required for the dataset to return bounding box tensor with shapes `[num anchors, 4]` and labels tensor with shapes `[num anchors]`, `num anchors` is a fixed number of predictions generated by a model. +SSD expects that dataset will return bounding box tensor with shapes `[num anchors, 4]` and labels tensor with shapes `[num anchors]`, where `num anchors` is a fixed number of predictions generated by a model. Each bounding box should be in a format `[x_center, y_center, width, height]`. Usually, there are fewer objects on an image than `num anchors` returned by a model, to handle this your model should have an additional class for `background` and bounding boxes to ignore (or empty bounding boxes) should be marked with that class. ```bash catalyst-dl run \ - --config catalyst/examples/detection/ssd-config.yaml \ + --config catalyst/examples/detection/configs/ssd-config.yaml \ --expdir catalyst/examples/detection \ - --logdir ssd-detection-logs + --logdir ssd-detection-logs \ --verbose ``` -## Training CenterNet +## CenterNet -For training CenterNet it is required for the dataset to return heatmap tensor with object centers (expected shape is `[num classes, heatmap height, heatmap width]`) and tensor for Width-Height regression (expected shape is `[2, heatmap height, heatmap width]`). +CenterNet returns two tensors - class heatmap (tensor with shape `[num classes, heatmap height, heatmap width]`) and Width-Height regression (tensor with shape `[2, heatmap height, heatmap width]`). The last one represents bounding box width and height. To get object centers (`x_center` and `y_center`) you can apply MaxPool with `3x3` kernel size and stride - `1` and padding - `1` to class heatmap. ```bash catalyst-dl run \ - --config catalyst/examples/detection/centernet-config.yaml \ + --config catalyst/examples/detection/configs/centernet-config.yaml \ --expdir catalyst/examples/detection \ - --logdir centernet-detection-logs + --logdir centernet-detection-logs \ + --verbose +``` + + +## YOLO-X + +YOLO-X predicts tensor with shape `[num anchors, 4 + 1 + num classes]`. This tensor consists of bounding box coordinates, class confidence, and class probability. +To extract bounding boxes need to get first 4 numbers for each anchor, confidences can be extracted by 4th index. The numbers after 4th index represent class probabilities. + +**NOTE:** YOLO-X predicts bounding boxes in format `(x_center, y_center, w, h)`. + +```bash +catalyst-dl run \ + --config catalyst/examples/detection/configs/yolo-x-config.yaml \ + --expdir catalyst/examples/detection \ + --logdir centernet-detection-logs \ --verbose ``` diff --git a/examples/detection/__init__.py b/examples/detection/__init__.py index a50a3dfd4e..d803aa2e8c 100644 --- a/examples/detection/__init__.py +++ b/examples/detection/__init__.py @@ -7,17 +7,35 @@ from .callbacks import DetectionMeanAveragePrecision # noqa: E402 from .criterion import CenterNetCriterion, SSDCriterion # noqa: E402 -from .custom_runner import CenterNetDetectionRunner, SSDDetectionRunner # noqa: E402 -from .dataset import CenterNetDataset, SSDDataset # noqa: E402 -from .model import CenterNet, SingleShotDetector # noqa: E402 +from .custom_runner import ( # noqa: E402 + CenterNetDetectionRunner, + SSDDetectionRunner, + YOLOXDetectionRunner, +) +from .dataset import CenterNetDataset, SSDDataset, YOLOXDataset # noqa: E402 +from .models import ( # noqa: E402 + CenterNet, + SingleShotDetector, + yolo_x_tiny, + yolo_x_small, + yolo_x_medium, + yolo_x_large, + yolo_x_big, +) # runers Registry(SSDDetectionRunner) Registry(CenterNetDetectionRunner) +Registry(YOLOXDetectionRunner) # models Registry(SingleShotDetector) Registry(CenterNet) +Registry(yolo_x_tiny) +Registry(yolo_x_small) +Registry(yolo_x_medium) +Registry(yolo_x_large) +Registry(yolo_x_big) # criterions Registry(SSDCriterion) @@ -29,3 +47,4 @@ # datasets Registry(SSDDataset) Registry(CenterNetDataset) +Registry(YOLOXDataset) diff --git a/examples/detection/callbacks.py b/examples/detection/callbacks.py index 90d4f6839e..475653b18e 100644 --- a/examples/detection/callbacks.py +++ b/examples/detection/callbacks.py @@ -191,6 +191,60 @@ def process_centernet_output( yield pred_sample, gt_sample +def process_yolo_x_output( + predicted_tensor, + gt_boxes, + gt_labels, + iou_threshold=0.5, +): + """Generate bbox and classes from YOLO-X model outputs. + + Args: + predicted_tensor (torch.Tensor): model outputs, + expected shapes [batch, num anchors, 4 + 1 + num classes]. + gt_boxes (torch.Tensor): ground truth bounding boxes, + expected shapes [batch, num anchors, 4]. + gt_labels (torch.Tensor): ground truth bounding box labels, + expected shape [batch, num anchors]. + iou_threshold (float): IoU threshold to use in NMS. + Default is ``0.5``. + + Yields: + predicted sample (np.ndarray) and ground truth sample (np.ndarray) + """ + batch_size = predicted_tensor.size(0) + outputs = predicted_tensor.detach().cpu().numpy() + + _pred_boxes = outputs[:, :, :4] + _pred_confidence = outputs[:, :, 4] + _pred_cls = np.argmax(outputs[:, :, 5:], -1) + + _gt_boxes = gt_boxes.cpu().numpy() + _gt_classes = gt_labels.cpu().numpy() + _gt_boxes_mask = _gt_boxes.sum(axis=2) > 0 + + for i in range(batch_size): + # build predictions + sample_bboxes = change_box_order(_pred_boxes[i], "xywh2xyxy") + sample_bboxes, sample_classes, sample_confs = nms_filter( + sample_bboxes, _pred_cls[i], _pred_confidence[i], iou_threshold=iou_threshold + ) + pred_sample = np.concatenate( + [sample_bboxes, sample_classes[:, None], sample_confs[:, None]], -1 + ) + pred_sample = pred_sample.astype(np.float32) + + # build ground truth + sample_gt_mask = _gt_boxes_mask[i] + sample_gt_bboxes = change_box_order(_gt_boxes[i][sample_gt_mask], "xywh2xyxy") + sample_gt_classes = _gt_classes[i][sample_gt_mask] + gt_sample = np.zeros((sample_gt_classes.shape[0], 7), dtype=np.float32) + gt_sample[:, :4] = sample_gt_bboxes + gt_sample[:, 4] = sample_gt_classes + + yield pred_sample, gt_sample + + class DetectionMeanAveragePrecision(Callback): """Compute mAP for Object Detection task.""" @@ -208,7 +262,8 @@ def __init__( Default is ``1``. metric_key (str): name of a metric. Default is ``"mAP"``. - output_type (str): model output type. Valid values are ``"ssd"`` or ``"centernet"``. + output_type (str): model output type. Valid values are ``"ssd"`` or + ``"centernet"`` or ``"yolo-x"``. Default is ``"ssd"``. iou_threshold (float): IoU threshold to use in NMS. Default is ``0.5``. @@ -217,7 +272,7 @@ def __init__( Default is ``0.5``. """ super().__init__(order=CallbackOrder.Metric) - assert output_type in ("ssd", "centernet") + assert output_type in ("ssd", "centernet", "yolo-x") self.num_classes = num_classes self.metric_key = metric_key @@ -261,6 +316,14 @@ def on_batch_end(self, runner: "IRunner"): # noqa: D102, F821 confidence_threshold=self.confidence_threshold, ): self.metric_fn.add(predicted_sample, ground_truth_sample) + elif self.output_type == "yolo-x": + p_tensor = runner.batch["predicted_tensor"] + gt_box = runner.batch["bboxes"] + gt_labels = runner.batch["labels"] + for predicted_sample, ground_truth_sample in process_yolo_x_output( + p_tensor, gt_box, gt_labels, iou_threshold=self.iou_threshold + ): + self.metric_fn.add(predicted_sample, ground_truth_sample) def on_loader_end(self, runner: "IRunner"): # noqa: D102, F821 if not runner.is_valid_loader: diff --git a/examples/detection/centernet-config.yaml b/examples/detection/configs/centernet-config.yaml similarity index 87% rename from examples/detection/centernet-config.yaml rename to examples/detection/configs/centernet-config.yaml index 584633c867..b6cd45f0ad 100644 --- a/examples/detection/centernet-config.yaml +++ b/examples/detection/configs/centernet-config.yaml @@ -6,6 +6,8 @@ shared: train_imgs_dir: &train_imgs_dir ./datasets/fruit-detection/data/images valid_file: &valid ./datasets/fruit-detection/dataset.json valid_imgs_dir: &valid_imgs_dir ./datasets/fruit-detection/data/images + images_height: &img_h 512 + images_width: &img_w 512 runner: _target_: CenterNetDetectionRunner @@ -38,8 +40,8 @@ stages: _target_: albumentations.Compose transforms: - _target_: albu.Resize - height: 512 - width: 512 + height: *img_h + width: *img_w - _target_: albu.Normalize - _target_: albu.ToTensorV2 bbox_params: @@ -54,8 +56,8 @@ stages: _target_: albumentations.Compose transforms: - _target_: albu.Resize - height: 512 - width: 512 + height: *img_h + width: *img_w - _target_: albu.Normalize - _target_: albu.ToTensorV2 bbox_params: @@ -72,7 +74,7 @@ stages: callbacks: &callbacks periodic_validation: - _target_: PeriodicLoaderCallback + _target_: catalyst.callbacks.PeriodicLoaderCallback valid_loader_key: valid valid_metric_key: mAP minimize: False @@ -86,5 +88,5 @@ stages: confidence_threshold: 0.25 optimizer: - _target_: OptimizerCallback + _target_: catalyst.callbacks.OptimizerCallback metric_key: loss diff --git a/examples/detection/ssd-config.yaml b/examples/detection/configs/ssd-config.yaml similarity index 88% rename from examples/detection/ssd-config.yaml rename to examples/detection/configs/ssd-config.yaml index 9977d451e1..8fd03f0e6f 100644 --- a/examples/detection/ssd-config.yaml +++ b/examples/detection/configs/ssd-config.yaml @@ -7,6 +7,8 @@ shared: train_imgs_dir: &train_imgs_dir ./datasets/fruit-detection/data/images valid_file: &valid ./datasets/fruit-detection/dataset.json valid_imgs_dir: &valid_imgs_dir ./datasets/fruit-detection/data/images + images_height: &img_h 416 + images_width: &img_w 416 runner: _target_: SSDDetectionRunner @@ -39,8 +41,8 @@ stages: _target_: albumentations.Compose transforms: - _target_: albu.Resize - height: 300 - width: 300 + height: *img_h + width: *img_w - _target_: albu.Normalize - _target_: albu.ToTensorV2 bbox_params: @@ -55,8 +57,8 @@ stages: _target_: albumentations.Compose transforms: - _target_: albu.Resize - height: 300 - width: 300 + height: *img_h + width: *img_w - _target_: albu.Normalize - _target_: albu.ToTensorV2 bbox_params: @@ -74,7 +76,7 @@ stages: callbacks: &callbacks periodic_validation: - _target_: PeriodicLoaderCallback + _target_: catalyst.callbacks.PeriodicLoaderCallback valid_loader_key: valid valid_metric_key: mAP minimize: False @@ -87,5 +89,5 @@ stages: iou_threshold: 0.5 optimizer: - _target_: OptimizerCallback + _target_: catalyst.callbacks.OptimizerCallback metric_key: loss diff --git a/examples/detection/configs/yolo-x-config.yaml b/examples/detection/configs/yolo-x-config.yaml new file mode 100644 index 0000000000..ad1766e7bf --- /dev/null +++ b/examples/detection/configs/yolo-x-config.yaml @@ -0,0 +1,88 @@ +shared: + num_classes: &num_classes 4 + num_epochs: &num_epochs 100 + validation_period: &val_period 5 + train_file: &train /mnt/4tb/datasets/fruit-detection/dataset.json + train_imgs_dir: &train_imgs_dir /mnt/4tb/datasets/fruit-detection/data/images + valid_file: &valid /mnt/4tb/datasets/fruit-detection/dataset.json + valid_imgs_dir: &valid_imgs_dir /mnt/4tb/datasets/fruit-detection/data/images + images_height: &img_h 416 + images_width: &img_w 416 + +runner: + _target_: YOLOXDetectionRunner + +engine: + _target_: DeviceEngine + +model: + _target_: yolo_x_tiny + num_classes: *num_classes + +loggers: + console: + _target_: ConsoleLogger + +stages: + initial_training: + num_epochs: *num_epochs + + loaders: &loaders + batch_size: 4 + num_workers: 4 + + datasets: + train: + _target_: YOLOXDataset + coco_json_path: *train + images_dir: *train_imgs_dir + transforms: + _target_: albumentations.Compose + transforms: + - _target_: albu.Resize + height: *img_h + width: *img_w + - _target_: albu.Normalize + - _target_: albu.ToTensorV2 + bbox_params: + _target_: albu.BboxParams + format: albumentations + + valid: + _target_: YOLOXDataset + coco_json_path: *valid + images_dir: *valid_imgs_dir + transforms: + _target_: albumentations.Compose + transforms: + - _target_: albu.Resize + height: *img_h + width: *img_w + - _target_: albu.Normalize + - _target_: albu.ToTensorV2 + bbox_params: + _target_: albu.BboxParams + format: albumentations + + optimizer: + _target_: torch.optim.AdamW + lr: 0.001 + + callbacks: &callbacks + periodic_validation: + _target_: catalyst.callbacks.PeriodicLoaderCallback + valid_loader_key: valid + valid_metric_key: mAP + minimize: False + valid: *val_period + + mAP: + _target_: DetectionMeanAveragePrecision + num_classes: *num_classes + output_type: yolo-x + iou_threshold: 0.5 + confidence_threshold: 0.25 + + optimizer: + _target_: catalyst.callbacks.OptimizerCallback + metric_key: loss diff --git a/examples/detection/custom_runner.py b/examples/detection/custom_runner.py index a4c68edd88..2f48e1d6fb 100644 --- a/examples/detection/custom_runner.py +++ b/examples/detection/custom_runner.py @@ -1,3 +1,7 @@ +# flake8: noqa + +import torch + from catalyst.runners import ConfigRunner @@ -23,6 +27,21 @@ def handle_batch(self, batch): class CenterNetDetectionRunner(ConfigRunner): """Runner for CenterNet models.""" + def get_loaders(self, stage: str): + """Insert into loaders collate_fn. + + Args: + stage (str): sage name + + Returns: + ordered dict with torch.utils.data.DataLoader + """ + loaders = super().get_loaders(stage) + for item in loaders.values(): + if hasattr(item.dataset, "collate_fn"): + item.collate_fn = item.dataset.collate_fn + return loaders + def handle_batch(self, batch): """Do a forward pass and compute loss. @@ -42,6 +61,13 @@ def handle_batch(self, batch): self.batch_metrics["regression_loss"] = regression_loss.item() self.batch_metrics["loss"] = loss + +class YOLOXDetectionRunner(ConfigRunner): + """Runner for YOLO-X models.""" + + def get_model(self, *args, **kwargs): + return super().get_model(*args, **kwargs)() + def get_loaders(self, stage: str): """Insert into loaders collate_fn. @@ -56,3 +82,19 @@ def get_loaders(self, stage: str): if hasattr(item.dataset, "collate_fn"): item.collate_fn = item.dataset.collate_fn return loaders + + def handle_batch(self, batch): + """Do a forward pass and compute loss. + + Args: + batch (Dict[str, Any]): batch of data. + """ + + if self.is_train_loader: + images = batch["image"] + targets = torch.cat([batch["labels"].unsqueeze(-1), batch["bboxes"]], -1) + loss = self.model(images, targets) + self.batch_metrics["loss"] = loss + else: + predictions = self.model(batch["image"]) + self.batch["predicted_tensor"] = predictions diff --git a/examples/detection/dataset.py b/examples/detection/dataset.py index ae6f4c6866..178ca8cf98 100644 --- a/examples/detection/dataset.py +++ b/examples/detection/dataset.py @@ -106,6 +106,29 @@ def clip(values, min_value=0.0, max_value=1.0): return [min(max(num, min_value), max_value) for num in values] +def change_box_order(boxes, order): + """Change box order between + (xmin, ymin, xmax, ymax) <-> (xcenter, ycenter, width, height). + + Args: + boxes: (torch.Tensor or np.ndarray) bounding boxes, sized [N,4]. + order: (str) either "xyxy2xywh" or "xywh2xyxy". + + Returns: + (torch.Tensor) converted bounding boxes, sized [N,4]. + """ + if order not in {"xyxy2xywh", "xywh2xyxy"}: + raise ValueError("`order` should be one of 'xyxy2xywh'/'xywh2xyxy'!") + + concat_fn = torch.cat if isinstance(boxes, torch.Tensor) else np.concatenate + + a = boxes[:, :2] + b = boxes[:, 2:] + if order == "xyxy2xywh": + return concat_fn([(a + b) / 2, b - a], 1) + return concat_fn([a - b / 2, a + b / 2], 1) + + class SSDDataset(Dataset): def __init__( self, @@ -344,3 +367,96 @@ def collate_fn(batch): for k in ("image", "heatmap", "wh_regr"): packed_batch[k] = torch.stack(packed_batch[k], 0) return packed_batch + + +class YOLOXDataset(Dataset): + def __init__(self, coco_json_path, images_dir=None, transforms=None, max_objects_on_image=120): + self.file = coco_json_path + self.images_dir = images_dir + self.transforms = transforms + self.max_objects_on_image = max_objects_on_image + + self.images, self.categories = load_coco_json(coco_json_path) + self.images_list = sorted(self.images.keys()) + + self.class_to_cid = { + cls_idx: cat_id for cls_idx, cat_id in enumerate(sorted(self.categories.keys())) + } + self.cid_to_class = {v: k for k, v in self.class_to_cid.items()} + self.num_classes = len(self.class_to_cid) + self.class_labels = [ + self.categories[self.class_to_cid[cls_idx]] + for cls_idx in range(len(self.class_to_cid)) + ] + + def __len__(self): + return len(self.images_list) + + def __getitem__(self, index): + img_id = self.images_list[index] + img_record = self.images[img_id] + + path = img_record["file_name"] + if self.images_dir is not None: + path = os.path.join(self.images_dir, path) + image = read_image(path) + + boxes = [] # each element is a tuple of (x1, y1, x2, y2, "class") + raw_annotations = img_record["annotations"][: self.max_objects_on_image] + for annotation in raw_annotations: + xyxy = pixels_to_absolute( + annotation["bbox"], img_record["width"], img_record["height"] + ) + xyxy = clip(xyxy, 0.0, 1.0) + bbox_class = str(self.cid_to_class[annotation["category_id"]]) + boxes.append(xyxy + [str(bbox_class)]) + + if self.transforms is not None: + transformed = self.transforms(image=image, bboxes=boxes) + image, boxes = transformed["image"], transformed["bboxes"] + else: + image = torch.from_numpy((image / 255.0).astype(np.float32)).permute(2, 0, 1) + + bboxes = np.zeros((self.max_objects_on_image, 4), dtype=np.float32) + classes = np.zeros(self.max_objects_on_image, dtype=np.int32) + for idx, (x1, y1, x2, y2, box_cls) in enumerate(boxes[: self.max_objects_on_image]): + bboxes[idx, :] = [x1, y1, x2, y2] + classes[idx] = int(box_cls) + + # scaling [0,1] -> [h, w] + bboxes = bboxes * (image.size(1), image.size(2), image.size(1), image.size(2)) + bboxes = torch.from_numpy(bboxes) + bboxes = change_box_order(bboxes, "xyxy2xywh") + classes = torch.LongTensor(classes) + + return { + "image": image, + "boxes": bboxes, + "classes": classes, + } + + @staticmethod + def collate_fn(batch): + """ + Collect batch for YOLO X model. + + Args: + batch (List[Dict[str, torch.Tensor]]): + List with records from YOLOXDataset. + + Returns: + images batch with shape [B, C, H, W] + boxes with shape [B, MAX_OBJECTS, 4] + classes with shape [B, MAX_OBJECTS,] + """ + images, boxes, classes = [], [], [] + for item in batch: + images.append(item["image"]) + boxes.append(item["boxes"]) + classes.append(item["classes"]) + + images = torch.stack(images) + boxes = torch.stack(boxes) + classes = torch.stack(classes) + + return {"image": images, "bboxes": boxes, "labels": classes} diff --git a/examples/detection/models/__init__.py b/examples/detection/models/__init__.py new file mode 100644 index 0000000000..525223853a --- /dev/null +++ b/examples/detection/models/__init__.py @@ -0,0 +1,11 @@ +# flake8: noqa + +from .centernet import CenterNet +from .ssd import SingleShotDetector +from .yolo_x import ( + yolo_x_tiny, + yolo_x_small, + yolo_x_medium, + yolo_x_large, + yolo_x_big, +) diff --git a/examples/detection/models/centernet.py b/examples/detection/models/centernet.py new file mode 100644 index 0000000000..35b41755c6 --- /dev/null +++ b/examples/detection/models/centernet.py @@ -0,0 +1,128 @@ +# flake8: noqa + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import mobilenet, resnet + +_backbones = { + "resnet18": (resnet.resnet18, 512), + "resnet34": (resnet.resnet34, 512), + "resnet50": (resnet.resnet50, 2048), + "resnet101": (resnet.resnet101, 2048), + "resnet152": (resnet.resnet152, 2048), + "mobilenet_v2": (mobilenet.mobilenet_v2, 1280) + # "mobilenet_v3_small": (torchvision.models.mobilenet_v3_small, 576), + # "mobilenet_v3_large": (torchvision.models.mobilenet_v3_large, 960), +} + + +class DoubleConv(nn.Module): + """(conv => BN => ReLU) * 2""" + + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + x = self.conv(x) + return x + + +class Interpolate(nn.Module): + def __init__( + self, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + ): + super().__init__() + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def forward(self, inputs): + return F.interpolate( + inputs, + size=self.size, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + ) + + +class UpDoubleConv(nn.Module): + def __init__(self, in_channels, out_channels, mode=None): + super().__init__() + self.in_ch = in_channels + self.out_ch = out_channels + self.mode = mode + if mode is None: + self.up = nn.ConvTranspose2d(in_channels, in_channels, 2, stride=2) + else: + align_corners = None if mode == "nearest" else True + self.up = Interpolate(scale_factor=2, mode=mode, align_corners=align_corners) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2=None): + x1 = self.up(x1) + + if x2 is not None: + x = torch.cat([x2, x1], dim=1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) + else: + x = x1 + + x = self.conv(x) + return x + + +class CenterNet(nn.Module): + def __init__(self, num_classes=1, backbone="resnet18", upsample_mode="nearest"): + super().__init__() + # create backbone. + basemodel = _backbones[backbone][0](pretrained=True) + if backbone == "mobilenet_v2": + layers = list(basemodel.children())[:-1] + else: + layers = list(basemodel.children())[:-2] + basemodel = nn.Sequential(*layers) + # set basemodel + self.base_model = basemodel + self.upsample_mode = upsample_mode + + num_ch = _backbones[backbone][1] + + # original upsample mode was "bilinear" + self.up1 = UpDoubleConv(num_ch, 512, upsample_mode) + self.up2 = UpDoubleConv(512, 256, upsample_mode) + self.up3 = UpDoubleConv(256, 256, upsample_mode) + # output classification + self.out_classification = nn.Conv2d(256, num_classes, 1) + # output residue + self.out_residue = nn.Conv2d(256, 2, 1) + + def forward(self, x): + x = self.base_model(x) + # Add positional info + x = self.up1(x) + x = self.up2(x) + x = self.up3(x) + c = self.out_classification(x) # NOTE: do not forget to apply sigmoid to obtain scores! + r = self.out_residue(x) + return c, r diff --git a/examples/detection/model.py b/examples/detection/models/ssd.py similarity index 59% rename from examples/detection/model.py rename to examples/detection/models/ssd.py index 2ed9259390..060032d1d7 100644 --- a/examples/detection/model.py +++ b/examples/detection/models/ssd.py @@ -2,10 +2,7 @@ import torch import torch.nn as nn -import torch.nn.functional as F -from torchvision.models import mobilenet, resnet - -__all__ = ("SingleShotDetector", "CenterNet") +from torchvision.models import resnet _channels_map = { "resnet18": [256, 512, 512, 256, 256, 128], @@ -15,17 +12,6 @@ "resnet152": [1024, 512, 512, 256, 256, 256], } -_backbones = { - "resnet18": (resnet.resnet18, 512), - "resnet34": (resnet.resnet34, 512), - "resnet50": (resnet.resnet50, 2048), - "resnet101": (resnet.resnet101, 2048), - "resnet152": (resnet.resnet152, 2048), - "mobilenet_v2": (mobilenet.mobilenet_v2, 1280) - # "mobilenet_v3_small": (torchvision.models.mobilenet_v3_small, 576), - # "mobilenet_v3_large": (torchvision.models.mobilenet_v3_large, 960), -} - class ResnetBackbone(nn.Module): def __init__(self, backbone="resnet50", backbone_path=None): @@ -169,114 +155,3 @@ class confidence logits (torch.Tensor) with shapes [B, A, N_CLASSES], locs, confs = self.bbox_view(detection_feed, self.loc, self.conf) return locs, confs - - -class DoubleConv(nn.Module): - """(conv => BN => ReLU) * 2""" - - def __init__(self, in_ch, out_ch): - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_ch, out_ch, 3, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, 3, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True), - ) - - def forward(self, x): - x = self.conv(x) - return x - - -class Interpolate(nn.Module): - def __init__( - self, - size=None, - scale_factor=None, - mode="nearest", - align_corners=None, - recompute_scale_factor=None, - ): - super().__init__() - self.size = size - self.scale_factor = scale_factor - self.mode = mode - self.align_corners = align_corners - self.recompute_scale_factor = recompute_scale_factor - - def forward(self, inputs): - return F.interpolate( - inputs, - size=self.size, - scale_factor=self.scale_factor, - mode=self.mode, - align_corners=self.align_corners, - recompute_scale_factor=self.recompute_scale_factor, - ) - - -class UpDoubleConv(nn.Module): - def __init__(self, in_channels, out_channels, mode=None): - super().__init__() - self.in_ch = in_channels - self.out_ch = out_channels - self.mode = mode - if mode is None: - self.up = nn.ConvTranspose2d(in_channels, in_channels, 2, stride=2) - else: - align_corners = None if mode == "nearest" else True - self.up = Interpolate(scale_factor=2, mode=mode, align_corners=align_corners) - self.conv = DoubleConv(in_channels, out_channels) - - def forward(self, x1, x2=None): - x1 = self.up(x1) - - if x2 is not None: - x = torch.cat([x2, x1], dim=1) - # input is CHW - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) - else: - x = x1 - - x = self.conv(x) - return x - - -class CenterNet(nn.Module): - def __init__(self, num_classes=1, backbone="resnet18", upsample_mode="nearest"): - super().__init__() - # create backbone. - basemodel = _backbones[backbone][0](pretrained=True) - if backbone == "mobilenet_v2": - layers = list(basemodel.children())[:-1] - else: - layers = list(basemodel.children())[:-2] - basemodel = nn.Sequential(*layers) - # set basemodel - self.base_model = basemodel - self.upsample_mode = upsample_mode - - num_ch = _backbones[backbone][1] - - # original upsample mode was "bilinear" - self.up1 = UpDoubleConv(num_ch, 512, upsample_mode) - self.up2 = UpDoubleConv(512, 256, upsample_mode) - self.up3 = UpDoubleConv(256, 256, upsample_mode) - # output classification - self.out_classification = nn.Conv2d(256, num_classes, 1) - # output residue - self.out_residue = nn.Conv2d(256, 2, 1) - - def forward(self, x): - x = self.base_model(x) - # Add positional info - x = self.up1(x) - x = self.up2(x) - x = self.up3(x) - c = self.out_classification(x) # NOTE: do not forget to apply sigmoid to obtain scores! - r = self.out_residue(x) - return c, r diff --git a/examples/detection/models/yolo_x.py b/examples/detection/models/yolo_x.py new file mode 100644 index 0000000000..eacc8a7f21 --- /dev/null +++ b/examples/detection/models/yolo_x.py @@ -0,0 +1,1268 @@ +# flake8: noqa + +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +logger = logging.getLogger("yolo-x") + + +def get_activation(name="silu", inplace=True): + if name == "silu": + module = nn.SiLU(inplace=inplace) + elif name == "relu": + module = nn.ReLU(inplace=inplace) + elif name == "lrelu": + module = nn.LeakyReLU(0.1, inplace=inplace) + else: + raise AttributeError("Unsupported act type: {}".format(name)) + return module + + +class BaseConv(nn.Module): + """A Conv2d -> Batchnorm -> silu/leaky relu block""" + + def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"): + super().__init__() + # same padding + pad = (ksize - 1) // 2 + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=pad, + groups=groups, + bias=bias, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.act = get_activation(act, inplace=True) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class DWConv(nn.Module): + """Depthwise Conv + Conv""" + + def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): + super().__init__() + self.dconv = BaseConv( + in_channels, + in_channels, + ksize=ksize, + stride=stride, + groups=in_channels, + act=act, + ) + self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) + + def forward(self, x): + x = self.dconv(x) + return self.pconv(x) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__( + self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + act="silu", + ): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class ResLayer(nn.Module): + "Residual layer with `in_channels` inputs." + + def __init__(self, in_channels: int): + super().__init__() + mid_channels = in_channels // 2 + self.layer1 = BaseConv(in_channels, mid_channels, ksize=1, stride=1, act="lrelu") + self.layer2 = BaseConv(mid_channels, in_channels, ksize=3, stride=1, act="lrelu") + + def forward(self, x): + out = self.layer2(self.layer1(x)) + return x + out + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP""" + + def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"): + super().__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation) + self.m = nn.ModuleList( + [nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes] + ) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation) + + def forward(self, x): + x = self.conv1(x) + x = torch.cat([x] + [m(x) for m in self.m], dim=1) + x = self.conv2(x) + return x + + +class CSPLayer(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + expansion=0.5, + depthwise=False, + act="silu", + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act) + module_list = [ + Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) + for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + return self.conv3(x) + + +class Focus(nn.Module): + """Focus width and height information into channel space.""" + + def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"): + super().__init__() + self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + patch_top_left = x[..., ::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_left = x[..., 1::2, ::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) + + +class Darknet(nn.Module): + # number of blocks from dark2 to dark5. + depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]} + + def __init__( + self, + depth, + in_channels=3, + stem_out_channels=32, + out_features=("dark3", "dark4", "dark5"), + ): + """ + Args: + depth (int): depth of darknet used in model, usually use [21, 53] for this param. + in_channels (int): number of input channels, for example, use 3 for RGB image. + stem_out_channels (int): number of output chanels of darknet stem. + It decides channels of darknet layer2 to layer5. + out_features (Tuple[str]): desired output layer name. + """ + super().__init__() + assert out_features, "please provide output features of Darknet" + self.out_features = out_features + self.stem = nn.Sequential( + BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"), + *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2), + ) + in_channels = stem_out_channels * 2 # 64 + + num_blocks = Darknet.depth2blocks[depth] + # create darknet with `stem_out_channels` and `num_blocks` layers. + # to make model structure more clear, we don't use `for` statement in python. + self.dark2 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[0], stride=2)) + in_channels *= 2 # 128 + self.dark3 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[1], stride=2)) + in_channels *= 2 # 256 + self.dark4 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[2], stride=2)) + in_channels *= 2 # 512 + + self.dark5 = nn.Sequential( + *self.make_group_layer(in_channels, num_blocks[3], stride=2), + *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2), + ) + + def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1): + "starts with conv layer then has `num_blocks` `ResLayer`" + return [ + BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"), + *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)], + ] + + def make_spp_block(self, filters_list, in_filters): + m = nn.Sequential( + *[ + BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"), + BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"), + SPPBottleneck( + in_channels=filters_list[1], + out_channels=filters_list[0], + activation="lrelu", + ), + BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"), + BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"), + ] + ) + return m + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs["stem"] = x + x = self.dark2(x) + outputs["dark2"] = x + x = self.dark3(x) + outputs["dark3"] = x + x = self.dark4(x) + outputs["dark4"] = x + x = self.dark5(x) + outputs["dark5"] = x + return {k: v for k, v in outputs.items() if k in self.out_features} + + +class CSPDarknet(nn.Module): + def __init__( + self, + dep_mul, + wid_mul, + out_features=("dark3", "dark4", "dark5"), + depthwise=False, + act="silu", + ): + super().__init__() + assert out_features, "please provide output features of Darknet" + self.out_features = out_features + Conv = DWConv if depthwise else BaseConv + + base_channels = int(wid_mul * 64) # 64 + base_depth = max(round(dep_mul * 3), 1) # 3 + + # stem + self.stem = Focus(3, base_channels, ksize=3, act=act) + + # dark2 + self.dark2 = nn.Sequential( + Conv(base_channels, base_channels * 2, 3, 2, act=act), + CSPLayer( + base_channels * 2, + base_channels * 2, + n=base_depth, + depthwise=depthwise, + act=act, + ), + ) + + # dark3 + self.dark3 = nn.Sequential( + Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), + CSPLayer( + base_channels * 4, + base_channels * 4, + n=base_depth * 3, + depthwise=depthwise, + act=act, + ), + ) + + # dark4 + self.dark4 = nn.Sequential( + Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), + CSPLayer( + base_channels * 8, + base_channels * 8, + n=base_depth * 3, + depthwise=depthwise, + act=act, + ), + ) + + # dark5 + self.dark5 = nn.Sequential( + Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), + SPPBottleneck(base_channels * 16, base_channels * 16, activation=act), + CSPLayer( + base_channels * 16, + base_channels * 16, + n=base_depth, + shortcut=False, + depthwise=depthwise, + act=act, + ), + ) + + def forward(self, x): + outputs = {} + x = self.stem(x) + outputs["stem"] = x + x = self.dark2(x) + outputs["dark2"] = x + x = self.dark3(x) + outputs["dark3"] = x + x = self.dark4(x) + outputs["dark4"] = x + x = self.dark5(x) + outputs["dark5"] = x + return {k: v for k, v in outputs.items() if k in self.out_features} + + +class YOLOPAFPN(nn.Module): + """ + YOLOv3 model. Darknet 53 is the default backbone of this model. + """ + + def __init__( + self, + depth=1.0, + width=1.0, + in_features=("dark3", "dark4", "dark5"), + in_channels=[256, 512, 1024], + depthwise=False, + act="silu", + ): + super().__init__() + self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) + self.in_features = in_features + self.in_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.lateral_conv0 = BaseConv( + int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act + ) + self.C3_p4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) # cat + + self.reduce_conv1 = BaseConv( + int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act + ) + self.C3_p3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[0] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv2 = Conv( + int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act + ) + self.C3_n3 = CSPLayer( + int(2 * in_channels[0] * width), + int(in_channels[1] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + # bottom-up conv + self.bu_conv1 = Conv( + int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act + ) + self.C3_n4 = CSPLayer( + int(2 * in_channels[1] * width), + int(in_channels[2] * width), + round(3 * depth), + False, + depthwise=depthwise, + act=act, + ) + + def forward(self, input): + """ + Args: + inputs: input images. + + Returns: + Tuple[Tensor]: FPN feature. + """ + + # backbone + out_features = self.backbone(input) + features = [out_features[f] for f in self.in_features] + [x2, x1, x0] = features + + fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 + f_out0 = self.upsample(fpn_out0) # 512/16 + f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16 + f_out0 = self.C3_p4(f_out0) # 1024->512/16 + + fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 + f_out1 = self.upsample(fpn_out1) # 256/8 + f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8 + pan_out2 = self.C3_p3(f_out1) # 512->256/8 + + p_out1 = self.bu_conv2(pan_out2) # 256->256/16 + p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16 + pan_out1 = self.C3_n3(p_out1) # 512->512/16 + + p_out0 = self.bu_conv1(pan_out1) # 512->512/32 + p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32 + pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 + + outputs = (pan_out2, pan_out1, pan_out0) + return outputs + + +def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): + if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: + raise IndexError() + + if xyxy: + tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) + br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) + area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) + area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) + else: + tl = torch.max( + (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), + ) + br = torch.min( + (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), + (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), + ) + + area_a = torch.prod(bboxes_a[:, 2:], 1) + area_b = torch.prod(bboxes_b[:, 2:], 1) + en = (tl < br).type(tl.type()).prod(dim=2) + area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) + return area_i / (area_a[:, None] + area_b - area_i) + + +class IOUloss(nn.Module): + def __init__(self, reduction="none", loss_type="iou"): + super(IOUloss, self).__init__() + self.reduction = reduction + self.loss_type = loss_type + + def forward(self, pred, target): + assert pred.shape[0] == target.shape[0] + + pred = pred.view(-1, 4) + target = target.view(-1, 4) + tl = torch.max((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)) + br = torch.min((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)) + + area_p = torch.prod(pred[:, 2:], 1) + area_g = torch.prod(target[:, 2:], 1) + + en = (tl < br).type(tl.type()).prod(dim=1) + area_i = torch.prod(br - tl, 1) * en + iou = (area_i) / (area_p + area_g - area_i + 1e-16) + + if self.loss_type == "iou": + loss = 1 - iou ** 2 + elif self.loss_type == "giou": + c_tl = torch.min((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)) + c_br = torch.max((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)) + area_c = torch.prod(c_br - c_tl, 1) + giou = iou - (area_c - area_i) / area_c.clamp(1e-16) + loss = 1 - giou.clamp(min=-1.0, max=1.0) + + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + + return loss + + +class YOLOXHead(nn.Module): + def __init__( + self, + num_classes, + width=1.0, + strides=[8, 16, 32], + in_channels=[256, 512, 1024], + act="silu", + depthwise=False, + ): + """ + Args: + act (str): activation type of conv. + Default is `"silu"`. + depthwise (bool): wheather apply depthwise conv in conv branch. + Default is `False`. + """ + super().__init__() + + self.n_anchors = 1 + self.num_classes = num_classes + self.decode_in_inference = True # for deploy, set to False + + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.cls_preds = nn.ModuleList() + self.reg_preds = nn.ModuleList() + self.obj_preds = nn.ModuleList() + self.stems = nn.ModuleList() + _conv_class = DWConv if depthwise else BaseConv + + for i in range(len(in_channels)): + self.stems.append( + BaseConv( + in_channels=int(in_channels[i] * width), + out_channels=int(256 * width), + ksize=1, + stride=1, + act=act, + ) + ) + self.cls_convs.append( + nn.Sequential( + *[ + _conv_class( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + _conv_class( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ] + ) + ) + self.reg_convs.append( + nn.Sequential( + *[ + _conv_class( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + _conv_class( + in_channels=int(256 * width), + out_channels=int(256 * width), + ksize=3, + stride=1, + act=act, + ), + ] + ) + ) + self.cls_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * self.num_classes, + kernel_size=1, + stride=1, + padding=0, + ) + ) + self.reg_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=4, + kernel_size=1, + stride=1, + padding=0, + ) + ) + self.obj_preds.append( + nn.Conv2d( + in_channels=int(256 * width), + out_channels=self.n_anchors * 1, + kernel_size=1, + stride=1, + padding=0, + ) + ) + + self.use_l1 = False + self.l1_loss = nn.L1Loss(reduction="none") + self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none") + self.iou_loss = IOUloss(reduction="none") + self.strides = strides + # self.grids = nn.ParameterList( + # [nn.Parameter(torch.zeros(1), requires_grad=False) for _ in range(len(in_channels))] + # ) + self.grids = [torch.zeros(1)] * len(in_channels) + # self.expanded_strides = [None] * len(in_channels) + + def initialize_biases(self, prior_prob): + for conv in self.cls_preds: + b = conv.bias.view(self.n_anchors, -1) + b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) + conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + for conv in self.obj_preds: + b = conv.bias.view(self.n_anchors, -1) + b.data.fill_(-math.log((1 - prior_prob) / prior_prob)) + conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def forward(self, xin, labels=None, imgs=None): + outputs = [] + origin_preds = [] + x_shifts = [] + y_shifts = [] + expanded_strides = [] + + for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( + zip(self.cls_convs, self.reg_convs, self.strides, xin) + ): + x = self.stems[k](x) + cls_x = x + reg_x = x + + cls_feat = cls_conv(cls_x) + cls_output = self.cls_preds[k](cls_feat) + + reg_feat = reg_conv(reg_x) + reg_output = self.reg_preds[k](reg_feat) + obj_output = self.obj_preds[k](reg_feat) + + if self.training: + output = torch.cat([reg_output, obj_output, cls_output], 1) + output, grid = self.get_output_and_grid( + output, k, stride_this_level, xin[0].type() + ) + x_shifts.append(grid[:, :, 0]) + y_shifts.append(grid[:, :, 1]) + expanded_strides.append( + torch.zeros(1, grid.shape[1]).fill_(stride_this_level).type_as(xin[0]) + ) + if self.use_l1: + batch_size = reg_output.shape[0] + hsize, wsize = reg_output.shape[-2:] + reg_output = reg_output.view(batch_size, self.n_anchors, 4, hsize, wsize) + reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(batch_size, -1, 4) + origin_preds.append(reg_output.clone()) + + else: + output = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1) + + outputs.append(output) + + if self.training: + return self.get_losses( + imgs, + x_shifts, + y_shifts, + expanded_strides, + labels, + torch.cat(outputs, 1), + origin_preds, + dtype=xin[0].dtype, + ) + else: + self.hw = [x.shape[-2:] for x in outputs] + # [batch, n_anchors_all, 85] + outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1) + if self.decode_in_inference: + return self.decode_outputs(outputs, dtype=xin[0].type()) + else: + return outputs + + def get_output_and_grid(self, output, k, stride, dtype): + grid = self.grids[k] + + batch_size = output.shape[0] + n_ch = 5 + self.num_classes + hsize, wsize = output.shape[-2:] + if grid.shape[2:4] != output.shape[2:4]: + yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) + grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) + self.grids[k] = grid + + output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize) + output = output.permute(0, 1, 3, 4, 2).reshape( + batch_size, self.n_anchors * hsize * wsize, -1 + ) + grid = grid.view(1, -1, 2) + output[..., :2] = (output[..., :2] + grid) * stride + output[..., 2:4] = torch.exp(output[..., 2:4]) * stride + return output, grid + + def decode_outputs(self, outputs, dtype): + grids = [] + strides = [] + for (hsize, wsize), stride in zip(self.hw, self.strides): + yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) + grid = torch.stack((xv, yv), 2).view(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + strides.append(torch.full((*shape, 1), stride)) + + grids = torch.cat(grids, dim=1).type(dtype) + strides = torch.cat(strides, dim=1).type(dtype) + + outputs[..., :2] = (outputs[..., :2] + grids) * strides + outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides + return outputs + + def get_losses( + self, + imgs, + x_shifts, + y_shifts, + expanded_strides, + labels, + outputs, + origin_preds, + dtype, + ): + bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4] + obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1] + cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls] + + # calculate targets + mixup = labels.shape[2] > 5 + if mixup: + label_cut = labels[..., :5] + else: + label_cut = labels + nlabel = (label_cut.sum(dim=2) > 0).sum(dim=1) # number of objects + + total_num_anchors = outputs.shape[1] + x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all] + y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all] + expanded_strides = torch.cat(expanded_strides, 1) + if self.use_l1: + origin_preds = torch.cat(origin_preds, 1) + + cls_targets = [] + reg_targets = [] + l1_targets = [] + obj_targets = [] + fg_masks = [] + + num_fg = 0.0 + num_gts = 0.0 + + for batch_idx in range(outputs.shape[0]): + num_gt = int(nlabel[batch_idx]) + num_gts += num_gt + if num_gt == 0: + cls_target = outputs.new_zeros((0, self.num_classes)) + reg_target = outputs.new_zeros((0, 4)) + l1_target = outputs.new_zeros((0, 4)) + obj_target = outputs.new_zeros((total_num_anchors, 1)) + fg_mask = outputs.new_zeros(total_num_anchors).bool() + else: + gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5] + gt_classes = labels[batch_idx, :num_gt, 0] + bboxes_preds_per_image = bbox_preds[batch_idx] + + try: + ( + gt_matched_classes, + fg_mask, + pred_ious_this_matching, + matched_gt_inds, + num_fg_img, + ) = self.get_assignments( # noqa + batch_idx, + num_gt, + total_num_anchors, + gt_bboxes_per_image, + gt_classes, + bboxes_preds_per_image, + expanded_strides, + x_shifts, + y_shifts, + cls_preds, + bbox_preds, + obj_preds, + labels, + imgs, + ) + except RuntimeError: + logger.error( + "OOM RuntimeError is raised due to the huge " + "memory cost during label assignment. " + "CPU mode is applied in this batch. " + "If you want to avoid this issue, " + "try to reduce the batch size or image size." + ) + torch.cuda.empty_cache() + ( + gt_matched_classes, + fg_mask, + pred_ious_this_matching, + matched_gt_inds, + num_fg_img, + ) = self.get_assignments( # noqa + batch_idx, + num_gt, + total_num_anchors, + gt_bboxes_per_image, + gt_classes, + bboxes_preds_per_image, + expanded_strides, + x_shifts, + y_shifts, + cls_preds, + bbox_preds, + obj_preds, + labels, + imgs, + "cpu", + ) + + torch.cuda.empty_cache() + num_fg += num_fg_img + + cls_target = F.one_hot( + gt_matched_classes.to(torch.int64), self.num_classes + ) * pred_ious_this_matching.unsqueeze(-1) + obj_target = fg_mask.unsqueeze(-1) + reg_target = gt_bboxes_per_image[matched_gt_inds] + if self.use_l1: + l1_target = self.get_l1_target( + outputs.new_zeros((num_fg_img, 4)), + gt_bboxes_per_image[matched_gt_inds], + expanded_strides[0][fg_mask], + x_shifts=x_shifts[0][fg_mask], + y_shifts=y_shifts[0][fg_mask], + ) + + cls_targets.append(cls_target) + reg_targets.append(reg_target) + obj_targets.append(obj_target.to(dtype)) + fg_masks.append(fg_mask) + if self.use_l1: + l1_targets.append(l1_target) + + cls_targets = torch.cat(cls_targets, 0) + reg_targets = torch.cat(reg_targets, 0) + obj_targets = torch.cat(obj_targets, 0) + fg_masks = torch.cat(fg_masks, 0) + if self.use_l1: + l1_targets = torch.cat(l1_targets, 0) + + num_fg = max(num_fg, 1) + loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum() / num_fg + loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum() / num_fg + loss_cls = ( + self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets) + ).sum() / num_fg + if self.use_l1: + loss_l1 = (self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg + else: + loss_l1 = 0.0 + + reg_weight = 5.0 + loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1 + + return ( + loss, + reg_weight * loss_iou, + loss_obj, + loss_cls, + loss_l1, + num_fg / max(num_gts, 1), + ) + + def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8): + l1_target[:, 0] = gt[:, 0] / stride - x_shifts + l1_target[:, 1] = gt[:, 1] / stride - y_shifts + l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps) + l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps) + return l1_target + + @torch.no_grad() + def get_assignments( + self, + batch_idx, + num_gt, + total_num_anchors, + gt_bboxes_per_image, + gt_classes, + bboxes_preds_per_image, + expanded_strides, + x_shifts, + y_shifts, + cls_preds, + bbox_preds, + obj_preds, + labels, + imgs, + mode="gpu", + ): + + if mode == "cpu": + print("------------CPU Mode for This Batch-------------") + gt_bboxes_per_image = gt_bboxes_per_image.cpu().float() + bboxes_preds_per_image = bboxes_preds_per_image.cpu().float() + gt_classes = gt_classes.cpu().float() + expanded_strides = expanded_strides.cpu().float() + x_shifts = x_shifts.cpu() + y_shifts = y_shifts.cpu() + + fg_mask, is_in_boxes_and_center = self.get_in_boxes_info( + gt_bboxes_per_image, + expanded_strides, + x_shifts, + y_shifts, + total_num_anchors, + num_gt, + ) + + bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] + cls_preds_ = cls_preds[batch_idx][fg_mask] + obj_preds_ = obj_preds[batch_idx][fg_mask] + num_in_boxes_anchor = bboxes_preds_per_image.shape[0] + + if mode == "cpu": + gt_bboxes_per_image = gt_bboxes_per_image.cpu() + bboxes_preds_per_image = bboxes_preds_per_image.cpu() + + pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) + + gt_cls_per_image = ( + F.one_hot(gt_classes.to(torch.int64), self.num_classes) + .float() + .unsqueeze(1) + .repeat(1, num_in_boxes_anchor, 1) + ) + pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) + + if mode == "cpu": + cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu() + + cls_preds_ = ( + cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() + * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() # noqa: W503 + ) + pair_wise_cls_loss = F.binary_cross_entropy( + cls_preds_.sqrt_(), gt_cls_per_image, reduction="none" + ).sum(-1) + del cls_preds_ + + cost = ( + pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center) + ) + + ( + num_fg, + gt_matched_classes, + pred_ious_this_matching, + matched_gt_inds, + ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) + del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss + + if mode == "cpu": + gt_matched_classes = gt_matched_classes.cuda() + fg_mask = fg_mask.cuda() + pred_ious_this_matching = pred_ious_this_matching.cuda() + matched_gt_inds = matched_gt_inds.cuda() + + return ( + gt_matched_classes, + fg_mask, + pred_ious_this_matching, + matched_gt_inds, + num_fg, + ) + + def get_in_boxes_info( + self, + gt_bboxes_per_image, + expanded_strides, + x_shifts, + y_shifts, + total_num_anchors, + num_gt, + ): + expanded_strides_per_image = expanded_strides[0] + x_shifts_per_image = x_shifts[0] * expanded_strides_per_image + y_shifts_per_image = y_shifts[0] * expanded_strides_per_image + # fmt: off + x_centers_per_image = ( + (x_shifts_per_image + 0.5 * expanded_strides_per_image) + .unsqueeze(0) + .repeat(num_gt, 1) + ) # [n_anchor] -> [n_gt, n_anchor] + y_centers_per_image = ( + (y_shifts_per_image + 0.5 * expanded_strides_per_image) + .unsqueeze(0) + .repeat(num_gt, 1) + ) + # fmt: on + + gt_bboxes_per_image_l = ( + (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]) + .unsqueeze(1) + .repeat(1, total_num_anchors) + ) + gt_bboxes_per_image_r = ( + (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]) + .unsqueeze(1) + .repeat(1, total_num_anchors) + ) + gt_bboxes_per_image_t = ( + (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]) + .unsqueeze(1) + .repeat(1, total_num_anchors) + ) + gt_bboxes_per_image_b = ( + (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]) + .unsqueeze(1) + .repeat(1, total_num_anchors) + ) + + b_l = x_centers_per_image - gt_bboxes_per_image_l + b_r = gt_bboxes_per_image_r - x_centers_per_image + b_t = y_centers_per_image - gt_bboxes_per_image_t + b_b = gt_bboxes_per_image_b - y_centers_per_image + bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2) + + is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 + is_in_boxes_all = is_in_boxes.sum(dim=0) > 0 + # in fixed center + + center_radius = 2.5 + + gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat( + 1, total_num_anchors + ) - center_radius * expanded_strides_per_image.unsqueeze(0) + gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat( + 1, total_num_anchors + ) + center_radius * expanded_strides_per_image.unsqueeze(0) + gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat( + 1, total_num_anchors + ) - center_radius * expanded_strides_per_image.unsqueeze(0) + gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat( + 1, total_num_anchors + ) + center_radius * expanded_strides_per_image.unsqueeze(0) + + c_l = x_centers_per_image - gt_bboxes_per_image_l + c_r = gt_bboxes_per_image_r - x_centers_per_image + c_t = y_centers_per_image - gt_bboxes_per_image_t + c_b = gt_bboxes_per_image_b - y_centers_per_image + center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) + is_in_centers = center_deltas.min(dim=-1).values > 0.0 + is_in_centers_all = is_in_centers.sum(dim=0) > 0 + + # in boxes and in centers + is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all + + is_in_boxes_and_center = ( + is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor] + ) + return is_in_boxes_anchor, is_in_boxes_and_center + + def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): + # Dynamic K + # --------------------------------------------------------------- + matching_matrix = torch.zeros_like(cost) + + ious_in_boxes_matrix = pair_wise_ious + n_candidate_k = 10 + topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1) + dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) + for gt_idx in range(num_gt): + _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) + matching_matrix[gt_idx][pos_idx] = 1.0 + + del topk_ious, dynamic_ks, pos_idx + + anchor_matching_gt = matching_matrix.sum(0) + if (anchor_matching_gt > 1).sum() > 0: + cost_min, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) + matching_matrix[:, anchor_matching_gt > 1] *= 0.0 + matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 + fg_mask_inboxes = matching_matrix.sum(0) > 0.0 + num_fg = fg_mask_inboxes.sum().item() + + fg_mask[fg_mask.clone()] = fg_mask_inboxes + + matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) + gt_matched_classes = gt_classes[matched_gt_inds] + + pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes] + return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds + + +class YOLOX(nn.Module): + """ + YOLOX model module. + The network returns loss values from three YOLO layers during training and detection results during test. + NOTE: + - model predicts bounding boxes in image size ranges + - bounding boxes format is [x_center, y_center, width, height] + - output format is [x_center, y_center, width, height, confidence, class_0_prob, ..., class_N_prob] + """ + + def __init__(self, backbone=None, head=None): + super().__init__() + if backbone is None: + backbone = YOLOPAFPN() + if head is None: + head = YOLOXHead(80) + + self.backbone = backbone + self.head = head + + def forward(self, x, targets=None): + # fpn output content features of [dark3, dark4, dark5] + fpn_outs = self.backbone(x) + + if self.training: + assert targets is not None + loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(fpn_outs, targets, x) + return loss + else: + outputs = self.head(fpn_outs) + return outputs + + +def _init_fn(M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + +def _yolo_x(num_classes=80, depth=1, width=1): + in_channels = [256, 512, 1024] + backbone = YOLOPAFPN(depth, width, in_channels=in_channels) + head = YOLOXHead(num_classes, width, in_channels=in_channels) + model = YOLOX(backbone, head) + model.apply(_init_fn) + model.head.initialize_biases(1e-2) + return model + + +def yolo_x_tiny(num_classes=80, *args, **kwargs): + """YOLO X tiny. + Model expects 416x416 images and for that size will return `3549` anchors. + + Args: + num_classes (int): number of classes to use for detection. + Default value is `80`. + + Returns: + YOLOX model. + """ + model = _yolo_x(num_classes, depth=0.33, width=0.375) + return model + + +def yolo_x_small(num_classes=80, *args, **kwargs): + """YOLO X small. + Model expects 640x640 images and for that size will return `8400` anchors. + + Args: + num_classes (int): number of classes to use for detection. + Default value is `80`. + + Returns: + YOLOX model. + """ + model = _yolo_x(num_classes, depth=0.33, width=0.50) + return model + + +def yolo_x_medium(num_classes=80, *args, **kwargs): + """YOLO X medium. + Model expects 640x640 images and for that size will return `8400` anchors. + + Args: + num_classes (int): number of classes to use for detection. + Default value is `80`. + + Returns: + YOLOX model. + """ + model = _yolo_x(num_classes, depth=0.67, width=0.75) + return model + + +def yolo_x_large(num_classes=80, *args, **kwargs): + """YOLO X large. + Model expects 640x640 images and for that size will return `8400` anchors. + + Args: + num_classes (int): number of classes to use for detection. + Default value is `80`. + + Returns: + YOLOX model. + """ + model = _yolo_x(num_classes, depth=1.0, width=1.0) + return model + + +def yolo_x_big(num_classes=80, *args, **kwargs): + """YOLO X. + Model expects 640x640 images and for that size will return `8400` anchors. + + Args: + num_classes (int): number of classes to use for detection. + Default value is `80`. + + Returns: + YOLOX model. + """ + model = _yolo_x(num_classes, depth=1.33, width=1.25) + return model