Skip to content

Commit

Permalink
Revert "Revert "GNNE-1980 Xpu Sram Codegen (#1148)""
Browse files Browse the repository at this point in the history
This reverts commit dbc04fa.
  • Loading branch information
zhangyang2057 committed Dec 26, 2023
1 parent dbc04fa commit be412c2
Show file tree
Hide file tree
Showing 24 changed files with 321 additions and 105 deletions.
26 changes: 26 additions & 0 deletions src/Nncase.Core/CompilerServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ public interface ICompilerServicesProvider
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);

/// <summary>
/// Using EGraph rewrite expression.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
IEGraph ERewrite(IEGraph expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
}

internal interface ICompilerServicesProviderInternal
Expand Down Expand Up @@ -409,6 +418,18 @@ public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassC
return Provider.ERewrite(expr, rules, options);
}

/// <summary>
/// Using EGraph rewrite expression.
/// </summary>
/// <param name="graph">Expression.</param>
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
public static IEGraph ERewrite(IEGraph graph, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
return Provider.ERewrite(graph, rules, options);
}

/// <summary>
/// Match enodes as root.
/// </summary>
Expand Down Expand Up @@ -677,4 +698,9 @@ public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext
{
return _eGraphrewriteProvider.ERewrite(expr, rules, options);
}

public IEGraph ERewrite(IEGraph graph, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
return _eGraphrewriteProvider.ERewrite(graph, rules, options);
}
}
10 changes: 10 additions & 0 deletions src/Nncase.Core/Enum/BinaryOp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,14 @@ public enum BinaryOp : byte
/// Right Shift.
/// </summary>
RightShift,

/// <summary>
/// Floor Div.
/// </summary>
FloorDiv,

/// <summary>
/// Ceil Div.
/// </summary>
CeilDiv,
}
17 changes: 16 additions & 1 deletion src/Nncase.Core/IR/Buffers/Allocate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,20 @@ namespace Nncase.IR.Buffers;
/// </summary>
public sealed partial class Allocate : Op
{
public TensorType ElemType { get; }
/// <summary>
/// Get the input parameter.
/// </summary>
public static readonly ParameterInfo Size = new(typeof(Allocate), 0, "size", TypePatternUtility.IsIntegralScalar());

/// <summary>
/// Gets the alloacted buffer type.
/// </summary>
public DataType ElemType { get; }

public TIR.MemoryLocation Location { get; }

/// <inheritdoc/>
public override bool CanFoldConstCall => false;

public override string DisplayProperty() => $"{ElemType}, {Location}";
}
2 changes: 2 additions & 0 deletions src/Nncase.Core/IR/Buffers/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ public static class Buffer
/// create the uninitialized buffer.
/// </summary>
public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);

public static Call Allocate(Expr size, DataType dataType, TIR.MemoryLocation location) => new Call(new Allocate(dataType, location), size);
}
52 changes: 38 additions & 14 deletions src/Nncase.Core/IR/TensorConst.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ public TensorConst(Tensor tensor)
Value = tensor;
}

public TensorConst(Tensor tensor, IRArray<SBP> ndsbp, Placement placement)
: base(new DistributedType(new TensorType(tensor.ElementType, tensor.Shape), ndsbp, placement))
{
Value = tensor;
}

public Tensor Value { get; }

/// <summary>
/// Gets value type.
/// </summary>
public new TensorType ValueType => (TensorType)base.ValueType;
public new IRType ValueType => base.ValueType;

/// <summary>
/// Create TensorConstant from a <see cref="byte"/>.
Expand Down Expand Up @@ -122,25 +128,43 @@ public TensorConst(Tensor tensor)
public static bool operator !=(TensorConst? left, TensorConst? right) => !(left == right);

/// <inheritdoc/>
public override string ToString() => ValueType switch
public override string ToString()
{
var x when x.IsScalar =>
x.DType switch
{
var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar<long>().ToString(),
var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar<float>().ToString(),
var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar<ulong>().ToString(),
var dtype when dtype == DataTypes.Boolean => Value.ToScalar<bool>().ToString(),
_ => $"{x.DType.GetDisplayName()} {x.Shape}",
},
_ => $"{ValueType.DType.GetDisplayName()} {ValueType.Shape}",
};
var type = ValueType switch
{
DistributedType dt => dt.TensorType,
TensorType tt => tt,
_ => throw new NotSupportedException("Not supported const type: " + ValueType),
};

return type switch
{
var x when x.IsScalar =>
x.DType switch
{
var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar<long>().ToString(),
var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar<float>().ToString(),
var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar<ulong>().ToString(),
var dtype when dtype == DataTypes.Boolean => Value.ToScalar<bool>().ToString(),
_ => $"{x.DType.GetDisplayName()} {x.Shape}",
},
_ => $"{type.DType.GetDisplayName()} {type.Shape}",
};
}

/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitTensorConst(this, context);

public TensorConst With(Tensor? value = null) => new TensorConst(value ?? Value);
public TensorConst With(Tensor? value = null)
{
if (value is null && ValueType is DistributedType dt)
{
return new TensorConst(Value, dt.NdSBP, dt.Placement);
}

return new TensorConst(value ?? Value);
}

