In [None]:
### Import packages
import math
import time
import signal
import typing
import warnings
import cv2, PIL
import threading
import traceback
import numpy as np
import pandas as pd
from cv2 import aruco
import matplotlib as mpl
from djitellopy import Tello
import ipywidgets as widgets
from typing import Tuple, List
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
%matplotlib nbagg
warnings.filterwarnings('ignore')

In [None]:
### Create buttons for controlling the drone
dir_items = [
    widgets.Button(
        description='left',
        button_style='success', # 'success', 'info', 'warning', 'danger' or ''
        icon='arrow-left' # (FontAwesome names without the `fa-` prefix)
    ),
    widgets.Button(
        description='forward',
        button_style='success', # 'success', 'info', 'warning', 'danger' or ''
        icon='arrow-up' # (FontAwesome names without the `fa-` prefix)
    ),
    widgets.Button(
        description='backward',
        button_style='success', # 'success', 'info', 'warning', 'danger' or ''
        icon='arrow-down' # (FontAwesome names without the `fa-` prefix)
    ),
    widgets.Button(
        description='right',
        button_style='success', # 'success', 'info', 'warning', 'danger' or ''
        icon='arrow-right' # (FontAwesome names without the `fa-` prefix)
    )]
rot_items = [
    widgets.Button(
        description='rotate left',
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        icon='rotate-left' # (FontAwesome names without the `fa-` prefix)
    ),
    widgets.Button(
        description='rotate right',
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        icon='rotate-right' # (FontAwesome names without the `fa-` prefix)
    ),
    widgets.Button(
        description='stop',
        button_style='danger', # 'success', 'info', 'warning', 'danger' or ''
        icon='stop' # (FontAwesome names without the `fa-` prefix)
    ),
    widgets.Button(
        description='land',
        button_style='warning', # 'success', 'info', 'warning', 'danger' or ''
        icon='caret-down' # (FontAwesome names without the `fa-` prefix)
    )  
]

In [None]:
'''
Let's first define some actions

Drone should be able to
* Rotate left
* Rotate right
* Move left
* Move right
* Move forward
* Move backwards

6 possible actions, okay, how big can we make our state space?

state space
'''
### GLOBAL VARIABLES ###
camera_matrix = np.array([[230.912545, 0.000000, 120.790180],
[0.000000, 230.691836, 158.098598],
[0.000000, 0.000000, 1.000000]])

distortion_coeff = np.array([0.278917, -0.485155, 0.002496, 0.004631, 0.000000])
### Create Aruco tag detection framework
aruco_dict = aruco.Dictionary_get(aruco.DICT_APRILTAG_36h11)
parameters =  aruco.DetectorParameters_create()

# Image size parameters
MARKER_SIZE = 0.2 # Marker size is 200 mm
IMAGE_W = 240 # This corresponds to __ direction
IMAGE_H = 320 # This corresponds to __ direction
GRID_N = 10 # Discretization for image
ANGLE_N = 4 # Discretization for rotation

# RL Parameters - These might need to be tuned later on
ALPHA = 0.1
GAMMA = 0.95
NEW_ACTION = False
USE_ACTION = 6 # Default to holding in postion
NUM_ACTIONS = 6
WAIT_INPUT = False

