## Basic Example

### An Notes:
- Behavior Cloning
- import behavior cloning from imitation.algorithms
- behavior cloning inputs
    - observation space (Space data structure)
    - action space (Space data structure)
    - rng (Generator Data Structure)
    - policy: ActorCriticPolicy data structure (from stablebaselines3 )
    - demonstrations: (Trajectory iterable, TransitionsMinimal, or Iterable Mapping)
    
- ActorCriticPolicy => stablebaselines3
- Space => OpenAI Gym library

In [1]:
import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

env = gym.make("CartPole-v1")
expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
expert.learn(1000)  # Note: set to 100000 to train a proficient expert


#rollout definition:
# how something plays out from current state

<stable_baselines3.ppo.ppo.PPO at 0x7f243ce11d80>

In [2]:
from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(expert, env, 10)
print(reward)



185.1


In [3]:
import numpy as np
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv

rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)
transitions = rollout.flatten_trajectories(rollouts)
# transitions have 3 attributes
# - obs
# - acts aka actions
# - next_obs aka next_actions
# - dones
# - info

In [4]:
transitions

Transitions(obs=array([[ 0.01667576, -0.02619564, -0.03331875, -0.00105004],
       [ 0.01615185,  0.16938792, -0.03333976, -0.30405644],
       [ 0.01953961,  0.36496875, -0.03942088, -0.6070647 ],
       ...,
       [ 0.4787453 ,  0.8429384 ,  0.17751434,  0.70167726],
       [ 0.49560407,  1.0352138 ,  0.19154789,  0.46971196],
       [ 0.51630837,  0.83797556,  0.20094213,  0.8161297 ]],
      dtype=float32), acts=array([1, 1, 0, ..., 1, 0, 1]), infos=array([{}, {}, {}, ..., {}, {}, {}], dtype=object), next_obs=array([[ 0.01615185,  0.16938792, -0.03333976, -0.30405644],
       [ 0.01953961,  0.36496875, -0.03942088, -0.6070647 ],
       [ 0.02683898,  0.17041951, -0.05156218, -0.32705432],
       ...,
       [ 0.49560407,  1.0352138 ,  0.19154789,  0.46971196],
       [ 0.51630837,  0.83797556,  0.20094213,  0.8161297 ],
       [ 0.5330679 ,  1.029863  ,  0.21726473,  0.59277016]],
      dtype=float32), dones=array([False, False, False, ..., False, False,  True]))

In [5]:
transitions.infos

array([{}, {}, {}, ..., {}, {}, {}], dtype=object)

In [6]:
transitions.infos.shape

(3064,)

In [7]:
transitions.acts

array([1, 1, 0, ..., 1, 0, 1])

In [8]:
from imitation.algorithms import bc

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)

In [9]:
reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward before training: {reward_before_training}")

Reward before training: 104.0


In [10]:
bc_trainer.train(n_epochs=1)
reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)
print(f"Reward after training: {reward_after_training}")

0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000693 |
|    entropy        | 0.693     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 72.5      |
|    loss           | 0.692     |
|    neglogp        | 0.692     |
|    prob_true_act  | 0.5       |
|    samples_so_far | 32        |
---------------------------------


