## Basic Example

### An Notes:
- Behavior Cloning
- import behavior cloning from imitation.algorithms
- behavior cloning inputs
    - observation space (Space data structure)
    - action space (Space data structure)
    - rng (Generator Data Structure)
    - policy: ActorCriticPolicy data structure (from stablebaselines3 )
    - demonstrations: (Trajectory iterable, TransitionsMinimal, or Iterable Mapping)
    
- ActorCriticPolicy => stablebaselines3
- Space => OpenAI Gym library

In [1]:
import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

env = gym.make("CartPole-v1")
expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
expert.learn(1000)  # Note: set to 100000 to train a proficient expert


#rollout definition:
# how something plays out from current state

<stable_baselines3.ppo.ppo.PPO at 0x7ff6f5521ea0>

In [2]:
from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(expert, env, 10)
print(reward)



185.1


In [3]:
import numpy as np
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv

rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)
transitions = rollout.flatten_trajectories(rollouts)
# transitions have 3 attributes
# - obs
# - acts aka actions
# - next_obs aka next_actions
# - dones
# - info

In [4]:
transitions

Transitions(obs=array([[ 0.03072335, -0.00977635, -0.04372443, -0.03105337],
       [ 0.03052782, -0.20424488, -0.04434549,  0.24751975],
       [ 0.02644292, -0.00851858, -0.0393951 , -0.05881438],
       ...,
       [ 0.23966257,  0.9135243 ,  0.18704192,  0.02855892],
       [ 0.25793305,  0.71628124,  0.18761311,  0.3739335 ],
       [ 0.2722587 ,  0.5190588 ,  0.19509178,  0.7194112 ]],
      dtype=float32), acts=array([0, 1, 1, ..., 0, 0, 1]), infos=array([{}, {}, {}, ..., {}, {}, {}], dtype=object), next_obs=array([[ 0.03052782, -0.20424488, -0.04434549,  0.24751975],
       [ 0.02644292, -0.00851858, -0.0393951 , -0.05881438],
       [ 0.02627255,  0.18714543, -0.04057138, -0.36366186],
       ...,
       [ 0.25793305,  0.71628124,  0.18761311,  0.3739335 ],
       [ 0.2722587 ,  0.5190588 ,  0.19509178,  0.7194112 ],
       [ 0.28263986,  0.71102333,  0.20948   ,  0.4939206 ]],
      dtype=float32), dones=array([False, False, False, ..., False, False,  True]))

In [5]:
transitions.infos

array([{}, {}, {}, ..., {}, {}, {}], dtype=object)

In [6]:
transitions.infos.shape

(3064,)

In [7]:
transitions.acts

array([0, 1, 1, ..., 0, 0, 1])

In [8]:
from imitation.algorithms import bc

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)

In [9]:
reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward before training: {reward_before_training}")

Reward before training: 104.0


In [10]:
bc_trainer.train(n_epochs=1)
reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward after training: {reward_after_training}")

0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000693 |
|    entropy        | 0.693     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 72.5      |
|    loss           | 0.692     |
|    neglogp        | 0.693     |
|    prob_true_act  | 0.5       |
|    samples_so_far | 32        |
---------------------------------


