# Import all the necessary packages

In [1]:
from enum import IntEnum
import time
import jsonrpclib
import subprocess
from subprocess import PIPE, Popen
from threading  import Thread
import sys
import re
from collections import OrderedDict

import PySimpleGUI as sg

from gym import Env, error, spaces, utils
from stable_baselines3 import DQN, PPO, A2C, TD3, SAC
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback, EvalCallback
from stable_baselines3.common import results_plotter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import results_plotter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal

import os
import requests
import shutil
import tempfile
import xml.etree.ElementTree as ET
from io import StringIO, BytesIO

import cv2
import numpy as np
import torch
from PIL import Image
from IPython.display import clear_output
import gym
from cv2 import QRCodeDetector
from pyzbar import pyzbar

import olympe
from olympe.messages.ardrone3.Piloting import TakeOff, Landing, moveBy, PCMD, moveTo
from olympe.messages.ardrone3.PilotingState import FlyingStateChanged, PositionChanged, GpsLocationChanged, moveToChanged
from olympe.enums.ardrone3.PilotingState import FlyingStateChanged_State as FlyingState
from olympe.messages.ardrone3.GPSSettingsState import GPSFixStateChanged, HomeChanged
from olympe.messages.gimbal import set_target, attitude
from olympe.messages.camera import (
    set_camera_mode,
    set_photo_mode,
    take_photo,
    photo_progress,
)
from olympe.media import (
    media_created,
    resource_created,
    media_removed,
    resource_removed,
    resource_downloaded,
    indexing_state,
    delete_media,
    download_media,
    download_media_thumbnail,
    MediaEvent,
)

from pynput.keyboard import Listener, Key, KeyCode
from collections import defaultdict

olympe.log.update_config({
    "loggers": {
        "olympe": {
                "handlers": []
            }
        },
        "ulog": {
            "level": "OFF",
            "handlers": [],
        }
})

# Define the constants

In [2]:
DRONE_IP = os.environ.get("DRONE_IP", "10.202.0.1")
DRONE_MEDIA_PORT = os.environ.get("DRONE_MEDIA_PORT", "80")

ANAFI_URL = "http://{}/".format(DRONE_IP)
ANAFI_MEDIA_API_URL = ANAFI_URL + "api/v1/media/medias/"

# Define the classes

## Action

