In [14]:
# Drives with your 30.01_quaternary steering model and overlays traffic-light predictions.

import os
import json
import time
import random
import numpy as np
import pygame
import carla

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
from carla import VehicleLightState
import cv2  # for resizing to 224x224
from pathlib import Path

# ========= Steering model (YOUR 30.01_quaternary) =========
CLASS_NAMES_STEER = [
    "Left Hardest", "Left Harder", "Left Hard", "Left Medium", "Left Light", "Left Slight", "Left Minimal",
    "No Turning",
    "Right Minimal", "Right Slight", "Right Light", "Right Medium", "Right Hard", "Right Harder", "Right Hardest"
]

class SteeringClassifier(nn.Module):
    def __init__(self, num_classes=15):
        super().__init__()
        try:
            resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        except Exception:
            resnet = models.resnet18(pretrained=True)
        resnet.fc = nn.Identity()
        self.cnn = resnet
        self.turn_embed = nn.Embedding(3, 16)
        self.fc = nn.Sequential(
            nn.Linear(512 + 16, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, image, turn_signal):
        x_img = self.cnn(image)
        x_signal = self.turn_embed(turn_signal)
        x = torch.cat((x_img, x_signal), dim=1)
        return self.fc(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
steer_model = SteeringClassifier(num_classes=15).to(device)
steer_model.load_state_dict(torch.load("../Models/30.01_quaternary_model_final.pth", map_location=device))
steer_model.eval()

image_transform_steer = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

steering_map = {
    0: -0.8,   1: -0.6,   2: -0.4,   3: -0.25,  4: -0.20,  5: -0.10,  6: -0.05,
    7:  0.0,
    8:  0.05,  9:  0.10, 10:  0.15, 11:  0.25, 12:  0.40, 13:  0.60, 14:  0.80
}

# ========= Traffic-light classifier (your training artifacts) =========
SAVE_DIR       = "../Models"
WEIGHTS_PTH_TL = os.path.join(SAVE_DIR, "traffic_classifier_state.pth")
CLASS_MAP_JSON = os.path.join(SAVE_DIR, "class_mapping.json")

label_to_index = {-1: 0, 0: 1, 1: 2}
index_to_label = {v: k for k, v in label_to_index.items()}
CLASS_NAMES_TL = {-1: "No Light", 0: "Red", 1: "Green"}
CLASS_EMOJI_TL = {-1: "⚫️", 0: "🔴", 1: "🟢"}

if os.path.exists(CLASS_MAP_JSON):
    with open(CLASS_MAP_JSON, "r") as f:
        m = json.load(f)
    if "label_to_index" in m and "index_to_label" in m:
        label_to_index = {int(k): int(v) for k, v in m["label_to_index"].items()}
        index_to_label = {int(k): int(v) for k, v in m["index_to_label"].items()}

NUM_CLASSES_TL = 3
IN_CHANNELS_TL = 3  # camera is RGB

class TLResNetClassifier(nn.Module):
    def __init__(self, in_channels=3, num_classes=3):
        super().__init__()
        try:
            from torchvision.models import resnet18
            backbone = resnet18(weights=None)
        except Exception:
            from torchvision.models import resnet18
            backbone = resnet18(pretrained=False)

        if in_channels != 3:
            backbone.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        in_feat = backbone.fc.in_features
        backbone.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feat, num_classes))
        self.net = backbone

    def forward(self, x):
        if x.dim() == 4:
            _, _, H, W = x.shape
            if H < 224 or W < 224:
                x = F.interpolate(x, size=(224, 224), mode="bilinear", align_corners=False)
        return self.net(x)

tl_model = TLResNetClassifier(IN_CHANNELS_TL, NUM_CLASSES_TL).to(device)
if not os.path.exists(WEIGHTS_PTH_TL):
    print(f"[WARN] Traffic-light weights not found at {WEIGHTS_PTH_TL}. Overlay will show N/A.")
else:
    state = torch.load(WEIGHTS_PTH_TL, map_location=device)
    tl_state_dict = state["model_state_dict"] if isinstance(state, dict) and "model_state_dict" in state else state
    tl_model.load_state_dict(tl_state_dict, strict=True)
tl_model.eval()


@torch.no_grad()
def predict_traffic_light(np_rgb):
    if not os.path.exists(WEIGHTS_PTH_TL):
        return "N/A", 0.0
    x = np_rgb.astype("float32") / 255.0
    x = np.transpose(x, (2, 0, 1))  # CHW
    if x.shape[0] != IN_CHANNELS_TL:
        if IN_CHANNELS_TL == 1 and x.shape[0] == 3:
            x = x.mean(axis=0, keepdims=True)
        elif IN_CHANNELS_TL == 3 and x.shape[0] == 1:
            x = np.repeat(x, 3, axis=0)
        else:
            raise ValueError(f"Channel mismatch for TL model: {x.shape[0]}")
    xt = torch.from_numpy(x).unsqueeze(0).to(device)
    logits = tl_model(xt)
    probs = torch.softmax(logits, dim=1).squeeze(0)
    conf, pred_idx = torch.max(probs, dim=0)
    pred_label = index_to_label[int(pred_idx.item())]
    name = CLASS_NAMES_TL.get(pred_label, str(pred_label))
    emoji = CLASS_EMOJI_TL.get(pred_label, "")
    return f"{emoji} {name}", float(conf.item())

