Skip to content

Commit

Permalink
Merge branch 'master' into fix/cos-nan
Browse files Browse the repository at this point in the history
  • Loading branch information
hejunchao committed Oct 12, 2023
2 parents 3182ee3 + 01da2d0 commit aca38ee
Show file tree
Hide file tree
Showing 15 changed files with 95 additions and 21 deletions.
10 changes: 10 additions & 0 deletions src/Nncase.Core/Quantization/QuantizeOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ public class QuantizeOptions
/// </summary>
public string QuantScheme { get; set; } = string.Empty;

/// <summary>
/// Gets or sets import quant info for inner check, this field should not be displayed config.toml.
/// </summary>
public string QuantSchemeInnerCheck { get; set; } = string.Empty;

/// <summary>
/// Gets or sets a value indicating whether strict mode.
/// </summary>
public bool QuantSchemeStrictMode { get; set; }

/// <summary>
/// Gets or sets a value indicating whether export quant scheme.
/// </summary>
Expand Down
4 changes: 4 additions & 0 deletions src/Nncase.Importer/Onnx/Conv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ private Expr VisitConv2D(in NodeProto op)
var strides = GetStrideAttribute(op).ToArray<long>().ToList();

var isConv1D = IsConv1D(weights);
List<string> wOutputNames = new() { weights.Metadata.OutputNames![0] };
if (isConv1D)
{
dilation.Add(1);
Expand All @@ -34,9 +35,12 @@ private Expr VisitConv2D(in NodeProto op)
weights = To4D(weights);
}

weights.Metadata.OutputNames = wOutputNames;
var pads = AutoPad(op, autoPad, input, weights, strides.ToArray<long>(), dilation.ToArray(), isConv1D);
pads.InferenceType();
var conv = F.NN.Conv2D(input, weights, bias, strides.ToArray(), pads, dilation.ToArray(), PadMode.Constant, group);
List<string> outputNames = new() { op.Name };
conv.Metadata.OutputNames = outputNames;
if (isConv1D)
{
conv = Squeeze(conv, new[] { 3 });
Expand Down
4 changes: 3 additions & 1 deletion src/Nncase.Importer/Onnx/Gemm.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using Nncase.IR;
using Nncase.IR.Tensors;
using Onnx;
Expand All @@ -27,6 +27,8 @@ private Expr VisitGemm(in NodeProto op)
}

var mm = F.Tensors.MatMul(a, b);
List<string> outputNames = new() { op.Name };
mm.Metadata.OutputNames = outputNames;
if (alpha != 1.0f)
{
mm = mm * alpha;
Expand Down
7 changes: 5 additions & 2 deletions src/Nncase.Importer/Onnx/MatMul.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using Nncase.IR;
using Onnx;
using F = Nncase.IR.F;
Expand All @@ -13,7 +13,10 @@ public partial class OnnxImporter
private Expr VisitMatMul(in NodeProto op)
{
var (a, b) = GetInputExprs(op, 0, 1);
return IR.F.Math.MatMul(a, b);
var matmul = IR.F.Math.MatMul(a, b);
List<string> outputNames = new() { op.Name };
matmul.Metadata.OutputNames = outputNames;
return matmul;
}
}
}
7 changes: 6 additions & 1 deletion src/Nncase.Importer/TFLite/Activations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ private Expr VisitPRelu(in tflite.Operator op)
private Expr VisitLeakyRelu(in tflite.Operator op)
{
var input = GetInputExprs(op, 0);
return F.NN.LeakyRelu(input, op.BuiltinOptionsAsLeakyReluOptions().Alpha);
var node = F.NN.LeakyRelu(input, op.BuiltinOptionsAsLeakyReluOptions().Alpha);

List<string> outputNames = new() { GetOutputTensor(op, 0).Name };
node.Metadata.OutputNames = outputNames;

return node;
}

private Expr VisitHardSwish(in tflite.Operator op)
Expand Down
9 changes: 9 additions & 0 deletions src/Nncase.Importer/TFLite/Binary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ private Expr VisitBinary(in tflite.Operator op, BinaryOp binaryOp, tflite.Activa
(var lhs, var rhs) = GetInputExprs(op, 0, 1);

var node = F.Math.Binary(binaryOp, lhs, rhs);
List<string> outputNames = new();

var outputsLength = op.GetOutputsArray().Length;
for (int i = 0; i < outputsLength; i++)
{
outputNames.Add(GetOutputTensor(op, i).Name + "_FusedBinary");
}

node.Metadata.OutputNames = outputNames;
return Activate(node, activation);
}

