In [None]:
import carla
import numpy as np
import cv2
import pygame
import random
import torch
import torch.nn as nn
from torchvision import models, transforms
from carla import VehicleLightState
import time

# === Pygame setup ===
pygame.init()
width, height = 800, 600
screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("CARLA PyTorch Driving")

# === Connect to CARLA ===
client = carla.Client('localhost', 2000)
client.set_timeout(25.0)
client.load_world("Town04")
world = client.get_world()

# === Enable synchronous mode ===
settings = world.get_settings()
settings.synchronous_mode = True
settings.fixed_delta_seconds = 0.05
world.apply_settings(settings)

# === Spawn vehicle ===
blueprint_library = world.get_blueprint_library()
vehicle_bp = blueprint_library.find('vehicle.tesla.model3')
# spawn_points = world.get_map().get_spawn_points()

vehicle = None
# for spawn_point in spawn_points:
#     vehicle = world.try_spawn_actor(vehicle_bp, spawn_point)
#     if vehicle is not None:
#         break
spawn_point = random.choice(world.get_map().get_spawn_points())
vehicle = world.spawn_actor(vehicle_bp, spawn_point)

if vehicle is None:
    raise RuntimeError("Could not spawn vehicle due to collisions.")

# === Attach RGB camera ===
camera_bp = blueprint_library.find('sensor.camera.rgb')
camera_bp.set_attribute('image_size_x', '448')
camera_bp.set_attribute('image_size_y', '252')
camera_bp.set_attribute('fov', '145')
camera_bp.set_attribute('sensor_tick', '0.1')
camera_transform = carla.Transform(carla.Location(x=1.5, z=2.4))
camera = world.spawn_actor(camera_bp, camera_transform, attach_to=vehicle)

# === Spectator follow ===
spectator = world.get_spectator()
def update_spectator():
    transform = vehicle.get_transform()
    forward_vector = transform.get_forward_vector()
    cam_location = transform.location - forward_vector * 8 + carla.Location(z=3)
    cam_rotation = carla.Rotation(pitch=-10, yaw=transform.rotation.yaw)
    spectator.set_transform(carla.Transform(cam_location, cam_rotation))

# === Globals ===
camera_image = None
signal_indicator = "NONE"
control = carla.VehicleControl()

# === Display helper ===
def process_image(image):
    global camera_image
    array = np.frombuffer(image.raw_data, dtype=np.uint8).reshape((image.height, image.width, 4))
    array = array[:, :, :3][:, :, ::-1]
    surface = pygame.surfarray.make_surface(array.swapaxes(0, 1))
    camera_image = pygame.transform.scale(surface, (width, height))

# === Corrected Model Definition ===
class DrivingModel(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet18(pretrained=False)
        self.cnn_backbone = nn.Sequential(*list(resnet.children())[:-1])  # Output: (B, 512, 1, 1)

        self.signal_fc = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU()
        )

        self.combined_fc = nn.Sequential(
            nn.Linear(512 + 32, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, img, signal):
        x = self.cnn_backbone(img).view(img.size(0), -1)  # (B, 512)
        s = self.signal_fc(signal)                        # (B, 32)
        x = torch.cat([x, s], dim=1)                      # (B, 544)
        return self.combined_fc(x)

# === Load Model ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DrivingModel().to(device)
model.load_state_dict(torch.load("../output/best_steering_model.pth", map_location=device))
model.eval()

# === Image Preprocessing ===
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# === Inference callback ===
def on_image(image):
    img = np.frombuffer(image.raw_data, dtype=np.uint8).reshape((image.height, image.width, 4))[:, :, :3]
    img_rgb = img[:, :, ::-1]  # Convert BGR to RGB

    # Turn signal state
    signal = 0
    light_state = vehicle.get_light_state()
    if light_state & VehicleLightState.LeftBlinker:
        signal = -1
    elif light_state & VehicleLightState.RightBlinker:
        signal = 1

    # Transform image and signal
    img_tensor = transform(img_rgb).unsqueeze(0).to(device)
    signal_tensor = torch.tensor([[signal]], dtype=torch.float32).to(device)

    # Predict steering
    with torch.no_grad():
        steer = model(img_tensor, signal_tensor).item()
    steer = float(np.clip(steer, -1.0, 1.0))

    print(f"[DEBUG] Signal: {signal}, Predicted steer: {steer:.3f}")


    # Apply control
    control.steer = steer
    control.throttle = 0.4
    control.brake = 0.0
    vehicle.apply_control(control)

    update_spectator()
    process_image(image)

# === Start camera listener ===
camera.listen(on_image)

# === Main Loop ===
try:
    clock = pygame.time.Clock()
    while True:
        world.tick()
        keys = pygame.key.get_pressed()

        # Handle signal lights
        if keys[pygame.K_q]:
            vehicle.set_light_state(VehicleLightState.LeftBlinker)
            signal_indicator = "LEFT"
        elif keys[pygame.K_e]:
            vehicle.set_light_state(VehicleLightState.RightBlinker)
            signal_indicator = "RIGHT"
        elif keys[pygame.K_r]:
            vehicle.set_light_state(VehicleLightState.NONE)
            signal_indicator = "NONE"

        for event in pygame.event.get():
            if event.type == pygame.QUIT or keys[pygame.K_ESCAPE]:
                raise KeyboardInterrupt

        # Display camera view
        if camera_image:
            screen.blit(camera_image, (0, 0))
            font = pygame.font.SysFont("Arial", 36)
            text = font.render(f"Signal Indicator: {signal_indicator}", True, (255, 255, 255))
            screen.blit(text, (20, 20))
            pygame.display.flip()

        clock.tick(30)

finally:
    camera.stop()
    camera.destroy()
    vehicle.destroy()
    world.apply_settings(carla.WorldSettings(synchronous_mode=False))
    pygame.quit()
    print("Shutdown complete.")
