In [1]:
import cv2
import torch
import numpy as np
from PIL import Image
from torch2trt import torch2trt

from jetracer.nvidia_racecar import NvidiaRacecar
from jetcam.csi_camera import CSICamera

from utils import preprocess 


In [2]:


CLASSES = ['drive', 'stop']  
device = torch.device('cuda')

In [4]:
# Road-following model
import torchvision 

follow_model = torchvision.models.resnet18(pretrained=False)
follow_model.fc = torch.nn.Linear(512, 2)
follow_model = follow_model.to(device).eval().half()
follow_model.load_state_dict(torch.load('road_following_model_11_6.pth'))

# Classification model
classification_model = torchvision.models.resnet18(pretrained=False)
classification_model.fc = torch.nn.Linear(512, len(CLASSES))
classification_model = classification_model.to(device).eval().half()
classification_model.load_state_dict(torch.load('drive_stop_model_10_6.pth'))

# TensorRT optimization
dummy_input = torch.zeros((1, 3, 224, 224)).to(device).half()
follow_model_trt = torch2trt(follow_model, [dummy_input], fp16_mode=True)
classification_model_trt = torch2trt(classification_model, [dummy_input], fp16_mode=True)


In [12]:
torch.save(follow_model_trt.state_dict(), 'follow_model_trt.pth')
torch.save(classification_model_trt.state_dict(), 'classification_model_trt_12_06.pth')

In [3]:
camera = CSICamera(width=224, height=224, capture_fps=30)
car = NvidiaRacecar()


In [4]:
# run when already optimized
from torch2trt import TRTModule

follow_model_trt = TRTModule()
classification_model_trt = TRTModule()
follow_model_trt.load_state_dict(torch.load('road_following_model_11_6trt.pth'))
classification_model_trt.load_state_dict(torch.load('classification_model_trt_12_06.pth'))


<All keys matched successfully>

In [5]:
def get_classification(image_tensor):
    with torch.no_grad():
        output = classification_model_trt(image_tensor)
        pred_idx = torch.argmax(output, dim=1).item()
        return CLASSES[pred_idx]


In [13]:
import ipywidgets
from jetcam.utils import bgr8_to_jpeg
import cv2

# Create image widget
prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)
display(prediction_widget)

def update_display(image, output, decision):
    # output: numpy array from follow_model_trt
    x = float(output[0])
    y = float(output[1]) if len(output) > 1 else 0.5  # Use y if your model predicts it, else center

    # Convert normalized coordinates to pixel positions
    cx = int(camera.width * (x / 2.0 + 0.5))
    cy = int(camera.height * (y / 2.0 + 0.5))

    img_with_circle = image.copy()
    img_with_circle = cv2.circle(img_with_circle, (cx, cy), 8, (255, 0, 0), 3)
    cv2.putText(img_with_circle, decision, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)
    prediction_widget.value = bgr8_to_jpeg(img_with_circle)

Image(value=b'', format='jpeg', height='224', width='224')

In [16]:
STEERING_GAIN = 0.9
STEERING_BIAS = 0.05
THROTTLE = 0.115

In [17]:
import time

try:
    while True:
        raw_image = camera.read()
        preprocessed = preprocess(raw_image).half().unsqueeze(0).to(device)

        # Run both models on the same preprocessed image
        decision = get_classification(preprocessed)
        output = follow_model_trt(preprocessed).cpu().numpy().flatten()

        if decision == "drive":
            car.steering = float(output[0]) * STEERING_GAIN + STEERING_BIAS
            car.throttle = THROTTLE
        else:
            # If stopped, set output to default for display
            output = np.array([0.0, 0.5])
            car.throttle = 0.0
            car.steering = 0.0

        update_display(raw_image, output, decision)

        time.sleep(0.05)

except KeyboardInterrupt:
    car.throttle = 0
    car.steering = 0
    print("Stopped by user")

Stopped by user


In [9]:
import ipywidgets as widgets
from IPython.display import display

state_widget = widgets.ToggleButtons(options=['stop', 'live'], description='State:', value='stop')
display(state_widget)


ToggleButtons(description='State:', options=('stop', 'live'), value='stop')