Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export as ONNX Model #25

Closed
CesMak opened this issue Mar 24, 2020 · 1 comment
Closed

Export as ONNX Model #25

CesMak opened this issue Mar 24, 2020 · 1 comment

Comments

@CesMak
Copy link

CesMak commented Mar 24, 2020

Hey,

Thanks for sharing this awesome code!

I would like to export my result also as onnx model. However I have no idea how to use it then... currently it did not work for me:

This is how I export it:

    torch_out = torch.onnx._export(ppo.policy, input_vector, path+".onnx",  export_params=True)

To get this to work I had to implement a forward as well:


class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, n_latent_var):
        super(ActorCritic, self).__init__()
        #... same as your code

    def forward(self, state_input):
        return torch.tensor(self.act(state_input, None))

    def act(self, state, memory):
        if type(state) is np.ndarray:
            state = torch.from_numpy(state).float().to(device)
        action_probs = self.action_layer(state)
        # here make a filter for only possible actions!
        #probs = probs * memory.leagalCards
        dist = Categorical(action_probs)

        action = dist.sample()

        if memory is not None:
            memory.states.append(state)
            memory.actions.append(action)
            memory.logprobs.append(dist.log_prob(action))

        return action.item()

Now I tried to use my onnx model like this:

But it returns always the same action :(


def getOnnxAction(path, x):
        '''Input:
        x:      240x1 list binary values
        path    *.onnx (with correct model)'''
        ort_session = onnxruntime.InferenceSession(path)
        ort_inputs  = {ort_session.get_inputs()[0].name: np.asarray(x, dtype=np.float32)}
        ort_outs    = ort_session.run(None, ort_inputs)
        return np.asarray(ort_outs)[0]

Any ideas what is going wrong here?

@CesMak
Copy link
Author

CesMak commented Mar 27, 2020

I found the solution.

torch.onnx.export(ppo_test.policy_old.action_layer, torch.rand(240), path+".onnx")

You can close this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants