-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
Tensor.Op.Greater
and Tensor.Op.Lesser
.
These compare numeric tensors and output boolean tensors.
- Loading branch information
Showing
6 changed files
with
149 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
:class Tensor.Op.Greater.Spec | ||
:is Spec | ||
:const describes: "Tensor.Op.Greater" | ||
|
||
:it "checks if the 1st operand's values are greater than those of the 2nd" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.greater!("example" | ||
g.const!("x", Tensor(I32).from_array([1, 2, 3])) | ||
g.const!("y", Tensor(I32).from_array([3, 2, 1])) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(Bool)).into_array == [False, False, True] | ||
)) | ||
|
||
:it "may optionally include equal values as being true" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.greater_or_equal!("example" | ||
g.const!("x", Tensor(I32).from_array([1, 2, 3])) | ||
g.const!("y", Tensor(I32).from_array([3, 2, 1])) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(Bool)).into_array == [False, True, True] | ||
)) | ||
|
||
:it "complains if the operands are of different types" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.greater!("example" | ||
g.const!("x", Tensor(I32).from_array([1, 2, 3])) | ||
g.const!("y", Tensor(I8).from_array([3, 2, 1])) | ||
) | ||
) | ||
|
||
:it "complains if the operands are of different sizes" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.greater!("example" | ||
g.const!("x", Tensor(I32).from_array([1, 2, 3])) | ||
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0])) | ||
) | ||
) | ||
|
||
:it "complains if the operands are of different shapes" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.greater!("example" | ||
g.const!("x", Tensor(I32).from_array([0, 1, 2, 3])) | ||
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]).try_reshape([2, 2])) | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
:class Tensor.Op.Lesser.Spec | ||
:is Spec | ||
:const describes: "Tensor.Op.Lesser" | ||
|
||
:it "checks if the 1st operand's values are lesser than those of the 2nd" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.lesser!("example" | ||
g.const!("x", Tensor(I32).from_array([1, 2, 3])) | ||
g.const!("y", Tensor(I32).from_array([3, 2, 1])) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(Bool)).into_array == [True, False, False] | ||
)) | ||
|
||
:it "may optionally include equal values as being true" | ||
_WithGraphHelper.run(@env) -> (g, session | assert no_error: ( | ||
result = session.compute!( | ||
g.lesser_or_equal!("example" | ||
g.const!("x", Tensor(I32).from_array([1, 2, 3])) | ||
g.const!("y", Tensor(I32).from_array([3, 2, 1])) | ||
) | ||
) | ||
|
||
assert: result.as!(Tensor(Bool)).into_array == [True, True, False] | ||
)) | ||
|
||
:it "complains if the operands are of different types" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.lesser!("example" | ||
g.const!("x", Tensor(I32).from_array([1, 2, 3])) | ||
g.const!("y", Tensor(I8).from_array([3, 2, 1])) | ||
) | ||
) | ||
|
||
:it "complains if the operands are of different sizes" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.lesser!("example" | ||
g.const!("x", Tensor(I32).from_array([1, 2, 3])) | ||
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0])) | ||
) | ||
) | ||
|
||
:it "complains if the operands are of different shapes" | ||
_WithGraphHelper.run(@env, False) -> (g, session | | ||
assert error: g.lesser!("example" | ||
g.const!("x", Tensor(I32).from_array([0, 1, 2, 3])) | ||
g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]).try_reshape([2, 2])) | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
:: Compares input values from the two operands to check if x is greater than y. | ||
:: | ||
:: The two inputs must be numeric tensors of the same type, size, and shape. | ||
:: The output is a boolean tensor indicating the result of each comparison. | ||
:: | ||
:: If the `or_equal` parameter is set to true, equal values in the operands | ||
:: will also cause the corresponding output value to be `True`. | ||
:struct box Tensor.Op.Greater | ||
:is Tensor.Op | ||
|
||
:fun non new!(graph Tensor.Graph, name, x, y, or_equal = False) | ||
op_name = if or_equal ("GreaterEqual" | "Greater") | ||
@_new(graph.new_operation(op_name, name) -> (builder | | ||
builder.add_input(x).add_input(y).finish! | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
:: Compares input values from the two operands to check if x is less than y. | ||
:: | ||
:: The two inputs must be numeric tensors of the same type, size, and shape. | ||
:: The output is a boolean tensor indicating the result of each comparison. | ||
:: | ||
:: If the `or_equal` parameter is set to true, equal values in the operands | ||
:: will also cause the corresponding output value to be `True`. | ||
:struct box Tensor.Op.Lesser | ||
:is Tensor.Op | ||
|
||
:fun non new!(graph Tensor.Graph, name, x, y, or_equal = False) | ||
op_name = if or_equal ("LessEqual" | "Less") | ||
@_new(graph.new_operation(op_name, name) -> (builder | | ||
builder.add_input(x).add_input(y).finish! | ||
)) |