diff --git a/R/transforms-tensor.R b/R/transforms-tensor.R index 79f5221e..fae1414e 100644 --- a/R/transforms-tensor.R +++ b/R/transforms-tensor.R @@ -340,7 +340,7 @@ transform_adjust_contrast.torch_tensor <- function(img, contrast_factor) { #' @export transform_adjust_hue.torch_tensor <- function(img, hue_factor) { - if (hue_factor < 0.5 || hue_factor > 0.5) + if (hue_factor < -0.5 || hue_factor > 0.5) value_error("hue_factor must be between -0.5 and 0.5.") check_img(img) diff --git a/tests/testthat/test-transforms.R b/tests/testthat/test-transforms.R index 27ac927b..e38f652a 100644 --- a/tests/testthat/test-transforms.R +++ b/tests/testthat/test-transforms.R @@ -185,3 +185,15 @@ test_that("linear transformation", { expect_equal(dim(out), c(3, 24, 32)) }) +test_that("adjust hue", { + + hue_factor <- c(-0.45, -0.25, 0.0, 0.25, 0.45) + x <- torch::torch_rand(3, 24, 32) + + for (f in hue_factor) { + out <- transform_adjust_hue(x, f) + expect_equal(dim(out), dim(x)) + } + +}) +