Skip to content

Commit

Permalink
support layernorm channel first(C#) (#1204)
Browse files Browse the repository at this point in the history
* support LayerNorm channel first

* fix review : #1204
  • Loading branch information
curioyang committed May 14, 2024
1 parent de4b7bd commit ae5df31
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern3>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern4>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern5>();
p.Add<Passes.Rules.Neutral.ConvertLayerNormChannelFirstToLast>();
p.Add<Passes.Rules.Neutral.FoldGeluWithScale>();
p.Add<Passes.Rules.Neutral.FoldGeneralGelu>();
p.Add<Passes.Rules.Neutral.FoldSwishPattern1>();
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/IR/NN/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static class NN

public static Call BatchNormalization(Expr input, Expr scale, Expr bias, Expr input_mean, Expr input_var, Expr epsilon, Expr momentum) => new Call(new BatchNormalization(), input, scale, bias, input_mean, input_var, epsilon, momentum);

public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true) => new Call(new LayerNorm(axis, epsilon, hasMean), input, scale, bias);
public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true, bool channelFirst = false) => new Call(new LayerNorm(axis, epsilon, hasMean, channelFirst), input, scale, bias);

public static Call BatchToSpace(Expr input, Expr blockShape, Expr crops) => new Call(new BatchToSpace(), input, blockShape, crops);

Expand Down
4 changes: 3 additions & 1 deletion src/Nncase.Core/IR/NN/LayerNorm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,7 @@ public sealed partial class LayerNorm : Op

public bool UseMean { get; }

public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}";
public bool ChannelFirst { get; }

public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}, ChannelFirst: {ChannelFirst}";
}
127 changes: 121 additions & 6 deletions src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using DryIoc;
using Nncase.IR;
using Nncase.IR.F;
using Nncase.IR.Math;
using Nncase.IR.NN;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.F.NN;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.F.Tensors;
using static Nncase.PatternMatch.Utility;

Expand Down Expand Up @@ -72,7 +80,14 @@ public sealed partial class FoldLayerNormPattern1 : RewriteRule<CallPattern>
if (subCall[Binary.Lhs] == rd1Call[Reduce.Input] && divCall[Binary.Lhs] == subCall)
{
var axis = addBetaCall.CheckedShape.Count - gamma.CheckedShape.Count;
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta);
bool cFirst = false;
var axes = rd1Call[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, channelFirst: cFirst);
}

return null;
Expand Down Expand Up @@ -133,7 +148,14 @@ public sealed partial class FoldLayerNormPattern2 : RewriteRule<CallPattern>
divCall[Binary.Lhs] == subCall)
{
var axis = addBetaCall.CheckedShape.Count - gamma.CheckedShape.Count;
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta);
bool cFirst = false;
var axes = rd1Call[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, channelFirst: cFirst);
}

return null;
Expand Down Expand Up @@ -202,7 +224,14 @@ public sealed partial class FoldLayerNormPattern3 : RewriteRule<CallPattern>
mulXCall[Binary.Lhs] == subMuCall[Binary.Lhs] && mulXCall[Binary.Lhs] == rdMuCall[Reduce.Input])
{
var axis = addAllCall.CheckedShape.Count - gamma.CheckedShape.Count;
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta);
bool cFirst = false;
var axes = rdMuCall[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, channelFirst: cFirst);
}

return null;
Expand Down Expand Up @@ -273,7 +302,14 @@ public sealed partial class FoldLayerNormPattern4 : RewriteRule<CallPattern>
subCall[Binary.Lhs] == mulXCall[Binary.Lhs] && subCall[Binary.Rhs] == mulMuCall[Binary.Lhs] && mulXCall[Binary.Lhs] == meanCall[Reduce.Input])
{
var axis = addAllCall.CheckedShape.Count - gamma.CheckedShape.Count;
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta);
bool cFirst = false;
var axes = meanCall[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, channelFirst: cFirst);
}

return null;
Expand Down Expand Up @@ -321,15 +357,94 @@ public sealed partial class FoldLayerNormPattern5 : RewriteRule<CallPattern>
IsTensorConst("two"))),
IsTensorConst("eps"))))));

private Expr? GetReplace(Call pow2Call, TensorConst eps, TensorConst gamma, Expr input, TensorConst one, TensorConst two)
private Expr? GetReplace(Call pow2Call, Call rdVarCall, TensorConst eps, TensorConst gamma, Expr input, TensorConst one, TensorConst two)
{
if (input == pow2Call[Binary.Lhs] && one.Value.Cast<float>()[0] == 1f && two.Value.Cast<float>()[0] == 2f)
{
var axis = pow2Call.CheckedShape.Count - gamma.CheckedShape.Count;
var beta = Tensor.FromScalar(0f, gamma.CheckedShape);
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, hasMean: false);
bool cFirst = false;
var axes = rdVarCall[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, hasMean: false, channelFirst: cFirst);
}

return null;
}
}

[RuleGenerator]
public sealed partial class ConvertLayerNormChannelFirstToLast : RewriteRule<CallPattern>
{
public override CallPattern Pattern { get; } =
IsLayerNorm(
"ln",
"_",
_ => true,
IsWildcard("x"),
IsWildcard("scale"),
IsWildcard("bias"));

private static List<int> GetPermWithAxis(long axis, int shapeSize)
{
var perm = new List<int>();
for (int i = 0; i < shapeSize; i++)
{
if (i != axis)
{
perm.Add(i);
}
}

perm.Add((int)axis);
return perm;
}

private Expr? GetReplace(LayerNorm ln, Expr x, Expr scale, Expr bias)
{
if (!ln.ChannelFirst)
{
return null;
}

int axis = ln.Axis;
float eps = ln.Epsilon;
bool useMean = ln.UseMean;
if ((axis == x.CheckedShape.Count - 1) || (axis == -1))
{
return null;
}

var inPerm = GetPermWithAxis(axis, x.CheckedShape.Count);
var outPerm = new List<int>();
for (int i = 0; i < inPerm.Count; i++)
{
outPerm.Add(inPerm[inPerm[i]]);
}

var newScale = scale;
var newBias = bias;

// the permutation of scale and bias must be the same.
if (scale.CheckedShape.Count != 1 && bias.CheckedShape.Count != 1)
{
int axisGap = x.CheckedShape.Count - scale.CheckedShape.Count;
if (axisGap > axis)
{
// Never reach here.
return null;
}

var constPerm = GetPermWithAxis(axis - axisGap, scale.CheckedShape.Count);

newScale = Tensors.Transpose(scale, constPerm.ToArray());
newBias = Tensors.Transpose(bias, constPerm.ToArray());
}

return Tensors.Transpose(LayerNorm(x.CheckedShape.Count - 1, eps, Tensors.Transpose(x, inPerm.ToArray()), newScale, newBias, useMean, true), outPerm.ToArray());
}
}

0 comments on commit ae5df31

Please sign in to comment.