# TODO
- Benchmark con .train() / .eval()
- Benchmark con with torch.no_grad() / sin
- BGR -> RGB (y el entrenamiento igual)
- Evento "DESCONEXIÓN" del mando?

In [None]:
import time
import threading

import cv2
import ipywidgets
import numpy as np
import torch
import torchvision
from inputs import get_gamepad
from IPython.display import display
from jetcam.csi_camera import CSICamera
from jetcam.utils import bgr8_to_jpeg
from jetracer.nvidia_racecar import NvidiaRacecar
from matplotlib import pyplot as plt
from torch2trt import TRTModule

# Constants

In [None]:
CAP_WIDTH = 640 #960
CAP_HEIGHT = 480 #540
SZ = 224
THROTTLE_GAIN = 0.15
STEERING_OFFSET = +0.15
STEERING_GAIN = -0.55

In [None]:
mean = torch.Tensor([0.485, 0.456, 0.406]).cuda() # R, G, B
std = torch.Tensor([0.229, 0.224, 0.225]).cuda()

# Read pretrained model

In [None]:
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('models/road_following_model_trt.pth'))

In [None]:
model_trt.eval()

# Create RaceCar object

In [None]:
car = NvidiaRacecar()

# Create camera object

In [None]:
#camera = CSICamera(width=SZ, height=SZ, capture_fps=30) # fps=65
camera = CSICamera(
    capture_width=CAP_WIDTH, 
    capture_height=CAP_HEIGHT,
    width=CAP_WIDTH, 
    height=CAP_HEIGHT, 
    capture_fps=30) # fps=65

In [None]:
#image = camera.read()

In [None]:
#plt.imshow(image)

In [None]:
#plt.imshow(image)

# Display functions

In [None]:
def update_display():
    info = '\n'.join([
        f"th_gain: {car.throttle_gain:.2f}",
        f"st_offs: {car.steering_offset:.2f}",
    ])
    requests.get("http://localhost:8000/text/" + info)

def reset_display():
    requests.get("http://localhost:8000/stats/on")

# Driving (AI / manual) thread

In [None]:
preview_widget = ipywidgets.Image(width=CAP_WIDTH, height=CAP_HEIGHT)
display(preview_widget)

In [None]:
def get_road_center(image):
    h, w = image.shape[:2]
    image = cv2.resize(image, (SZ, SZ))

    # Convert uint8 array to float tensor, div by 255, permute dimensions HWC to CHW
    t = torchvision.transforms.functional.to_tensor(image) # uint8 -> float, [0,255] -> [0,1], 3, SZ, SZ
    
    # Move to GPU
    t = t.cuda()
    
    # Normalize with imagenet stats since we're using an imagenet pretrained model
    t = (t - mean[:,None,None]) / std[:,None,None]
    
    # Add batch dimension
    t = t[None, ...] # 1, 3, SZ, SZ
    
    # float -> float16 for faster inference
    t = t.half()

    # Get center of road
    o = model_trt(t)

    # Clamp to [-1, 1]
    o = torch.clamp(o, min=-1, max=1)

    # Remove batch dimension
    o = o.flatten()
    
    # Move to CPU
    o = o.cpu()
    
    # Get x, y values between -1, 1
    nx, ny = o.tolist()

    # Map to image x and y
    ix = int((nx + 1) / 2 * (w - 1))
    iy = int((ny + 1) / 2 * (h - 1))

    return nx, ny, ix, iy

In [None]:
%%timeit
image = camera.read()

# Get center of road
nx, ny, ix, iy = get_road_center(image)

# Draw circle and preview
cv2.circle(image, (ix, iy), 8, (0, 255, 0), 2)

jpg = bgr8_to_jpeg(image)
preview_widget.value = jpg

In [None]:
def drive():
    if manual_drive:
        car.throttle = manual_throttle
        car.steering = manual_steering
    else:
        image = camera.read()
        image = preprocess(image).half()
        output = model_trt(image).detach().cpu().numpy().flatten()
        x = float(output[0])
        #print(output)
    #time.sleep(0)

In [None]:
def update_image(change):
    # New image is a numpy array with shape (SZ, SZ, 3), dtype uint8
    image = change['new']

    # Get center of road
    nx, ny, ix, iy = get_road_center(image)

    # AI drive
    if not manual_drive:
        car.steering = x
    
    # Draw circle and preview
    cv2.circle(image, (ix, iy), 8, (0, 255, 0), 2)
    jpg = bgr8_to_jpeg(image)
    preview_widget.value = jpg    

In [None]:
camera.observe(update_image, names='value')

In [None]:
camera.running = True

In [None]:
right_trigger = left_trigger = 0
car.throttle = 0.0
car.throttle_gain = 
car.steering_offset = STEERING_OFFSET
car.steering_gain = STEERING_GAIN
manual_drive = True

while not (right_trigger and left_trigger):
    events = get_gamepad()
    for event in events:
        if event.ev_type == 'Absolute' and manual_drive:
            if event.code == 'ABS_Y':
                car.throttle = -(event.state - 127.5) / 127.5
                #print(f'Y={event.state}')
            if event.code == 'ABS_Z':
                car.steering = (event.state - 127.5) / 127.5
                #print(f'Z={event.state}')
        elif event.ev_type == 'Key' and event.code == 'BTN_TR2':
            right_trigger = event.state
        elif event.ev_type == 'Key' and event.code == 'BTN_TL2':
            left_trigger = event.state
        elif event.ev_type == 'Key' and event.state == 1:
            if event.code == 'BTN_WEST': # Y / UP
                car.throttle_gain = min(1.0, car.throttle_gain + 0.05)
                update_display()
            if event.code == 'BTN_SOUTH': # A / DOWN
                car.throttle_gain = max(0.0, car.throttle_gain - 0.05)
                update_display()
            if event.code == 'BTN_EAST': # B / RIGHT
                car.steering_offset = max(-0.3, car.steering_offset - 0.05)
                car.steering = 0.1
                car.steering = 0
                update_display()
            if event.code == 'BTN_NORTH': # X / LEFT
                car.steering_offset = min(0.3, car.steering_offset + 0.05)
                car.steering = 0.1
                car.steering = 0
                update_display()
            if event.code == 'BTN_START' and event.state == 1:
                print("AI mode")
                manual_drive = False
                car.throttle = 1.
            if event.code == 'BTN_SELECT' and event.state == 1:
                print("Manual mode")
                manual_drive = True
                car.throttle = 0.


                
            #print(f"throttle_gain={car.throttle_gain:.1f}")
        #print(event.ev_type, event.code, event.state)
        
manual_drive = True
camera.running = False
camera.unobserve_all()
car.throttle = 0

reset_display()