Skip to content

Commit

Permalink
Merge pull request #2 from jemc-savi/add/some-ops
Browse files Browse the repository at this point in the history
Add some `Tensor.Op`s and make it easy to run them.
  • Loading branch information
jemc committed Mar 31, 2023
2 parents 360de40 + a2aaa45 commit c6ebce2
Show file tree
Hide file tree
Showing 17 changed files with 503 additions and 35 deletions.
6 changes: 3 additions & 3 deletions manifest.savi
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
:manifest lib Tensor
:sources "src/*.savi"

:dependency Map v0
:from "github:savi-lang/Map"

:manifest bin "spec"
:copies Tensor
:sources "spec/*.savi"
Expand All @@ -11,9 +14,6 @@
:depends on Time
:depends on Timer

:transitive dependency Map v0
:from "github:savi-lang/Map"

:transitive dependency Time v0
:from "github:savi-lang/Time"

Expand Down
3 changes: 3 additions & 0 deletions spec/Main.savi
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
Spec.Process.run(env, [
Spec.Run(Tensor.Spec).new(env)
Spec.Run(Tensor.Graph.Spec).new(env)
Spec.Run(Tensor.Op.Const.Spec).new(env)
Spec.Run(Tensor.Op.MatMul.Spec).new(env)
Spec.Run(Tensor.Op.Softmax.Spec).new(env)
])
22 changes: 19 additions & 3 deletions spec/Tensor.Graph.Spec.savi
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,35 @@
.set_attr_tensor!("value", b_value)
.finish!
)
product = graph.new_operation("MatMul", "example") -> (builder |
product1 = graph.new_operation("MatMul", "product1") -> (builder |
builder
.add_input(a.output(0))
.add_input(b.output(0))
.finish!
)
product2 = graph.new_operation("MatMul", "product2") -> (builder |
builder
.add_input(a.output(0))
.add_input(b.output(0))
.set_attr_bool("transpose_a", True)
.finish!
)

result = session.open!.hacky_temporary_run!(product)

result = session.compute!(product1.output(0))
assert: result.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 2.0 * 7.0, 1.0 * 6.0 + 2.0 * 8.0 // row1⋅col1, row1⋅col2
3.0 * 5.0 + 4.0 * 7.0, 3.0 * 6.0 + 4.0 * 8.0 // row2⋅col1, row2⋅col2
]

results = session.compute_many!([product1.output(0), product2.output(0)])
assert: results[product1.output(0)]!.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 2.0 * 7.0, 1.0 * 6.0 + 2.0 * 8.0 // row1⋅col1, row1⋅col2
3.0 * 5.0 + 4.0 * 7.0, 3.0 * 6.0 + 4.0 * 8.0 // row2⋅col1, row2⋅col2
]
assert: results[product2.output(0)]!.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 3.0 * 7.0, 1.0 * 6.0 + 3.0 * 8.0 // col1⋅col1, col1⋅col2
2.0 * 5.0 + 4.0 * 7.0, 2.0 * 6.0 + 4.0 * 8.0 // col2⋅col1, col2⋅col2
]
|
graph.errors.each -> (error | @env.err.print(error.message))
session.errors.each -> (error | @env.err.print(error.message))
Expand Down
14 changes: 14 additions & 0 deletions spec/Tensor.Op.Const.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
:class Tensor.Op.Const.Spec
:is Spec
:const describes: "Tensor.Op.Const"

:it "emits a constant tensor value"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.const!("example"
Tensor(F64).from_array([1, 2, 3, 4])
)
)

assert: result.as!(Tensor(F64)).into_array == [1, 2, 3, 4]
))
90 changes: 90 additions & 0 deletions spec/Tensor.Op.MatMul.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
:class Tensor.Op.MatMul.Spec
:is Spec
:const describes: "Tensor.Op.MatMul"

:fun non f64_2x2(a, b, c, d)
Tensor(F64).from_array([a, b, c, d]).try_reshape([2, 2])

:it "computes matrix multiplication"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.matmul!("example"
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0))
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0))
)
)

assert: result.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 2.0 * 7.0, 1.0 * 6.0 + 2.0 * 8.0 // Arow1⋅Bcol1, Arow1⋅Bcol2
3.0 * 5.0 + 4.0 * 7.0, 3.0 * 6.0 + 4.0 * 8.0 // Arow2⋅Bcol1, Arow2⋅Bcol2
]
))

