Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tensor.Op.Greater and Tensor.Op.Lesser. #5

Merged
merged 1 commit into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions spec/Main.savi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Spec.Run(Tensor.Op.Bitcast.Spec).new(env)
Spec.Run(Tensor.Op.Cast.Spec).new(env)
Spec.Run(Tensor.Op.Const.Spec).new(env)
Spec.Run(Tensor.Op.Greater.Spec).new(env)
Spec.Run(Tensor.Op.Lesser.Spec).new(env)
Spec.Run(Tensor.Op.Logical.Spec).new(env)
Spec.Run(Tensor.Op.MatMul.Spec).new(env)
Spec.Run(Tensor.Op.Softmax.Spec).new(env)
Expand Down
51 changes: 51 additions & 0 deletions spec/Tensor.Op.Greater.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
:class Tensor.Op.Greater.Spec
:is Spec
:const describes: "Tensor.Op.Greater"

:it "checks if the 1st operand's values are greater than those of the 2nd"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.greater!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1]))
)
)

assert: result.as!(Tensor(Bool)).into_array == [False, False, True]
))

:it "may optionally include equal values as being true"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.greater_or_equal!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1]))
)
)

assert: result.as!(Tensor(Bool)).into_array == [False, True, True]
))

:it "complains if the operands are of different types"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.greater!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I8).from_array([3, 2, 1]))
)
)

:it "complains if the operands are of different sizes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.greater!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]))
)
)

:it "complains if the operands are of different shapes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.greater!("example"
g.const!("x", Tensor(I32).from_array([0, 1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]).try_reshape([2, 2]))
)
)
51 changes: 51 additions & 0 deletions spec/Tensor.Op.Lesser.Spec.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
:class Tensor.Op.Lesser.Spec
:is Spec
:const describes: "Tensor.Op.Lesser"

:it "checks if the 1st operand's values are lesser than those of the 2nd"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.lesser!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1]))
)
)

assert: result.as!(Tensor(Bool)).into_array == [True, False, False]
))

:it "may optionally include equal values as being true"
_WithGraphHelper.run(@env) -> (g, session | assert no_error: (
result = session.compute!(
g.lesser_or_equal!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1]))
)
)

assert: result.as!(Tensor(Bool)).into_array == [True, True, False]
))

:it "complains if the operands are of different types"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.lesser!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I8).from_array([3, 2, 1]))
)
)

:it "complains if the operands are of different sizes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.lesser!("example"
g.const!("x", Tensor(I32).from_array([1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]))
)
)

:it "complains if the operands are of different shapes"
_WithGraphHelper.run(@env, False) -> (g, session |
assert error: g.lesser!("example"
g.const!("x", Tensor(I32).from_array([0, 1, 2, 3]))
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]).try_reshape([2, 2]))
)
)
15 changes: 15 additions & 0 deletions src/Tensor.Graph.Helper.savi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
:fun ref logical_or!(name, x, y)
Tensor.Op.Logical.Or.new!(@graph, name, x, y)

///
// Comparative Operations

:fun ref greater!(name, x, y)
Tensor.Op.Greater.new!(@graph, name, x, y, False)

:fun ref greater_or_equal!(name, x, y)
Tensor.Op.Greater.new!(@graph, name, x, y, True)

:fun ref lesser!(name, x, y)
Tensor.Op.Lesser.new!(@graph, name, x, y, False)

:fun ref lesser_or_equal!(name, x, y)
Tensor.Op.Lesser.new!(@graph, name, x, y, True)

///
// Type Conversions

Expand Down
15 changes: 15 additions & 0 deletions src/Tensor.Op.Greater.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
:: Compares input values from the two operands to check if x is greater than y.
::
:: The two inputs must be numeric tensors of the same type, size, and shape.
:: The output is a boolean tensor indicating the result of each comparison.
::
:: If the `or_equal` parameter is set to true, equal values in the operands
:: will also cause the corresponding output value to be `True`.
:struct box Tensor.Op.Greater
:is Tensor.Op

:fun non new!(graph Tensor.Graph, name, x, y, or_equal = False)
op_name = if or_equal ("GreaterEqual" | "Greater")
@_new(graph.new_operation(op_name, name) -> (builder |
builder.add_input(x).add_input(y).finish!
))
15 changes: 15 additions & 0 deletions src/Tensor.Op.Lesser.savi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
:: Compares input values from the two operands to check if x is less than y.
::
:: The two inputs must be numeric tensors of the same type, size, and shape.
:: The output is a boolean tensor indicating the result of each comparison.
::
:: If the `or_equal` parameter is set to true, equal values in the operands
:: will also cause the corresponding output value to be `True`.
:struct box Tensor.Op.Lesser
:is Tensor.Op

:fun non new!(graph Tensor.Graph, name, x, y, or_equal = False)
op_name = if or_equal ("LessEqual" | "Less")
@_new(graph.new_operation(op_name, name) -> (builder |
builder.add_input(x).add_input(y).finish!
))