In [33]:
import stable_baselines3 as sb3
import os
import torch as th
export_onnx_file = "SB3_RL_MPPT.onnx"

AGENT_CKP_PATH_3 = os.path.join("models", "02_mppt_a2c.tar")
class OnnxablePolicy(th.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super().__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        action_hidden, value_hidden = self.extractor(observation)
        return self.action_net(action_hidden), self.value_net(value_hidden)


# Example: model = PPO("MlpPolicy", "Pendulum-v1")
model = sb3.A2C.load(AGENT_CKP_PATH_3, device="cpu")
onnxable_model = OnnxablePolicy(
    model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
    onnxable_model,
    dummy_input,
    "my_a2c_model.onnx",
    opset_version=9,
    input_names=["input"],
)


verbose: False, log level: Level.ERROR



In [34]:
model.get_parameters()['policy'].keys()

odict_keys(['mlp_extractor.shared_net.0.weight', 'mlp_extractor.shared_net.0.bias', 'mlp_extractor.policy_net.0.weight', 'mlp_extractor.policy_net.0.bias', 'mlp_extractor.policy_net.2.weight', 'mlp_extractor.policy_net.2.bias', 'mlp_extractor.value_net.0.weight', 'mlp_extractor.value_net.0.bias', 'action_net.weight', 'action_net.bias', 'value_net.weight', 'value_net.bias'])

In [35]:
model.policy

ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (shared_net): Sequential(
      (0): Linear(in_features=2, out_features=128, bias=True)
      (1): SELU()
    )
    (policy_net): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): SELU()
      (2): Linear(in_features=128, out_features=15, bias=True)
      (3): SELU()
    )
    (value_net): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): SELU()
    )
  )
  (action_net): Linear(in_features=15, out_features=15, bias=True)
  (value_net): Linear(in_features=128, out_features=1, bias=True)
)

In [21]:
model.get_parameters()['policy']['action_net.weight'].shape

torch.Size([15, 15])

In [22]:
model.get_parameters()['policy']['value_net.weight'].shape

torch.Size([1, 128])

In [26]:
model.get_parameters()['policy']['mlp_extractor.policy_net.0.weight'].shape

torch.Size([128, 128])

In [28]:
model.get_parameters()['policy']['mlp_extractor.policy_net.2.weight'].shape

torch.Size([15, 128])

In [36]:
import numpy as np
my_actions = np.array([-25, -15, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 15, 25]) / 56

In [37]:
my_actions.shape

(15,)

In [38]:

my_actions = np.array([-25, -15, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 15, 25]) / 56

In [39]:
my_actions.shape

(15,)