diff --git a/spec/Main.savi b/spec/Main.savi index 54ef3ad..50600c1 100644 --- a/spec/Main.savi +++ b/spec/Main.savi @@ -27,6 +27,11 @@ Spec.Run(Tensor.Op.SplitV.Spec).new(env) Spec.Run(Tensor.Op.Square.Spec).new(env) + /// + // Tensor.Comp specs + + Spec.Run(Tensor.Comp.TensorDot.Outer.Spec).new(env) + /// // Tensor.Gen specs diff --git a/spec/Tensor.Comp.TensorDot.Outer.Spec.savi b/spec/Tensor.Comp.TensorDot.Outer.Spec.savi new file mode 100644 index 0000000..76a3ff5 --- /dev/null +++ b/spec/Tensor.Comp.TensorDot.Outer.Spec.savi @@ -0,0 +1,66 @@ +:class Tensor.Comp.TensorDot.Outer.Spec + :is Spec + :const describes: "Tensor.Comp.TensorDot.Outer" + + :it "is equivalent to matrix multiplication for rank-2 tensors" + _WithGraphHelper.run(@env) -> (g, session | assert no_error: ( + result = session.compute!( + g.tensordot_outer!("example" + g.const!("A", Tensor(F64).from_array([ + 1.0, 2.0 + 3.0, 4.0 + ]).try_reshape([2, 2])) + g.const!("B", Tensor(F64).from_array([ + 5.0, 6.0 + 7.0, 8.0 + ]).try_reshape([2, 2])) + ) + ) + + assert: result.as!(Tensor(F64)).into_array == [ + 1.0 * 5.0 + 2.0 * 7.0, 1.0 * 6.0 + 2.0 * 8.0 // Arow1⋅Bcol1, Arow1⋅Bcol2 + 3.0 * 5.0 + 4.0 * 7.0, 3.0 * 6.0 + 4.0 * 8.0 // Arow2⋅Bcol1, Arow2⋅Bcol2 + ] + )) + + :it "handles larger-rank tensors by applying to the outer axes" + _WithGraphHelper.run(@env) -> (g, session | assert no_error: ( + result = session.compute!( + g.tensordot_outer!("example" + g.const!("A", Tensor(F64).from_array([ + 1, 2, 3 + 4, 5, 6 + + 7, 8, 9 + 10, 11, 12 + ]).try_reshape([2, 2, 3])) + g.const!("B", Tensor(F64).from_array([ + 13, 14 + 15, 16 + + 17, 18 + 19, 20 + + 21, 22 + 23, 24 + ]).try_reshape([3, 2, 2])) + ) + ) + + assert: result.shape_into_array == [2, 2, 2, 2] + assert: result.as!(Tensor(F64)).into_array == [ + 1.0 * 13.0 + 2.0 * 17.0 + 3.0 * 21.0, 1.0 * 14.0 + 2.0 * 18.0 + 3.0 * 22.0 // A00_ ⋅ B_00, A00_ ⋅ B_01 + 1.0 * 15.0 + 2.0 * 19.0 + 3.0 * 23.0, 1.0 * 16.0 + 2.0 * 20.0 + 3.0 * 24.0 // A00_ ⋅ B_10, A00_ ⋅ B_11 + + 4.0 * 13.0 + 5.0 * 17.0 + 6.0 * 21.0, 4.0 * 14.0 + 5.0 * 18.0 + 6.0 * 22.0 // A01_ ⋅ B_00, A01_ ⋅ B_01 + 4.0 * 15.0 + 5.0 * 19.0 + 6.0 * 23.0, 4.0 * 16.0 + 5.0 * 20.0 + 6.0 * 24.0 // A01_ ⋅ B_10, A01_ ⋅ B_11 + + // + + 7.0 * 13.0 + 8.0 * 17.0 + 9.0 * 21.0, 7.0 * 14.0 + 8.0 * 18.0 + 9.0 * 22.0 // A10_ ⋅ B_00, A10_ ⋅ B_01 + 7.0 * 15.0 + 8.0 * 19.0 + 9.0 * 23.0, 7.0 * 16.0 + 8.0 * 20.0 + 9.0 * 24.0 // A10_ ⋅ B_10, A10_ ⋅ B_11 + + 10.0 * 13.0 + 11.0 * 17.0 + 12.0 * 21.0, 10.0 * 14.0 + 11.0 * 18.0 + 12.0 * 22.0 // A11_ ⋅ B_00, A11_ ⋅ B_01 + 10.0 * 15.0 + 11.0 * 19.0 + 12.0 * 23.0, 10.0 * 16.0 + 11.0 * 20.0 + 12.0 * 24.0 // A11_ ⋅ B_10, A11_ ⋅ B_11 + ] + )) diff --git a/src/Tensor.Comp.Tensordot.Outer.savi b/src/Tensor.Comp.Tensordot.Outer.savi new file mode 100644 index 0000000..83201a5 --- /dev/null +++ b/src/Tensor.Comp.Tensordot.Outer.savi @@ -0,0 +1,34 @@ +:: Specialized implementation of tensordot with [0, -1] hard-coded as the axes. +:: +:: This is a common use case for tensordot, and we can skip some complexity +:: by implementing only this case. +:: +:: It is effectively a matrix multiplication, with the first dimension of the +:: first tensor and the last dimension of the second tensor as the matrix axes. +:module Tensor.Comp.TensorDot.Outer + :fun build!(g Tensor.Graph.Helper.Methods, name String, a, b) + a_shape = g.shape!("\(name).a_shape", a) + b_shape = g.shape!("\(name).b_shape", b) + + a_shape_split = g.split_varying!("\(name).a_shape_split", a_shape, 0, [-1, 1]) + b_shape_split = g.split_varying!("\(name).b_shape_split", b_shape, 0, [1, -1]) + + a_free_dims = a_shape_split.output_slice(0) + a_target_dim = a_shape_split.output_slice(1) + b_target_dim = b_shape_split.output_slice(0) + b_free_dims = b_shape_split.output_slice(1) + + neg_one = g.const!("\(name).neg_one", Tensor(I32).from_array([-1])) + zero_axis = g.const!("\(name).zero_axis", Tensor(I32).scalar(0)) + + g.reshape_dynamic!("\(name).result" + g.matmul!("\(name).matmul" + g.reshape_dynamic!("\(name).a_reshape", a + g.concat_dynamic!("\(name).a_new_shape", [neg_one, a_target_dim], zero_axis) + ) + g.reshape_dynamic!("\(name).b_reshape", b + g.concat_dynamic!("\(name).b_new_shape", [b_target_dim, neg_one], zero_axis) + ) + ) + g.concat_dynamic!("\(name).c_new_shape", [a_free_dims, b_free_dims], zero_axis) + ) diff --git a/src/Tensor.Graph.Helper.savi b/src/Tensor.Graph.Helper.savi index 8eb1064..842b2dc 100644 --- a/src/Tensor.Graph.Helper.savi +++ b/src/Tensor.Graph.Helper.savi @@ -228,3 +228,9 @@ :fun ref select!(name, condition, true_case, false_case) Tensor.Op.Select.new!(@graph, name, condition, true_case, false_case) + + /// + // Composite Operations + + :fun ref tensordot_outer!(name, a, b) + Tensor.Comp.TensorDot.Outer.build!(@, name, a, b)