def state_to_grid_ref(state: np.array) -> np.ndarray:
    '''
    Take the current state of the tag, and find what indices we should use to update
    
    param:
        state (np.array): x,y,theta
    returns:
        (np.array): Bin the current position belongs in
    '''
    # Discretize the x-y
    #print(state)
    x_grid = int(np.clip(state[0] // (IMAGE_W/GRID_N), 0, GRID_N))
    y_grid = int(np.clip(state[1] // (IMAGE_H/GRID_N), 0, GRID_N))
    theta_raw = state[2]
    # Rotation bins: [[45,135],[135,225],[225,315],[315,45]]
    if theta_raw > -15 and theta_raw < 15: # Forward direction
        theta_grid = 0
    elif theta_raw <= -15 and theta_raw >= -135: # Left direction
        theta_grid = 1
    elif theta_raw > 135 or theta_raw < -135: # Down direction
        theta_grid = 2
    elif theta_raw >= 15 and theta_raw <= 135: # Right direction
        theta_grid = 3
#     print(f'Converted grid reference: {np.array([x_grid, y_grid, theta_grid])}')
    return np.array([x_grid, y_grid, theta_grid])

def create_q_table(x_y_space: int, rotation_increments: int, num_actions: int):
    '''
    Create the lookup value for the Q-value grid
    
    param:
        x_y_space (int): Number of discrete grid cells to use for the x-y state space
        rotation_increments (int): Number of rotation discretizations (360/n)
        num_actions (int): Length of total selectable actions
    
    returns:
        (np.ndarray) - 4-D array representing the Q-values for each action
    '''
    # This table is a 4-D array, where each x,y,z-rotation
    # Intialize to the minimum value
    return np.zeros((x_y_space, x_y_space, rotation_increments, num_actions))

def create_policy_table(q_table: np.ndarray) -> np.ndarray:
    '''
    Given a Q-table at the end of the training cycle, create a simple policy that maps actions from the value of being
    in a state e.g.
    
    param:
        q_table (np.ndarray): 4-D array representing the Q-values for each action
    
    returns:
        (np.ndarray): 3-D array representing the best action to choose given the current x,y,theta position
    '''
    # Create the best action, by estimating the value of being in a particular state
    return np.argmax(q_table, axis=3)


def get_action_from_policy(policy: np.ndarray, state) -> np.ndarray:
    '''
    Given a Q-table at the end of the training cycle, create a simple policy that maps actions from the value of being
    in a state e.g.
    
    param:
        policy (np.ndarray): Action table for state
        state (np.array): x,y,theta of drone
    returns:
        (int): Action to execute on drone
    '''
    # State
    state_idx = state_to_grid_ref(state)
    # Prevent us from exploring only a single option
    all_actions = Q_TABLE[state_idx[0],state_idx[1],state_idx[2],:]
    tied_actions = np.where(all_actions == np.max(all_actions))[0]
    return np.random.choice(tied_actions)

def setup_drone():
    '''
    Connects to the Tello drone
    
    returns:
        (Tello) Connected instance of the DJI Tello drone
    '''
    try:
        # Create tello instance
        tello = Tello()
        # Connect to the drone
        tello.connect()
        # Start streaming data
        tello.streamon()
        return tello
    except Exception as e:
        print(f'Oops! an error occured: {str(e)}')
        print(traceback.print_exc())
        
def process_image(img: np.ndarray) -> np.ndarray:
    '''
    Pre-process image by removing byte junk, rotating, and sharpening
    
    param:
        img (np.ndarray): 2-D image array representing the pixel intensities
        
    returns:
        (np.ndarray): Cleaned image
    
    '''
    # Remove garbage bytes from framebuffer
    img = img[:240]
    # Rotate image into the correct orientation
    img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
    # Try to sharpen the image to help with aruco tag detection
    #kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
    #img = cv2.filter2D(img, -1, kernel)
    return img

def detect_aruco_and_pose(img: np.ndarray) -> np.array:
    '''
    Detect ArUco tags in an image using the April Tag library
    
    param:
        img (np.ndarray): Input downward facing camera image
    
    returns:
        (np.ndarray): Calculated object pose (x,y,theta)
    '''
    corners, ids, rejectedImgPoints = aruco.detectMarkers(img, aruco_dict, parameters=parameters)
    if len(corners) > 0:
        frame_markers = aruco.drawDetectedMarkers(img.copy(), corners, ids)
        # Given the image, let's estimate our pose
        pose = pose_estimation(corners, ids, img.copy())
        # font
        font = cv2.FONT_HERSHEY_SIMPLEX
        # org
        org = (20, 20)
        # fontScale
        fontScale = 0.5
        # Blue color in BGR
        color = (255, 0, 0)
        # Line thickness of 2 px
        thickness = 2
        # Using cv2.putText() method
        frame_markers = cv2.putText(frame_markers, f'x:{int(pose[0])} y:{int(pose[1])} t:{int(pose[2])}', org, font, 
                           fontScale, color, thickness, cv2.LINE_AA)
        # Draw center dot
        frame_markers = cv2.circle(frame_markers, (int(pose[0]), int(pose[1])), 10, (0,0,255), -1)
        # Draw the grid references
        pose_ref = state_to_grid_ref(pose)
        frame_markers = cv2.putText(frame_markers, f'Grid Ref: {pose_ref}', (0,320), font, 
                   fontScale, color, thickness, cv2.LINE_AA)
        cv2.imshow('Aruco Tags', frame_markers)
        # cv2.imwrite('./img/last.png', frame_markers)
        return pose
    else:
        # print('NO TAG DETECTED! PLEASE MOVE THE DRONE IN SIGHT OF THE TAG AND CONTINUE')
        return np.array([])
    
def isRotationMatrix(R: np.ndarray) -> bool:
    '''
    Checks whether a rotation matrix is valid
    
    param:
        R (np.ndarray): Rotation matrix
    returns:
        (bool): Status as valid rotation matrix
    '''
    Rt = np.transpose(R)
    shouldBeIdentity = np.dot(Rt, R)
    I = np.identity(3, dtype = R.dtype)
    n = np.linalg.norm(I - shouldBeIdentity)
    return n < 1e-6
 
# Calculates rotation matrix to euler angles
# The result is the same as MATLAB except the order
# of the euler angles ( x and z are swapped ).
def rotationMatrixToEulerAngles(R: np.ndarray) -> np.array:
    '''
    Calculates rotation matrix to euler angles
    The result is the same as MATLAB except the order
    of the euler angles ( x and z are swapped ).
    
    param:
        R (np.ndarray): Rotation matrix
    returns:
        (np.array): Roll, pitch, yaw
    '''
 
    assert(isRotationMatrix(R))
    sy = math.sqrt(R[0,0] * R[0,0] +  R[1,0] * R[1,0])
    singular = sy < 1e-6
    if  not singular :
        x = math.atan2(R[2,1] , R[2,2])
        y = math.atan2(-R[2,0], sy)
        z = math.atan2(R[1,0], R[0,0])
    else :
        x = math.atan2(-R[1,2], R[1,1])
        y = math.atan2(-R[2,0], sy)
        z = 0
    return np.array([x, y, z])

def extract_rotation(rvect: np.ndarray) -> float:
    '''
    Convert Rodrigues rotation representation to Euler angle for ease of use
    
    param:
        rvect (np.ndarray): Representation of rotation in angle-axis notation
    returns:
        (float): Euler angle of rotation around the z-axis
    '''
    rotation_matrix, _ = cv2.Rodrigues(rvect)
    # Get just the z-axis rotation value
    return np.degrees(rotationMatrixToEulerAngles(rotation_matrix)[2])

def pose_estimation(corners: List, ids: List, frame: np.ndarray) -> None:
    # If markers are detected
    if len(corners) > 0:
        for i in range(0, len(ids)):
            # Estimate pose of each marker and return the values rvec and tvec---(different from those of camera coefficients)
            
            rotation, translation, markerPoints = cv2.aruco.estimatePoseSingleMarkers(corners[i], MARKER_SIZE, camera_matrix,
                                                                       distortion_coeff)
            # Draw a square around the markers
            cv2.aruco.drawDetectedMarkers(frame, corners) 
            # Draw Axis
            frame = cv2.drawFrameAxes(frame, camera_matrix, distortion_coeff, rotation, translation, 0.01)  
            # Calculate the position IN the image here
            x, y = [corners[0][0][:, 0].mean()], [corners[0][0][:, 1].mean()]
            theta = extract_rotation(rotation)
            return np.array([x[0],y[0],theta])
    else:
        return np.array([])
    

def update_q_table(state: np.array, action: int, state_prime: np.array) -> None:
    '''
    Update value of action in table using the Q-function
    
    param:
        state (np.array): Current state discretization
        action (int): Numeric value indicating what we should do
        state_prime (np.array): Next predicted state
    return:
        None
    '''
    # Q(s,a) = Q(s,a) + alpha * [R(s,a)] + gamma * max Q(s',a') - Q(s,a)]
    global Q_TABLE
    state_idx = state_to_grid_ref(state)
    state_prime_idx = state_to_grid_ref(state_prime)
    Q_TABLE[(*state_idx,action)] += ALPHA * (reward(state,action) + GAMMA * np.max(Q_TABLE[state_prime_idx]) - Q_TABLE[(*state_idx,action)])

def reward(state: np.array, action: int) -> float:
    '''
    Calculate the exected reward of being in a state and taking an action
    
    param:
        state (np.array): Current state of the drone
        action (int): Encoded action to execute
    returns:
        (float): Calculated reward for occupying the current state
    '''
    # Okay spit balling this reward function - the reward here should be a function of the offset angle
    # of the drone. e.g. given 0 is the target, are we between +/- 45, or are we completely turned around
    # secondly the reward should be a penalty for the distnace
    # basically, if the drone is not in the correct angle, it should rotate to face the heading vector
    # THEN the function should penalize being outside of center of the image by some amount
    goal_state = np.array([IMAGE_W/2, IMAGE_H/2])
    distance = np.linalg.norm(state[:2] - goal_state)
    print(f'Distance to goal: {distance}')
    # See how close we are on the angle
    grid_ref = state_to_grid_ref(state)
    # Drone is not in the forward position, we should rotate to face the target direction
    # If the distance is short, we are within some close portion of the goal
    if distance > 50:
        return 0 # Slightly prioritize rotating, but penalize offset distance
    else:
        return 1 # Huzzah! We are in a good enough state
    # Drone is not in the forward position
    if grid_ref[2] != 0:
        return -0.1
    
    
def stop_action() -> None:
    '''
    Send a blank action to keep the drone from drifting in a certain direction
    '''
    for _ in range(10):
        tello.send_rc_control(0, 0, 0, 0)
        
def next_step(state: np.array, action: int) -> np.array:
    '''
    Given the current state and a selected action, estimate the location of the
    fiducial marker in the next time step
    
    e.g. (state) + velocity command * timestep ~= new state
    
    The tello has imperfect vehicle dynamics, so this allows us to make a reasoned
    estimate using the Bellman equation
    
    [0] Rotate left
    [1] Rotate right
    [2] Move left
    [3] Move right
    [4] Move forward
    [5] Move backwards
    [6] STOP
    
    NOTE: THESE TRANSITION STEPS ARE DEPENDENT ON ALTITUDE, WHICH SHOULD REMAIN FIXED!
    
    TARGET HEIGHT = 1 meter above the ground
    
    param:
        state (np.array): x,y,theta position of the drone
        action (int): Encoded action to execute
    returns:
        (np.array): New predicted state
    '''
    next_state = state
    new_theta = state[2]
    if action == 0: # Rotate left
        if state[2] < -175: # This is about where the range of the drone will clip
            new_theta = state[2] - 5
        else:
            next_theta = 180 - abs(state[2] - 5 + 180) # This should be right
    elif action == 1: # Rotate right, increase by 5
        if state[2] < 175:
            new_theta = state[2] + 5
        else:
            next_theta = -180 + abs(state[2] - 5 - 180) # This should be right
    elif action == 2: # Move left, target will shift to right
        next_state[0] = state[0] + 24
    elif action == 3: # Move right in the world, target will shift to left
        next_state[0] = state[0] - 24
    elif action == 4: # Move forwards, target will shift down
        next_state[1] = state[1] + 30
    elif action == 5: # Move backwards, target will shift up
        next_state[1] = state[1] - 30
    elif action == 6: # Stay
        pass
    
    # Perform logic to double check our state
    new_x = np.clip(next_state[0], 0, IMAGE_W)
    new_y = np.clip(next_state[1], 0, IMAGE_H)    
    return np.array([new_x, new_y, new_theta])

    
def take_action(action: int) -> None:
    '''
    Take action that advances our state
    
    param:
        action (int): Encoded action to execute
    '''
    SPEED = 20 # This value is hard-coded and can be adjusted as need be
    lr_vel = 0
    fb_vel = 0
    ud_vel = 0
    yaw = 0
    if action == 0:
        yaw = -SPEED # Rotate clockwise
    elif action == 1:
        yaw = SPEED # Rotate counterclockwise
    elif action == 2:
        lr_vel = -SPEED # Left
    elif action == 3:
        lr_vel = SPEED # Right
    elif action == 4:
        fb_vel = SPEED # Forward
    elif action == 5:
        fb_vel = -SPEED # Backwards
    elif action == 6:
        pass
        
    # Send our command to the drone
    tello.send_rc_control(lr_vel, fb_vel, ud_vel, yaw)
    time.sleep(0.1)
    # Countermand the drone command to prevent the drone from drifting
    stop_action()

In [None]:
### Create global tello drone instance
tello = setup_drone()
tello.takeoff()
tello.set_video_direction(Tello.CAMERA_DOWNWARD)
# Zero out any residual control inputs
tello.send_rc_control(0, 0, 0, 0)

# Create new value table
Q_TABLE = create_q_table(GRID_N,ANGLE_N,NUM_ACTIONS)
# Create new policy
POLICY = create_policy_table(Q_TABLE)

# Signal handler to remove dangling connections
def shutdown_drone(sig, frame):
    print('Gracefully shutting down connection...')
    tello.land()
    #### SAVE THE Q TABLE ###
    np.save('./models/q_vals_temp.npy', Q_TABLE)
    ### SAVE THE POLICY
    np.save('./models/policy_temp.npy', POLICY)

# Register the shutdown command
signal.signal(signal.SIGINT, shutdown_drone)

# Define callback function for button presses
def on_button_clicked(b):
    print(b.description)
    if b.description == 'land':
        shutdown_drone(None,None)
        return
    lookup = {
        'rotate left': 0,
        'rotate right': 1,
        'left': 2,
        'right': 3,
        'forward': 4,
        'backward': 5,
        'stop': 6,
    }
    global USE_ACTION
    USE_ACTION = lookup[b.description]
    global WAIT_INPUT
    WAIT_INPUT = False

# Assign the callback function for each of the buttons
for button in dir_items:
    button.on_click(on_button_clicked)
for button in rot_items:
    button.on_click(on_button_clicked)

# Main execution loop
def train_e_greedy():
    POLICY = create_policy_table(Q_TABLE)
    while True:
        # grab current image from the drone
        img = tello.get_frame_read().frame
        img = process_image(img)
        # print(img.shape)
        cv2.imshow('Tello Vid', img)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
        # Detect aruco tag and calculate its current pose (x,y,theta)
        # Allow user to take an action - this would normally be an epsilon greedy step
        state = detect_aruco_and_pose(img)
        # Estimate our current pose
        if len(state) > 0:
            if np.random.rand() > 0.90:
                action = np.random.choice(range(6))
            else:
                action = get_action_from_policy(POLICY, state)
            print(f'State: {state} Action: {action}')
            # Given an action and the current state, predict where the drone will end up next
            state_prime = next_step(state, action)
            # Update the Q table
            update_q_table(state, action, state_prime)
            # apply the action
            take_action(action)
            # update the policy
            POLICY = create_policy_table(Q_TABLE)
        time.sleep(0.1)
    tello.land()
        
def train_human_in_the_loop():
    POLICY = create_policy_table(Q_TABLE)
    while True:
        # grab current image from the drone
        img = tello.get_frame_read().frame
        img = process_image(img)
        # print(img.shape)
        cv2.imshow('Tello Vid', img)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
        # Detect aruco tag and calculate its current pose (x,y,theta)
        # Allow user to take an action - this would normally be an epsilon greedy step
        state = detect_aruco_and_pose(img)
        # Estimate our current pose
        if len(state) > 0:
            random_number = np.random.rand()
            if random_number > 0.90:
                # 10% of the time, try a random action
                action = np.random.choice(range(6)) # Take a random action
            elif random_number > 0.50:
                # 40% of the time, we ask the user for the correct input
                global WAIT_INPUT
                WAIT_INPUT = True
                display(widgets.VBox([widgets.HBox(dir_items), widgets.HBox(rot_items)]))
                while WAIT_INPUT:
                    action = USE_ACTION
                    time.sleep(0.1)
            else:
                # 50% of the time, we take an action from the policy
                action = get_action_from_policy(POLICY, state)
            print(f'State: {state} Action: {action}')
            # Given an action and the current state, predict where the drone will end up next
            state_prime = next_step(state, action)
            # Update the Q table
            update_q_table(state, action, state_prime)
            # apply the action
            take_action(action)
            # update the policy
            POLICY = create_policy_table(Q_TABLE)
        time.sleep(0.1)
    tello.land()
        
def evaluate_policy(POLICY: np.ndarray) -> None:
    '''
    Evaluate a saved policy
    '''
    while True:
        # grab current image from the drone
        img = tello.get_frame_read().frame
        img = process_image(img)
        # print(img.shape)
        cv2.imshow('Tello Vid', img)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
        # Detect aruco tag and calculate its current pose (x,y,theta)
        state = detect_aruco_and_pose(img)
        # Estimate our current pose
        if len(state) > 0:
            action = get_action_from_policy(POLICY, state)
            print(f'State: {state} Action: {action}')
            # apply the action
            take_action(action)
            # update the policy
            POLICY = create_policy_table(Q_TABLE)
        time.sleep(0.1)
    tello.land()

In [None]:
train()

In [None]:
cv2.destroyAllWindows()
#### SAVE THE Q TABLE ###
np.save('./models/q_vals.npy', Q_TABLE)
### SAVE THE POLICY
np.save('./models/policy.npy', POLICY)
tello.land()