In [None]:
"""
Inference code for YOLOv5
Offline evaluation pipeline
"""

In [None]:
from detection_core import batch_inference
from

In [None]:
data_yaml   = 'data/bdd100k.yaml'
weight = 'runs/train/exp11/weights/best.pt'
batch_size = 32
imgsz=640
device='cuda:0'

opt = {
    "conf_thres": 0.001,
    "iou_thres": 0.6,
    "augment": True,
    "verbose": False,
    "save_txt": True,
    "save_conf": False,
    "save_json": False,
    "single_cls": False,
}

opt = namedtuple("Opt", opt.keys())(*opt.values())

In [None]:
# load data
with open(data_yaml) as f:
    data = yaml.load(f, Loader=yaml.FullLoader)  # model dict
    
nc = data['nc']  # number of classes
iouv = torch.linspace(0.5, 0.95, 10).to(device)  # iou vector for mAP@0.5:0.95
niou = iouv.numel()


# Dataloader
img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
_ = model(img.half()) # run once
path = data['test']
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]

In [None]:
# define metrics
confusion_matrix = ConfusionMatrix(nc=nc)
names = {k: v for k, v in enumerate(model.names)}
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
stats = []
seen = 0

In [None]:
# run inference
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
    img = img.to(device, non_blocking=True)
    img = img.half() # uint8 to fp16/32
    img /= 255.0  # 0 - 255 to 0.0 - 1.0
    targets = targets.to(device)
    nb, _, height, width = img.shape  # batch size, channels, height, width

    with torch.no_grad():
        # Run model
        t = time_synchronized()
        output = make_inference(img, augment=opt.augment)  # inference
        t0 += time_synchronized() - t
        t1 += time_synchronized() - t
        
    # Statistics per image
    for si, pred in enumerate(output):
        labels = targets[targets[:, 0] == si, 1:] # get label of image si th
        n_label = len(labels)
        t_cls = labels[:, 0].tolist() if n_label else []  # target class
        path = Path(paths[si])
        seen += 1
        
        if len(pred) == 0:
            if n_label:
                stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), t_cls))
            continue
            
        # Predictions
        predn = pred.clone()
        scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1])  # native-space pred
        
        
        # Assign all predictions as incorrect
        correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
        if nl:
            detected = []  # target indices
            tcls_tensor = labels[:, 0]

            # target boxes
            tbox = xywh2xyxy(labels[:, 1:5])
            scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1])  # native-space labels

            # Per target class
            for cls in torch.unique(tcls_tensor):
                ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1)  # prediction indices
                pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1)  # target indices

                # Search for detections
                if pi.shape[0]:
                    # Prediction to target ious
                    ious, i = box_iou(predn[pi, :4], tbox[ti]).max(1)  # best ious, indices

                    # Append detections
                    detected_set = set()
                    for j in (ious > iouv[0]).nonzero(as_tuple=False):
                        d = ti[i[j]]  # detected target
                        if d.item() not in detected_set:
                            detected_set.add(d.item())
                            detected.append(d)
                            correct[pi[j]] = ious[j] > iouv  # iou_thres is 1xn
                            if len(detected) == nl:  # all targets already located in image
                                break

        # Append statistics (correct, conf, pcls, tcls)
        stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), t_cls))
        