:it "computes matrix multiplication with the first matrix transposed"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.matmul_with_a_transposed!("example"
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0))
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0))
)
)

assert: result.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 3.0 * 7.0, 1.0 * 6.0 + 3.0 * 8.0 // Acol1⋅Bcol1, Acol1⋅Bcol2
2.0 * 5.0 + 4.0 * 7.0, 2.0 * 6.0 + 4.0 * 8.0 // Acol2⋅Bcol1, Acol2⋅Bcol2
]
))

:it "computes matrix multiplication with the second matrix transposed"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.matmul_with_b_transposed!("example"
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0))
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0))
)
)

assert: result.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 2.0 * 6.0, 1.0 * 7.0 + 2.0 * 8.0 // Arow1⋅Brow1, Arow1⋅Brow2
3.0 * 5.0 + 4.0 * 6.0, 3.0 * 7.0 + 4.0 * 8.0 // Arow2⋅Brow1, Arow2⋅Brow2
]
))

:it "computes matrix multiplication with both matrices transposed"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.matmul_with_both_transposed!("example"
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0))
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0))
)
)

assert: result.as!(Tensor(F64)).into_array == [
1.0 * 5.0 + 3.0 * 6.0, 1.0 * 7.0 + 3.0 * 8.0 // Acol1⋅Brow1, Acol1⋅Brow2
2.0 * 5.0 + 4.0 * 6.0, 2.0 * 7.0 + 4.0 * 8.0 // Acol2⋅Brow1, Acol2⋅Brow2
]
))

:it "complains when one of the inputs is a scalar (rank 0 tensor)"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.matmul!("example"
g.const!("A", Tensor(F64).scalar(99))
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0))
)
)

:it "complains when one of the inputs is a vector (rank 1 tensor)"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.matmul!("example"
g.const!("A", Tensor(F64).from_array([1, 2, 3, 4]))
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0))
)
)

:it "complains when one of the inputs has a rank higher 2"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.matmul!("example"
g.const!("A", @f64_2x2(1.0, 2.0, 3.0, 4.0).try_reshape([2, 1, 2]))
g.const!("B", @f64_2x2(5.0, 6.0, 7.0, 8.0))
)
)
60 changes: 60 additions & 0 deletions spec/Tensor.Op.Softmax.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
:class Tensor.Op.Softmax.Spec
:is Spec
:const describes: "Tensor.Op.Softmax"

:it "computes the softmax function of a vector (tensor rank 1)"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.softmax!("example"
g.const!("input", Tensor(F64).from_array([1, 2, 3, 4, 5]))
)
)

assert: result.as!(Tensor(F64)).into_array == [
0.01165623095603961
0.031684920796124276
0.08612854443626873
0.23412165725273662
0.6364086465588309
]
))

:it "complains when applied to a scalar (rank 0 tensor)"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: (
g.softmax!("example"
g.const!("input", Tensor(F64).scalar(99))
)
)
)

:it "when applied to a higher rank, computes each inner row separately"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.softmax!("example"
g.const!("input"
Tensor(F64).from_array([
1, 2, 3
1, 2, 0 // with implicit bias, this is equivalent to 2, 3, 1

4, 5, 6 // with implicit bias, this is equivalent to 1, 2, 3
4, 5, 0 // but here the pattern changes, as 0 is far from 4 & 5

7, 8, 9 // and this is also equivalent to 1, 2, 3
7, 8, 0 // and this 0 is even father from 7 & 8
]).try_reshape([3, 2, 3])
)
)
)

assert: result.as!(Tensor(F64)).into_array == [
0.09003057317038046, 0.2447284710547976, 0.6652409557748219
0.24472847105479759, 0.6652409557748218, 0.09003057317038045

0.09003057317038046, 0.2447284710547976, 0.6652409557748219
0.2676231541498623, 0.7274751568004648, 0.004901689049672922

0.09003057317038046, 0.2447284710547976, 0.6652409557748219
0.26887548158545244, 0.7308793357119101, 0.00024518270263755956
]
))
16 changes: 16 additions & 0 deletions spec/_WithGraphHelper.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
:struct _WithGraphHelper
:let graph: Tensor.Graph.new
:let session: Tensor.Graph.Session.new(@graph)
:copies Tensor.Graph.Helper.Methods

:fun non run(env Env, print_errors = True)
graph = Tensor.Graph.new
g = Tensor.Graph.Helper.new(graph)
session = Tensor.Graph.Session.new(graph)