Expand Down
9 changes: 6 additions & 3 deletions src/Nncase.Importer/TFLite/Conv2DTranspose.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ private Expr VisitConv2DTranspose(in tflite.Operator op)
var padding = F.ShapeExpr.GetPaddings(F.Tensors.Stack(new IR.Tuple(newOutShape), 0), wShape, stride, dilation, options.Padding == tflite.Padding.SAME, false);
var clamp = ValueRange<float>.Full;

return F.Tensors.NCHWToNHWC(F.Math.Clamp(
F.NN.Conv2DTranspose(
var conv2DTranspose = F.NN.Conv2DTranspose(
F.Tensors.NHWCToNCHW(input),
F.Tensors.NHWCToNCHW(weights),
bias,
Expand All @@ -54,7 +53,11 @@ private Expr VisitConv2DTranspose(in tflite.Operator op)
Tensor.From<long>(new long[] { 0, 0, 0, 0 }),
dilation,
PadMode.Constant,
1),
1);
List<string> outputNames = new() { GetInputTensor(op, 0).Name };
conv2DTranspose.Metadata.OutputNames = outputNames;
return F.Tensors.NCHWToNHWC(F.Math.Clamp(
conv2DTranspose,
clamp.Min,
clamp.Max));
}
Expand Down
9 changes: 8 additions & 1 deletion src/Nncase.Importer/TFLite/MatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ private Expr VisitMatMul(in tflite.Operator op, bool isFullyConnected = true)
? GetInputExprs(op, 2)
: Expand(Cast(0, GetDataType(GetInputTensor(op, 0).Type)), new[] { otherTensor.Shape(0) }).Evaluate().AsTensor();

var mm = MatMul(lhs, rhs) + bias;
var matmul = MatMul(lhs, rhs);
List<string> outputNames = new() { GetInputTensor(op, 0).Name + "_matmul" };
matmul.Metadata.OutputNames = outputNames;
outputNames.Clear();
outputNames.Add(GetInputTensor(op, 0).Name + "_bias");
bias.Metadata.OutputNames = outputNames;
var mm = matmul + bias;

return fusedActivationFunction switch
{
ActivationFunctionType.NONE => mm,
Expand Down
5 changes: 4 additions & 1 deletion src/Nncase.Passes/Rules/Neutral/AddPreProcess.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext o
if (inputType != InputType.Float32)
{
var qP = QuantParamOf(QuantMode.UnsignedMode, new[] { inputRange[0], inputRange[1] }, 8);
newInput = Dequantize(newInput, qP, DataTypes.Float32);
var dequantize = Dequantize(newInput, qP, DataTypes.Float32);
List<string> outputNames = new() { input.Metadata.OutputNames?[0] + "_PreDequantize" };
dequantize.Metadata.OutputNames = outputNames;
newInput = dequantize;
}

// Letterbox
Expand Down
10 changes: 9 additions & 1 deletion src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public static bool CheckOp(Op op)
// 这里必须要对matmul的rhs进行判断,如果matmul是动态的那么不会走量化,如果是静态的那么一定会转到conv2d
// 因此认为matmul的rhs为const的情况下一定能转成conv2d
bool isWeights = ((call.Target is Conv2D || call.Target is Conv2DTranspose) && (i == 1))
|| (call.Target is LSTM && i > 0)
|| (call.Target is LSTM && (i == 1 || i == 2))
|| (call.Target is MatMul && i == 1 && callParams[1] is TensorConst);

if (!configExist && !useAutoMixQuant)
Expand Down Expand Up @@ -190,6 +190,14 @@ private Marker WrapNormalOutput(Call call, bool configExist, bool useAutoMixQuan
private IR.Tuple WrapLSTMOutput(Call call, int outputSize, bool configExist, bool useAutoMixQuant, RunPassContext context)
{
var outputs = Enumerable.Range(0, outputSize).Select(i => IR.F.Tensors.GetItem(call, i)).ToArray();
for (int i = 0; i < outputSize; i++)
{
var outputNames = new List<string>();
var getItem = IR.F.Tensors.GetItem(call, i);
outputNames.Add("LSTMOutput_" + call.Metadata.OutputNames?[i]);
outputs[i].Metadata.OutputNames = outputNames;
}

foreach (var o in outputs)
{
context.MatchOptions.SuppressPattern(o, Pattern);
Expand Down
12 changes: 10 additions & 2 deletions src/Nncase.Passes/Rules/Neutral/BatchNormToBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Immutable;
using System.Linq;
using Nncase.IR;
using Nncase.IR.Math;
using Nncase.Passes;
using Nncase.PatternMatch;
using static Nncase.IR.F.NN;
Expand All @@ -32,7 +33,7 @@ public sealed partial class BatchNormToBinary : IRewriteRule
IsTensorConst("var"),
IsTensorConst("eps"));

private Expr? GetReplace(Expr input, Tensor<float> gamma, Tensor<float> beta, Tensor<float> mean, Tensor<float> var, Tensor<float> eps)
private Expr? GetReplace(Expr input, Tensor<float> gamma, Tensor<float> beta, Tensor<float> mean, Tensor<float> var, Tensor<float> eps, Expr bn, Call bnCall)
{
if (input.CheckedShape.Rank <= 1)
{
Expand All @@ -45,6 +46,13 @@ public sealed partial class BatchNormToBinary : IRewriteRule
var scaleBn = IR.F.Math.Div(gamma, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps)));
var biasBn = IR.F.Math.Sub(beta, IR.F.Math.Mul(gamma, IR.F.Math.Div(mean, IR.F.Math.Sqrt(IR.F.Math.Add(var, eps)))));

return IR.F.Math.Add(IR.F.Math.Mul(input, Reshape(scaleBn, bnShape)), Reshape(biasBn, bnShape));
var matmul = IR.F.Math.Mul(input, Reshape(scaleBn, bnShape));
List<string> outputNames = new() { bnCall.Metadata.OutputNames?[0] + "_BN_Matmul" };
matmul.Metadata.OutputNames = outputNames;
outputNames.Clear();
outputNames.Add(bnCall.Metadata.OutputNames?[0] + "_BN_Binary");
var binary = IR.F.Math.Add(matmul, Reshape(biasBn, bnShape));
binary.Metadata.OutputNames = outputNames;
return binary;
}
}
20 changes: 16 additions & 4 deletions src/Nncase.Passes/Rules/Neutral/MatMulToConv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ public sealed partial class BroadcastMatMul : IRewriteRule

var ifShape = new int[] { -1, aShape[^2].FixedValue, aShape[^1].FixedValue };
var wShape = new int[] { -1, newBShape[^2], newBShape[^1] };
return Reshape(MatMul(Reshape(a, ifShape), Reshape(IR.F.Tensors.Broadcast(b, newBShape), wShape)), newOutputShape);
var bBroadCast = IR.F.Tensors.Broadcast(b, newBShape);
List<string> outputNames = new() { b.Metadata.OutputNames![0] + "_bBroadCast" };
bBroadCast.Metadata.OutputNames = outputNames;
return Reshape(MatMul(Reshape(a, ifShape), Reshape(bBroadCast, wShape)), newOutputShape);
}
else if (aShape.Rank < bShape.Rank)
{
Expand All @@ -163,7 +166,10 @@ public sealed partial class BroadcastMatMul : IRewriteRule

var ifShape = new int[] { -1, newAShape[^2], newAShape[^1] };
var wShape = new int[] { -1, bShape[^2].FixedValue, bShape[^1].FixedValue };
return Reshape(MatMul(Reshape(IR.F.Tensors.Broadcast(a, newAShape), ifShape), Reshape(b, wShape)), newOutputShape);
var aBroadCast = IR.F.Tensors.Broadcast(a, newAShape);
List<string> outputNames = new() { a.Metadata.OutputNames![0] + "_aBroadCast" };
aBroadCast.Metadata.OutputNames = outputNames;
return Reshape(MatMul(Reshape(aBroadCast, ifShape), Reshape(b, wShape)), newOutputShape);
}
else
{
Expand All @@ -182,13 +188,19 @@ public sealed partial class BroadcastMatMul : IRewriteRule

var ifShape = new int[] { -1, newAShape[^2], newAShape[^1] };
var wShape = new int[] { -1, newBShape[^2], newBShape[^1] };
var bBroadCast = IR.F.Tensors.Broadcast(b, newBShape);
List<string> bOutputNames = new() { b.Metadata.OutputNames?[0] + "_bBroadCast" };
bBroadCast.Metadata.OutputNames = bOutputNames;
var aBroadCast = IR.F.Tensors.Broadcast(a, newAShape);
List<string> aOutputNames = new() { a.Metadata.OutputNames?[0] + "_aBroadCast" };
aBroadCast.Metadata.OutputNames = aOutputNames;
return Reshape(
MatMul(
Reshape(
IR.F.Tensors.Broadcast(a, newAShape),
aBroadCast,
ifShape),
Reshape(
IR.F.Tensors.Broadcast(b, newBShape),
bBroadCast,
wShape)).InheritMetaData(matMulCall),
newOutputShape);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Rules/Neutral/ReshapeMatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public partial class ReshapeMatMul : RewriteRule<Pattern>
}
}

var end = IR.F.Tensors.Reshape(IR.F.Tensors.MatMul(lhs, rhs), outputShape);
var end = IR.F.Tensors.Reshape(IR.F.Tensors.MatMul(lhs, rhs).InheritMetaData(matmul), outputShape);
return end;
}
}
4 changes: 2 additions & 2 deletions src/Nncase.Quantization/Quantization/Quantizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ public async Task RunAsync(RunPassContext options)
}

var quantSchemeString = JsonConvert.SerializeObject(quantScheme, Newtonsoft.Json.Formatting.Indented);
_quantizeOptions.QuantScheme = quantSchemeString;
_quantizeOptions.QuantSchemeInnerCheck = quantSchemeString;
if (Path.Exists(DumpScope.Current.Directory))
{
File.WriteAllText(Path.Join(DumpScope.Current.Directory, "..", "..", "QuantScheme.json"), _quantizeOptions.QuantScheme);
File.WriteAllText(Path.Join(DumpScope.Current.Directory, "..", "..", "QuantScheme.json"), quantSchemeString);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Tests/Quant/UnitTestExportQuantScheme.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public async Task TestExportQuantSchemeForWeightsByTensorConv2D()
var readJson = "{\"Version\":\"1.0\",\"Model\":null,\"Outputs\":[{\"Name\":\"weight\",\"DataType\":\"u8\",\"DataRange\":[{\"Min\":0.0,\"Max\":0.9988426,\"IsFull\":false}],\"DataRangeMode\":\"by_tensor\"}]}";
var quantScheme = JsonConvert.DeserializeObject<QuantScheme>(readJson);
var expectedQuantScheme = JsonConvert.SerializeObject(quantScheme, Newtonsoft.Json.Formatting.Indented);
Assert.Equal(expectedQuantScheme, CompileOptions.QuantizeOptions.QuantScheme);
Assert.Equal(expectedQuantScheme, CompileOptions.QuantizeOptions.QuantSchemeInnerCheck);
}

