In [12]:
%pip install huggingface_sb3 stable_baselines3 gymnasium plotly

Note: you may need to restart the kernel to use updated packages.


In [2]:
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO

checkpoint = load_from_hub(
	repo_id="sb3/ppo-CartPole-v1",
	filename="ppo-CartPole-v1.zip",
)
custom_objects = {
	"learning_rate": 0.0,
	"lr_schedule": lambda _: 0.0,
	"clip_range": lambda _: 0.0,
}

model = PPO.load(checkpoint, custom_objects=custom_objects, print_system_info=True)

  from .autonotebook import tqdm as notebook_tqdm


== CURRENT SYSTEM INFO ==
- OS: macOS-14.4.1-arm64-arm-64bit Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:41 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T8103
- Python: 3.11.9
- Stable-Baselines3: 2.4.0
- PyTorch: 2.5.1
- GPU Enabled: False
- Numpy: 1.26.4
- Cloudpickle: 3.1.0
- Gymnasium: 1.0.0
- OpenAI Gym: 0.26.2

== SAVED MODEL SYSTEM INFO ==
- OS: Linux-5.15.0-97-generic-x86_64-with-glibc2.35 # 107-Ubuntu SMP Wed Feb 7 13:26:48 UTC 2024
- Python: 3.10.9
- Stable-Baselines3: 2.3.0a3
- PyTorch: 2.2.0+cpu
- GPU Enabled: False
- Numpy: 1.24.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
- OpenAI Gym: 0.26.2



In [3]:
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
import gymnasium as gym

eval_env = Monitor(gym.make("CartPole-v1"))
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")

mean_reward=500.00 +/- 0.0


In [6]:
def run():
    done = False
    truncated = False
    
    env = gym.make("CartPole-v1", render_mode="human")
    obs, info = env.reset()
    total_reward = 0
    while not (done or truncated):
        action, _info = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        total_reward += reward
    
    env.close()
    
    print(f"Total reward: {total_reward:.2f}")
run()

Total reward: 500.00


In [15]:
model.policy.extract_features()

ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=2, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
)

In [17]:
hidden_layer_1_wts, hidden_layer_1_bs = model.policy.mlp_extractor.policy_net[0].parameters()
hidden_layer_2_wts, hidden_layer_2_bs = model.policy.mlp_extractor.policy_net[2].parameters()

hidden_layer_1_wts = hidden_layer_1_wts.T.detach()
hidden_layer_1_bs = hidden_layer_1_bs.detach()
hidden_layer_2_wts = hidden_layer_2_wts.T.detach()
hidden_layer_2_bs = hidden_layer_2_bs.detach()

In [18]:
import plotly.express as px

px.imshow(hidden_layer_1_wts, color_continuous_scale="RdBu").show()
px.imshow(hidden_layer_2_wts, color_continuous_scale="RdBu").show()