Skip to content

Commit

Permalink
Merge pull request #18 from jemc-savi/add/op
Browse files Browse the repository at this point in the history
Add `Tensor.Comp.Tensordot.Outer`, a special case of tensordot.
  • Loading branch information
jemc committed Jun 12, 2023
2 parents 897a9aa + bdbe597 commit b755d53
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 0 deletions.
5 changes: 5 additions & 0 deletions spec/Main.savi
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
Spec.Run(Tensor.Op.SplitV.Spec).new(env)
Spec.Run(Tensor.Op.Square.Spec).new(env)

///
// Tensor.Comp specs

Spec.Run(Tensor.Comp.TensorDot.Outer.Spec).new(env)

///
// Tensor.Gen specs

Expand Down
66 changes: 66 additions & 0 deletions spec/Tensor.Comp.TensorDot.Outer.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
:class Tensor.Comp.TensorDot.Outer.Spec
:is Spec
:const describes: "Tensor.Comp.TensorDot.Outer"

:it "is equivalent to matrix multiplication for rank-2 tensors"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.tensordot_outer!("example"
g.const!("A", Tensor(F64).from_array([
1.0, 2.0
3.0, 4.0
]).try_reshape([2, 2]))
g.const!("B", Tensor(F64).from_array([
5.0, 6.0
7.0, 8.0
]).try_reshape([2, 2]))
)
)

assert: result.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 2.0 * 7.0, 1.0 * 6.0 + 2.0 * 8.0 // Arow1⋅Bcol1, Arow1⋅Bcol2
3.0 * 5.0 + 4.0 * 7.0, 3.0 * 6.0 + 4.0 * 8.0 // Arow2⋅Bcol1, Arow2⋅Bcol2
]
))

:it "handles larger-rank tensors by applying to the outer axes"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.tensordot_outer!("example"
g.const!("A", Tensor(F64).from_array([
1, 2, 3
4, 5, 6

7, 8, 9
10, 11, 12
]).try_reshape([2, 2, 3]))
g.const!("B", Tensor(F64).from_array([
13, 14
15, 16

17, 18
19, 20

21, 22
23, 24
]).try_reshape([3, 2, 2]))
)
)

assert: result.shape_into_array == [2, 2, 2, 2]
assert: result.as!(Tensor(F64)).into_array == [
1.0 * 13.0 + 2.0 * 17.0 + 3.0 * 21.0, 1.0 * 14.0 + 2.0 * 18.0 + 3.0 * 22.0 // A00_ ⋅ B_00, A00_ ⋅ B_01
1.0 * 15.0 + 2.0 * 19.0 + 3.0 * 23.0, 1.0 * 16.0 + 2.0 * 20.0 + 3.0 * 24.0 // A00_ ⋅ B_10, A00_ ⋅ B_11

4.0 * 13.0 + 5.0 * 17.0 + 6.0 * 21.0, 4.0 * 14.0 + 5.0 * 18.0 + 6.0 * 22.0 // A01_ ⋅ B_00, A01_ ⋅ B_01
4.0 * 15.0 + 5.0 * 19.0 + 6.0 * 23.0, 4.0 * 16.0 + 5.0 * 20.0 + 6.0 * 24.0 // A01_ ⋅ B_10, A01_ ⋅ B_11

//

7.0 * 13.0 + 8.0 * 17.0 + 9.0 * 21.0, 7.0 * 14.0 + 8.0 * 18.0 + 9.0 * 22.0 // A10_ ⋅ B_00, A10_ ⋅ B_01
7.0 * 15.0 + 8.0 * 19.0 + 9.0 * 23.0, 7.0 * 16.0 + 8.0 * 20.0 + 9.0 * 24.0 // A10_ ⋅ B_10, A10_ ⋅ B_11

10.0 * 13.0 + 11.0 * 17.0 + 12.0 * 21.0, 10.0 * 14.0 + 11.0 * 18.0 + 12.0 * 22.0 // A11_ ⋅ B_00, A11_ ⋅ B_01
10.0 * 15.0 + 11.0 * 19.0 + 12.0 * 23.0, 10.0 * 16.0 + 11.0 * 20.0 + 12.0 * 24.0 // A11_ ⋅ B_10, A11_ ⋅ B_11
]
))
34 changes: 34 additions & 0 deletions src/Tensor.Comp.Tensordot.Outer.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
:: Specialized implementation of tensordot with [0, -1] hard-coded as the axes.
::
:: This is a common use case for tensordot, and we can skip some complexity
:: by implementing only this case.
::
:: It is effectively a matrix multiplication, with the first dimension of the
:: first tensor and the last dimension of the second tensor as the matrix axes.
:module Tensor.Comp.TensorDot.Outer
:fun build!(g Tensor.Graph.Helper.Methods, name String, a, b)
a_shape = g.shape!("\(name).a_shape", a)
b_shape = g.shape!("\(name).b_shape", b)

a_shape_split = g.split_varying!("\(name).a_shape_split", a_shape, 0, [-1, 1])
b_shape_split = g.split_varying!("\(name).b_shape_split", b_shape, 0, [1, -1])

a_free_dims = a_shape_split.output_slice(0)
a_target_dim = a_shape_split.output_slice(1)
b_target_dim = b_shape_split.output_slice(0)
b_free_dims = b_shape_split.output_slice(1)

neg_one = g.const!("\(name).neg_one", Tensor(I32).from_array([-1]))
zero_axis = g.const!("\(name).zero_axis", Tensor(I32).scalar(0))

g.reshape_dynamic!("\(name).result"
g.matmul!("\(name).matmul"
g.reshape_dynamic!("\(name).a_reshape", a
g.concat_dynamic!("\(name).a_new_shape", [neg_one, a_target_dim], zero_axis)
)
g.reshape_dynamic!("\(name).b_reshape", b
g.concat_dynamic!("\(name).b_new_shape", [b_target_dim, neg_one], zero_axis)
)
)
g.concat_dynamic!("\(name).c_new_shape", [a_free_dims, b_free_dims], zero_axis)
)
6 changes: 6 additions & 0 deletions src/Tensor.Graph.Helper.savi
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,9 @@

:fun ref select!(name, condition, true_case, false_case)
Tensor.Op.Select.new!(@graph, name, condition, true_case, false_case)

///
// Composite Operations

:fun ref tensordot_outer!(name, a, b)
Tensor.Comp.TensorDot.Outer.build!(@, name, a, b)

0 comments on commit b755d53

Please sign in to comment.