**This is example of how to trace model with jit and export it to the onnx**

In [None]:
!pip install onnx
!pip install onnxruntime
!pip install git+https://github.com/Denys88/rl_games
!pip install envpool
!pip install gym
!pip install pygame
!pip install -U colabgymrender

In [None]:
from rl_games.torch_runner import Runner
import os
import yaml
import torch
import matplotlib.pyplot as plt
import gym
from IPython import display
import numpy as np
import onnx
import onnxruntime as ort
%matplotlib inline

In [None]:
!nvidia-smi -L

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir 'runs/'

In [None]:
config = {'params': {'algo': {'name': 'a2c_continuous'},
  'config': {'bound_loss_type': 'regularisation',
   'bounds_loss_coef': 0.0,
   'clip_value': False,
   'critic_coef': 4,
   'e_clip': 0.2,
   'entropy_coef': 0.0,
   'env_config': {'env_name': 'Pendulum-v1', 'seed': 5},
   'env_name': 'envpool',
   'full_experiment_name' : 'pendulum_onnx',
   'save_best_after' : 20,
   'gamma': 0.99,
   'grad_norm': 1.0,
   'horizon_length': 32,
   'kl_threshold': 0.008,
   'learning_rate': '3e-4',
   'lr_schedule': 'adaptive',
   'max_epochs': 200,
   'mini_epochs': 5,
   'minibatch_size': 1024,
   'name': 'pendulum',
   'normalize_advantage': True,
   'normalize_input': True,
   'normalize_value': True,
   'num_actors': 64,
   'player': {'render': True},
   'ppo': True,
   'reward_shaper': {'scale_value': 0.1},
   'schedule_type': 'standard',
   'score_to_win': 20000,
   'tau': 0.95,
   'truncate_grads': True,
   'use_smooth_clamp': False,
   'value_bootstrap': True},
  'model': {'name': 'continuous_a2c_logstd'},
  'network': {'mlp': {'activation': 'elu',
    'initializer': {'name': 'default'},
    'units': [32, 32]},
   'name': 'actor_critic',
   'separate': False,
   'space': {'continuous': {'fixed_sigma': True,
     'mu_activation': 'None',
     'mu_init': {'name': 'default'},
     'sigma_activation': 'None',
     'sigma_init': {'name': 'const_initializer', 'val': 0}}}},
  'seed': 5}}

In [None]:
runner = Runner()
runner.load(config)
runner.run({
    'train': True,
})

In [None]:
class ModelWrapper(torch.nn.Module):
    '''
    Main idea is to ignore outputs which we don't need from model
    '''
    def __init__(self, model):
        torch.nn.Module.__init__(self)
        self._model = model
        
        
    def forward(self,input_dict):
        input_dict['obs'] = self._model.norm_obs(input_dict['obs'])
        '''
        just model export doesn't work. Looks like onnx issue with torch distributions
        thats why we are exporting only neural network
        '''
        #print(input_dict)
        #output_dict = self._model.a2c_network(input_dict)
        #input_dict['is_train'] = False
        #return output_dict['logits'], output_dict['values']
        return self._model.a2c_network(input_dict)

In [None]:
agent = runner.create_player()
agent.restore('runs/pendulum_onnx/nn/pendulum.pth')

import rl_games.algos_torch.flatten as flatten
inputs = {
    'obs' : torch.zeros((1,) + agent.obs_shape).to(agent.device),
    'rnn_states' : agent.states,
}

with torch.no_grad():
    adapter = flatten.TracingAdapter(ModelWrapper(agent.model), inputs, allow_non_tensor=True)
    traced = torch.jit.trace(adapter, adapter.flattened_inputs, check_trace=False)
    flattened_outputs = traced(*adapter.flattened_inputs)
    print(flattened_outputs)
    
torch.onnx.export(traced, *adapter.flattened_inputs, "pendulum.onnx", verbose=True, input_names=['obs'], output_names=['mu','log_std', 'value'])

onnx_model = onnx.load("pendulum.onnx")

# Check that the model is well formed
onnx.checker.check_model(onnx_model)

In [None]:
ort_model = ort.InferenceSession("pendulum.onnx")

outputs = ort_model.run(
    None,
    {"obs": np.zeros((1, 3)).astype(np.float32)},
)
print(outputs)

In [None]:
os.environ["SDL_VIDEODRIVER"] = "dummy"

In [None]:
is_done = False

env = gym.make('Pendulum-v1')
obs = env.reset()
prev_screen = env.render(mode='rgb_array')
plt.imshow(prev_screen)
total_reward = 0
num_steps = 0

while not is_done:
    outputs = ort_model.run(None, {"obs": np.expand_dims(obs, axis=0).astype(np.float32)},)
    mu = outputs[0].squeeze(1)
    sigma = np.exp(outputs[1].squeeze(1))
    action = np.random.normal(mu, sigma)
    obs, reward, done, info = env.step(action)
    total_reward += reward
    num_steps += 1
    is_done = done

    screen = env.render(mode='rgb_array')
    plt.imshow(screen)
    display.display(plt.gcf())    
    display.clear_output(wait=True)

print(total_reward, num_steps)
display.clear_output(wait=True)