Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

what's reason for making register_module internal? #515

Closed
fwaris opened this issue Feb 13, 2022 · 8 comments
Closed

what's reason for making register_module internal? #515

fwaris opened this issue Feb 13, 2022 · 8 comments

Comments

@fwaris
Copy link
Contributor

fwaris commented Feb 13, 2022

I am building a functional wrapper around torchsharp. It uses register_module to compose modules together in a functional style. It is now broken. Please advise as to an appropriate replacement.

Making it internal enforces a very OO view on the models - which is not ideal in my view.

Example:

open TorchSharp
open TorchSharp.Fun

let model = 
   torch.nn.Linear(100L,10L)
  ->> torch.nn.Dropout(0.1)
  ->> torch.nn.Relu()

let t = mode.forward (...)

There are extensions for more complex forward invocations but in general it makes the structure of the model very apparent and its easier to experiment with variations on the base structure during model development.

@lostmsu
Copy link
Contributor

lostmsu commented Feb 13, 2022

While I don't like the implicit submodule registration via reflection that we have currently, I don't see exactly what is the problem you are facing.

I forgot details of operator overloading in F#, but you could do

Sequential Chain(Module a, Module b) => torch.nn.Sequential(a, b);
Sequential Chain(Sequential s, Module m) => torch.nn.Sequential(s.layers.Append(m));

@fwaris
Copy link
Contributor Author

fwaris commented Feb 13, 2022

the ->> operator implementation calls register_module to bind the submodule (but now it cannot).

Also other similar functions e.g. register_parameter, etc. are still public. register_module has been accessible thus far without any adverse affects. Marking it internal needlessly limits future extensibility.

@lostmsu
Copy link
Contributor

lostmsu commented Feb 14, 2022

@fwaris why my Chain sample can not work as your ->> operator?

@NiklasGustafsson
Copy link
Contributor

@fwaris -- the reason I made it internal is that it's not in the Pytorch APIs. If there's a strong reason to have it be public, then it's not a big deal to make it so.

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Feb 14, 2022

Also, @fwaris, is your "->>" operator similar to the TorchSharp F# operator -->?

    override _.forward(input) =
        input
        --> conv1 --> relu --> conv2 --> relu --> pool1 --> dropout1
        --> flatten
        --> fc1 --> relu --> dropout2 --> fc2
        --> logsm

Defined in Tensor.cs:

 public static Tensor op_MinusMinusGreater(Tensor t, torch.nn.Module m) => m.forward(t);

@fwaris
Copy link
Contributor Author

fwaris commented Feb 14, 2022

@NiklasGustafsson
Thanks. Would really appreciate if register_module is public again. It will be easier to build other APIs on top of TorchSharp.

The '->>' operator is inspired by DiffSharp. It binds two modules together (to create joined modules).

The --> operator invokes the forward function of a module with a tensor. The two behave differently but can be intermixed.

@NiklasGustafsson
Copy link
Contributor

The next release will have this change, but it will take a little bit, because that is going to be a huge one (adding support for parameter groups).

@fwaris
Copy link
Contributor Author

fwaris commented Feb 15, 2022

Thanks @NiklasGustafsson. I can work with the older version for now. Nice to know that we are getting parameter group support. Big models are hard to train and need to use many tricks, e.g. different learning rates for different layers, etc.

@fwaris fwaris closed this as completed Feb 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants