Skip to content

Commit

Permalink
Merge pull request #9 from microsoft/torch_policy
Browse files Browse the repository at this point in the history
Remove dynamix axis
  • Loading branch information
gabrielmittag committed Dec 7, 2023
2 parents ba9409d + d1a8e59 commit 7258bda
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,22 @@ def forward(self, x, h, c):
torchBwModel.eval()
torch.onnx.export(
torchBwModel,
(torch_dummy_inputs, torch_initial_hidden_state, torch_initial_cell_state),
(torch_dummy_inputs[0:1, 0:1, :], torch_initial_hidden_state, torch_initial_cell_state),
model_path,
opset_version=11,
input_names=['obs', 'hidden_states', 'cell_states'], # the model's input names
output_names=['output', 'state_out', 'cell_out'], # the model's output names
dynamic_axes={
'obs' : {0: 'batch_size', 1: 'seq_len'},
'hidden_states' : {0: 'batch_size'},
'cell_states' : {0: 'batch_size'},
'state_out' : {0: 'batch_size'},
'cell_out' : {0: 'batch_size'},
'output' : {0: 'batch_size', 1: 'seq_len'},
}
)

# verify tf and onnx models outputs
# verify torch and onnx models outputs
ort_session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
onnx_hidden_state, onnx_cell_state = (np.zeros((1, hidden_size), dtype=np.float32), np.zeros((1, hidden_size), dtype=np.float32))
torch_hidden_state, torch_cell_state = (torch.as_tensor(onnx_hidden_state), torch.as_tensor(onnx_cell_state))
# online interaction: step through the environment 1 time step at a time
with torch.no_grad():
for i in tqdm(range(dummy_inputs.shape[1])):
torch_estimate, torch_hidden_state, torch_cell_state = torchBwModel(torch_dummy_inputs[0:1, i:i+1, :], torch_hidden_state, torch_cell_state)
feed_dict= {'obs': dummy_inputs[0:1,i:i+1,:], 'hidden_states': onnx_hidden_state, 'cell_states': onnx_cell_state}
feed_dict= {'obs': dummy_inputs[0:1, i:i+1, :], 'hidden_states': onnx_hidden_state, 'cell_states': onnx_cell_state}
onnx_estimate, onnx_hidden_state, onnx_cell_state = ort_session.run(None, feed_dict)
assert np.allclose(torch_estimate.numpy(), onnx_estimate, atol=1e-6), 'Failed to match model outputs!'
assert np.allclose(torch_hidden_state, onnx_hidden_state, atol=1e-7), 'Failed to match hidden state1'
Expand Down

0 comments on commit 7258bda

Please sign in to comment.