Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows setting the dtype using a string. #1045

Merged
merged 2 commits into from May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Expand Up @@ -15,6 +15,7 @@
- `nnf_gelu()` and `nn_gelu()` gained the `approximate` argument. (#1043)
- Implemented `!=` for torch devices. (#1042)
- `load_state_dict()` for optimizers now default to cloning the tensors in the state dict, so they don't keep references to objects in the dict. (#1041)
- Allows setting the dtype with a string. (#1045)

# torch 0.10.0

Expand Down
49 changes: 33 additions & 16 deletions R/tensor.R
Expand Up @@ -67,24 +67,41 @@ Tensor <- R7Class(
cpp_tensor_numel(self$ptr)
},
to = function(dtype = NULL, device = NULL, other = NULL, non_blocking = FALSE,
copy = FALSE, memory_format = torch_preserve_format()) {
if (!is.null(other)) {
args <- list(other = other)
} else if (is.null(device)) {
args <- list(dtype = dtype)
} else {
args <- list(dtype = dtype, device = device)
copy = FALSE, memory_format = NULL) {

has_device <- !is.null(device)
has_dtype <- !is.null(dtype)
has_other <- !is.null(other)

if (has_other) {
# can't have device and dtype
if (has_device || has_dtype) {
cli::cli_abort("Had {.arg other} but {.arg device} or {.arg dtype} are non {.val NULL}")
}

return(private$`_to`(other = other, non_blocking = non_blocking, copy = copy))
}

args$non_blocking <- non_blocking
args$copy <- copy
args$memory_format <- memory_format

if (is.null(args$dtype) && is.null(args$other)) {
args$dtype <- self$dtype

if (!has_dtype) {
dtype <- self$dtype
}

if (has_device) {
private$`_to`(
dtype = dtype,
device = device,
non_blocking = non_blocking,
copy = copy,
memory_format = memory_format
)
} else {
private$`_to`(
dtype = dtype,
non_blocking = non_blocking,
copy = copy,
memory_format = memory_format
)
}

do.call(private$`_to`, args)
},
bool = function(memory_format = torch_preserve_format()) {
self$to(torch_bool(), memory_format = memory_format)
Expand Down
11 changes: 10 additions & 1 deletion inst/include/lantern/lantern.h
Expand Up @@ -2581,6 +2581,15 @@ HOST_API void lantern_cpu_set_rng_state (void* state)

}

LANTERN_API void* (LANTERN_PTR _lantern_Dtype_from_string) (void* dtype_str);
HOST_API void* lantern_Dtype_from_string (void* dtype_str)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_Dtype_from_string(dtype_str);
LANTERN_HOST_HANDLER;
return ret;
}

/* Autogen Headers -- Start */
LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking);
HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { LANTERN_CHECK_LOADED void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; }
Expand Down Expand Up @@ -10156,7 +10165,7 @@ LOAD_SYMBOL(_lantern_jit_execute);
LOAD_SYMBOL(_lantern_jit_operator_info);
LOAD_SYMBOL(_lantern_jit_all_schemas_for);
LOAD_SYMBOL(_lantern_function_schema_list_at);

LOAD_SYMBOL(_lantern_Dtype_from_string);
/* Autogen Symbols -- Start */
LOAD_SYMBOL(_lantern__cast_byte_tensor_bool)
LOAD_SYMBOL(_lantern__cast_char_tensor_bool)
Expand Down
4 changes: 3 additions & 1 deletion src/codegen.cpp
Expand Up @@ -66,7 +66,9 @@ std::string cpp_arg_to_torch_type(SEXP obj,
if (e_scalar_type && Rf_inherits(obj, "torch_dtype")) {
return "ScalarType";
}

if (e_scalar_type && is_character) {
return "ScalarType";
}
if (e_scalar_type && is_null) {
return "ScalarType";
}
Expand Down
11 changes: 10 additions & 1 deletion src/lantern/include/lantern/lantern.h
Expand Up @@ -2581,6 +2581,15 @@ HOST_API void lantern_cpu_set_rng_state (void* state)

}

LANTERN_API void* (LANTERN_PTR _lantern_Dtype_from_string) (void* dtype_str);
HOST_API void* lantern_Dtype_from_string (void* dtype_str)
{
LANTERN_CHECK_LOADED
void* ret = _lantern_Dtype_from_string(dtype_str);
LANTERN_HOST_HANDLER;
return ret;
}

/* Autogen Headers -- Start */
LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking);
HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { LANTERN_CHECK_LOADED void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; }
Expand Down Expand Up @@ -10156,7 +10165,7 @@ LOAD_SYMBOL(_lantern_jit_execute);
LOAD_SYMBOL(_lantern_jit_operator_info);
LOAD_SYMBOL(_lantern_jit_all_schemas_for);
LOAD_SYMBOL(_lantern_function_schema_list_at);

LOAD_SYMBOL(_lantern_Dtype_from_string);
/* Autogen Symbols -- Start */
LOAD_SYMBOL(_lantern__cast_byte_tensor_bool)
LOAD_SYMBOL(_lantern__cast_char_tensor_bool)
Expand Down
46 changes: 46 additions & 0 deletions src/lantern/src/Dtype.cpp
Expand Up @@ -25,6 +25,52 @@ LANTERN_DTYPE_FUN(qint32, kQInt32)
LANTERN_DTYPE_FUN(cfloat, kComplexFloat)
LANTERN_DTYPE_FUN(cdouble, kComplexDouble)
LANTERN_DTYPE_FUN(byte, kByte)

void* _lantern_Dtype_from_string (void* dtype_str) {
LANTERN_FUNCTION_START
auto str = from_raw::string(dtype_str);
auto dtype = [&str] () {
if (str == "float" || str == "float32") {
return torch::kFloat32;
} else if (str == "float64" || str == "double") {
return torch::kFloat64;
} else if (str == "float16" || str == "half") {
return torch::kFloat16;
} else if (str == "bfloat16") {
return at::kBFloat16;
} else if (str == "complex32" || str == "chalf") {
return torch::kComplexHalf;
} else if (str == "complex64" || str == "cfloat") {
return torch::kComplexFloat;
} else if (str == "complex128" || str == "cdouble") {
return torch::kComplexDouble;
} else if (str == "uint8") {
return torch::kByte;
} else if (str == "int8") {
return torch::kInt8;
} else if (str == "int16" || str == "short") {
return torch::kInt16;
} else if (str == "int32" || str == "int") {
return torch::kInt32;
} else if (str == "int64" || str == "long") {
return torch::kInt64;
} else if (str == "bool") {
return torch::kBool;
} else if (str == "quint8") {
return torch::kQUInt8;
} else if (str == "qint8") {
return torch::kQInt8;
} else if (str == "qint32") {
return torch::kQInt32;
} else if (str == "quint4x2") {
return torch::kQUInt4x2;
} else {
throw std::runtime_error("Error unknown type " + str);
}
}();
return make_raw::Dtype(dtype);
LANTERN_FUNCTION_END
}

void* _lantern_Dtype_type(void *dtype) {
LANTERN_FUNCTION_START
Expand Down
5 changes: 5 additions & 0 deletions src/torch_api.cpp
Expand Up @@ -521,6 +521,11 @@ XPtrTorchDtype from_sexp_dtype(SEXP x) {
auto out = Rcpp::as<Rcpp::XPtr<XPtrTorchDtype>>(x);
return XPtrTorchDtype(out->get_shared());
}

if (TYPEOF(x) == STRSXP) {
auto dtype_string = Rcpp::as<XPtrTorchstring>(x);
return XPtrTorchDtype(lantern_Dtype_from_string(dtype_string.get()));
}

if (TYPEOF(x) == NILSXP) {
return XPtrTorchDtype();
Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/test-dtype.R
Expand Up @@ -38,3 +38,31 @@ test_that("Default dtype", {

torch_set_default_dtype(torch_float())
})

test_that("can set select devices using strings", {
dtypes <- list(
"float32" = torch_float32(),
"float" = torch_float(),
"float64" = torch_float64(),
"double" = torch_double(),
"float16" = torch_float16(),
"half" = torch_half(),
"uint8" = torch_uint8(),
"int8" = torch_int8(),
"int16" = torch_int16(),
"short" = torch_short(),
"int32" = torch_int32(),
"int" = torch_int(),
"int64" = torch_int64(),
"long" = torch_long(),
"bool" = torch_bool()
)

for(i in seq_along(dtypes)) {
x <- torch_empty(10, 10, dtype = names(dtypes)[i])
y <- torch_empty(10, 10, dtype = dtypes[[i]])

expect_true(x$device == y$device)
}

})
9 changes: 9 additions & 0 deletions tests/testthat/test-tensor.R
Expand Up @@ -505,4 +505,13 @@ test_that("can make a byte tensor from a raw vector", {

expect_equal(as.array(ten), x)
expect_equal(rawToChar(as.array(ten)), "hello world")
})

test_that("to can change both device and dtype", {

x <- torch_randn(10, 10)
y <- x$to(dtype = "double", device = "meta")

expect_true(y$dtype == torch_double())
expect_true(y$device == torch_device("meta"))
})
2 changes: 1 addition & 1 deletion tools/create-decls.R
Expand Up @@ -30,7 +30,7 @@ make_load_symbols <- function(decls) {

decls <- readr::read_lines(
"
char* _lantern_Tensor_data_ptr_byte (void *self)
void* _lantern_Dtype_from_string (void* dtype_str)
"
)

Expand Down