Skip to content

Commit

Permalink
Merge pull request #9 from jemc-savi/add/slice-op
Browse files Browse the repository at this point in the history
Add `Tensor.Op.Slice`.
  • Loading branch information
jemc committed Apr 3, 2023
2 parents 31d666f + 30c52f1 commit c7bb83d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 3 deletions.
1 change: 1 addition & 0 deletions spec/Main.savi
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
Spec.Run(Tensor.Op.Pack.Spec).new(env)
Spec.Run(Tensor.Op.Reshape.Spec).new(env)
Spec.Run(Tensor.Op.Select.Spec).new(env)
Spec.Run(Tensor.Op.Slice.Spec).new(env)
Spec.Run(Tensor.Op.Softmax.Spec).new(env)
])
70 changes: 70 additions & 0 deletions spec/Tensor.Op.Slice.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
:class Tensor.Op.Slice.Spec
:is Spec
:const describes: "Tensor.Op.Slice"

:it "slices a contiguous portion of the input tensor"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.slice!("example"
g.const!("input"
Tensor(F64).from_array([
10, 11, 12, 13, 14
15, 16, 17, 18, 19
20, 21, 22, 23, 24
25, 26, 27, 28, 29

30, 31, 32, 33, 34
35, 36, 37, 38, 39
40, 41, 42, 43, 44
45, 46, 47, 48, 49

50, 51, 52, 53, 54
55, 56, 57, 58, 59
60, 61, 62, 63, 64
65, 66, 67, 68, 69
]).try_reshape([3, 4, 5])
)
g.const!("begin_indices", Tensor(I32).from_array([1, 2, 1]))
g.const!("output_shape", Tensor(I32).from_array([2, 2, 3]))
)
)

assert: result.shape_into_array == [2, 2, 3]
assert: result.as!(Tensor(F64)).into_array == [
41, 42, 43
46, 47, 48

61, 62, 63
66, 67, 68
]
))

:it "complains on session comput if the output shape is out of bounds"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: session.compute!(
g.slice!("example"
g.const!("input"
Tensor(F64).from_array([
10, 11, 12, 13, 14
15, 16, 17, 18, 19
20, 21, 22, 23, 24
25, 26, 27, 28, 29

30, 31, 32, 33, 34
35, 36, 37, 38, 39
40, 41, 42, 43, 44
45, 46, 47, 48, 49

50, 51, 52, 53, 54
55, 56, 57, 58, 59
60, 61, 62, 63, 64
65, 66, 67, 68, 69
]).try_reshape([3, 4, 5])
)
g.const!("begin_indices", Tensor(I32).from_array([1, 2, 1]))
g.const!("output_shape", Tensor(I32).from_array([3, 2, 3]))
// (since we started at index 1 in the first dimension,
// an output shape with size 3 in that dimension is out of bounds)
)
)
)
12 changes: 9 additions & 3 deletions src/Tensor.Graph.Helper.savi
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@
:fun ref cast_with_floating_point_truncation!(name, input, output_type)
Tensor.Op.Cast.new!(@graph, name, input, output_type, True)

:fun ref pack!(name, inputs, axis USize = 0)
Tensor.Op.Pack.new!(@graph, name, inputs, axis)

:fun ref reshape!(name, input, output_shape Array(USize))
Tensor.Op.Reshape.new!(@graph, name, input
@const!("\(name).new_shape"
Expand All @@ -63,6 +60,15 @@
)
)

///
// Fan-out/Fan-in Operations

:fun ref pack!(name, inputs, axis USize = 0)
Tensor.Op.Pack.new!(@graph, name, inputs, axis)

:fun ref slice!(name, input, begin_indices, output_shape)
Tensor.Op.Slice.new!(@graph, name, input, begin_indices, output_shape)

///
// Other Unary Operations

Expand Down
21 changes: 21 additions & 0 deletions src/Tensor.Op.Slice.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
:: Create a new output tensor from a contiguous slice of the given input tensor.
::
:: Data will be taken from all dimensions of the input tensor, beginning at
:: the indices indicated by the `begin_indices` operand, and continuing up
:: through contiguous elements for the sizes indicated by the `output_shape`
:: operand. Both of these operands must therefore be vectors (rank-1 tensors)
:: with the same number of elements as the total dimension count of the input.
:struct box Tensor.Op.Slice
:is Tensor.Op
:fun non new!(graph Tensor.Graph, name
input
begin_indices
output_shape
)
@_new(graph.new_operation("Slice", name) -> (builder |
builder
.add_input(input)
.add_input(begin_indices)
.add_input(output_shape)
.finish!
))

0 comments on commit c7bb83d

Please sign in to comment.