Skip to content

Commit

Permalink
Merge pull request #678 from mohamed-180/master
Browse files Browse the repository at this point in the history
Implement [ for nn_sequential modules.
  • Loading branch information
dfalbel committed Aug 20, 2021
2 parents 41697d5 + 5fca174 commit e2ea0c4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Expand Up @@ -29,13 +29,15 @@ S3method("==",torch_tensor)
S3method(">",torch_tensor)
S3method(">=",torch_tensor)
S3method("[",dataset)
S3method("[",nn_sequential)
S3method("[",torch_tensor)
S3method("[<-",torch_tensor)
S3method("[[",R7)
S3method("[[",enum_env)
S3method("[[",nn_Module)
S3method("[[",nn_module)
S3method("[[",nn_module_list)
S3method("[[",nn_sequential)
S3method("[[",script_method)
S3method("[[",script_module)
S3method("[[<-",R7)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
@@ -1,6 +1,7 @@
# torch (development version)

- Additional info is showed when printing tensors like if it requires grad and the grad fn. (#668, #669, #673, @mohamed-180)
- We can now subset `nn_sequential` modules using `[`. (#678, @mohamed-180)

# torch 0.5.0

Expand Down
13 changes: 13 additions & 0 deletions R/nn.R
Expand Up @@ -626,6 +626,19 @@ length.nn_sequential <- function(x) {
length(x$.__enclos_env__$private$modules_)
}

#' @export
`[[.nn_sequential` <- function(x, y) {
if (rlang::is_scalar_integerish(y))
x$.__enclos_env__$private$modules_[[y]]
else
NextMethod("[[")
}

#' @export
`[.nn_sequential` <- function(x, y) {
nn_sequential(!!!lapply(y, function(i) x[[i]]))
}

#' Holds submodules in a list.
#'
#' [nn_module_list] can be indexed like a regular R list, but
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-nn.R
Expand Up @@ -555,3 +555,24 @@ test_that("calling to doesn't modify the requires_grad attribute of a parameter"
expect_true(!x$weight$requires_grad)

})

test_that("we can subset `nn_sequential`", {


x <- nn_sequential(
nn_relu(),
nn_tanh(),
nn_relu6(),
nn_relu(),
nn_tanh()
)

expect_true(inherits(x[[1]], "nn_relu"))
expect_true(inherits(x[[3]], "nn_relu6"))

y <- x[2:4]
expect_true(inherits(y, "nn_sequential"))
expect_true(inherits(y[[1]], "nn_tanh"))
expect_true(inherits(y[[2]], "nn_relu6"))

})

0 comments on commit e2ea0c4

Please sign in to comment.