# ========= pygame + CARLA =========
pygame.init()
width, height = 1024, 768
screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("CARLA: Steering Model + Traffic-Light Overlay")
font = pygame.font.SysFont(None, 30)
font_big = pygame.font.SysFont(None, 36)

client = carla.Client("localhost", 2000)
client.set_timeout(25.0)
client.load_world("Town05")
world = client.get_world()
bp = world.get_blueprint_library()

vehicle_bp = (bp.filter("vehicle.dodge.charger_2020") or bp.filter("vehicle.*model3*"))[0]
spawn_point = random.choice(world.get_map().get_spawn_points())
vehicle = world.spawn_actor(vehicle_bp, spawn_point)
vehicle.set_autopilot(False)

spectator = world.get_spectator()

cam_bp = bp.find('sensor.camera.rgb')
cam_bp.set_attribute('image_size_x', '448')
cam_bp.set_attribute('image_size_y', '252')
cam_bp.set_attribute('fov', '145')
cam_bp.set_attribute('sensor_tick', '0.1')
cam_transform = carla.Transform(carla.Location(x=1.5, z=2.4))
camera = world.spawn_actor(cam_bp, cam_transform, attach_to=vehicle)

camera_image = None
camera_np = None
camera_image2 = None
camera_np2 = None
pred_label_steer = "Loading..."
pred_label_tl = "N/A"
steering_value = 0.0

cam_bp2 = bp.find('sensor.camera.rgb')
cam_bp2.set_attribute('image_size_x', '1024')
cam_bp2.set_attribute('image_size_y', '768')
cam_bp2.set_attribute('fov', '110')
cam_bp2.set_attribute('sensor_tick', '0.1')
cam_transform2 = carla.Transform(carla.Location(x=1.5, z=2.4))
camera2 = world.spawn_actor(cam_bp2, cam_transform2, attach_to=vehicle)

def process_image(image):
    global camera_image, camera_np
    array = np.frombuffer(image.raw_data, dtype=np.uint8)
    array = array.reshape((image.height, image.width, 4))
    array = array[:, :, :3][:, :, ::-1]  # RGB
    camera_np = array.copy()
    camera_image_raw = pygame.surfarray.make_surface(array.swapaxes(0, 1))
    camera_image = pygame.transform.scale(camera_image_raw, (width, height))

def process_image2(image):
    global camera_image2, camera_np2
    array = np.frombuffer(image.raw_data, dtype=np.uint8)
    array = array.reshape((image.height, image.width, 4))
    array = array[:, :, :3][:, :, ::-1]  # RGB
    array = traffic_light_crop(array)
    camera_np2 = array.copy()
    camera_image_raw = pygame.surfarray.make_surface(array.swapaxes(0, 1))
    camera_image2 = pygame.transform.scale(camera_image_raw, (width, height))
    
camera.listen(process_image)

camera2.listen(process_image2)

clock = pygame.time.Clock()
print("Model-controlled driving started.")
print("Q/E to turn on left/right signal, R to cancel. W/S throttle up/down. P respawn. ESC to exit.")

