From ae5df31a283c8c4c9450ac9282a9002adf3356db Mon Sep 17 00:00:00 2001 From: Curio Yang <39184746+curioyang@users.noreply.github.com> Date: Tue, 14 May 2024 17:44:31 +0800 Subject: [PATCH] support layernorm channel first(C#) (#1204) * support LayerNorm channel first * fix review : https://github.com/kendryte/nncase/pull/1204\#discussion_r1599285210 --- src/Nncase.Compiler/Compiler.cs | 1 + src/Nncase.Core/IR/NN/Functional.cs | 2 +- src/Nncase.Core/IR/NN/LayerNorm.cs | 4 +- .../Rules/Neutral/FoldLayerNorm.cs | 127 +++++++++++++++++- 4 files changed, 126 insertions(+), 8 deletions(-) diff --git a/src/Nncase.Compiler/Compiler.cs b/src/Nncase.Compiler/Compiler.cs index c21b7247d3..f34c90a08b 100644 --- a/src/Nncase.Compiler/Compiler.cs +++ b/src/Nncase.Compiler/Compiler.cs @@ -108,6 +108,7 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); + p.Add(); p.Add(); p.Add(); p.Add(); diff --git a/src/Nncase.Core/IR/NN/Functional.cs b/src/Nncase.Core/IR/NN/Functional.cs index e8a44d8a38..537da60eb0 100644 --- a/src/Nncase.Core/IR/NN/Functional.cs +++ b/src/Nncase.Core/IR/NN/Functional.cs @@ -34,7 +34,7 @@ public static class NN public static Call BatchNormalization(Expr input, Expr scale, Expr bias, Expr input_mean, Expr input_var, Expr epsilon, Expr momentum) => new Call(new BatchNormalization(), input, scale, bias, input_mean, input_var, epsilon, momentum); - public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true) => new Call(new LayerNorm(axis, epsilon, hasMean), input, scale, bias); + public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true, bool channelFirst = false) => new Call(new LayerNorm(axis, epsilon, hasMean, channelFirst), input, scale, bias); public static Call BatchToSpace(Expr input, Expr blockShape, Expr crops) => new Call(new BatchToSpace(), input, blockShape, crops); diff --git a/src/Nncase.Core/IR/NN/LayerNorm.cs b/src/Nncase.Core/IR/NN/LayerNorm.cs index 2474f44fc2..421421e908 100644 --- a/src/Nncase.Core/IR/NN/LayerNorm.cs +++ b/src/Nncase.Core/IR/NN/LayerNorm.cs @@ -39,5 +39,7 @@ public sealed partial class LayerNorm : Op public bool UseMean { get; } - public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}"; + public bool ChannelFirst { get; } + + public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}, ChannelFirst: {ChannelFirst}"; } diff --git a/src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs b/src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs index c8df9e9e62..11033dbcab 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldLayerNorm.cs @@ -1,11 +1,19 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System; +using System.Collections.Generic; +using System.Linq; +using DryIoc; using Nncase.IR; +using Nncase.IR.F; using Nncase.IR.Math; +using Nncase.IR.NN; +using Nncase.IR.Tensors; using Nncase.PatternMatch; using static Nncase.IR.F.NN; using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.NN; using static Nncase.PatternMatch.F.Tensors; using static Nncase.PatternMatch.Utility; @@ -72,7 +80,14 @@ public sealed partial class FoldLayerNormPattern1 : RewriteRule if (subCall[Binary.Lhs] == rd1Call[Reduce.Input] && divCall[Binary.Lhs] == subCall) { var axis = addBetaCall.CheckedShape.Count - gamma.CheckedShape.Count; - return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta); + bool cFirst = false; + var axes = rd1Call[Reduce.Axis].Evaluate().AsTensor().ToArray(); + if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1) + { + cFirst = true; + } + + return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta, channelFirst: cFirst); } return null; @@ -133,7 +148,14 @@ public sealed partial class FoldLayerNormPattern2 : RewriteRule divCall[Binary.Lhs] == subCall) { var axis = addBetaCall.CheckedShape.Count - gamma.CheckedShape.Count; - return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta); + bool cFirst = false; + var axes = rd1Call[Reduce.Axis].Evaluate().AsTensor().ToArray(); + if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1) + { + cFirst = true; + } + + return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta, channelFirst: cFirst); } return null; @@ -202,7 +224,14 @@ public sealed partial class FoldLayerNormPattern3 : RewriteRule mulXCall[Binary.Lhs] == subMuCall[Binary.Lhs] && mulXCall[Binary.Lhs] == rdMuCall[Reduce.Input]) { var axis = addAllCall.CheckedShape.Count - gamma.CheckedShape.Count; - return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta); + bool cFirst = false; + var axes = rdMuCall[Reduce.Axis].Evaluate().AsTensor().ToArray(); + if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1) + { + cFirst = true; + } + + return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta, channelFirst: cFirst); } return null; @@ -273,7 +302,14 @@ public sealed partial class FoldLayerNormPattern4 : RewriteRule subCall[Binary.Lhs] == mulXCall[Binary.Lhs] && subCall[Binary.Rhs] == mulMuCall[Binary.Lhs] && mulXCall[Binary.Lhs] == meanCall[Reduce.Input]) { var axis = addAllCall.CheckedShape.Count - gamma.CheckedShape.Count; - return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta); + bool cFirst = false; + var axes = meanCall[Reduce.Axis].Evaluate().AsTensor().ToArray(); + if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1) + { + cFirst = true; + } + + return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta, channelFirst: cFirst); } return null; @@ -321,15 +357,94 @@ public sealed partial class FoldLayerNormPattern5 : RewriteRule IsTensorConst("two"))), IsTensorConst("eps")))))); - private Expr? GetReplace(Call pow2Call, TensorConst eps, TensorConst gamma, Expr input, TensorConst one, TensorConst two) + private Expr? GetReplace(Call pow2Call, Call rdVarCall, TensorConst eps, TensorConst gamma, Expr input, TensorConst one, TensorConst two) { if (input == pow2Call[Binary.Lhs] && one.Value.Cast()[0] == 1f && two.Value.Cast()[0] == 2f) { var axis = pow2Call.CheckedShape.Count - gamma.CheckedShape.Count; var beta = Tensor.FromScalar(0f, gamma.CheckedShape); - return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta, hasMean: false); + bool cFirst = false; + var axes = rdVarCall[Reduce.Axis].Evaluate().AsTensor().ToArray(); + if (axes.Length == 1 && axes[0] != input.CheckedShape.Count - 1 && axes[0] != -1) + { + cFirst = true; + } + + return LayerNorm(axis, eps.Value.Cast()[0], input, gamma, beta, hasMean: false, channelFirst: cFirst); } return null; } } + +[RuleGenerator] +public sealed partial class ConvertLayerNormChannelFirstToLast : RewriteRule +{ + public override CallPattern Pattern { get; } = + IsLayerNorm( + "ln", + "_", + _ => true, + IsWildcard("x"), + IsWildcard("scale"), + IsWildcard("bias")); + + private static List GetPermWithAxis(long axis, int shapeSize) + { + var perm = new List(); + for (int i = 0; i < shapeSize; i++) + { + if (i != axis) + { + perm.Add(i); + } + } + + perm.Add((int)axis); + return perm; + } + + private Expr? GetReplace(LayerNorm ln, Expr x, Expr scale, Expr bias) + { + if (!ln.ChannelFirst) + { + return null; + } + + int axis = ln.Axis; + float eps = ln.Epsilon; + bool useMean = ln.UseMean; + if ((axis == x.CheckedShape.Count - 1) || (axis == -1)) + { + return null; + } + + var inPerm = GetPermWithAxis(axis, x.CheckedShape.Count); + var outPerm = new List(); + for (int i = 0; i < inPerm.Count; i++) + { + outPerm.Add(inPerm[inPerm[i]]); + } + + var newScale = scale; + var newBias = bias; + + // the permutation of scale and bias must be the same. + if (scale.CheckedShape.Count != 1 && bias.CheckedShape.Count != 1) + { + int axisGap = x.CheckedShape.Count - scale.CheckedShape.Count; + if (axisGap > axis) + { + // Never reach here. + return null; + } + + var constPerm = GetPermWithAxis(axis - axisGap, scale.CheckedShape.Count); + + newScale = Tensors.Transpose(scale, constPerm.ToArray()); + newBias = Tensors.Transpose(bias, constPerm.ToArray()); + } + + return Tensors.Transpose(LayerNorm(x.CheckedShape.Count - 1, eps, Tensors.Transpose(x, inPerm.ToArray()), newScale, newBias, useMean, true), outPerm.ToArray()); + } +}