95batch [00:00, 110.21batch/s]
95batch [00:00, 110.04batch/s][A


Reward after training: 111.6


In [11]:
bc_trainer.save_policy("bc_policy")
# Supposed to use reconstruct policy to load back in

In [12]:
# Can extract model weights from policy parameters
for i in expert.policy.parameters():
    print(i)
    print(type(i))

Parameter containing:
tensor([[-1.9931e-01,  1.1871e-01,  2.1681e-01, -2.8612e-02],
        [ 2.5342e-01,  2.2416e-01, -1.6489e-01, -2.8988e-01],
        [-2.1801e-01,  1.8459e-01,  6.3882e-02,  1.8900e-01],
        [ 1.1188e-01,  8.0900e-02,  2.2855e-01,  1.2489e-01],
        [ 1.2619e-01, -1.3515e-01,  4.8238e-02, -7.1517e-02],
        [ 2.3515e-01, -1.3411e-01, -4.0923e-01,  3.0698e-01],
        [-1.4544e-01, -1.9093e-02,  8.2905e-02, -1.7243e-01],
        [-1.1564e-01,  2.7794e-01,  3.3697e-02,  9.1425e-02],
        [ 9.1015e-02, -4.9472e-02,  1.2093e-01, -2.0191e-01],
        [ 7.9515e-02,  1.2860e-01, -1.4759e-01, -5.5095e-01],
        [ 3.1623e-01,  2.1549e-02,  1.8243e-03, -9.6266e-02],
        [-9.5484e-02,  2.2127e-01,  3.7253e-02,  7.4869e-03],
        [-1.3339e-01, -2.5973e-01,  4.0252e-01, -2.8703e-02],
        [-2.2588e-01, -7.9444e-02, -8.7333e-02,  5.8340e-03],
        [-1.2040e-01, -1.7251e-01, -5.0705e-03, -1.2205e-01],
        [ 5.5180e-02,  8.1029e-02, -2.3748e-02, 

In [13]:
expert_weights = expert.policy.parameters_to_vector()
expert_weights.shape

(9155,)

In [14]:
bc_weights = bc_trainer.policy.parameters_to_vector()
bc_weights.shape
# Hmm, why are the shapes different???

(2531,)

# Now Train BC for Air Hockey

In [9]:
from air_hockey_challenge.framework.evaluate_agent import evaluate, custom_evaluate
from baseline.baseline_agent.baseline_agent import build_agent
import pickle

In [29]:
config = {'render': False, 'quiet': True, 'n_episodes': 1000, 'n_cores': 4, 'log_dir': 'logs', 'seed': None, 'generate_score': 'phase-1', 'env_list': ['3dof-hit']}
custom_evaluate(build_agent, **config)

=== CUSTOM EVALUATE ===
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  RE

In [30]:
with open("training_data.pkl", "rb") as f:
    training_data = pickle.load(f)
    obs = training_data["obs"]
    actions = training_data["actions"].reshape(training_data["actions"].shape[0],-1)
    next_obs = training_data["next_obs"] # <- for next run, this will be renamed to next obs
    dones = training_data["dones"]
    info = training_data["info"]

In [31]:
from imitation.data.types import Transitions

transitions = Transitions(obs=obs, acts=actions, infos=info, next_obs=next_obs, dones=dones)

In [32]:
from air_hockey_challenge.framework.air_hockey_challenge_wrapper import AirHockeyChallengeWrapper
import gym

mdp = AirHockeyChallengeWrapper(env="3dof-hit")
obs_space = mdp.info.observation_space
bc_obs_space = gym.spaces.Box(low=obs_space.low, high=obs_space.high, shape=obs_space.shape)


#this deceives u.... action_space from mdp only shows the torque limits
action_space = mdp.info.action_space

jnt_range = mdp.base_env.env_info['robot']['robot_model'].jnt_range
vel_range = mdp.base_env.env_info['robot']['joint_vel_limit'].T * 0.95

low = np.hstack((jnt_range[:,0][:,np.newaxis], vel_range[:,0][:,np.newaxis])).T
high = np.hstack((jnt_range[:,1][:,np.newaxis], vel_range[:,1][:,np.newaxis])).T

bc_action_space = gym.spaces.Box(low=low, high=high, shape=(2,3))

In [33]:
from imitation.algorithms import bc
import numpy as np

rng = np.random.default_rng()
bc_trainer = bc.BC(
    observation_space=bc_obs_space,
    action_space=bc_action_space,
    demonstrations=transitions,
    rng=rng,
    batch_size=256,
    device='cuda'
)

In [35]:
bc_trainer.train(n_epochs=200, progress_bar=True,log_interval=10000,log_rollouts_n_episodes=1000)

0batch [00:00, ?batch/s]

--------------------------------
| batch_size        | 256      |
| bc/               |          |
|    batch          | 0        |
|    ent_loss       | -0.00851 |
|    entropy        | 8.51     |
|    epoch          | 0        |
|    l2_loss        | 0        |
|    l2_norm        | 88.5     |
|    loss           | 8.12     |
|    neglogp        | 8.13     |
|    prob_true_act  | 0.000318 |
|    samples_so_far | 256      |
--------------------------------


125batch [00:03, 40.87batch/s]
255batch [00:06, 41.56batch/s][A
385batch [00:08, 42.63batch/s][A
516batch [00:12, 40.54batch/s][A
644batch [00:14, 42.58batch/s][A
771batch [00:17, 56.18batch/s][A
899batch [00:19, 60.45batch/s][A
1032batch [00:22, 48.77batch/s][A
1158batch [00:25, 41.39batch/s][A
1288batch [00:28, 43.08batch/s][A
1417batch [00:31, 42.61batch/s][A
1547batch [00:34, 42.57batch/s][A
1677batch [00:36, 52.03batch/s][A
1806batch [00:39, 59.51batch/s][A
1934batch [00:41, 45.04batch/s][A
2064batch [00:44, 45.59batch/s][A
2189batch [00:47, 40.86batch/s][A
2322batch [00:50, 42.72batch/s][A
2447batch [00:53, 46.61batch/s][A
2578batch [00:55, 41.53batch/s][A
2708batch [00:58, 44.00batch/s][A
2833batch [01:01, 50.87batch/s][A
2964batch [01:04, 42.51batch/s][A
3095batch [01:07, 45.81batch/s][A
3225batch [01:10, 41.41batch/s][A
3351batch [01:13, 44.51batch/s][A
3483batch [01:16, 41.92batch/s][A
3612batch [01:18, 49.46batch/s][A
3738batch [01:21, 43.65batch/s]

--------------------------------
| batch_size        | 256      |
| bc/               |          |
|    batch          | 10000    |
|    ent_loss       | 0.0181   |
|    entropy        | -18.1    |
|    epoch          | 77       |
|    l2_loss        | 0        |
|    l2_norm        | 182      |
|    loss           | -18.2    |
|    neglogp        | -18.2    |
|    prob_true_act  | 4.21e+08 |
|    samples_so_far | 2560256  |
--------------------------------


10061batch [03:43, 45.04batch/s]
10191batch [03:46, 41.39batch/s][A
10316batch [03:49, 31.76batch/s][A
10446batch [03:52, 48.02batch/s][A
10486batch [03:53, 39.94batch/s][A

Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from: 

10576batch [03:55, 45.09batch/s]
10707batch [03:58, 44.81batch/s][A
10832batch [04:01, 41.03batch/s][A
10962batch [04:04, 41.96batch/s][A
10967batch [04:04, 40.13batch/s][A

Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from: 

11089batch [04:07, 51.47batch/s]
11222batch [04:09, 57.32batch/s][A
11352batch [04:12, 58.84batch/s][A
11478batch [04:14, 58.83batch/s][A
11610batch [04:17, 41.84batch/s][A
11736batch [04:20, 41.59batch/s][A
11866batch [04:23, 44.85batch/s][A
11996batch [04:26, 40.65batch/s][A
12126batch [04:29, 46.10batch/s][A
12253batch [04:31, 43.28batch/s][A
12383batch [04:34, 41.90batch/s][A
12512batch [04:37, 47.33batch/s][A
12638batch [04:40, 47.38batch/s][A
12768batch [04:43, 44.16batch/s][A
12884batch [04:45, 42.49batch/s][A

Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from: 

12899batch [04:46, 42.40batch/s]
13025batch [04:49, 47.46batch/s][A
13154batch [04:51, 58.17batch/s][A
13283batch [04:54, 54.26batch/s][A
13347batch [04:56, 43.96batch/s][A

Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from:  READY  to:  SMASH
Agent:  1 Switch tactic from:  SMASH  to:  READY
Agent:  1 Switch tactic from: 

13412batch [04:57, 40.89batch/s]
13545batch [05:00, 50.26batch/s][A
13674batch [05:03, 49.81batch/s][A
13800batch [05:05, 45.06batch/s][A
13932batch [05:08, 49.29batch/s][A
14061batch [05:11, 49.27batch/s][A
14186batch [05:13, 43.37batch/s][A
14317batch [05:16, 41.37batch/s][A
14447batch [05:19, 41.90batch/s][A
14576batch [05:22, 45.57batch/s][A
14703batch [05:25, 58.88batch/s][A
14834batch [05:27, 60.61batch/s][A
14960batch [05:30, 47.32batch/s][A
15091batch [05:32, 42.30batch/s][A
15222batch [05:35, 46.05batch/s][A
15348batch [05:38, 51.47batch/s][A
15479batch [05:40, 60.02batch/s][A
15604batch [05:43, 60.05batch/s][A
15734batch [05:45, 49.35batch/s][A
15865batch [05:48, 46.20batch/s][A
15995batch [05:51, 40.21batch/s][A
16124batch [05:53, 44.32batch/s][A
16254batch [05:57, 41.25batch/s][A
16379batch [06:00, 40.20batch/s][A
16508batch [06:03, 42.70batch/s][A
16638batch [06:06, 40.36batch/s][A
16767batch [06:09, 40.54batch/s][A
16899batch [06:12, 44.03batch/s

--------------------------------
| batch_size        | 256      |
| bc/               |          |
|    batch          | 20000    |
|    ent_loss       | 0.0189   |
|    entropy        | -18.9    |
|    epoch          | 155      |
|    l2_loss        | 0        |
|    l2_norm        | 192      |
|    loss           | -19.1    |
|    neglogp        | -19.1    |
|    prob_true_act  | 1.3e+09  |
|    samples_so_far | 5120256  |
--------------------------------


20122batch [07:25, 43.18batch/s]
20253batch [07:28, 46.14batch/s][A
20378batch [07:31, 41.15batch/s][A
20505batch [07:33, 57.23batch/s][A
20640batch [07:36, 58.34batch/s][A
20764batch [07:38, 52.41batch/s][A
20893batch [07:41, 51.79batch/s][A
21026batch [07:44, 48.28batch/s][A
21154batch [07:46, 40.93batch/s][A
21285batch [07:49, 40.92batch/s][A
21410batch [07:52, 45.23batch/s][A
21540batch [07:55, 43.73batch/s][A
21671batch [07:58, 41.09batch/s][A
21797batch [08:01, 45.15batch/s][A
21930batch [08:04, 45.33batch/s][A
22055batch [08:07, 43.21batch/s][A
22185batch [08:10, 44.13batch/s][A
22315batch [08:13, 40.78batch/s][A
22445batch [08:16, 40.63batch/s][A
22575batch [08:19, 40.36batch/s][A
22700batch [08:22, 41.29batch/s][A
22832batch [08:25, 50.77batch/s][A
22958batch [08:28, 44.04batch/s][A
23088batch [08:31, 45.13batch/s][A
23216batch [08:34, 44.44batch/s][A
23347batch [08:37, 43.12batch/s][A
23477batch [08:40, 40.52batch/s][A
23607batch [08:43, 44.57batch/s

In [51]:
torch.save(bc_trainer.policy.state_dict(),'intro_il_model.pt')