In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from PIL import Image

# ===========================
# CONFIGURATION
# ===========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASS_NAMES = ["L3", "L2", "L1", "L0", "S0", "R0", "R1", "R2", "R3",
               "S1", "S2", "S3", "L4", "R4", "S4"]  # your steering classes

# ===========================
# IMAGE TRANSFORMS
# ===========================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ===========================
# MULTITASK MODEL
# ===========================
class MultiTaskModel(nn.Module):
    def __init__(self, num_classes=15):
        super(MultiTaskModel, self).__init__()
        resnet = models.resnet18(pretrained=True)
        resnet.fc = nn.Identity()
        self.cnn = resnet
        self.turn_embed = nn.Embedding(3, 16)
        self.shared_fc = nn.Sequential(
            nn.Linear(512 + 16, 128),
            nn.ReLU()
        )
        self.steering_head = nn.Linear(128, num_classes)
        self.velocity_head = nn.Linear(128, 1)

    def forward(self, image, turn_signal):
        x_img = self.cnn(image)
        x_signal = self.turn_embed(turn_signal)
        x = torch.cat((x_img, x_signal), dim=1)
        x = self.shared_fc(x)
        steering_logits = self.steering_head(x)
        velocity = self.velocity_head(x)
        return steering_logits, velocity

# ===========================
# LOAD DATA
# ===========================
images = np.load("../CarlaData/val_images.npy", allow_pickle=True)
angles = np.load("../CarlaData/val_angles.npy")
turn_signals = np.load("../CarlaData/val_turn_signals.npy")
velocities = np.load("../CarlaData/val_velocities.npy")  # if available

# ===========================
# LOAD MODEL
# ===========================
model = MultiTaskModel()
model.load_state_dict(torch.load("../Models/30.01_quinary_model_final.pth", map_location=device))
model.to(device)
model.eval()

# ===========================
# PREDICT FUNCTION
# ===========================
def predict(model, image, turn_signal, device='cuda'):
    model.eval()
    model.to(device)

    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    img_tensor = transform(image).unsqueeze(0).to(device)
    turn_tensor = torch.tensor([[turn_signal + 1]], dtype=torch.long).to(device)

    with torch.no_grad():
        steering_logits, velocity = model(img_tensor, turn_tensor.squeeze(1))
        pred_class = steering_logits.argmax(dim=1).item()
        pred_velocity = velocity.item()

    return CLASS_NAMES[pred_class], pred_velocity

# ===========================
# VIEWER SETUP
# ===========================
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.2)
index = [0]

def show_image(idx):
    ax.clear()
    img = images[idx]
    steering = angles[idx]
    velocity = velocities[idx] if len(velocities) > idx else None
    turn = int(turn_signals[idx])
    turn_label = "Left" if turn == -1 else "Right" if turn == 1 else "None"

    pred_label, pred_velocity = predict(model, img, turn, device)

    ax.imshow(img)
    if velocity is not None:
        ax.set_title(f"Idx: {idx} | GT Steering: {steering:.2f} | GT Vel: {velocity:.2f} | Signal: {turn_label} | 🚘 Pred: {pred_label} | 🔺Vel: {pred_velocity:.2f}")
    else:
        ax.set_title(f"Idx: {idx} | GT Steering: {steering:.2f} | Signal: {turn_label} | 🚘 Pred: {pred_label} | 🔺Vel: {pred_velocity:.2f}")
    ax.axis('off')
    plt.draw()

def next_image(event):
    index[0] = (index[0] + 1) % len(images)
    show_image(index[0])

def prev_image(event):
    index[0] = (index[0] - 1) % len(images)
    show_image(index[0])

# ===========================
# BUTTON SETUP
# ===========================
axprev = plt.axes([0.1, 0.05, 0.2, 0.075])
axnext = plt.axes([0.7, 0.05, 0.2, 0.075])
bnext = Button(axnext, 'Next')
bnext.on_clicked(next_image)
bprev = Button(axprev, 'Previous')
bprev.on_clicked(prev_image)

# ===========================
# LAUNCH VIEWER
# ===========================
show_image(index[0])
plt.show()


FileNotFoundError: [Errno 2] No such file or directory: '../CarlaData/val_images.npy'