-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from jemc-savi/add/op
Add `Tensor.Comp.Tensordot.Outer`, a special case of tensordot.
- Loading branch information
Showing
4 changed files
with
111 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
:class Tensor.Comp.TensorDot.Outer.Spec | ||
:is Spec | ||
:const describes: "Tensor.Comp.TensorDot.Outer" | ||
|
||
:it "is equivalent to matrix multiplication for rank-2 tensors" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.tensordot_outer!("example" | ||
g.const!("A", Tensor(F64).from_array([ | ||
1.0, 2.0 | ||
3.0, 4.0 | ||
]).try_reshape([2, 2])) | ||
g.const!("B", Tensor(F64).from_array([ | ||
5.0, 6.0 | ||
7.0, 8.0 | ||
]).try_reshape([2, 2])) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(F64)).into_array == [ | ||
1.0 * 5.0 + 2.0 * 7.0, 1.0 * 6.0 + 2.0 * 8.0 // Arow1⋅Bcol1, Arow1⋅Bcol2 | ||
3.0 * 5.0 + 4.0 * 7.0, 3.0 * 6.0 + 4.0 * 8.0 // Arow2⋅Bcol1, Arow2⋅Bcol2 | ||
] | ||
)) | ||
|
||
:it "handles larger-rank tensors by applying to the outer axes" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.tensordot_outer!("example" | ||
g.const!("A", Tensor(F64).from_array([ | ||
1, 2, 3 | ||
4, 5, 6 | ||
|
||
7, 8, 9 | ||
10, 11, 12 | ||
]).try_reshape([2, 2, 3])) | ||
g.const!("B", Tensor(F64).from_array([ | ||
13, 14 | ||
15, 16 | ||
|
||
17, 18 | ||
19, 20 | ||
|
||
21, 22 | ||
23, 24 | ||
]).try_reshape([3, 2, 2])) | ||
) | ||
) | ||
|
||
assert: result.shape_into_array == [2, 2, 2, 2] | ||
assert: result.as!(Tensor(F64)).into_array == [ | ||
1.0 * 13.0 + 2.0 * 17.0 + 3.0 * 21.0, 1.0 * 14.0 + 2.0 * 18.0 + 3.0 * 22.0 // A00_ ⋅ B_00, A00_ ⋅ B_01 | ||
1.0 * 15.0 + 2.0 * 19.0 + 3.0 * 23.0, 1.0 * 16.0 + 2.0 * 20.0 + 3.0 * 24.0 // A00_ ⋅ B_10, A00_ ⋅ B_11 | ||
|
||
4.0 * 13.0 + 5.0 * 17.0 + 6.0 * 21.0, 4.0 * 14.0 + 5.0 * 18.0 + 6.0 * 22.0 // A01_ ⋅ B_00, A01_ ⋅ B_01 | ||
4.0 * 15.0 + 5.0 * 19.0 + 6.0 * 23.0, 4.0 * 16.0 + 5.0 * 20.0 + 6.0 * 24.0 // A01_ ⋅ B_10, A01_ ⋅ B_11 | ||
|
||
// | ||
|
||
7.0 * 13.0 + 8.0 * 17.0 + 9.0 * 21.0, 7.0 * 14.0 + 8.0 * 18.0 + 9.0 * 22.0 // A10_ ⋅ B_00, A10_ ⋅ B_01 | ||
7.0 * 15.0 + 8.0 * 19.0 + 9.0 * 23.0, 7.0 * 16.0 + 8.0 * 20.0 + 9.0 * 24.0 // A10_ ⋅ B_10, A10_ ⋅ B_11 | ||
|
||
10.0 * 13.0 + 11.0 * 17.0 + 12.0 * 21.0, 10.0 * 14.0 + 11.0 * 18.0 + 12.0 * 22.0 // A11_ ⋅ B_00, A11_ ⋅ B_01 | ||
10.0 * 15.0 + 11.0 * 19.0 + 12.0 * 23.0, 10.0 * 16.0 + 11.0 * 20.0 + 12.0 * 24.0 // A11_ ⋅ B_10, A11_ ⋅ B_11 | ||
] | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
:: Specialized implementation of tensordot with [0, -1] hard-coded as the axes. | ||
:: | ||
:: This is a common use case for tensordot, and we can skip some complexity | ||
:: by implementing only this case. | ||
:: | ||
:: It is effectively a matrix multiplication, with the first dimension of the | ||
:: first tensor and the last dimension of the second tensor as the matrix axes. | ||
:module Tensor.Comp.TensorDot.Outer | ||
:fun build!(g Tensor.Graph.Helper.Methods, name String, a, b) | ||
a_shape = g.shape!("\(name).a_shape", a) | ||
b_shape = g.shape!("\(name).b_shape", b) | ||
|
||
a_shape_split = g.split_varying!("\(name).a_shape_split", a_shape, 0, [-1, 1]) | ||
b_shape_split = g.split_varying!("\(name).b_shape_split", b_shape, 0, [1, -1]) | ||
|
||
a_free_dims = a_shape_split.output_slice(0) | ||
a_target_dim = a_shape_split.output_slice(1) | ||
b_target_dim = b_shape_split.output_slice(0) | ||
b_free_dims = b_shape_split.output_slice(1) | ||
|
||
neg_one = g.const!("\(name).neg_one", Tensor(I32).from_array([-1])) | ||
zero_axis = g.const!("\(name).zero_axis", Tensor(I32).scalar(0)) | ||
|
||
g.reshape_dynamic!("\(name).result" | ||
g.matmul!("\(name).matmul" | ||
g.reshape_dynamic!("\(name).a_reshape", a | ||
g.concat_dynamic!("\(name).a_new_shape", [neg_one, a_target_dim], zero_axis) | ||
) | ||
g.reshape_dynamic!("\(name).b_reshape", b | ||
g.concat_dynamic!("\(name).b_new_shape", [b_target_dim, neg_one], zero_axis) | ||
) | ||
) | ||
g.concat_dynamic!("\(name).c_new_shape", [a_free_dims, b_free_dims], zero_axis) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters