### Export to ONNX

In [12]:
from stable_baselines3 import PPO
import torch
from torch.nn import Parameter
import gym

class OnnxablePolicy(torch.nn.Module):
    def __init__(self, policy, output_sizes: list):
        super(OnnxablePolicy, self).__init__()
        self.extractor = policy.mlp_extractor
        self.action_net = policy.action_net
        self.value_net = policy.value_net

        version_number = torch.tensor([3], dtype=torch.float32, device='cpu')
        self.version_number = Parameter(version_number, requires_grad=False)
        memory_size = torch.tensor([0], dtype=torch.float32, device='cpu')
        self.memory_size = Parameter(memory_size, requires_grad=False)

        action_out_shape = torch.tensor([output_sizes[0]], dtype=torch.float32, device='cpu')
        self.action_out_shape = Parameter(action_out_shape, requires_grad=False)

    def forward(self, observation0, observation1):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        observation = torch.concat((observation0, observation1), dim=1)
        action_hidden, value_hidden = self.extractor(observation)

        action_out = self.action_net(action_hidden)
        value_out = self.value_net(value_hidden)
        
        return self.version_number, self.memory_size, action_out, self.action_out_shape


# Example: model = PPO("MlpPolicy", "Pendulum-v1")
model = PPO.load('TargetSeeker.zip')
model.policy.to("cpu")
input_sizes=[85, 12]
output_sizes=[3]
onnxable_model = OnnxablePolicy(model.policy, output_sizes)

# Input names
input_names=['obs_0', 'obs_1']

# Output names
output_names = ['version_number', 'memory_size']
if isinstance(model.policy.action_space, gym.spaces.Discrete):
    output_names += ['discrete_actions']
    output_names += ['discrete_action_output_shape']
if isinstance(model.policy.action_space, gym.spaces.Box):
    output_names += ['continuous_actions']
    output_names += ['continuous_action_output_shape']

# Dynamic axes
dynamic_axes={}
for name in input_names+output_names[-2:]:
    dynamic_axes[name] = {0: 'batch'}

print(f'{input_names}\n{output_names}\n{dynamic_axes}')

# Export the model to ONNX
dummy_input = tuple(torch.randn(1, x) for x in input_sizes)
torch.onnx.export(
    onnxable_model,
    dummy_input, 
    'TargetSeeker.onnx',
    opset_version=9, 
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes
)

['obs_0', 'obs_1']
['version_number', 'memory_size', 'continuous_actions', 'continuous_action_output_shape']
{'obs_0': {0: 'batch'}, 'obs_1': {0: 'batch'}, 'continuous_actions': {0: 'batch'}, 'continuous_action_output_shape': {0: 'batch'}}


### Test exported model

In [21]:
##### Load and test with onnx

import onnx
import onnxruntime as ort
import numpy as np

onnx_path = 'TargetSeeker.onnx'
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

print([x.name for x in onnx_model.graph.input])
print([x.name for x in onnx_model.graph.output])

observation0 = np.zeros((1, 85)).astype(np.float32)
observation1 = np.zeros((1, 12)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
out = ort_sess.run(None, {'obs_0': observation0, 'obs_1': observation1})
print(out)

['obs_0', 'obs_1']
['version_number', 'memory_size', 'continuous_actions', 'continuous_action_output_shape']
[array([3.], dtype=float32), array([0.], dtype=float32), array([[ 0.09262253,  0.02039998, -0.03820657]], dtype=float32), array([3.], dtype=float32)]