95batch [00:01, 104.75batch/s]
95batch [00:01, 94.20batch/s] [A


Reward after training: 121.8


In [11]:
bc_trainer.save_policy("bc_policy")
# Supposed to use reconstruct policy to load back in

In [12]:
# Can extract model weights from policy parameters
for i in expert.policy.parameters():
    print(i)
    print(type(i))

Parameter containing:
tensor([[-1.9931e-01,  1.1871e-01,  2.1681e-01, -2.8612e-02],
        [ 2.5342e-01,  2.2416e-01, -1.6489e-01, -2.8988e-01],
        [-2.1801e-01,  1.8459e-01,  6.3882e-02,  1.8900e-01],
        [ 1.1188e-01,  8.0900e-02,  2.2855e-01,  1.2489e-01],
        [ 1.2619e-01, -1.3515e-01,  4.8238e-02, -7.1517e-02],
        [ 2.3515e-01, -1.3411e-01, -4.0923e-01,  3.0698e-01],
        [-1.4544e-01, -1.9093e-02,  8.2905e-02, -1.7243e-01],
        [-1.1564e-01,  2.7794e-01,  3.3697e-02,  9.1425e-02],
        [ 9.1015e-02, -4.9472e-02,  1.2093e-01, -2.0191e-01],
        [ 7.9515e-02,  1.2860e-01, -1.4759e-01, -5.5095e-01],
        [ 3.1623e-01,  2.1549e-02,  1.8243e-03, -9.6266e-02],
        [-9.5484e-02,  2.2127e-01,  3.7253e-02,  7.4869e-03],
        [-1.3339e-01, -2.5973e-01,  4.0252e-01, -2.8703e-02],
        [-2.2588e-01, -7.9444e-02, -8.7333e-02,  5.8340e-03],
        [-1.2040e-01, -1.7251e-01, -5.0705e-03, -1.2205e-01],
        [ 5.5180e-02,  8.1029e-02, -2.3748e-02, 

In [13]:
expert_weights = expert.policy.parameters_to_vector()
expert_weights.shape

(9155,)

In [14]:
bc_weights = bc_trainer.policy.parameters_to_vector()
bc_weights.shape
# Hmm, why are the shapes different???

(2531,)

# Now Train BC for Air Hockey

In [15]:
from air_hockey_challenge.framework.evaluate_agent import evaluate, custom_evaluate
from baseline.baseline_agent.baseline_agent import build_agent

import pickle

In [16]:
config = {'render': True, 'quiet': False, 'n_episodes': 5, 'n_cores': 1, 'log_dir': 'logs', 'seed': None, 'generate_score': 'phase-1', 'env_list': ['3dof-hit']}
custom_evaluate(build_agent, **config)

=== CUSTOM EVALUATE ===


  0%|                                                     | 0/5 [00:00<?, ?it/s]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 20%|█████████                                    | 1/5 [00:06<00:24,  6.19s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 40%|██████████████████                           | 2/5 [00:12<00:19,  6.35s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 60%|███████████████████████████                  | 3/5 [00:19<00:13,  6.51s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 80%|████████████████████████████████████         | 4/5 [00:25<00:06,  6.29s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


                                                                                

DATA: [(1.0, 12.0, defaultdict(<class 'list'>, {'Episode 0': ['jerk > 10000', 'max computation_time > 0.02s'], 'Episode 1': ['jerk > 10000', 'max computation_time > 0.02s'], 'Episode 2': ['jerk > 10000', 'mean computation_time > 0.02s'], 'Episode 3': ['jerk > 10000', 'mean computation_time > 0.02s'], 'Episode 4': ['jerk > 10000', 'mean computation_time > 0.02s']}), dict_keys(['joint_pos_constr', 'joint_vel_constr', 'ee_constr']), 574)]
Environment:        3dof-hit
Number of Episodes: 5
Success:            1.0000
Penalty:            12.0
Number of Violations: 
  Jerk              5
  Computation Time  5
  Total             10
-------------------------------------------------





In [79]:
with open("training_data.pkl", "rb") as f:
    training_data = pickle.load(f)
    obs = training_data["obs"]
    actions = training_data["actions"].reshape(training_data["actions"].shape[0],-1)
    next_obs = training_data["next_obs"] # <- for next run, this will be renamed to next obs
    dones = training_data["dones"]
    info = training_data["info"]

In [80]:
from imitation.data.types import Transitions

transitions = Transitions(obs=obs, acts=actions, infos=info, next_obs=next_obs, dones=dones)

In [81]:
from air_hockey_challenge.framework.air_hockey_challenge_wrapper import AirHockeyChallengeWrapper
import gym

mdp = AirHockeyChallengeWrapper(env="3dof-hit")
obs_space = mdp.info.observation_space
bc_obs_space = gym.spaces.Box(low=obs_space.low, high=obs_space.high, shape=obs_space.shape)


#this deceives u.... action_space from mdp only shows the torque limits
action_space = mdp.info.action_space

jnt_range = mdp.base_env.env_info['robot']['robot_model'].jnt_range
vel_range = mdp.base_env.env_info['robot']['joint_vel_limit'].T * 0.95

low = np.hstack((jnt_range[:,0][:,np.newaxis], vel_range[:,0][:,np.newaxis])).T
high = np.hstack((jnt_range[:,1][:,np.newaxis], vel_range[:,1][:,np.newaxis])).T

bc_action_space = gym.spaces.Box(low=low, high=high, shape=(2,3))

In [82]:
from imitation.algorithms import bc
import numpy as np

rng = np.random.default_rng()
bc_trainer = bc.BC(
    observation_space=bc_obs_space,
    action_space=bc_action_space,
    demonstrations=transitions,
    rng=rng,
)

In [None]:
bc_trainer.train(n_epochs=100)

0batch [00:00, ?batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 0        |
|    ent_loss       | -0.00847 |
|    entropy        | 8.47     |
|    epoch          | 0        |
|    l2_loss        | 0        |
|    l2_norm        | 89.4     |
|    loss           | 6.97     |
|    neglogp        | 6.98     |
|    prob_true_act  | 0.000998 |
|    samples_so_far | 32       |
--------------------------------


12batch [00:00, 54.70batch/s]
31batch [00:00, 51.89batch/s][A
49batch [00:00, 51.98batch/s][A
67batch [00:01, 50.81batch/s][A
85batch [00:01, 52.82batch/s][A
97batch [00:01, 50.97batch/s][A
114batch [00:02, 51.50batch/s][A
133batch [00:02, 55.04batch/s][A
151batch [00:02, 51.72batch/s][A
166batch [00:03, 59.80batch/s][A
177batch [00:03, 71.75batch/s][A
200batch [00:03, 88.69batch/s][A
218batch [00:03, 72.99batch/s][A
235batch [00:04, 70.14batch/s][A
245batch [00:04, 77.42batch/s][A
266batch [00:04, 87.01batch/s][A
283batch [00:04, 66.03batch/s][A
301batch [00:04, 71.80batch/s][A
312batch [00:04, 80.88batch/s][A
334batch [00:05, 79.82batch/s][A
351batch [00:05, 64.20batch/s][A
372batch [00:05, 79.41batch/s][A
383batch [00:05, 86.40batch/s][A
405batch [00:06, 86.58batch/s][A
424batch [00:06, 86.36batch/s][A
441batch [00:06, 64.65batch/s][A
459batch [00:06, 73.79batch/s][A
468batch [00:07, 77.34batch/s][A
485batch [00:07, 74.01batch/s][A
494batch [00:07, 74.82ba

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 500      |
|    ent_loss       | -0.00526 |
|    entropy        | 5.26     |
|    epoch          | 29       |
|    l2_loss        | 0        |
|    l2_norm        | 96.2     |
|    loss           | 2.31     |
|    neglogp        | 2.32     |
|    prob_true_act  | 0.0991   |
|    samples_so_far | 16032    |
--------------------------------


510batch [00:07, 66.35batch/s]
525batch [00:07, 69.38batch/s][A
535batch [00:07, 77.61batch/s][A
556batch [00:08, 87.43batch/s][A
575batch [00:08, 81.85batch/s][A
592batch [00:08, 73.66batch/s][A
607batch [00:08, 60.98batch/s][A
627batch [00:09, 55.37batch/s][A
640batch [00:09, 57.22batch/s][A
659batch [00:09, 53.52batch/s][A
679batch [00:10, 58.52batch/s][A
693batch [00:10, 62.12batch/s][A
714batch [00:10, 62.39batch/s][A
730batch [00:11, 69.07batch/s][A
737batch [00:11, 67.02batch/s][A
758batch [00:11, 80.90batch/s][A
775batch [00:11, 70.93batch/s][A
793batch [00:11, 77.95batch/s][A
809batch [00:12, 74.16batch/s][A
831batch [00:12, 62.60batch/s][A
845batch [00:12, 58.42batch/s][A
867batch [00:13, 66.90batch/s][A
881batch [00:13, 64.92batch/s][A
898batch [00:13, 72.51batch/s][A
914batch [00:13, 72.87batch/s][A
931batch [00:13, 78.66batch/s][A
947batch [00:14, 69.00batch/s][A
963batch [00:14, 69.34batch/s][A
984batch [00:14, 60.31batch/s][A
997batch [00:14, 

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 1000     |
|    ent_loss       | -0.00222 |
|    entropy        | 2.22     |
|    epoch          | 58       |
|    l2_loss        | 0        |
|    l2_norm        | 103      |
|    loss           | -0.757   |
|    neglogp        | -0.755   |
|    prob_true_act  | 2.13     |
|    samples_so_far | 32032    |
--------------------------------


1003batch [00:15, 54.72batch/s]
1016batch [00:15, 57.33batch/s][A
1034batch [00:15, 53.29batch/s][A
1052batch [00:16, 51.36batch/s][A
1066batch [00:16, 59.07batch/s][A
1087batch [00:16, 63.44batch/s][A
1101batch [00:16, 61.83batch/s][A
1115batch [00:17, 64.00batch/s][A
1133batch [00:17, 73.85batch/s][A
1145batch [00:17, 86.55batch/s][A
1171batch [00:17, 105.23batch/s][A
1182batch [00:17, 94.32batch/s] [A
1205batch [00:17, 95.67batch/s][A
1216batch [00:18, 98.73batch/s][A
1237batch [00:18, 85.04batch/s][A
1255batch [00:18, 81.32batch/s][A
1272batch [00:18, 75.45batch/s][A
1289batch [00:19, 70.95batch/s][A
1308batch [00:19, 78.46batch/s][A
1325batch [00:19, 71.09batch/s][A
1340batch [00:19, 60.70batch/s][A
1355batch [00:20, 59.32batch/s][A
1370batch [00:20, 64.10batch/s][A
1390batch [00:20, 80.52batch/s][A
1410batch [00:20, 73.50batch/s][A
1427batch [00:21, 77.25batch/s][A
1443batch [00:21, 72.04batch/s][A
1460batch [00:21, 75.22batch/s][A
1475batch [00:21, 61.1

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 1500     |
|    ent_loss       | 0.000779 |
|    entropy        | -0.779   |
|    epoch          | 88       |
|    l2_loss        | 0        |
|    l2_norm        | 109      |
|    loss           | -3.58    |
|    neglogp        | -3.58    |
|    prob_true_act  | 37.6     |
|    samples_so_far | 48032    |
--------------------------------


1510batch [00:22, 79.64batch/s]
1520batch [00:22, 84.90batch/s][A


In [84]:
config = {'render': True, 'quiet': False, 'n_episodes': 5, 'n_cores': 1, 'log_dir': 'logs', 'seed': None, 'generate_score': 'phase-1', 'env_list': ['3dof-hit']}
evaluate(build_agent, **config)

  0%|                                                     | 0/5 [00:00<?, ?it/s]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 20%|█████████                                    | 1/5 [00:05<00:21,  5.32s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 40%|██████████████████                           | 2/5 [00:10<00:15,  5.16s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 60%|███████████████████████████                  | 3/5 [00:17<00:12,  6.09s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


 80%|████████████████████████████████████         | 4/5 [00:23<00:06,  6.22s/it]

Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY


                                                                                

Environment:        3dof-hit
Number of Episodes: 5
Success:            1.0000
Penalty:            9.0
Number of Violations: 
  Jerk              5
  Computation Time  5
  Total             10
-------------------------------------------------

