diff --git a/src/Nncase.Core/CompilerServices.cs b/src/Nncase.Core/CompilerServices.cs
index 9c666bcd21..2f731a7fdf 100644
--- a/src/Nncase.Core/CompilerServices.cs
+++ b/src/Nncase.Core/CompilerServices.cs
@@ -208,6 +208,15 @@ public interface ICompilerServicesProvider
/// Options.
/// Rewrited expression.
Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext options);
+
+ ///
+ /// Using EGraph rewrite expression.
+ ///
+ /// Expression.
+ /// Rewrite rules.
+ /// Options.
+ /// Rewrited expression.
+ IEGraph ERewrite(IEGraph expr, IEnumerable rules, RunPassContext options);
}
internal interface ICompilerServicesProviderInternal
@@ -409,6 +418,18 @@ public static Expr ERewrite(Expr expr, IEnumerable rules, RunPassC
return Provider.ERewrite(expr, rules, options);
}
+ ///
+ /// Using EGraph rewrite expression.
+ ///
+ /// Expression.
+ /// Rewrite rules.
+ /// Options.
+ /// Rewrited expression.
+ public static IEGraph ERewrite(IEGraph graph, IEnumerable rules, RunPassContext options)
+ {
+ return Provider.ERewrite(graph, rules, options);
+ }
+
///
/// Match enodes as root.
///
@@ -677,4 +698,9 @@ public Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext
{
return _eGraphrewriteProvider.ERewrite(expr, rules, options);
}
+
+ public IEGraph ERewrite(IEGraph graph, IEnumerable rules, RunPassContext options)
+ {
+ return _eGraphrewriteProvider.ERewrite(graph, rules, options);
+ }
}
diff --git a/src/Nncase.Core/Enum/BinaryOp.cs b/src/Nncase.Core/Enum/BinaryOp.cs
index 45afba68cc..fa4092fa48 100644
--- a/src/Nncase.Core/Enum/BinaryOp.cs
+++ b/src/Nncase.Core/Enum/BinaryOp.cs
@@ -93,4 +93,14 @@ public enum BinaryOp : byte
/// Right Shift.
///
RightShift,
+
+ ///
+ /// Floor Div.
+ ///
+ FloorDiv,
+
+ ///
+ /// Ceil Div.
+ ///
+ CeilDiv,
}
diff --git a/src/Nncase.Core/IR/Buffers/Allocate.cs b/src/Nncase.Core/IR/Buffers/Allocate.cs
index 14b44010ec..ff7bdd13c0 100644
--- a/src/Nncase.Core/IR/Buffers/Allocate.cs
+++ b/src/Nncase.Core/IR/Buffers/Allocate.cs
@@ -13,5 +13,20 @@ namespace Nncase.IR.Buffers;
///
public sealed partial class Allocate : Op
{
- public TensorType ElemType { get; }
+ ///
+ /// Get the input parameter.
+ ///
+ public static readonly ParameterInfo Size = new(typeof(Allocate), 0, "size", TypePatternUtility.IsIntegralScalar());
+
+ ///
+ /// Gets the alloacted buffer type.
+ ///
+ public DataType ElemType { get; }
+
+ public TIR.MemoryLocation Location { get; }
+
+ ///
+ public override bool CanFoldConstCall => false;
+
+ public override string DisplayProperty() => $"{ElemType}, {Location}";
}
diff --git a/src/Nncase.Core/IR/Buffers/Functional.cs b/src/Nncase.Core/IR/Buffers/Functional.cs
index a2e3507a5f..463c4f1e2c 100644
--- a/src/Nncase.Core/IR/Buffers/Functional.cs
+++ b/src/Nncase.Core/IR/Buffers/Functional.cs
@@ -42,4 +42,6 @@ public static class Buffer
/// create the uninitialized buffer.
///
public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);
+
+ public static Call Allocate(Expr size, DataType dataType, TIR.MemoryLocation location) => new Call(new Allocate(dataType, location), size);
}
diff --git a/src/Nncase.Core/IR/TensorConst.cs b/src/Nncase.Core/IR/TensorConst.cs
index 9e651978ed..07dccfbcc3 100644
--- a/src/Nncase.Core/IR/TensorConst.cs
+++ b/src/Nncase.Core/IR/TensorConst.cs
@@ -20,12 +20,18 @@ public TensorConst(Tensor tensor)
Value = tensor;
}
+ public TensorConst(Tensor tensor, IRArray ndsbp, Placement placement)
+ : base(new DistributedType(new TensorType(tensor.ElementType, tensor.Shape), ndsbp, placement))
+ {
+ Value = tensor;
+ }
+
public Tensor Value { get; }
///
/// Gets value type.
///
- public new TensorType ValueType => (TensorType)base.ValueType;
+ public new IRType ValueType => base.ValueType;
///
/// Create TensorConstant from a .
@@ -122,25 +128,43 @@ public TensorConst(Tensor tensor)
public static bool operator !=(TensorConst? left, TensorConst? right) => !(left == right);
///
- public override string ToString() => ValueType switch
+ public override string ToString()
{
- var x when x.IsScalar =>
- x.DType switch
- {
- var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar().ToString(),
- var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar().ToString(),
- var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar().ToString(),
- var dtype when dtype == DataTypes.Boolean => Value.ToScalar().ToString(),
- _ => $"{x.DType.GetDisplayName()} {x.Shape}",
- },
- _ => $"{ValueType.DType.GetDisplayName()} {ValueType.Shape}",
- };
+ var type = ValueType switch
+ {
+ DistributedType dt => dt.TensorType,
+ TensorType tt => tt,
+ _ => throw new NotSupportedException("Not supported const type: " + ValueType),
+ };
+
+ return type switch
+ {
+ var x when x.IsScalar =>
+ x.DType switch
+ {
+ var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar().ToString(),
+ var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar().ToString(),
+ var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar().ToString(),
+ var dtype when dtype == DataTypes.Boolean => Value.ToScalar().ToString(),
+ _ => $"{x.DType.GetDisplayName()} {x.Shape}",
+ },
+ _ => $"{type.DType.GetDisplayName()} {type.Shape}",
+ };
+ }
///
public override TExprResult Accept(ExprFunctor functor, TContext context)
=> functor.VisitTensorConst(this, context);
- public TensorConst With(Tensor? value = null) => new TensorConst(value ?? Value);
+ public TensorConst With(Tensor? value = null)
+ {
+ if (value is null && ValueType is DistributedType dt)
+ {
+ return new TensorConst(Value, dt.NdSBP, dt.Placement);
+ }
+
+ return new TensorConst(value ?? Value);
+ }
///
public override bool Equals(object? obj) => Equals(obj as TensorConst);
diff --git a/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs b/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs
index 7559954a49..1a64b55917 100644
--- a/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs
+++ b/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs
@@ -23,6 +23,37 @@ protected override Expr RewriteLeafBlock(Block expr)
{
if (predicate)
{
+ if (expr.AllocBuffers.Length > 0)
+ {
+ var lets = expr.AllocBuffers.ToArray().Select(b => (T.Let(out var v, b.MemSpan.Start, b.Name + "_ptr"), v)).ToArray();
+ for (int i = 0; i < lets.Length - 1; i++)
+ {
+ lets[i].Item1.Body(lets[i + 1].Item1);
+ }
+
+ var map = new Dictionary(ReferenceEqualityComparer.Instance);
+ for (int i = 0; i < expr.AllocBuffers.Length; i++)
+ {
+ map.Add(expr.AllocBuffers[i].MemSpan.Start, lets[i].v);
+ }
+
+ var mutator = new Substitutor(e =>
+ {
+ if (map.TryGetValue(e, out var r))
+ {
+ return r;
+ }
+
+ return null;
+ });
+
+ var initBody = mutator.Visit(expr.InitBody, Unit.Default);
+ var body = mutator.Visit(expr.Body, Unit.Default);
+
+ lets[^1].Item1.Body(initBody, body);
+ return lets[0].Item1.Build();
+ }
+
return T.Sequential(expr.InitBody, expr.Body);
}
else
diff --git a/src/Nncase.Core/PatternMatch/ConstPattern.cs b/src/Nncase.Core/PatternMatch/ConstPattern.cs
index cb9a52aebe..bca5ec63a1 100644
--- a/src/Nncase.Core/PatternMatch/ConstPattern.cs
+++ b/src/Nncase.Core/PatternMatch/ConstPattern.cs
@@ -66,9 +66,9 @@ public static partial class Utility
public static TensorConstPattern IsConst(string? name, Func cond) => new(
x =>
{
- if (DataTypes.IsFloat(x.ValueType.DType))
+ if (DataTypes.IsFloat(x.CheckedDataType))
{
- if (x.ValueType.IsScalar)
+ if (x.CheckedShape.IsScalar)
{
return cond(x.Value.ToScalar());
}
@@ -93,9 +93,9 @@ public static partial class Utility
public static TensorConstPattern IsConst(string? name, Func cond) => new(
x =>
{
- if (DataTypes.IsIntegral(x.ValueType.DType))
+ if (DataTypes.IsIntegral(x.CheckedDataType))
{
- if (x.ValueType.IsScalar)
+ if (x.CheckedShape.IsScalar)
{
return cond(x.Value.ToScalar());
}
diff --git a/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs b/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs
index 60d79a6b89..2add795b71 100644
--- a/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs
+++ b/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs
@@ -51,7 +51,7 @@ public T Build()
public ISequentialBuilder InsertBody(int index, params object[] exprOrBuilders)
{
- _subBuilders[_subBuilders.Length - 1].InsertBody(index, exprOrBuilders);
+ _subBuilders[index < 0 ? _subBuilders.Length + index : index].Body(exprOrBuilders);
return this;
}
}
diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs
index 28740e43ab..ffda2527df 100644
--- a/src/Nncase.Core/TIR/Script.cs
+++ b/src/Nncase.Core/TIR/Script.cs
@@ -134,8 +134,16 @@ public static ISequentialBuilder Grid(out Var[] loopVars, LoopMode loopMode
{
string[] names = { "i", "j", "k", "l" };
var newLoopVars = loopVars = new Var[ranges.Length];
- return new NestBodyExprBuilder(ranges.Select((rg, i) =>
- T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray());
+ var newLoops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray();
+ return new NestBodyExprBuilder(newLoops);
+ }
+
+ public static ISequentialBuilder Grid(out Var[] loopVars, out ISequentialBuilder[] loops, LoopMode loopMode, params TIR.Range[] ranges)
+ {
+ string[] names = { "i", "j", "k", "l" };
+ var newLoopVars = loopVars = new Var[ranges.Length];
+ var newLoops = loops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray();
+ return new NestBodyExprBuilder(loops);
}
///
@@ -223,6 +231,49 @@ public static Buffer CreateBuffer(TensorType tensorType, MemoryLocation location
return buffer;
}
+ ///
+ /// create the buffer by expressions.
+ ///
+ public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
+ {
+ if (name.StartsWith("var "))
+ {
+ name = name[4..];
+ }
+
+ var strides = TensorUtilities.GetStrides(dimensions);
+ var size = TensorUtilities.GetProduct(dimensions.ToArray()) * dataType.SizeInBytes;
+ var memspan = new MemSpan(size, location);
+ buffer = new Buffer(name, dataType, memspan, dimensions, strides);
+ return buffer;
+ }
+
+ public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, Expr[] strides, MemSpan memSpan, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
+ {
+ if (name.StartsWith("var "))
+ {
+ name = name[4..];
+ }
+
+ buffer = new Buffer(name, dataType, memSpan, dimensions, strides);
+ return buffer;
+ }
+
+ public static Buffer AttachBuffer(Expr start, TensorType tensorType, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
+ {
+ if (name.StartsWith("var "))
+ {
+ name = name[4..];
+ }
+
+ var dimensions = tensorType.Shape.ToValueArray();
+ var strides = TensorUtilities.GetStrides(dimensions);
+ var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
+ var memspan = new MemSpan(start, size, location);
+ buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
+ return buffer;
+ }
+
///
/// create buffer by const.
///
@@ -233,11 +284,11 @@ public static Buffer AttachBuffer(TensorConst @const, out Buffer buffer, [Caller
name = name[4..];
}
- var dimensions = @const.ValueType.Shape.ToValueArray();
+ var dimensions = @const.CheckedShape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
- var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.ValueType.DType.SizeInBytes;
+ var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.CheckedDataType.SizeInBytes;
var memspan = new MemSpan(IR.F.Buffer.DDrOf(@const), size, MemoryLocation.Rdata);
- buffer = new Buffer(name, @const.ValueType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
+ buffer = new Buffer(name, @const.CheckedDataType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
}
diff --git a/src/Nncase.Core/TensorUtilities.cs b/src/Nncase.Core/TensorUtilities.cs
index 79f658aefa..146e5c6cfa 100644
--- a/src/Nncase.Core/TensorUtilities.cs
+++ b/src/Nncase.Core/TensorUtilities.cs
@@ -69,7 +69,7 @@ public static Expr GetProduct(ReadOnlySpan dimensions, int startIndex = 0)
for (int i = startIndex; i < dimensions.Length; i++)
{
var dimension = dimensions[i];
- product *= IR.F.Math.Require(dimension >= 0, dimension, "Dimension is out of range.");
+ product *= dimension;
}
return product;
diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs
index 2061a40958..eb4a84be0d 100644
--- a/src/Nncase.Core/Utilities/DistributedUtility.cs
+++ b/src/Nncase.Core/Utilities/DistributedUtility.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System.Diagnostics.CodeAnalysis;
+using NetFabric.Hyperlinq;
using Nncase.IR;
namespace Nncase.Utilities;
@@ -26,11 +27,7 @@ public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tens
ndsbps.Add(ndsbp);
}
- return ndsbps.CartesianProduct().
- Select(ndsbp => ndsbp.ToArray()).
- Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).
- Select(ndsbp => new IRArray(ndsbp)).
- ToArray();
+ return ndsbps.CartesianProduct().Select(ndsbp => ndsbp.ToArray()).Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).Select(ndsbp => new IRArray(ndsbp)).ToArray();
}
public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedType distributedType)
@@ -65,11 +62,7 @@ public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedT
}
}
- return candidateNdsbps.CartesianProduct().
- Select(ndsbp => ndsbp.ToArray()).
- Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).
- Select(ndsbp => new IRArray(ndsbp)).
- ToArray();
+ return candidateNdsbps.CartesianProduct().Select(ndsbp => ndsbp.ToArray()).Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).Select(ndsbp => new IRArray(ndsbp)).ToArray();
}
public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsbp, Placement placement)
@@ -131,24 +124,74 @@ public static Expr[] TryGetNonUniformDividedShape(DistributedType distributedTyp
}
return hierarchies.Select((divs, axis) =>
+ {
+ Expr dim;
+ if (divs.Any())
{
- Expr dim;
- if (divs.Any())
+ var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray());
+ var (res, rem) = Math.DivRem(shape[axis], divsor);
+ if (rem == 0)
{
- var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray());
- var (res, rem) = Math.DivRem(shape[axis], divsor);
- dim = IR.F.Math.Select(
- TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1),
- res,
- res + rem);
+ return res;
}
- else
+
+ dim = IR.F.Math.Select(
+ TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1),
+ res,
+ res + rem);
+ }
+ else
+ {
+ dim = distributedType.TensorType.Shape[axis].FixedValue;
+ }
+
+ return dim;
+ }).ToArray();
+ }
+
+ public static List TryGetNonUniformDividedSlice(DistributedType distributedType)
+ {
+ var shape = distributedType.TensorType.Shape.ToValueArray();
+ var hierarchies = Enumerable.Range(0, shape.Length).Select(i => new List()).ToArray();
+ for (int i = 0; i < distributedType.NdSBP.Count; i++)
+ {
+ if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis })
+ {
+ hierarchies[axis].Add(i);
+ }
+ }
+
+ var spliList = hierarchies.Select, int[]>((divs, axis) =>
+ {
+ int[] dim;
+ if (divs.Any())
+ {
+ var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray());
+ var (res, rem) = Math.DivRem(shape[axis], divsor);
+ if (rem == 0)
{
- dim = distributedType.TensorType.Shape[axis].FixedValue;
+ return new[] { res };
}
- return dim;
- }).ToArray();
+ dim = new[] { res, res + rem };
+ }
+ else
+ {
+ dim = distributedType.TensorType.Shape.ToValueArray().Skip(axis).Take(1).ToArray();
+ }
+
+ return dim;
+ }).ToList();
+
+ IEnumerable ret = new[] { Array.Empty() };
+ foreach (int[] array in spliList)
+ {
+ ret = from seq in ret
+ from item in array
+ select seq.Concat(new[] { item }).ToArray();
+ }
+
+ return ret.ToList();
}
public static bool IsDivideBy(int input, int divisor)
@@ -174,17 +217,14 @@ public static bool IsDivideExactly(int input, int divisor)
public static float GetDividedTensorEfficiency(DistributedType distributedType, int burstLength)
{
var (tiles, shape) = GetDividedTile(distributedType);
- return Enumerable.Range(0, tiles.Count).
- Select(i => tiles[i].Ranges(0, shape[i])).
- CartesianProduct().
- Select(rgs =>
- {
- var slice = rgs.ToArray();
- var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToArray(), slice, out var contiguousStart);
- var size = TensorUtilities.GetProduct(tiles.ToArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes;
- var (div, rem) = Math.DivRem(size, burstLength);
- return ((div * 1.0f) + ((float)rem / burstLength)) / (div + 1);
- }).Average();
+ return Enumerable.Range(0, tiles.Count).Select(i => tiles[i].Ranges(0, shape[i])).CartesianProduct().Select(rgs =>
+ {
+ var slice = rgs.ToArray();
+ var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToArray(), slice, out var contiguousStart);
+ var size = TensorUtilities.GetProduct(tiles.ToArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes;
+ var (div, rem) = Math.DivRem(size, burstLength);
+ return ((div * 1.0f) + ((float)rem / burstLength)) / (div + 1);
+ }).Average();
}
public static TensorType GetDividedTensorType(DistributedType distributedType)
diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
index 4a447073b8..079dafc907 100644
--- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
+++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
@@ -248,6 +248,8 @@ public ILPrintVisitor(TextWriter textWriter, bool display_callable, int indent_l
_scope = new(textWriter, indent_level);
}
+ public override string DefaultVisitType(IRType type) => type.ToString();
+
///
public override string VisitType(AnyType type) => "any";
diff --git a/src/Nncase.Evaluator/Buffers/Allocate.cs b/src/Nncase.Evaluator/Buffers/Allocate.cs
index d1a75d3375..d09f1654f3 100644
--- a/src/Nncase.Evaluator/Buffers/Allocate.cs
+++ b/src/Nncase.Evaluator/Buffers/Allocate.cs
@@ -14,6 +14,6 @@ public partial class AllocateEvaluator : ITypeInferencer
///
public IRType Visit(ITypeInferenceContext context, Allocate target)
{
- return TensorType.Pointer(target.ElemType.DType);
+ return TensorType.Pointer(target.ElemType);
}
}
diff --git a/src/Nncase.Evaluator/Buffers/DDrOf.cs b/src/Nncase.Evaluator/Buffers/DDrOf.cs
index 86c53e04b7..b329ee1787 100644
--- a/src/Nncase.Evaluator/Buffers/DDrOf.cs
+++ b/src/Nncase.Evaluator/Buffers/DDrOf.cs
@@ -12,8 +12,13 @@ namespace Nncase.Evaluator.Buffers;
[TypeInferGenerator]
public partial class DDrOfEvaluator : ITypeInferencer
{
- private IRType Visit(TensorType input)
+ private IRType Visit(IRType input)
{
- return TensorType.Pointer(input.DType);
+ return input switch
+ {
+ DistributedType d => TensorType.Pointer(d.TensorType.DType),
+ TensorType t => TensorType.Pointer(t.DType),
+ _ => new InvalidType(input.GetType().Name),
+ };
}
}
diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs
index 1ac424b0be..4e8ab2e493 100755
--- a/src/Nncase.Evaluator/Math/Binary.cs
+++ b/src/Nncase.Evaluator/Math/Binary.cs
@@ -214,6 +214,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b)
BinaryOp.Sub => a - b,
BinaryOp.Mul => a * b,
BinaryOp.Div => a / b,
+ BinaryOp.FloorDiv => (int)System.Math.Floor((float)a / b),
+ BinaryOp.CeilDiv => (int)System.Math.Ceiling((float)a / b),
BinaryOp.Mod => a % b,
BinaryOp.Min => System.Math.Min(a, b),
BinaryOp.Max => System.Math.Max(a, b),
@@ -227,6 +229,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b)
BinaryOp.Sub => a - b,
BinaryOp.Mul => a * b,
BinaryOp.Div => a / b,
+ BinaryOp.FloorDiv => (uint)System.Math.Floor((float)a / b),
+ BinaryOp.CeilDiv => (uint)System.Math.Ceiling((float)a / b),
BinaryOp.Mod => a % b,
BinaryOp.Min => System.Math.Min(a, b),
BinaryOp.Max => System.Math.Max(a, b),
@@ -242,6 +246,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b)
BinaryOp.Sub => a - b,
BinaryOp.Mul => a * b,
BinaryOp.Div => a / b,
+ BinaryOp.FloorDiv => (ulong)System.Math.Floor((float)a / b),
+ BinaryOp.CeilDiv => (ulong)System.Math.Ceiling((float)a / b),
BinaryOp.Mod => a % b,
BinaryOp.Min => System.Math.Min(a, b),
BinaryOp.Max => System.Math.Max(a, b),
@@ -262,6 +268,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b)
BinaryOp.Sub => a - b,
BinaryOp.Mul => a * b,
BinaryOp.Div => a / b,
+ BinaryOp.FloorDiv => (long)System.Math.Floor((float)a / b),
+ BinaryOp.CeilDiv => (long)System.Math.Ceiling((float)a / b),
BinaryOp.Mod => a % b,
BinaryOp.Min => System.Math.Min(a, b),
BinaryOp.Max => System.Math.Max(a, b),
@@ -298,6 +306,8 @@ static OrtKISharp.Tensor Mod(OrtKISharp.Tensor a, OrtKISharp.Tensor b)
BinaryOp.Sub => a - b,
BinaryOp.Mul => a * b,
BinaryOp.Div => a / b,
+ BinaryOp.FloorDiv => OrtKI.Floor(a.Cast(OrtDataType.Float) / b.Cast(OrtDataType.Float)).Cast(a.DataType),
+ BinaryOp.CeilDiv => OrtKI.Ceil(a.Cast(OrtDataType.Float) / b.Cast(OrtDataType.Float)).Cast(a.DataType),
BinaryOp.Mod => Mod(a, b),
BinaryOp.Min => OrtKI.Min(new[] { a, b }),
BinaryOp.Max => OrtKI.Max(new[] { a, b }),
diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs
index 1f19b64388..4642a4f8d5 100644
--- a/src/Nncase.Evaluator/Math/MatMul.cs
+++ b/src/Nncase.Evaluator/Math/MatMul.cs
@@ -23,7 +23,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b)
{
if (VisitTensorType(a.TensorType, b.TensorType) is not TensorType outType)
{
- return new InvalidType(string.Empty);
+ return new InvalidType($"{a.TensorType} {b.TensorType} not support");
}
if (a.Placement != b.Placement)
@@ -162,7 +162,7 @@ public IRType Visit(ITypeInferenceContext context, MatMul target)
{
(DistributedType a, DistributedType b) => VisitDistributedType(a, b),
(TensorType a, TensorType b) => VisitTensorType(a, b),
- _ => new InvalidType(string.Empty),
+ _ => new InvalidType($"{lhs} {rhs} not support"),
};
}
diff --git a/src/Nncase.Importer/Onnx/QLinearConv.cs b/src/Nncase.Importer/Onnx/QLinearConv.cs
index 0ab9ebc99c..3e1bb7e07a 100644
--- a/src/Nncase.Importer/Onnx/QLinearConv.cs
+++ b/src/Nncase.Importer/Onnx/QLinearConv.cs
@@ -29,13 +29,13 @@ private Expr VisitQLinearConv(in NodeProto op)
var group = GetIntAttribute(op, "group", 1);
var strides = GetStrideAttribute(op);
- int? stridesValueLen = ((TensorConst)strides).ValueType.Shape[0].Value;
+ int? stridesValueLen = ((TensorConst)strides).CheckedShape[0].Value;
for (var i = 0; i < stridesValueLen; i++)
{
System.Diagnostics.Trace.Assert(((TensorConst)strides).Value.Cast()[i] <= (long)int.MaxValue);
}
- int? dilationValueLen = ((TensorConst)dilation).ValueType.Shape[0].Value;
+ int? dilationValueLen = ((TensorConst)dilation).CheckedShape[0].Value;
for (var i = 0; i < dilationValueLen; i++)
{
System.Diagnostics.Trace.Assert(((TensorConst)dilation).Value.Cast()[i] <= (long)int.MaxValue);
@@ -63,16 +63,16 @@ private Expr VisitQLinearConv(in NodeProto op)
if (bias == null)
{
- int? ocNumber = ((TensorConst)weights).ValueType.Shape[0].Value;
+ int? ocNumber = ((TensorConst)weights).CheckedShape[0].Value;
var zeroBias = new TensorConst(new int[ocNumber == null ? default(int) : ocNumber.Value]);
var conv = F.NN.Conv2D(inputDeq, weightsDeq, zeroBias, strideConst, pads, dilationConst, PadMode.Constant, group);
- return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType);
+ return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType);
}
else
{
var biasDeq = Dequantize(bias, new QuantParam(0, ((TensorConst)xScale).Value.ToScalar() * ((TensorConst)wScale).Value.ToScalar()), DataTypes.Float32);
var conv = F.NN.Conv2D(inputDeq, weightsDeq, biasDeq, strideConst, pads, dilationConst, PadMode.Constant, group);
- return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType);
+ return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType);
}
}
}
diff --git a/src/Nncase.Importer/Onnx/QLinearMatmul.cs b/src/Nncase.Importer/Onnx/QLinearMatmul.cs
index 5ab2ad73d8..892a6b6c2d 100644
--- a/src/Nncase.Importer/Onnx/QLinearMatmul.cs
+++ b/src/Nncase.Importer/Onnx/QLinearMatmul.cs
@@ -25,7 +25,7 @@ private Expr VisitQLinearMatMul(in NodeProto op)
var aDeq = Dequantize(input_a, new QuantParam(((TensorConst)aZeroPoint).Value.ToScalar(), ((TensorConst)aScale).Value.ToScalar()), DataTypes.Float32);
var bDeq = Dequantize(input_b, new QuantParam(((TensorConst)bZeroPoint).Value.ToScalar(), ((TensorConst)bScale).Value.ToScalar()), DataTypes.Float32);
var matmul = F.Tensors.MatMul(aDeq, bDeq);
- return Quantize(matmul, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType);
+ return Quantize(matmul, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType);
}
}
}
diff --git a/src/Nncase.Importer/Onnx/Quantize.cs b/src/Nncase.Importer/Onnx/Quantize.cs
index cc33583711..6a4c771cb4 100644
--- a/src/Nncase.Importer/Onnx/Quantize.cs
+++ b/src/Nncase.Importer/Onnx/Quantize.cs
@@ -23,7 +23,7 @@ private Expr VisitQuantizeLinear(in NodeProto op)
new QuantParam(
biasConst.Value.ToScalar(),
scaleConst.Value.ToScalar()),
- ((TensorConst)bias).ValueType.DType);
+ ((TensorConst)bias).CheckedDataType);
}
throw new NotImplementedException("Onnx importer not impl for dynamic scale and bias");
diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs
index 80aebda267..15a6505686 100644
--- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs
+++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs
@@ -134,7 +134,7 @@ protected override Expr RewriteLeafBuffer(TIR.Buffer expr)
protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan)
{
- if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const { ValueType: TensorType constType } @const)
+ if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const @const)
{
if (!ModuleRdataMaps.TryGetValue(Entry.ModuleKind, out var moduleRdataMap))
{
@@ -163,7 +163,7 @@ protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan)
Changed = true;
}
- return memSpan.With(new TensorConst(Tensor.FromPointer((ulong)memRange.Min, constType.DType)), memRange.Max - memRange.Min);
+ return memSpan.With(new TensorConst(Tensor.FromPointer((ulong)memRange.Min, @const.CheckedDataType)), memRange.Max - memRange.Min);
}
return memSpan;
diff --git a/src/Nncase.Passes/ModulePass.cs b/src/Nncase.Passes/ModulePass.cs
index dc35cd0843..25f8edd98f 100644
--- a/src/Nncase.Passes/ModulePass.cs
+++ b/src/Nncase.Passes/ModulePass.cs
@@ -30,7 +30,7 @@ protected override Task OnPassStartAsync(IRModule input, RunPassContext context)
{
foreach (var func in input.Functions)
{
- DumpScope.Current.DumpIR(func, func.Name, "Start");
+ DumpScope.Current.DumpIR(func, string.Empty, "Start");
}
}
@@ -44,7 +44,7 @@ protected override Task OnPassEndAsync(IRModule post, RunPassContext context)
{
foreach (var func in post.Functions)
{
- DumpScope.Current.DumpIR(func, func.Name, "End");
+ DumpScope.Current.DumpIR(func, string.Empty, "End");
}
}
diff --git a/src/Nncase.Tests/Core/IR/UnitTestConst.cs b/src/Nncase.Tests/Core/IR/UnitTestConst.cs
index 80b8a5e669..45e4e4585f 100644
--- a/src/Nncase.Tests/Core/IR/UnitTestConst.cs
+++ b/src/Nncase.Tests/Core/IR/UnitTestConst.cs
@@ -18,7 +18,7 @@ public void TestByte()
byte expected = 1;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -30,7 +30,7 @@ public void TestUshort()
ushort expected = 1;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -42,7 +42,7 @@ public void TestUint()
uint expected = 1;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -54,7 +54,7 @@ public void TestUlong()
ulong expected = 1;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -66,7 +66,7 @@ public void TestSbyte()
sbyte expected = 1;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -78,7 +78,7 @@ public void TestShort()
short expected = 1;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -90,7 +90,7 @@ public void TestInt()
int expected = 1;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -102,7 +102,7 @@ public void TestLong()
long expected = 1;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -114,7 +114,7 @@ public void TestHalf()
var expected = (Half)1F;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -126,7 +126,7 @@ public void TestFloat()
var expected = 1F;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -138,7 +138,7 @@ public void TestDouble()
var expected = 1D;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -150,7 +150,7 @@ public void TestBfloat16()
var expected = (BFloat16)1F;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -162,7 +162,7 @@ public void TestBool()
var expected = false;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -175,7 +175,7 @@ public void TestUtf8Char()
Utf8Char expected = b;
Const c = expected;
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -188,7 +188,7 @@ public void TestFromTensorValue()
var tv = new TensorValue(b);
var c = Const.FromValue(tv);
var tc = (TensorConst)c;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(b, list[0]);
diff --git a/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs b/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs
index c444a0e45b..733fe17b09 100644
--- a/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs
+++ b/src/Nncase.Tests/Core/IR/UnitTestTensorConst.cs
@@ -18,7 +18,7 @@ public void TestByte()
{
byte expected = 1;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -29,7 +29,7 @@ public void TestUshort()
{
ushort expected = 1;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -40,7 +40,7 @@ public void TestUint()
{
uint expected = 1;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -51,7 +51,7 @@ public void TestUlong()
{
ulong expected = 1;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -62,7 +62,7 @@ public void TestSbyte()
{
sbyte expected = 1;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -73,7 +73,7 @@ public void TestShort()
{
short expected = 1;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -84,7 +84,7 @@ public void TestInt()
{
int expected = 1;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -95,7 +95,7 @@ public void TestLong()
{
long expected = 1;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -106,7 +106,7 @@ public void TestHalf()
{
var expected = (Half)1F;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -117,7 +117,7 @@ public void TestFloat()
{
var expected = 1F;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -128,7 +128,7 @@ public void TestDouble()
{
var expected = 1D;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -139,7 +139,7 @@ public void TestBfloat16()
{
var expected = (BFloat16)1F;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -150,7 +150,7 @@ public void TestBool()
{
var expected = false;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
@@ -162,7 +162,7 @@ public void TestUtf8Char()
byte b = 1;
Utf8Char expected = b;
TensorConst tc = expected;
- Assert.True(tc.ValueType.IsScalar);
+ Assert.True(tc.CheckedShape.IsScalar);
Assert.Equal(DataType.FromType(), tc.Value.ElementType);
var list = (IList)tc.Value;
Assert.Equal(expected, list[0]);
diff --git a/src/Nncase.Tests/Core/UnitTestTIR.cs b/src/Nncase.Tests/Core/UnitTestTIR.cs
index f0c40be178..79b72f431a 100644
--- a/src/Nncase.Tests/Core/UnitTestTIR.cs
+++ b/src/Nncase.Tests/Core/UnitTestTIR.cs
@@ -99,7 +99,7 @@ public void TestForSegment()
public void TestGrid()
{
var grid1 = T.Grid(out _, LoopMode.Serial, new Range(-1f, 1f, 1));
- var grid2 = T.Grid(out _, out _, new(1, 1));
+ var grid2 = T.Grid(out _, LoopMode.Serial, new Range(1, 1, 1));
Assert.Equal(grid1.GetType(), grid2.GetType());
}