diff --git a/README.md b/README.md index d98bd7bfa7da..b25c6fca983c 100755 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ $ pip install -r requirements.txt * [Supervisely Ecosystem](https://github.com/ultralytics/yolov5/issues/2518)  🌟 NEW * [Multi-GPU Training](https://github.com/ultralytics/yolov5/issues/475) * [PyTorch Hub](https://github.com/ultralytics/yolov5/issues/36)  ⭐ NEW -* [ONNX and TorchScript Export](https://github.com/ultralytics/yolov5/issues/251) +* [TorchScript, ONNX, CoreML Export](https://github.com/ultralytics/yolov5/issues/251) 🚀 * [Test-Time Augmentation (TTA)](https://github.com/ultralytics/yolov5/issues/303) * [Model Ensembling](https://github.com/ultralytics/yolov5/issues/318) * [Model Pruning/Sparsity](https://github.com/ultralytics/yolov5/issues/304) @@ -130,7 +130,7 @@ import torch model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Image -img = 'https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg' +img = 'https://ultralytics.com/images/zidane.jpg' # Inference results = model(img) diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/data/GlobalWheat2020.yaml b/data/GlobalWheat2020.yaml new file mode 100644 index 000000000000..f45182b43e25 --- /dev/null +++ b/data/GlobalWheat2020.yaml @@ -0,0 +1,55 @@ +# Global Wheat 2020 dataset http://www.global-wheat.com/ +# Train command: python train.py --data GlobalWheat2020.yaml +# Default dataset location is next to YOLOv5: +# /parent_folder +# /datasets/GlobalWheat2020 +# /yolov5 + + +# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/] +train: # 3422 images + - ../datasets/GlobalWheat2020/images/arvalis_1 + - ../datasets/GlobalWheat2020/images/arvalis_2 + - ../datasets/GlobalWheat2020/images/arvalis_3 + - ../datasets/GlobalWheat2020/images/ethz_1 + - ../datasets/GlobalWheat2020/images/rres_1 + - ../datasets/GlobalWheat2020/images/inrae_1 + - ../datasets/GlobalWheat2020/images/usask_1 + +val: # 748 images (WARNING: train set contains ethz_1) + - ../datasets/GlobalWheat2020/images/ethz_1 + +test: # 1276 images + - ../datasets/GlobalWheat2020/images/utokyo_1 + - ../datasets/GlobalWheat2020/images/utokyo_2 + - ../datasets/GlobalWheat2020/images/nau_1 + - ../datasets/GlobalWheat2020/images/uq_1 + +# number of classes +nc: 1 + +# class names +names: [ 'wheat_head' ] + + +# download command/URL (optional) -------------------------------------------------------------------------------------- +download: | + from utils.general import download, Path + + # Download + dir = Path('../datasets/GlobalWheat2020') # dataset directory + urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip', + 'https://github.com/ultralytics/yolov5/releases/download/v1.0/GlobalWheat2020_labels.zip'] + download(urls, dir=dir) + + # Make Directories + for p in 'annotations', 'images', 'labels': + (dir / p).mkdir(parents=True, exist_ok=True) + + # Move + for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \ + 'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1': + (dir / p).rename(dir / 'images' / p) # move to /images + f = (dir / p).with_suffix('.json') # json file + if f.exists(): + f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations diff --git a/data/SKU-110K.yaml b/data/SKU-110K.yaml new file mode 100644 index 000000000000..a8c1f25b385a --- /dev/null +++ b/data/SKU-110K.yaml @@ -0,0 +1,52 @@ +# SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 +# Train command: python train.py --data SKU-110K.yaml +# Default dataset location is next to YOLOv5: +# /parent_folder +# /datasets/SKU-110K +# /yolov5 + + +# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/] +train: ../datasets/SKU-110K/train.txt # 8219 images +val: ../datasets/SKU-110K/val.txt # 588 images +test: ../datasets/SKU-110K/test.txt # 2936 images + +# number of classes +nc: 1 + +# class names +names: [ 'object' ] + + +# download command/URL (optional) -------------------------------------------------------------------------------------- +download: | + import shutil + from tqdm import tqdm + from utils.general import np, pd, Path, download, xyxy2xywh + + # Download + datasets = Path('../datasets') # download directory + urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz'] + download(urls, dir=datasets, delete=False) + + # Rename directories + dir = (datasets / 'SKU-110K') + if dir.exists(): + shutil.rmtree(dir) + (datasets / 'SKU110K_fixed').rename(dir) # rename dir + (dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir + + # Convert labels + names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names + for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv': + x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations + images, unique_images = x[:, 0], np.unique(x[:, 0]) + with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f: + f.writelines(f'./images/{s}\n' for s in unique_images) + for im in tqdm(unique_images, desc=f'Converting {dir / d}'): + cls = 0 # single-class dataset + with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f: + for r in x[images == im]: + w, h = r[6], r[7] # image width, height + xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance + f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label diff --git a/data/visdrone.yaml b/data/VisDrone.yaml similarity index 95% rename from data/visdrone.yaml rename to data/VisDrone.yaml index c23e6bc286f8..c4603b200132 100644 --- a/data/visdrone.yaml +++ b/data/VisDrone.yaml @@ -1,5 +1,5 @@ # VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset -# Train command: python train.py --data visdrone.yaml +# Train command: python train.py --data VisDrone.yaml # Default dataset location is next to YOLOv5: # /parent_folder # /VisDrone @@ -20,11 +20,7 @@ names: [ 'pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', ' # download command/URL (optional) -------------------------------------------------------------------------------------- download: | - import os - from pathlib import Path - - from utils.general import download - + from utils.general import download, os, Path def visdrone2yolo(dir): from PIL import Image diff --git a/data/hyp.finetune_objects365.yaml b/data/hyp.finetune_objects365.yaml new file mode 100644 index 000000000000..2b104ef2d9bf --- /dev/null +++ b/data/hyp.finetune_objects365.yaml @@ -0,0 +1,28 @@ +lr0: 0.00258 +lrf: 0.17 +momentum: 0.779 +weight_decay: 0.00058 +warmup_epochs: 1.33 +warmup_momentum: 0.86 +warmup_bias_lr: 0.0711 +box: 0.0539 +cls: 0.299 +cls_pw: 0.825 +obj: 0.632 +obj_pw: 1.0 +iou_t: 0.2 +anchor_t: 3.44 +anchors: 3.2 +fl_gamma: 0.0 +hsv_h: 0.0188 +hsv_s: 0.704 +hsv_v: 0.36 +degrees: 0.0 +translate: 0.0902 +scale: 0.491 +shear: 0.0 +perspective: 0.0 +flipud: 0.0 +fliplr: 0.5 +mosaic: 1.0 +mixup: 0.0 diff --git a/data/objects365.yaml b/data/objects365.yaml new file mode 100644 index 000000000000..eb99995903cf --- /dev/null +++ b/data/objects365.yaml @@ -0,0 +1,102 @@ +# Objects365 dataset https://www.objects365.org/ +# Train command: python train.py --data objects365.yaml +# Default dataset location is next to YOLOv5: +# /parent_folder +# /datasets/objects365 +# /yolov5 + +# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/] +train: ../datasets/objects365/images/train # 1742289 images +val: ../datasets/objects365/images/val # 5570 images + +# number of classes +nc: 365 + +# class names +names: [ 'Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp', 'Glasses', 'Bottle', 'Desk', 'Cup', + 'Street Lights', 'Cabinet/shelf', 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', 'Book', + 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower', 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', + 'Pillow', 'Boots', 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt', 'Monitor/TV', + 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', + 'Stool', 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Basket', 'Drum', 'Pen/Pencil', 'Bus', 'Wild Bird', + 'High Heels', 'Motorcycle', 'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck', + 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop', 'Awning', + 'Bed', 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple', 'Air Conditioner', 'Knife', + 'Hockey Stick', 'Paddle', 'Pickup Truck', 'Fork', 'Traffic Sign', 'Balloon', 'Tripod', 'Dog', 'Spoon', 'Clock', + 'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger', 'Blackboard/Whiteboard', 'Napkin', 'Other Fish', + 'Orange/Tangerine', 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle', 'Fan', + 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane', 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', + 'Luggage', 'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone', 'Sports Car', 'Stop Sign', + 'Dessert', 'Scooter', 'Stroller', 'Crane', 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat', + 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza', 'Elephant', 'Skateboard', 'Surfboard', + 'Gun', 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot', 'Toilet', 'Kite', 'Strawberry', + 'Other Balls', 'Shovel', 'Pepper', 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks', + 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board', 'Coffee Table', 'Side Table', 'Scissors', + 'Marker', 'Pie', 'Ladder', 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', 'Zebra', 'Grape', + 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin', 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck', + 'Billiards', 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club', 'Briefcase', 'Cucumber', 'Cigar/Cigarette', + 'Paint Brush', 'Pear', 'Heavy Truck', 'Hamburger', 'Extractor', 'Extension Cord', 'Tong', 'Tennis Racket', + 'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', 'Tennis', 'Ship', 'Swing', 'Coffee Machine', + 'Slide', 'Carriage', 'Onion', 'Green beans', 'Projector', 'Frisbee', 'Washing Machine/Drying Machine', + 'Chicken', 'Printer', 'Watermelon', 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hot-air balloon', + 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog', 'Blender', 'Peach', 'Rice', 'Wallet/Purse', + 'Volleyball', 'Deer', 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple', 'Golf Ball', + 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle', 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', + 'Megaphone', 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', 'Sandwich', 'Nuts', + 'Speed Limit Sign', 'Induction Cooker', 'Broom', 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit', + 'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese', 'Notepaper', 'Cherry', 'Pliers', 'CD', + 'Pasta', 'Hammer', 'Cue', 'Avocado', 'Hamimelon', 'Flask', 'Mushroom', 'Screwdriver', 'Soap', 'Recorder', + 'Bear', 'Eggplant', 'Board Eraser', 'Coconut', 'Tape Measure/Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', + 'Steak', 'Crosswalk Sign', 'Stapler', 'Camel', 'Formula 1', 'Pomegranate', 'Dishwasher', 'Crab', + 'Hoverboard', 'Meat ball', 'Rice Cooker', 'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', + 'Butterfly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill', 'Hair Dryer', 'Egg tart', + 'Jellyfish', 'Treadmill', 'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi', 'Target', 'French', + 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case', 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', + 'Scallop', 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Tennis paddle', 'Cosmetics Brush/Eyeliner Pencil', + 'Chainsaw', 'Eraser', 'Lobster', 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling', 'Table Tennis' ] + + +# download command/URL (optional) -------------------------------------------------------------------------------------- +download: | + from pycocotools.coco import COCO + from tqdm import tqdm + + from utils.general import download, Path + + # Make Directories + dir = Path('../datasets/objects365') # dataset directory + for p in 'images', 'labels': + (dir / p).mkdir(parents=True, exist_ok=True) + for q in 'train', 'val': + (dir / p / q).mkdir(parents=True, exist_ok=True) + + # Download + url = "https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/train/" + download([url + 'zhiyuan_objv2_train.tar.gz'], dir=dir, delete=False) # annotations json + download([url + f for f in [f'patch{i}.tar.gz' for i in range(51)]], dir=dir / 'images' / 'train', + curl=True, delete=False, threads=8) + + # Move + train = dir / 'images' / 'train' + for f in tqdm(train.rglob('*.jpg'), desc=f'Moving images'): + f.rename(train / f.name) # move to /images/train + + # Labels + coco = COCO(dir / 'zhiyuan_objv2_train.json') + names = [x["name"] for x in coco.loadCats(coco.getCatIds())] + for cid, cat in enumerate(names): + catIds = coco.getCatIds(catNms=[cat]) + imgIds = coco.getImgIds(catIds=catIds) + for im in tqdm(coco.loadImgs(imgIds), desc=f'Class {cid + 1}/{len(names)} {cat}'): + width, height = im["width"], im["height"] + path = Path(im["file_name"]) # image filename + try: + with open(dir / 'labels' / 'train' / path.with_suffix('.txt').name, 'a') as file: + annIds = coco.getAnnIds(imgIds=im["id"], catIds=catIds, iscrowd=None) + for a in coco.loadAnns(annIds): + x, y, w, h = a['bbox'] # bounding box in xywh (xy top-left corner) + x, y = x + w / 2, y + h / 2 # xy to center + file.write(f"{cid} {x / width:.5f} {y / height:.5f} {w / width:.5f} {h / height:.5f}\n") + + except Exception as e: + print(e) diff --git a/data/scripts/get_argoverse_hd.sh b/data/scripts/get_argoverse_hd.sh index 18131a6764d6..331509914568 100644 --- a/data/scripts/get_argoverse_hd.sh +++ b/data/scripts/get_argoverse_hd.sh @@ -36,7 +36,7 @@ for val in annotation_files: img_name = a['images'][img_id]['name'] img_label_name = img_name[:-3] + "txt" - obj_class = annot['category_id'] + cls = annot['category_id'] # instance class id x_center, y_center, width, height = annot['bbox'] x_center = (x_center + width / 2) / 1920. # offset and scale y_center = (y_center + height / 2) / 1200. # offset and scale @@ -46,11 +46,10 @@ for val in annotation_files: img_dir = "./labels/" + a['seq_dirs'][a['images'][annot['image_id']]['sid']] Path(img_dir).mkdir(parents=True, exist_ok=True) - if img_dir + "/" + img_label_name not in label_dict: label_dict[img_dir + "/" + img_label_name] = [] - label_dict[img_dir + "/" + img_label_name].append(f"{obj_class} {x_center} {y_center} {width} {height}\n") + label_dict[img_dir + "/" + img_label_name].append(f"{cls} {x_center} {y_center} {width} {height}\n") for filename in label_dict: with open(filename, "w") as file: diff --git a/data/scripts/get_coco128.sh b/data/scripts/get_coco128.sh new file mode 100644 index 000000000000..395043b5b2dc --- /dev/null +++ b/data/scripts/get_coco128.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# COCO128 dataset https://www.kaggle.com/ultralytics/coco128 +# Download command: bash data/scripts/get_coco128.sh +# Train command: python train.py --data coco128.yaml +# Default dataset location is next to /yolov5: +# /parent_folder +# /coco128 +# /yolov5 + +# Download/unzip images and labels +d='../' # unzip directory +url=https://github.com/ultralytics/yolov5/releases/download/v1.0/ +f='coco128.zip' # or 'coco2017labels-segments.zip', 68 MB +echo 'Downloading' $url$f ' ...' +curl -L $url$f -o $f && unzip -q $f -d $d && rm $f & # download, unzip, remove in background + +wait # finish background tasks diff --git a/detect.py b/detect.py index ba42f349dbaf..732fec698006 100644 --- a/detect.py +++ b/detect.py @@ -5,7 +5,6 @@ import cv2 import torch import torch.backends.cudnn as cudnn -from numpy import random from models.experimental import attempt_load from utils.datasets import LoadStreams, LoadImages @@ -69,7 +68,8 @@ def detect(opt): pred = model(img, augment=opt.augment)[0] # Apply NMS - pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) + pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, opt.classes, opt.agnostic_nms, + max_det=opt.max_det) t2 = time_synchronized() # Apply Classifier @@ -79,7 +79,7 @@ def detect(opt): # Process detections for i, det in enumerate(pred): # detections per image if webcam: # batch_size >= 1 - p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count + p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count else: p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0) @@ -88,6 +88,7 @@ def detect(opt): txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt s += '%gx%g ' % img.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + imc = im0.copy() if opt.save_crop else im0 # for opt.save_crop if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() @@ -108,10 +109,9 @@ def detect(opt): if save_img or opt.save_crop or view_img: # Add bbox to image c = int(cls) # integer class label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}') - plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness) if opt.save_crop: - save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) + save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) # Print time (inference + NMS) print(f'{s}Done. ({t2 - t1:.3f}s)') @@ -154,6 +154,7 @@ def detect(opt): parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') + parser.add_argument('--max-det', type=int, default=1000, help='maximum number of detections per image') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') @@ -172,7 +173,7 @@ def detect(opt): parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') opt = parser.parse_args() print(opt) - check_requirements(exclude=('pycocotools', 'thop')) + check_requirements(exclude=('tensorboard', 'pycocotools', 'thop')) with torch.no_grad(): if opt.update: # update all models (to fix SourceChangeWarning) diff --git a/hubconf.py b/hubconf.py index e42d0b59bd2a..f74e70c85a65 100644 --- a/hubconf.py +++ b/hubconf.py @@ -5,20 +5,10 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5s') """ -from pathlib import Path - import torch -from models.yolo import Model -from utils.general import check_requirements, set_logging -from utils.google_utils import attempt_download -from utils.torch_utils import select_device - -dependencies = ['torch', 'yaml'] -check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('pycocotools', 'thop')) - -def create(name, pretrained, channels, classes, autoshape, verbose): +def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): """Creates a specified YOLOv5 model Arguments: @@ -26,105 +16,98 @@ def create(name, pretrained, channels, classes, autoshape, verbose): pretrained (bool): load pretrained weights into the model channels (int): number of input channels classes (int): number of model classes + autoshape (bool): apply YOLOv5 .autoshape() wrapper to model + verbose (bool): print all information to screen + device (str, torch.device, None): device to use for model parameters Returns: - pytorch model + YOLOv5 pytorch model """ + from pathlib import Path + + from models.yolo import Model, attempt_load + from utils.general import check_requirements, set_logging + from utils.google_utils import attempt_download + from utils.torch_utils import select_device + + check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('tensorboard', 'pycocotools', 'thop')) + set_logging(verbose=verbose) + + fname = Path(name).with_suffix('.pt') # checkpoint filename try: - set_logging(verbose=verbose) - - cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path - model = Model(cfg, channels, classes) - if pretrained: - fname = f'{name}.pt' # checkpoint filename - attempt_download(fname) # download if not found locally - ckpt = torch.load(fname, map_location=torch.device('cpu')) # load - msd = model.state_dict() # model state_dict - csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 - csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter - model.load_state_dict(csd, strict=False) # load - if len(ckpt['model'].names) == classes: - model.names = ckpt['model'].names # set class names attribute - if autoshape: - model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS - device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available + if pretrained and channels == 3 and classes == 80: + model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model + else: + cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path + model = Model(cfg, channels, classes) # create model + if pretrained: + attempt_download(fname) # download if not found locally + ckpt = torch.load(fname, map_location=torch.device('cpu')) # load + msd = model.state_dict() # model state_dict + csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter + model.load_state_dict(csd, strict=False) # load + if len(ckpt['model'].names) == classes: + model.names = ckpt['model'].names # set class names attribute + if autoshape: + model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS + device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device) return model.to(device) except Exception as e: help_url = 'https://github.com/ultralytics/yolov5/issues/36' - s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url + s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url raise Exception(s) from e -def custom(path_or_model='path/to/model.pt', autoshape=True, verbose=True): - """YOLOv5-custom model https://github.com/ultralytics/yolov5 - - Arguments (3 options): - path_or_model (str): 'path/to/model.pt' - path_or_model (dict): torch.load('path/to/model.pt') - path_or_model (nn.Module): torch.load('path/to/model.pt')['model'] - - Returns: - pytorch model - """ - set_logging(verbose=verbose) - - model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint - if isinstance(model, dict): - model = model['ema' if model.get('ema') else 'model'] # load model - - hub_model = Model(model.yaml).to(next(model.parameters()).device) # create - hub_model.load_state_dict(model.float().state_dict()) # load state_dict - hub_model.names = model.names # class names - if autoshape: - hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS - device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available - return hub_model.to(device) +def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None): + # YOLOv5 custom or local model + return _create(path, autoshape=autoshape, verbose=verbose, device=device) -def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): +def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): # YOLOv5-small model https://github.com/ultralytics/yolov5 - return create('yolov5s', pretrained, channels, classes, autoshape, verbose) + return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device) -def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): +def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): # YOLOv5-medium model https://github.com/ultralytics/yolov5 - return create('yolov5m', pretrained, channels, classes, autoshape, verbose) + return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device) -def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): +def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): # YOLOv5-large model https://github.com/ultralytics/yolov5 - return create('yolov5l', pretrained, channels, classes, autoshape, verbose) + return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device) -def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): +def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): # YOLOv5-xlarge model https://github.com/ultralytics/yolov5 - return create('yolov5x', pretrained, channels, classes, autoshape, verbose) + return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device) -def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): +def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5 - return create('yolov5s6', pretrained, channels, classes, autoshape, verbose) + return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device) -def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): +def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5 - return create('yolov5m6', pretrained, channels, classes, autoshape, verbose) + return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device) -def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): +def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5 - return create('yolov5l6', pretrained, channels, classes, autoshape, verbose) + return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device) -def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True): +def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5 - return create('yolov5x6', pretrained, channels, classes, autoshape, verbose) + return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device) if __name__ == '__main__': - model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained - # model = custom(path_or_model='path/to/model.pt') # custom + model = _create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained + # model = custom(path='path/to/model.pt') # custom # Verify inference import cv2 diff --git a/models/common.py b/models/common.py index 9764d4c3a6c0..4211db406c3d 100644 --- a/models/common.py +++ b/models/common.py @@ -215,26 +215,28 @@ class NMS(nn.Module): conf = 0.25 # confidence threshold iou = 0.45 # IoU threshold classes = None # (optional list) filter by class + max_det = 1000 # maximum number of detections per image def __init__(self): super(NMS, self).__init__() def forward(self, x): - return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) + return non_max_suppression(x[0], self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) -class autoShape(nn.Module): +class AutoShape(nn.Module): # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS conf = 0.25 # NMS confidence threshold iou = 0.45 # NMS IoU threshold classes = None # (optional list) filter by class + max_det = 1000 # maximum number of detections per image def __init__(self, model): - super(autoShape, self).__init__() + super(AutoShape, self).__init__() self.model = model.eval() def autoshape(self): - print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() + print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() return self @torch.no_grad() @@ -285,7 +287,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): t.append(time_synchronized()) # Post-process - y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS for i in range(n): scale_coords(shape1, y[i][:, :4], shape0[i]) diff --git a/models/experimental.py b/models/experimental.py index 548353c93be0..afa787907104 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -110,7 +110,9 @@ def forward(self, x, augment=False): return y, None # inference, train output -def attempt_load(weights, map_location=None): +def attempt_load(weights, map_location=None, inplace=True): + from models.yolo import Detect, Model + # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: @@ -120,15 +122,16 @@ def attempt_load(weights, map_location=None): # Compatibility updates for m in model.modules(): - if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: - m.inplace = True # pytorch 1.7.0 compatibility + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: + m.inplace = inplace # pytorch 1.7.0 compatibility elif type(m) is Conv: m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility if len(model) == 1: return model[-1] # return model else: - print('Ensemble created with %s\n' % weights) - for k in ['names', 'stride']: + print(f'Ensemble created with {weights}\n') + for k in ['names']: setattr(model, k, getattr(model[-1], k)) + model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride return model # return ensemble diff --git a/models/export.py b/models/export.py index 243819fab855..c3636ff346ec 100644 --- a/models/export.py +++ b/models/export.py @@ -1,7 +1,7 @@ -"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats +"""Exports a YOLOv5 *.pt model to TorchScript, ONNX, CoreML formats Usage: - $ export PYTHONPATH="$PWD" && python models/export.py --weights yolov5s.pt --img 640 --batch 1 + $ python path/to/models/export.py --weights yolov5s.pt --img 640 --batch 1 """ import argparse @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from torch.utils.mobile_optimizer import optimize_for_mobile from sparseml.pytorch.utils import ModuleExporter from sparseml.pytorch.utils.quantization import skip_onnx_input_quantize @@ -21,7 +22,7 @@ from models.experimental import attempt_load from models.yolo import Model from utils.activations import Hardswish, SiLU -from utils.general import set_logging, check_img_size +from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging from utils.google_utils import attempt_download from utils.sparse import SparseMLWrapper from utils.torch_utils import select_device, intersect_dicts, is_parallel, torch_distributed_zero_first @@ -101,11 +102,18 @@ def load_checkpoint(type_, weights, device, cfg=None, hyp=None, nc=None, recipe= parser.add_argument('--weights', type=str, default='./yolov3.pt', help='weights path') # from yolov3/models/ parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width parser.add_argument('--batch-size', type=int, default=1, help='batch size') - parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') - parser.add_argument('--grid', action='store_true', help='export Detect() layer grid') parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx', 'coreml'], help='include formats') + parser.add_argument('--half', action='store_true', help='FP16 half-precision export') + parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True') + parser.add_argument('--train', action='store_true', help='model.train() mode') + parser.add_argument('--optimize', action='store_true', help='optimize TorchScript for mobile') # TorchScript-only + parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only + parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only + parser.add_argument('--opset-version', type=int, default=12, help='ONNX opset version') # ONNX-only opt = parser.parse_args() opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand + opt.include = [x.lower() for x in opt.include] print(opt) set_logging() t = time.time() @@ -119,11 +127,16 @@ def load_checkpoint(type_, weights, device, cfg=None, hyp=None, nc=None, recipe= # Checks gs = int(max(model.stride)) # grid size (max stride) opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples + assert not (opt.device.lower() == 'cpu' and opt.half), '--half only compatible with GPU export, i.e. use --device 0' # Input img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection # Update model + if opt.half: + img, model = img.half(), model.half() # to FP16 + if opt.train: + model.train() # training mode (no grid construction in Detect layer) for k, m in model.named_modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility if isinstance(m, models.common.Conv): # assign export-friendly activations @@ -131,63 +144,88 @@ def load_checkpoint(type_, weights, device, cfg=None, hyp=None, nc=None, recipe= m.act = Hardswish() elif isinstance(m.act, nn.SiLU): m.act = SiLU() - # elif isinstance(m, models.yolo.Detect): - # m.forward = m.forward_export # assign forward (optional) - model.model[-1].export = not opt.grid # set Detect() layer grid export - y = model(img) # dry run - - # TorchScript export - try: - print('\nStarting TorchScript export with torch %s...' % torch.__version__) - f = opt.weights.replace('.pt', '.torchscript.pt') # filename - ts = torch.jit.trace(model, img, strict=False) - ts.save(f) - print('TorchScript export success, saved as %s' % f) - except Exception as e: - print('TorchScript export failure: %s' % e) - - # ONNX export - try: - import onnx - - print('\nStarting ONNX export with onnx %s...' % onnx.__version__) - f = opt.weights.replace('.pt', '.onnx') # filename - if not sparseml_wrapper.enabled: - torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], - output_names=['classes', 'boxes'] if y is None else ['output'], - dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) - 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None) - else: - # export through SparseML so quantized and pruned graphs can be corrected - save_dir = '/'.join(f.split('/')[:-1]) - save_name = f.split('/')[-1] - exporter = ModuleExporter(model, save_dir) - exporter.export_onnx(img, name=save_name, convert_qat=True) - try: - skip_onnx_input_quantize(f, f) - except: - pass - - # Checks - onnx_model = onnx.load(f) # load onnx model - onnx.checker.check_model(onnx_model) # check onnx model - # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model - print('ONNX export success, saved as %s' % f) - except Exception as e: - print('ONNX export failure: %s' % e) - - # CoreML export - try: - import coremltools as ct - - print('\nStarting CoreML export with coremltools %s...' % ct.__version__) - # convert model from torchscript and apply pixel scaling as per detect.py - model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) - f = opt.weights.replace('.pt', '.mlmodel') # filename - model.save(f) - print('CoreML export success, saved as %s' % f) - except Exception as e: - print('CoreML export failure: %s' % e) + elif isinstance(m, models.yolo.Detect): + m.inplace = opt.inplace + m.onnx_dynamic = opt.dynamic + # m.forward = m.forward_export # assign forward (optional) + + for _ in range(2): + y = model(img) # dry runs + print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)") + + # TorchScript export ----------------------------------------------------------------------------------------------- + if 'torchscript' in opt.include or 'coreml' in opt.include: + prefix = colorstr('TorchScript:') + try: + print(f'\n{prefix} starting export with torch {torch.__version__}...') + f = opt.weights.replace('.pt', '.torchscript.pt') # filename + ts = torch.jit.trace(model, img, strict=False) + (optimize_for_mobile(ts) if opt.optimize else ts).save(f) + print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + except Exception as e: + print(f'{prefix} export failure: {e}') + + # ONNX export ------------------------------------------------------------------------------------------------------ + if 'onnx' in opt.include: + prefix = colorstr('ONNX:') + try: + import onnx + + print(f'{prefix} starting export with onnx {onnx.__version__}...') + f = opt.weights.replace('.pt', '.onnx') # filename + if not sparseml_wrapper.enabled: + torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'], + dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) + 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None) + else: + # export through SparseML so quantized and pruned graphs can be corrected + save_dir = '/'.join(f.split('/')[:-1]) + save_name = f.split('/')[-1] + exporter = ModuleExporter(model, save_dir) + exporter.export_onnx(img, name=save_name, convert_qat=True) + try: + skip_onnx_input_quantize(f, f) + except: + pass + + # Checks + model_onnx = onnx.load(f) # load onnx model + onnx.checker.check_model(model_onnx) # check onnx model + # print(onnx.helper.printable_graph(model_onnx.graph)) # print + + # Simplify + if opt.simplify: + try: + check_requirements(['onnx-simplifier']) + import onnxsim + + print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') + model_onnx, check = onnxsim.simplify( + model_onnx, + dynamic_input_shape=opt.dynamic, + input_shapes={'images': list(img.shape)} if opt.dynamic else None) + assert check, 'assert check failed' + onnx.save(model_onnx, f) + except Exception as e: + print(f'{prefix} simplifier failure: {e}') + print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + except Exception as e: + print(f'{prefix} export failure: {e}') + + # CoreML export ---------------------------------------------------------------------------------------------------- + if 'coreml' in opt.include: + prefix = colorstr('CoreML:') + try: + import coremltools as ct + + print(f'{prefix} starting export with coremltools {ct.__version__}...') + assert opt.train, 'CoreML exports should be placed in model.train() mode with `python export.py --train`' + model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) + f = opt.weights.replace('.pt', '.mlmodel') # filename + model.save(f) + print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + except Exception as e: + print(f'{prefix} export failure: {e}') # Finish print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t)) \ No newline at end of file diff --git a/models/yolo.py b/models/yolo.py index a8c3b3622010..474fb8c5ca24 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -25,9 +25,9 @@ class Detect(nn.Module): stride = None # strides computed during build - export = False # onnx export + onnx_dynamic = False # ONNX export parameter - def __init__(self, nc=80, anchors=(), ch=()): # detection layer + def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer super(Detect, self).__init__() self.nc = nc # number of classes self.no = nc + 5 # number of outputs per anchor @@ -38,23 +38,28 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer self.register_buffer('anchors', a) # shape(nl,na,2) self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + self.inplace = inplace # use in-place ops (e.g. slice assignment) def forward(self, x): # x = x.copy() # for profiling z = [] # inference output - self.training |= self.export for i in range(self.nl): x[i] = self.m[i](x[i]) # conv bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() if not self.training: # inference - if self.grid[i].shape[2:4] != x[i].shape[2:4]: + if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic: self.grid[i] = self._make_grid(nx, ny).to(x[i].device) y = x[i].sigmoid() - y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy - y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + if self.inplace: + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 + xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) return x if self.training else (torch.cat(z, 1), x) @@ -86,12 +91,14 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i self.yaml['anchors'] = round(anchors) # override yaml value self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml['nc'])] # default names + self.inplace = self.yaml.get('inplace', True) # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) # Build strides, anchors m = self.model[-1] # Detect() if isinstance(m, Detect): s = 256 # 2x min stride + m.inplace = self.inplace m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward m.anchors /= m.stride.view(-1, 1, 1) check_anchor_order(m) @@ -106,24 +113,23 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i def forward(self, x, augment=False, profile=False): if augment: - img_size = x.shape[-2:] # height, width - s = [1, 0.83, 0.67] # scales - f = [None, 3, None] # flips (2-ud, 3-lr) - y = [] # outputs - for si, fi in zip(s, f): - xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) - yi = self.forward_once(xi)[0] # forward - # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save - yi[..., :4] /= si # de-scale - if fi == 2: - yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud - elif fi == 3: - yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr - y.append(yi) - return torch.cat(y, 1), None # augmented inference, train + return self.forward_augment(x) # augmented inference, None else: return self.forward_once(x, profile) # single-scale inference, train + def forward_augment(self, x): + img_size = x.shape[-2:] # height, width + s = [1, 0.83, 0.67] # scales + f = [None, 3, None] # flips (2-ud, 3-lr) + y = [] # outputs + for si, fi in zip(s, f): + xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) + yi = self.forward_once(xi)[0] # forward + # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save + yi = self._descale_pred(yi, fi, si, img_size) + y.append(yi) + return torch.cat(y, 1), None # augmented inference, train + def forward_once(self, x, profile=False): y, dt = [], [] # outputs for m in self.model: @@ -147,6 +153,23 @@ def forward_once(self, x, profile=False): logger.info('%.1fms total' % sum(dt)) return x + def _descale_pred(self, p, flips, scale, img_size): + # de-scale predictions following augmented inference (inverse operation) + if self.inplace: + p[..., :4] /= scale # de-scale + if flips == 2: + p[..., 1] = img_size[0] - p[..., 1] # de-flip ud + elif flips == 3: + p[..., 0] = img_size[1] - p[..., 0] # de-flip lr + else: + x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale + if flips == 2: + y = img_size[0] - y # de-flip ud + elif flips == 3: + x = img_size[1] - x # de-flip lr + p = torch.cat((x, y, wh, p[..., 4:]), -1) + return p + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency # https://arxiv.org/abs/1708.02002 section 3.3 # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. @@ -193,9 +216,9 @@ def nms(self, mode=True): # add or remove NMS module self.model = self.model[:-1] # remove return self - def autoshape(self): # add autoShape module - logger.info('Adding autoShape... ') - m = autoShape(self) # wrap model + def autoshape(self): # add AutoShape module + logger.info('Adding AutoShape... ') + m = AutoShape(self) # wrap model copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes return m diff --git a/requirements.txt b/requirements.txt index 7d5581435bb4..f651925f3a67 100755 --- a/requirements.txt +++ b/requirements.txt @@ -21,10 +21,11 @@ pandas # export -------------------------------------- # coremltools>=4.1 -# onnx>=1.8.1 +# onnx>=1.9.0 # scikit-learn==0.19.2 # for coreml quantization # extras -------------------------------------- -thop # FLOPS computation +# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172 pycocotools>=2.0 # COCO mAP sparseml==0.3.1 # Pruning and Quantization +thop # FLOPS computation diff --git a/test.py b/test.py index fa30b04438f6..4196dfe1ef40 100644 --- a/test.py +++ b/test.py @@ -188,8 +188,8 @@ def test(data, # Per target class for cls in torch.unique(tcls_tensor): - ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices - pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices + ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # target indices + pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # prediction indices # Search for detections if pi.shape[0]: @@ -310,7 +310,7 @@ def test(data, opt.save_json |= opt.data.endswith('coco.yaml') opt.data = check_file(opt.data) # check file print(opt) - check_requirements() + check_requirements(exclude=('tensorboard', 'pycocotools', 'thop')) if opt.task in ('train', 'val', 'test'): # run normally test(opt.data, diff --git a/train.py b/train.py index 30ff4974e7a7..4d3a5f87620a 100644 --- a/train.py +++ b/train.py @@ -346,9 +346,9 @@ def train(hyp, opt, device, tb_writer=None): if plots and ni < 3: f = save_dir / f'train_batch{ni}.jpg' # filename Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() - # if tb_writer: - # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) - # tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph + if tb_writer: + tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph + # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) elif plots and ni == 10 and wandb_logger.wandb: wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg') if x.exists()]}) @@ -385,8 +385,6 @@ def train(hyp, opt, device, tb_writer=None): # Write with open(results_file, 'a') as f: f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss - if len(opt.name) and opt.bucket: - os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) # Log tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss @@ -436,7 +434,7 @@ def train(hyp, opt, device, tb_writer=None): # Test best.pt logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) if opt.data.endswith('coco.yaml') and nc == 80: # if COCO - for m in (last, best) if best.exists() else (last): # speed, mAP tests + for m in [last, best] if best.exists() else [last]: # speed, mAP tests test_model, _ = load_checkpoint('ensemble', m, device) results, _, _ = test.test(opt.data, batch_size=batch_size * 2, @@ -462,7 +460,7 @@ def train(hyp, opt, device, tb_writer=None): if wandb_logger.wandb and not opt.evolve: # Log the stripped model wandb_logger.wandb.log_artifact(str(final), type='model', name='run_' + wandb_logger.wandb_run.id + '_model', - aliases=['last', 'best', 'stripped']) + aliases=['latest', 'best', 'stripped']) wandb_logger.finish_run() else: dist.destroy_process_group() @@ -517,7 +515,7 @@ def train(hyp, opt, device, tb_writer=None): set_logging(opt.global_rank) if opt.global_rank in [-1, 0]: check_git_status() - check_requirements() + check_requirements(exclude=('pycocotools', 'thop')) # Resume wandb_run = check_wandb_resume(opt) @@ -546,6 +544,7 @@ def train(hyp, opt, device, tb_writer=None): device = torch.device('cuda', opt.local_rank) dist.init_process_group(backend='nccl', init_method='env://') # distributed backend assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' + assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' opt.batch_size = opt.total_batch_size // opt.world_size # Hyperparameters diff --git a/tutorial.ipynb b/tutorial.ipynb index 245b46aa7d9f..3954feadfcb2 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -563,7 +563,7 @@ "clear_output()\n", "print(f\"Setup complete. Using torch {torch.__version__} ({torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'CPU'})\")" ], - "execution_count": 31, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -889,7 +889,7 @@ "id": "bOy5KI2ncnWd" }, "source": [ - "# Tensorboard (optional)\n", + "# Tensorboard (optional)\n", "%load_ext tensorboard\n", "%tensorboard --logdir runs/train" ], @@ -902,9 +902,10 @@ "id": "2fLAV42oNb7M" }, "source": [ - "# Weights & Biases (optional)\n", - "%pip install -q wandb \n", - "!wandb login # use 'wandb disabled' or 'wandb enabled' to disable or enable" + "# Weights & Biases (optional)\n", + "%pip install -q wandb\n", + "import wandb\n", + "wandb.login()" ], "execution_count": null, "outputs": [] diff --git a/utils/autoanchor.py b/utils/autoanchor.py index 75b350da729c..87dc394c832e 100644 --- a/utils/autoanchor.py +++ b/utils/autoanchor.py @@ -3,7 +3,6 @@ import numpy as np import torch import yaml -from scipy.cluster.vq import kmeans from tqdm import tqdm from utils.general import colorstr @@ -76,6 +75,8 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 Usage: from utils.autoanchor import *; _ = kmean_anchors() """ + from scipy.cluster.vq import kmeans + thr = 1. / thr prefix = colorstr('autoanchor: ') diff --git a/utils/datasets.py b/utils/datasets.py index 3fcdddd7c013..36416b14e138 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -172,7 +172,7 @@ def __next__(self): ret_val, img0 = self.cap.read() self.frame += 1 - print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='') + print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='') else: # Read image @@ -193,7 +193,7 @@ def __next__(self): def new_video(self, path): self.frame = 0 self.cap = cv2.VideoCapture(path) - self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) def __len__(self): return self.nf # number of files @@ -270,7 +270,7 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32): sources = [sources] n = len(sources) - self.imgs = [None] * n + self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n self.sources = [clean_str(x) for x in sources] # clean source names for later for i, s in enumerate(sources): # index, source # Start thread to read frames from video stream @@ -284,12 +284,13 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32): assert cap.isOpened(), f'Failed to open {s}' w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.fps = cap.get(cv2.CAP_PROP_FPS) % 100 + self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback + self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback _, self.imgs[i] = cap.read() # guarantee first frame - thread = Thread(target=self.update, args=([i, cap]), daemon=True) - print(f' success ({w}x{h} at {self.fps:.2f} FPS).') - thread.start() + self.threads[i] = Thread(target=self.update, args=([i, cap]), daemon=True) + print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)") + self.threads[i].start() print('') # newline # check for common shapes @@ -298,18 +299,17 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32): if not self.rect: print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') - def update(self, index, cap): - # Read next stream frame in a daemon thread - n = 0 - while cap.isOpened(): + def update(self, i, cap): + # Read stream `i` frames in daemon thread + n, f = 0, self.frames[i] + while cap.isOpened() and n < f: n += 1 # _, self.imgs[index] = cap.read() cap.grab() - if n == 4: # read every 4th frame + if n % 4: # read every 4th frame success, im = cap.retrieve() - self.imgs[index] = im if success else self.imgs[index] * 0 - n = 0 - time.sleep(1 / self.fps) # wait time + self.imgs[i] = im if success else self.imgs[i] * 0 + time.sleep(1 / self.fps[i]) # wait time def __iter__(self): self.count = -1 @@ -317,12 +317,12 @@ def __iter__(self): def __next__(self): self.count += 1 - img0 = self.imgs.copy() - if cv2.waitKey(1) == ord('q'): # q to quit + if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit cv2.destroyAllWindows() raise StopIteration # Letterbox + img0 = self.imgs.copy() img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0] # Stack @@ -490,20 +490,23 @@ def cache_labels(self, path=Path('./labels.cache'), prefix=''): x[im_file] = [l, shape, segments] except Exception as e: nc += 1 - print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') + logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \ f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" pbar.close() if nf == 0: - print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') + logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}') x['hash'] = get_hash(self.label_files + self.img_files) x['results'] = nf, nm, ne, nc, i + 1 x['version'] = 0.1 # cache version - torch.save(x, path) # save for next time - logging.info(f'{prefix}New cache created: {path}') + try: + torch.save(x, path) # save for next time + logging.info(f'{prefix}New cache created: {path}') + except Exception as e: + logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable return x def __len__(self): diff --git a/utils/flask_rest_api/README.md b/utils/flask_rest_api/README.md index 0cdc51be692d..324c2416dcd9 100644 --- a/utils/flask_rest_api/README.md +++ b/utils/flask_rest_api/README.md @@ -1,5 +1,5 @@ # Flask REST API -[REST](https://en.wikipedia.org/wiki/Representational_state_transfer) [API](https://en.wikipedia.org/wiki/API)s are commonly used to expose Machine Learning (ML) models to other services. This folder contains an example REST API created using Flask to expose the `yolov5s` model from [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/). +[REST](https://en.wikipedia.org/wiki/Representational_state_transfer) [API](https://en.wikipedia.org/wiki/API)s are commonly used to expose Machine Learning (ML) models to other services. This folder contains an example REST API created using Flask to expose the YOLOv5s model from [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/). ## Requirements @@ -22,30 +22,47 @@ Then use [curl](https://curl.se/) to perform a request: $ curl -X POST -F image=@zidane.jpg 'http://localhost:5000/v1/object-detection/yolov5s'` ``` -The model inference results are returned: +The model inference results are returned as a JSON response: -```shell -[{'class': 0, - 'confidence': 0.8197850585, - 'name': 'person', - 'xmax': 1159.1403808594, - 'xmin': 750.912902832, - 'ymax': 711.2583007812, - 'ymin': 44.0350036621}, - {'class': 0, - 'confidence': 0.5667674541, - 'name': 'person', - 'xmax': 1065.5523681641, - 'xmin': 116.0448303223, - 'ymax': 713.8904418945, - 'ymin': 198.4603881836}, - {'class': 27, - 'confidence': 0.5661227107, - 'name': 'tie', - 'xmax': 516.7975463867, - 'xmin': 416.6880187988, - 'ymax': 717.0524902344, - 'ymin': 429.2020568848}] +```json +[ + { + "class": 0, + "confidence": 0.8900438547, + "height": 0.9318675399, + "name": "person", + "width": 0.3264600933, + "xcenter": 0.7438579798, + "ycenter": 0.5207948685 + }, + { + "class": 0, + "confidence": 0.8440024257, + "height": 0.7155083418, + "name": "person", + "width": 0.6546785235, + "xcenter": 0.427829951, + "ycenter": 0.6334488392 + }, + { + "class": 27, + "confidence": 0.3771208823, + "height": 0.3902671337, + "name": "tie", + "width": 0.0696444362, + "xcenter": 0.3675483763, + "ycenter": 0.7991207838 + }, + { + "class": 27, + "confidence": 0.3527112305, + "height": 0.1540903747, + "name": "tie", + "width": 0.0336618312, + "xcenter": 0.7814827561, + "ycenter": 0.5065554976 + } +] ``` An example python script to perform inference using [requests](https://docs.python-requests.org/en/master/) is given in `example_request.py` diff --git a/utils/flask_rest_api/restapi.py b/utils/flask_rest_api/restapi.py index 9d88f618905d..a54e2309715c 100644 --- a/utils/flask_rest_api/restapi.py +++ b/utils/flask_rest_api/restapi.py @@ -24,15 +24,14 @@ def predict(): img = Image.open(io.BytesIO(image_bytes)) - results = model(img, size=640) - data = results.pandas().xyxy[0].to_json(orient="records") - return data + results = model(img, size=640) # reduce size=320 for faster inference + return results.pandas().xyxy[0].to_json(orient="records") if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Flask api exposing yolov5 model") + parser = argparse.ArgumentParser(description="Flask API exposing YOLOv5 model") parser.add_argument("--port", default=5000, type=int, help="port number") args = parser.parse_args() - model = torch.hub.load("ultralytics/yolov5", "yolov5s", force_reload=True).autoshape() # force_reload to recache + model = torch.hub.load("ultralytics/yolov5", "yolov5s", force_reload=True) # force_reload to recache app.run(host="0.0.0.0", port=args.port) # debug=True causes Restarting with stat diff --git a/utils/general.py b/utils/general.py index 3edf899d7411..7e0ac772bb03 100755 --- a/utils/general.py +++ b/utils/general.py @@ -16,6 +16,7 @@ import cv2 import numpy as np import pandas as pd +import pkg_resources as pkg import torch import torchvision import yaml @@ -51,11 +52,20 @@ def get_latest_run(search_dir='.'): return max(last_list, key=os.path.getctime) if last_list else '' -def isdocker(): +def is_docker(): # Is environment a Docker container return Path('/workspace').exists() # or Path('/.dockerenv').exists() +def is_colab(): + # Is environment a Google Colab instance + try: + import google.colab + return True + except Exception as e: + return False + + def emojis(str=''): # Return platform-dependent emoji-safe version of string return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str @@ -81,7 +91,7 @@ def check_git_status(): print(colorstr('github: '), end='') try: assert Path('.git').exists(), 'skipping check (not a git repository)' - assert not isdocker(), 'skipping check (Docker image)' + assert not is_docker(), 'skipping check (Docker image)' assert check_online(), 'skipping check (offline)' cmd = 'git fetch && git config --get remote.origin.url' @@ -98,10 +108,19 @@ def check_git_status(): print(e) +def check_python(minimum='3.7.0', required=True): + # Check current python version vs. required python version + current = platform.python_version() + result = pkg.parse_version(current) >= pkg.parse_version(minimum) + if required: + assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed' + return result + + def check_requirements(requirements='requirements.txt', exclude=()): # Check installed dependencies meet requirements (pass *.txt file or list of packages) - import pkg_resources as pkg prefix = colorstr('red', 'bold', 'requirements:') + check_python() # check python version if isinstance(requirements, (str, Path)): # requirements.txt file file = Path(requirements) if not file.exists(): @@ -118,7 +137,10 @@ def check_requirements(requirements='requirements.txt', exclude=()): except Exception as e: # DistributionNotFound or VersionConflict if requirements not met n += 1 print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...") - print(subprocess.check_output(f"pip install '{r}'", shell=True).decode()) + try: + print(subprocess.check_output(f"pip install '{r}'", shell=True).decode()) + except Exception as e: + print(f'{prefix} {e}') if n: # if packages updated source = file.resolve() if 'file' in locals() else requirements @@ -138,7 +160,8 @@ def check_img_size(img_size, s=32): def check_imshow(): # Check if environment supports image displays try: - assert not isdocker(), 'cv2.imshow() is disabled in Docker environments' + assert not is_docker(), 'cv2.imshow() is disabled in Docker environments' + assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments' cv2.imshow('test', np.zeros((1, 1, 3))) cv2.waitKey(1) cv2.destroyAllWindows() @@ -183,25 +206,34 @@ def check_dataset(dict): raise Exception('Dataset not found.') -def download(url, dir='.', multi_thread=False): +def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1): # Multi-threaded file download and unzip function def download_one(url, dir): # Download 1 file f = dir / Path(url).name # filename if not f.exists(): print(f'Downloading {url} to {f}...') - torch.hub.download_url_to_file(url, f, progress=True) # download - if f.suffix in ('.zip', '.gz'): + if curl: + os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail + else: + torch.hub.download_url_to_file(url, f, progress=True) # torch download + if unzip and f.suffix in ('.zip', '.gz'): print(f'Unzipping {f}...') if f.suffix == '.zip': - os.system(f'unzip -qo {f} -d {dir} && rm {f}') # unzip -quiet -overwrite + s = f'unzip -qo {f} -d {dir} && rm {f}' # unzip -quiet -overwrite elif f.suffix == '.gz': - os.system(f'tar xfz {f} --directory {f.parent} && rm {f}') # unzip + s = f'tar xfz {f} --directory {f.parent}' # unzip + if delete: # delete zip file after unzip + s += f' && rm {f}' + os.system(s) dir = Path(dir) dir.mkdir(parents=True, exist_ok=True) # make directory - if multi_thread: - ThreadPool(8).imap(lambda x: download_one(*x), zip(url, repeat(dir))) # 8 threads + if threads > 1: + pool = ThreadPool(threads) + pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded + pool.close() + pool.join() else: for u in tuple(url) if isinstance(url, str) else url: download_one(u, dir) @@ -453,7 +485,7 @@ def wh_iou(wh1, wh2): def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, - labels=()): + labels=(), max_det=300): """Runs Non-Maximum Suppression (NMS) on inference results Returns: @@ -463,9 +495,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non nc = prediction.shape[2] - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates + # Checks + assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' + assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' + # Settings min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height - max_det = 300 # maximum number of detections per image max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() time_limit = 10.0 # seconds to quit after redundant = True # require redundant detections @@ -627,8 +662,8 @@ def apply_classifier(x, model, img, im0): return x -def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False): - # Save an image crop as {file} with crop size multiplied by {gain} and padded by {pad} pixels +def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True): + # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop xyxy = torch.tensor(xyxy).view(-1, 4) b = xyxy2xywh(xyxy) # boxes if square: @@ -636,8 +671,10 @@ def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BG b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad xyxy = xywh2xyxy(b).long() clip_coords(xyxy, im.shape) - crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2])] - cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop if BGR else crop[..., ::-1]) + crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)] + if save: + cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop) + return crop def increment_path(path, exist_ok=False, sep='', mkdir=False): diff --git a/utils/google_utils.py b/utils/google_utils.py index 6a4660bad509..63d3e5b212f3 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -21,6 +21,7 @@ def attempt_download(file, repo='ultralytics/yolov5'): file = Path(str(file).strip().replace("'", '')) if not file.exists(): + file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) try: response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] @@ -47,7 +48,7 @@ def attempt_download(file, repo='ultralytics/yolov5'): assert redundant, 'No secondary mirror' url = f'https://storage.googleapis.com/{repo}/ckpt/{name}' print(f'Downloading {url} to {file}...') - os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights) + os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail finally: if not file.exists() or file.stat().st_size < 1E6: # check file.unlink(missing_ok=True) # remove partial downloads diff --git a/utils/plots.py b/utils/plots.py index ab6448aa96eb..8313ef210f90 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -16,7 +16,6 @@ import torch import yaml from PIL import Image, ImageDraw, ImageFont -from scipy.signal import butter, filtfilt from utils.general import xywh2xyxy, xyxy2xywh from utils.metrics import fitness @@ -29,7 +28,10 @@ class Colors: # Ultralytics color palette https://ultralytics.com/ def __init__(self): - self.palette = [self.hex2rgb(c) for c in matplotlib.colors.TABLEAU_COLORS.values()] + # hex = matplotlib.colors.TABLEAU_COLORS.values() + hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', + '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') + self.palette = [self.hex2rgb('#' + c) for c in hex] self.n = len(self.palette) def __call__(self, i, bgr=False): @@ -54,6 +56,8 @@ def hist2d(x, y, n=100): def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): + from scipy.signal import butter, filtfilt + # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy def butter_lowpass(cutoff, fs, order): nyq = 0.5 * fs @@ -64,11 +68,10 @@ def butter_lowpass(cutoff, fs, order): return filtfilt(b, a, data) # forward-backward filter -def plot_one_box(x, im, color=None, label=None, line_thickness=3): +def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3): # Plots one bounding box on image 'im' using OpenCV assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.' tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness - color = color or [random.randint(0, 255) for _ in range(3)] c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) if label: @@ -79,17 +82,16 @@ def plot_one_box(x, im, color=None, label=None, line_thickness=3): cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) -def plot_one_box_PIL(box, im, color=None, label=None, line_thickness=None): +def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=None): # Plots one bounding box on image 'im' using PIL im = Image.fromarray(im) draw = ImageDraw.Draw(im) line_thickness = line_thickness or max(int(min(im.size) / 200), 2) - draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot + draw.rectangle(box, width=line_thickness, outline=color) # plot if label: - fontsize = max(round(max(im.size) / 40), 12) - font = ImageFont.truetype("Arial.ttf", fontsize) + font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12)) txt_width, txt_height = font.getsize(label) - draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color)) + draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color) draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font) return np.asarray(im) @@ -295,7 +297,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): # matplotlib labels matplotlib.use('svg') # faster ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() - ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) + y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) + # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 ax[0].set_ylabel('instances') if 0 < len(names) < 30: ax[0].set_xticks(range(len(names))) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 06b902954655..36360136e891 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -72,11 +72,12 @@ def select_device(device='', batch_size=None): cuda = not cpu and torch.cuda.is_available() if cuda: - n = torch.cuda.device_count() - if n > 1 and batch_size: # check that batch_size is compatible with device_count + devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7 + n = len(devices) # device count + if n > 1 and batch_size: # check batch_size is divisible by device_count assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' space = ' ' * len(s) - for i, d in enumerate(device.split(',') if device else range(n)): + for i, d in enumerate(devices): p = torch.cuda.get_device_properties(i) s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB else: diff --git a/utils/wandb_logging/wandb_utils.py b/utils/wandb_logging/wandb_utils.py index d8fbd1ef42aa..57ce9035a777 100644 --- a/utils/wandb_logging/wandb_utils.py +++ b/utils/wandb_logging/wandb_utils.py @@ -1,3 +1,4 @@ +"""Utilities and tools for tracking runs with Weights & Biases.""" import json import sys from pathlib import Path @@ -9,7 +10,7 @@ sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path from utils.datasets import LoadImagesAndLabels from utils.datasets import img2label_paths -from utils.general import colorstr, xywh2xyxy, check_dataset +from utils.general import colorstr, xywh2xyxy, check_dataset, check_file try: import wandb @@ -35,8 +36,9 @@ def get_run_info(run_path): run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX)) run_id = run_path.stem project = run_path.parent.stem + entity = run_path.parent.parent.stem model_artifact_name = 'run_' + run_id + '_model' - return run_id, project, model_artifact_name + return entity, project, run_id, model_artifact_name def check_wandb_resume(opt): @@ -44,9 +46,9 @@ def check_wandb_resume(opt): if isinstance(opt.resume, str): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.global_rank not in [-1, 0]: # For resuming DDP runs - run_id, project, model_artifact_name = get_run_info(opt.resume) + entity, project, run_id, model_artifact_name = get_run_info(opt.resume) api = wandb.Api() - artifact = api.artifact(project + '/' + model_artifact_name + ':latest') + artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest') modeldir = artifact.download() opt.weights = str(Path(modeldir) / "last.pt") return True @@ -54,7 +56,7 @@ def check_wandb_resume(opt): def process_wandb_config_ddp_mode(opt): - with open(opt.data) as f: + with open(check_file(opt.data)) as f: data_dict = yaml.safe_load(f) # data dict train_dir, val_dir = None, None if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX): @@ -78,6 +80,18 @@ def process_wandb_config_ddp_mode(opt): class WandbLogger(): + """Log training runs, datasets, models, and predictions to Weights & Biases. + + This logger sends information to W&B at wandb.ai. By default, this information + includes hyperparameters, system configuration and metrics, model metrics, + and basic data metrics and analyses. + + By providing additional command line arguments to train.py, datasets, + models and predictions can also be logged. + + For more on how this logger is used, see the Weights & Biases documentation: + https://docs.wandb.com/guides/integrations/yolov5 + """ def __init__(self, opt, name, run_id, data_dict, job_type='Training'): # Pre-training routine -- self.job_type = job_type @@ -85,16 +99,17 @@ def __init__(self, opt, name, run_id, data_dict, job_type='Training'): # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call if isinstance(opt.resume, str): # checks resume from artifact if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): - run_id, project, model_artifact_name = get_run_info(opt.resume) + entity, project, run_id, model_artifact_name = get_run_info(opt.resume) model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name assert wandb, 'install wandb to resume wandb runs' # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config - self.wandb_run = wandb.init(id=run_id, project=project, resume='allow') + self.wandb_run = wandb.init(id=run_id, project=project, entity=entity, resume='allow') opt.resume = model_artifact_name elif self.wandb: self.wandb_run = wandb.init(config=opt, resume="allow", project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, + entity=opt.entity, name=name, job_type=job_type, id=run_id) if not wandb.run else wandb.run @@ -115,7 +130,7 @@ def __init__(self, opt, name, run_id, data_dict, job_type='Training'): def check_and_upload_dataset(self, opt): assert wandb, 'Install wandb to upload dataset' check_dataset(self.data_dict) - config_path = self.log_dataset_artifact(opt.data, + config_path = self.log_dataset_artifact(check_file(opt.data), opt.single_cls, 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem) print("Created dataset config file ", config_path) @@ -158,7 +173,8 @@ def setup_training(self, opt, data_dict): def download_dataset_artifact(self, path, alias): if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): - dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) + artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) + dataset_artifact = wandb.use_artifact(artifact_path.as_posix()) assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" datadir = dataset_artifact.download() return datadir, dataset_artifact @@ -171,8 +187,8 @@ def download_model_artifact(self, opt): modeldir = model_artifact.download() epochs_trained = model_artifact.metadata.get('epochs_trained') total_epochs = model_artifact.metadata.get('total_epochs') - assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % ( - total_epochs) + is_finished = total_epochs is None + assert not is_finished, 'training is finished, can only resume incomplete runs.' return modeldir, model_artifact return None, None @@ -187,7 +203,7 @@ def log_model(self, path, opt, epoch, fitness_score, best_model=False): }) model_artifact.add_file(str(path / 'last.pt'), name='last.pt') wandb.log_artifact(model_artifact, - aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else '']) + aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else '']) print("Saving model artifact on epoch ", epoch + 1) def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): @@ -196,9 +212,9 @@ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config= nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) names = {k: v for k, v in enumerate(names)} # to index dictionary self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( - data['train']), names, name='train') if data.get('train') else None + data['train'], rect=True, batch_size=1), names, name='train') if data.get('train') else None self.val_artifact = self.create_dataset_table(LoadImagesAndLabels( - data['val']), names, name='val') if data.get('val') else None + data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None if data.get('train'): data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train') if data.get('val'): @@ -243,16 +259,12 @@ def create_dataset_table(self, dataset, class_to_id, name='dataset'): table = wandb.Table(columns=["id", "train_image", "Classes", "name"]) class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()]) for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)): - height, width = shapes[0] - labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height]) box_data, img_classes = [], {} - for cls, *xyxy in labels[:, 1:].tolist(): + for cls, *xywh in labels[:, 1:].tolist(): cls = int(cls) - box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, + box_data.append({"position": {"middle": [xywh[0], xywh[1]], "width": xywh[2], "height": xywh[3]}, "class_id": cls, - "box_caption": "%s" % (class_to_id[cls]), - "scores": {"acc": 1}, - "domain": "pixel"}) + "box_caption": "%s" % (class_to_id[cls])}) img_classes[cls] = class_to_id[cls] boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes), @@ -294,7 +306,7 @@ def end_epoch(self, best_result=False): if self.result_artifact: train_results = wandb.JoinedTable(self.val_table, self.result_table, "id") self.result_artifact.add(train_results, 'result') - wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch), + wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), ('best' if best_result else '')]) self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")