Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: torch.nn.Flatten() equivalent module for nn_sequential #716

Closed
gsgxnet opened this issue Oct 16, 2021 · 3 comments · Fixed by #773
Closed

Feature request: torch.nn.Flatten() equivalent module for nn_sequential #716

gsgxnet opened this issue Oct 16, 2021 · 3 comments · Fixed by #773

Comments

@gsgxnet
Copy link

gsgxnet commented Oct 16, 2021

I wanted to transfer the PyTorch Basics examples to R torch.
https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#nn-module
within there:

model = torch.nn.Sequential(
    torch.nn.Linear(3, 1),
    torch.nn.Flatten(0, 1)
)

I am not able to find an equivalent nn_flatten module in the R torch library.
According to https://stackoverflow.com/questions/53953460/how-to-flatten-input-in-nn-sequential-in-pytorch
torch.nn.Flatten() could be implemented within PyTorch as an own class, eg:

class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

at former times, when torch.nn.Flatten() did not exist there.
Could this be done in R torch as well? If so how?

@gsgxnet gsgxnet changed the title Feature request: torch.nn.Flatten() equivalent module for orch.nn.Sequential Feature request: torch.nn.Flatten() equivalent module for nn_sequential Oct 17, 2021
@gsgxnet
Copy link
Author

gsgxnet commented Oct 17, 2021

Okay, finally found a hint to implement the class myself within the mlverse docu:
https://torch.mlverse.org/docs/articles/examples/basic-nn-module.html and
https://torch.mlverse.org/docs/articles/getting-started/custom-nn.html
I will give that an attempt.

@dfalbel
Copy link
Member

dfalbel commented Oct 18, 2021

You could do something like this, by implementing a custom module.
Do you want to file a PR so its available by default in torch? Most modules are implemented similarly, eg: https://github.com/mlverse/torch/blob/master/R/nn-linear.R#L59-L84

library(torch)

nn_flatten <- nn_module(
  initialize = function(start_dim = 1, end_dim = -1) {
    self$start_dim <- start_dim
    self$end_dim <- end_dim
  },
  forward = function(x) {
    torch_flatten(x, start_dim = self$start_dim, end_dim = self$end_dim)
  }
)

model <- nn_sequential(
  nn_linear(3, 1),
  nn_flatten(1, 2)
)

model(torch_randn(10, 3))
#> torch_tensor
#>  0.4452
#>  0.7583
#>  0.7753
#>  0.9112
#>  0.0293
#> -0.3305
#>  0.5906
#> -0.3069
#> -0.0519
#>  0.1821
#> [ CPUFloatType{10} ][ grad_fn = <ViewBackward> ]

Created on 2021-10-18 by the reprex package (v2.0.0)

@gsgxnet
Copy link
Author

gsgxnet commented Oct 18, 2021

@dfalbel ,
thanks, yes I came to a very similar solution after some trial and error. And I looked up your implementations of other modules and found they are, as you stated, plain R code. In principle much simpler than expected.
I would file a PR, but 2 obstacles exist:

  • While knowing how to clone a github repo and how to work on a distinct local branch of that, I am still not sure how to create a PR. Could you point me to a good tutorial to do it the right way?
  • PyTorch comes with a unflatten module as well. Much more complicated. And I think only both together would make a good PR.

Besides these beginners problems I would like very much to contribute those little contributions I feel I could do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants