In [None]:
# this script drives with our 30.01_quaternary steering model and overlays traffic light predictions
# we add a hard stop when a red light is detected with strong confidence, then release it with a debounce
# when the left signal is on, the top left quarter of camera2 is shown in grayscale
# new: when the left indicator engages, we count how many degrees our car has turned

import os
import json
import time
import random
import numpy as np
import pygame
import carla

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
from carla import VehicleLightState
import cv2  # for resizing to 224x224
from pathlib import Path

# this part sets up our steering model. we use resnet18 as a feature extractor and add a small head
# the network takes both the image and the turn signal state and then predicts one of 15 steering classes
CLASS_NAMES_STEER = [
    "Left Hardest", "Left Harder", "Left Hard", "Left Medium", "Left Light", "Left Slight", "Left Minimal",
    "No Turning",
    "Right Minimal", "Right Slight", "Right Light", "Right Medium", "Right Hard", "Right Harder", "Right Hardest"
]

class SteeringClassifier(nn.Module):
    def __init__(self, num_classes=15):
        super().__init__()
        try:
            resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        except Exception:
            resnet = models.resnet18(pretrained=True)
        resnet.fc = nn.Identity()
        self.cnn = resnet
        self.turn_embed = nn.Embedding(3, 16)
        self.fc = nn.Sequential(
            nn.Linear(512 + 16, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, image, turn_signal):
        x_img = self.cnn(image)
        x_signal = self.turn_embed(turn_signal)
        x = torch.cat((x_img, x_signal), dim=1)
        return self.fc(x)

# load steering model and pick device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
steer_model = SteeringClassifier(num_classes=15).to(device)
steer_model.load_state_dict(torch.load("../Models/30.01_quaternary_model_final.pth", map_location=device))
steer_model.eval()

# basic preprocessing for the steering input
image_transform_steer = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# map discrete steering class to a continuous steering value
steering_map = {
    0: -0.8,   1: -0.6,   2: -0.4,   3: -0.25,  4: -0.20,  5: -0.10,  6: -0.05,
    7:  0.0,
    8:  0.05,  9:  0.10, 10:  0.15, 11:  0.25, 12:  0.40, 13:  0.60, 14:  0.80
}

# here we prepare our traffic light classifier. it loads training artifacts and maps label indices
SAVE_DIR       = "../Models"
WEIGHTS_PTH_TL = os.path.join(SAVE_DIR, "traffic_classifier_state.pth")
CLASS_MAP_JSON = os.path.join(SAVE_DIR, "class_mapping.json")

label_to_index = {-1: 0, 0: 1, 1: 2}
index_to_label = {v: k for k, v in label_to_index.items()}
CLASS_NAMES_TL = {-1: "No Light", 0: "Red", 1: "Green"}

if os.path.exists(CLASS_MAP_JSON):
    with open(CLASS_MAP_JSON, "r") as f:
        m = json.load(f)
    if "label_to_index" in m and "index_to_label" in m:
        label_to_index = {int(k): int(v) for k, v in m["label_to_index"].items()}
        index_to_label = {int(k): int(v) for k, v in m["index_to_label"].items()}

NUM_CLASSES_TL = 3
IN_CHANNELS_TL = 3  # camera is rgb

# traffic light classifier based on resnet18 with a small head
class TLResNetClassifier(nn.Module):
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        try:
            from torchvision.models import resnet18
            backbone = resnet18(weights=None)
        except Exception:
            from torchvision.models import resnet18
            backbone = resnet18(pretrained=False)
        if in_channels != 3:
            backbone.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        in_feat = backbone.fc.in_features
        backbone.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feat, num_classes))
        self.net = backbone

    def forward(self, x):
        if x.dim() == 4:
            _, _, H, W = x.shape
            if H < 224 or W < 224:
                x = F.interpolate(x, size=(224, 224), mode="bilinear", align_corners=False)
        return self.net(x)

# load traffic light model weights and set eval
tl_model = TLResNetClassifier(IN_CHANNELS_TL, NUM_CLASSES_TL).to(device)
if not os.path.exists(WEIGHTS_PTH_TL):
    print(f"[warn] traffic-light weights not found at {WEIGHTS_PTH_TL}. overlay will show n/a.")
else:
    state = torch.load(WEIGHTS_PTH_TL, map_location=device)
    tl_state_dict = state["model_state_dict"] if isinstance(state, dict) and "model_state_dict" in state else state
    tl_model.load_state_dict(tl_state_dict, strict=True)
tl_model.eval()

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

# this helper predicts the traffic light class for a given rgb frame. returns name and confidence
@torch.no_grad()
def predict_traffic_light(np_rgb):
    if not os.path.exists(WEIGHTS_PTH_TL):
        return "N/A", 0.0

    x = np.asarray(np_rgb)
    if x.ndim == 2:  # HxW -> HxWx1
        x = x[..., None]
    if x.dtype != np.float32:
        x = x.astype(np.float32)
    if x.max() > 1.5:  # guard if uint8 slipped in
        x /= 255.0

    # HWC -> CHW
    x = np.transpose(x, (2, 0, 1))

    # channel alignment
    if x.shape[0] != IN_CHANNELS_TL:
        if IN_CHANNELS_TL == 3 and x.shape[0] == 1:
            x = np.repeat(x, 3, axis=0)  # gray -> rgb
        elif IN_CHANNELS_TL == 1 and x.shape[0] == 3:
            x = x.mean(axis=0, keepdims=True)  # rgb -> gray
        else:
            raise ValueError(f"channel mismatch for tl model: got {x.shape[0]}, expected {IN_CHANNELS_TL}")

    # to tensor (N,C,H,W)
    xt = torch.from_numpy(x).unsqueeze(0).to(device)

    # resize to 224x224 (same as training/validation)
    _, _, H, W = xt.shape
    if H != 224 or W != 224:
        xt = F.interpolate(xt, size=(224, 224), mode="bilinear", align_corners=False)

    # normalize
    if IN_CHANNELS_TL == 3:
        mean = torch.tensor(IMAGENET_MEAN, device=device).view(1, 3, 1, 1)
        std  = torch.tensor(IMAGENET_STD,  device=device).view(1, 3, 1, 1)
        xt = (xt - mean) / std
    else:
        xt = (xt - 0.5) / 0.5

    # predict
    logits = tl_model(xt)
    probs = torch.softmax(logits, dim=1).squeeze(0)
    conf, pred_idx = torch.max(probs, dim=0)

    pred_label = index_to_label[int(pred_idx.item())]
    name = CLASS_NAMES_TL.get(pred_label, str(pred_label))
    return name, float(conf.item())

# now we set up pygame and carla. this handles rendering, vehicle, and camera sensors
pygame.init()
width, height = 1024, 768
screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("carla: steering model + traffic-light overlay")
font = pygame.font.SysFont(None, 30)
font_big = pygame.font.SysFont(None, 36)

client = carla.Client("localhost", 2000)
client.set_timeout(25.0)
client.load_world("Town05")
world = client.get_world()
bp = world.get_blueprint_library()

# spawn the vehicle and set manual control
vehicle_bp = (bp.filter("vehicle.dodge.charger_2020") or bp.filter("vehicle.*model3*"))[0]
spawn_point = random.choice(world.get_map().get_spawn_points())
vehicle = world.spawn_actor(vehicle_bp, spawn_point)
vehicle.set_autopilot(False)

spectator = world.get_spectator()

# attach cameras. one narrow for steering, one wide for hud
cam_bp = bp.find('sensor.camera.rgb')
cam_bp.set_attribute('image_size_x', '448')
cam_bp.set_attribute('image_size_y', '252')
cam_bp.set_attribute('fov', '145')
cam_bp.set_attribute('sensor_tick', '0.1')
cam_transform = carla.Transform(carla.Location(x=1.5, z=2.4))
camera = world.spawn_actor(cam_bp, cam_transform, attach_to=vehicle)

camera_image = None
camera_np = None
camera_image2 = None
camera_np2 = None
pred_label_steer = "Loading..."
pred_label_tl = "N/A"
conf_tl = 0.0  # default confidence for hud

cam_bp2 = bp.find('sensor.camera.rgb')
cam_bp2.set_attribute('image_size_x', '1024')
cam_bp2.set_attribute('image_size_y', '768')
cam_bp2.set_attribute('fov', '110')
cam_bp2.set_attribute('sensor_tick', '0.1')
cam_transform2 = carla.Transform(carla.Location(x=1.5, z=2.4))
camera2 = world.spawn_actor(cam_bp2, cam_transform2, attach_to=vehicle)

# camera callbacks convert raw frames into numpy arrays and pygame surfaces
def process_image(image):
    global camera_image, camera_np
    array = np.frombuffer(image.raw_data, dtype=np.uint8)
    array = array.reshape((image.height, image.width, 4))
    array = array[:, :, :3][:, :, ::-1]  # rgb
    camera_np = array.copy()
    camera_image_raw = pygame.surfarray.make_surface(array.swapaxes(0, 1))
    camera_image = pygame.transform.scale(camera_image_raw, (width, height))

