In [1]:
from enum import IntEnum
import time
import jsonrpclib
import subprocess
from subprocess import PIPE, Popen
from threading  import Thread
import sys
import re
from collections import OrderedDict

import PySimpleGUI as sg

from gym import Env, error, spaces, utils
from stable_baselines3 import DQN, PPO, A2C, TD3, SAC
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.env_checker import check_env
import numpy as np

import os
import requests
import shutil
import tempfile
import xml.etree.ElementTree as ET
from io import StringIO, BytesIO

import cv2
import numpy as np
import torch
from PIL import Image

import olympe
from olympe.messages.ardrone3.Piloting import TakeOff, Landing, moveBy, PCMD, moveTo
from olympe.messages.ardrone3.PilotingState import FlyingStateChanged, PositionChanged, GpsLocationChanged, moveToChanged
from olympe.enums.ardrone3.PilotingState import FlyingStateChanged_State as FlyingState
from olympe.messages.ardrone3.GPSSettingsState import GPSFixStateChanged, HomeChanged
from olympe.messages.gimbal import set_target, attitude
from olympe.messages.camera import (
    set_camera_mode,
    set_photo_mode,
    take_photo,
    photo_progress,
)
from olympe.media import (
    media_created,
    resource_created,
    media_removed,
    resource_removed,
    resource_downloaded,
    indexing_state,
    delete_media,
    download_media,
    download_media_thumbnail,
    MediaEvent,
)

from pynput.keyboard import Listener, Key, KeyCode
from collections import defaultdict

olympe.log.update_config({
    "loggers": {
        "olympe": {
                "handlers": []
            }
        },
        "ulog": {
            "level": "OFF",
            "handlers": [],
        }
})

In [2]:
class KeyboardCtrl(Listener):
    def __init__(self, ctrl_keys=None):
        self._key_pressed = defaultdict(lambda: False)
        self._last_action_ts = defaultdict(lambda: 0.0)
        super().__init__(on_press=self._on_press)
        self.start()

    def _on_press(self, key):
        if isinstance(key, KeyCode):
            self._key_pressed[key.char] = True
        elif isinstance(key, Key):
            self._key_pressed[key] = True
        if self._key_pressed[Key.esc]:
            return False
        else:
            return True
        
    def quit(self):
        return not self.running or self._key_pressed[Key.esc]

In [6]:
class Action:
    def __init__(self, drone):
        self.drone = drone
        self.drone(GPSFixStateChanged(_policy = 'wait'))
    
    def takeoff(self):
        pass
    
    def land(self):
        self.drone(PCMD(1, 0, 0, 0, 0, 0) >> FlyingStateChanged(state="hovering", _timeout=5)).wait()
        assert self.drone(Landing() >> FlyingStateChanged(state="landed")).wait().success()
    
    def __len__(self):
        pass

class Reward:
    def __init__(self, drone):
        self.drone = drone
        self.setup_photo_single_mode()
    
    def setup_photo_single_mode(self):
        self.drone(set_camera_mode(cam_id=0, value="photo")).wait()

        self.drone(
            set_photo_mode(
                cam_id=0,
                mode="single",
                format="rectilinear",
                file_format="jpeg",
                burst="burst_14_over_1s",
                bracketing="preset_1ev",
                capture_interval=0.0,
                )
            ).wait()
        
        pitch = -90.0
        self.drone(
            set_target(
                gimbal_id=0,
                control_mode="position",
                yaw_frame_of_reference="none",
                yaw=0.0,
                pitch_frame_of_reference="absolute",
                pitch=pitch,
                roll_frame_of_reference="none",
                roll=0.0,
                )
            >> attitude(
                pitch_absolute=pitch, _policy="wait", _float_tol=(1e-3, 1e-1)
                )
            ).wait(_timeout=20)
        
    def take_photo_single(self):

        # Drone web server URL
        ANAFI_URL = "http://{}/".format(self.DRONE_IP)

        # Drone media web API URL
        ANAFI_MEDIA_API_URL = ANAFI_URL + "api/v1/media/medias/"

        self.drone(PCMD(1, 0, 0, 0, 0, 0) >> FlyingStateChanged(state="hovering", _timeout=5)).wait()
        photo_saved = self.drone(photo_progress(result="photo_saved", _policy="wait"))
        self.drone(take_photo(cam_id=0)).wait()
        if not photo_saved.wait().success():
            print("Photos not saved")
        media_id = photo_saved.received_events().last().args["media_id"]

        # download the photos associated with this media id
        media_info_response = requests.get(ANAFI_MEDIA_API_URL + media_id)
        media_info_response.raise_for_status()
        for resource in media_info_response.json()["resources"]:
            image_response = requests.get(ANAFI_URL + resource["url"], stream=True)
            image_response.raise_for_status()
            img = Image.open(BytesIO(image_response.content))
            result = self.model(img)
#             result.print()
#             result.show()
    
        return result.xyxy[0].shape[0]
    
class State:
    def __init__(self, drone):
        self.drone = drone

class Drone:
    def __init__(self, drone_ip):
        self.drone = olympe.Drone(drone_ip)
        self.drone.connect()
        
        self.state = State(self.drone)
        self.reward = Reward(self.drone)
        self.action = Action(self.drone)
        
        self._takeoff()
    
    def take_action(self, action):
        state = 0
        reward = 0
        
        return state, reward

    def reset(self):
        return state
    
    def _get_state(self):
        pass
    
    def _get_reward(self):
        pass
    
    def _takeoff(self):
        self.action.takeoff()
        
    def _land(self):
        self.action.land()
        
    def __del__(self):
        self._land()
        self.drone.disconnect()
        del state
        del reward
        del action

class Simulation:
    def __init__(self):
        pass
    
    @staticmethod
    def disable_battery():
        sphinx = jsonrpclib.Server('http://127.0.0.1:8383')
        sphinx.SetParam(machine='anafi4k', 
                         object='lipobattery/lipobattery', 
                         parameter='discharge_speed_factor', 
                         value='0')
    
    @staticmethod
    def stop_target_movement():
        f = open("/home/daniel/dp-code/plugins/moving_target/toggle_movement.txt", "w")
        f.write("0")
        f.close()
    
    @staticmethod
    def start_target_movement():
        f = open("/home/daniel/dp-code/plugins/moving_target/toggle_movement.txt", "w")
        f.write("1")
        f.close()
    
class AnafiEnv(Env):
    def __init__(self, drone_ip="10.202.0.1"):
        super(AnafiEnv, self).__init__()
        
        Simulation.disable_battery()
        Simulation.stop_target_movement()
        
        self.agent = Drone(drone_ip)
        Simulation.start_target_movement()
    
    def step(self, action):
        obs, reward = self.agent.take_action(action)
        
        return obs, reward, done, {}
    
    def reset(self):
        return self.agent.reset()
    
    def render(self, mode='human'):
        pass
    
    def close(self):
        Simulation.stop_target_movement()
        del self.agent

In [10]:
Action = IntEnum(
    'ACTION',
    'FORWARD BACKWARD LEFT RIGHT FORWARD_LEFT FORWARD_RIGHT BACKWARD_LEFT BACKWARD_RIGHT HOVER',
    start=0
)

class AnafiEnv(Env):
    def __init__(self, 
                 DRONE_IP="10.202.0.1", 
                 x_move_dist=1, 
                 y_move_dist=1.64,
                 boundaries={"x":5, "y":5, "z_min":3, "z_max":8}):
        
        super(AnafiEnv, self).__init__()
        
        self.x_move_dist = x_move_dist
        self.y_move_dist = y_move_dist
        self.boundaries = boundaries
        self.DRONE_IP = DRONE_IP
        
        self.model = torch.hub.load('../yolov5', 'custom', path='../weights/best.pt', source='local', verbose=False)
        self.model.conf = 0.7
        self.model.classes = [0]
        
        self.sphinx = jsonrpclib.Server('http://127.0.0.1:8383')
        self.sphinx.SetParam(machine='anafi4k', object='lipobattery/lipobattery', parameter='discharge_speed_factor', value='0')
        
        Simulation.stop_target_movement()
        
        self.drone = olympe.Drone(self.DRONE_IP)
        self.drone.connect()
        self.drone(GPSFixStateChanged(_policy = 'wait'))
        self.takeoff()
        self.setup_photo_single_mode()
        self.home = self.drone.get_state(GpsLocationChanged)
        self.agent_coord = np.zeros(3)
                
        self.current_cell = self.get_cell(13)
        self.invalid_left_cells = [1, 6, 11, 16, 21]
        self.invalid_forward_cells = [1, 2, 3, 4, 5]
        self.invalid_right_cells = [5, 10, 15, 20, 25]
        self.invalid_backward_cells = [21, 22, 23, 24, 25]
        
        self.action_space = spaces.Discrete(len(Action))
        self.observation_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
        
        ON_POSIX = 'posix' in sys.builtin_module_names
        cmd = "parrot-gz topic -e /gazebo/default/pose/info | grep -A 5 'name: \"anafi4k\"'"
        p = Popen(cmd, stdout=PIPE, bufsize=1, close_fds=ON_POSIX, shell=True)
        self.odom_thread = Thread(target=self.set_current_agent_coord, args=[p.stdout], daemon=True)
        self.odom_thread.start()
        
        self.step_count = 0
        self.rewards = []
        
        self.drone(PCMD(1, 0, 0, 0, 80, 0))
        time.sleep(0.05)
        
        Simulation.start_target_movement()
    
    def setup_photo_single_mode(self):
        self.drone(set_camera_mode(cam_id=0, value="photo")).wait()

        self.drone(
            set_photo_mode(
                cam_id=0,
                mode="single",
                format="rectilinear",
                file_format="jpeg",
                burst="burst_14_over_1s",
                bracketing="preset_1ev",
                capture_interval=0.0,
                )
            ).wait()
        
        pitch = -90.0
        self.drone(
            set_target(
                gimbal_id=0,
                control_mode="position",
                yaw_frame_of_reference="none",
                yaw=0.0,
                pitch_frame_of_reference="absolute",
                pitch=pitch,
                roll_frame_of_reference="none",
                roll=0.0,
                )
            >> attitude(
                pitch_absolute=pitch, _policy="wait", _float_tol=(1e-3, 1e-1)
                )
            ).wait(_timeout=20)
        
    def take_photo_single(self):

        # Drone web server URL
        ANAFI_URL = "http://{}/".format(self.DRONE_IP)

        # Drone media web API URL
        ANAFI_MEDIA_API_URL = ANAFI_URL + "api/v1/media/medias/"

        self.drone(PCMD(1, 0, 0, 0, 0, 0) >> FlyingStateChanged(state="hovering", _timeout=5)).wait()
        
        Simulation.stop_target_movement()
        time.sleep(1)
        
        photo_saved = self.drone(photo_progress(result="photo_saved", _policy="wait"))
        self.drone(take_photo(cam_id=0)).wait()
        if not photo_saved.wait().success():
            print("Photos not saved")
        media_id = photo_saved.received_events().last().args["media_id"]

        # download the photos associated with this media id
        media_info_response = requests.get(ANAFI_MEDIA_API_URL + media_id)
        media_info_response.raise_for_status()
        for resource in media_info_response.json()["resources"]:
            image_response = requests.get(ANAFI_URL + resource["url"], stream=True)
            image_response.raise_for_status()
            img = Image.open(BytesIO(image_response.content))
            result = self.model(img)
#             result.print()
#             result.show()

        media = self.drone.media
        delete = delete_media(media_id, _timeout=10)
        if not media(delete).wait().success():
            logger.error(
                "Failed to delete media {} {}".format(media_id, delete.explain())
            )
#         print(ANAFI_URL)
#         print(resource["url"])
#         print("Deleting ", ANAFI_URL + resource["url"])
#         ######### trying to delete returns HTTPerror: 400 bad request for url. ###############
#         image_response = requests.delete("http://10.202.0.1/data/media/medias", stream=True)
#         image_response.raise_for_status()
        
        time.sleep(1)
        Simulation.start_target_movement()
    
        return result.xyxy[0].shape[0]
        
        # TODO: Delete the image file in the web server after processing to avoid accumulation
    
    def get_current_agent_coord(self):
        return self.agent_coord
    
    def set_current_agent_coord(self, output):
        for line in iter(output.readline, b''):
            line = str(line)
            
            if "x:" in line:
                x = re.findall(r"[-+]?\d*\.\d+|\d+", line)[0]
                self.agent_coord[0] = float(x)
            elif "y:" in line:
                y = re.findall(r"[-+]?\d*\.\d+|\d+", line)[0]
                self.agent_coord[1] = float(y)
            elif "z:" in line:
                z = re.findall(r"[-+]?\d*\.\d+|\d+", line)[0]
                self.agent_coord[2] = float(z)
            
    def step(self, action):
        
        done = False
        x, y, z = self.agent_coord
         
        self.take_action(action)
        
        self.step_count += 1
        
        obs = [0.5, 0.5, 0.5]
        reward = self.take_photo_single()
#         reward = 0
        self.rewards.append(reward)

#         obs = self.agent_coord / self.boundary
            
#         if (self.agent_coord >= self.).any() or self.step_count >= 10:
#             done = True
#             self.step_count = 0
        
        return obs, reward, done, {}

    def take_action(self, action):
        if action == Action.HOVER:
            print("hovering...")
            return
        elif action == Action.LEFT:
            if self.current_cell["id"] in self.invalid_left_cells:
                print("trying left but hovering...")
                return
            next_cell_id = self.current_cell["id"] - 1
        elif action == Action.RIGHT:
            if self.current_cell["id"] in self.invalid_right_cells:
                print("trying right but hovering...")
                return
            next_cell_id = self.current_cell["id"] + 1
        elif action == Action.FORWARD:
            if self.current_cell["id"] in self.invalid_forward_cells:
                print("trying forward but hovering...")
                return
            next_cell_id = self.current_cell["id"] - 5
        elif action == Action.BACKWARD:
            if self.current_cell["id"] in self.invalid_backward_cells:
                print("trying backward but hovering...")
                return
            next_cell_id = self.current_cell["id"] + 5
        elif action == Action.FORWARD_RIGHT:
            if self.current_cell["id"] in self.invalid_forward_cells + self.invalid_right_cells:
                print("trying FR but hovering...")
                return
            next_cell_id = self.current_cell["id"] - 4
        elif action == Action.FORWARD_LEFT:
            if self.current_cell["id"] in self.invalid_forward_cells + self.invalid_left_cells:
                print("trying FL but hovering...")
                return
            next_cell_id = self.current_cell["id"] - 6
        elif action == Action.BACKWARD_RIGHT:
            if self.current_cell["id"] in self.invalid_backward_cells + self.invalid_right_cells:
                print("trying BR but hovering...")
                return
            next_cell_id = self.current_cell["id"] + 6
        elif action == Action.BACKWARD_LEFT:
            if self.current_cell["id"] in self.invalid_backward_cells + self.invalid_left_cells:
                print("trying BL but hovering...")
                return
            next_cell_id = self.current_cell["id"] + 4
        
        next_cell = self.get_cell(next_cell_id)
        self.move_to_cell(next_cell)
        self.current_cell = next_cell
    
    def get_cell(self, cell_id):
        return self.cell_locs[cell_id - 1]
    
    def move_to_cell(self, next_cell):
        print("moving from cell", self.current_cell["id"], "to cell", next_cell["id"])
        
        self.drone(
            moveTo(next_cell["latitude"],  next_cell["longitude"], next_cell["altitude"], "HEADING_DURING", 90.0)
            >> moveToChanged(status="DONE", _timeout=15)
        ).wait()
    
    @property
    def cell_locs(self):
        altitude = 1.5
#         altitude = 1.0698300695419312
        
        return [
            # cell no. 1
            OrderedDict([('id', 1),
                         ('latitude', 48.87892882877948),
                         ('longitude', 2.367806016921019),
                         ('altitude', altitude)]),
            # cell no. 2
            OrderedDict([('id', 2),
                         ('latitude', 48.87891432875068),
                     ('longitude', 2.367805928258828),
                     ('altitude', altitude)]),
            # cell no. 3
            OrderedDict([('id', 3),
                         ('latitude', 48.87889999449613),
                     ('longitude', 2.367805968481272),
                     ('altitude', altitude)]),
            # cell no. 4
            OrderedDict([('id', 4),
                         ('latitude', 48.87888566666667),
                     ('longitude', 2.3678058333333336),
                     ('altitude', altitude)]),
            # cell no. 5
            OrderedDict([('id', 5),
                         ('latitude', 48.87887116666668),
                     ('longitude', 2.3678058333333336),
                     ('altitude', altitude)]),
            # cell no. 6
            OrderedDict([('id', 6),
                         ('latitude', 48.87892883333334),
                     ('longitude', 2.367792166666667),
                     ('altitude', altitude)]),
            # cell no. 7
            OrderedDict([('id', 7),
                         ('latitude', 48.87891433333333),
                     ('longitude', 2.367792166666667),
                     ('altitude', altitude)]),
            # cell no. 8
            OrderedDict([('id', 8),
                         ('latitude', 48.87890000000001),
                     ('longitude', 2.367792166666667),
                     ('altitude', altitude)]),
            # cell no. 9
            OrderedDict([('id', 9),
                         ('latitude', 48.87888566666667),
                     ('longitude', 2.367792166666667),
                     ('altitude', altitude)]),
            # cell no. 10
            OrderedDict([('id', 10),
                         ('latitude', 48.87887116666668),
                     ('longitude', 2.367792166666667),
                     ('altitude', altitude)]),
            # cell no. 11
            OrderedDict([('id', 11),
                         ('latitude', 48.8789288284333),
                     ('longitude', 2.367778637691261),
                     ('altitude', altitude)]),
            # cell no. 12
            OrderedDict([('id', 12),
                         ('latitude', 48.87891433106901),
                     ('longitude', 2.3677786478023783),
                     ('altitude', altitude)]),
            # cell no. 13
            OrderedDict([('id', 13),
                         ('latitude', 48.87890000000001),
                     ('longitude', 2.3677785),
                     ('altitude', altitude)]),
            # cell no. 14
            OrderedDict([('id', 14),
                         ('latitude', 48.87888566666667),
                     ('longitude', 2.3677785),
                     ('altitude', altitude)]),
            # cell no. 15
            OrderedDict([('id', 15),
                         ('latitude', 48.87887116666668),
                     ('longitude', 2.3677785),
                     ('altitude', altitude)]),
            # cell no. 16
            OrderedDict([('id', 16),
                         ('latitude', 48.87892883333334),
                     ('longitude', 2.367765),
                     ('altitude', altitude)]),
            # cell no. 17
            OrderedDict([('id', 17),
                         ('latitude', 48.87891433106901),
                     ('longitude', 2.367765),
                     ('altitude', altitude)]),
            # cell no. 18
            OrderedDict([('id', 18),
                         ('latitude', 48.87890000000001),
                     ('longitude', 2.367765),
                     ('altitude', altitude)]),
            # cell no. 19
            OrderedDict([('id', 19),
                         ('latitude', 48.87888566666667),
                     ('longitude', 2.367765),
                     ('altitude', altitude)]),
            # cell no. 20
            OrderedDict([('id', 20),
                         ('latitude', 48.87887116666668),
                     ('longitude', 2.367765),
                     ('altitude', altitude)]),
            # cell no. 21
            OrderedDict([('id', 21),
                         ('latitude', 48.87892883333334),
                     ('longitude', 2.367751333333333),
                     ('altitude', altitude)]),
            # cell no. 22
            OrderedDict([('id', 22),
                         ('latitude', 48.87891433106901),
                     ('longitude', 2.367751333333333),
                     ('altitude', altitude)]),
            # cell no. 23
            OrderedDict([('id', 23),
                         ('latitude', 48.87890000000001),
                     ('longitude', 2.367751333333333),
                     ('altitude', altitude)]),
            # cell no. 24
            OrderedDict([('id', 24),
                         ('latitude', 48.87888566666667),
                     ('longitude', 2.367751333333333),
                     ('altitude', altitude)]),
            # cell no. 25
            OrderedDict([('id', 25),
                         ('latitude', 48.87887116666668),
                     ('longitude', 2.367751333333333),
                     ('altitude', altitude)]),
        ]
    
    def move(self, lin_x, lin_y, lin_z, ang_yaw):
        success =  self.drone(
                            moveBy(lin_x, lin_y, -lin_z, ang_yaw)  # (forward-backward, right-left, down-up, clockwise-anticlockwise)
                            >> FlyingStateChanged(state="hovering", _timeout=5)
                            ).wait().success()
                            
    def takeoff(self):
        assert self.drone(
                    FlyingStateChanged(state="hovering", _policy="check")
                    | (
                        GPSFixStateChanged(fixed=1, _timeout=10)
                        >> (
                            TakeOff(_no_expect=True)
                            & FlyingStateChanged(state="hovering", _policy="wait", _timeout=5)
                        )
                    )
                ).wait().success()
        
    def reset(self):
        latitude, longitude, altitude = self.home["latitude"], self.home["longitude"], self.home["altitude"]
        
        resetAction = self.drone(
            moveTo(latitude,  longitude, altitude, "HEADING_DURING", 90.0)
            >> moveToChanged(status="DONE", _timeout=15)
        ).wait()
    
