In [35]:
import torch
from torch import nn 

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize

log_path = "/home/daniel/Documents/github/mcx-lab/rl_ws/src/rl-baselines3-zoo/logs/ppo/A1GymEnv-v0_2"
model_path = log_path + "/A1GymEnv-v0.zip"
model = PPO.load(model_path, deterministic=True)

# export normalizer
import pickle
with open(log_path + "/A1GymEnv-v0/vecnormalize.pkl", "rb") as file_handler:
    vec_normalize = pickle.load(file_handler)
print(vec_normalize.obs_rms.mean)
print(vec_normalize.obs_rms.var)
print(vec_normalize.epsilon)
print(vec_normalize.clip_obs)

# Convert VecNormalize to a torch layer 
# so it can be exported as a layer in the JIT module
import numpy as np
class Normalize(nn.Module):
    """ Torch implementation of Stable Baselines3 VecNormalize """
    def __init__(self, mean: np.ndarray, var: np.ndarray, epsilon: float, clip: float):
        super(Normalize, self).__init__()
        self.mean = torch.tensor(mean, dtype=torch.float32)
        self.var = torch.tensor(var, dtype=torch.float32)
        self.epsilon = torch.tensor(epsilon, dtype = torch.float32)
        self.clip = torch.tensor(clip, dtype = torch.float32)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clip(
            (x - self.mean) / torch.sqrt(self.var + self.epsilon),
            -self.clip, 
            self.clip
        )

normalize_layer = Normalize(
    vec_normalize.obs_rms.mean,
    vec_normalize.obs_rms.var,
    vec_normalize.epsilon,
    vec_normalize.clip_obs,
)

# export policy as a simple sequential network
policy = nn.Sequential(
    normalize_layer,
    *[m for m in model.policy.mlp_extractor.policy_net.modules() if not isinstance(m, nn.Sequential)], 
    model.policy.action_net
)
policy.eval()

print(policy)

[ 0.95492715 -0.13722686 -0.01805571  0.03465187  0.09965638 -0.01244039
  0.09339193  0.09982642 -0.37575056  0.64906948 -1.52038246  0.36914969
  0.78088707 -1.51810128 -0.21349837  1.09994564 -1.33341672  0.32087463
  1.07842145 -1.31339599  0.09066599 -0.04599702  0.57847774 -0.12386647
 -0.13965874  0.6348717   0.04317015 -0.0747653   0.44737636 -0.01981669
 -0.05129586  0.22097944  0.01679691  0.00310517]
[1.83628951e-01 8.45550566e-02 3.91434501e-02 2.54915481e-02
 1.39900123e-01 5.84977073e+00 2.91519409e+00 1.42981600e+00
 5.00709927e-02 1.04206885e-01 7.97609170e-02 6.65389511e-02
 9.19186959e-02 8.66994638e-02 9.58188620e-02 4.02969370e-02
 9.00284100e-02 8.37713567e-02 4.98324783e-02 1.03873686e-01
 8.87176997e+00 2.83332862e+01 1.78218377e+01 8.00413824e+00
 2.54492348e+01 2.18691213e+01 1.25091122e+01 4.72442452e+00
 1.74890142e+01 7.78838515e+00 4.67038593e+00 1.30536541e+01
 2.07140641e-05 8.75077372e-05]
1e-08
10.0
Sequential(
  (0): Normalize()
  (1): Linear(in_featur

In [32]:
# Export weights to csv
import numpy as np
import json

shapes = {}
for name, parameter in policy.named_parameters():
    name = name.replace('.', '_')
    print(name)
    param_np = parameter.detach().cpu().numpy()
    shapes[name] = param_np.shape
    print(parameter.dtype)
    print(parameter.device)
    np.savetxt(f"model/weights/{name}.csv", param_np, delimiter=",")

# Export metadata 

with open("model/metadata.json", 'w') as jsonfile:
    json.dump({'parameter_shapes': shapes}, jsonfile)

# Add a couple of expected outputs to sanity check the conversion
sample_inputs = {
    'zeros': torch.zeros(34),
    'ones': torch.ones(34)
}
for name, inp in sample_inputs.items():
    oup = policy(inp.to(model.device)).detach().cpu().numpy()
    inp = inp.detach().cpu().numpy()
    print(name, oup)
    np.savetxt(f"model/sample_outputs/{name}_in.csv", inp, delimiter=",")
    np.savetxt(f"model/sample_outputs/{name}_out.csv", oup, delimiter=",")



1_weight
torch.float32
cuda:0
1_bias
torch.float32
cuda:0
3_weight
torch.float32
cuda:0
3_bias
torch.float32
cuda:0
5_weight
torch.float32
cuda:0
5_bias
torch.float32
cuda:0
zeros [-4.394763   -4.8699784  -2.2114382   3.285827   -0.49302593 -1.6082126
 -1.4629492   0.9332253   0.7294416  -2.6659343   0.16742828 -0.6772122 ]
ones [-9.693275   -9.4261265   2.4467864   5.6924467   1.210065   -1.0285163
  4.7852015   4.5193543   0.21731241  3.306598   -1.005662    1.2743167 ]


In [36]:
# Convert to TorchScript

example = torch.rand(34).to(torch.device('cpu'))
policy = policy.to(torch.device('cpu'))
traced_script_module = torch.jit.trace(policy, example)
traced_script_module.save("normalize_and_policy_network.pt")