yield (g, session)

if print_errors (
graph.errors.each -> (error | env.err.print(error.message))
session.errors.each -> (error | env.err.print(error.message))
)
34 changes: 34 additions & 0 deletions src/Tensor.Graph.Helper.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
:struct Tensor.Graph.Helper
:let graph Tensor.Graph
:new (@graph)
:copies Tensor.Graph.Helper.Methods

:trait Tensor.Graph.Helper.Methods
:fun graph @->(Tensor.Graph)

///
// Value Sources

:fun ref const!(name, value)
Tensor.Op.Const.new!(@graph, name, value)

///
// Unary Operations

:fun ref softmax!(name, input)
Tensor.Op.Softmax.new!(@graph, name, input)

///
// Binary Operations

:fun ref matmul!(name, a, b)
Tensor.Op.MatMul.new!(@graph, name, a, b, False, False)

:fun ref matmul_with_a_transposed!(name, a, b)
Tensor.Op.MatMul.new!(@graph, name, a, b, True, False)

:fun ref matmul_with_b_transposed!(name, a, b)
Tensor.Op.MatMul.new!(@graph, name, a, b, False, True)

:fun ref matmul_with_both_transposed!(name, a, b)
Tensor.Op.MatMul.new!(@graph, name, a, b, True, True)
14 changes: 7 additions & 7 deletions src/Tensor.Graph.Input.savi
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
:struct Tensor.Graph.Input
:let operation Tensor.Graph.Operation
:struct box Tensor.Graph.Input
:let op Tensor.Graph.Operation
:let index USize
:new (@operation, @index)
:new box (@op, @index)

:fun _to_ffi
_FFI.Input._new(@operation._ptr, @index.i32)
_FFI.Input._new(@op._ptr, @index.i32)

:struct _FFI.Input
:let operation_ptr CPointer(_FFI.Operation)
:let index I32
:new _new(@operation_ptr, @index)
:let _op_ptr CPointer(_FFI.Operation)
:let _index I32
:new _new(@_op_ptr, @_index)
20 changes: 18 additions & 2 deletions src/Tensor.Graph.Operation.savi
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@
:new _new(@graph, op_type String, oper_name String)
@_ptr = @_ffi.new(@graph._ptr, op_type.cpointer, oper_name.cpointer)

:fun ref add_input(from_output Tensor.Graph.Output)
:fun ref add_input(can_output Tensor.Graph.CanOutput)
return @ if @_ptr.is_null
@_ffi.add_input(@_ptr, from_output._to_ffi)
@_ffi.add_input(@_ptr, can_output.output._to_ffi)
@

:fun ref set_attr_bool(attr_name String, value Bool)
return @ if @_ptr.is_null
@_ffi.set_attr_bool(
@_ptr
attr_name.cstring
value.u8
)
@

:fun ref set_attr_type(attr_name String, type_code I32)
Expand Down Expand Up @@ -80,6 +89,13 @@
// ) None
// :foreign_name TF_AddOutput

:ffi set_attr_bool(
ptr CPointer(@)
attr_name CPointer(U8)
value U8
) None
:foreign_name TF_SetAttrBool

:ffi set_attr_type(
ptr CPointer(@)
attr_name CPointer(U8)
Expand Down
27 changes: 19 additions & 8 deletions src/Tensor.Graph.Output.savi
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
:struct Tensor.Graph.Output
:let operation Tensor.Graph.Operation
:trait box Tensor.Graph.CanOutput
:fun output Tensor.Graph.Output

:struct box Tensor.Graph.Output
:let op Tensor.Graph.Operation
:let index USize
:new (@operation, @index)
:new box (@op, @index)

:is Tensor.Graph.CanOutput
:fun output: @

:fun hash USize: @op._ptr.address.hash.bit_xor(@index.hash)
:fun "=="(that @'box) Bool
@op._ptr.address == that.op._ptr.address
&& @index == that.index

:fun _to_ffi
_FFI.Output._new(@operation._ptr, @index.i32)
_FFI.Output._new(@op._ptr, @index.i32)

:struct _FFI.Output
:let operation_ptr CPointer(_FFI.Operation)
:let index I32
:new _new(@operation_ptr, @index)
:struct box _FFI.Output
:let _op_ptr CPointer(_FFI.Operation)
:let _index I32
:new box _new(@_op_ptr, @_index)
Loading

0 comments on commit c6ebce2

Please sign in to comment.