# Test Behavior Cloning Network inside of Air Hockey Challenge

Description:
- Trained Behavior Cloning network using Policy from Stable Baselines3 (inherits from nn.Module)
- Goal is to create a policy in MushroomRL (pipe the network over)
- Load the weights
- Evaluate network

### Transfer weights

In [9]:
import torch
import numpy as np
from imitation.policies.base import FeedForward32Policy

In [10]:
# setup observation space action space and lr schedule
from air_hockey_challenge.framework.air_hockey_challenge_wrapper import AirHockeyChallengeWrapper
import gym

mdp = AirHockeyChallengeWrapper(env="3dof-hit")
obs_space = mdp.info.observation_space
bc_obs_space = gym.spaces.Box(low=obs_space.low, high=obs_space.high, shape=obs_space.shape)


#this deceives u.... action_space from mdp only shows the torque limits
action_space = mdp.info.action_space

jnt_range = mdp.base_env.env_info['robot']['robot_model'].jnt_range
vel_range = mdp.base_env.env_info['robot']['joint_vel_limit'].T * 0.95

low = np.hstack((jnt_range[:,0][:,np.newaxis], vel_range[:,0][:,np.newaxis])).T
high = np.hstack((jnt_range[:,1][:,np.newaxis], vel_range[:,1][:,np.newaxis])).T

bc_action_space = gym.spaces.Box(low=low, high=high, shape=(2,3))

In [12]:
PATH = 'intro_il_model.pt'
model = FeedForward32Policy(action_space = bc_action_space, observation_space = bc_obs_space, lr_schedule= lambda _: torch.finfo(torch.float32).max)

In [13]:
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

### Load network into an agent for testing
Notes:
- FeedForward32Policy is an ActorCritic.
- this means we only use the action network instead of the critic/value network.
- action gets us what we need
- https://stable-baselines.readthedocs.io/en/master/_modules/stable_baselines/common/policies.html#ActorCriticPolicy
- link above breaks down the model structure.

In [34]:
model

FeedForward32Policy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=12, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=12, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=32, out_features=6, bias=True)
  (value_net): Linear(in_features=32, out_features=1, bias=True)
)

In [42]:
action, _ ,_ = model.forward(torch.zeros(1,12))

In [44]:
action

tensor([[[ 0.4628, -0.4467, -0.6038],
         [-0.9175,  0.6231,  0.8011]]], grad_fn=<ReshapeAliasBackward0>)

In [95]:
from air_hockey_challenge.framework import AgentBase
from air_hockey_challenge.framework.challenge_core import ChallengeCore, CustomChallengeCore

In [108]:
class BehaviorCloneAgent(AgentBase):
    def __init__(self, env_info, model, device, **kwargs):
        super().__init__(env_info, **kwargs)
        self.new_start = True
        self.hold_position = None
        
        self.model = model
        self.device = device

    def reset(self):
        self.new_start = True
        self.hold_position = None

    def draw_action(self, observation):
        if self.new_start:
            self.new_start = False
            self.hold_position = self.get_joint_pos(observation)

            velocity = np.zeros_like(self.hold_position)
            action = np.vstack([self.hold_position, velocity])
        else:
            obs_tensor = torch.tensor(observation,dtype=torch.float).reshape(1,12).detach()
            action,_,_ = self.model(obs_tensor)
            action = action.reshape(2,3).detach().numpy()
            
        
        return action

In [109]:
def build_agent(env_info, **kwargs):
    """
    Function where an Agent that controls the environments should be returned.
    The Agent should inherit from the mushroom_rl Agent base env.

    Args:
        env_info (dict): The environment information
        kwargs (any): Additionally setting from agent_config.yml
    Returns:
         (AgentBase) An instance of the Agent
    """

    return BehaviorCloneAgent(env_info, **kwargs)

In [110]:
mdp = AirHockeyChallengeWrapper(env="3dof-hit", action_type="position-velocity", interpolation_order=3, debug=True)
agent = BehaviorCloneAgent(mdp.base_env.env_info, model, 'cpu')

In [111]:
from air_hockey_challenge.framework.evaluate_agent import evaluate

In [112]:
evaluate(build_agent, 'some_folder/', ['3dof-hit'], 5, 1, model=model, device='cpu', render=True)

KeyboardInterrupt: 