Skip to content

Commit

Permalink
RTDETRDetectionModel TorchScript, ONNX Predict and Val support (ult…
Browse files Browse the repository at this point in the history
…ralytics#8818)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher authored and hmurari committed Apr 17, 2024
1 parent 4ac314e commit b9f37f4
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ jobs:
run: |
yolo checks
pip list
- name: Benchmark World DetectionModel
- name: Benchmark YOLOWorld DetectionModel
shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
- name: Benchmark SegmentationModel
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def test_export(model, format):
def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"):
"""Test the RTDETR functionality with the Ultralytics framework."""
# Warning: MUST use imgsz=640
run(f"yolo train {task} model={model} data={data} --imgsz= 640 epochs =1, cache = disk") # add coma, spaces to args
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=640 save save_crop save_txt")
run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk") # add coma, spaces to args
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt")


@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12")
Expand Down
2 changes: 0 additions & 2 deletions ultralytics/models/rtdetr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def __init__(self, model="rtdetr-l.pt") -> None:
Raises:
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
"""
if model and Path(model).suffix not in (".pt", ".yaml", ".yml"):
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
super().__init__(model=model, task="detect")

@property
Expand Down
5 changes: 4 additions & 1 deletion ultralytics/models/rtdetr/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ def postprocess(self, preds, img, orig_imgs):
The method filters detections based on confidence and class if specified in `self.args`.
Args:
preds (torch.Tensor): Raw predictions from the model.
preds (list): List of [predictions, extra] from the model.
img (torch.Tensor): Processed input images.
orig_imgs (list or torch.Tensor): Original, unprocessed images.
Returns:
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
and class labels.
"""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]

nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)

Expand Down
3 changes: 3 additions & 0 deletions ultralytics/models/rtdetr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def build_dataset(self, img_path, mode="val", batch=None):

def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]

bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def check_file(file, suffix="", download=True, hard=True):
downloads.safe_download(url=url, file=file, unzip=False)
return file
else: # search
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
Expand Down
41 changes: 41 additions & 0 deletions ultralytics/utils/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,44 @@ def get_latest_run(search_dir="."):
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ""


def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
"""
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
Args:
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
update_names (bool, optional): Update model names from a data YAML.
Example:
```python
from ultralytics.utils.files import update_models
model_names = (f"rtdetr-{size}.pt" for size in "lx")
update_models(model_names)
```
"""
from ultralytics import YOLO
from ultralytics.nn.autobackend import default_class_names

target_dir = source_dir / "updated_models"
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists

for model_name in model_names:
model_path = source_dir / model_name
print(f"Loading model from {model_path}")

# Load model
model = YOLO(model_path)
model.half()
if update_names: # update model names from a dataset YAML
model.model.names = default_class_names("coco8.yaml")

# Define new save path
save_path = target_dir / model_name

# Save model using model.save()
print(f"Re-saving {model_name} model to {save_path}")
model.save(save_path, use_dill=False)

0 comments on commit b9f37f4

Please sign in to comment.