def traffic_light_crop(img: np.ndarray) -> np.ndarray:
    """
    Applies a center zoom (by zoom_percent) and optionally blacks out the bottom third.
    Returns a new image with the same HxW as input.
    """
    if img is None or img.ndim != 3:
        return img

    h, w = img.shape[:2]
    out = img

    # ---- Center zoom ----
    z = 1.0 + (50.0 / 100.0)
    new_w = max(1, int(w / z))
    new_h = max(1, int(h / z))
    x1 = max(0, (w - new_w) // 2)
    y1 = max(0, (h - new_h) // 2)
    cropped = out[y1:y1 + new_h, x1:x1 + new_w]
    if cropped.size > 0:
        out = cv2.resize(cropped, (w, h))

    # ---- Blackout bottom third ----
    out = out.copy()
    out[(2*h)//3:, :, :] = 0

    return out


# Controls
throttle = 0.3
prev_steering_value = 0.0
last_steering_time = time.time()
STEERING_HOLD_S = 0.05  # 50 ms

def respawn_vehicle():
    global vehicle, camera, camera2
    try:
        if camera.is_listening:
            camera.stop()
        camera.destroy()
        if camera2.is_listening:
            camera2.stop()
        camera2.destroy()
    except Exception:
        pass
    try:
        vehicle.destroy()
    except Exception:
        pass

    sp = random.choice(world.get_map().get_spawn_points())
    vehicle_new = world.spawn_actor(vehicle_bp, sp)
    vehicle_new.set_autopilot(False)

    camera_new = world.spawn_actor(cam_bp, cam_transform, attach_to=vehicle_new)
    camera_new.listen(process_image)

    camera_new2 = world.spawn_actor(cam_bp2, cam_transform2, attach_to=vehicle_new)
    camera_new2.listen(process_image2)

    vehicle = vehicle_new
    camera = camera_new
    camera2 = camera_new2

try:
    while True:
        clock.tick(60)
        pygame.event.pump()
        keys = pygame.key.get_pressed()

        # Turn signal state (Q/E/R)
        if keys[pygame.K_q]:
            vehicle.set_light_state(VehicleLightState.LeftBlinker)
            turn_signal = -1
        elif keys[pygame.K_e]:
            vehicle.set_light_state(VehicleLightState.RightBlinker)
            turn_signal = 1
        elif keys[pygame.K_r]:
            vehicle.set_light_state(VehicleLightState.NONE)
            turn_signal = 0
        else:
            ls = vehicle.get_light_state()
            if ls == VehicleLightState.LeftBlinker:
                turn_signal = -1
            elif ls == VehicleLightState.RightBlinker:
                turn_signal = 1
            else:
                turn_signal = 0

        # Throttle tweak
        if keys[pygame.K_w]:
            throttle = min(1.0, throttle + 0.005)
        if keys[pygame.K_s]:
            throttle = max(0.1, throttle - 0.005)

        # Respawn
        if keys[pygame.K_p]:
            print("Respawning vehicle...")
            respawn_vehicle()
            throttle = 0.3
            time.sleep(0.5)
            continue

        control = carla.VehicleControl()
        control.throttle = throttle

        # ====== PREDICT (steering + traffic light) ======

        if camera_np is not None:
            # Steering model (YOUR 30.01_quaternary drives here)
            with torch.no_grad():
                img_t = image_transform_steer(camera_np).unsqueeze(0).to(device)
                signal_t = torch.tensor([[turn_signal + 1]], dtype=torch.long).to(device)  # {0,1,2}
                out = steer_model(img_t, signal_t.squeeze(1))
                pred_class = out.argmax(dim=1).item()
                pred_label_steer = CLASS_NAMES_STEER[pred_class]
                steering_value = steering_map[pred_class]

                # Enforce minimum turn for active signal
                if turn_signal == -1 and steering_value > -0.02:
                    steering_value = -0.02
                elif turn_signal == 1 and steering_value < 0.02:
                    steering_value = 0.02

                # Simple smoothing/hold
                current_time = time.time()
                if steering_value == 0.0 or prev_steering_value == 0.0:
                    control.steer = steering_value
                    prev_steering_value = steering_value
                    last_steering_time = current_time
                elif steering_value != prev_steering_value:
                    control.steer = (steering_value + prev_steering_value) / 2
                    prev_steering_value = steering_value
                    last_steering_time = current_time
                else:
                    if current_time - last_steering_time >= STEERING_HOLD_S:
                        control.steer = steering_value

        if camera_np2 is not None:
            # Traffic-light overlay (independent; does not change control here)
            with torch.no_grad():
                pred_label_tl, conf_tl = predict_traffic_light(camera_np2)

        # Apply controls
        vehicle.apply_control(control)

        # Spectator follow
        tr = vehicle.get_transform()
        fwd = tr.get_forward_vector()
        cam_location = tr.location - fwd * 8 + carla.Location(z=3)
        cam_rotation = carla.Rotation(pitch=-10, yaw=tr.rotation.yaw)
        spectator.set_transform(carla.Transform(cam_location, cam_rotation))

        # ====== Draw HUD ======
        if camera_image2:
            screen.blit(camera_image2, (0, 0))

            # Steering label
            label_turn = font.render(f"Steer class: {pred_label_steer}", True, (255, 255, 0))
            label_steer = font.render(f"Steering: {steering_value:.3f}", True, (0, 255, 0))
            label_speed = font.render(f"Throttle: {throttle:.2f}", True, (0, 200, 255))
            screen.blit(label_turn, (20, 20))
            screen.blit(label_steer, (20, 50))
            screen.blit(label_speed, (20, 80))

            # Traffic-light banner
            banner = pygame.Surface((width, 46), pygame.SRCALPHA)
            banner.fill((0, 0, 0, 140))
            screen.blit(banner, (0, height - 46))
            tl_text = font_big.render(
                f"Traffic Light: {pred_label_tl} ({conf_tl*100:.1f}%)",
                True, (255, 255, 255)
            )
            screen.blit(tl_text, (20, height - 40))

            pygame.display.flip()

        # Quit events
        for event in pygame.event.get():
            if event.type == pygame.QUIT or keys[pygame.K_ESCAPE]:
                raise KeyboardInterrupt

except KeyboardInterrupt:
    print("Exiting and cleaning up...")
finally:
    try:
        camera.stop()
    except Exception:
        pass
    try:
        vehicle.destroy()
    except Exception:
        pass
    try:
        camera.destroy()
    except Exception:
        pass
    pygame.quit()


Model-controlled driving started.
Q/E to turn on left/right signal, R to cancel. W/S throttle up/down. P respawn. ESC to exit.
Exiting and cleaning up...
