Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GNNE-1980 Xpu Sram Codegen #1148

Merged
merged 14 commits into from
Dec 22, 2023
26 changes: 26 additions & 0 deletions src/Nncase.Core/CompilerServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);

/// <summary>
/// Using EGraph rewrite expression.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
IEGraph ERewrite(IEGraph expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
}

internal interface ICompilerServicesProviderInternal
Expand Down Expand Up @@ -409,6 +418,18 @@
return Provider.ERewrite(expr, rules, options);
}

/// <summary>
/// Using EGraph rewrite expression.
/// </summary>
/// <param name="graph">Expression.</param>
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
public static IEGraph ERewrite(IEGraph graph, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
return Provider.ERewrite(graph, rules, options);

Check warning on line 430 in src/Nncase.Core/CompilerServices.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/CompilerServices.cs#L430

Added line #L430 was not covered by tests
}

/// <summary>
/// Match enodes as root.
/// </summary>
Expand Down Expand Up @@ -677,4 +698,9 @@
{
return _eGraphrewriteProvider.ERewrite(expr, rules, options);
}

public IEGraph ERewrite(IEGraph graph, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
return _eGraphrewriteProvider.ERewrite(graph, rules, options);

Check warning on line 704 in src/Nncase.Core/CompilerServices.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/CompilerServices.cs#L704

Added line #L704 was not covered by tests
}
}
10 changes: 10 additions & 0 deletions src/Nncase.Core/Enum/BinaryOp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,14 @@ public enum BinaryOp : byte
/// Right Shift.
/// </summary>
RightShift,

/// <summary>
/// Floor Div.
/// </summary>
FloorDiv,

/// <summary>
/// Ceil Div.
/// </summary>
CeilDiv,
}
17 changes: 16 additions & 1 deletion src/Nncase.Core/IR/Buffers/Allocate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,20 @@
/// </summary>
public sealed partial class Allocate : Op
{
public TensorType ElemType { get; }
/// <summary>
/// Get the input parameter.
/// </summary>
public static readonly ParameterInfo Size = new(typeof(Allocate), 0, "size", TypePatternUtility.IsIntegralScalar());

Check warning on line 19 in src/Nncase.Core/IR/Buffers/Allocate.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/Buffers/Allocate.cs#L19

Added line #L19 was not covered by tests

/// <summary>
/// Gets the alloacted buffer type.
/// </summary>
public DataType ElemType { get; }

Check warning on line 24 in src/Nncase.Core/IR/Buffers/Allocate.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/Buffers/Allocate.cs#L24

Added line #L24 was not covered by tests

public TIR.MemoryLocation Location { get; }

Check warning on line 26 in src/Nncase.Core/IR/Buffers/Allocate.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/Buffers/Allocate.cs#L26

Added line #L26 was not covered by tests

/// <inheritdoc/>
public override bool CanFoldConstCall => false;

Check warning on line 29 in src/Nncase.Core/IR/Buffers/Allocate.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/Buffers/Allocate.cs#L29

Added line #L29 was not covered by tests

public override string DisplayProperty() => $"{ElemType}, {Location}";

Check warning on line 31 in src/Nncase.Core/IR/Buffers/Allocate.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/Buffers/Allocate.cs#L31

Added line #L31 was not covered by tests
}
2 changes: 2 additions & 0 deletions src/Nncase.Core/IR/Buffers/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@
/// create the uninitialized buffer.
/// </summary>
public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);

public static Call Allocate(Expr size, DataType dataType, TIR.MemoryLocation location) => new Call(new Allocate(dataType, location), size);

Check warning on line 46 in src/Nncase.Core/IR/Buffers/Functional.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/Buffers/Functional.cs#L46

Added line #L46 was not covered by tests
}
52 changes: 38 additions & 14 deletions src/Nncase.Core/IR/TensorConst.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@
Value = tensor;
}

public TensorConst(Tensor tensor, IRArray<SBP> ndsbp, Placement placement)
: base(new DistributedType(new TensorType(tensor.ElementType, tensor.Shape), ndsbp, placement))

Check warning on line 24 in src/Nncase.Core/IR/TensorConst.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/TensorConst.cs#L24

Added line #L24 was not covered by tests
{
Value = tensor;
}

Check warning on line 27 in src/Nncase.Core/IR/TensorConst.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/TensorConst.cs#L26-L27

Added lines #L26 - L27 were not covered by tests

public Tensor Value { get; }

