Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unwrapping arguments passed from R #454

Merged
merged 39 commits into from Feb 12, 2021
Merged

Conversation

dfalbel
Copy link
Member

@dfalbel dfalbel commented Jan 26, 2021

Allows unwrapping arguments from R without using autogenerated code.

This will eventually allow us to remove the argument_to_torch function that is the main overhead when calling torch functions.

torch/R/codegen-utils.R

Lines 34 to 169 in 480a011

argument_to_torch_type <- function(obj, expected_types, arg_name) {
if (is.name(obj))
return(NULL)
if (any(arg_name == c("index", "indices", "dims")) && any("Tensor" == expected_types) && is_torch_tensor(obj))
return(list(as_1_based_tensor(obj), "Tensor"))
if (any("Tensor" == expected_types) && is_torch_tensor(obj))
return(list(obj, "Tensor"))
if (any("Scalar" == expected_types) && is_torch_scalar(obj))
return(list(obj$ptr, "Scalar"))
if (any("DimnameList" == expected_types) && is_torch_dimname_list(obj))
return(list(obj$ptr, "DimnameList"))
if (arg_name == "indices" && any("TensorList" == expected_types) && is_torch_tensor_list(obj))
return(list(as_1_based_tensor_list(obj)$ptr, "TensorList"))
if (any("TensorList" == expected_types) && is_torch_tensor_list(obj))
return(list(obj$ptr, "TensorList"))
if (any("TensorOptions" == expected_types) && is_torch_tensor_options(obj))
return(list(obj$ptr, "TensorOptions"))
if (any("MemoryFormat" == expected_types) && is_torch_memory_format(obj))
return(list(obj$ptr, "MemoryFormat"))
if (any("ScalarType" == expected_types) && is_torch_dtype(obj))
return(list(obj$ptr, "ScalarType"))
if (any("ScalarType" == expected_types) && is.null(obj))
return(list(cpp_nullopt(), "ScalarType"))
if (any("Scalar" == expected_types) && is_scalar_atomic(obj))
return(list(torch_scalar(obj)$ptr, "Scalar"))
if (arg_name == "index" && any("Tensor" == expected_types) && is.atomic(obj) && !is.null(obj))
return(list(torch_tensor(obj - 1, dtype = torch_long())$ptr, "Tensor"))
if (any("Tensor" == expected_types) && is.atomic(obj) && !is.null(obj))
return(list(torch_tensor(obj)$ptr, "Tensor"))
if (any("DimnameList" == expected_types) && is.character(obj))
return(list(torch_dimname_list(obj)$ptr, "DimnameList"))
if (any("IntArrayRef" == expected_types) && (is.numeric(obj) || is.list(obj)) && arg_name %in% c("dims", "dims_self", "dims_other"))
return(list(as_1_based_dim(obj), "IntArrayRef"))
if (any("IntArrayRef" == expected_types) && any("DimnameList" == expected_types) && is.numeric(obj))
return(list(as_1_based_dim(obj), "IntArrayRef"))
if (any("IntArrayRef" == expected_types) && is.numeric(obj))
return(list(as.integer(obj), "IntArrayRef"))
if (any("IntArrayRef" == expected_types) && is.list(obj))
return(list(as.integer(obj), "IntArrayRef"))
if (any("ArrayRef<double>" == expected_types) && is.numeric(obj))
return(list(obj, "ArrayRef<double>"))
if (any("IntArrayRef" == expected_types) && is.null(obj))
return(list(NULL, "IntArrayRef"))
if (any("ArrayRef<double>" == expected_types) && is.null(obj))
return(list(NULL, "ArrayRef<double>"))
if (any("int64_t" == expected_types) && is.numeric(obj) && length(obj) == 1 && any(arg_name == c("dim", "dim0", "dim1", "dim2", "start_dim", "end_dim", "index")))
return(list(as_1_based_dim(obj), "int64_t"))
if (any("int64_t" == expected_types) && is.numeric(obj) && length(obj) == 1)
return(list(as.integer(obj), "int64_t"))
if (any("bool" == expected_types) && is.logical(obj) && length(obj) == 1)
return(list(obj, "bool"))
if (any("double" == expected_types) && is.numeric(obj) && length(obj) == 1)
return(list(as.double(obj), "double"))
if (any("std::string" == expected_types) && is.character(obj))
return(list(obj, "std::string"))
if (any(c("std::array<bool,4>", "std::array<bool,3>", "std::array<bool,2>") %in% expected_types) && is.logical(obj))
return(list(obj, paste0("std::array<bool,", length(obj), ">")))
if (any("TensorOptions" == expected_types) && is.list(obj))
return(list(as_torch_tensor_options(obj)$ptr, "TensorOptions"))
if (arg_name == "indices" && any("TensorList" == expected_types) && is.list(obj))
return(list(torch_tensor_list(lapply(obj, function(x) x$sub(1L, 1L)))$ptr, "TensorList"))
if (any("TensorList" == expected_types) && is.list(obj))
return(list(torch_tensor_list(obj)$ptr, "TensorList"))
if (any("MemoryFormat" == expected_types) && is.null(obj))
return(list(cpp_nullopt(), "MemoryFormat"))
if (any("Generator" == expected_types) && is_torch_generator(obj))
return(list(obj$ptr, "Generator"))
if (any("Generator" == expected_types) && is.null(obj))
return(list(.generator_null$ptr, "Generator"))
if (any("Scalar" == expected_types) && is.null(obj))
return(list(cpp_nullopt(), "Scalar"))
if (any("int64_t" == expected_types) && is.null(obj))
return(list(NULL, "int64_t"))
if (any("Tensor" == expected_types) && length(obj) == 0 && is.list(obj))
return(list(cpp_tensor_undefined(), "Tensor"))
if (any("Tensor" == expected_types) && is.null(obj))
return(list(cpp_tensor_undefined(), "Tensor"))
if (any("double" == expected_types) && is.null(obj))
return(list(NULL, "double"))
if (any("Device" == expected_types) && is_torch_device(obj))
return(list(obj$ptr, "Device"))
if (any("Device" == expected_types) && is.character(obj))
return(list(torch_device(obj)$ptr, "Device"))
if (any("TensorList" == expected_types) && is.numeric(obj))
return(list(torch_tensor_list(list(torch_tensor(obj)))$ptr, "TensorList"))
if (any("TensorList" == expected_types) && is_torch_tensor(obj))
return(list(torch_tensor_list(list(obj))$ptr, "TensorList"))
if (any("Scalar" == expected_types) && is_torch_tensor(obj))
return(list(torch_scalar(obj$item())$ptr, "Scalar"))
stop("Can't convert argument", call.=FALSE)
}

