# PT Model Testing

Test a specified model with a specific folder with images to verify detection accuracy.

In [45]:
# Cell 2: Imports and setup
import os
import cv2
import torch
import matplotlib.pyplot as plt
from ultralytics import YOLO
from IPython.display import display
import ipywidgets as widgets
import numpy as np

# Set paths and parameters
SOURCE_FOLDER = '../containerdoors/images/'
SOURCE_FOLDER = os.path.abspath(SOURCE_FOLDER)
MODEL_PATH = '../models/unified-detection-75epocc-13072025.pt'   
CONFIDENCE_THRESHOLD = 0.3

# Load model
model = YOLO(MODEL_PATH)
model.conf = CONFIDENCE_THRESHOLD

# Get list of JPG files
image_files = sorted([f for f in os.listdir(SOURCE_FOLDER) if f.lower().endswith('.jpg')])
current_index = 0  # For tracking which image is shown

In [46]:
# Cell 3: Preprocessing function
def preprocess_image(filepath):
    img_gray = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    if img_gray is None:
        raise ValueError(f"Image not found or can't be read: {filepath}")
    
    h_target = 320
    h, w = img_gray.shape
    scale = h_target / h
    new_w = int(w * scale)

    img_resized = cv2.resize(img_gray, (new_w, h_target))
    pad_w = 320 - new_w
    if pad_w < 0:
        raise ValueError(f"Image width after resize exceeds 320px: {filepath}")

    img_padded = cv2.copyMakeBorder(img_resized, 0, 0, 0, pad_w, cv2.BORDER_CONSTANT, value=255)

    img_rgb = cv2.cvtColor(img_padded, cv2.COLOR_GRAY2RGB)  # for display
    img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0  # for YOLO

    return img_tensor, img_rgb

In [47]:
# Cell 4: Display and inference function
def show_image_with_detections(index):
    file_path = os.path.join(SOURCE_FOLDER, image_files[index])
    img_tensor, img_np = preprocess_image(file_path)  # Now gets both

    results = model(img_tensor)[0]
    boxes = results.boxes
    classes = results.names

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(img_np)  # Just use the numpy RGB image, no cvtColor needed

    if boxes:
        for box in boxes:
            cls = int(box.cls[0])
            conf = float(box.conf[0])
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                       edgecolor='red', facecolor='none', linewidth=2))
            ax.text(x1, y1 - 5, f"{classes[cls]} {conf:.2f}", color='white',
                    backgroundcolor='red', fontsize=8)

    ax.set_title(f"{index + 1} / {len(image_files)}")
    ax.axis('off')
    plt.show()

In [48]:
# Cell 5: UI to navigate images
back_button = widgets.Button(description="⬅️ Back")
next_button = widgets.Button(description="Next ➡️")
out = widgets.Output()

def update_ui(change=None):
    out.clear_output(wait=True)
    with out:
        show_image_with_detections(current_index)

def on_next_clicked(b):
    global current_index
    current_index = (current_index + 1) % len(image_files)
    update_ui()

def on_back_clicked(b):
    global current_index
    current_index = (current_index - 1) % len(image_files)
    update_ui()

back_button.on_click(on_back_clicked)
next_button.on_click(on_next_clicked)

display(widgets.HBox([back_button, next_button]), out)
update_ui()

HBox(children=(Button(description='⬅️ Back', style=ButtonStyle()), Button(description='Next ➡️', style=ButtonS…

Output()