/// <summary>
/// Gets value type.
/// </summary>
public new TensorType ValueType => (TensorType)base.ValueType;
public new IRType ValueType => base.ValueType;

/// <summary>
/// Create TensorConstant from a <see cref="byte"/>.
Expand Down Expand Up @@ -122,25 +128,43 @@
public static bool operator !=(TensorConst? left, TensorConst? right) => !(left == right);

/// <inheritdoc/>
public override string ToString() => ValueType switch
public override string ToString()
{
var x when x.IsScalar =>
x.DType switch
{
var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar<long>().ToString(),
var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar<float>().ToString(),
var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar<ulong>().ToString(),
var dtype when dtype == DataTypes.Boolean => Value.ToScalar<bool>().ToString(),
_ => $"{x.DType.GetDisplayName()} {x.Shape}",
},
_ => $"{ValueType.DType.GetDisplayName()} {ValueType.Shape}",
};
var type = ValueType switch
{
DistributedType dt => dt.TensorType,

Check warning on line 135 in src/Nncase.Core/IR/TensorConst.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/TensorConst.cs#L135

Added line #L135 was not covered by tests
TensorType tt => tt,
_ => throw new NotSupportedException("Not supported const type: " + ValueType),
};

return type switch
{
var x when x.IsScalar =>
x.DType switch
{
var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar<long>().ToString(),
var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar<float>().ToString(),
var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar<ulong>().ToString(),
var dtype when dtype == DataTypes.Boolean => Value.ToScalar<bool>().ToString(),
_ => $"{x.DType.GetDisplayName()} {x.Shape}",
},
_ => $"{type.DType.GetDisplayName()} {type.Shape}",
};
}

/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitTensorConst(this, context);

public TensorConst With(Tensor? value = null) => new TensorConst(value ?? Value);
public TensorConst With(Tensor? value = null)
{
if (value is null && ValueType is DistributedType dt)
{
return new TensorConst(Value, dt.NdSBP, dt.Placement);

Check warning on line 163 in src/Nncase.Core/IR/TensorConst.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/TensorConst.cs#L163

Added line #L163 was not covered by tests
}

return new TensorConst(value ?? Value);
}

