In [5]:
import time
import torch
import numpy as np
import pygetwindow as gw
from selenium import webdriver
from selenium.webdriver.common.by import By
from torchvision import transforms
from PIL import Image
import mss
import cv2
import os

In [6]:
# ==== Model Definitions ====
from torchvision import models
import torch.nn as nn

class AngleClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = models.resnet18(pretrained=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

class ComponentClassifier(nn.Module):
    def __init__(self, num_outputs):
        super().__init__()
        self.model = models.resnet18(pretrained=False)
        self.model.fc = nn.Sequential(
            nn.Linear(self.model.fc.in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_outputs),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [12]:
# ==== Step 1: Launch Web Simulation ====
url = "http://103.233.100.26:8080/"
driver = webdriver.Chrome()
print("⌛ Opening browser...")
driver.get(url)

⌛ Opening browser...


In [13]:
# ==== Step 2: Wait and Get Canvas Bounding Box ====
def get_canvas_bbox(driver):
    canvas = driver.find_element(By.TAG_NAME, "canvas")
    canvas_rect = driver.execute_script("""
        const rect = arguments[0].getBoundingClientRect();
        return {x: rect.left, y: rect.top, width: rect.width, height: rect.height};
    """, canvas)

    win = next((w for w in gw.getWindowsWithTitle("") if "Car Control" in w.title), None)
    if not win:
        raise Exception("Window with title containing 'Car Control' not found.")
    win_x, win_y = win.left, win.top

    offset_x, offset_y = 2, 180
    dpi_scale = driver.execute_script("return window.devicePixelRatio") * 1.01

    left = int(win_x + canvas_rect['x'] * dpi_scale + offset_x)
    top = int(win_y + canvas_rect['y'] * dpi_scale + offset_y)
    right = int(left + canvas_rect['width'] * dpi_scale)
    bottom = int(top + canvas_rect['height'] * dpi_scale)

    return (left, top, right, bottom)

bbox = get_canvas_bbox(driver)
print(f"📷 Canvas captured at: {bbox}")

📷 Canvas captured at: (186, 268, 1737, 1049)


In [16]:
# ==== Step 3: Load Models from .pt File ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("car_multi_model.pt", map_location=device)

angle_classes = checkpoint["angle_label_encoder"]
angle_model = AngleClassifier(num_classes=len(angle_classes))
component_model = ComponentClassifier(num_outputs=5)

angle_model.load_state_dict(checkpoint["angle_model_state_dict"])
component_model.load_state_dict(checkpoint["component_model_state_dict"])
angle_model.eval().to(device)
component_model.eval().to(device)

# ==== Step 4: Image Capture + Preprocess ====
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def capture_and_preprocess(bbox):
    with mss.mss() as sct:
        screenshot = sct.grab(bbox)
        img = Image.frombytes('RGB', screenshot.size, screenshot.rgb)
        img_tensor = transform(img).unsqueeze(0).to(device)
        return img, img_tensor

# ==== Step 5: Run Prediction Loop with GUI ====
print("🟢 Starting real-time inference... Press 'q' to quit.")
while True:
    img, img_tensor = capture_and_preprocess(bbox)

    with torch.no_grad():
        angle_logits = angle_model(img_tensor)
        component_probs = component_model(img_tensor)

    angle_pred = angle_classes[angle_logits.argmax(dim=1).item()]
    component_labels = ['FL', 'FR', 'RL', 'RR', 'Hood']
    component_pred = (component_probs > 0.5).int().squeeze().cpu().numpy()
    component_result = dict(zip(component_labels, component_pred))

    # Show results
    canvas_cv = np.array(img)[:, :, ::-1].copy()  # Convert to OpenCV BGR format with memory fix
    cv2.putText(canvas_cv, f"Angle: {angle_pred}", (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

    for i, (k, v) in enumerate(component_result.items()):
        text = f"{k}: {'Open' if v else 'Closed'}"
        cv2.putText(canvas_cv, text, (10, 60 + i * 25),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)  # Red text, thickness 2


    cv2.imshow("Real-Time Prediction", canvas_cv)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cv2.destroyAllWindows()
# driver.quit()


  checkpoint = torch.load("car_multi_model.pt", map_location=device)


🟢 Starting real-time inference... Press 'q' to quit.
