diff --git a/.github/workflows/models.yaml b/.github/workflows/models.yaml index b1433bb4..59a30da2 100644 --- a/.github/workflows/models.yaml +++ b/.github/workflows/models.yaml @@ -21,6 +21,7 @@ jobs: python-version: '3.10' - run: | python -m pip install --upgrade pip - pip install torch torchvision boto3 + pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip install boto3 ultralytics - run: | python tools/convert-models.py diff --git a/tools/convert-models.py b/tools/convert-models.py index 639c9fd2..5d82677f 100644 --- a/tools/convert-models.py +++ b/tools/convert-models.py @@ -3,6 +3,7 @@ import os import boto3 from botocore.exceptions import ClientError +from ultralytics import settings def upload_blob(bucket_name, source_file_name, destination_blob_name): @@ -110,9 +111,34 @@ def blob_exist(bucket_name, blob_name): 'fasterrcnn_resnet50_v2': 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth', 'fasterrcnn_mobilenet_v3_large': 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth', 'fasterrcnn_mobilenet_v3_large_320': 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth', + 'yolo_v8_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8l.pt', + 'yolo_v8_l_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8l-seg.pt', + 'yolo_v8_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8m.pt', + 'yolo_v8_m_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8m-seg.pt', + 'yolo_v8_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', + 'yolo_v8_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s.pt', + 'yolo_v8_s_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s-seg.pt', + 'yolo_v8_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8x.pt', + 'yolo_v11_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11l.pt', + 'yolo_v11_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11m.pt', + 'yolo_v11_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt', + 'yolo_v11_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s.pt', + 'yolo_v11_s_cls': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-cls.pt', + 'yolo_v11_s_obb': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-obb.pt', + 'yolo_v11_s_pose': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-pose.pt', + 'yolo_v11_s_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-seg.pt', + 'yolo_v11_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11x.pt', + 'yolo_v12_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12l.pt', + 'yolo_v12_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12m.pt', + 'yolo_v12_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12n.pt', + 'yolo_v12_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12s.pt', + 'yolo_v12_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12x.pt', } os.makedirs("models", exist_ok=True) +# yolo specifics +os.makedirs("runs", exist_ok=True) +settings.update({"runs_dir": "runs/", "weights_dir": "models/", "sync": False}) for name, url in models.items(): fpath = "models/" + name + ".pth" @@ -124,6 +150,11 @@ def blob_exist(bucket_name, blob_name): # download from url, convert and upload the converted weights m = load_state_dict_from_url(url, progress=False) converted = {} + + # yolo models weights are embedded in a BaseModel per https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/tasks.py#L309 + if name.startswith("yolo_"): + m = m["model"].model.float().state_dict() + for nm, par in m.items(): converted.update([(nm, par.clone())]) torch.save(converted, fpath, _use_new_zipfile_serialization=True)