-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: torch.nn.Flatten() equivalent module for nn_sequential #716
Comments
Okay, finally found a hint to implement the class myself within the mlverse docu: |
You could do something like this, by implementing a custom module. library(torch)
nn_flatten <- nn_module(
initialize = function(start_dim = 1, end_dim = -1) {
self$start_dim <- start_dim
self$end_dim <- end_dim
},
forward = function(x) {
torch_flatten(x, start_dim = self$start_dim, end_dim = self$end_dim)
}
)
model <- nn_sequential(
nn_linear(3, 1),
nn_flatten(1, 2)
)
model(torch_randn(10, 3))
#> torch_tensor
#> 0.4452
#> 0.7583
#> 0.7753
#> 0.9112
#> 0.0293
#> -0.3305
#> 0.5906
#> -0.3069
#> -0.0519
#> 0.1821
#> [ CPUFloatType{10} ][ grad_fn = <ViewBackward> ] Created on 2021-10-18 by the reprex package (v2.0.0) |
@dfalbel ,
Besides these beginners problems I would like very much to contribute those little contributions I feel I could do. |
I wanted to transfer the PyTorch Basics examples to R torch.
https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#nn-module
within there:
I am not able to find an equivalent
nn_flatten
module in the Rtorch
library.According to https://stackoverflow.com/questions/53953460/how-to-flatten-input-in-nn-sequential-in-pytorch
torch.nn.Flatten()
could be implemented within PyTorch as an own class, eg:at former times, when
torch.nn.Flatten()
did not exist there.Could this be done in R
torch
as well? If so how?The text was updated successfully, but these errors were encountered: