In [None]:
from enum import IntEnum
import time
import jsonrpclib
import subprocess
from subprocess import PIPE, Popen
from threading  import Thread
import sys
import re

from gym import Env, error, spaces, utils
from stable_baselines3 import DQN, PPO, A2C, TD3, SAC
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.env_checker import check_env
import numpy as np

import os
import requests
import shutil
import tempfile
import xml.etree.ElementTree as ET
from io import StringIO, BytesIO

import cv2
import numpy as np
import torch
from PIL import Image

import olympe
from olympe.messages.ardrone3.Piloting import TakeOff, Landing, moveBy, PCMD, moveTo
from olympe.messages.ardrone3.PilotingState import FlyingStateChanged, GpsLocationChanged, moveToChanged
from olympe.enums.ardrone3.PilotingState import FlyingStateChanged_State as FlyingState
from olympe.messages.ardrone3.GPSSettingsState import GPSFixStateChanged
from olympe.messages.gimbal import set_target
from olympe.messages.camera import (
    set_camera_mode,
    set_photo_mode,
    take_photo,
    photo_progress,
)

from pynput.keyboard import Listener, Key, KeyCode
from collections import defaultdict

olympe.log.update_config({
    "loggers": {
        "olympe": {
                "handlers": []
            }
        },
        "ulog": {
            "level": "OFF",
            "handlers": [],
        }
})

In [None]:
class KeyboardCtrl(Listener):
    def __init__(self, ctrl_keys=None):
        self._key_pressed = defaultdict(lambda: False)
        self._last_action_ts = defaultdict(lambda: 0.0)
        super().__init__(on_press=self._on_press)
        self.start()

    def _on_press(self, key):
        if isinstance(key, KeyCode):
            self._key_pressed[key.char] = True
        elif isinstance(key, Key):
            self._key_pressed[key] = True
        if self._key_pressed[Key.esc]:
            return False
        else:
            return True
        
    def quit(self):
        return not self.running or self._key_pressed[Key.esc]

In [None]:
Action = IntEnum(
    'ACTION',
    'UP DOWN LEFT RIGHT FORWARD BACKWARD HOVER',
    start=0
)

class AnafiEnv(Env):
    def __init__(self, 
                 DRONE_IP="10.202.0.1", 
                 move_dist=2, 
                 boundaries={"x":5, "y":5, "z_min":3, "z_max":8}):
        
        super(AnafiEnv, self).__init__()
        
        self.move_dist = move_dist
        self.boundaries = boundaries
        self.DRONE_IP = DRONE_IP
        
        self.model = torch.hub.load('../yolov5', 'custom', path='../weights/best.pt', source='local', verbose=False)
        self.model.conf = 0.7
        self.model.classes = [0]
        
        self.sphinx = jsonrpclib.Server('http://127.0.0.1:8383')
        self.sphinx.SetParam(machine='anafi4k', object='lipobattery/lipobattery', parameter='discharge_speed_factor', value='0')
        
        self.drone = olympe.Drone(self.DRONE_IP)
        self.drone.connect()
        self.setup_photo_single_mode()
        self.takeoff()
        self.home = self.drone.get_state(GpsLocationChanged)
        self.agent_coord = np.zeros(3)
        
        self.action_space = spaces.Discrete(len(Action))
        self.observation_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
        
        ON_POSIX = 'posix' in sys.builtin_module_names
        cmd = "parrot-gz topic -e /gazebo/default/pose/info | grep -A 5 'name: \"anafi4k\"'"
        p = Popen(cmd, stdout=PIPE, bufsize=1, close_fds=ON_POSIX, shell=True)
        self.odom_thread = Thread(target=self.set_current_agent_coord, args=[p.stdout], daemon=True)
        self.odom_thread.start()
        
        self.step_count = 0
        self.rewards = []
        
        while self.agent_coord[2] < boundaries["z_min"]:
            self.drone(PCMD(1, 0, 0, 0, 80, 0))
            time.sleep(0.05)
    
    def setup_photo_single_mode(self):
        self.drone(set_camera_mode(cam_id=0, value="photo")).wait()

        self.drone(
            set_photo_mode(
                cam_id=0,
                mode="single",
                format="rectilinear",
                file_format="jpeg",
                burst="burst_14_over_1s",
                bracketing="preset_1ev",
                capture_interval=0.0,
                )
            ).wait()
        
        self.drone(
            set_target(
                gimbal_id=0,
                control_mode="position",
                yaw_frame_of_reference="none",
                yaw=0.0,
                pitch_frame_of_reference="absolute",
                pitch=-90.0,
                roll_frame_of_reference="none",
                roll=0.0,
                )
            ).wait()
        
    def take_photo_single(self):

        # Drone web server URL
        ANAFI_URL = "http://{}/".format(self.DRONE_IP)

        # Drone media web API URL
        ANAFI_MEDIA_API_URL = ANAFI_URL + "api/v1/media/medias/"

        self.drone(PCMD(1, 0, 0, 0, 0, 0) >> FlyingStateChanged(state="hovering", _timeout=5)).wait()
        photo_saved = self.drone(photo_progress(result="photo_saved", _policy="wait"))
        self.drone(take_photo(cam_id=0)).wait()
        if not photo_saved.wait().success():
            print("Photos not saved")
        media_id = photo_saved.received_events().last().args["media_id"]

        # download the photos associated with this media id
        media_info_response = requests.get(ANAFI_MEDIA_API_URL + media_id)
        media_info_response.raise_for_status()
        for resource in media_info_response.json()["resources"]:
            image_response = requests.get(ANAFI_URL + resource["url"], stream=True)
            image_response.raise_for_status()
            img = Image.open(BytesIO(image_response.content))
            result = self.model(img)
