-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create mechanism for easy model composition #59
Comments
I've been experimenting a bit and after starting #81 I believe I have a solution to this issue. Introduce generator =
Axon.input({nil, 100})
|> Axon.dense(128, activation: :tanh)
|> Axon.dense(512, activation: :tanh)
|> Axon.dense(784, activation: :tanh)
|> Axon.reshape({1, 28, 28})
|> Axon.function()
discriminator =
Axon.input({nil, 1, 28, 28})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
|> Axon.function()
joint = discriminator.(generator.(Axon.input({nil, 100}))
And generator and discriminator are still separate objects. The biggest question then becomes how do execution and compilation act when they encounter an Also note the reason we can't just do: generator = fn x ->
x
|> Axon.dense(128, activation: :tanh)
|> Axon.dense(512, activation: :tanh)
|> Axon.dense(784, activation: :tanh)
|> Axon.reshape({1, 28, 28})
|> Axon.function()
end
discriminator = fn x ->
x
|> Axon.dense(128, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
|> Axon.function()
end
g = generator.(Axon.input({nil, 100})
d = discriminator.(Axon.input({nil, 784})
joint = discriminator.(generator.(Axon.input({nil, 100})) is because of how Axon's compiler works. Subsequent calls to both generator and discriminator in the above yield brand new models with new uniquely named parameters rather than yielding the same model on each call - which is what |
This is possible with blocks now |
For now, we'll only consider how this should work in the model creation and execution API, but it will touch the training API as well.
Consider the models in a basic GAN:
In order to train, what you'd want to do is something like:
And then you can alternate using
step_d
andstep_g
to train on valid / fake images. Unfortunately, we currently don't support model composition in this sense - you can define functionsgenerator
anddiscriminator
without an input block, but there's no way to cleanly determine which parameters belong to which model. Ideally, you'd be able to compose models in some way so that when you initialize, predict, train, etc. parameters are grouped:Whatever the implementation is, it will involve adding some metadata to parameters to express that expresses their ownership to a given model. From an API perspective, one option is to introduce
Axon.compose
for composing Axon structs into a single model while preserving parameter information, although I'm not sure I love that right now.The text was updated successfully, but these errors were encountered: