Press q in opencv window to stop.

In [None]:
from pathlib import Path
import time

import cv2
import mss
import numpy as np
import torch
import torchvision
import torchvision.transforms.functional as TF

MODEL_PATH = Path('checkpoints/final_efficientnet_v2_s.pth')
if not MODEL_PATH.exists():
    MODEL_PATH = Path('../checkpoints/final_efficientnet_v2_s.pth')
if not MODEL_PATH.exists():
    raise FileNotFoundError(f'Model file not found: {MODEL_PATH.resolve()}')

IMG_SIZE = 224
NUM_CLASSES = 10
MEAN = [0.4914, 0.4822, 0.4465]
STD = [0.2470, 0.2435, 0.2616]
CLASS_NAMES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck',
]

CAPTURE_BACKEND = 'mss'  # 'mss' or 'opencv_camera'
USE_PRIMARY_MONITOR = True
MONITOR_INDEX = 1
CAPTURE_REGION = {'top': 120, 'left': 200, 'width': 1280, 'height': 720}
CAMERA_INDEX = 0

TARGET_FPS = 10
WINDOW_NAME = 'Terraria Biome Detection'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if CAPTURE_BACKEND not in {'mss', 'opencv_camera'}:
    raise ValueError("CAPTURE_BACKEND must be 'mss' or 'opencv_camera'.")

print('Device:', DEVICE)
print('Capture backend:', CAPTURE_BACKEND)
print('Model path:', MODEL_PATH.resolve())

Device: cpu
Capture backend: mss
Model path: C:\Users\David Kim\Documents\Programming\Terraria-Biome-Detection\checkpoints\final_efficientnet_v2_s.pth


In [2]:
def load_model():
    model = torchvision.models.efficientnet_v2_s(weights=None, num_classes=NUM_CLASSES)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    return model.to(DEVICE).eval()


def preprocess(frame_bgr):
    rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    rgb = cv2.resize(rgb, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
    x = torch.from_numpy(rgb).permute(2, 0, 1).float().div_(255.0)
    x = TF.normalize(x, MEAN, STD)
    return x.unsqueeze(0).to(DEVICE)


def predict(frame_bgr):
    with torch.inference_mode():
        probs = torch.softmax(MODEL(preprocess(frame_bgr)), dim=1)[0]
    idx = int(probs.argmax().item())
    return CLASS_NAMES[idx], float(probs[idx].item())


def draw_overlay(frame, label, conf, fps):
    out = frame.copy()
    cv2.putText(out, f'Pred: {label}', (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)
    cv2.putText(out, f'Conf: {conf:.3f}', (20, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 255), 2)
    cv2.putText(out, f'FPS: {fps:.1f}', (20, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 200, 0), 2)
    cv2.putText(out, 'Press q to quit', (20, out.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (200, 200, 200), 2)
    return out


MODEL = load_model()

In [3]:
if CAPTURE_BACKEND == 'mss':
    sct = mss.mss()
    monitor = sct.monitors[MONITOR_INDEX] if USE_PRIMARY_MONITOR else CAPTURE_REGION
    print('Capture region:', monitor)

    def grab():
        return np.asarray(sct.grab(monitor), dtype=np.uint8)[:, :, :3]

    cleanup = sct.close
else:
    cap = cv2.VideoCapture(CAMERA_INDEX)
    if not cap.isOpened():
        raise RuntimeError(f'Could not open camera index {CAMERA_INDEX}.')
    print(f'Using camera backend at index {CAMERA_INDEX}')

    def grab():
        ok, frame = cap.read()
        if not ok or frame is None:
            raise RuntimeError('Failed to read frame from camera backend.')
        return frame

    cleanup = cap.release

frame_interval = 1.0 / max(1, TARGET_FPS)
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)

try:
    while True:
        t0 = time.perf_counter()
        frame = grab()
        label, conf = predict(frame)

        dt = max(time.perf_counter() - t0, 1e-6)
        fps_display = 1.0 / dt

        cv2.imshow(WINDOW_NAME, draw_overlay(frame, label, conf, fps_display))
        if (cv2.waitKey(1) & 0xFF) == ord('q'):
            break

        sleep_time = frame_interval - dt
        if sleep_time > 0:
            time.sleep(sleep_time)
finally:
    cleanup()
    cv2.destroyAllWindows()

Capture region: {'left': 0, 'top': 0, 'width': 1920, 'height': 1080}


KeyboardInterrupt: 