# YOLOv5 1: Inference (YOLOv5)

In [1]:
import os
import glob as glob
import matplotlib.pyplot as plt
import cv2
import random
import numpy as np
import torch

np.random.seed(42)

In [2]:
# Define some parameters
PRETRAINED_MODEL_LOCATION = '/home/martin/Projects/ongoing/ukw/ukw_detection_system/ukw_detection_system/wildfire_smoke/models/full/v3-epochs35-fix-dataset/wildfire-model/weights/best.pt'
TEST_IMAGES_LOCATION = '/home/martin/Projects/ongoing/ukw/ukw_detection_system/ukw_detection_system/wildfire_smoke/data/test'
OUTPUT_FOLDER = '/'.join(PRETRAINED_MODEL_LOCATION.split('/')[:-2]) + "/test_inferences"

TRAIN = True

In [3]:
%cd yolov5/
# Helper function to logging results
def set_res_dir():
    # Directory to store results
    res_dir_count = len(glob.glob('runs/train/*'))
    print(f"Current number of result directories: {res_dir_count}")
    if TRAIN:
        RES_DIR = f"results_{res_dir_count+1}"
        print(RES_DIR)
    else:
        RES_DIR = f"results_{res_dir_count}"
    return RES_DIR

/home/martin/Projects/ongoing/ukw/ukw_detection_system/ukw_detection_system/wildfire_smoke/experiments/yolov5/yolov5


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [4]:
# Function to convert bounding boxes in YOLO format to xmin, ymin, xmax, ymax.
def yolo2bbox(bboxes):
    xmin, ymin = bboxes[0]-bboxes[2]/2, bboxes[1]-bboxes[3]/2
    xmax, ymax = bboxes[0]+bboxes[2]/2, bboxes[1]+bboxes[3]/2
    return xmin, ymin, xmax, ymax


# Function to draw the boxes
def plot_box(image, bboxes, labels):
    # Need the image height and width to denormalize
    # the bounding box coordinates
    h, w, _ = image.shape
    for box_num, box in enumerate(bboxes):
        x1, y1, x2, y2 = yolo2bbox(box)
        # denormalize the coordinates
        xmin = int(x1*w)
        ymin = int(y1*h)
        xmax = int(x2*w)
        ymax = int(y2*h)
        width = xmax - xmin
        height = ymax - ymin

        class_name = class_names[int(labels[box_num])]

        cv2.rectangle(
            image,
            (xmin, ymin), (xmax, ymax),
            color=colors[class_names.index(class_name)],
            thickness=2
        )

        font_scale = min(1,max(3,int(w/500)))
        font_thickness = min(2, max(10,int(w/50)))

        p1, p2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
        # Text width and height
        tw, th = cv2.getTextSize(
            class_name,
            0, fontScale=font_scale, thickness=font_thickness
        )[0]
        p2 = p1[0] + tw, p1[1] + -th - 10
        cv2.rectangle(
            image,
            p1, p2,
            color=colors[class_names.index(class_name)],
            thickness=-1,
        )
        cv2.putText(
            image,
            class_name,
            (xmin+1, ymin-10),
            cv2.FONT_HERSHEY_SIMPLEX,
            font_scale,
            (255, 255, 255),
            font_thickness
        )
    return image


# Function to plot images with the bounding boxes.
def plot(image_paths, label_paths, num_samples):
    all_training_images = glob.glob(image_paths)
    all_training_labels = glob.glob(label_paths)
    all_training_images.sort()
    all_training_labels.sort()

    num_images = len(all_training_images)

    plt.figure(figsize=(15, 12))
    for i in range(num_samples):
        j = random.randint(0,num_images-1)
        image = cv2.imread(all_training_images[j])
        with open(all_training_labels[j], 'r') as f:
            bboxes = []
            labels = []
            label_lines = f.readlines()
            for label_line in label_lines:
                label = label_line[0]
                bbox_string = label_line[2:]
                x_c, y_c, w, h = bbox_string.split(' ')
                x_c = float(x_c)
                y_c = float(y_c)
                w = float(w)
                h = float(h)
                bboxes.append([x_c, y_c, w, h])
                labels.append(label)
        result_image = plot_box(image, bboxes, labels)
        plt.subplot(2, 2, i+1)
        plt.imshow(result_image[:, :, ::-1])
        plt.axis('off')
    plt.subplots_adjust(wspace=0)
    plt.tight_layout()
    plt.show()

The following functions are for carrying out inference on images and videos.

In [5]:
# Helper function for inference on images.
def inference(RES_DIR, data_path, weights, output_folder_path):
    # Directory to store inference results.
    print(f"Inference detection directories: {output_folder_path}")
    # Inference on images.
    !python detect.py --weights {weights} \
    --source {data_path} --name {output_folder_path} --view-img
    return output_folder_path

We may also need to visualize images in any of the directories. The following function accepts a directory path and plots all the images in them.

In [6]:
def visualize(INFER_DIR, n_images):
    # Visualize inference images.
    INFER_PATH = f"runs/detect/{INFER_DIR}"
    infer_images = glob.glob(f"{INFER_PATH}/*.jpg")
    print(infer_images)
    for pred_image in infer_images[:n_images]:
        image = cv2.imread(pred_image)
        plt.figure(figsize=(19, 16))
        plt.imshow(image[:, :, ::-1])
        plt.axis('off')
        plt.show()

**Visualize validation prediction images.**

In [7]:
# Inference on images.
RES_DIR = set_res_dir()

image_infer_dir = inference(RES_DIR, TEST_IMAGES_LOCATION, PRETRAINED_MODEL_LOCATION, OUTPUT_FOLDER)

Current number of result directories: 0
results_1
Inference detection directories: /home/martin/Projects/ongoing/ukw/ukw_detection_system/ukw_detection_system/wildfire_smoke/models/full/v3-epochs35-fix-dataset/wildfire-model/test_inferences
[34m[1mdetect: [0mweights=['/home/martin/Projects/ongoing/ukw/ukw_detection_system/ukw_detection_system/wildfire_smoke/models/full/v3-epochs35-fix-dataset/wildfire-model/weights/best.pt'], source=/home/martin/Projects/ongoing/ukw/ukw_detection_system/ukw_detection_system/wildfire_smoke/data/test, data=data/coco128.yaml, imgsz=[640, 640], conf_thres=0.25, iou_thres=0.45, max_det=1000, device=, view_img=True, save_txt=False, save_csv=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=runs/detect, name=/home/martin/Projects/ongoing/ukw/ukw_detection_system/ukw_detection_system/wildfire_smoke/models/full/v3-epochs35-fix-dataset/wildfire-model/test_inferences, 