Skip to content

Commit

Permalink
Feature/add squeeze binary shape (#1059)
Browse files Browse the repository at this point in the history
* add SqueezeBinaryShape

* fix np.bool

* format

* recover binary test

* add unit test for SqueezeBinaryShape

* fix pytest: data with [0,1) maybe occur 'x div zero'

* fix review

---------

Co-authored-by: yanghaoqi <yanghaoqi_intern@canaan-creative.com>
  • Loading branch information
curioyang and yanghaoqi committed Aug 19, 2023
1 parent 9026361 commit 5f64821
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 46 deletions.
42 changes: 22 additions & 20 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,28 @@ public void TargetIndependentPass(IPassManager passManager)
});

passManager.AddWithName<DataflowPass>("SqueezeShape").Configure(p =>
{
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern1>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern2>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern3>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern4>();
p.Add<Passes.Rules.Neutral.FoldGeluWithScale>();
p.Add<Passes.Rules.Neutral.FoldGeneralGelu>();
p.Add<Passes.Rules.Neutral.FoldSwishPattern1>();
p.Add<Passes.Rules.Neutral.FoldSwishPattern2>();
p.Add<Passes.Rules.Neutral.FoldHardSwish1>();
p.Add<Passes.Rules.Neutral.FoldHardSwish2>();
p.Add<Passes.Rules.Neutral.FoldHardSwish3>();
p.Add<Passes.Rules.Neutral.FoldHardSwish4>();
p.Add<Passes.Rules.Neutral.FoldHardSwish5>();
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.FocusFull>();
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
});
{
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern1>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern2>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern3>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern4>();
p.Add<Passes.Rules.Neutral.FoldGeluWithScale>();
p.Add<Passes.Rules.Neutral.FoldGeneralGelu>();
p.Add<Passes.Rules.Neutral.FoldSwishPattern1>();
p.Add<Passes.Rules.Neutral.FoldSwishPattern2>();
p.Add<Passes.Rules.Neutral.FoldHardSwish1>();
p.Add<Passes.Rules.Neutral.FoldHardSwish2>();
p.Add<Passes.Rules.Neutral.FoldHardSwish3>();
p.Add<Passes.Rules.Neutral.FoldHardSwish4>();
p.Add<Passes.Rules.Neutral.FoldHardSwish5>();
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.FocusFull>();
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
});

passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
{
p.Add<Passes.Rules.Neutral.FoldConstCall>();
Expand Down
156 changes: 156 additions & 0 deletions src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Nncase.IR.NN;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.F.Math;
using static Nncase.IR.F.Tensors;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
Expand Down Expand Up @@ -245,3 +246,158 @@ public sealed partial class SqueezeTransposeShape : IRewriteRule
return Reshape(Transpose(Reshape(input, new_shape.ToArray()), new_perm.ToArray()), newOutputShape);
}
}

[RuleGenerator]
public sealed partial class SqueezeBinaryShape : IRewriteRule
{
/// <inheritdoc/>
public IPattern Pattern { get; } = IsBinary("binary", "binaryCall", x => true, IsWildcard("lhs") with { TypePattern = HasFixedShape() }, IsWildcard("rhs") with { TypePattern = HasFixedShape() });

/// <summary>
/// Squeeze input shape.
/// </summary>
/// <param name="a"> left input shape.</param>
/// <param name="b"> right input shape.</param>
/// <returns> Squeeze flag, new lhs, new rhs. </returns>
public (bool SqueezeOrNot, List<int> NewAShape, List<int> NewBShape) SqueezeInputShape(List<int> a, List<int> b)
{
var aSize = a.Count;
var bSize = b.Count;

var squeezeTimes = Math.Max(
aSize > 4 ? aSize - 4 : 0,
bSize > 4 ? bSize - 4 : 0);

if (squeezeTimes <= 0)
{
return (false, a, b);
}

List<int> newA = a;
List<int> newB = b;

if (aSize == bSize)
{
if (a.SequenceEqual(b))
{
newA = SqueezeShape(a);
newB = SqueezeShape(b);
}
else
{
var canFold = Enumerable.Repeat(true, aSize).ToArray();
var foldIndexCouples = new List<(int, int)>();

for (int i = 0; i < aSize; i++)
{
if (a[i] != b[i])
{
canFold[i] = false;
}
}

for (int i = aSize - 1; i > 0; i--)
{
if (canFold[i] && canFold[i - 1])
{
foldIndexCouples.Add((i - 1, i));
}
}

while (squeezeTimes > 0 && foldIndexCouples.Count > 0)
{
var (front, back) = foldIndexCouples[0];
newA[front] *= newA[back];
newB[front] *= newB[back];

newA.RemoveAt(back);
newB.RemoveAt(back);

foldIndexCouples.RemoveAt(0);
squeezeTimes--;
}

for (int i = newA.Count - 1, count = newA.Count - 5; i >= 0 && count >= 0; i--)
{
if (newA[i] * newB[i] == 1)
{
newA.RemoveAt(i);
newB.RemoveAt(i);
count--;
}
}

if (newA.Count > 4)
{
return (false, newA, newB);
}
}
}
else
{
if (aSize != 1)
{
newA = SqueezeShape(a);
}

if (bSize != 1)
{
newB = SqueezeShape(b);
}
}

return (true, newA, newB);
}

private static List<int> SqueezeShape(List<int> shape)
{
var newShape = new List<int> { 1, 1, 1, 1 };

for (int i = shape.Count - 1, k = 3; i >= 0; i--)
{
newShape[k] *= shape[i];
if (k > 0)
{
k--;
}
}

return newShape;
}

private static List<int> GetOutputShape(List<int> a, List<int> b)
{
if (a.Count == 1)
{
return b;
}

if (b.Count == 1)
{
return a;
}

var outputShape = a;
for (int i = 0; i < a.Count; i++)
{
outputShape[i] = Math.Max(a[i], b[i]);
}

return outputShape;
}

private Expr? GetReplace(Binary binary, Call binaryCall, Expr lhs, Expr rhs)
{
var lShape = lhs.CheckedShape.Count == 0 ? new Shape(new List<int> { 1 }) : lhs.CheckedShape;
var rShape = rhs.CheckedShape.Count == 0 ? new Shape(new List<int> { 1 }) : rhs.CheckedShape;
var (result, newLShape, newRShape) = SqueezeInputShape(lShape.ToValueList(), rShape.ToValueList());
if (!result)
{
return null;
}

var outputShape = GetOutputShape(lShape.ToValueList(), rShape.ToValueList());

return Reshape(Binary(binary.BinaryOp, Reshape(lhs, newLShape.ToArray()), Reshape(rhs, newRShape.ToArray())), outputShape.ToArray());
}
}
51 changes: 51 additions & 0 deletions src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeTransposeShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Text;
using System.Threading.Tasks;
using Nncase.Diagnostics;
using Nncase.IR.Math;
using Nncase.Passes;
using Nncase.Passes.Rules.Neutral;
using Nncase.Tests.TestFixture;
Expand Down Expand Up @@ -60,3 +61,53 @@ public void TestSqueezeTransposeShapeNegative(int[] shape, int[] perm)
TestNotMatch<SqueezeTransposeShape>(rootPre);
}
}