[Fact]
Expand Down Expand Up @@ -78,7 +78,7 @@ public async Task TestExportQuantSchemeForWeightsByChannelConv2D()
var readJson = "{\"Version\":\"1.0\",\"Model\":null,\"Outputs\":[{\"Name\":\"weight\",\"DataType\":\"u8\",\"DataRange\":[{\"Min\":0.0,\"Max\":0.32098764,\"IsFull\":false},{\"Min\":0.33333334,\"Max\":0.654321,\"IsFull\":false},{\"Min\":0.6666667,\"Max\":0.9876543,\"IsFull\":false}],\"DataRangeMode\":\"by_channel\"}]}";
var quantScheme = JsonConvert.DeserializeObject<QuantScheme>(readJson);
var expectedQuantScheme = JsonConvert.SerializeObject(quantScheme, Newtonsoft.Json.Formatting.Indented);
Assert.Equal(expectedQuantScheme, CompileOptions.QuantizeOptions.QuantScheme);
Assert.Equal(expectedQuantScheme, CompileOptions.QuantizeOptions.QuantSchemeInnerCheck);
}

private async Task<DumpVisitor> TestExportQuantSchemeMainPassesAsync(Var input, Expr output, bool exportWeightRangeByChannel)
Expand Down

0 comments on commit aca38ee

Please sign in to comment.