Skip to content

Commit

Permalink
Raise error when comparing a device with something that's not a dtype. (
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Aug 14, 2023
1 parent c45fe87 commit 28a5b09
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions R/dtype.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ torch_qint32 <- function() torch_dtype$new(cpp_torch_qint32())

#' @export
`==.torch_dtype` <- function(e1, e2) {
if (!is_torch_dtype(e1) || !is_torch_dtype(e2)) {
runtime_error("One of the objects is not a dtype. Comparison is not possible.")
}
cpp_dtype_to_string(e1$ptr) == cpp_dtype_to_string(e2$ptr)
}

Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-dtype.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,12 @@ test_that("can set select devices using strings", {
}

})

test_that("error when comparing dtypes", {

expect_error(
NULL == torch_float64(),
"not a dtype"
)

})

0 comments on commit 28a5b09

Please sign in to comment.