In [33]:
import torch
import numpy as np
import os

class Predictor:
    def __init__(self, imgs, modelPath='data/yoloModels/best.pt'):
        self.imgs = imgs
        self.model = torch.hub.load('ultralytics/yolov5','custom', path=modelPath, verbose=False)
        self.centers = []
        self.results = [] # [[[xmin, xmax, ymin, ymax, certainty, classID, className, [centerX, centerY]], ...], ...]
        self.getPredictions()


    def __str__(self):
        return f'{len(self.imgs)} images, {self.name}'

    def getPredictions(self):
        for i, img in enumerate(self.imgs):
            currResult = self.model(img).pandas().xyxy[0].to_numpy()
            self.results.append([])
            for j, res in enumerate(currResult):
                self.results[-1].append([])
                for k, val in enumerate(res):
                    self.results[-1][-1].append(val)
                
                center = self.getCenter(*self.results[-1][-1][:4])
                self.results[-1][-1].append(center)


    def getCenter(self, xmin, xmax, ymin, ymax):
        return np.array([(xmin + xmax)/2, (ymin + ymax)/2]).astype(int)


In [34]:
path = 'data\conveyorImages'
imgs = [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.png')]


pred = Predictor(imgs=imgs)


YOLOv5  2022-4-27 torch 1.11.0+cpu CPU

Fusing layers... 
Model summary: 213 layers, 7012822 parameters, 0 gradients
Adding AutoShape... 


In [35]:
pred.results

[[[856.2462158203125,
   310.13323974609375,
   998.8065185546875,
   426.71905517578125,
   0.9314594268798828,
   0,
   'box',
   array([583, 712])]],
 [[732.5384521484375,
   334.01318359375,
   891.630859375,
   474.98284912109375,
   0.9456979632377625,
   0,
   'box',
   array([533, 683])]],
 [[628.7694091796875,
   357.84063720703125,
   793.7801513671875,
   512.0861206054688,
   0.9429858922958374,
   0,
   'box',
   array([493, 652])]],
 [[575.2089233398438,
   368.7104187011719,
   741.6259155273438,
   529.5492553710938,
   0.9529181122779846,
   0,
   'box',
   array([471, 635])]]]