In [38]:
%reload_ext autoreload
%autoreload 2
%matplotlib notebook

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import torch
from torch.distributions import Bernoulli, MultivariateNormal, Categorical
import gym

import notebook_setup
from tqdm.auto import tqdm, trange
from systems import CartPoleEnv
from systems import CartPoleContinuousEnv
from ppo import ActorCriticDiscrete, ActorCriticMultiBinary, ActorCriticBox, PPO, DEVICE

# Policies

## Discrete

In [None]:
ppo_params = dict(
    state_dim=4,
    action_dim=2,
    n_latent_var=32,
    lr=0.02,
    epochs=5,
    update_interval=500
)

In [None]:
agent = PPO(CartPoleEnv(), ActorCriticDiscrete, **ppo_params)
rewards = agent.learn(30000)
plt.scatter(np.arange(len(rewards)), rewards)

In [None]:
ActorCriticDiscrete(state_dim=4, action_dim=2, n_latent_var=32)

## Continuous

In [None]:
env = gym.make('LunarLanderContinuous-v2')

ppo_params = dict(
    state_dim=env.observation_space.shape[0],
    action_dim=env.action_space.shape[0],
    n_latent_var=64,
    lr=0.0003,
    epochs=75,
    update_interval=3000
)

In [None]:
agent = PPO(env, ActorCriticBox, **ppo_params)
rewards = agent.learn(3000)
plt.scatter(np.arange(len(rewards)), rewards)

In [None]:
env = CartPoleContinuousEnv()

ppo_params = dict(
    state_dim=env.observation_space.shape[0],
    action_dim=env.action_space.shape[0],
    n_latent_var=32,
    lr=0.02,
    epochs=25,
    update_interval=500
)

In [None]:
agent = PPO(env, ActorCriticBox, **ppo_params)
rewards = agent.learn(10000)
plt.scatter(np.arange(len(rewards)), rewards)

## Discretized Continuous

In [None]:
class ActorCriticBoxDiscrete(ActorCriticBox):
    
    def predict(self, state):
        action, logprob = super().predict(state)
        return int(np.round(np.clip(action.item(), 0, 1))), logprob

    def evaluate(self, state, action):
        action_logprobs, state_value, dist_entropy = \
            super().evaluate(state, action)
        return action_logprobs, state_value, dist_entropy

env = CartPoleEnv()

ppo_params = dict(
    state_dim=4,
    action_dim=1,
    n_latent_var=64,
    lr=0.002,
    epochs=50,
    update_interval=500
)

agent = PPO(env, ActorCriticBoxDiscrete, **ppo_params)
rewards = agent.learn(10000)
plt.scatter(np.arange(len(rewards)), rewards)

# Quadcopter

In [6]:
from systems.quadcopter import Quadcopter, QuadcopterSupervisorEnv, Controller, QUADPARAMS, CONTROLLER_PARAMS

In [114]:
env = QuadcopterSupervisorEnv(Controller(Quadcopter()))

In [116]:
env.reset()

array([17.150793  , 10.0325985 , -8.3591585 ,  0.5563135 ,  0.74002427,
        0.9572367 ,  0.46991718, -0.06050808,  0.4406542 , -0.05996131,
        0.02197874, -0.05602194], dtype=float32)

In [10]:
env.ctrl.target

array([ 1.36089122,  8.51193277, -8.57927884])

In [11]:
env.start_pos

array([0.97627008, 4.30378733, 2.05526752])

In [94]:
env = QuadcopterSupervisorEnv(Controller(Quadcopter()))
T = 1200
pos = np.zeros((T, 3))
env.reset(position=(0,0,0), linear_rate=(0,0,0), orientation=(0,0,0), angular_rate=(0,0,0), target=(3,1,2))
pos[0] = env.start_pos
for t in trange(1, T, leave=False):
    env.step(0.)
    pos[t] = env.ctrl.quadcopter.state[:3]

HBox(children=(FloatProgress(value=0.0, max=1199.0), HTML(value='')))

In [119]:
env = QuadcopterSupervisorEnv(Controller(Quadcopter()))
factor = 0.15
pos_ = np.zeros((T, 3))
env.reset(position=(0,0,0), linear_rate=(0,0,0), orientation=(0,0,0), angular_rate=(0,0,0), target=(3,1,2))
pos_[0] = env.start_pos
for t in trange(1, T, leave=False):
    env.step(factor)
    pos_[t] = env.ctrl.quadcopter.state[:3]

HBox(children=(FloatProgress(value=0.0, max=1199.0), HTML(value='')))

In [None]:
fig = plt.figure(figsize=(8, 8), constrained_layout=True)
gs = fig.add_gridspec(3, 1)
ax = fig.add_subplot(gs[0:2, 0], projection='3d')
ax.plot(pos[:, 0], pos[:, 1], pos[:, 2], label='No supervision', color='b')
ax.plot(pos_[:, 0], pos_[:, 1], pos_[:, 2], label=f'{factor:.2f}', color='r')
ax.text(*env.start_pos, "start")
ax.text(*env.ctrl.target, "end")
ax.legend()

# us = np.linspace(0, 2 * np.pi, 16)
# ts = np.linspace(0, 1000, 16)
# line = env.start + env.direction * t
# us, ts = np.meshgrid(us, ts)

ax = fig.add_subplot(gs[2:, 0])
ax.plot(pos[:, 0], 'r:', label='x')
ax.plot(pos_[:, 0], 'b:')
ax.plot(pos[:, 1], 'r-', label='y')
ax.plot(pos_[:, 1], 'b-')
ax.plot(pos[:, 2], 'r--', label='z')
ax.plot(pos_[:, 2], 'b--')
ax.legend()

plt.show()