def process_image2(image):
    global camera_image2, camera_np2, left_signal_on
    array = np.frombuffer(image.raw_data, dtype=np.uint8)
    array = array.reshape((image.height, image.width, 4))
    array = array[:, :, :3][:, :, ::-1]  # bgra -> rgb

    # apply our tl crop first (zoom + blackout + side blur/desat)
    array = traffic_light_crop(array)

    # add two small grayscale rectangles to the top corners to help emphasize regions
    H, W = array.shape[:2]
    rect_w = (2 * W) // 6
    rect_h = H // 4

    # top right
    y0 = 0
    y1 = rect_h * 2
    x1 = W
    x0 = max(0, x1 - rect_w)
    roi_r = array[y0:y1, x0:x1]
    gray_r = cv2.cvtColor(roi_r, cv2.COLOR_RGB2GRAY)
    array[y0:y1, x0:x1] = cv2.cvtColor(gray_r, cv2.COLOR_GRAY2RGB)

    # top left
    y0 = 0
    y1 = rect_h
    x1 = max(0, rect_w)
    x0 = 0
    roi_r = array[y0:y1, x0:x1]
    gray_r = cv2.cvtColor(roi_r, cv2.COLOR_RGB2GRAY)
    array[y0:y1, x0:x1] = cv2.cvtColor(gray_r, cv2.COLOR_GRAY2RGB)

    # save for model and display
    camera_np2 = array.copy()

    # build pygame surface for hud
    camera_image_raw = pygame.surfarray.make_surface(array.swapaxes(0, 1))
    camera_image2 = pygame.transform.scale(camera_image_raw, (width, height))

camera.listen(process_image)
camera2.listen(process_image2)

clock = pygame.time.Clock()
print("model-controlled driving started.")
print("q/e to turn on left/right signal, r to cancel. w/s throttle up/down. p respawn. esc to exit.")