public class UnitTestSqueezeBinaryShape : TransformTestBase
{
public static IEnumerable<object[]> TestSqueezeBinaryShapePosivateData =>
new[]
{
new object[] { new[] { 1 }, new[] { 1, 2, 4, 8, 3 } },
new object[] { new[] { 1, 2, 4, 8, 3 }, new[] { 1 } },
new object[] { new[] { 1, 2, 4, 8, 3 }, new[] { 1, 1, 4, 1, 1 } },
new object[] { new[] { 1, 2, 4, 8, 3 }, new[] { 1, 1, 4, 8, 1 } },
new object[] { new[] { 1, 2, 4, 8, 3 }, new[] { 1, 2, 4, 8, 3 } },
new object[] { new[] { 1, 2, 1, 8, 1 }, new[] { 3, 1, 6, 1, 1 } },
new object[] { new[] { 1, 2, 4, 1, 3, 1, 3 }, new[] { 1, 2, 4, 1, 1, 5, 1 } },
new object[] { new[] { 2, 3, 4, 8, 3, 5, 3, 5 }, new[] { 2, 3, 4, 8, 1, 5, 3, 5 } },
};

public static IEnumerable<object[]> TestSqueezeBinaryShapeNegativeData =>
new[]
{
new object[] { new[] { 2 }, new[] { 2 } },
new object[] { new[] { 1, 2 }, new[] { 1 } },
new object[] { new[] { 1, 2 }, new[] { 1, 2 } },
new object[] { new[] { 2 }, new[] { 1 } },
new object[] { new[] { 1, 2, 4 }, new[] { 1 } },
new object[] { new[] { 1, 2, 4 }, new[] { 1, 2, 1 } },
new object[] { new[] { 1, 2, 4 }, new[] { 1, 2, 4 } },
new object[] { new[] { 1, 2, 1, 5, 8 }, new[] { 3, 1, 4, 1, 8 } },
new object[] { new[] { 2, 3, 4, 8, 3, 5, 3, 5 }, new[] { 2, 1, 4, 8, 1, 5, 3, 5 } },
};

[Theory]
[MemberData(nameof(TestSqueezeBinaryShapePosivateData))]
public void TestSqueezeBinaryShapePositivate(int[] lShape, int[] rShape)
{
var a = Random.Normal(DataTypes.Float32, 0, 1, 0, lShape);
var b = Random.Normal(DataTypes.Float32, 0, 1, 0, rShape);
var rootPre = Math.Binary(BinaryOp.Add, a, b);
TestMatched<SqueezeBinaryShape>(rootPre);
}

[Theory]
[MemberData(nameof(TestSqueezeBinaryShapeNegativeData))]
public void TestSqueezeBinaryShapeNegative(int[] lShape, int[] rShape)
{
var a = Random.Normal(DataTypes.Float32, 0, 1, 0, lShape);
var b = Random.Normal(DataTypes.Float32, 0, 1, 0, rShape);
var rootPre = Math.Binary(BinaryOp.Add, a, b);
TestNotMatch<SqueezeBinaryShape>(rootPre);
}
}
2 changes: 1 addition & 1 deletion tests/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def from_random(self, shape: List[int], dtype: np.dtype, abs: bool = False) -> n
data = np.random.randint(0, 256, shape)
elif dtype == np.int8:
data = np.random.randint(-128, 128, shape)
elif dtype == np.bool:
elif dtype == bool:
data = np.random.rand(*shape) > 0.5
elif dtype == np.int32:
data = np.random.randint(1, 5, size=shape, dtype='int32')
Expand Down
47 changes: 22 additions & 25 deletions tests/importer/onnx_/basic/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,42 @@ def _make_module(v_shape):
class BinaryModule(torch.nn.Module):
def __init__(self):
super(BinaryModule, self).__init__()
# self.v = torch.from_numpy(np.random.rand(*v_shape).astype(np.float32))
self.v = torch.from_numpy(np.ones(v_shape).astype(np.float32))

def forward(self, x):
outs = []
outs.append(torch.add(x, self.v))
# outs.append(torch.mul(x, self.v))
# outs.append(torch.sub(x, self.v))
# outs.append(torch.max(x, self.v))
# outs.append(torch.div(x, self.v))
# outs.append(torch.min(x, self.v))
# outs.append(torch.fmod(x, self.v))
outs.append(torch.mul(x, self.v))
outs.append(torch.sub(x, self.v))
outs.append(torch.max(x, self.v))
outs.append(torch.div(x, self.v))
outs.append(torch.min(x, self.v))
outs.append(torch.fmod(x, self.v))
return outs

return BinaryModule()


lhs_shapes = [
# [3],
# [64, 3],
# [3, 64, 3],
# [8, 3, 64, 3]
[1, 3, 24, 24]
[3],
[64, 3],
[3, 64, 3],
[8, 3, 64, 3],
]

rhs_shapes = [
# [1],
# [3],
# [1, 3],
# [64, 1],
# [64, 3],
# [3, 64, 1],
# [3, 64, 3],
# [8, 3, 64, 1],
# [8, 3, 64, 3],
# [8, 3, 1, 3],
# [8, 1, 64, 3],
# [1, 3, 64, 1]
[1, 3, 24, 24]
[1],
[3],
[1, 3],
[64, 1],
[64, 3],
[3, 64, 1],
[3, 64, 3],
[8, 3, 64, 1],
[8, 3, 64, 3],
[8, 3, 1, 3],
[8, 1, 64, 3],
[1, 3, 64, 1],
]


Expand Down

0 comments on commit 5f64821

Please sign in to comment.