/// <inheritdoc/>
public override bool Equals(object? obj) => Equals(obj as TensorConst);
Expand Down
31 changes: 31 additions & 0 deletions src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,37 @@
{
if (predicate)
{
if (expr.AllocBuffers.Length > 0)
{
var lets = expr.AllocBuffers.ToArray().Select(b => (T.Let(out var v, b.MemSpan.Start, b.Name + "_ptr"), v)).ToArray();
for (int i = 0; i < lets.Length - 1; i++)
{
lets[i].Item1.Body(lets[i + 1].Item1);

Check warning on line 31 in src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs#L31

Added line #L31 was not covered by tests
}

var map = new Dictionary<Expr, Expr>(ReferenceEqualityComparer.Instance);

Check warning on line 34 in src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs#L34

Added line #L34 was not covered by tests
for (int i = 0; i < expr.AllocBuffers.Length; i++)
{
map.Add(expr.AllocBuffers[i].MemSpan.Start, lets[i].v);

Check warning on line 37 in src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs#L37

Added line #L37 was not covered by tests
}

var mutator = new Substitutor(e =>
{

Check warning on line 41 in src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs#L40-L41

Added lines #L40 - L41 were not covered by tests
if (map.TryGetValue(e, out var r))
{
return r;
}

return null;
});

Check warning on line 48 in src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs#L43-L48

Added lines #L43 - L48 were not covered by tests

var initBody = mutator.Visit(expr.InitBody, Unit.Default);
var body = mutator.Visit(expr.Body, Unit.Default);

Check warning on line 51 in src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs#L50-L51

Added lines #L50 - L51 were not covered by tests

lets[^1].Item1.Body(initBody, body);
return lets[0].Item1.Build();

Check warning on line 54 in src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs#L53-L54

Added lines #L53 - L54 were not covered by tests
}

return T.Sequential(expr.InitBody, expr.Body);
}
else
Expand Down
8 changes: 4 additions & 4 deletions src/Nncase.Core/PatternMatch/ConstPattern.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ public static partial class Utility
public static TensorConstPattern IsConst(string? name, Func<float, bool> cond) => new(
x =>
{
if (DataTypes.IsFloat(x.ValueType.DType))
if (DataTypes.IsFloat(x.CheckedDataType))
{
if (x.ValueType.IsScalar)
if (x.CheckedShape.IsScalar)
{
return cond(x.Value.ToScalar<float>());
}
Expand All @@ -93,9 +93,9 @@ public static partial class Utility
public static TensorConstPattern IsConst(string? name, Func<int, bool> cond) => new(
x =>
{
if (DataTypes.IsIntegral(x.ValueType.DType))
if (DataTypes.IsIntegral(x.CheckedDataType))
{
if (x.ValueType.IsScalar)
if (x.CheckedShape.IsScalar)
{
return cond(x.Value.ToScalar<int>());
}
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public T Build()

public ISequentialBuilder<T> InsertBody(int index, params object[] exprOrBuilders)
{
_subBuilders[_subBuilders.Length - 1].InsertBody(index, exprOrBuilders);
_subBuilders[index < 0 ? _subBuilders.Length + index : index].Body(exprOrBuilders);
return this;
}
}
61 changes: 56 additions & 5 deletions src/Nncase.Core/TIR/Script.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,16 @@
{
string[] names = { "i", "j", "k", "l" };
var newLoopVars = loopVars = new Var[ranges.Length];
return new NestBodyExprBuilder<For>(ranges.Select((rg, i) =>
T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray());
var newLoops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray();
return new NestBodyExprBuilder<For>(newLoops);
}

public static ISequentialBuilder<For> Grid(out Var[] loopVars, out ISequentialBuilder<For>[] loops, LoopMode loopMode, params TIR.Range[] ranges)
{
string[] names = { "i", "j", "k", "l" };
var newLoopVars = loopVars = new Var[ranges.Length];

Check warning on line 144 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L143-L144

Added lines #L143 - L144 were not covered by tests
var newLoops = loops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray();
return new NestBodyExprBuilder<For>(loops);

Check warning on line 146 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L146

Added line #L146 was not covered by tests
}

/// <summary>
Expand Down Expand Up @@ -223,6 +231,49 @@
return buffer;
}

/// <summary>
/// create the buffer by expressions.
/// </summary>
public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
{
if (name.StartsWith("var "))
{
name = name[4..];

Check warning on line 241 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L241

Added line #L241 was not covered by tests
}

var strides = TensorUtilities.GetStrides(dimensions);
var size = TensorUtilities.GetProduct(dimensions.ToArray()) * dataType.SizeInBytes;
var memspan = new MemSpan(size, location);
buffer = new Buffer(name, dataType, memspan, dimensions, strides);
return buffer;

Check warning on line 248 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L244-L248

Added lines #L244 - L248 were not covered by tests
}

public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, Expr[] strides, MemSpan memSpan, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
{
if (name.StartsWith("var "))
{
name = name[4..];

Check warning on line 255 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L255

Added line #L255 was not covered by tests
}

buffer = new Buffer(name, dataType, memSpan, dimensions, strides);
return buffer;

Check warning on line 259 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L258-L259

Added lines #L258 - L259 were not covered by tests
}

public static Buffer AttachBuffer(Expr start, TensorType tensorType, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
{
if (name.StartsWith("var "))
{
name = name[4..];

Check warning on line 266 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L266

Added line #L266 was not covered by tests
}

var dimensions = tensorType.Shape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
var memspan = new MemSpan(start, size, location);

Check warning on line 272 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L269-L272

Added lines #L269 - L272 were not covered by tests
buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;

Check warning on line 274 in src/Nncase.Core/TIR/Script.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/Script.cs#L274

Added line #L274 was not covered by tests
}

/// <summary>
/// create buffer by const.
/// </summary>
Expand All @@ -233,11 +284,11 @@
name = name[4..];
}

var dimensions = @const.ValueType.Shape.ToValueArray();
var dimensions = @const.CheckedShape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.ValueType.DType.SizeInBytes;
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.CheckedDataType.SizeInBytes;
var memspan = new MemSpan(IR.F.Buffer.DDrOf(@const), size, MemoryLocation.Rdata);
buffer = new Buffer(name, @const.ValueType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
buffer = new Buffer(name, @const.CheckedDataType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/TensorUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static Expr GetProduct(ReadOnlySpan<Expr> dimensions, int startIndex = 0)
for (int i = startIndex; i < dimensions.Length; i++)
{
var dimension = dimensions[i];
product *= IR.F.Math.Require(dimension >= 0, dimension, "Dimension is out of range.");
product *= dimension;
}

return product;
Expand Down
Loading
Loading