Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into fix/metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
FusionBolt committed Nov 21, 2023
2 parents 592f43f + f6e4674 commit 2eccfc0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
14 changes: 9 additions & 5 deletions src/Nncase.Evaluator/NN/LayerNorm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ private UInt128 GetRingReduceCommunicate(DistributedType distributedType, int[]
#if true
private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis, float epsilon, bool useMean = true)
{
int outputSize = 1;
int outerSize = 1;
int innerSize = 1;
float[] inputArray = input.ToArray<float>();
float[] outputArray = new float[inputArray.Length];
Expand All @@ -157,23 +157,25 @@ private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis,

for (int i = 0; i < axis; i++)
{
outputSize *= inShape[i];
outerSize *= inShape[i];
}

for (int i = axis; i < inShape.Length; i++)
{
innerSize *= inShape[i];
}

for (int batch = 0; batch < outputSize; batch++)
for (int batch = 0; batch < outerSize; batch++)
{
float mean1 = 0f;
if (useMean)
{
for (int i = 0; i < innerSize; i++)
{
mean1 += inputArray[(i + (batch * innerSize)) % inputArray.Length] / innerSize;
mean1 += inputArray[(i + (batch * innerSize)) % inputArray.Length];
}

mean1 /= innerSize;
}

float[] sub = new float[innerSize];
Expand All @@ -191,9 +193,11 @@ private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis,
float mean2 = 0f;
for (int i = 0; i < innerSize; i++)
{
mean2 += pow[i] / innerSize;
mean2 += pow[i];
}

mean2 /= innerSize;

float add = mean2 + epsilon;
float sqrt = (float)System.Math.Sqrt(add);

Expand Down
14 changes: 4 additions & 10 deletions src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,10 @@ public sealed partial class BatchNormToBinary : IRewriteRule
var shape = input.CheckedShape.ToValueArray();
var bnShape = Enumerable.Repeat(1, shape.Length - 1).ToArray();
bnShape[0] = shape[1];
var scaleBn = IR.F.Math.Div(gamma, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps)));
var biasBn = IR.F.Math.Sub(beta, IR.F.Math.Mul(gamma, IR.F.Math.Div(mean, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps)))));

var matmul = IR.F.Math.Mul(input, Reshape(scaleBn, bnShape));
List<string> outputNames = new() { bnCall.Metadata.OutputNames?[0] + "_BN_Matmul" };
matmul.Metadata.OutputNames = outputNames;
outputNames.Clear();
outputNames.Add(bnCall.Metadata.OutputNames?[0] + "_BN_Binary");
var binary = IR.F.Math.Add(matmul, Reshape(biasBn, bnShape));
binary.Metadata.OutputNames = outputNames;
var scaleBn = IR.F.Math.Div(gamma, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Scale" } });
var biasBn = IR.F.Math.Sub(beta, IR.F.Math.Mul(gamma, IR.F.Math.Div(mean, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))))).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Bias" } });
var mul = IR.F.Math.Mul(input, Reshape(scaleBn, bnShape).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Scale" } })).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_BN_Mul" } });
var binary = IR.F.Math.Add(mul, Reshape(biasBn, bnShape).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Bias" } })).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_BN_Add" } });
return binary;
}
}
4 changes: 2 additions & 2 deletions src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public CombineConstBinaryTranspose()
}
}

var newConst = Reshape(x, newShape.ToArray());
var newConst = Reshape(x, newShape.ToArray()).InheritMetaData(x);
return Transpose(Binary(binary.BinaryOp, newConst, y).InheritMetaData(binaryCall), perm);
}

Expand All @@ -125,7 +125,7 @@ public CombineConstBinaryTranspose()
}
}

var newConst = Reshape(y, newShape.ToArray());
var newConst = Reshape(y, newShape.ToArray()).InheritMetaData(y);
return Transpose(Binary(binary.BinaryOp, x, newConst).InheritMetaData(binaryCall), perm);
}

Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,6 @@ private static List<int> GetOutputShape(List<int> a, List<int> b)

var outputShape = GetOutputShape(lShape.ToValueList(), rShape.ToValueList());

return Reshape(Binary(binary.BinaryOp, Reshape(lhs, newLShape.ToArray()), Reshape(rhs, newRShape.ToArray())).With(metadata: binaryCall.Metadata), outputShape.ToArray());
return Reshape(Binary(binary.BinaryOp, Reshape(lhs, newLShape.ToArray()).With(metadata: lhs.Metadata), Reshape(rhs, newRShape.ToArray()).With(metadata: rhs.Metadata)).With(metadata: binaryCall.Metadata), outputShape.ToArray());
}
}

0 comments on commit 2eccfc0

Please sign in to comment.