Skip to content

Commit

Permalink
add missing meta information for quant (#1130)
Browse files Browse the repository at this point in the history
Co-authored-by: guodongliang <guodongliang@canaan-creative.com>
  • Loading branch information
uranus0515 and guodongliang committed Nov 20, 2023
1 parent 3820af1 commit 2571d16
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 13 deletions.
14 changes: 4 additions & 10 deletions src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,10 @@ public sealed partial class BatchNormToBinary : IRewriteRule
var shape = input.CheckedShape.ToValueArray();
var bnShape = Enumerable.Repeat(1, shape.Length - 1).ToArray();
bnShape[0] = shape[1];
var scaleBn = IR.F.Math.Div(gamma, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps)));
var biasBn = IR.F.Math.Sub(beta, IR.F.Math.Mul(gamma, IR.F.Math.Div(mean, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps)))));

var matmul = IR.F.Math.Mul(input, Reshape(scaleBn, bnShape));
List<string> outputNames = new() { bnCall.Metadata.OutputNames?[0] + "_BN_Matmul" };
matmul.Metadata.OutputNames = outputNames;
outputNames.Clear();
outputNames.Add(bnCall.Metadata.OutputNames?[0] + "_BN_Binary");
var binary = IR.F.Math.Add(matmul, Reshape(biasBn, bnShape));
binary.Metadata.OutputNames = outputNames;
var scaleBn = IR.F.Math.Div(gamma, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Scale" } });
var biasBn = IR.F.Math.Sub(beta, IR.F.Math.Mul(gamma, IR.F.Math.Div(mean, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps))))).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Bias" } });
var mul = IR.F.Math.Mul(input, Reshape(scaleBn, bnShape).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Scale" } })).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_BN_Mul" } });
var binary = IR.F.Math.Add(mul, Reshape(biasBn, bnShape).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_Bias" } })).With(metadata: new IRMetadata() { OutputNames = new[] { bnCall.Metadata.OutputNames?[0] + "_BN_Add" } });
return binary;
}
}
4 changes: 2 additions & 2 deletions src/Nncase.Passes/Rules/Neutral/CombineTranspose.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public CombineConstBinaryTranspose()
}
}

var newConst = Reshape(x, newShape.ToArray());
var newConst = Reshape(x, newShape.ToArray()).InheritMetaData(x);
return Transpose(Binary(binary.BinaryOp, newConst, y).InheritMetaData(binaryCall), perm);
}

Expand All @@ -125,7 +125,7 @@ public CombineConstBinaryTranspose()
}
}

var newConst = Reshape(y, newShape.ToArray());
var newConst = Reshape(y, newShape.ToArray()).InheritMetaData(y);
return Transpose(Binary(binary.BinaryOp, x, newConst).InheritMetaData(binaryCall), perm);
}

Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Rules/Neutral/SqueezeShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,6 @@ private static List<int> GetOutputShape(List<int> a, List<int> b)

var outputShape = GetOutputShape(lShape.ToValueList(), rShape.ToValueList());

return Reshape(Binary(binary.BinaryOp, Reshape(lhs, newLShape.ToArray()), Reshape(rhs, newRShape.ToArray())).With(metadata: binaryCall.Metadata), outputShape.ToArray());
return Reshape(Binary(binary.BinaryOp, Reshape(lhs, newLShape.ToArray()).With(metadata: lhs.Metadata), Reshape(rhs, newRShape.ToArray()).With(metadata: rhs.Metadata)).With(metadata: binaryCall.Metadata), outputShape.ToArray());
}
}

0 comments on commit 2571d16

Please sign in to comment.