# Object Following - Live Demo

In this notebook we'll show how you can follow an object with JetBot!  We'll use a pre-trained neural network
that was trained on the [COCO dataset](http://cocodataset.org) to detect 90 different common objects.  These include

* Person (index 0)
* Cup (index 47)

and many others (you can check [this file](https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_complete_label_map.pbtxt) for a full list of class indices).  The model is sourced from the [TensorFlow object detection API](https://github.com/tensorflow/models/tree/master/research/object_detection)
which provides utilities for training object detectors for custom tasks also!  Once the model is trained, we optimize it using NVIDIA TensorRT on the Jetson Nano.

This makes the network very fast, capable of real-time execution on Jetson Nano!  We won't run through all of the training and optimization steps in this notebook though.

Anyways, let's get started.  First, we'll want to import the ``Object_Follower`` class which takes our pre-trained SSD engine or yolo engine.

In [None]:
from IPython.display import display
import ipywidgets.widgets as widgets
from IPython.display import clear_output

from ipywidgets.widgets import HBox, Layout
import traitlets


In [None]:
# %pip install pandas
from jetbot.utils import model_selection
from jetbot import ObjectFollower

In [None]:
# avoider_model='../collision_avoidance/best_model.pth'
OF = ObjectFollower(init_sensor_of=True)
OF.conf_th = 0.5

trt_ms = model_selection(core_library = "TensorRT")
trt_ms.model_function = "object detection"

model_function_widget = widgets.Select(options=trt_ms.model_function_list, value="object detection",
                                       description='Model Function:')

model_type_widget = widgets.Select(options=trt_ms.model_type_list, value="SSD", description='Model Type:')
traitlets.dlink((trt_ms, 'model_type_list'), (model_type_widget, 'options'))
traitlets.dlink((model_type_widget, 'value'), (trt_ms, 'model_type'))
traitlets.dlink((trt_ms, 'model_type'), (OF, 'type_follower_model'))

model_path_widget = widgets.Select(options=trt_ms.model_path_list, description='Model Path:',
                                   layout=Layout(width='75%'))
traitlets.dlink((trt_ms, 'model_path_list'), (model_path_widget, 'options'))
traitlets.dlink((model_path_widget, 'value'), (trt_ms, 'model_path'))
traitlets.dlink((trt_ms, 'model_path'), (OF, 'follower_model'))


In [None]:
out = widgets.Output()

# image_widget = widgets.Image(format='jpeg', width=OF.img_width, height=OF.img_height)
image_widget = widgets.Image(format='jpeg', width=350, height=350)

# display(image_widget)
traitlets.dlink((OF, 'cap_image'), (image_widget, 'value'))

# display buttons
button_layout = widgets.Layout(width='100px', height='40px', align_self='center')
stop_button = widgets.Button(description='Stop', tooltip='Click to stop running', icon='stop', layout=button_layout)
stop_button.style.button_color='Red'

start_button = widgets.Button(description='Start', tooltip='Click to start running', icon='play', layout=button_layout)
start_button.style.button_color='lightBlue'
button_box = widgets.HBox([start_button, stop_button], layout=widgets.Layout(align_self='center'))


In [None]:
blocked_widget = widgets.FloatSlider(min=0.0, max=1.0, value=0.0, description='blocked')
label_widget = widgets.IntText(value=1, description='tracked label')  # target to be tracked
label_text_widget = widgets.Text(value='', description='label name')  # target name to be tracked
speed_widget = widgets.FloatSlider(value=0.18, min=0.05, max=0.5, step=0.001, description='speed', readout_format='.3f')
speed_gain_widget = widgets.FloatSlider(value=0.18, min=0.05, max=0.5, step=0.001, description='speed gain', readout_format='.3f')
turn_gain_widget = widgets.FloatSlider(value=0.25, min=0.05, max=0.5, step=0.001, description='turn gain', readout_format='.3f')
steering_bias_widget = widgets.FloatSlider(value=0.02, min=-0.1, max=0.1, step=0.001, description='steering bias', readout_format='.3f')

traitlets.dlink((OF, 'blocked'), (blocked_widget, 'value'))
traitlets.dlink((label_widget, 'value'), (OF, 'label'))
traitlets.dlink((OF, 'label_text'), (label_text_widget, 'value'))
traitlets.dlink((OF, 'speed_of'), (speed_widget, 'value'))
traitlets.dlink((turn_gain_widget, 'value'), (OF, 'turn_gain_of'))
traitlets.dlink((speed_gain_widget, 'value'), (OF, 'speed_gain_of'))
traitlets.dlink((steering_bias_widget, 'value'), (OF, 'steering_bias_of'))

In [None]:
def start(change):
    OF.start_of(change)

def stop(change):
    OF.stop_of(change)
    %reset -f

In [None]:
display(HBox([model_type_widget, model_path_widget]))

display(widgets.VBox([
    widgets.HBox([image_widget, blocked_widget]),
    widgets.HBox([label_widget, label_text_widget]),
    widgets.HBox([speed_gain_widget, speed_widget]),
    turn_gain_widget,
    steering_bias_widget,
    button_box
]))

start_button.on_click(start)
stop_button.on_click(stop)