TODO:

  • int64_t index autocastiing
  • IndexTensorList autocasting
  • remove all other transformation of obj in argument_to_torch_type.

This allows us to gain some speedups, specially with small models where the R overhead can be significatnly larger.

A simple benchmark gives us 37% speedup for small ops:

x <- torch_randn(10, 10)
w <- torch_randn(10, 10)

bench::mark(
  a = torch_mm(x, w)
)

CRAN torch:

# A tibble: 1 x 13
  expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory time  gc   
  <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list> <lis> <lis>
1 a           54.6µs  62.9µs    14855.    19.9KB     14.8  7016     7      472ms <trch… <Rpro… <bch… <tib…

This PR:

# A tibble: 1 x 13
  expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory time  gc   
  <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list> <lis> <lis>
1 a           35.3µs  39.1µs    23586.    4.98KB     9.44  9996     4      424ms <trch… <Rpro… <bch… <tib…

This PR also allows us to rewrite the dispatcher in C++ in the future.

@dfalbel dfalbel added the lantern Use this label if your PR affects lantern so it's built in the CI label Feb 4, 2021
@dfalbel dfalbel merged commit a6e5bca into master Feb 12, 2021
@dfalbel dfalbel deleted the refactor-argument-wrapping branch February 12, 2021 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
lantern Use this label if your PR affects lantern so it's built in the CI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant