In [55]:
import tensorrt as trt

In [56]:
import onnxruntime as ort

onnx_model_path = "./yolon_no_prune.onnx"
# session = ort.InferenceSession(onnx_model_path, providers=['CUDAExecutionProvider']))
# session = ort.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']))
session = ort.InferenceSession(onnx_model_path, providers=['TensorrtExecutionProvider'])

In [57]:
image_folder_path = "./JPEGImages"

In [58]:
import numpy as np
import torch
from torchvision.ops import nms

def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    # 左上右下
    lt, rb  = torch.split(distance, 2, dim)
    x1y1    = anchor_points - lt
    x2y2    = anchor_points + rb
    if xywh:
        c_xy    = (x1y1 + x2y2) / 2
        wh      = x2y2 - x1y1
        return torch.cat((c_xy, wh), dim)  # xywh bbox
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox


def decode_box(num_classes, input_shape, dbox, cls, anchors, strides):
    # dbox, cls, origin_cls, anchors, strides = inputs
    dbox = dist2bbox(dbox, anchors.unsqueeze(0), xywh=True, dim=1) * strides
    y = torch.cat((dbox, cls.sigmoid()), 1).permute(0, 2, 1)
    y[:, :, :4] = y[:, :, :4] / torch.Tensor([input_shape[1], input_shape[0], input_shape[1], input_shape[0]]).to(y.device)
    return y

def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image):
    box_yx = box_xy[..., ::-1]
    box_hw = box_wh[..., ::-1]
    input_shape = np.array(input_shape)
    image_shape = np.array(image_shape)

    if letterbox_image:
        new_shape = np.round(image_shape * np.min(input_shape/image_shape))
        offset = (input_shape - new_shape)/2./input_shape
        scale = input_shape/new_shape

        box_yx = (box_yx - offset) * scale
        box_hw *= scale

    box_mins = box_yx - (box_hw / 2.)
    box_maxes = box_yx + (box_hw / 2.)
    boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
    boxes *= np.concatenate([image_shape, image_shape], axis=-1)
    return boxes

def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
    box_corner = prediction.new(prediction.shape)
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4]

    output = [None for _ in range(len(prediction))]
    for i, image_pred in enumerate(prediction):
        class_conf, class_pred = torch.max(image_pred[:, 4:4 + num_classes], 1, keepdim=True)
        conf_mask = (class_conf[:, 0] >= conf_thres).squeeze()
        image_pred = image_pred[conf_mask]
        class_conf = class_conf[conf_mask]
        class_pred = class_pred[conf_mask]
        if not image_pred.size(0):
            continue
        detections = torch.cat((image_pred[:, :4], class_conf.float(), class_pred.float()), 1)
        unique_labels = detections[:, -1].cpu().unique()

        if prediction.is_cuda:
            unique_labels = unique_labels.cuda()
            detections = detections.cuda()

        for c in unique_labels:
            detections_class = detections[detections[:, -1] == c]
            keep = nms(detections_class[:, :4], detections_class[:, 4], nms_thres)
            max_detections = detections_class[keep]
            output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
        
        if output[i] is not None:
            output[i] = output[i].cpu().numpy()
            box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
            output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
    return output

    

In [59]:
input_shape = [1280, 1280]
num_classes = 7

In [60]:
import os
import time
import cv2
import numpy as np
from tqdm import tqdm  # Import tqdm for the progress bar
from PIL import ImageDraw, ImageFont, Image
import colorsys

# Assuming you have already loaded the model and created a session
# session = ort.InferenceSession(onnx_model_path)
def preprocess_input(image):
    image /= 255.0
    return image

def resize_image(image, size, letterbox_image):
    iw, ih  = image.size
    w, h    = size
    if letterbox_image:
        scale   = min(w/iw, h/ih)
        nw      = int(iw*scale)
        nh      = int(ih*scale)

        image   = image.resize((nw,nh), Image.BICUBIC)
        new_image = Image.new('RGB', size, (128,128,128))
        new_image.paste(image, ((w-nw)//2, (h-nh)//2))
    else:
        new_image = image.resize((w, h), Image.BICUBIC)
    return new_image

# Function to preprocess the image (modify as per your model's requirement)
def preprocess_image(image_path):
    img = Image.open(image_path)
    img  = resize_image(img, (input_shape[1], input_shape[0]), True)
    image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(img, dtype='float32')), (2, 0, 1)), 0)
    # img = cv2.resize(img, (640, 640))  # Resize to match model requirement
    # img = img.transpose(2, 0, 1)  # Change data layout from HWC to CHW
    # img = img.astype('float32') / 255.0  # Normalize if required
    # img = np.expand_dims(img, axis=0)  # Add batch dimension
    return image_data
import re

