Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-input / multi-output models #28

Closed
seanmor5 opened this issue Mar 31, 2021 · 3 comments
Closed

Support multi-input / multi-output models #28

seanmor5 opened this issue Mar 31, 2021 · 3 comments
Labels
kind:feature New feature or request

Comments

@seanmor5
Copy link
Contributor

No description provided.

@seanmor5
Copy link
Contributor Author

seanmor5 commented Apr 6, 2021

We currently already support multi input models (just use Axon.input multiple times). To address multi output models, I propose a tuple combinator:

base = Axon.input({nil, 784}) |> Axon.dense(128)

out1 = base |> Axon.softmax()
out2 = base |> Axon.relu()

model = Axon.tuple([out1, out2])

@seanmor5
Copy link
Contributor Author

seanmor5 commented Apr 7, 2021

Full support for multi-input models was added in #46.

In order to support multi-output models we need to add the tuple composite layer:

tuple(inputs, opts) :: %Axon{op: :tuple, parent: [input1, input2, ...]}

This layer will only be supported in other composite layers, as the last layer of the network, or in nx layers. In order to include in nx layers we need to also consider how to handle tuple output shapes in the nx layer. For example consider this trivial example:

{x1, x2} = {Axon.input({nil, 32}), Axon.input({nil, 32})}

Axon.tuple([x1, x2])
|> Axon.nx(fn {x1, x2} -> {Nx.cos(x1), Nx.sin(x2)} end)

We can by default wrap tuple output shapes in Axon.tuple, but the limitation is that the outputs then cannot be handled separately unless we introduce decomposition functions for working with the Axon.tuple layer.

@seanmor5
Copy link
Contributor Author

seanmor5 commented Apr 15, 2021

I've added support for multi output models: 71fce20

Rather than introducing a new layer, you can pass tuples directly to Axon.predict, Axon.init, and Axon.compile. So:

def model do
  inp = Axon.input({nil, 784})
  x = Axon.dense(inp, 128)
  y = Axon.dense(inp, 10)
  {x, y} 
end

params = Axon.init(model())

Axon.predict(model(), params, Nx.random_uniform({32, 784}))

Currently missing support for multi-output / tuple models in the training API

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind:feature New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant