Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support layernorm channel first(C#) #1204

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern3>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern4>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern5>();
p.Add<Passes.Rules.Neutral.ConvertLayerNormChannelFirstToLast>();
p.Add<Passes.Rules.Neutral.FoldGeluWithScale>();
p.Add<Passes.Rules.Neutral.FoldGeneralGelu>();
p.Add<Passes.Rules.Neutral.FoldSwishPattern1>();
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/IR/NN/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static class NN

public static Call BatchNormalization(Expr input, Expr scale, Expr bias, Expr input_mean, Expr input_var, Expr epsilon, Expr momentum) => new Call(new BatchNormalization(), input, scale, bias, input_mean, input_var, epsilon, momentum);

public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true) => new Call(new LayerNorm(axis, epsilon, hasMean), input, scale, bias);
public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true, bool channelFirst = false) => new Call(new LayerNorm(axis, epsilon, hasMean, channelFirst), input, scale, bias);

public static Call BatchToSpace(Expr input, Expr blockShape, Expr crops) => new Call(new BatchToSpace(), input, blockShape, crops);

Expand Down
4 changes: 3 additions & 1 deletion src/Nncase.Core/IR/NN/LayerNorm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,7 @@ public sealed partial class LayerNorm : Op

public bool UseMean { get; }

public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}";
public bool ChannelFirst { get; }

public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}, ChannelFirst: {ChannelFirst}";
}
127 changes: 121 additions & 6 deletions src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using DryIoc;
using Nncase.IR;
using Nncase.IR.F;
using Nncase.IR.Math;
using Nncase.IR.NN;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.F.NN;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.F.Tensors;
using static Nncase.PatternMatch.Utility;

Expand Down Expand Up @@ -72,7 +80,14 @@ public sealed partial class FoldLayerNormPattern1 : RewriteRule<CallPattern>
if (subCall[Binary.Lhs] == rd1Call[Reduce.Input] && divCall[Binary.Lhs] == subCall)
{
var axis = addBetaCall.CheckedShape.Count - gamma.CheckedShape.Count;
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta);
bool cFirst = false;
var axes = rd1Call[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, channelFirst: cFirst);
}

return null;
Expand Down Expand Up @@ -133,7 +148,14 @@ public sealed partial class FoldLayerNormPattern2 : RewriteRule<CallPattern>
divCall[Binary.Lhs] == subCall)
{
var axis = addBetaCall.CheckedShape.Count - gamma.CheckedShape.Count;
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta);
bool cFirst = false;
var axes = rd1Call[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, channelFirst: cFirst);
}

return null;
Expand Down Expand Up @@ -202,7 +224,14 @@ public sealed partial class FoldLayerNormPattern3 : RewriteRule<CallPattern>
mulXCall[Binary.Lhs] == subMuCall[Binary.Lhs] && mulXCall[Binary.Lhs] == rdMuCall[Reduce.Input])
{
var axis = addAllCall.CheckedShape.Count - gamma.CheckedShape.Count;
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta);
bool cFirst = false;
var axes = rdMuCall[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, channelFirst: cFirst);
}

return null;
Expand Down Expand Up @@ -273,7 +302,14 @@ public sealed partial class FoldLayerNormPattern4 : RewriteRule<CallPattern>
subCall[Binary.Lhs] == mulXCall[Binary.Lhs] && subCall[Binary.Rhs] == mulMuCall[Binary.Lhs] && mulXCall[Binary.Lhs] == meanCall[Reduce.Input])
{
var axis = addAllCall.CheckedShape.Count - gamma.CheckedShape.Count;
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta);
bool cFirst = false;
var axes = meanCall[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, channelFirst: cFirst);
}

return null;
Expand Down Expand Up @@ -321,15 +357,94 @@ public sealed partial class FoldLayerNormPattern5 : RewriteRule<CallPattern>
IsTensorConst("two"))),
IsTensorConst("eps"))))));

private Expr? GetReplace(Call pow2Call, TensorConst eps, TensorConst gamma, Expr input, TensorConst one, TensorConst two)
private Expr? GetReplace(Call pow2Call, Call rdVarCall, TensorConst eps, TensorConst gamma, Expr input, TensorConst one, TensorConst two)
{
if (input == pow2Call[Binary.Lhs] && one.Value.Cast<float>()[0] == 1f && two.Value.Cast<float>()[0] == 2f)
{
var axis = pow2Call.CheckedShape.Count - gamma.CheckedShape.Count;
var beta = Tensor.FromScalar(0f, gamma.CheckedShape);
return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, hasMean: false);
bool cFirst = false;
var axes = rdVarCall[Reduce.Axis].Evaluate().AsTensor().ToArray<int>();
if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1)
{
cFirst = true;
}

return LayerNorm(axis, eps.Value.Cast<float>()[0], input, gamma, beta, hasMean: false, channelFirst: cFirst);
}

return null;
}
}

[RuleGenerator]
public sealed partial class ConvertLayerNormChannelFirstToLast : RewriteRule<CallPattern>
{
public override CallPattern Pattern { get; } =
IsLayerNorm(
"ln",
"_",
_ => true,
IsWildcard("x"),
IsWildcard("scale"),
IsWildcard("bias"));

private static List<int> GetPermWithAxis(long axis, int shapeSize)
{
xhuohai marked this conversation as resolved.
Show resolved Hide resolved
var perm = new List<int>();
for (int i = 0; i < shapeSize; i++)
{
if (i != axis)
{
perm.Add(i);
}
}

perm.Add((int)axis);
return perm;
}

private Expr? GetReplace(LayerNorm ln, Expr x, Expr scale, Expr bias)
{
if (!ln.ChannelFirst)
{
return null;
}

int axis = ln.Axis;
float eps = ln.Epsilon;
bool useMean = ln.UseMean;
if ((axis == x.CheckedShape.Count - 1) || (axis == -1))
{
return null;
}

var inPerm = GetPermWithAxis(axis, x.CheckedShape.Count);
var outPerm = new List<int>();
for (int i = 0; i < inPerm.Count; i++)
{
outPerm.Add(inPerm[inPerm[i]]);
}

var newScale = scale;
var newBias = bias;

// the permutation of scale and bias must be the same.
if (scale.CheckedShape.Count != 1 && bias.CheckedShape.Count != 1)
{
int axisGap = x.CheckedShape.Count - scale.CheckedShape.Count;
if (axisGap > axis)
{
// Never reach here.
return null;
}

var constPerm = GetPermWithAxis(axis - axisGap, scale.CheckedShape.Count);

newScale = Tensors.Transpose(scale, constPerm.ToArray());
newBias = Tensors.Transpose(bias, constPerm.ToArray());
}

return Tensors.Transpose(LayerNorm(x.CheckedShape.Count - 1, eps, Tensors.Transpose(x, inPerm.ToArray()), newScale, newBias, useMean, true), outPerm.ToArray());
}
}
Loading