def natural_sort_key(s):
    return [int(s) if s.isdigit() else s.lower() for s in re.split('(\d+)', s)]

image_paths = sorted([os.path.join(image_folder_path, f) for f in os.listdir(image_folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))], key=natural_sort_key)
print(image_paths)
image_paths = image_paths[:1000]

total_time = 0
k = 0
# Process each image and perform inference with a progress bar
for image_path in tqdm(image_paths, desc="Processing images"):
    input_image = preprocess_image(image_path)
    image = Image.open(image_path)
    # Assuming the model takes an input named 'input' and outputs a tensor named 'output'
    image_shape = np.array(np.shape(image)[0:2])
    start_time = time.time()
    outputs = session.run(None, {'input': input_image})
    output = [torch.tensor(arr) for arr in outputs]
    outputs = [torch.tensor(arr) for arr in outputs]
    #0是对的，1是对的，6对4，5对3
    outputs = decode_box(num_classes, input_shape, outputs[0], outputs[1], outputs[5], outputs[6])

    #######################################################################我加的
    hsv_tuples = [(x / num_classes, 1., 1.) for x in range(num_classes)]
    colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
    colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
    target_width, target_height = image_shape[1], image_shape[0]
    target_aspect = target_width / target_height
    new_height = int(input_shape[0] / target_aspect)
    offset = (input_shape[1] - new_height) // 2
    masks = torch.argmax(output[7].long(), 1)
    masks = masks[0, :, :].cpu().detach().numpy()
    mask_rgb = np.zeros((*masks.shape, 3), dtype=np.uint8)
    mask_rgb[masks == 1] = [255, 0, 0]  # 红色
    mask_rgb[masks == 2] = [0, 255, 0]  # 绿色 
    mask_rgb[masks == 3] = [0, 0, 255]  # 蓝色
    cropped_image = mask_rgb[offset:(offset + new_height), :]
    mask_rgb = cv2.resize(cropped_image, (image_shape[1], image_shape[0]), interpolation=cv2.INTER_LINEAR)
    mask_image = Image.fromarray(mask_rgb)
    #######################################################################我加的

    results = non_max_suppression(outputs, num_classes, input_shape, 
                image_shape, True, conf_thres = 0.5, nms_thres = 0.3)
    if results[0] is None: continue
    else:
        top_label   = np.array(results[0][:, 5], dtype = 'int32')
        top_conf    = results[0][:, 4]
        top_boxes   = results[0][:, :4]
        print(image_path)
        print(f'image_shape: {image_shape}')
        print(f'top_label: {top_label}')
        print(f'top_conf: {top_conf}')
        print(f'top_boxes: {top_boxes}')

        #---------------------------------------------------------#
        #   设置字体与边框厚度
        #---------------------------------------------------------#
        font        = ImageFont.truetype(font='./simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
        thickness   = int(max((image.size[0] + image.size[1]) // np.mean(input_shape), 1))
        for i, c in list(enumerate(top_label)):
            predicted_class = int(c)
            box             = top_boxes[i]
            score           = top_conf[i]

            top, left, bottom, right = box

            top     = max(0, np.floor(top).astype('int32'))
            left    = max(0, np.floor(left).astype('int32'))
            bottom  = min(image.size[1], np.floor(bottom).astype('int32'))
            right   = min(image.size[0], np.floor(right).astype('int32'))

            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)
            label_size = draw.textsize(label, font)
            label = label.encode('utf-8')
            print(label, top, left, bottom, right)
            
            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            for i in range(thickness):
                draw.rectangle([left + i, top + i, right - i, bottom - i], outline=colors[c])
            draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=colors[c])
            draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
            del draw
            draw = ImageDraw.Draw(image)
            mask_image = mask_image.convert("RGBA")
            mask_array = np.array(mask_image)
            mask_array[:, :, 3] = (mask_array[:, :, :3] > 0).any(axis=2) * 128
            mask_image = Image.fromarray(mask_array)
            image = image.convert("RGBA")
            image.alpha_composite(mask_image)
            del draw
        image.save(f'./output/{k}.png')
        k = k + 1
                #######################################################################我加的
                # draw = ImageDraw.Draw(image)
                # mask_image = mask_image.convert("RGBA")
                # mask_array = np.array(mask_image)
                # mask_array[:, :, 3] = (mask_array[:, :, :3] > 0).any(axis=2) * 128
                # mask_image = Image.fromarray(mask_array)
                # image = image.convert("RGBA")
                # image.alpha_composite(mask_image)
                # del draw
                #######################################################################我加的

    end_time = time.time()
    total_time = total_time + (end_time - start_time)


fps = len(image_paths) / total_time

print(f"Processed {len(image_paths)} images in {total_time:.2f} seconds.")
print(f"Frame per second (FPS): {fps:.2f}")

['./JPEGImages/0.jpg', './JPEGImages/1.jpg', './JPEGImages/2.jpg', './JPEGImages/3.jpg', './JPEGImages/4.jpg', './JPEGImages/5.jpg', './JPEGImages/6.jpg', './JPEGImages/7.jpg', './JPEGImages/8.jpg', './JPEGImages/9.jpg', './JPEGImages/10.jpg', './JPEGImages/11.jpg', './JPEGImages/12.jpg']


Processing images:   0%|          | 0/13 [00:00<?, ?it/s]

Processing images:   8%|▊         | 1/13 [00:00<00:05,  2.30it/s]

./JPEGImages/0.jpg
image_shape: [ 720 1280]
top_label: [0]
top_conf: [0.7588528]
top_boxes: [[330.1515  463.5061  346.01993 482.37976]]
b'0 0.76' 330 463 346 482


Processing images:  15%|█▌        | 2/13 [00:00<00:04,  2.31it/s]

./JPEGImages/1.jpg
image_shape: [ 720 1280]
top_label: [0]
top_conf: [0.77763134]
top_boxes: [[403.4912  332.62305 442.713   375.2893 ]]
b'0 0.78' 403 332 442 375
./JPEGImages/2.jpg
image_shape: [ 720 1280]
top_label: [0 0 0 0]
top_conf: [0.90948296 0.89358085 0.8629413  0.7074271 ]
top_boxes: [[327.31598  -1.47192 484.00757 226.59546]
 [362.62933 456.36084 459.98343 581.25635]
 [342.83203 245.34608 393.43063 304.55188]
 [334.53024 195.3917  370.79248 240.9744 ]]
b'0 0.91' 327 0 484 226
b'0 0.89' 362 456 459 581
b'0 0.86' 342 245 393 304
b'0 0.71' 334 195 370 240


Processing images:  23%|██▎       | 3/13 [00:01<00:05,  1.96it/s]

./JPEGImages/3.jpg
image_shape: [1080 1920]
top_label: [0 0 0 0 0 0 1]
top_conf: [0.86952496 0.86128706 0.79200983 0.7785965  0.7571234  0.6416642
 0.7644294 ]
top_boxes: [[ 513.69556 1419.8953   613.6967  1667.4802 ]
 [ 521.3985   859.8302   615.79767  985.80334]
 [ 525.99896  388.0211   588.63226  474.7704 ]
 [ 529.89044  784.23486  562.3664   823.19   ]
 [ 531.3195   737.2389   556.40454  764.218  ]
 [ 525.63354  535.7323   554.7002   579.95245]
 [ 524.3351  1151.9467   606.939   1193.6678 ]]
b'0 0.87' 513 1419 613 1667
b'0 0.86' 521 859 615 985
b'0 0.79' 525 388 588 474
b'0 0.78' 529 784 562 823
b'0 0.76' 531 737 556 764
b'0 0.64' 525 535 554 579
b'1 0.76' 524 1151 606 1193


Processing images:  38%|███▊      | 5/13 [00:02<00:04,  1.76it/s]

./JPEGImages/4.jpg
image_shape: [ 720 1280]
top_label: [0]
top_conf: [0.74688995]
top_boxes: [[380.5364  580.53345 395.8494  598.0051 ]]
b'0 0.75' 380 580 395 598


Processing images:  46%|████▌     | 6/13 [00:03<00:03,  1.91it/s]

./JPEGImages/5.jpg
image_shape: [ 720 1280]
top_label: [0]
top_conf: [0.5746191]
top_boxes: [[445.75278 446.6018  463.13177 467.0672 ]]
b'0 0.57' 445 446 463 467
./JPEGImages/6.jpg
image_shape: [1088 1920]
top_label: [0 0 1 2]
top_conf: [0.8759316  0.86156726 0.8493998  0.67359865]
top_boxes: [[ 556.484    710.46185  681.69415  850.79114]
 [ 577.0404   545.7759   684.98285  680.75415]
 [ 552.46356  906.0999   648.14197  986.18756]
 [ 551.38403 1079.0781   631.5888  1105.9963 ]]
b'0 0.88' 556 710 681 850
b'0 0.86' 577 545 684 680
b'1 0.85' 552 906 648 986
b'2 0.67' 551 1079 631 1105


Processing images:  54%|█████▍    | 7/13 [00:03<00:03,  1.66it/s]

./JPEGImages/7.jpg
image_shape: [1080 1920]
top_label: [0 0 0 0 0 0 0 0 0 0 4 5]
top_conf: [0.9145428  0.8965528  0.8744245  0.8733102  0.8489684  0.84777987
 0.8039182  0.77353615 0.6918869  0.57517165 0.8654269  0.89768445]
top_boxes: [[4.78953674e+02 2.12649246e+02 6.29408203e+02 4.82037506e+02]
 [5.62563721e+02 1.24140755e+02 6.73785645e+02 3.43844452e+02]
 [5.32186951e+02 9.15501526e+02 6.57321777e+02 1.04398474e+03]
 [5.53872070e+02 6.66079346e+02 6.85565186e+02 8.33956360e+02]
 [5.42225647e+02 3.85875031e+02 6.37761353e+02 5.21197754e+02]
 [5.37519409e+02 7.23595619e-01 6.54739624e+02 1.16539795e+02]
 [5.56659607e+02 7.98532471e+02 6.18690735e+02 8.63140320e+02]
 [5.55472290e+02 7.16643219e+01 6.57645813e+02 2.11347153e+02]
 [5.51677307e+02 5.08522156e+02 6.20945679e+02 5.54957825e+02]
 [5.58800354e+02 1.03104968e+03 6.15208313e+02 1.05857983e+03]
 [3.84687195e+02 9.25074524e+02 4.34500458e+02 9.42665344e+02]
 [3.82269928e+02 1.02110663e+03 4.34565186e+02 1.04556995e+03]]
b'0 0.

Processing images:  62%|██████▏   | 8/13 [00:04<00:03,  1.44it/s]

./JPEGImages/8.jpg
image_shape: [1080 1920]
top_label: [0 0 0 0]
top_conf: [0.9120898  0.8793356  0.87576884 0.68811953]
top_boxes: [[488.19662     1.9689274 758.20984   328.40012  ]
 [506.0826    555.47406   622.6379    691.66766  ]
 [460.23618   752.111     606.8192    917.6165   ]
 [524.45746   723.18994   550.1065    749.72485  ]]
b'0 0.91' 488 1 758 328
b'0 0.88' 506 555 622 691
b'0 0.88' 460 752 606 917
b'0 0.69' 524 723 550 749


Processing images:  69%|██████▉   | 9/13 [00:05<00:02,  1.52it/s]

./JPEGImages/9.jpg
image_shape: [1080 1920]
top_label: [0 0 0]
top_conf: [0.87558335 0.85386    0.7444073 ]
top_boxes: [[ 334.07336   1408.817      460.43317   1544.2117   ]
 [ 337.09003    966.3564     410.7061    1053.1396   ]
 [ 334.12927      1.6483784  472.5522     141.2091   ]]
b'0 0.88' 334 1408 460 1544
b'0 0.85' 337 966 410 1053
b'0 0.74' 334 1 472 141


Processing images:  77%|███████▋  | 10/13 [00:06<00:02,  1.46it/s]

./JPEGImages/10.jpg
image_shape: [1080 1920]
top_label: [0 6]
top_conf: [0.87161624 0.5254349 ]
top_boxes: [[ 529.59973  981.4095   607.31726 1072.2589 ]
 [ 236.51384  181.82945  278.41318  244.2766 ]]
b'0 0.87' 529 981 607 1072
b'6 0.53' 236 181 278 244


Processing images:  85%|████████▍ | 11/13 [00:06<00:01,  1.47it/s]

./JPEGImages/11.jpg
image_shape: [ 720 1280]
top_label: [0 2]
top_conf: [0.7758542  0.68949246]
top_boxes: [[307.88986 343.11658 339.5513  423.54495]
 [313.40268 789.3176  370.82462 809.5403 ]]
b'0 0.78' 307 343 339 423
b'2 0.69' 313 789 370 809


Processing images:  92%|█████████▏| 12/13 [00:07<00:00,  1.59it/s]

./JPEGImages/12.jpg
image_shape: [1080 1920]
top_label: [0 0 0 0 0 1]
top_conf: [0.89559865 0.8579432  0.84982526 0.6104842  0.53023744 0.5017248 ]
top_boxes: [[589.4994    62.665413 751.87225  338.50867 ]
 [711.85406  657.76013  799.4238   758.24274 ]
 [744.72864  854.0421   817.8033   912.14264 ]
 [748.04926  534.5122   778.7087   598.3492  ]
 [756.8212   617.64185  782.9046   666.5429  ]
 [774.58887  948.55145  818.86426  976.40814 ]]
b'0 0.90' 589 62 751 338
b'0 0.86' 711 657 799 758
b'0 0.85' 744 854 817 912
b'0 0.61' 748 534 778 598
b'0 0.53' 756 617 782 666
b'1 0.50' 774 948 818 976


Processing images: 100%|██████████| 13/13 [00:08<00:00,  1.61it/s]

Processed 13 images in 7.67 seconds.
Frame per second (FPS): 1.69



