## Basic Example

In [7]:
import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

env = gym.make("CartPole-v1")
expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
expert.learn(1000)  # Note: set to 100000 to train a proficient expert

<stable_baselines3.ppo.ppo.PPO at 0x7f594834fd90>

In [8]:
from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(expert, env, 10)
print(reward)

51.9




In [9]:
import numpy as np
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv

rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)
transitions = rollout.flatten_trajectories(rollouts)
# transitions have 3 attributes
# - obs
# - acts aka actions
# - next_obs aka next_actions
# - dones
# - info

In [9]:
transitions

Transitions(obs=array([[-1.3357522e-02, -3.5211290e-04,  3.4990080e-02,  1.9791661e-02],
       [-1.3364565e-02, -1.9595793e-01,  3.5385914e-02,  3.2330579e-01],
       [-1.7283723e-02, -1.3572598e-03,  4.1852027e-02,  4.1988794e-02],
       ...,
       [-1.2179057e-01, -6.1762249e-01,  1.5395266e-01,  1.0212862e+00],
       [-1.3414302e-01, -4.2484939e-01,  1.7437838e-01,  7.8062999e-01],
       [-1.4264001e-01, -6.2188470e-01,  1.8999098e-01,  1.1227086e+00]],
      dtype=float32), acts=array([0, 1, 0, ..., 1, 0, 1]), infos=array([{}, {}, {}, ..., {}, {}, {}], dtype=object), next_obs=array([[-0.01336456, -0.19595793,  0.03538591,  0.3233058 ],
       [-0.01728372, -0.00135726,  0.04185203,  0.04198879],
       [-0.01731087, -0.1970536 ,  0.0426918 ,  0.34757715],
       ...,
       [-0.13414302, -0.4248494 ,  0.17437838,  0.78063   ],
       [-0.14264001, -0.6218847 ,  0.18999098,  1.1227086 ],
       [-0.1550777 , -0.4296917 ,  0.21244515,  0.8951285 ]],
      dtype=float32), dones=

In [10]:
transitions.infos

array([{}, {}, {}, ..., {}, {}, {}], dtype=object)

In [11]:
transitions.infos.shape

(1779,)

In [7]:
transitions.acts

array([0, 1, 0, ..., 1, 0, 1])

In [4]:
from imitation.algorithms import bc

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)

In [5]:
reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward before training: {reward_before_training}")

Reward before training: 23.7


In [6]:
bc_trainer.train(n_epochs=1)
reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward after training: {reward_after_training}")

0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000693 |
|    entropy        | 0.693     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 72.5      |
|    loss           | 0.692     |
|    neglogp        | 0.693     |
|    prob_true_act  | 0.5       |
|    samples_so_far | 32        |
---------------------------------


54batch [00:00, 534.57batch/s]
55batch [00:00, 526.19batch/s][A


Reward after training: 62.0


In [7]:
bc_trainer.save_policy("bc_policy")
# Supposed to use reconstruct policy to load back in

In [10]:
# Can extract model weights from policy parameters
for i in expert.policy.parameters():
    print(i)
    print(type(i))

Parameter containing:
tensor([[-1.9988e-01,  1.2384e-01,  2.1468e-01, -3.8810e-02],
        [ 2.4016e-01,  2.1442e-01, -1.4569e-01, -2.7289e-01],
        [-2.0511e-01,  1.9610e-01,  4.1891e-02,  1.6776e-01],
        [ 1.1887e-01,  8.9274e-02,  2.1588e-01,  1.0640e-01],
        [ 1.2598e-01, -1.2922e-01,  4.7004e-02, -7.9553e-02],
        [ 2.4999e-01, -1.2543e-01, -4.2739e-01,  2.9047e-01],
        [-1.5636e-01, -2.6764e-02,  9.8490e-02, -1.5604e-01],
        [-1.0875e-01,  2.8341e-01,  2.3356e-02,  7.9646e-02],
        [ 8.0613e-02, -5.7643e-02,  1.3630e-01, -1.8588e-01],
        [ 6.0324e-02,  1.1692e-01, -1.2352e-01, -5.3381e-01],
        [ 3.1052e-01,  1.4436e-02,  1.2801e-02, -7.9789e-02],
        [-1.0405e-01,  2.1224e-01,  5.2555e-02,  2.5662e-02],
        [-1.3170e-01, -2.5280e-01,  3.9184e-01, -5.5117e-02],
        [-2.1820e-01, -7.3212e-02, -9.7979e-02, -7.8742e-03],
        [-1.1662e-01, -1.6924e-01, -6.8760e-03, -1.2541e-01],
        [ 6.4953e-02,  9.1821e-02, -4.0011e-02, 

In [15]:
expert_weights = expert.policy.parameters_to_vector()
expert_weights.shape

(9155,)

In [16]:
bc_weights = bc_trainer.policy.parameters_to_vector()
bc_weights.shape
# Hmm, why are the shapes different???

(array([-0.00279474,  0.6086224 , -0.30300206, ...,  0.09945919,
         0.16489056,  0.        ], dtype=float32),
 (2531,))

# Now Train BC for Air Hockey

In [1]:
from air_hockey_challenge.framework.evaluate_agent import evaluate, custom_evaluate
from baseline.baseline_agent.baseline_agent import build_agent

import pickle

In [2]:
config = {'render': True, 'quiet': False, 'n_episodes': 5, 'n_cores': 1, 'log_dir': 'logs', 'seed': None, 'generate_score': 'phase-1', 'env_list': ['3dof-hit']}
custom_evaluate(build_agent, **config)

=== CUSTOM EVALUATE ===


  0%|                                                                                                                                                                                                                   | 0/5 [00:00<?, ?it/s]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 20%|████████████████████████████████████████▌                                                                                                                                                                  | 1/5 [00:04<00:19,  4.80s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 40%|█████████████████████████████████████████████████████████████████████████████████▏                                                                                                                         | 2/5 [00:08<00:12,  4.09s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 60%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                 | 3/5 [00:12<00:08,  4.03s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 80%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                        | 4/5 [00:16<00:03,  3.99s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


                                                                                                                                                                                                                                              

[(1.0, 4.0, defaultdict(<class 'list'>, {'Episode 0': ['jerk > 10000', 'max computation_time > 0.02s'], 'Episode 1': ['jerk > 10000'], 'Episode 4': ['jerk > 10000', 'max computation_time > 0.02s']}), dict_keys(['joint_pos_constr', 'joint_vel_constr', 'ee_constr']), 648)]
DATA: [(1.0, 4.0, defaultdict(<class 'list'>, {'Episode 0': ['jerk > 10000', 'max computation_time > 0.02s'], 'Episode 1': ['jerk > 10000'], 'Episode 4': ['jerk > 10000', 'max computation_time > 0.02s']}), dict_keys(['joint_pos_constr', 'joint_vel_constr', 'ee_constr']), 648)]
Environment:        3dof-hit
Number of Episodes: 5
Success:            1.0000
Penalty:            4.0
Number of Violations: 
  Jerk              3
  Computation Time  2
  Total             5
-------------------------------------------------





In [3]:
with open("training_data.pkl", "rb") as f:
    training_data = pickle.load(f)
    obs = training_data["obs"]
    actions = training_data["actions"]
    next_obs = training_data["next_obs"] # <- for next run, this will be renamed to next obs
    dones = training_data["dones"]
    info = training_data["info"]

In [6]:
from imitation.data.types import Transitions

transitions = Transitions(obs=obs, acts=actions, infos=info, next_obs=next_obs, dones=dones)

In [18]:
from air_hockey_challenge.framework.air_hockey_challenge_wrapper import AirHockeyChallengeWrapper
import gym

mdp = AirHockeyChallengeWrapper(env="3dof-hit")
obs_space = mdp.info.observation_space
bc_obs_space = gym.spaces.Box(low=obs_space.low, high=obs_space.high, shape=obs_space.shape)
action_space = mdp.info.action_space
bc_action_space = gym.spaces.Box(low=action_space.low, high=action_space.high, shape=action_space.shape)

In [19]:
from imitation.algorithms import bc
import numpy as np

rng = np.random.default_rng()
bc_trainer = bc.BC(
    observation_space=bc_obs_space,
    action_space=bc_action_space,
    demonstrations=transitions,
    rng=rng,
)

In [20]:
bc_trainer.train(n_epochs=1)

0batch [00:00, ?batch/s]


ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([32, 2, 3]) vs torch.Size([32, 3]).

In [2]:
config = {'render': True, 'quiet': False, 'n_episodes': 5, 'n_cores': 1, 'log_dir': 'logs', 'seed': None, 'generate_score': 'phase-1', 'env_list': ['3dof-hit']}
evaluate(build_agent, **config)

  0%|                                                                                           | 0/5 [00:00<?, ?it/s]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 20%|████████████████▌                                                                  | 1/5 [00:03<00:14,  3.74s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 40%|█████████████████████████████████▏                                                 | 2/5 [00:07<00:11,  3.70s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 60%|█████████████████████████████████████████████████▊                                 | 3/5 [00:11<00:07,  3.67s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 80%|██████████████████████████████████████████████████████████████████▍                | 4/5 [00:14<00:03,  3.64s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


                                                                                                                      

Environment:        3dof-hit
Number of Episodes: 5
Success:            1.0000
Penalty:            6.5
Number of Violations: 
  Jerk              4
  Computation Time  5
  Total             9
-------------------------------------------------



