Skip to content

Commit

Permalink
implemented 'array' for Tensor objects, added 'sweep' method for appl…
Browse files Browse the repository at this point in the history
…ying operations across dimensions and associated unit tests, this also led to requiring additional paste0 call when dealing with multiple arguments
  • Loading branch information
Determan authored and Determan committed Dec 19, 2017
1 parent fdac4d0 commit 3ca0aef
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 3 deletions.
60 changes: 57 additions & 3 deletions R/Tensor.R
Expand Up @@ -195,13 +195,27 @@ Tensor <- R6Class("Tensor",
}
private$.initializer = FALSE
},
"array" = {
self$tensor = initializer
if(missing(shape)){
private$.shape = private$.get_shape(initializer)
}
private$.initializer = FALSE
},
"numeric" = {
self$tensor = initializer
if(missing(shape)){
private$.shape = private$.get_shape(initializer)
}
private$.initializer = FALSE
},
"integer" = {
self$tensor = initializer
if(missing(shape)){
private$.shape = private$.get_shape(initializer)
}
private$.initializer = FALSE
},
{
if(inherits(initializer, "gpuMatrix") | inherits(initializer, "vclMatrix")){

Expand Down Expand Up @@ -1250,6 +1264,31 @@ Tensor <- R6Class("Tensor",
invisible(self)
},

sweep = function(MARGIN, STATS, FUN, name = NA){
name = private$.createName(name)

args = c(paste0("MARGIN = ", MARGIN),
paste0("FUN = '", FUN, "'"))

x_tensor = if(!is(STATS, "Tensor")) Tensor$new(STATS) else STATS

# function is single input operation, so take last node
input_shapes = if(length(self$graph) > 0) tail(self$graph, 1)[[1]]$output_shapes else list()
# function doesn't change shape
output_shapes = input_shapes

Node$new(self,
ops = list(Operation$new("sweep", args = args)),
name = name,
input_nodes = if(length(self$graph) > 0) tail(self$graph, 1) else list(),
output_nodes = list(),
input_tensors = list("STATS" = x_tensor),
input_shapes = input_shapes,
output_shapes = output_shapes)

invisible(self)
},

compute = function(feed_list = NA){
if(private$.initializer){

Expand Down Expand Up @@ -1344,11 +1383,16 @@ Tensor <- R6Class("Tensor",

}else{
if(!is.null(args)){
f = parse(text = paste(func, '(output,', args, ')'))
f = parse(text = paste(func, '(output,', paste0(args, collapse = ", "), ')'))
}else{
f = parse(text = paste(func, '(output)'))
}

# print('args')
# print(args)
# print('expression')
# print(f)

output = eval(f)

}
Expand Down Expand Up @@ -1381,9 +1425,16 @@ Tensor <- R6Class("Tensor",
if(!is.null(args)){
# print('args not null')
if(is.na(op$order)){
f = parse(text = paste(prefix, '(output,', inputs, ", ", args, ')'))
f = parse(text = paste(prefix, '(output,',
inputs, ", ",
paste0(args, collapse = ", "),
')'))
}else{
f = parse(text = paste(prefix, '(', inputs, ", output,", args, ')'))
f = parse(text = paste(prefix, '(',
inputs,
", output,",
paste0(args, collapse = ", "),
')'))
}
}else{

Expand All @@ -1398,6 +1449,8 @@ Tensor <- R6Class("Tensor",
}
}

# print('args')
# print(args)
# print('expression')
# print(f)

Expand Down Expand Up @@ -1540,6 +1593,7 @@ Tensor <- R6Class("Tensor",
"integer" = length(value),
"numeric" = length(value),
"matrix" = dim(value),
"array" = dim(value),
{
if(is(value, "gpuMatrix") | is(value, "vclMatrix")){
dim(value)
Expand Down
61 changes: 61 additions & 0 deletions tests/testthat/test_apply.R
@@ -0,0 +1,61 @@
library(lazytensor)
context("Apply Operations")

# basic matrix - 2D
mat <- matrix(1:12, nrow = 4)

## array - 3D
A <- array(1:24, dim = 4:2)


test_that("Sweep Matrix", {

# columns
baseC <- sweep(mat, 1, seq(4), FUN = '-')
# rows
baseR <- sweep(mat, 2, seq(3), FUN = '-')

# Tensor computations
A_tensor <- Tensor$new(mat)
tensorC <- A_tensor$sweep(1, seq(4), FUN = '-')$compute()
A_tensor$drop()
tensorR <- A_tensor$sweep(2, seq(3), FUN = '-')$compute()

expect_equal(tensorC, baseC, tolerance=.Machine$double.eps ^ 0.5,
info="columnwise sweep elements not equivalent",
check.attributes=FALSE)
expect_equal(tensorR, baseR, tolerance=.Machine$double.eps ^ 0.5,
info="rowwise sweep elements not equivalent",
check.attributes=FALSE)
})

test_that("Sweep Array (3D)", {

# columns
baseC <- sweep(A, 1, seq(4), FUN = '-')
# rows
baseR <- sweep(A, 2, seq(3), FUN = '-')
# N dimension
baseN <- sweep(A, 3, seq(2), FUN = '-')

# Tensor computations
A_tensor <- Tensor$new(A)
tensorC <- A_tensor$sweep(1, seq(4), FUN = '-')$compute()
A_tensor$drop()
tensorR <- A_tensor$sweep(2, seq(3), FUN = '-')$compute()
A_tensor$drop()
tensorN <- A_tensor$sweep(3, seq(2), FUN = '-')$compute()

expect_equal(tensorC, baseC, tolerance=.Machine$double.eps ^ 0.5,
info="columnwise sweep elements not equivalent",
check.attributes=FALSE)
expect_equal(tensorR, baseR, tolerance=.Machine$double.eps ^ 0.5,
info="rowwise sweep elements not equivalent",
check.attributes=FALSE)
expect_equal(tensorN, baseN, tolerance=.Machine$double.eps ^ 0.5,
info="N-wise sweep elements not equivalent",
check.attributes=FALSE)
})



0 comments on commit 3ca0aef

Please sign in to comment.