Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples: yolo-x training loop example #1324

Merged
merged 6 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- YOLO-X (new) detection example and refactoring ([#1324](https://github.com/catalyst-team/catalyst/pull/1324))

### Changed

Expand Down
30 changes: 23 additions & 7 deletions examples/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,44 @@ python3 to_coco.py <images directory> <annotations directory> <output json file>

## Single Shot Detector (SSD)

For training SSD it is required for the dataset to return bounding box tensor with shapes `[num anchors, 4]` and labels tensor with shapes `[num anchors]`, `num anchors` is a fixed number of predictions generated by a model.
SSD expects that dataset will return bounding box tensor with shapes `[num anchors, 4]` and labels tensor with shapes `[num anchors]`, where `num anchors` is a fixed number of predictions generated by a model.

Each bounding box should be in a format `[x_center, y_center, width, height]`. Usually, there are fewer objects on an image than `num anchors` returned by a model, to handle this your model should have an additional class for `background` and bounding boxes to ignore (or empty bounding boxes) should be marked with that class.


```bash
catalyst-dl run \
--config catalyst/examples/detection/ssd-config.yaml \
--config catalyst/examples/detection/configs/ssd-config.yaml \
--expdir catalyst/examples/detection \
--logdir ssd-detection-logs
--logdir ssd-detection-logs \
--verbose
```


## Training CenterNet
## CenterNet

For training CenterNet it is required for the dataset to return heatmap tensor with object centers (expected shape is `[num classes, heatmap height, heatmap width]`) and tensor for Width-Height regression (expected shape is `[2, heatmap height, heatmap width]`).
CenterNet returns two tensors - class heatmap (tensor with shape `[num classes, heatmap height, heatmap width]`) and Width-Height regression (tensor with shape `[2, heatmap height, heatmap width]`). The last one represents bounding box width and height. To get object centers (`x_center` and `y_center`) you can apply MaxPool with `3x3` kernel size and stride - `1` and padding - `1` to class heatmap.

```bash
catalyst-dl run \
--config catalyst/examples/detection/centernet-config.yaml \
--config catalyst/examples/detection/configs/centernet-config.yaml \
--expdir catalyst/examples/detection \
--logdir centernet-detection-logs
--logdir centernet-detection-logs \
--verbose
```


## YOLO-X

YOLO-X predicts tensor with shape `[num anchors, 4 + 1 + num classes]`. This tensor consists of bounding box coordinates, class confidence, and class probability.
To extract bounding boxes need to get first 4 numbers for each anchor, confidences can be extracted by 4th index. The numbers after 4th index represent class probabilities.

**NOTE:** YOLO-X predicts bounding boxes in format `(x_center, y_center, w, h)`.

```bash
catalyst-dl run \
--config catalyst/examples/detection/configs/yolo-x-config.yaml \
--expdir catalyst/examples/detection \
--logdir centernet-detection-logs \
--verbose
```
25 changes: 22 additions & 3 deletions examples/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,35 @@

from .callbacks import DetectionMeanAveragePrecision # noqa: E402
from .criterion import CenterNetCriterion, SSDCriterion # noqa: E402
from .custom_runner import CenterNetDetectionRunner, SSDDetectionRunner # noqa: E402
from .dataset import CenterNetDataset, SSDDataset # noqa: E402
from .model import CenterNet, SingleShotDetector # noqa: E402
from .custom_runner import ( # noqa: E402
CenterNetDetectionRunner,
SSDDetectionRunner,
YOLOXDetectionRunner,
)
from .dataset import CenterNetDataset, SSDDataset, YOLOXDataset # noqa: E402
from .models import ( # noqa: E402
CenterNet,
SingleShotDetector,
yolo_x_tiny,
yolo_x_small,
yolo_x_medium,
yolo_x_large,
yolo_x_big,
)

# runers
Registry(SSDDetectionRunner)
Registry(CenterNetDetectionRunner)
Registry(YOLOXDetectionRunner)

# models
Registry(SingleShotDetector)
Registry(CenterNet)
Registry(yolo_x_tiny)
Registry(yolo_x_small)
Registry(yolo_x_medium)
Registry(yolo_x_large)
Registry(yolo_x_big)

# criterions
Registry(SSDCriterion)
Expand All @@ -29,3 +47,4 @@
# datasets
Registry(SSDDataset)
Registry(CenterNetDataset)
Registry(YOLOXDataset)
67 changes: 65 additions & 2 deletions examples/detection/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,60 @@ def process_centernet_output(
yield pred_sample, gt_sample


def process_yolo_x_output(
predicted_tensor,
gt_boxes,
gt_labels,
iou_threshold=0.5,
):
"""Generate bbox and classes from YOLO-X model outputs.

Args:
predicted_tensor (torch.Tensor): model outputs,
expected shapes [batch, num anchors, 4 + 1 + num classes].
gt_boxes (torch.Tensor): ground truth bounding boxes,
expected shapes [batch, num anchors, 4].
gt_labels (torch.Tensor): ground truth bounding box labels,
expected shape [batch, num anchors].
iou_threshold (float): IoU threshold to use in NMS.
Default is ``0.5``.

Yields:
predicted sample (np.ndarray) and ground truth sample (np.ndarray)
"""
batch_size = predicted_tensor.size(0)
outputs = predicted_tensor.detach().cpu().numpy()

_pred_boxes = outputs[:, :, :4]
_pred_confidence = outputs[:, :, 4]
_pred_cls = np.argmax(outputs[:, :, 5:], -1)

_gt_boxes = gt_boxes.cpu().numpy()
_gt_classes = gt_labels.cpu().numpy()
_gt_boxes_mask = _gt_boxes.sum(axis=2) > 0

for i in range(batch_size):
# build predictions
sample_bboxes = change_box_order(_pred_boxes[i], "xywh2xyxy")
sample_bboxes, sample_classes, sample_confs = nms_filter(
sample_bboxes, _pred_cls[i], _pred_confidence[i], iou_threshold=iou_threshold
)
pred_sample = np.concatenate(
[sample_bboxes, sample_classes[:, None], sample_confs[:, None]], -1
)
pred_sample = pred_sample.astype(np.float32)

# build ground truth
sample_gt_mask = _gt_boxes_mask[i]
sample_gt_bboxes = change_box_order(_gt_boxes[i][sample_gt_mask], "xywh2xyxy")
sample_gt_classes = _gt_classes[i][sample_gt_mask]
gt_sample = np.zeros((sample_gt_classes.shape[0], 7), dtype=np.float32)
gt_sample[:, :4] = sample_gt_bboxes
gt_sample[:, 4] = sample_gt_classes

yield pred_sample, gt_sample


class DetectionMeanAveragePrecision(Callback):
"""Compute mAP for Object Detection task."""

Expand All @@ -208,7 +262,8 @@ def __init__(
Default is ``1``.
metric_key (str): name of a metric.
Default is ``"mAP"``.
output_type (str): model output type. Valid values are ``"ssd"`` or ``"centernet"``.
output_type (str): model output type. Valid values are ``"ssd"`` or
``"centernet"`` or ``"yolo-x"``.
Default is ``"ssd"``.
iou_threshold (float): IoU threshold to use in NMS.
Default is ``0.5``.
Expand All @@ -217,7 +272,7 @@ def __init__(
Default is ``0.5``.
"""
super().__init__(order=CallbackOrder.Metric)
assert output_type in ("ssd", "centernet")
assert output_type in ("ssd", "centernet", "yolo-x")

self.num_classes = num_classes
self.metric_key = metric_key
Expand Down Expand Up @@ -261,6 +316,14 @@ def on_batch_end(self, runner: "IRunner"): # noqa: D102, F821
confidence_threshold=self.confidence_threshold,
):
self.metric_fn.add(predicted_sample, ground_truth_sample)
elif self.output_type == "yolo-x":
p_tensor = runner.batch["predicted_tensor"]
gt_box = runner.batch["bboxes"]
gt_labels = runner.batch["labels"]
for predicted_sample, ground_truth_sample in process_yolo_x_output(
p_tensor, gt_box, gt_labels, iou_threshold=self.iou_threshold
):
self.metric_fn.add(predicted_sample, ground_truth_sample)

def on_loader_end(self, runner: "IRunner"): # noqa: D102, F821
if not runner.is_valid_loader:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ shared:
train_imgs_dir: &train_imgs_dir ./datasets/fruit-detection/data/images
valid_file: &valid ./datasets/fruit-detection/dataset.json
valid_imgs_dir: &valid_imgs_dir ./datasets/fruit-detection/data/images
images_height: &img_h 512
images_width: &img_w 512

runner:
_target_: CenterNetDetectionRunner
Expand Down Expand Up @@ -38,8 +40,8 @@ stages:
_target_: albumentations.Compose
transforms:
- _target_: albu.Resize
height: 512
width: 512
height: *img_h
width: *img_w
- _target_: albu.Normalize
- _target_: albu.ToTensorV2
bbox_params:
Expand All @@ -54,8 +56,8 @@ stages:
_target_: albumentations.Compose
transforms:
- _target_: albu.Resize
height: 512
width: 512
height: *img_h
width: *img_w
- _target_: albu.Normalize
- _target_: albu.ToTensorV2
bbox_params:
Expand All @@ -72,7 +74,7 @@ stages:

callbacks: &callbacks
periodic_validation:
_target_: PeriodicLoaderCallback
_target_: catalyst.callbacks.PeriodicLoaderCallback
valid_loader_key: valid
valid_metric_key: mAP
minimize: False
Expand All @@ -86,5 +88,5 @@ stages:
confidence_threshold: 0.25

optimizer:
_target_: OptimizerCallback
_target_: catalyst.callbacks.OptimizerCallback
metric_key: loss
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ shared:
train_imgs_dir: &train_imgs_dir ./datasets/fruit-detection/data/images
valid_file: &valid ./datasets/fruit-detection/dataset.json
valid_imgs_dir: &valid_imgs_dir ./datasets/fruit-detection/data/images
images_height: &img_h 416
images_width: &img_w 416

runner:
_target_: SSDDetectionRunner
Expand Down Expand Up @@ -39,8 +41,8 @@ stages:
_target_: albumentations.Compose
transforms:
- _target_: albu.Resize
height: 300
width: 300
height: *img_h
width: *img_w
- _target_: albu.Normalize
- _target_: albu.ToTensorV2
bbox_params:
Expand All @@ -55,8 +57,8 @@ stages:
_target_: albumentations.Compose
transforms:
- _target_: albu.Resize
height: 300
width: 300
height: *img_h
width: *img_w
- _target_: albu.Normalize
- _target_: albu.ToTensorV2
bbox_params:
Expand All @@ -74,7 +76,7 @@ stages:

callbacks: &callbacks
periodic_validation:
_target_: PeriodicLoaderCallback
_target_: catalyst.callbacks.PeriodicLoaderCallback
valid_loader_key: valid
valid_metric_key: mAP
minimize: False
Expand All @@ -87,5 +89,5 @@ stages:
iou_threshold: 0.5

optimizer:
_target_: OptimizerCallback
_target_: catalyst.callbacks.OptimizerCallback
metric_key: loss
88 changes: 88 additions & 0 deletions examples/detection/configs/yolo-x-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
shared:
num_classes: &num_classes 4
num_epochs: &num_epochs 100
validation_period: &val_period 5
train_file: &train /mnt/4tb/datasets/fruit-detection/dataset.json
train_imgs_dir: &train_imgs_dir /mnt/4tb/datasets/fruit-detection/data/images
valid_file: &valid /mnt/4tb/datasets/fruit-detection/dataset.json
valid_imgs_dir: &valid_imgs_dir /mnt/4tb/datasets/fruit-detection/data/images
images_height: &img_h 416
images_width: &img_w 416

runner:
_target_: YOLOXDetectionRunner

engine:
_target_: DeviceEngine

model:
_target_: yolo_x_tiny
num_classes: *num_classes

loggers:
console:
_target_: ConsoleLogger

stages:
initial_training:
num_epochs: *num_epochs

loaders: &loaders
batch_size: 4
num_workers: 4

datasets:
train:
_target_: YOLOXDataset
coco_json_path: *train
images_dir: *train_imgs_dir
transforms:
_target_: albumentations.Compose
transforms:
- _target_: albu.Resize
height: *img_h
width: *img_w
- _target_: albu.Normalize
- _target_: albu.ToTensorV2
bbox_params:
_target_: albu.BboxParams
format: albumentations

valid:
_target_: YOLOXDataset
coco_json_path: *valid
images_dir: *valid_imgs_dir
transforms:
_target_: albumentations.Compose
transforms:
- _target_: albu.Resize
height: *img_h
width: *img_w
- _target_: albu.Normalize
- _target_: albu.ToTensorV2
bbox_params:
_target_: albu.BboxParams
format: albumentations

optimizer:
_target_: torch.optim.AdamW
lr: 0.001

callbacks: &callbacks
periodic_validation:
_target_: catalyst.callbacks.PeriodicLoaderCallback
valid_loader_key: valid
valid_metric_key: mAP
minimize: False
valid: *val_period

mAP:
_target_: DetectionMeanAveragePrecision
num_classes: *num_classes
output_type: yolo-x
iou_threshold: 0.5
confidence_threshold: 0.25

optimizer:
_target_: catalyst.callbacks.OptimizerCallback
metric_key: loss
Loading