# traffic light crop helper. zoom center, blackout bottom third, blur sides
def traffic_light_crop(img: np.ndarray) -> np.ndarray:
    """
    applies a center zoom and blacks out the bottom third.
    returns a new image with the same height and width as input.
    """
    if img is None or img.ndim != 3:
        return img

    h, w = img.shape[:2]
    out = img

    # center zoom
    z = 1.0 + (50.0 / 100.0)
    new_w = max(1, int(w / z))
    new_h = max(1, int(h / z))
    x1 = max(0, (w - new_w) // 2)
    y1 = max(0, (h - new_h) // 2)
    cropped = out[y1:y1 + new_h, x1:x1 + new_w]
    if cropped.size > 0:
        out = cv2.resize(cropped, (w, h))

    # blackout bottom third
    out[h * 2 // 3 :, ...] = 0

    # prepare for color ops
    is_gray = (out.ndim == 2) or (out.shape[2] == 1)
    if is_gray:
        out_rgb = cv2.cvtColor(out, cv2.COLOR_GRAY2BGR)
    else:
        out_rgb = out

    quarter_w = max(1, w // 4)

    def blur_and_desaturate(region: np.ndarray) -> np.ndarray:
        blurred = cv2.GaussianBlur(region, (21, 21), 0)
        hsv = cv2.cvtColor(blurred, cv2.COLOR_BGR2HSV)
        hsv[:, :, 1] = 0
        return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

    out_rgb[:, :quarter_w, :] = blur_and_desaturate(out_rgb[:, :quarter_w, :])
    out_rgb[:, -quarter_w:, :] = blur_and_desaturate(out_rgb[:, -quarter_w:, :])

    # we keep color for the ui path
    return out

# helpers to track left-turn yaw degrees
def _wrap_deg(d: float) -> float:
    """wrap any angle to [-180, 180)."""
    return (d + 180.0) % 360.0 - 180.0

def _delta_deg(prev: float, cur: float) -> float:
    """smallest signed delta cur - prev in degrees (handles wrap)."""
    return _wrap_deg(cur - prev)

# control state and params
throttle = 0.3
prev_steering_value = 0.0
last_steering_time = time.time()
STEERING_HOLD_S = 0.05  # 50 ms
steering_value = 0.0  # init so hud has a value on first frames

# thresholds for hard stop on red
RED_CONF_THRESHOLD     = 0.83
RELEASE_CONF_THRESHOLD = 0.75
UNBRAKE_DEBOUNCE_S     = 0.02
SPEED_EPS              = 0.30

hard_stop_active = False
last_nonred_time = time.time()

# flags for signals and left-turn tracking
left_signal_on = False
left_turn_tracking_active = False
left_turn_deg = 0.0
left_turn_prev_yaw = None
prev_left_signal_on = False  # edge detection

def speed_mps(v):
    vel = v.get_velocity()
    return (vel.x**2 + vel.y**2 + vel.z**2) ** 0.5

# helper to respawn vehicle and reattach sensors
def respawn_vehicle():
    global vehicle, camera, camera2
    try:
        if camera.is_listening:
            camera.stop()
        camera.destroy()
        if camera2.is_listening:
            camera2.stop()
        camera2.destroy()
    except Exception:
        pass
    try:
        vehicle.destroy()
    except Exception:
        pass

    sp = random.choice(world.get_map().get_spawn_points())
    vehicle_new = world.spawn_actor(vehicle_bp, sp)
    vehicle_new.set_autopilot(False)

    camera_new = world.spawn_actor(cam_bp, cam_transform, attach_to=vehicle_new)
    camera_new.listen(process_image)

    camera_new2 = world.spawn_actor(cam_bp2, cam_transform2, attach_to=vehicle_new)
    camera_new2.listen(process_image2)

    # reset globals
    vehicle = vehicle_new
    camera = camera_new
    camera2 = camera_new2

# main loop. we predict steering and traffic light state, apply controls, draw hud, and follow with spectator
try:
    while True:
        clock.tick(60)
        pygame.event.pump()
        keys = pygame.key.get_pressed()

        # update turn signal state with q/e/r (use bitmask checks because lights can combine)
        if keys[pygame.K_q]:
            vehicle.set_light_state(VehicleLightState.LeftBlinker)
            turn_signal = -1
        elif keys[pygame.K_e]:
            vehicle.set_light_state(VehicleLightState.RightBlinker)
            turn_signal = 1
        elif keys[pygame.K_r]:
            vehicle.set_light_state(VehicleLightState.NONE)
            turn_signal = 0
        else:
            ls = vehicle.get_light_state()
            if ls & VehicleLightState.LeftBlinker:
                turn_signal = -1
            elif ls & VehicleLightState.RightBlinker:
                turn_signal = 1
            else:
                turn_signal = 0

        # drive the gray-out state and angle tracking flags from current signals
        ls_now = vehicle.get_light_state()
        left_signal_on = bool(ls_now & VehicleLightState.LeftBlinker) or (turn_signal == -1)
        right_signal_on = bool(ls_now & VehicleLightState.RightBlinker) or (turn_signal == 1)
        no_signal_on = (not left_signal_on) and (not right_signal_on)
        if no_signal_on or right_signal_on:
            left_turn_deg = 0.0

        # start/stop turn-angle tracking on left signal edges
        if left_signal_on and not prev_left_signal_on:
            left_turn_tracking_active = True
            left_turn_deg = 0.0
            left_turn_prev_yaw = vehicle.get_transform().rotation.yaw
        elif (not left_signal_on) and prev_left_signal_on:
            left_turn_tracking_active = False
            left_turn_prev_yaw = None
        prev_left_signal_on = left_signal_on

        # throttle tweaks
        if keys[pygame.K_w]:
            throttle = min(1.0, throttle + 0.005)
        if keys[pygame.K_s]:
            throttle = max(0.1, throttle - 0.005)

        # respawn on demand
        if keys[pygame.K_p]:
            print("respawning vehicle...")
            respawn_vehicle()
            throttle = 0.3
            time.sleep(0.5)
            continue

        control = carla.VehicleControl()
        control.throttle = throttle

        # predict steering
        if camera_np is not None:
            with torch.no_grad():
                img_t = image_transform_steer(camera_np).unsqueeze(0).to(device)
                signal_t = torch.tensor([[turn_signal + 1]], dtype=torch.long).to(device)  # {0,1,2}
                out = steer_model(img_t, signal_t.squeeze(1))
                pred_class = out.argmax(dim=1).item()
                pred_label_steer = CLASS_NAMES_STEER[pred_class]
                steering_value = steering_map[pred_class]

                # small bias so signal intent is honored
                if turn_signal == -1 and steering_value > -0.02:
                    steering_value = -0.02
                elif turn_signal == 1 and steering_value < 0.02:
                    steering_value = 0.02

                # hold logic to avoid chattering
                current_time = time.time()
                if steering_value == 0.0 or prev_steering_value == 0.0:
                    control.steer = steering_value
                    prev_steering_value = steering_value
                    last_steering_time = current_time
                elif steering_value != prev_steering_value:
                    control.steer = (steering_value + prev_steering_value) / 2
                    prev_steering_value = steering_value
                    last_steering_time = current_time
                else:
                    if current_time - last_steering_time >= STEERING_HOLD_S:
                        control.steer = steering_value

        # predict traffic light and manage hard stop gate
        if camera_np2 is not None:
            with torch.no_grad():
                pred_label_tl, conf_tl = predict_traffic_light(camera_np2)

            name_is_red   = (pred_label_tl == "Red")
            name_is_green = (pred_label_tl == "Green")
            name_is_none  = (pred_label_tl == "No Light")

            now = time.time()

            # activate hard stop only if confidently red
            if name_is_red and conf_tl >= RED_CONF_THRESHOLD:
                hard_stop_active = True
            else:
                # track last time we were clearly not red
                if (name_is_green and conf_tl >= RELEASE_CONF_THRESHOLD) or \
                   (name_is_none  and conf_tl >= RELEASE_CONF_THRESHOLD) or \
                   (not name_is_red):
                    last_nonred_time = now

            # deactivate if no longer red or red but too uncertain
            if hard_stop_active and (not name_is_red or (name_is_red and conf_tl < 0.70)):
                hard_stop_active = False

        # apply controls with hard-stop override
        if hard_stop_active:
            control.throttle = 0.0
            control.brake = 1.0
            control.hand_brake = False
            try:
                ls = vehicle.get_light_state()
                vehicle.set_light_state(ls | VehicleLightState.Brake)
            except Exception:
                pass
        else:
            control.brake = max(0.0, control.brake)

        vehicle.apply_control(control)

        # spectator follow camera
        tr = vehicle.get_transform()
        fwd = tr.get_forward_vector()
        cam_location = tr.location - fwd * 8 + carla.Location(z=3)
        cam_rotation = carla.Rotation(pitch=-10, yaw=tr.rotation.yaw)
        spectator.set_transform(carla.Transform(cam_location, cam_rotation))

        # accumulate signed degrees turned while left signal is active
        if left_turn_tracking_active:
            cur_yaw = tr.rotation.yaw
            if left_turn_prev_yaw is not None:
                d = _delta_deg(left_turn_prev_yaw, cur_yaw)
                # if we want absolute-only accumulation use: left_turn_deg += abs(d)
                left_turn_deg += d
            left_turn_prev_yaw = cur_yaw

        # draw hud
        if camera_np2 is not None:
            # full-screen grayscale if absolute left-turn exceeds small threshold
            frame_np = camera_np2
            if abs(left_turn_deg) > 2.5:
                g = cv2.cvtColor(frame_np, cv2.COLOR_RGB2GRAY)
                frame_np = cv2.cvtColor(g, cv2.COLOR_GRAY2RGB)

            camera_image2 = pygame.surfarray.make_surface(frame_np.swapaxes(0, 1))
            camera_image2 = pygame.transform.scale(camera_image2, (width, height))
            screen.blit(camera_image2, (0, 0))

            label_turn = font.render(f"Steer class: {pred_label_steer}", True, (255, 255, 0))
            label_steer = font.render(f"Steering: {steering_value:.3f}", True, (0, 255, 0))
            label_speed = font.render(f"Throttle: {throttle:.2f}", True, (0, 200, 255))
            screen.blit(label_turn, (20, 20))
            screen.blit(label_steer, (20, 50))
            screen.blit(label_speed, (20, 80))

            # degrees since left indicator engaged
            label_turn_count = font.render(f"Left-turn Δ: {left_turn_deg:.1f}°", True, (255, 255, 255))
            screen.blit(label_turn_count, (20, 110))

            # traffic light banner
            banner = pygame.Surface((width, 46), pygame.SRCALPHA)
            banner.fill((0, 0, 0, 140))
            screen.blit(banner, (0, height - 46))
            tl_text = font_big.render(
                f"Traffic Light: {pred_label_tl} ({conf_tl*100:.1f}%)",
                True, (255, 255, 255)
            )
            screen.blit(tl_text, (20, height - 40))

            if hard_stop_active:
                tl_text2 = font_big.render("HARD STOP ACTIVE", True, (255, 80, 80))
                screen.blit(tl_text2, (width - 320, height - 40))

            pygame.display.flip()

        # quit events
        for event in pygame.event.get():
            if event.type == pygame.QUIT or keys[pygame.K_ESCAPE]:
                raise KeyboardInterrupt

except KeyboardInterrupt:
    print("exiting and cleaning up...")
finally:
    try:
        camera.stop()
        camera2.stop()
    except Exception:
        pass
    try:
        vehicle.destroy()
    except Exception:
        pass
    try:
        camera.destroy()
        camera2.destroy()
    except Exception:
        pass
    pygame.quit()


Model-controlled driving started.
Q/E to turn on left/right signal, R to cancel. W/S throttle up/down. P respawn. ESC to exit.
Exiting and cleaning up...


: 