From 52d483fe23ee625c5f470c7b12bb811f7ccbb3a8 Mon Sep 17 00:00:00 2001 From: Joe Eli McIlvain Date: Fri, 31 Mar 2023 22:17:40 -0700 Subject: [PATCH] Add `Tensor.Op.Greater` and `Tensor.Op.Lesser`. These compare numeric tensors and output boolean tensors. --- spec/Main.savi | 2 ++ spec/Tensor.Op.Greater.Spec.savi | 51 ++++++++++++++++++++++++++++++++ spec/Tensor.Op.Lesser.Spec.savi | 51 ++++++++++++++++++++++++++++++++ src/Tensor.Graph.Helper.savi | 15 ++++++++++ src/Tensor.Op.Greater.savi | 15 ++++++++++ src/Tensor.Op.Lesser.savi | 15 ++++++++++ 6 files changed, 149 insertions(+) create mode 100644 spec/Tensor.Op.Greater.Spec.savi create mode 100644 spec/Tensor.Op.Lesser.Spec.savi create mode 100644 src/Tensor.Op.Greater.savi create mode 100644 src/Tensor.Op.Lesser.savi diff --git a/spec/Main.savi b/spec/Main.savi index f907590..7d27fe4 100644 --- a/spec/Main.savi +++ b/spec/Main.savi @@ -6,6 +6,8 @@ Spec.Run(Tensor.Op.Bitcast.Spec).new(env) Spec.Run(Tensor.Op.Cast.Spec).new(env) Spec.Run(Tensor.Op.Const.Spec).new(env) + Spec.Run(Tensor.Op.Greater.Spec).new(env) + Spec.Run(Tensor.Op.Lesser.Spec).new(env) Spec.Run(Tensor.Op.Logical.Spec).new(env) Spec.Run(Tensor.Op.MatMul.Spec).new(env) Spec.Run(Tensor.Op.Softmax.Spec).new(env) diff --git a/spec/Tensor.Op.Greater.Spec.savi b/spec/Tensor.Op.Greater.Spec.savi new file mode 100644 index 0000000..27693d0 --- /dev/null +++ b/spec/Tensor.Op.Greater.Spec.savi @@ -0,0 +1,51 @@ +:class Tensor.Op.Greater.Spec + :is Spec + :const describes: "Tensor.Op.Greater" + + :it "checks if the 1st operand's values are greater than those of the 2nd" + _WithGraphHelper.run(@env) -> (g, session | assert no_error: ( + result = session.compute!( + g.greater!("example" + g.const!("x", Tensor(I32).from_array([1, 2, 3])) + g.const!("y", Tensor(I32).from_array([3, 2, 1])) + ) + ) + + assert: result.as!(Tensor(Bool)).into_array == [False, False, True] + )) + + :it "may optionally include equal values as being true" + _WithGraphHelper.run(@env) -> (g, session | assert no_error: ( + result = session.compute!( + g.greater_or_equal!("example" + g.const!("x", Tensor(I32).from_array([1, 2, 3])) + g.const!("y", Tensor(I32).from_array([3, 2, 1])) + ) + ) + + assert: result.as!(Tensor(Bool)).into_array == [False, True, True] + )) + + :it "complains if the operands are of different types" + _WithGraphHelper.run(@env, False) -> (g, session | + assert error: g.greater!("example" + g.const!("x", Tensor(I32).from_array([1, 2, 3])) + g.const!("y", Tensor(I8).from_array([3, 2, 1])) + ) + ) + + :it "complains if the operands are of different sizes" + _WithGraphHelper.run(@env, False) -> (g, session | + assert error: g.greater!("example" + g.const!("x", Tensor(I32).from_array([1, 2, 3])) + g.const!("y", Tensor(I32).from_array([3, 2, 1, 0])) + ) + ) + + :it "complains if the operands are of different shapes" + _WithGraphHelper.run(@env, False) -> (g, session | + assert error: g.greater!("example" + g.const!("x", Tensor(I32).from_array([0, 1, 2, 3])) + g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]).try_reshape([2, 2])) + ) + ) diff --git a/spec/Tensor.Op.Lesser.Spec.savi b/spec/Tensor.Op.Lesser.Spec.savi new file mode 100644 index 0000000..656ad83 --- /dev/null +++ b/spec/Tensor.Op.Lesser.Spec.savi @@ -0,0 +1,51 @@ +:class Tensor.Op.Lesser.Spec + :is Spec + :const describes: "Tensor.Op.Lesser" + + :it "checks if the 1st operand's values are lesser than those of the 2nd" + _WithGraphHelper.run(@env) -> (g, session | assert no_error: ( + result = session.compute!( + g.lesser!("example" + g.const!("x", Tensor(I32).from_array([1, 2, 3])) + g.const!("y", Tensor(I32).from_array([3, 2, 1])) + ) + ) + + assert: result.as!(Tensor(Bool)).into_array == [True, False, False] + )) + + :it "may optionally include equal values as being true" + _WithGraphHelper.run(@env) -> (g, session | assert no_error: ( + result = session.compute!( + g.lesser_or_equal!("example" + g.const!("x", Tensor(I32).from_array([1, 2, 3])) + g.const!("y", Tensor(I32).from_array([3, 2, 1])) + ) + ) + + assert: result.as!(Tensor(Bool)).into_array == [True, True, False] + )) + + :it "complains if the operands are of different types" + _WithGraphHelper.run(@env, False) -> (g, session | + assert error: g.lesser!("example" + g.const!("x", Tensor(I32).from_array([1, 2, 3])) + g.const!("y", Tensor(I8).from_array([3, 2, 1])) + ) + ) + + :it "complains if the operands are of different sizes" + _WithGraphHelper.run(@env, False) -> (g, session | + assert error: g.lesser!("example" + g.const!("x", Tensor(I32).from_array([1, 2, 3])) + g.const!("y", Tensor(I32).from_array([3, 2, 1, 0])) + ) + ) + + :it "complains if the operands are of different shapes" + _WithGraphHelper.run(@env, False) -> (g, session | + assert error: g.lesser!("example" + g.const!("x", Tensor(I32).from_array([0, 1, 2, 3])) + g.const!("y", Tensor(I32).from_array([3, 2, 1, 0]).try_reshape([2, 2])) + ) + ) diff --git a/src/Tensor.Graph.Helper.savi b/src/Tensor.Graph.Helper.savi index 637429c..eba2918 100644 --- a/src/Tensor.Graph.Helper.savi +++ b/src/Tensor.Graph.Helper.savi @@ -24,6 +24,21 @@ :fun ref logical_or!(name, x, y) Tensor.Op.Logical.Or.new!(@graph, name, x, y) + /// + // Comparative Operations + + :fun ref greater!(name, x, y) + Tensor.Op.Greater.new!(@graph, name, x, y, False) + + :fun ref greater_or_equal!(name, x, y) + Tensor.Op.Greater.new!(@graph, name, x, y, True) + + :fun ref lesser!(name, x, y) + Tensor.Op.Lesser.new!(@graph, name, x, y, False) + + :fun ref lesser_or_equal!(name, x, y) + Tensor.Op.Lesser.new!(@graph, name, x, y, True) + /// // Type Conversions diff --git a/src/Tensor.Op.Greater.savi b/src/Tensor.Op.Greater.savi new file mode 100644 index 0000000..c04f148 --- /dev/null +++ b/src/Tensor.Op.Greater.savi @@ -0,0 +1,15 @@ +:: Compares input values from the two operands to check if x is greater than y. +:: +:: The two inputs must be numeric tensors of the same type, size, and shape. +:: The output is a boolean tensor indicating the result of each comparison. +:: +:: If the `or_equal` parameter is set to true, equal values in the operands +:: will also cause the corresponding output value to be `True`. +:struct box Tensor.Op.Greater + :is Tensor.Op + + :fun non new!(graph Tensor.Graph, name, x, y, or_equal = False) + op_name = if or_equal ("GreaterEqual" | "Greater") + @_new(graph.new_operation(op_name, name) -> (builder | + builder.add_input(x).add_input(y).finish! + )) diff --git a/src/Tensor.Op.Lesser.savi b/src/Tensor.Op.Lesser.savi new file mode 100644 index 0000000..bbd9c74 --- /dev/null +++ b/src/Tensor.Op.Lesser.savi @@ -0,0 +1,15 @@ +:: Compares input values from the two operands to check if x is less than y. +:: +:: The two inputs must be numeric tensors of the same type, size, and shape. +:: The output is a boolean tensor indicating the result of each comparison. +:: +:: If the `or_equal` parameter is set to true, equal values in the operands +:: will also cause the corresponding output value to be `True`. +:struct box Tensor.Op.Lesser + :is Tensor.Op + + :fun non new!(graph Tensor.Graph, name, x, y, or_equal = False) + op_name = if or_equal ("LessEqual" | "Less") + @_new(graph.new_operation(op_name, name) -> (builder | + builder.add_input(x).add_input(y).finish! + ))