#             result.print()
#             result.show()
    
        return result.xyxy[0].shape[0]
        
        # TODO: Delete the image file in the web server after processing to avoid accumulation
    
    def get_current_agent_coord(self):
        return self.agent_coord
    
    def set_current_agent_coord(self, output):
        for line in iter(output.readline, b''):
            line = str(line)
            
            if "x:" in line:
                x = re.findall(r"[-+]?\d*\.\d+|\d+", line)[0]
                self.agent_coord[0] = float(x)
            elif "y:" in line:
                y = re.findall(r"[-+]?\d*\.\d+|\d+", line)[0]
                self.agent_coord[1] = float(y)
            elif "z:" in line:
                z = re.findall(r"[-+]?\d*\.\d+|\d+", line)[0]
                self.agent_coord[2] = float(z)
            
    def step(self, action):
        
        done = False
        
        x, y, z = self.agent_coord
        
        # change invalid actions to hover
        if self.boundaries["x"] - abs(x) < self.move_dist:
            if x > 0 and action == Action.FORWARD:
                action = Action.HOVER
            elif x < 0 and action == Action.BACKWARD:
                action = Action.HOVER
        elif self.boundaries["y"] - abs(y) < self.move_dist:
            if y > 0 and action == Action.LEFT:
                action = Action.HOVER
            elif y < 0 and action == Action.RIGHT:
                action = Action.HOVER
        elif (z - self.boundaries["z_min"] < self.move_dist or z < self.boundaries["z_min"]) and action == Action.DOWN:
            action = Action.HOVER
        elif (self.boundaries["z_max"] - z < self.move_dist or z > self.boundaries["z_max"]) and action == Action.UP:
            action = Action.HOVER
            
       
        self.take_action(action)
        
        self.step_count += 1
        
        obs = [0.5, 0.5, 0.5]
        reward = self.take_photo_single()
        self.rewards.append(reward)

#         obs = self.agent_coord / self.boundary
            
#         if (self.agent_coord >= self.).any() or self.step_count >= 10:
#             done = True
#             self.step_count = 0
        
        return obs, reward, done, {}

    def take_action(self, action):
        if action == Action.UP:
            self.move(0, 0, self.move_dist, 0)
        elif action == Action.DOWN:
            self.move(0, 0, -self.move_dist, 0)
        elif action == Action.RIGHT:
            self.move(0, self.move_dist, 0, 0)
        elif action == Action.LEFT:
            self.move(0, -self.move_dist, 0, 0)
        elif action == Action.FORWARD:
            self.move(self.move_dist, 0, 0, 0)
        elif action == Action.BACKWARD:
            self.move(-self.move_dist, 0, 0, 0)
        elif action == Action.HOVER:
            self.move(0, 0, 0, 0)
    
    def move(self, lin_x, lin_y, lin_z, ang_yaw):
        success =  self.drone(
                            moveBy(lin_x, lin_y, -lin_z, ang_yaw)  # (forward-backward, right-left, down-up, clockwise-anticlockwise)
                            >> FlyingStateChanged(state="hovering", _timeout=5)
                            ).wait().success()
                            
    def takeoff(self):
        assert self.drone(
                    FlyingStateChanged(state="hovering", _policy="check")
                    | (
                        GPSFixStateChanged(fixed=1, _timeout=10)
                        >> (
                            TakeOff(_no_expect=True)
                            & FlyingStateChanged(state="hovering", _policy="wait", _timeout=5)
                        )
                    )
                ).wait().success()
        
    def reset(self):
        latitude, longitude, altitude = self.home["latitude"], self.home["longitude"], self.home["altitude"]
        
        resetAction = self.drone(
            moveTo(latitude,  longitude, altitude, "HEADING_DURING", 90.0)
            >> moveToChanged(status="DONE", _timeout=15)
        ).wait()
    
#         return self.agent_coord / self.boundary
        return [0.5, 0.5, 0.5]
    
    def render(self, mode='human'):
        pass
    
    def close(self):
        self.drone(PCMD(1, 0, 0, 0, 0, 0) >> FlyingStateChanged(state="hovering", _timeout=5)).wait()
        assert self.drone(Landing() >> FlyingStateChanged(state="landed")).wait().success()
        self.drone.disconnect()

In [None]:
env = AnafiEnv(move_dist=4, boundaries={"x":10, "y":8, "z_min":3, "z_max":8})
control = KeyboardCtrl()

while not control.quit():
    action = env.action_space.sample()
    env.step(action)
    
env.reset()
env.close()

In [None]:
env.rewards