diff --git a/src/Nncase.Evaluator/NN/LayerNorm.cs b/src/Nncase.Evaluator/NN/LayerNorm.cs index b76300efbf..f844f5fc75 100644 --- a/src/Nncase.Evaluator/NN/LayerNorm.cs +++ b/src/Nncase.Evaluator/NN/LayerNorm.cs @@ -145,7 +145,7 @@ private UInt128 GetRingReduceCommunicate(DistributedType distributedType, int[] #if true private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis, float epsilon, bool useMean = true) { - int outputSize = 1; + int outerSize = 1; int innerSize = 1; float[] inputArray = input.ToArray(); float[] outputArray = new float[inputArray.Length]; @@ -157,7 +157,7 @@ private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis, for (int i = 0; i < axis; i++) { - outputSize *= inShape[i]; + outerSize *= inShape[i]; } for (int i = axis; i < inShape.Length; i++) @@ -165,15 +165,17 @@ private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis, innerSize *= inShape[i]; } - for (int batch = 0; batch < outputSize; batch++) + for (int batch = 0; batch < outerSize; batch++) { float mean1 = 0f; if (useMean) { for (int i = 0; i < innerSize; i++) { - mean1 += inputArray[(i + (batch * innerSize)) % inputArray.Length] / innerSize; + mean1 += inputArray[(i + (batch * innerSize)) % inputArray.Length]; } + + mean1 /= innerSize; } float[] sub = new float[innerSize]; @@ -191,9 +193,11 @@ private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis, float mean2 = 0f; for (int i = 0; i < innerSize; i++) { - mean2 += pow[i] / innerSize; + mean2 += pow[i]; } + mean2 /= innerSize; + float add = mean2 + epsilon; float sqrt = (float)System.Math.Sqrt(add); diff --git a/src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs b/src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs index fc34c4dc72..5e69dd90d5 100644 --- a/src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs +++ b/src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs @@ -43,16 +43,10 @@ public sealed partial class BatchNormToBinary : IRewriteRule var shape = input.CheckedShape.ToValueArray(); var bnShape = Enumerable.Repeat(1, shape.Length - 1).ToArray(); bnShape[0] = shape[1]; - var scaleBn = IR.F.Math.Div(gamma, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))); - var biasBn = IR.F.Math.Sub(beta, IR.F.Math.Mul(gamma, IR.F.Math.Div(mean, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))))); - - var matmul = IR.F.Math.Mul(input, Reshape(scaleBn, bnShape)); - List outputNames = new() { bnCall.Metadata.OutputNames?[0] + "_BN_Matmul" }; - matmul.Metadata.OutputNames = outputNames; - outputNames.Clear(); - outputNames.Add(bnCall.Metadata.OutputNames?[0] + "_BN_Binary"); - var binary = IR.F.Math.Add(matmul, Reshape(biasBn, bnShape)); - binary.Metadata.OutputNames = outputNames; + var scaleBn = IR.F.Math.Div(gamma, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Scale" } }); + var biasBn = IR.F.Math.Sub(beta, IR.F.Math.Mul(gamma, IR.F.Math.Div(mean, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))))).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Bias" } }); + var mul = IR.F.Math.Mul(input, Reshape(scaleBn, bnShape).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Scale" } })).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_BN_Mul" } }); + var binary = IR.F.Math.Add(mul, Reshape(biasBn, bnShape).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Bias" } })).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_BN_Add" } }); return binary; } } diff --git a/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs b/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs index 4d979c8a56..3cfa6b4a89 100644 --- a/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs +++ b/src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs @@ -105,7 +105,7 @@ public CombineConstBinaryTranspose() } } - var newConst = Reshape(x, newShape.ToArray()); + var newConst = Reshape(x, newShape.ToArray()).InheritMetaData(x); return Transpose(Binary(binary.BinaryOp, newConst, y).InheritMetaData(binaryCall), perm); } @@ -125,7 +125,7 @@ public CombineConstBinaryTranspose() } } - var newConst = Reshape(y, newShape.ToArray()); + var newConst = Reshape(y, newShape.ToArray()).InheritMetaData(y); return Transpose(Binary(binary.BinaryOp, x, newConst).InheritMetaData(binaryCall), perm); } diff --git a/src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs b/src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs index 289e5baacb..f3a50c126b 100644 --- a/src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs +++ b/src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs @@ -402,6 +402,6 @@ private static List GetOutputShape(List a, List b) var outputShape = GetOutputShape(lShape.ToValueList(), rShape.ToValueList()); - return Reshape(Binary(binary.BinaryOp, Reshape(lhs, newLShape.ToArray()), Reshape(rhs, newRShape.ToArray())).With(metadata: binaryCall.Metadata), outputShape.ToArray()); + return Reshape(Binary(binary.BinaryOp, Reshape(lhs, newLShape.ToArray()).With(metadata: lhs.Metadata), Reshape(rhs, newRShape.ToArray()).With(metadata: rhs.Metadata)).With(metadata: binaryCall.Metadata), outputShape.ToArray()); } }