#         return self.agent_coord / self.boundary
        return [0.5, 0.5, 0.5]
    
    def render(self, mode='human'):
        pass
    
    def close(self):
        Simulation.stop_target_movement()
        self.drone(PCMD(1, 0, 0, 0, 0, 0) >> FlyingStateChanged(state="hovering", _timeout=5)).wait()
        assert self.drone(Landing() >> FlyingStateChanged(state="landed")).wait().success()
        self.drone.disconnect()

## Run the simulation

In [23]:
try:
    del env
except NameError:
    pass

In [24]:
%%capture --no-stdout 
# do not capture stdout but do capture stderr

layout = [
    [
            sg.Column([
                [sg.Text("Time step: 0", size=(20, 2), key="-STEP-", justification='center', font=("Helvetica, 12"))],
                [sg.Text("Reward", size=(20, 2), justification='center', font=("Helvetica, 14"))],
                [sg.Text("0", size=(20, 3), key="-REWARD-", justification='center', font=("Helvetica", 12))],
            ],
            element_justification='center'),
        ],
    ] 

window = sg.Window("Results", layout, margins=(25, 25))
env = AnafiEnv(x_move_dist=1, y_move_dist=1.64, boundaries={"x":2, "y":3.28, "z_min":1.1, "z_max":70})
control = KeyboardCtrl()

step = 0
event, values = window.read(timeout=10)
while not control.quit():
    action = env.action_space.sample()
    state, reward, done, _ = env.step(action)
    window['-REWARD-'].update(str(reward))
    step = step + 1
    window['-STEP-'].update("Time step: " + str(step))

    event, values = window.read(timeout=10)
    
    if event == sg.WIN_CLOSED:
        break

window.close()
env.reset()
env.close()

moving from cell 13 to cell 19
moving from cell 19 to cell 23
hovering...
hovering...
moving from cell 23 to cell 19
hovering...
moving from cell 19 to cell 18
moving from cell 18 to cell 17
moving from cell 17 to cell 13
moving from cell 13 to cell 19
moving from cell 19 to cell 15
moving from cell 15 to cell 14
moving from cell 14 to cell 19
moving from cell 19 to cell 18
moving from cell 18 to cell 23
