Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create mechanism for easy model composition #59

Closed
seanmor5 opened this issue Apr 19, 2021 · 2 comments
Closed

Create mechanism for easy model composition #59

seanmor5 opened this issue Apr 19, 2021 · 2 comments
Labels
kind:feature New feature or request note:discussion Details or approval are up for discussion

Comments

@seanmor5
Copy link
Contributor

For now, we'll only consider how this should work in the model creation and execution API, but it will touch the training API as well.

Consider the models in a basic GAN:

generator =
  Axon.input({nil, 100})
  |> Axon.dense(128, activation: :tanh)
  |> Axon.dense(512, activation: :tanh)
  |> Axon.dense(784, activation: :tanh)
  |> Axon.reshape({1, 28, 28})

discriminator =
  Axon.input({nil, 1, 28, 28})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)

In order to train, what you'd want to do is something like:

combined = compose(discriminator, generator)  # represents D(G(input)) 
step_d = Axon.Training.step(discriminator, :binary_cross_entropy, Axon.Optimizers.sgd(0.005)
step_g = Axon.Training.step(combined, :binary_cross_entropy, Axon.Optimizers.adam(0.01)

And then you can alternate using step_d and step_g to train on valid / fake images. Unfortunately, we currently don't support model composition in this sense - you can define functions generator and discriminator without an input block, but there's no way to cleanly determine which parameters belong to which model. Ideally, you'd be able to compose models in some way so that when you initialize, predict, train, etc. parameters are grouped:

combined = compose(discriminator, generator)
{d_params, g_params} = combined_params = Axon.init(combined)
Axon.predict(combined, combined_params)

{{d_params, g_params}, _} =
  combined
  |> Axon.Training.step(:binary_cross_entropy, Axon.Optimizers.adam(0.01)
  |> Axon.Training.train(inputs, targets)

Whatever the implementation is, it will involve adding some metadata to parameters to express that expresses their ownership to a given model. From an API perspective, one option is to introduce Axon.compose for composing Axon structs into a single model while preserving parameter information, although I'm not sure I love that right now.

@seanmor5 seanmor5 added kind:feature New feature or request note:discussion Details or approval are up for discussion labels Apr 19, 2021
@seanmor5
Copy link
Contributor Author

I've been experimenting a bit and after starting #81 I believe I have a solution to this issue. Introduce Axon.function. The idea is that Axon.function takes a block of layers with inputs and returns an anonymous function with arity matching the number of inputs in the block. So the GAN would look like:

generator =
  Axon.input({nil, 100})
  |> Axon.dense(128, activation: :tanh)
  |> Axon.dense(512, activation: :tanh)
  |> Axon.dense(784, activation: :tanh)
  |> Axon.reshape({1, 28, 28})
  |> Axon.function()

discriminator =
  Axon.input({nil, 1, 28, 28})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)
  |> Axon.function()

joint = discriminator.(generator.(Axon.input({nil, 100}))

And generator and discriminator are still separate objects. The biggest question then becomes how do execution and compilation act when they encounter an Axon.function.

Also note the reason we can't just do:

generator = fn x ->
  x
  |> Axon.dense(128, activation: :tanh)
  |> Axon.dense(512, activation: :tanh)
  |> Axon.dense(784, activation: :tanh)
  |> Axon.reshape({1, 28, 28})
  |> Axon.function()
end

discriminator = fn x ->
  x
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)
  |> Axon.function()
end

g = generator.(Axon.input({nil, 100})
d = discriminator.(Axon.input({nil, 784})

joint = discriminator.(generator.(Axon.input({nil, 100}))

is because of how Axon's compiler works. Subsequent calls to both generator and discriminator in the above yield brand new models with new uniquely named parameters rather than yielding the same model on each call - which is what Axon.function would do.

@seanmor5
Copy link
Contributor Author

This is possible with blocks now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind:feature New feature or request note:discussion Details or approval are up for discussion
Projects
None yet
Development

No branches or pull requests

1 participant