Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
python-version: '3.10'
- run: |
python -m pip install --upgrade pip
pip install torch torchvision boto3
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install boto3 ultralytics
- run: |
python tools/convert-models.py
31 changes: 31 additions & 0 deletions tools/convert-models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import boto3
from botocore.exceptions import ClientError
from ultralytics import settings


def upload_blob(bucket_name, source_file_name, destination_blob_name):
Expand Down Expand Up @@ -110,9 +111,34 @@ def blob_exist(bucket_name, blob_name):
'fasterrcnn_resnet50_v2': 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth',
'fasterrcnn_mobilenet_v3_large': 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth',
'fasterrcnn_mobilenet_v3_large_320': 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth',
'yolo_v8_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8l.pt',
'yolo_v8_l_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8l-seg.pt',
'yolo_v8_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8m.pt',
'yolo_v8_m_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8m-seg.pt',
'yolo_v8_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt',
'yolo_v8_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s.pt',
'yolo_v8_s_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s-seg.pt',
'yolo_v8_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8x.pt',
'yolo_v11_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11l.pt',
'yolo_v11_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11m.pt',
'yolo_v11_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt',
'yolo_v11_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s.pt',
'yolo_v11_s_cls': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-cls.pt',
'yolo_v11_s_obb': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-obb.pt',
'yolo_v11_s_pose': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-pose.pt',
'yolo_v11_s_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-seg.pt',
'yolo_v11_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11x.pt',
'yolo_v12_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12l.pt',
'yolo_v12_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12m.pt',
'yolo_v12_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12n.pt',
'yolo_v12_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12s.pt',
'yolo_v12_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12x.pt',
}

os.makedirs("models", exist_ok=True)
# yolo specifics
os.makedirs("runs", exist_ok=True)
settings.update({"runs_dir": "runs/", "weights_dir": "models/", "sync": False})

for name, url in models.items():
fpath = "models/" + name + ".pth"
Expand All @@ -124,6 +150,11 @@ def blob_exist(bucket_name, blob_name):
# download from url, convert and upload the converted weights
m = load_state_dict_from_url(url, progress=False)
converted = {}

# yolo models weights are embedded in a BaseModel per https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/tasks.py#L309
if name.startswith("yolo_"):
m = m["model"].model.float().state_dict()

for nm, par in m.items():
converted.update([(nm, par.clone())])
torch.save(converted, fpath, _use_new_zipfile_serialization=True)
Expand Down
Loading