Skip to content

Commit

Permalink
Add Tensor.Op.Greater and Tensor.Op.Lesser.
Browse files Browse the repository at this point in the history
These each compare two numeric tensors and give a boolean tensor as the output.
  • Loading branch information
jemc committed Apr 1, 2023
1 parent e29b6bd commit fa35868
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 0 deletions.
2 changes: 2 additions & 0 deletions spec/Main.savi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Spec.Run(Tensor.Op.Bitcast.Spec).new(env)
Spec.Run(Tensor.Op.Cast.Spec).new(env)
Spec.Run(Tensor.Op.Const.Spec).new(env)
Spec.Run(Tensor.Op.Greater.Spec).new(env)
Spec.Run(Tensor.Op.Lesser.Spec).new(env)
Spec.Run(Tensor.Op.Logical.Spec).new(env)
Spec.Run(Tensor.Op.MatMul.Spec).new(env)
Spec.Run(Tensor.Op.Softmax.Spec).new(env)
Expand Down
51 changes: 51 additions & 0 deletions spec/Tensor.Op.Greater.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
:class Tensor.Op.Greater.Spec
:is Spec
:const describes: "Tensor.Op.Greater"

:it "checks if the 1st operand's values are greater than those of the 2nd"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.greater!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1]))
)
)

assert: result.as!(Tensor(Bool)).into_array == [False, False, True]
))

:it "may optionally include equal values as being true"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.greater_or_equal!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1]))
)
)

assert: result.as!(Tensor(Bool)).into_array == [False, True, True]
))

:it "complains if the operands are of different types"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.greater!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I8).from_array([3, 2, 1]))
)
)

:it "complains if the operands are of different sizes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.greater!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]))
)
)

:it "complains if the operands are of different shapes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.greater!("example"
g.const!("x", Tensor(I32).from_array([0, 1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]).try_reshape([2, 2]))
)
)
51 changes: 51 additions & 0 deletions spec/Tensor.Op.Lesser.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
:class Tensor.Op.Lesser.Spec
:is Spec
:const describes: "Tensor.Op.Lesser"

:it "checks if the 1st operand's values are lesser than those of the 2nd"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.lesser!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1]))
)
)

assert: result.as!(Tensor(Bool)).into_array == [True, False, False]
))

:it "may optionally include equal values as being true"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.lesser_or_equal!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1]))
)
)

assert: result.as!(Tensor(Bool)).into_array == [True, True, False]
))

:it "complains if the operands are of different types"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.lesser!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I8).from_array([3, 2, 1]))
)
)

:it "complains if the operands are of different sizes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.lesser!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]))
)
)

:it "complains if the operands are of different shapes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.lesser!("example"
g.const!("x", Tensor(I32).from_array([0, 1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]).try_reshape([2, 2]))
)
)
15 changes: 15 additions & 0 deletions src/Tensor.Graph.Helper.savi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
:fun ref logical_or!(name, x, y)
Tensor.Op.Logical.Or.new!(@graph, name, x, y)

///
// Comparative Operations

:fun ref greater!(name, x, y)
Tensor.Op.Greater.new!(@graph, name, x, y, False)

:fun ref greater_or_equal!(name, x, y)
Tensor.Op.Greater.new!(@graph, name, x, y, True)

:fun ref lesser!(name, x, y)
Tensor.Op.Lesser.new!(@graph, name, x, y, False)

:fun ref lesser_or_equal!(name, x, y)
Tensor.Op.Lesser.new!(@graph, name, x, y, True)

///
// Type Conversions

Expand Down
15 changes: 15 additions & 0 deletions src/Tensor.Op.Greater.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
:: Compares input values from the two operands to check if x is greater than y.
::
:: The two inputs must be numeric tensors of the same type, size, and shape.
:: The output is a boolean tensor indicating the result of each comparison.
::
:: If the `or_equal` parameter is set to true, equal values in the operands
:: will also cause the corresponding output value to be `True`.
:struct box Tensor.Op.Greater
:is Tensor.Op

:fun non new!(graph Tensor.Graph, name, x, y, or_equal = False)
op_name = if or_equal ("GreaterEqual" | "Greater")
@_new(graph.new_operation(op_name, name) -> (builder |
builder.add_input(x).add_input(y).finish!
))
15 changes: 15 additions & 0 deletions src/Tensor.Op.Lesser.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
:: Compares input values from the two operands to check if x is less than y.
::
:: The two inputs must be numeric tensors of the same type, size, and shape.
:: The output is a boolean tensor indicating the result of each comparison.
::
:: If the `or_equal` parameter is set to true, equal values in the operands
:: will also cause the corresponding output value to be `True`.
:struct box Tensor.Op.Lesser
:is Tensor.Op

:fun non new!(graph Tensor.Graph, name, x, y, or_equal = False)
op_name = if or_equal ("LessEqual" | "Less")
@_new(graph.new_operation(op_name, name) -> (builder |
builder.add_input(x).add_input(y).finish!
))

0 comments on commit fa35868

Please sign in to comment.