-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from jemc-savi/add/some-ops
Add some `Tensor.Op`s and make it easy to run them.
- Loading branch information
Showing
17 changed files
with
503 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
:class Tensor.Op.Const.Spec | ||
:is Spec | ||
:const describes: "Tensor.Op.Const" | ||
|
||
:it "emits a constant tensor value" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.const!("example" | ||
Tensor(F64).from_array([1, 2, 3, 4]) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(F64)).into_array == [1, 2, 3, 4] | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
:class Tensor.Op.MatMul.Spec | ||
:is Spec | ||
:const describes: "Tensor.Op.MatMul" | ||
|
||
:fun non f64_2x2(a, b, c, d) | ||
Tensor(F64).from_array([a, b, c, d]).try_reshape([2, 2]) | ||
|
||
:it "computes matrix multiplication" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.matmul!("example" | ||
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0)) | ||
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0)) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(F64)).into_array == [ | ||
1.0 * 5.0 + 2.0 * 7.0, 1.0 * 6.0 + 2.0 * 8.0 // Arow1⋅Bcol1, Arow1⋅Bcol2 | ||
3.0 * 5.0 + 4.0 * 7.0, 3.0 * 6.0 + 4.0 * 8.0 // Arow2⋅Bcol1, Arow2⋅Bcol2 | ||
] | ||
)) | ||
|
||
:it "computes matrix multiplication with the first matrix transposed" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.matmul_with_a_transposed!("example" | ||
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0)) | ||
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0)) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(F64)).into_array == [ | ||
1.0 * 5.0 + 3.0 * 7.0, 1.0 * 6.0 + 3.0 * 8.0 // Acol1⋅Bcol1, Acol1⋅Bcol2 | ||
2.0 * 5.0 + 4.0 * 7.0, 2.0 * 6.0 + 4.0 * 8.0 // Acol2⋅Bcol1, Acol2⋅Bcol2 | ||
] | ||
)) | ||
|
||
:it "computes matrix multiplication with the second matrix transposed" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.matmul_with_b_transposed!("example" | ||
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0)) | ||
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0)) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(F64)).into_array == [ | ||
1.0 * 5.0 + 2.0 * 6.0, 1.0 * 7.0 + 2.0 * 8.0 // Arow1⋅Brow1, Arow1⋅Brow2 | ||
3.0 * 5.0 + 4.0 * 6.0, 3.0 * 7.0 + 4.0 * 8.0 // Arow2⋅Brow1, Arow2⋅Brow2 | ||
] | ||
)) | ||
|
||
:it "computes matrix multiplication with both matrices transposed" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.matmul_with_both_transposed!("example" | ||
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0)) | ||
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0)) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(F64)).into_array == [ | ||
1.0 * 5.0 + 3.0 * 6.0, 1.0 * 7.0 + 3.0 * 8.0 // Acol1⋅Brow1, Acol1⋅Brow2 | ||
2.0 * 5.0 + 4.0 * 6.0, 2.0 * 7.0 + 4.0 * 8.0 // Acol2⋅Brow1, Acol2⋅Brow2 | ||
] | ||
)) | ||
|
||
:it "complains when one of the inputs is a scalar (rank 0 tensor)" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.matmul!("example" | ||
g.const!("A", Tensor(F64).scalar(99)) | ||
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0)) | ||
) | ||
) | ||
|
||
:it "complains when one of the inputs is a vector (rank 1 tensor)" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.matmul!("example" | ||
g.const!("A", Tensor(F64).from_array([1, 2, 3, 4])) | ||
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0)) | ||
) | ||
) | ||
|
||
:it "complains when one of the inputs has a rank higher 2" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.matmul!("example" | ||
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0).try_reshape([2, 1, 2])) | ||
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0)) | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
:class Tensor.Op.Softmax.Spec | ||
:is Spec | ||
:const describes: "Tensor.Op.Softmax" | ||
|
||
:it "computes the softmax function of a vector (tensor rank 1)" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.softmax!("example" | ||
g.const!("input", Tensor(F64).from_array([1, 2, 3, 4, 5])) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(F64)).into_array == [ | ||
0.01165623095603961 | ||
0.031684920796124276 | ||
0.08612854443626873 | ||
0.23412165725273662 | ||
0.6364086465588309 | ||
] | ||
)) | ||
|
||
:it "complains when applied to a scalar (rank 0 tensor)" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: ( | ||
g.softmax!("example" | ||
g.const!("input", Tensor(F64).scalar(99)) | ||
) | ||
) | ||
) | ||
|
||
:it "when applied to a higher rank, computes each inner row separately" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.softmax!("example" | ||
g.const!("input" | ||
Tensor(F64).from_array([ | ||
1, 2, 3 | ||
1, 2, 0 // with implicit bias, this is equivalent to 2, 3, 1 | ||
|
||
4, 5, 6 // with implicit bias, this is equivalent to 1, 2, 3 | ||
4, 5, 0 // but here the pattern changes, as 0 is far from 4 & 5 | ||
|
||
7, 8, 9 // and this is also equivalent to 1, 2, 3 | ||
7, 8, 0 // and this 0 is even father from 7 & 8 | ||
]).try_reshape([3, 2, 3]) | ||
) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(F64)).into_array == [ | ||
0.09003057317038046, 0.2447284710547976, 0.6652409557748219 | ||
0.24472847105479759, 0.6652409557748218, 0.09003057317038045 | ||
|
||
0.09003057317038046, 0.2447284710547976, 0.6652409557748219 | ||
0.2676231541498623, 0.7274751568004648, 0.004901689049672922 | ||
|
||
0.09003057317038046, 0.2447284710547976, 0.6652409557748219 | ||
0.26887548158545244, 0.7308793357119101, 0.00024518270263755956 | ||
] | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
:struct _WithGraphHelper | ||
:let graph: Tensor.Graph.new | ||
:let session: Tensor.Graph.Session.new(@graph) | ||
:copies Tensor.Graph.Helper.Methods | ||
|
||
:fun non run(env Env, print_errors = True) | ||
graph = Tensor.Graph.new | ||
g = Tensor.Graph.Helper.new(graph) | ||
session = Tensor.Graph.Session.new(graph) | ||
|
||
yield (g, session) | ||
|
||
if print_errors ( | ||
graph.errors.each -> (error | env.err.print(error.message)) | ||
session.errors.each -> (error | env.err.print(error.message)) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
:struct Tensor.Graph.Helper | ||
:let graph Tensor.Graph | ||
:new (@graph) | ||
:copies Tensor.Graph.Helper.Methods | ||
|
||
:trait Tensor.Graph.Helper.Methods | ||
:fun graph @->(Tensor.Graph) | ||
|
||
/// | ||
// Value Sources | ||
|
||
:fun ref const!(name, value) | ||
Tensor.Op.Const.new!(@graph, name, value) | ||
|
||
/// | ||
// Unary Operations | ||
|
||
:fun ref softmax!(name, input) | ||
Tensor.Op.Softmax.new!(@graph, name, input) | ||
|
||
/// | ||
// Binary Operations | ||
|
||
:fun ref matmul!(name, a, b) | ||
Tensor.Op.MatMul.new!(@graph, name, a, b, False, False) | ||
|
||
:fun ref matmul_with_a_transposed!(name, a, b) | ||
Tensor.Op.MatMul.new!(@graph, name, a, b, True, False) | ||
|
||
:fun ref matmul_with_b_transposed!(name, a, b) | ||
Tensor.Op.MatMul.new!(@graph, name, a, b, False, True) | ||
|
||
:fun ref matmul_with_both_transposed!(name, a, b) | ||
Tensor.Op.MatMul.new!(@graph, name, a, b, True, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
:struct Tensor.Graph.Input | ||
:let operation Tensor.Graph.Operation | ||
:struct box Tensor.Graph.Input | ||
:let op Tensor.Graph.Operation | ||
:let index USize | ||
:new (@operation, @index) | ||
:new box (@op, @index) | ||
|
||
:fun _to_ffi | ||
_FFI.Input._new(@operation._ptr, @index.i32) | ||
_FFI.Input._new(@op._ptr, @index.i32) | ||
|
||
:struct _FFI.Input | ||
:let operation_ptr CPointer(_FFI.Operation) | ||
:let index I32 | ||
:new _new(@operation_ptr, @index) | ||
:let _op_ptr CPointer(_FFI.Operation) | ||
:let _index I32 | ||
:new _new(@_op_ptr, @_index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,23 @@ | ||
:struct Tensor.Graph.Output | ||
:let operation Tensor.Graph.Operation | ||
:trait box Tensor.Graph.CanOutput | ||
:fun output Tensor.Graph.Output | ||
|
||
:struct box Tensor.Graph.Output | ||
:let op Tensor.Graph.Operation | ||
:let index USize | ||
:new (@operation, @index) | ||
:new box (@op, @index) | ||
|
||
:is Tensor.Graph.CanOutput | ||
:fun output: @ | ||
|
||
:fun hash USize: @op._ptr.address.hash.bit_xor(@index.hash) | ||
:fun "=="(that @'box) Bool | ||
@op._ptr.address == that.op._ptr.address | ||
&& @index == that.index | ||
|
||
:fun _to_ffi | ||
_FFI.Output._new(@operation._ptr, @index.i32) | ||
_FFI.Output._new(@op._ptr, @index.i32) | ||
|
||
:struct _FFI.Output | ||
:let operation_ptr CPointer(_FFI.Operation) | ||
:let index I32 | ||
:new _new(@operation_ptr, @index) | ||
:struct box _FFI.Output | ||
:let _op_ptr CPointer(_FFI.Operation) | ||
:let _index I32 | ||
:new box _new(@_op_ptr, @_index) |
Oops, something went wrong.