In [None]:
import carla
import pygame
import time
import random
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from carla import VehicleLightState

# === Model Class Labels ===
CLASS_NAMES = [
    "Left Hardest", "Left Harder", "Left Hard", "Left Medium", "Left Light", "Left Slight", "Left Minimal",
    "No Turning",
    "Right Minimal", "Right Slight", "Right Light", "Right Medium", "Right Hard", "Right Harder", "Right Hardest"
]

# === PyTorch Model Definition ===
class SteeringClassifier(nn.Module):
    def __init__(self, num_classes=15):
        super(SteeringClassifier, self).__init__()
        resnet = models.resnet18(pretrained=True)
        resnet.fc = nn.Identity()
        self.cnn = resnet
        self.turn_embed = nn.Embedding(3, 16)
        self.fc = nn.Sequential(
            nn.Linear(512 + 16, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, image, turn_signal):
        x_img = self.cnn(image)
        x_signal = self.turn_embed(turn_signal)
        x = torch.cat((x_img, x_signal), dim=1)
        return self.fc(x)

# === Initialize PyTorch Model ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SteeringClassifier(num_classes=15)
model.load_state_dict(torch.load("../Models/30.01_quaternary_model_final.pth", map_location=device))
model.eval().to(device)

image_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# === Pygame Setup ===
pygame.init()
width, height = 800, 600
screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("CARLA Model-Controlled Driving")

# === Connect to CARLA ===
client = carla.Client("localhost", 2000)
client.set_timeout(25.0)
client.load_world("Town05")
world = client.get_world()
blueprint_library = world.get_blueprint_library()

# === Spawn Vehicle ===
vehicle_bp = blueprint_library.filter("vehicle.dodge.charger_2020")[0]
spawn_point = random.choice(world.get_map().get_spawn_points())
vehicle = world.spawn_actor(vehicle_bp, spawn_point)
vehicle.set_autopilot(False)

# === Spectator Setup ===
spectator = world.get_spectator()

# === Attach Camera Sensor ===
camera_bp = blueprint_library.find('sensor.camera.rgb')
camera_bp.set_attribute('image_size_x', '448')
camera_bp.set_attribute('image_size_y', '252')
camera_bp.set_attribute('fov', '145')
camera_bp.set_attribute('sensor_tick', '0.1')
cam_transform = carla.Transform(carla.Location(x=1.5, z=2.4))
camera = world.spawn_actor(camera_bp, cam_transform, attach_to=vehicle)

# Shared variables
camera_image = None
camera_np = None
pred_label = "Loading..."

def process_image(image):
    global camera_image, camera_np
    array = np.frombuffer(image.raw_data, dtype=np.uint8)
    array = array.reshape((image.height, image.width, 4))
    array = array[:, :, :3][:, :, ::-1]  # RGB
    camera_np = array.copy()
    camera_image_raw = pygame.surfarray.make_surface(array.swapaxes(0, 1))
    camera_image = pygame.transform.scale(camera_image_raw, (800, 600))

camera.listen(process_image)

# === Clock ===
clock = pygame.time.Clock()

print("Model-controlled driving started.")
print("Q/E to turn on left/right signal, R to cancel.")
print("ESC or close window to exit.")

# === Steering Mapping for 15 Classes ===
steering_map = {
    0: -0.8,   1: -0.6,   2: -0.4,   3: -0.25,   4: -0.20,  5: -0.1,  6: -0.05,
    7:  0.0,
    8:  0.05,  9:  0.1, 10:  0.15, 11:  0.25,  12:  0.4,  13:  0.6,  14:  0.8
}

def respawn_vehicle():
    global vehicle, camera

    # Destroy current camera and vehicle
    if camera.is_listening:
        camera.stop()
    camera.destroy()
    vehicle.destroy()

    # Respawn vehicle at random location
    spawn_point = random.choice(world.get_map().get_spawn_points())
    vehicle_new = world.spawn_actor(vehicle_bp, spawn_point)
    vehicle_new.set_autopilot(False)

    # Reattach camera to new vehicle
    camera_new = world.spawn_actor(camera_bp, cam_transform, attach_to=vehicle_new)
    camera_new.listen(process_image)

    vehicle = vehicle_new
    camera = camera_new


# === Initial Control Variables ===
throttle = 0.3

# === Main Loop ===
try:
    while True:
        clock.tick(60)
        pygame.event.pump()
        keys = pygame.key.get_pressed()

        # Turn signal input (Q/E/R)
        if keys[pygame.K_q]:
            vehicle.set_light_state(VehicleLightState.LeftBlinker)
            turn_signal = -1
        elif keys[pygame.K_e]:
            vehicle.set_light_state(VehicleLightState.RightBlinker)
            turn_signal = 1
        elif keys[pygame.K_r]:
            vehicle.set_light_state(VehicleLightState.NONE)
            turn_signal = 0
        else:
            light_state = vehicle.get_light_state()
            if light_state == VehicleLightState.LeftBlinker:
                turn_signal = -1
            elif light_state == VehicleLightState.RightBlinker:
                turn_signal = 1
            else:
                turn_signal = 0


        if keys[pygame.K_w]:
            throttle += 0.005
        if keys[pygame.K_s]:
            if throttle > 0.1:
                throttle -= 0.005


        # Reset vehicle on 'P'
        if keys[pygame.K_p]:
            print("Respawning vehicle...")
            respawn_vehicle()
            throttle = 0.3
            time.sleep(1)  # Prevent rapid respawning if key is held down
            continue  # Skip rest of loop for this frame

            
        # === Predict and Control ===
        control = carla.VehicleControl()
        
        control.throttle = throttle

        prev_steering_value = 0.0
        last_steering_time = time.time()
        STEERING_HOLD_S = 0.05  # 50 ms

        if camera_np is not None:
            with torch.no_grad():
                img = image_transform(camera_np).unsqueeze(0).to(device)
                signal = torch.tensor([[turn_signal + 1]], dtype=torch.long).to(device)
                output = model(img, signal.squeeze(1))
                pred_class = output.argmax(dim=1).item()
                pred_label = CLASS_NAMES[pred_class]
                steering_value = steering_map[pred_class]

                # Enforce minimum turn for active signal
                if turn_signal == -1 and steering_value > -0.02:
                    steering_value = -0.02
                elif turn_signal == 1 and steering_value < 0.02:
                    steering_value = 0.02

                # === Steering control logic ===
                current_time = time.time()

                if steering_value == 0.0 or prev_steering_value == 0.0:
                    # Direct application with no hold or smoothing
                    control.steer = steering_value
                    prev_steering_value = steering_value
                    last_steering_time = current_time
                elif steering_value != prev_steering_value:
                    # New value → average and apply
                    smoothed_value = (steering_value + prev_steering_value) / 2
                    control.steer = smoothed_value
                    prev_steering_value = steering_value
                    last_steering_time = current_time
                else:
                    # Same value → wait until 10ms elapsed
                    if current_time - last_steering_time >= STEERING_HOLD_S:
                        control.steer = steering_value


        # if camera_np is not None:
        #     with torch.no_grad():
        #         img = image_transform(camera_np).unsqueeze(0).to(device)
        #         signal = torch.tensor([[turn_signal + 1]], dtype=torch.long).to(device)
        #         output = model(img, signal.squeeze(1))
        #         pred_class = output.argmax(dim=1).item()
        #         pred_label = CLASS_NAMES[pred_class]
        #         steering_value = steering_map[pred_class]
        #         if turn_signal == -1:
        #             if steering_value > -0.02:
        #                 steering_value = -0.02                  
        #         elif turn_signal == 1:
        #             if steering_value < 0.02:
        #                 steering_value = 0.02

        #         control.steer = steering_value
                

        vehicle.apply_control(control)

        # === Spectator follows vehicle ===
        car_transform = vehicle.get_transform()
        forward_vector = car_transform.get_forward_vector()
        cam_location = car_transform.location - forward_vector * 8 + carla.Location(z=3)
        cam_rotation = carla.Rotation(pitch=-10, yaw=car_transform.rotation.yaw)
        spectator.set_transform(carla.Transform(cam_location, cam_rotation))

        # === Display Feed + Prediction ===
        if camera_image:
            screen.blit(camera_image, (0, 0))
            font = pygame.font.SysFont(None, 30)
            label_turn = font.render(f"Predicted: {pred_label}", True, (255, 255, 0))
            label_steer = font.render(f"Steering: {steering_value:.3f}", True, (0, 255, 0))
            label_speed = font.render(f"Throttle: {throttle:.1f}", True, (0, 200, 255))

            screen.blit(label_turn, (20, 20))
            screen.blit(label_steer, (20, 50))
            screen.blit(label_speed, (20, 80))

            pygame.display.flip()

        for event in pygame.event.get():
            if event.type == pygame.QUIT or keys[pygame.K_ESCAPE]:
                raise KeyboardInterrupt

except KeyboardInterrupt:
    print("Exiting and cleaning up...")

finally:
    camera.stop()
    vehicle.destroy()
    camera.destroy()
    pygame.quit()


Model-controlled driving started.
Q/E to turn on left/right signal, R to cancel.
ESC or close window to exit.
Exiting and cleaning up...
