diff --git a/src/Nncase.Core/CompilerServices.cs b/src/Nncase.Core/CompilerServices.cs index 9c666bcd21..2f731a7fdf 100644 --- a/src/Nncase.Core/CompilerServices.cs +++ b/src/Nncase.Core/CompilerServices.cs @@ -208,6 +208,15 @@ public interface ICompilerServicesProvider /// Options. /// Rewrited expression. Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext options); + + /// + /// Using EGraph rewrite expression. + /// + /// Expression. + /// Rewrite rules. + /// Options. + /// Rewrited expression. + IEGraph ERewrite(IEGraph expr, IEnumerable rules, RunPassContext options); } internal interface ICompilerServicesProviderInternal @@ -409,6 +418,18 @@ public static Expr ERewrite(Expr expr, IEnumerable rules, RunPassC return Provider.ERewrite(expr, rules, options); } + /// + /// Using EGraph rewrite expression. + /// + /// Expression. + /// Rewrite rules. + /// Options. + /// Rewrited expression. + public static IEGraph ERewrite(IEGraph graph, IEnumerable rules, RunPassContext options) + { + return Provider.ERewrite(graph, rules, options); + } + /// /// Match enodes as root. /// @@ -677,4 +698,9 @@ public Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext { return _eGraphrewriteProvider.ERewrite(expr, rules, options); } + + public IEGraph ERewrite(IEGraph graph, IEnumerable rules, RunPassContext options) + { + return _eGraphrewriteProvider.ERewrite(graph, rules, options); + } } diff --git a/src/Nncase.Core/Enum/BinaryOp.cs b/src/Nncase.Core/Enum/BinaryOp.cs index 45afba68cc..fa4092fa48 100644 --- a/src/Nncase.Core/Enum/BinaryOp.cs +++ b/src/Nncase.Core/Enum/BinaryOp.cs @@ -93,4 +93,14 @@ public enum BinaryOp : byte /// Right Shift. /// RightShift, + + /// + /// Floor Div. + /// + FloorDiv, + + /// + /// Ceil Div. + /// + CeilDiv, } diff --git a/src/Nncase.Core/IR/Buffers/Allocate.cs b/src/Nncase.Core/IR/Buffers/Allocate.cs index 14b44010ec..ff7bdd13c0 100644 --- a/src/Nncase.Core/IR/Buffers/Allocate.cs +++ b/src/Nncase.Core/IR/Buffers/Allocate.cs @@ -13,5 +13,20 @@ namespace Nncase.IR.Buffers; /// public sealed partial class Allocate : Op { - public TensorType ElemType { get; } + /// + /// Get the input parameter. + /// + public static readonly ParameterInfo Size = new(typeof(Allocate), 0, "size", TypePatternUtility.IsIntegralScalar()); + + /// + /// Gets the alloacted buffer type. + /// + public DataType ElemType { get; } + + public TIR.MemoryLocation Location { get; } + + /// + public override bool CanFoldConstCall => false; + + public override string DisplayProperty() => $"{ElemType}, {Location}"; } diff --git a/src/Nncase.Core/IR/Buffers/Functional.cs b/src/Nncase.Core/IR/Buffers/Functional.cs index a2e3507a5f..463c4f1e2c 100644 --- a/src/Nncase.Core/IR/Buffers/Functional.cs +++ b/src/Nncase.Core/IR/Buffers/Functional.cs @@ -42,4 +42,6 @@ public static class Buffer /// create the uninitialized buffer. /// public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape); + + public static Call Allocate(Expr size, DataType dataType, TIR.MemoryLocation location) => new Call(new Allocate(dataType, location), size); } diff --git a/src/Nncase.Core/IR/TensorConst.cs b/src/Nncase.Core/IR/TensorConst.cs index 9e651978ed..07dccfbcc3 100644 --- a/src/Nncase.Core/IR/TensorConst.cs +++ b/src/Nncase.Core/IR/TensorConst.cs @@ -20,12 +20,18 @@ public TensorConst(Tensor tensor) Value = tensor; } + public TensorConst(Tensor tensor, IRArray ndsbp, Placement placement) + : base(new DistributedType(new TensorType(tensor.ElementType, tensor.Shape), ndsbp, placement)) + { + Value = tensor; + } + public Tensor Value { get; } /// /// Gets value type. /// - public new TensorType ValueType => (TensorType)base.ValueType; + public new IRType ValueType => base.ValueType; /// /// Create TensorConstant from a . @@ -122,25 +128,43 @@ public TensorConst(Tensor tensor) public static bool operator !=(TensorConst? left, TensorConst? right) => !(left == right); /// - public override string ToString() => ValueType switch + public override string ToString() { - var x when x.IsScalar => - x.DType switch - { - var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar().ToString(), - var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar().ToString(), - var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar().ToString(), - var dtype when dtype == DataTypes.Boolean => Value.ToScalar().ToString(), - _ => $"{x.DType.GetDisplayName()} {x.Shape}", - }, - _ => $"{ValueType.DType.GetDisplayName()} {ValueType.Shape}", - }; + var type = ValueType switch + { + DistributedType dt => dt.TensorType, + TensorType tt => tt, + _ => throw new NotSupportedException("Not supported const type: " + ValueType), + }; + + return type switch + { + var x when x.IsScalar => + x.DType switch + { + var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar().ToString(), + var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar().ToString(), + var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar().ToString(), + var dtype when dtype == DataTypes.Boolean => Value.ToScalar().ToString(), + _ => $"{x.DType.GetDisplayName()} {x.Shape}", + }, + _ => $"{type.DType.GetDisplayName()} {type.Shape}", + }; + } /// public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitTensorConst(this, context); - public TensorConst With(Tensor? value = null) => new TensorConst(value ?? Value); + public TensorConst With(Tensor? value = null) + { + if (value is null && ValueType is DistributedType dt) + { + return new TensorConst(Value, dt.NdSBP, dt.Placement); + } + + return new TensorConst(value ?? Value); + } /// public override bool Equals(object? obj) => Equals(obj as TensorConst); diff --git a/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs b/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs index 7559954a49..1a64b55917 100644 --- a/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs +++ b/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs @@ -23,6 +23,37 @@ protected override Expr RewriteLeafBlock(Block expr) { if (predicate) { + if (expr.AllocBuffers.Length > 0) + { + var lets = expr.AllocBuffers.ToArray().Select(b => (T.Let(out var v, b.MemSpan.Start, b.Name + "_ptr"), v)).ToArray(); + for (int i = 0; i < lets.Length - 1; i++) + { + lets[i].Item1.Body(lets[i + 1].Item1); + } + + var map = new Dictionary(ReferenceEqualityComparer.Instance); + for (int i = 0; i < expr.AllocBuffers.Length; i++) + { + map.Add(expr.AllocBuffers[i].MemSpan.Start, lets[i].v); + } + + var mutator = new Substitutor(e => + { + if (map.TryGetValue(e, out var r)) + { + return r; + } + + return null; + }); + + var initBody = mutator.Visit(expr.InitBody, Unit.Default); + var body = mutator.Visit(expr.Body, Unit.Default); + + lets[^1].Item1.Body(initBody, body); + return lets[0].Item1.Build(); + } + return T.Sequential(expr.InitBody, expr.Body); } else diff --git a/src/Nncase.Core/PatternMatch/ConstPattern.cs b/src/Nncase.Core/PatternMatch/ConstPattern.cs index cb9a52aebe..bca5ec63a1 100644 --- a/src/Nncase.Core/PatternMatch/ConstPattern.cs +++ b/src/Nncase.Core/PatternMatch/ConstPattern.cs @@ -66,9 +66,9 @@ public static partial class Utility public static TensorConstPattern IsConst(string? name, Func cond) => new( x => { - if (DataTypes.IsFloat(x.ValueType.DType)) + if (DataTypes.IsFloat(x.CheckedDataType)) { - if (x.ValueType.IsScalar) + if (x.CheckedShape.IsScalar) { return cond(x.Value.ToScalar()); } @@ -93,9 +93,9 @@ public static partial class Utility public static TensorConstPattern IsConst(string? name, Func cond) => new( x => { - if (DataTypes.IsIntegral(x.ValueType.DType)) + if (DataTypes.IsIntegral(x.CheckedDataType)) { - if (x.ValueType.IsScalar) + if (x.CheckedShape.IsScalar) { return cond(x.Value.ToScalar()); } diff --git a/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs b/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs index 60d79a6b89..2add795b71 100644 --- a/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs +++ b/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs @@ -51,7 +51,7 @@ public T Build() public ISequentialBuilder InsertBody(int index, params object[] exprOrBuilders) { - _subBuilders[_subBuilders.Length - 1].InsertBody(index, exprOrBuilders); + _subBuilders[index < 0 ? _subBuilders.Length + index : index].Body(exprOrBuilders); return this; } } diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index 28740e43ab..ffda2527df 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -134,8 +134,16 @@ public static ISequentialBuilder Grid(out Var[] loopVars, LoopMode loopMode { string[] names = { "i", "j", "k", "l" }; var newLoopVars = loopVars = new Var[ranges.Length]; - return new NestBodyExprBuilder(ranges.Select((rg, i) => - T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray()); + var newLoops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray(); + return new NestBodyExprBuilder(newLoops); + } + + public static ISequentialBuilder Grid(out Var[] loopVars, out ISequentialBuilder[] loops, LoopMode loopMode, params TIR.Range[] ranges) + { + string[] names = { "i", "j", "k", "l" }; + var newLoopVars = loopVars = new Var[ranges.Length]; + var newLoops = loops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray(); + return new NestBodyExprBuilder(loops); } /// @@ -223,6 +231,49 @@ public static Buffer CreateBuffer(TensorType tensorType, MemoryLocation location return buffer; } + /// + /// create the buffer by expressions. + /// + public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") + { + if (name.StartsWith("var ")) + { + name = name[4..]; + } + + var strides = TensorUtilities.GetStrides(dimensions); + var size = TensorUtilities.GetProduct(dimensions.ToArray()) * dataType.SizeInBytes; + var memspan = new MemSpan(size, location); + buffer = new Buffer(name, dataType, memspan, dimensions, strides); + return buffer; + } + + public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, Expr[] strides, MemSpan memSpan, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") + { + if (name.StartsWith("var ")) + { + name = name[4..]; + } + + buffer = new Buffer(name, dataType, memSpan, dimensions, strides); + return buffer; + } + + public static Buffer AttachBuffer(Expr start, TensorType tensorType, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") + { + if (name.StartsWith("var ")) + { + name = name[4..]; + } + + var dimensions = tensorType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; + var memspan = new MemSpan(start, size, location); + buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); + return buffer; + } + /// /// create buffer by const. /// @@ -233,11 +284,11 @@ public static Buffer AttachBuffer(TensorConst @const, out Buffer buffer, [Caller name = name[4..]; } - var dimensions = @const.ValueType.Shape.ToValueArray(); + var dimensions = @const.CheckedShape.ToValueArray(); var strides = TensorUtilities.GetStrides(dimensions); - var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.ValueType.DType.SizeInBytes; + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.CheckedDataType.SizeInBytes; var memspan = new MemSpan(IR.F.Buffer.DDrOf(@const), size, MemoryLocation.Rdata); - buffer = new Buffer(name, @const.ValueType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); + buffer = new Buffer(name, @const.CheckedDataType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); return buffer; } diff --git a/src/Nncase.Core/TensorUtilities.cs b/src/Nncase.Core/TensorUtilities.cs index 79f658aefa..146e5c6cfa 100644 --- a/src/Nncase.Core/TensorUtilities.cs +++ b/src/Nncase.Core/TensorUtilities.cs @@ -69,7 +69,7 @@ public static Expr GetProduct(ReadOnlySpan dimensions, int startIndex = 0) for (int i = startIndex; i < dimensions.Length; i++) { var dimension = dimensions[i]; - product *= IR.F.Math.Require(dimension >= 0, dimension, "Dimension is out of range."); + product *= dimension; } return product; diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs index 2061a40958..eb4a84be0d 100644 --- a/src/Nncase.Core/Utilities/DistributedUtility.cs +++ b/src/Nncase.Core/Utilities/DistributedUtility.cs @@ -2,6 +2,7 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System.Diagnostics.CodeAnalysis; +using NetFabric.Hyperlinq; using Nncase.IR; namespace Nncase.Utilities; @@ -26,11 +27,7 @@ public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tens ndsbps.Add(ndsbp); } - return ndsbps.CartesianProduct(). - Select(ndsbp => ndsbp.ToArray()). - Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)). - Select(ndsbp => new IRArray(ndsbp)). - ToArray(); + return ndsbps.CartesianProduct().Select(ndsbp => ndsbp.ToArray()).Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).Select(ndsbp => new IRArray(ndsbp)).ToArray(); } public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedType distributedType) @@ -65,11 +62,7 @@ public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedT } } - return candidateNdsbps.CartesianProduct(). - Select(ndsbp => ndsbp.ToArray()). - Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)). - Select(ndsbp => new IRArray(ndsbp)). - ToArray(); + return candidateNdsbps.CartesianProduct().Select(ndsbp => ndsbp.ToArray()).Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).Select(ndsbp => new IRArray(ndsbp)).ToArray(); } public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsbp, Placement placement) @@ -131,24 +124,74 @@ public static Expr[] TryGetNonUniformDividedShape(DistributedType distributedTyp } return hierarchies.Select((divs, axis) => + { + Expr dim; + if (divs.Any()) { - Expr dim; - if (divs.Any()) + var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray()); + var (res, rem) = Math.DivRem(shape[axis], divsor); + if (rem == 0) { - var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray()); - var (res, rem) = Math.DivRem(shape[axis], divsor); - dim = IR.F.Math.Select( - TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1), - res, - res + rem); + return res; } - else + + dim = IR.F.Math.Select( + TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1), + res, + res + rem); + } + else + { + dim = distributedType.TensorType.Shape[axis].FixedValue; + } + + return dim; + }).ToArray(); + } + + public static List TryGetNonUniformDividedSlice(DistributedType distributedType) + { + var shape = distributedType.TensorType.Shape.ToValueArray(); + var hierarchies = Enumerable.Range(0, shape.Length).Select(i => new List()).ToArray(); + for (int i = 0; i < distributedType.NdSBP.Count; i++) + { + if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis }) + { + hierarchies[axis].Add(i); + } + } + + var spliList = hierarchies.Select, int[]>((divs, axis) => + { + int[] dim; + if (divs.Any()) + { + var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray()); + var (res, rem) = Math.DivRem(shape[axis], divsor); + if (rem == 0) { - dim = distributedType.TensorType.Shape[axis].FixedValue; + return new[] { res }; } - return dim; - }).ToArray(); + dim = new[] { res, res + rem }; + } + else + { + dim = distributedType.TensorType.Shape.ToValueArray().Skip(axis).Take(1).ToArray(); + } + + return dim; + }).ToList(); + + IEnumerable ret = new[] { Array.Empty() }; + foreach (int[] array in spliList) + { + ret = from seq in ret + from item in array + select seq.Concat(new[] { item }).ToArray(); + } + + return ret.ToList(); } public static bool IsDivideBy(int input, int divisor) @@ -174,17 +217,14 @@ public static bool IsDivideExactly(int input, int divisor) public static float GetDividedTensorEfficiency(DistributedType distributedType, int burstLength) { var (tiles, shape) = GetDividedTile(distributedType); - return Enumerable.Range(0, tiles.Count). - Select(i => tiles[i].Ranges(0, shape[i])). - CartesianProduct(). - Select(rgs => - { - var slice = rgs.ToArray(); - var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToArray(), slice, out var contiguousStart); - var size = TensorUtilities.GetProduct(tiles.ToArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes; - var (div, rem) = Math.DivRem(size, burstLength); - return ((div * 1.0f) + ((float)rem / burstLength)) / (div + 1); - }).Average(); + return Enumerable.Range(0, tiles.Count).Select(i => tiles[i].Ranges(0, shape[i])).CartesianProduct().Select(rgs => + { + var slice = rgs.ToArray(); + var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToArray(), slice, out var contiguousStart); + var size = TensorUtilities.GetProduct(tiles.ToArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes; + var (div, rem) = Math.DivRem(size, burstLength); + return ((div * 1.0f) + ((float)rem / burstLength)) / (div + 1); + }).Average(); } public static TensorType GetDividedTensorType(DistributedType distributedType) diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs index 4a447073b8..079dafc907 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs @@ -248,6 +248,8 @@ public ILPrintVisitor(TextWriter textWriter, bool display_callable, int indent_l _scope = new(textWriter, indent_level); } + public override string DefaultVisitType(IRType type) => type.ToString(); + /// public override string VisitType(AnyType type) => "any"; diff --git a/src/Nncase.Evaluator/Buffers/Allocate.cs b/src/Nncase.Evaluator/Buffers/Allocate.cs index d1a75d3375..d09f1654f3 100644 --- a/src/Nncase.Evaluator/Buffers/Allocate.cs +++ b/src/Nncase.Evaluator/Buffers/Allocate.cs @@ -14,6 +14,6 @@ public partial class AllocateEvaluator : ITypeInferencer /// public IRType Visit(ITypeInferenceContext context, Allocate target) { - return TensorType.Pointer(target.ElemType.DType); + return TensorType.Pointer(target.ElemType); } } diff --git a/src/Nncase.Evaluator/Buffers/DDrOf.cs b/src/Nncase.Evaluator/Buffers/DDrOf.cs index 86c53e04b7..b329ee1787 100644 --- a/src/Nncase.Evaluator/Buffers/DDrOf.cs +++ b/src/Nncase.Evaluator/Buffers/DDrOf.cs @@ -12,8 +12,13 @@ namespace Nncase.Evaluator.Buffers; [TypeInferGenerator] public partial class DDrOfEvaluator : ITypeInferencer { - private IRType Visit(TensorType input) + private IRType Visit(IRType input) { - return TensorType.Pointer(input.DType); + return input switch + { + DistributedType d => TensorType.Pointer(d.TensorType.DType), + TensorType t => TensorType.Pointer(t.DType), + _ => new InvalidType(input.GetType().Name), + }; } } diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs index 1ac424b0be..4e8ab2e493 100755 --- a/src/Nncase.Evaluator/Math/Binary.cs +++ b/src/Nncase.Evaluator/Math/Binary.cs @@ -214,6 +214,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => (int)System.Math.Floor((float)a / b), + BinaryOp.CeilDiv => (int)System.Math.Ceiling((float)a / b), BinaryOp.Mod => a % b, BinaryOp.Min => System.Math.Min(a, b), BinaryOp.Max => System.Math.Max(a, b), @@ -227,6 +229,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => (uint)System.Math.Floor((float)a / b), + BinaryOp.CeilDiv => (uint)System.Math.Ceiling((float)a / b), BinaryOp.Mod => a % b, BinaryOp.Min => System.Math.Min(a, b), BinaryOp.Max => System.Math.Max(a, b), @@ -242,6 +246,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => (ulong)System.Math.Floor((float)a / b), + BinaryOp.CeilDiv => (ulong)System.Math.Ceiling((float)a / b), BinaryOp.Mod => a % b, BinaryOp.Min => System.Math.Min(a, b), BinaryOp.Max => System.Math.Max(a, b), @@ -262,6 +268,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => (long)System.Math.Floor((float)a / b), + BinaryOp.CeilDiv => (long)System.Math.Ceiling((float)a / b), BinaryOp.Mod => a % b, BinaryOp.Min => System.Math.Min(a, b), BinaryOp.Max => System.Math.Max(a, b), @@ -298,6 +306,8 @@ static OrtKISharp.Tensor Mod(OrtKISharp.Tensor a, OrtKISharp.Tensor b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => OrtKI.Floor(a.Cast(OrtDataType.Float) / b.Cast(OrtDataType.Float)).Cast(a.DataType), + BinaryOp.CeilDiv => OrtKI.Ceil(a.Cast(OrtDataType.Float) / b.Cast(OrtDataType.Float)).Cast(a.DataType), BinaryOp.Mod => Mod(a, b), BinaryOp.Min => OrtKI.Min(new[] { a, b }), BinaryOp.Max => OrtKI.Max(new[] { a, b }), diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs index 1f19b64388..4642a4f8d5 100644 --- a/src/Nncase.Evaluator/Math/MatMul.cs +++ b/src/Nncase.Evaluator/Math/MatMul.cs @@ -23,7 +23,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b) { if (VisitTensorType(a.TensorType, b.TensorType) is not TensorType outType) { - return new InvalidType(string.Empty); + return new InvalidType($"{a.TensorType} {b.TensorType} not support"); } if (a.Placement != b.Placement) @@ -162,7 +162,7 @@ public IRType Visit(ITypeInferenceContext context, MatMul target) { (DistributedType a, DistributedType b) => VisitDistributedType(a, b), (TensorType a, TensorType b) => VisitTensorType(a, b), - _ => new InvalidType(string.Empty), + _ => new InvalidType($"{lhs} {rhs} not support"), }; } diff --git a/src/Nncase.Importer/Onnx/QLinearConv.cs b/src/Nncase.Importer/Onnx/QLinearConv.cs index 0ab9ebc99c..3e1bb7e07a 100644 --- a/src/Nncase.Importer/Onnx/QLinearConv.cs +++ b/src/Nncase.Importer/Onnx/QLinearConv.cs @@ -29,13 +29,13 @@ private Expr VisitQLinearConv(in NodeProto op) var group = GetIntAttribute(op, "group", 1); var strides = GetStrideAttribute(op); - int? stridesValueLen = ((TensorConst)strides).ValueType.Shape[0].Value; + int? stridesValueLen = ((TensorConst)strides).CheckedShape[0].Value; for (var i = 0; i < stridesValueLen; i++) { System.Diagnostics.Trace.Assert(((TensorConst)strides).Value.Cast()[i] <= (long)int.MaxValue); } - int? dilationValueLen = ((TensorConst)dilation).ValueType.Shape[0].Value; + int? dilationValueLen = ((TensorConst)dilation).CheckedShape[0].Value; for (var i = 0; i < dilationValueLen; i++) { System.Diagnostics.Trace.Assert(((TensorConst)dilation).Value.Cast()[i] <= (long)int.MaxValue); @@ -63,16 +63,16 @@ private Expr VisitQLinearConv(in NodeProto op) if (bias == null) { - int? ocNumber = ((TensorConst)weights).ValueType.Shape[0].Value; + int? ocNumber = ((TensorConst)weights).CheckedShape[0].Value; var zeroBias = new TensorConst(new int[ocNumber == null ? default(int) : ocNumber.Value]); var conv = F.NN.Conv2D(inputDeq, weightsDeq, zeroBias, strideConst, pads, dilationConst, PadMode.Constant, group); - return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType); + return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType); } else { var biasDeq = Dequantize(bias, new QuantParam(0, ((TensorConst)xScale).Value.ToScalar() * ((TensorConst)wScale).Value.ToScalar()), DataTypes.Float32); var conv = F.NN.Conv2D(inputDeq, weightsDeq, biasDeq, strideConst, pads, dilationConst, PadMode.Constant, group); - return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType); + return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType); } } } diff --git a/src/Nncase.Importer/Onnx/QLinearMatmul.cs b/src/Nncase.Importer/Onnx/QLinearMatmul.cs index 5ab2ad73d8..892a6b6c2d 100644 --- a/src/Nncase.Importer/Onnx/QLinearMatmul.cs +++ b/src/Nncase.Importer/Onnx/QLinearMatmul.cs @@ -25,7 +25,7 @@ private Expr VisitQLinearMatMul(in NodeProto op) var aDeq = Dequantize(input_a, new QuantParam(((TensorConst)aZeroPoint).Value.ToScalar(), ((TensorConst)aScale).Value.ToScalar()), DataTypes.Float32); var bDeq = Dequantize(input_b, new QuantParam(((TensorConst)bZeroPoint).Value.ToScalar(), ((TensorConst)bScale).Value.ToScalar()), DataTypes.Float32); var matmul = F.Tensors.MatMul(aDeq, bDeq); - return Quantize(matmul, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType); + return Quantize(matmul, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType); } } } diff --git a/src/Nncase.Importer/Onnx/Quantize.cs b/src/Nncase.Importer/Onnx/Quantize.cs index cc33583711..6a4c771cb4 100644 --- a/src/Nncase.Importer/Onnx/Quantize.cs +++ b/src/Nncase.Importer/Onnx/Quantize.cs @@ -23,7 +23,7 @@ private Expr VisitQuantizeLinear(in NodeProto op) new QuantParam( biasConst.Value.ToScalar(), scaleConst.Value.ToScalar()), - ((TensorConst)bias).ValueType.DType); + ((TensorConst)bias).CheckedDataType); } throw new NotImplementedException("Onnx importer not impl for dynamic scale and bias"); diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index 80aebda267..15a6505686 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -134,7 +134,7 @@ protected override Expr RewriteLeafBuffer(TIR.Buffer expr) protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) { - if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const { ValueType: TensorType constType } @const) + if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const @const) { if (!ModuleRdataMaps.TryGetValue(Entry.ModuleKind, out var moduleRdataMap)) { @@ -163,7 +163,7 @@ protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) Changed = true; } - return memSpan.With(new TensorConst(Tensor.FromPointer((ulong)memRange.Min, constType.DType)), memRange.Max - memRange.Min); + return memSpan.With(new TensorConst(Tensor.FromPointer((ulong)memRange.Min, @const.CheckedDataType)), memRange.Max - memRange.Min); } return memSpan; diff --git a/src/Nncase.Passes/ModulePass.cs b/src/Nncase.Passes/ModulePass.cs index dc35cd0843..25f8edd98f 100644 --- a/src/Nncase.Passes/ModulePass.cs +++ b/src/Nncase.Passes/ModulePass.cs @@ -30,7 +30,7 @@ protected override Task OnPassStartAsync(IRModule input, RunPassContext context) { foreach (var func in input.Functions) { - DumpScope.Current.DumpIR(func, func.Name, "Start"); + DumpScope.Current.DumpIR(func, string.Empty, "Start"); } } @@ -44,7 +44,7 @@ protected override Task OnPassEndAsync(IRModule post, RunPassContext context) { foreach (var func in post.Functions) { - DumpScope.Current.DumpIR(func, func.Name, "End"); + DumpScope.Current.DumpIR(func, string.Empty, "End"); } } diff --git a/src/Nncase.Tests/Core/IR/UnitTestConst.cs b/src/Nncase.Tests/Core/IR/UnitTestConst.cs index 80b8a5e669..45e4e4585f 100644 --- a/src/Nncase.Tests/Core/IR/UnitTestConst.cs +++ b/src/Nncase.Tests/Core/IR/UnitTestConst.cs @@ -18,7 +18,7 @@ public void TestByte() byte expected = 1; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -30,7 +30,7 @@ public void TestUshort() ushort expected = 1; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -42,7 +42,7 @@ public void TestUint() uint expected = 1; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -54,7 +54,7 @@ public void TestUlong() ulong expected = 1; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -66,7 +66,7 @@ public void TestSbyte() sbyte expected = 1; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -78,7 +78,7 @@ public void TestShort() short expected = 1; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -90,7 +90,7 @@ public void TestInt() int expected = 1; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -102,7 +102,7 @@ public void TestLong() long expected = 1; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -114,7 +114,7 @@ public void TestHalf() var expected = (Half)1F; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -126,7 +126,7 @@ public void TestFloat() var expected = 1F; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -138,7 +138,7 @@ public void TestDouble() var expected = 1D; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -150,7 +150,7 @@ public void TestBfloat16() var expected = (BFloat16)1F; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -162,7 +162,7 @@ public void TestBool() var expected = false; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -175,7 +175,7 @@ public void TestUtf8Char() Utf8Char expected = b; Const c = expected; var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -188,7 +188,7 @@ public void TestFromTensorValue() var tv = new TensorValue(b); var c = Const.FromValue(tv); var tc = (TensorConst)c; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(b, list[0]); diff --git a/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs b/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs index c444a0e45b..733fe17b09 100644 --- a/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs +++ b/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs @@ -18,7 +18,7 @@ public void TestByte() { byte expected = 1; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -29,7 +29,7 @@ public void TestUshort() { ushort expected = 1; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -40,7 +40,7 @@ public void TestUint() { uint expected = 1; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -51,7 +51,7 @@ public void TestUlong() { ulong expected = 1; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -62,7 +62,7 @@ public void TestSbyte() { sbyte expected = 1; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -73,7 +73,7 @@ public void TestShort() { short expected = 1; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -84,7 +84,7 @@ public void TestInt() { int expected = 1; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -95,7 +95,7 @@ public void TestLong() { long expected = 1; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -106,7 +106,7 @@ public void TestHalf() { var expected = (Half)1F; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -117,7 +117,7 @@ public void TestFloat() { var expected = 1F; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -128,7 +128,7 @@ public void TestDouble() { var expected = 1D; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -139,7 +139,7 @@ public void TestBfloat16() { var expected = (BFloat16)1F; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -150,7 +150,7 @@ public void TestBool() { var expected = false; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); @@ -162,7 +162,7 @@ public void TestUtf8Char() byte b = 1; Utf8Char expected = b; TensorConst tc = expected; - Assert.True(tc.ValueType.IsScalar); + Assert.True(tc.CheckedShape.IsScalar); Assert.Equal(DataType.FromType(), tc.Value.ElementType); var list = (IList)tc.Value; Assert.Equal(expected, list[0]); diff --git a/src/Nncase.Tests/Core/UnitTestTIR.cs b/src/Nncase.Tests/Core/UnitTestTIR.cs index f0c40be178..79b72f431a 100644 --- a/src/Nncase.Tests/Core/UnitTestTIR.cs +++ b/src/Nncase.Tests/Core/UnitTestTIR.cs @@ -99,7 +99,7 @@ public void TestForSegment() public void TestGrid() { var grid1 = T.Grid(out _, LoopMode.Serial, new Range(-1f, 1f, 1)); - var grid2 = T.Grid(out _, out _, new(1, 1)); + var grid2 = T.Grid(out _, LoopMode.Serial, new Range(1, 1, 1)); Assert.Equal(grid1.GetType(), grid2.GetType()); }