/// <inheritdoc/>
public override bool Equals(object? obj) => Equals(obj as TensorConst);
Expand Down
31 changes: 31 additions & 0 deletions src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,37 @@ protected override Expr RewriteLeafBlock(Block expr)
{
if (predicate)
{
if (expr.AllocBuffers.Length > 0)
{
var lets = expr.AllocBuffers.ToArray().Select(b => (T.Let(out var v, b.MemSpan.Start, b.Name + "_ptr"), v)).ToArray();
for (int i = 0; i < lets.Length - 1; i++)
{
lets[i].Item1.Body(lets[i + 1].Item1);
}

var map = new Dictionary<Expr, Expr>(ReferenceEqualityComparer.Instance);
for (int i = 0; i < expr.AllocBuffers.Length; i++)
{
map.Add(expr.AllocBuffers[i].MemSpan.Start, lets[i].v);
}

var mutator = new Substitutor(e =>
{
if (map.TryGetValue(e, out var r))
{
return r;
}
return null;
});

var initBody = mutator.Visit(expr.InitBody, Unit.Default);
var body = mutator.Visit(expr.Body, Unit.Default);

lets[^1].Item1.Body(initBody, body);
return lets[0].Item1.Build();
}

return T.Sequential(expr.InitBody, expr.Body);
}
else
Expand Down
8 changes: 4 additions & 4 deletions src/Nncase.Core/PatternMatch/ConstPattern.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ public static partial class Utility
public static TensorConstPattern IsConst(string? name, Func<float, bool> cond) => new(
x =>
{
if (DataTypes.IsFloat(x.ValueType.DType))
if (DataTypes.IsFloat(x.CheckedDataType))
{
if (x.ValueType.IsScalar)
if (x.CheckedShape.IsScalar)
{
return cond(x.Value.ToScalar<float>());
}
Expand All @@ -93,9 +93,9 @@ public static partial class Utility
public static TensorConstPattern IsConst(string? name, Func<int, bool> cond) => new(
x =>
{
if (DataTypes.IsIntegral(x.ValueType.DType))
if (DataTypes.IsIntegral(x.CheckedDataType))
{
if (x.ValueType.IsScalar)
if (x.CheckedShape.IsScalar)
{
return cond(x.Value.ToScalar<int>());
}
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public T Build()

public ISequentialBuilder<T> InsertBody(int index, params object[] exprOrBuilders)
{
_subBuilders[_subBuilders.Length - 1].InsertBody(index, exprOrBuilders);
_subBuilders[index < 0 ? _subBuilders.Length + index : index].Body(exprOrBuilders);
return this;
}
}
61 changes: 56 additions & 5 deletions src/Nncase.Core/TIR/Script.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,16 @@ public static ISequentialBuilder<For> Grid(out Var[] loopVars, LoopMode loopMode
{
string[] names = { "i", "j", "k", "l" };
var newLoopVars = loopVars = new Var[ranges.Length];
return new NestBodyExprBuilder<For>(ranges.Select((rg, i) =>
T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray());
var newLoops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray();
return new NestBodyExprBuilder<For>(newLoops);
}

public static ISequentialBuilder<For> Grid(out Var[] loopVars, out ISequentialBuilder<For>[] loops, LoopMode loopMode, params TIR.Range[] ranges)
{
string[] names = { "i", "j", "k", "l" };
var newLoopVars = loopVars = new Var[ranges.Length];
var newLoops = loops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray();
return new NestBodyExprBuilder<For>(loops);
}

/// <summary>
Expand Down Expand Up @@ -223,6 +231,49 @@ public static Buffer CreateBuffer(TensorType tensorType, MemoryLocation location
return buffer;
}

/// <summary>
/// create the buffer by expressions.
/// </summary>
public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
{
if (name.StartsWith("var "))
{
name = name[4..];
}

var strides = TensorUtilities.GetStrides(dimensions);
var size = TensorUtilities.GetProduct(dimensions.ToArray()) * dataType.SizeInBytes;
var memspan = new MemSpan(size, location);
buffer = new Buffer(name, dataType, memspan, dimensions, strides);
return buffer;
}

public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, Expr[] strides, MemSpan memSpan, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
{
if (name.StartsWith("var "))
{
name = name[4..];
}

buffer = new Buffer(name, dataType, memSpan, dimensions, strides);
return buffer;
}

public static Buffer AttachBuffer(Expr start, TensorType tensorType, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "")
{
if (name.StartsWith("var "))
{
name = name[4..];
}

var dimensions = tensorType.Shape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes;
var memspan = new MemSpan(start, size, location);
buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
}

/// <summary>
/// create buffer by const.
/// </summary>
Expand All @@ -233,11 +284,11 @@ public static Buffer AttachBuffer(TensorConst @const, out Buffer buffer, [Caller
name = name[4..];
}

var dimensions = @const.ValueType.Shape.ToValueArray();
var dimensions = @const.CheckedShape.ToValueArray();
var strides = TensorUtilities.GetStrides(dimensions);
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.ValueType.DType.SizeInBytes;
var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.CheckedDataType.SizeInBytes;
var memspan = new MemSpan(IR.F.Buffer.DDrOf(@const), size, MemoryLocation.Rdata);
buffer = new Buffer(name, @const.ValueType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
buffer = new Buffer(name, @const.CheckedDataType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray());
return buffer;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/TensorUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static Expr GetProduct(ReadOnlySpan<Expr> dimensions, int startIndex = 0)
for (int i = startIndex; i < dimensions.Length; i++)
{
var dimension = dimensions[i];
product *= IR.F.Math.Require(dimension >= 0, dimension, "Dimension is out of range.");
product *= dimension;
}

return product;
Expand Down

0 comments on commit be412c2

Please sign in to comment.