In [3]:
class Action:
    def __init__(self, drone):
        self.drone = drone
        self.home = self.drone.get_state(GpsLocationChanged)
        
        self.current_cell = self._get_cell(13)
        self.invalid_left_cells = [1, 6, 11, 16, 21]
        self.invalid_forward_cells = [1, 2, 3, 4, 5]
        self.invalid_right_cells = [5, 10, 15, 20, 25]
        self.invalid_backward_cells = [21, 22, 23, 24, 25]
        
        self.Move = IntEnum(
            'MOVE',
            'FORWARD BACKWARD LEFT RIGHT FORWARD_LEFT FORWARD_RIGHT BACKWARD_LEFT BACKWARD_RIGHT HOVER',
            start=0
        )
        
    def take_action(self, action):
        next_cell_id = self._get_next_cell_id(action)
        next_cell = self._get_cell(next_cell_id)
        
        old_cell_id, new_cell_id = self.current_cell["id"], next_cell["id"]
        if old_cell_id == new_cell_id: 
            return old_cell_id, new_cell_id, self._get_action_name(action)
        
        self._move_to_cell(next_cell)
        
        self.current_cell = next_cell
        
        return old_cell_id, new_cell_id, self._get_action_name(action)
        
    def reset(self):
        next_cell = self._get_cell(13)
        self._move_to_cell(next_cell)
        
        old_cell_id, new_cell_id = self.current_cell["id"], next_cell["id"]
        self.current_cell = next_cell
        
        return old_cell_id, new_cell_id
    
    def _get_cell(self, cell_id):
        return self._cell_coords[cell_id - 1]
    
    def _get_action_name(self, action):
        direction = str(self.Move(action)).split(".")[1].capitalize()
        return "Moving " + direction if "hover" not in direction.lower() else "Hovering"
    
    def _move_to_cell(self, next_cell):        
        self.drone(
            moveTo(next_cell["latitude"],  next_cell["longitude"], next_cell["altitude"], "HEADING_DURING", 90.0)
            >> moveToChanged(status="DONE", _timeout=15)
        ).wait()
    
    def _get_next_cell_id(self, action):
        if action == self.Move.HOVER:
            return self.current_cell["id"]
        elif action == self.Move.LEFT:
            if self.current_cell["id"] in self.invalid_left_cells:
                return self.current_cell["id"]
            next_cell_id = self.current_cell["id"] - 1
        elif action == self.Move.RIGHT:
            if self.current_cell["id"] in self.invalid_right_cells:
                return self.current_cell["id"]
            next_cell_id = self.current_cell["id"] + 1
        elif action == self.Move.FORWARD:
            if self.current_cell["id"] in self.invalid_forward_cells:
                return self.current_cell["id"]
            next_cell_id = self.current_cell["id"] - 5
        elif action == self.Move.BACKWARD:
            if self.current_cell["id"] in self.invalid_backward_cells:
                return self.current_cell["id"]
            next_cell_id = self.current_cell["id"] + 5
        elif action == self.Move.FORWARD_RIGHT:
            if self.current_cell["id"] in self.invalid_forward_cells + self.invalid_right_cells:
                return self.current_cell["id"]
            next_cell_id = self.current_cell["id"] - 4
        elif action == self.Move.FORWARD_LEFT:
            if self.current_cell["id"] in self.invalid_forward_cells + self.invalid_left_cells:
                return self.current_cell["id"]
            next_cell_id = self.current_cell["id"] - 6
        elif action == self.Move.BACKWARD_RIGHT:
            if self.current_cell["id"] in self.invalid_backward_cells + self.invalid_right_cells:
                return self.current_cell["id"]
            next_cell_id = self.current_cell["id"] + 6
        elif action == self.Move.BACKWARD_LEFT:
            if self.current_cell["id"] in self.invalid_backward_cells + self.invalid_left_cells:
                return self.current_cell["id"]
            next_cell_id = self.current_cell["id"] + 4
            
        return next_cell_id
    
    @property
    def _cell_coords(self):
        altitude = 6.0
        dlong = 6.8e-5 # in degrees == 5 meters along x-axis (forward[+]-backward[-])
        dlat = 7.2e-5 # in degrees == 8 meters along y-axis (left[+]-right[-])
        
        home_lat = self.home["latitude"]
        home_long = self.home["longitude"]
        
        return [
            # cell no. 1
            OrderedDict([('id', 1),
                         ('latitude', home_lat + 2 * dlat),
                         ('longitude', home_long + 2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 2
            OrderedDict([('id', 2),
                         ('latitude', home_lat + 1 * dlat),
                         ('longitude', home_long + 2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 3
            OrderedDict([('id', 3),
                         ('latitude', home_lat + 0 * dlat),
                         ('longitude', home_long + 2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 4
            OrderedDict([('id', 4),
                         ('latitude', home_lat + -1 * dlat),
                         ('longitude', home_long + 2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 5
            OrderedDict([('id', 5),
                         ('latitude', home_lat + -2 * dlat),
                         ('longitude', home_long + 2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 6
            OrderedDict([('id', 6),
                         ('latitude', home_lat + 2 * dlat),
                         ('longitude', home_long + 1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 7
            OrderedDict([('id', 7),
                         ('latitude', home_lat + 1 * dlat),
                         ('longitude', home_long + 1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 8
            OrderedDict([('id', 8),
                         ('latitude', home_lat + 0 * dlat),
                         ('longitude', home_long + 1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 9
            OrderedDict([('id', 9),
                         ('latitude', home_lat + -1 * dlat),
                         ('longitude', home_long + 1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 10
            OrderedDict([('id', 10),
                         ('latitude', home_lat + -2 * dlat),
                         ('longitude', home_long + 1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 11
            OrderedDict([('id', 11),
                         ('latitude', home_lat + 2 * dlat),
                         ('longitude', home_long + 0 * dlong),
                         ('altitude', altitude)]),
            # cell no. 12
            OrderedDict([('id', 12),
                         ('latitude', home_lat + 1 * dlat),
                         ('longitude', home_long + 0 * dlong),
                         ('altitude', altitude)]),
            # cell no. 13
            OrderedDict([('id', 13),
                         ('latitude', home_lat + 0 * dlat),
                         ('longitude', home_long + 0 * dlong),
                         ('altitude', altitude)]),
            # cell no. 14
            OrderedDict([('id', 14),
                         ('latitude', home_lat + -1 * dlat),
                         ('longitude', home_long + 0 * dlong),
                         ('altitude', altitude)]),
            # cell no. 15
            OrderedDict([('id', 15),
                         ('latitude', home_lat + -2 * dlat),
                         ('longitude', home_long + 0 * dlong),
                         ('altitude', altitude)]),
            # cell no. 16
            OrderedDict([('id', 16),
                         ('latitude', home_lat + 2 * dlat),
                         ('longitude', home_long + -1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 17
            OrderedDict([('id', 17),
                         ('latitude', home_lat + 1 * dlat),
                         ('longitude', home_long + -1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 18
            OrderedDict([('id', 18),
                         ('latitude', home_lat + 0 * dlat),
                         ('longitude', home_long + -1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 19
            OrderedDict([('id', 19),
                         ('latitude', home_lat + -1 * dlat),
                         ('longitude', home_long + -1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 20
            OrderedDict([('id', 20),
                         ('latitude', home_lat + -2 * dlat),
                         ('longitude', home_long + -1 * dlong),
                         ('altitude', altitude)]),
            # cell no. 21
            OrderedDict([('id', 21),
                         ('latitude', home_lat + 2 * dlat),
                         ('longitude', home_long + -2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 22
            OrderedDict([('id', 22),
                         ('latitude', home_lat + 1 * dlat),
                         ('longitude', home_long + -2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 23
            OrderedDict([('id', 23),
                         ('latitude', home_lat + 0 * dlat),
                         ('longitude', home_long + -2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 24
            OrderedDict([('id', 24),
                         ('latitude', home_lat + -1 * dlat),
                         ('longitude', home_long + -2 * dlong),
                         ('altitude', altitude)]),
            # cell no. 25
            OrderedDict([('id', 25),
                         ('latitude', home_lat + -2 * dlat),
                         ('longitude', home_long + -2 * dlong),
                         ('altitude', altitude)]),
            ]
        
    def __len__(self):
        return len(self.Move)

## Drone

In [4]:
class Drone:
    def __init__(self, drone_ip, num_targets, max_timestep, is_training=False):
        self.drone = olympe.Drone(drone_ip)
        self.drone.connect()
        
        self.drone(GPSFixStateChanged(_policy = 'wait'))
        self._takeoff()
        if not is_training: self._setup_camera()

        self.action = Action(self.drone)
        
        self.is_training = is_training
        self.num_targets = num_targets
        self.max_timestep = max_timestep
        self.timestep = 0
        self.visited_targets = np.zeros(self.num_targets, dtype=bool)
        self.target_positions = np.zeros(self.num_targets, dtype=np.uint8)
    
    def take_action(self, action):
        old_cell, new_cell, action_name = self.action.take_action(action)
        self.timestep += 1
        detected_targets = self._detect_targets(new_cell)
        
        reward = self._get_reward(detected_targets) # !!! _get_reward must come before self.visited_targets is changed in _get_state
        state = self._get_state(new_cell, detected_targets) 
        if self.timestep >= self.max_timestep or np.all(self.visited_targets):
            done = True
            self.visited_targets[:] = False
        else:
            done = False
        info = {
            "action": str(action_name), 
            "direction": "Cell " + str(old_cell) + " --> " + "Cell " + str(new_cell)
        }
        
        return state, reward, done, info

    def reset(self):
        old_cell, new_cell = self.action.reset()
        detected_targets = self._detect_targets(new_cell)
        self.timestep = 0
        return self._get_state(new_cell, detected_targets)
    
    def _get_state(self, new_cell, detected_targets):
        # {t, cell_id, [I1, I2, I3, ..., In]}
        
        self.visited_targets[detected_targets] = True
        
        return np.concatenate(([self.timestep, new_cell], self.visited_targets)).astype(np.uint8)
    
    def _get_reward(self, detected_targets):
        reward_scale = 1.5
        num_new_targets = np.count_nonzero(
            detected_targets & (detected_targets != self.visited_targets)
        )
        return reward_scale * num_new_targets if num_new_targets > 0 else -1
        
    def _detect_targets(self, cell_id):
        return self.target_positions == cell_id
        
        if self.is_training:
            positions = np.genfromtxt('target_positions.csv', delimiter=',', skip_header=1, dtype=np.uint8)
            return positions[:,1] == cell_id
        
        detected_targets = np.zeros(self.num_targets, dtype=bool)

        img = self._take_photo()
        if img is None:
            return detected_targets
        
        for result in pyzbar.decode(img):
            idx = int(result.data) - 1
            try:
                detected_targets[idx] = True
            except (ValueError, IndexError):
                pass
                
        return detected_targets

    def _setup_camera(self):
        assert self.drone.media_autoconnect
        self.drone.media.integrity_check = True
        is_indexed = False
        while not is_indexed:
            is_indexed = self.drone.media(
                indexing_state(state="indexed")
            ).wait(_timeout=5).success()
        
        self.drone(set_camera_mode(cam_id=0, value="photo")).wait()

        assert self.drone(
            set_photo_mode(
                cam_id=0,
                mode="single",
                format= "rectilinear",
                file_format="jpeg",
                # the following are ignored in photo single mode
                burst="burst_14_over_1s",
                bracketing="preset_1ev",
                capture_interval=5.0,
            )
        ).wait().success()

        assert self.drone(
            set_target(
                gimbal_id=0,
                control_mode="position",
                yaw_frame_of_reference="none",
                yaw=0.0,
                pitch_frame_of_reference="absolute",
                pitch=-90.0,
                roll_frame_of_reference="none",
                roll=0.0,
                )
            >> attitude(
                pitch_absolute=-90.0, _policy="wait", _float_tol=(1e-3, 1e-1)
                )
            ).wait(_timeout=20).success()
    
    def _take_photo(self):
        photo_saved = self.drone(photo_progress(result="photo_saved", _policy="wait"))
        self.drone(take_photo(cam_id=0)).wait()
        
        photo_taken = False
        tries = 0
        while not photo_taken:
            tries += 1
            if tries > 3:
#                 assert False, "take_photo timedout"
                print("take_photo timedout")
                return None
            photo_taken = photo_saved.wait(_timeout=5).success()
            
        # get the bytes of the image
        media_id = photo_saved.received_events().last().args["media_id"]
        for _ in range(5):
            media_info_response = requests.get(ANAFI_MEDIA_API_URL + media_id, timeout=10)
            if media_info_response.status_code == 200:
                break
        try:
            media_info_response.raise_for_status()
        except requests.exceptions.HTTPError as err:
            print(err)
            return None
        
        resource = media_info_response.json()["resources"][0]
        image_response = requests.get(ANAFI_URL + resource["url"], stream=True)
        image_response.raise_for_status()
        
        img = Image.open(BytesIO(image_response.content))
        
        # delete the image stored on the drone
        photo_deleted = False
        delete_tries = 0
        while not photo_deleted:
            delete_tries += 1
            if delete_tries > 3:
#                 assert False, "Failed to delete media {} {}".format(media_id, delete.explain())
                print("Failed to delete media {} {}".format(media_id, delete.explain()))
                break
            delete = delete_media(media_id, _timeout=10)
            photo_deleted = self.drone.media(delete).wait().success()
        
        return img
    
    def _takeoff(self):
        takeoff_success = self._success_if_takeoff()
        if not takeoff_success:
            print("Retrying taking off...")
            takeoff_success = self._success_if_takeoff()
    
    def _success_if_takeoff(self):
        return self.drone(
                FlyingStateChanged(state="hovering")
                | (TakeOff() & FlyingStateChanged(state="hovering"))
            ).wait(10).success()
        
    def _land(self):
        self.drone(PCMD(1, 0, 0, 0, 0, 0) >> FlyingStateChanged(state="hovering", _timeout=5)).wait()
        assert self.drone(Landing() >> FlyingStateChanged(state="landed")).wait().success()
        
    def __del__(self):
        self._land()
        self.drone.disconnect()
        del state
        del reward
        del action

## Simulation

In [5]:
class Simulation:
    sphinx = jsonrpclib.Server('http://127.0.0.1:8383')
    
    def __init__(self):
        pass
    
    @staticmethod
    def disable_battery():
        Simulation.sphinx.SetParam(machine='anafi4k',
                                   object='lipobattery/lipobattery',
                                   parameter='discharge_speed_factor',
                                   value='0')
    
    @staticmethod
    def reset_world():
        Simulation.sphinx.TriggerAction(machine='world',
                                        object='fwman/fwman',
                                        action='world_reset_all')
    
    @staticmethod
    def gen_targets_pos(num_targets, mu_x=-7.5, variance_x=3, mu_y=-11, variance_y=5):
        locs = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]]).rvs(size=num_targets)
        
        cell_boundaries = [
           (locs[:,0] > 7.5) & (locs[:,1] > 12),
           (locs[:,0] > 7.5) & (locs[:,1] > 4),
           (locs[:,0] > 7.5) & (locs[:,1] > -4),
           (locs[:,0] > 7.5) & (locs[:,1] > -12),
           (locs[:,0] > 7.5) & (locs[:,1] > -20),
           (locs[:,0] > 2.5) & (locs[:,1] > 12),
           (locs[:,0] > 2.5) & (locs[:,1] > 4),
           (locs[:,0] > 2.5) & (locs[:,1] > -4),
           (locs[:,0] > 2.5) & (locs[:,1] > -12),
           (locs[:,0] > 2.5) & (locs[:,1] > -20),
           (locs[:,0] > -2.5) & (locs[:,1] > 12),
           (locs[:,0] > -2.5) & (locs[:,1] > 4),
           (locs[:,0] > -2.5) & (locs[:,1] > -4),
           (locs[:,0] > -2.5) & (locs[:,1] > -12),
           (locs[:,0] > -2.5) & (locs[:,1] > -20),
           (locs[:,0] > -7.5) & (locs[:,1] > 12),
           (locs[:,0] > -7.5) & (locs[:,1] > 4),
           (locs[:,0] > -7.5) & (locs[:,1] > -4),
           (locs[:,0] > -7.5) & (locs[:,1] > -12),
           (locs[:,0] > -7.5) & (locs[:,1] > -20),
           (locs[:,0] > -12.5) & (locs[:,1] > 12),
           (locs[:,0] > -12.5) & (locs[:,1] > 4),
           (locs[:,0] > -12.5) & (locs[:,1] > -4),
           (locs[:,0] > -12.5) & (locs[:,1] > -12),
           (locs[:,0] > -12.5) & (locs[:,1] > -20),
        ]
        cell_ids = np.arange(25) + 1
        
        return np.select(cell_boundaries, cell_ids)
    
    @staticmethod
    def cease_targets():
        f = open("../plugins/moving_target/toggle_movement.txt", "w")
        f.write("0")
        f.close()
    
    @staticmethod
    def move_targets():
        f = open("../plugins/moving_target/toggle_movement.txt", "w")
        f.write("1")
        f.close()
        
    @staticmethod
    def reset_targets(release=False):
        f = open("../plugins/moving_target/reset_position.txt", "w")
        if not release:
            f.write("1")
        else:
            f.write("0")
        f.close()
        

# Helper classes

In [None]:
class EnvManager():
    def __init__(self):
        pass
    
    def store_state(self, state):
        pass
    
    def store_reward(self, state):
        pass
    
    def get_last_state(self):
        pass

class FlightListener(olympe.EventListener):

    default_queue_size = 100

    def __init__(self, *args, **kwds):
        super().__init__(*args, **kwds)
    
    @olympe.listen_event(media_created())
    def onMediaCreated(self, event, scheduler):
        resource_url = list(event.media.resources.values())[0].url
        image_response = requests.get(ANAFI_URL + resource_url, stream=True)
        image_response.raise_for_status()
        img = Image.open(BytesIO(image_response.content))

# Define a Gym Environment

In [6]:
class AnafiEnv(Env):
    def __init__(self, num_targets, max_timestep, drone_ip="10.202.0.1", is_training=False):
        super(AnafiEnv, self).__init__()
        
        Simulation.disable_battery()
#         Simulation.cease_targets()
        
        self.num_targets = num_targets
        self.max_timestep = max_timestep
        self.begin(num_targets, max_timestep, is_training, drone_ip)
        
        self.action_space = spaces.Discrete(len(self.agent.action))
        self.observation_space = spaces.Box( # {t, cell_id, [I1, I2, I3, ..., In]}
            low=np.array([0, 1] + num_targets*[0]), 
            high=np.array([max_timestep, 25] + num_targets*[1]), 
            dtype=np.uint8,
        )
        
#         Simulation.move_targets()
    
    def begin(self, num_targets, max_timestep, is_training, drone_ip="10.202.0.1"):
        self.agent = Drone(drone_ip, num_targets, max_timestep, is_training)
    
    def step(self, action):
        obs, reward, done, info = self.agent.take_action(action)
        
        return obs, reward, done, info
    
    def reset(self):
#         Simulation.reset_targets()
        self.agent.target_positions = Simulation.gen_targets_pos(self.num_targets)
        return self.agent.reset()
    
    def render(self, mode='human'):
        pass
    
    def close(self):
#         Simulation.cease_targets()
        del self.agent

# Run the simulation

In [10]:
def disp_info(action, observation, reward, done, info):
#     clear_output(wait=True)
    print("Action:", info["action"] + ",", info["direction"])
    print("State:", observation)
    print("Reward:", reward)
#     down_scale = 3
#     display(img.resize((img.size[0]//down_scale, img.size[1]//down_scale)))

In [None]:
try:
    del env
except Exception as e:
    pass

env = AnafiEnv(num_targets=10, max_timestep=15, is_training=True)
observation = env.reset()
actions = [7, 4, 7, 3, 1, 2, 8, 4, 1, 5, 6]
for i in range(1000):
#     action = env.action_space.sample()
    action = actions[i % len(actions)]
    observation, reward, done, info = env.step(action)
    disp_info(action, observation, reward, done, info)

    if done:
        print("The episode has ended. Resetting environment...")
        observation = env.reset()
        disp_info(action, observation, 0, done, info)

env.close()

# Check the environment

In [None]:
env = AnafiEnv(num_targets=10, max_timestep=8)
check_env(env)

# Train the agent

In [7]:
class SaveOnBestTrainingRewardCallback(BaseCallback):
    """
    Callback for saving a model (the check is done every ``check_freq`` steps)
    based on the training reward (in practice, we recommend using ``EvalCallback``).

    :param check_freq:
    :param log_dir: Path to the folder where the model will be saved.
      It must contains the file created by the ``Monitor`` wrapper.
    :param verbose: Verbosity level.
    """
    def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
        super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.log_dir = log_dir
        self.save_path = os.path.join(log_dir, 'best_model')
        self.best_mean_reward = -np.inf

    def _init_callback(self) -> None:
        pass
#         # Create folder if needed
#         if self.save_path is not None:
#             os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        if self.n_calls % self.check_freq == 0:

            # Retrieve training reward
            x, y = ts2xy(load_results(self.log_dir), 'timesteps')
            if len(x) > 0:
                # Mean training reward over the last 100 episodes
                mean_reward = np.mean(y[-100:])
                if self.verbose > 0:
                    print(f"Num timesteps: {self.num_timesteps}")
                    print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")

                # New best model, you could save the agent here
                if mean_reward > self.best_mean_reward:
                    self.best_mean_reward = mean_reward
                    # Example for saving best model
                    if self.verbose > 0:
                        print(f"Saving new best model to {self.save_path}")
                    self.model.save(self.save_path)
        
        return True

In [8]:
try:
    del env
    del model
except Exception as e:
    pass

# Create log dir
run = 2
log_dir = "logs/"
monitor_file = os.path.join(log_dir, str(run))
os.makedirs(log_dir, exist_ok=True)

env = AnafiEnv(num_targets=10, max_timestep=15, is_training=True)
env = Monitor(env, monitor_file)
env = DummyVecEnv([lambda: env])

callback = SaveOnBestTrainingRewardCallback(check_freq=512, log_dir=log_dir)

# model = PPO("MlpPolicy", env, n_steps=1024, verbose=1, tensorboard_log=log_dir)
model = PPO.load(os.path.join(log_dir, str(run-1) + "_run"), env)
model.learn(total_timesteps=15_000, callback=callback, tb_log_name="PPO_" + str(run), reset_num_timesteps=False)
model.save(os.path.join(log_dir, str(run) + "_run"))

Using cuda device
Logging to logs/PPO_1_0
Num timesteps: 512
Best mean reward: -inf - Last mean reward per episode: -8.97
Saving new best model to logs/best_model
Num timesteps: 1024
Best mean reward: -8.97 - Last mean reward per episode: -10.12
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 14.8     |
|    ep_rew_mean     | -10.1    |
| time/              |          |
|    fps             | 0        |
|    iterations      | 1        |
|    time_elapsed    | 2544     |
|    total_timesteps | 1024     |
---------------------------------
Num timesteps: 1536
Best mean reward: -8.97 - Last mean reward per episode: -9.88
Num timesteps: 2048
Best mean reward: -8.97 - Last mean reward per episode: -8.79
Saving new best model to logs/best_model
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 14.6        |
|    ep_rew_mean          | -8.79       |
| time/                   |             |

Num timesteps: 8704
Best mean reward: 0.21 - Last mean reward per episode: 1.45
Saving new best model to logs/best_model
Num timesteps: 9216
Best mean reward: 1.45 - Last mean reward per episode: 2.67
Saving new best model to logs/best_model
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 12.4        |
|    ep_rew_mean          | 2.67        |
| time/                   |             |
|    fps                  | 0           |
|    iterations           | 9           |
|    time_elapsed         | 18938       |
|    total_timesteps      | 9216        |
| train/                  |             |
|    approx_kl            | 0.013835983 |
|    clip_fraction        | 0.143       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.89       |
|    explained_variance   | 0.0847      |
|    learning_rate        | 0.0003      |
|    loss                 | 9.59        |
|    n_updates            | 80          |
| 

2022-03-07 15:08:03,012 [31m[ERROR] [0m	ulog - pomp - epoll_ctl(fd=73) err=9(Bad file descriptor)[0m
2022-03-07 15:08:03,016 [31m[ERROR] [0m	ulog - pomp - epoll_ctl op=2 cb=0x7fc467e194e0 userdata=0x7fc3037edf90[0m


# Continue training

In [None]:
try:
    del env
    del model
except Exception as e:
    pass

# Create log dir
log_dir = "tmp/"
os.makedirs(log_dir, exist_ok=True)

num_targets=10
max_timestep=15

env = AnafiEnv(num_targets=num_targets, max_timestep=max_timestep)
env = Monitor(env, log_dir)

callback = SaveOnBestTrainingRewardCallback(check_freq=25, log_dir=log_dir, num_targets=num_targets, max_timestep=max_timestep)

checkpoint_callback = CheckpointCallback(save_freq=10, save_path='./logs/')

model = PPO.load("logs/rl_model_180_steps.zip", env, verbose=1)
model.learn(total_timesteps=150, callback=[callback, checkpoint_callback])
model.save("best_model.zip")

# Test the agent

In [None]:
try:
    del env
    del model
except Exception as e:
    pass

env = AnafiEnv(num_targets=10, max_timestep=15)
model = PPO.load("third_run", env, verbose=1)

obs = env.reset()
for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)

    if done:
        obs = env.reset()