diff --git a/src/Nncase.Compiler/Compiler.cs b/src/Nncase.Compiler/Compiler.cs index 3d83a6b0b4..c21b7247d3 100644 --- a/src/Nncase.Compiler/Compiler.cs +++ b/src/Nncase.Compiler/Compiler.cs @@ -100,7 +100,6 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); - p.Add(); p.Add(); p.Add(); p.Add(); @@ -141,6 +140,7 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); + p.Add(); p.Add(); }); diff --git a/src/Nncase.Core/TIR/TIRUtilities.cs b/src/Nncase.Core/TIR/TIRUtilities.cs index 6a1142e71c..6dcad9ad16 100644 --- a/src/Nncase.Core/TIR/TIRUtilities.cs +++ b/src/Nncase.Core/TIR/TIRUtilities.cs @@ -71,4 +71,21 @@ public static IReadOnlyList ComputeBounds(IReadOnlyList su IR.F.Math.Max(0, t.First.Start), IR.F.Math.Min(t.Second.FixedValue, t.First.Stop), t.First.Step)).ToArray(); + + public static bool TryGetFixedRegions(TIR.BufferRegion region, out (int Start, int Stop, int Step)[] slice) + { + slice = new (int Start, int Stop, int Step)[region.Region.Length]; + for (int i = 0; i < region.Region.Length; i++) + { + var rg = region.Region[i]; + if (rg is not Range { Start: IR.TensorConst start, Stop: IR.TensorConst stop, Step: IR.TensorConst step }) + { + return false; + } + + slice[i] = (start.Value.ToScalar(), stop.Value.ToScalar(), step.Value.ToScalar()); + } + + return true; + } } diff --git a/src/Nncase.Core/Utilities/ShapeExprUtility.cs b/src/Nncase.Core/Utilities/ShapeExprUtility.cs index c1da04dca4..9c73513cf7 100644 --- a/src/Nncase.Core/Utilities/ShapeExprUtility.cs +++ b/src/Nncase.Core/Utilities/ShapeExprUtility.cs @@ -1,7 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -using GiGraph.Dot.Output.Writers.Edges; using Nncase.Diagnostics; using Nncase.IR; using Nncase.IR.Tensors; diff --git a/src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs b/src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs index c8f3bd5c38..ac2d1c1caa 100644 --- a/src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs +++ b/src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs @@ -32,6 +32,25 @@ internal static DotGraph DumpEgraphAsDot(IEGraph eGraph, CostModel.EGraphCostMod return printer.SaveToStream(file); } + /// + /// find the minCostEnode in eclass. + /// + /// the marker first. + /// + /// + internal static ENode MinByWithMarker(EClass eClass, CostModel.EGraphCostModel costModel) + { + return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? CostModel.Cost.Zero : costModel[x])!; + } + + /// + /// find the minCostEnode in eclass skip marker. + /// + internal static ENode MinByWithOutMarker(EClass eClass, CostModel.EGraphCostModel costModel) + { + return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!; + } + private DotGraph AttachEGraphCost(CostModel.EGraphCostModel costModel, EClass entry) { // 1. display each enode costs. @@ -72,12 +91,12 @@ void Dfs(EClass curclass) continue; } - var minCostEnode = parent.MinByWithMarker(costModel); + var minCostEnode = MinByWithMarker(parent, costModel); // when this marker ecalss has been visited, skip it. if (markerEclassMemo.Contains(parent)) { - minCostEnode = parent.MinByWithOutMarker(costModel); + minCostEnode = MinByWithOutMarker(parent, costModel); } var (minCostDotnode, table) = NodesMap[minCostEnode]; @@ -93,7 +112,7 @@ void Dfs(EClass curclass) if (minCostEnode.Expr is Marker && child == parent) { markerEclassMemo.Add(child); - var otherminCostENode = child.MinByWithOutMarker(costModel); + var otherminCostENode = MinByWithOutMarker(child, costModel); var (childDotNode, _) = NodesMap[otherminCostENode]; _dotGraph.Edges.Add(childDotNode, minCostDotnode, edge => { @@ -103,7 +122,7 @@ void Dfs(EClass curclass) } else { - var childEnode = child.Find().MinByWithMarker(costModel); + var childEnode = MinByWithMarker(child.Find(), costModel); var (childDotNode, _) = NodesMap[childEnode]; _dotGraph.Edges.Add(childDotNode, minCostDotnode, edge => { @@ -126,3 +145,23 @@ void Dfs(EClass curclass) return _dotGraph; } } + +internal sealed class ENodeTypeComparer : IComparer +{ + public static readonly ENodeTypeComparer Instance = new(); + + public int Compare(Expr? x, Expr? y) => (x, y) switch + { + (null, null) => 0, + (Expr, null) => 1, + (null, Expr) => -1, + (Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)), + }; + + private int GetPriority(Expr x) => x switch + { + Marker => 0, + Const => 1, + _ => 2, + }; +} diff --git a/src/Nncase.EGraph/Passes/EGraphExtensions.cs b/src/Nncase.EGraph/Passes/EGraphExtensions.cs new file mode 100644 index 0000000000..5a349cec48 --- /dev/null +++ b/src/Nncase.EGraph/Passes/EGraphExtensions.cs @@ -0,0 +1,50 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Google.OrTools.Sat; +using Nncase.CostModel; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes; + +/// +/// EGraph extract extensions. +/// +public static class EGraphExtensions +{ + /// + /// Extract egraph. + /// + /// egraph. + /// Root eclass. + /// base func cost evaluator. + /// the cp model constrains. + public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[] constrains) + { + // 1. set enode expr with more accuracy type. + foreach (var eclass in eGraph.Classes) + { + foreach (var nodes in eclass.Nodes) + { + if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0) + { + nodes.Expr.CheckedType = eclass.CheckedType; + } + } + } + + // 2. start the cost evaluator + var costModel = new CostModel.EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate(); + + return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains); + } +} diff --git a/src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs b/src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs deleted file mode 100644 index 7c0cfbdc26..0000000000 --- a/src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text; -using Nncase.CostModel; -using Nncase.Diagnostics; -using Nncase.IR; -using Nncase.PatternMatch; -using static Nncase.PatternMatch.F.Math; -using static Nncase.PatternMatch.Utility; - -namespace Nncase.Passes; - -/// -/// EGraph extract extensions. -/// -public static class EGraphExtractExtensions -{ - /// - /// Extract egraph. - /// - /// eGraph. - /// Root eclass. - /// base func cost evaluator. - /// the picks. - /// Extracted root expression. - public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, out IReadOnlyDictionary picks) - { - // 1. set enode expr with more accuracy type. - foreach (var eclass in eGraph.Classes) - { - foreach (var nodes in eclass.Nodes) - { - if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0) - { - nodes.Expr.CheckedType = eclass.CheckedType; - } - } - } - - // 2. start the cost evaluator - var costModel = new EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate(); - - // if (DumpScope.Current.IsEnabled(DumpFlags.EGraphCost)) - // { - // using var fs = DumpScope.Current.OpenFile(Path.Combine("Costs", $"V{eGraph.Version}.dot")); - // EGraphPrinter.DumpEgraphAsDot(eGraph, costModel, root.Find(), fs); - // } - // return new EGraphExtractor(costModel).Extract(root.Find(), eGraph); - return new EGraphExtractors.SatExtractor(costModel).Extract(root.Find(), eGraph, out picks); - } - - /// - /// find the minCostEnode in eclass. - /// - /// the marker first. - /// - /// - internal static ENode MinByWithMarker(this EClass eClass, CostModel.EGraphCostModel costModel) - { - return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? Cost.Zero : costModel[x])!; - } - - /// - /// find the minCostEnode in eclass skip marker. - /// - internal static ENode MinByWithOutMarker(this EClass eClass, CostModel.EGraphCostModel costModel) - { - return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!; - } - - internal sealed class ENodeTypeComparer : IComparer - { - public static readonly ENodeTypeComparer Instance = new(); - - public int Compare(Expr? x, Expr? y) => (x, y) switch - { - (null, null) => 0, - (Expr, null) => 1, - (null, Expr) => -1, - (Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)), - }; - - private int GetPriority(Expr x) => x switch - { - Marker => 0, - Const => 1, - _ => 2, - }; - } -} diff --git a/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractor.cs similarity index 94% rename from src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs rename to src/Nncase.EGraph/Passes/EGraphExtractor.cs index fab2fc05ea..c7eba45704 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtractor.cs @@ -11,18 +11,20 @@ using Nncase.Diagnostics; using Nncase.IR; -namespace Nncase.Passes.EGraphExtractors; +namespace Nncase.Passes; -internal class SatExtractor : IExtractor +public delegate void EGraphExtractConstrains(CpModel model, IReadOnlyDictionary vars); + +internal class EGraphExtractor { private readonly EGraphCostModel _costModel; - public SatExtractor(EGraphCostModel costModel) + public EGraphExtractor(EGraphCostModel costModel) { _costModel = costModel; } - public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary picks) + public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] constrains) { var cpmodel = new CpModel(); @@ -68,6 +70,11 @@ public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary vars[n]), eGraph.Nodes.Select(n => checked((long)_costModel[n].Score)))); @@ -121,7 +128,7 @@ public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary e, e => solver.BooleanValue(vars[e])); + var picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e])); using (var dumpStream = enableDump ? DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null) { EGraphPrinter.DumpEgraphAsDot(eGraph, _costModel, picks, root.Find(), dumpStream); diff --git a/src/Nncase.EGraph/Passes/EGraphExtractors/Extractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractors/Extractor.cs deleted file mode 100644 index fcf2abd729..0000000000 --- a/src/Nncase.EGraph/Passes/EGraphExtractors/Extractor.cs +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text; -using Nncase.CostModel; -using Nncase.Diagnostics; -using Nncase.IR; -using Nncase.PatternMatch; -using static Nncase.PatternMatch.F.Math; -using static Nncase.PatternMatch.Utility; - -namespace Nncase.Passes.EGraphExtractors; - -internal interface IExtractor -{ - Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary picks); -} - -internal class Extractor : IExtractor -{ - private readonly EGraphCostModel _costModel; - private readonly Dictionary _eclassMemo = new(); - private readonly Dictionary _markerEclassMemo = new(); - private readonly Dictionary _picks = new(); - private StreamWriter? _dumpWriter; - - public Extractor(EGraphCostModel costModel) - { - _costModel = costModel; - } - - public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary picks) - { - _dumpWriter = DumpScope.Current.IsEnabled(DumpFlags.EGraphCost) - ? new StreamWriter(DumpScope.Current.OpenFile($"{nameof(Extractor)}_Class_{root.Id}.txt")) - : null; - try - { - Visit(root); - } - finally - { - _dumpWriter?.Dispose(); - } - - foreach (var enode in eGraph.Nodes) - { - if (!_picks.ContainsKey(enode)) - { - _picks[enode] = false; - } - } - - picks = _picks; - return _eclassMemo[root]; - } - - private void Visit(EClass eclass) - { - var stack = new Stack<(EClass, ENode)>(); - stack.Push((eclass, eclass.MinByWithMarker(_costModel))); - var markerEclassSet = new HashSet(); - while (stack.Any()) - { - (eclass, var minCostEnode) = stack.Peek(); - if (_eclassMemo.ContainsKey(eclass)) - { - stack.Pop(); - continue; - } - - Expr? expr = null; - switch (minCostEnode.Expr) - { - case Var or TensorConst or TupleConst or Op or Fusion or None: - expr = minCostEnode.Expr; - break; - case Function or Call or IR.Tuple or Marker or IR.If: - var childrenExprs = new List(); - foreach (var child in minCostEnode.Children) - { - if (!_eclassMemo.TryGetValue(child, out var childExpr)) - { - if (minCostEnode.Expr is Marker && child == eclass) - { - if (!_markerEclassMemo.TryGetValue(eclass, out var markerInputExpr)) - { - markerEclassSet.Add(eclass); - stack.Push((eclass, eclass.MinByWithOutMarker(_costModel))); - } - else - { - childrenExprs.Add(markerInputExpr); - } - } - else - { - stack.Push((child, child.MinByWithMarker(_costModel))); - } - } - else - { - childrenExprs.Add(childExpr); - } - } - - if (childrenExprs.Count != minCostEnode.Children.Count) - { - break; - } - - expr = minCostEnode.Expr switch - { - Function function => Visit(minCostEnode, function, new(childrenExprs)), - Call call => Visit(minCostEnode, call, new(childrenExprs)), - IR.Tuple tuple => Visit(minCostEnode, tuple, new(childrenExprs)), - Marker marker => Visit(minCostEnode, marker, new(childrenExprs)), - IR.If @if => Visit(minCostEnode, @if, new(childrenExprs)), - _ => throw new ArgumentException("Unsupported expression type."), - }; - - break; - default: - throw new ArgumentException("Unsupported expression type."); - } - - if (expr is null) - { - continue; - } - - if (markerEclassSet.Contains(eclass) && minCostEnode.Expr is not Marker) - { - _markerEclassMemo.Add(eclass, expr); - } - else - { - _eclassMemo.Add(eclass, expr); - } - - _picks[minCostEnode] = true; - stack.Pop(); - } - } - - private Marker Visit(ENode enode, Marker marker, IRArray children) - { - var target = children[0]; - var attr = children[1]; - return marker.With(target: target, attribute: attr); - } - - private Function Visit(ENode enode, Function func, IRArray children) - { - if (children.Count == 0) - { - return func; - } - - var body = children[0]; - return func.With(body: body); - } - - private IR.Tuple Visit(ENode enode, IR.Tuple tuple, IRArray children) - { - return tuple.With(fields: children.ToArray()); - } - - private IR.If Visit(ENode enode, IR.If @if, IRArray children) - { - return @if.With(condition: children[^3], then: children[^2], @else: children[^1], paramList: children[..^3].ToArray()); - } - - private Call Visit(ENode enode, Call call, IRArray children) - { - var target = children[0]; - var arguments = children.Skip(1); - - // for mix quant debug. - if (call.EnodeQuantConfigWithCosine != null && _dumpWriter != null) - { - _dumpWriter.WriteLine(call + " " + call.CheckedType); - for (int i = 0; i < call.EnodeQuantConfigWithCosine.Count; i++) - { - for (int j = 0; j < call.EnodeQuantConfigWithCosine[i].Item1.Count; j++) - { - _dumpWriter.Write(call.EnodeQuantConfigWithCosine[i].Item1[j] + " "); - } - - _dumpWriter.WriteLine(call.EnodeQuantConfigWithCosine[i].Item3); - } - } - - return call.With(target: target, arguments: arguments.ToArray(), call.Metadata); - } -} diff --git a/src/Nncase.EGraph/Passes/RewriteProvider.cs b/src/Nncase.EGraph/Passes/RewriteProvider.cs index 07d3416edf..ad64c22073 100644 --- a/src/Nncase.EGraph/Passes/RewriteProvider.cs +++ b/src/Nncase.EGraph/Passes/RewriteProvider.cs @@ -36,7 +36,7 @@ public Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext var graph = new EGraph(expr); ERewrite(graph, rules, options); - var post = graph.Extract(graph.Root!, null, out _); + var post = graph.Extract(graph.Root!, null, Array.Empty()); return post; } diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs index 13e8ab86f0..3b76f1673b 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduleTypes.cs @@ -3,54 +3,34 @@ namespace Nncase.Passes.BufferSchedule; -internal sealed class TimeInterval +public sealed class Interval { - public TimeInterval(int start, int end) - { - Brith = start; - Death = end; - } - - public int Brith { get; set; } - - public int Death { get; set; } - - public int Size => Death - Brith; - - public override string ToString() - { - return $"TimeInterval({Brith}, {Death})"; - } -} - -internal sealed class MemSpan -{ - public MemSpan(int start, int end) + public Interval(int start, int end) { Start = start; - End = end; + Stop = end; } public int Start { get; set; } - public int End { get; set; } + public int Stop { get; set; } - public int Size => End - Start; + public int Size => Stop - Start; public override string ToString() { - return $"MemSpan({Start}, {End})"; + return $"Interval({Start}, {Stop})"; } } -internal class ScheduleBuffer +public class ScheduleBuffer { - public ScheduleBuffer(string name, int number, TimeInterval interval, MemSpan span, int[] shape, int[] strides, bool inplace) + public ScheduleBuffer(string name, int number, Interval timeInterval, Interval memInterval, int[] shape, int[] strides, bool inplace) { Name = name; Number = number; - Interval = interval; - Span = span; + TimeInterval = timeInterval; + MemInterval = memInterval; Shape = shape; Strides = strides; Inplace = inplace; @@ -60,9 +40,9 @@ public ScheduleBuffer(string name, int number, TimeInterval interval, MemSpan sp public int Number { get; } - public TimeInterval Interval { get; } + public Interval TimeInterval { get; } - public MemSpan Span { get; } + public Interval MemInterval { get; } public int[] Shape { get; } @@ -72,6 +52,6 @@ public ScheduleBuffer(string name, int number, TimeInterval interval, MemSpan sp public override string ToString() { - return $"ScheduledBuffer('{Name}', {Number}, {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})"; + return $"ScheduledBuffer('{Name}', {Number}, {TimeInterval}, {MemInterval}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})"; } } diff --git a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs index 25f7d6f5b8..8d02aff4e8 100644 --- a/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs +++ b/src/Nncase.Passes/BufferSchedule/BufferScheduler.cs @@ -13,12 +13,42 @@ namespace Nncase.Passes.BufferSchedule; -internal sealed class BufferScheduler +public class BufferScheduler { - public IReadOnlyDictionary CollectLifeTime(Function func) + public virtual void ExternalConstrains(CpModel model, IReadOnlyDictionary bufferMap, IReadOnlyDictionary boxs) { - var c = new LifeTimeCollector(); - return c.Collect(func); + foreach (var (expr, item) in bufferMap) + { + if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple tuple) + { + // the concat inputs must contiguous + int offset = 0; + for (int i = 0; i < tuple.Fields.Length; i++) + { + model.Add((boxs[concatCall].Y.StartExpr() + offset) == boxs[tuple.Fields[i]].Y.StartExpr()); + offset += bufferMap[tuple.Fields[i]].MemInterval.Size; + } + } + else if (expr is Call { Target: IR.Tensors.Split } splitCall) + { + // the split must equal with input. + model.Add(boxs[splitCall].Y.StartExpr() == boxs[splitCall.Arguments[0]].Y.StartExpr()); + + // the split outputs must contiguous + var users = splitCall.GetUsers(); + int offset = 0; + foreach (var user in users.OrderBy(e => ((Call)e).Arguments[1].Evaluate().AsTensor().ToScalar())) + { + model.Add((boxs[splitCall].Y.StartExpr() + offset) == boxs[user].Y.StartExpr()); + offset += bufferMap[user].MemInterval.Size; + } + } + else if (expr is Call { Target: IR.Tensors.Reshape } reshapCall) + { + // the reshape must equal with it's input. + model.Add(boxs[reshapCall].Y.StartExpr() == boxs[reshapCall.Arguments[0]].Y.StartExpr()); + } + } } public void Schedule(IReadOnlyDictionary bufferMap) @@ -30,21 +60,21 @@ public void Schedule(IReadOnlyDictionary bufferMap) var yStarts = new List(); foreach (var (expr, item) in bufferMap) { - var xInterval = model.NewIntervalVar(model.NewConstant(item.Interval.Brith), model.NewConstant(item.Interval.Size), model.NewConstant(item.Interval.Death), item.Name + $"{item.Number}_x"); + var xInterval = model.NewIntervalVar(model.NewConstant(item.TimeInterval.Start), model.NewConstant(item.TimeInterval.Size), model.NewConstant(item.TimeInterval.Stop), item.Name + $"{item.Number}_x"); - var upbound = 2147483648 - item.Span.End; + var upbound = 2147483648 - item.MemInterval.Stop; if (upbound <= 0) { throw new System.NotSupportedException(); } var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_{item.Number}_y_start"); - var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.Span.End, $"{item.Name}_{item.Number}_y"); + var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.MemInterval.Stop, $"{item.Name}_{item.Number}_y"); noOverlap.AddRectangle(xInterval, yInterval); yStarts.Add(memStartVar); boxs.Add(expr, (xInterval, yInterval)); - for (int time = item.Interval.Brith; time < item.Interval.Death; time++) + for (int time = item.TimeInterval.Start; time < item.TimeInterval.Stop; time++) { if (!timeMap.TryGetValue(time, out var timelist)) { @@ -56,38 +86,7 @@ public void Schedule(IReadOnlyDictionary bufferMap) } } - foreach (var (expr, item) in bufferMap) - { - if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple tuple) - { - // the concat inputs must contiguous - int offset = 0; - for (int i = 0; i < tuple.Fields.Length; i++) - { - model.Add((boxs[concatCall].Y.StartExpr() + offset) == boxs[tuple.Fields[i]].Y.StartExpr()); - offset += bufferMap[tuple.Fields[i]].Span.Size; - } - } - else if (expr is Call { Target: IR.Tensors.Split } splitCall) - { - // the split must equal with input. - model.Add(boxs[splitCall].Y.StartExpr() == boxs[splitCall.Arguments[0]].Y.StartExpr()); - - // the split outputs must contiguous - var users = splitCall.GetUsers(); - int offset = 0; - foreach (var user in users.OrderBy(e => ((Call)e).Arguments[1].Evaluate().AsTensor().ToScalar())) - { - model.Add((boxs[splitCall].Y.StartExpr() + offset) == boxs[user].Y.StartExpr()); - offset += bufferMap[user].Span.Size; - } - } - else if (expr is Call { Target: IR.Tensors.Reshape } reshapCall) - { - // the reshape must equal with it's input. - model.Add(boxs[reshapCall].Y.StartExpr() == boxs[reshapCall.Arguments[0]].Y.StartExpr()); - } - } + ExternalConstrains(model, bufferMap, boxs); model.Minimize(LinearExpr.Sum(yStarts)); @@ -99,10 +98,10 @@ public void Schedule(IReadOnlyDictionary bufferMap) throw new System.NotSupportedException(); } - foreach (var (k, v) in bufferMap) + foreach (var (k, _) in bufferMap) { - bufferMap[k].Span.Start = checked((int)solver.Value(boxs[k].Y.StartExpr())); - bufferMap[k].Span.End = checked((int)solver.Value(boxs[k].Y.EndExpr())); + bufferMap[k].MemInterval.Start = checked((int)solver.Value(boxs[k].Y.StartExpr())); + bufferMap[k].MemInterval.Stop = checked((int)solver.Value(boxs[k].Y.EndExpr())); } } @@ -119,18 +118,11 @@ import itertools from typing import List @dataclass -class TimeInterval(): +class Interval(): start: int end: int def __str__(self) -> str: - return f'(start: {self.start}, end {self.end})' - -@dataclass -class MemSpan(): - depth_start: int - depth_end: int - def __str__(self) -> str: - return f'(start: {self.depth_start}, size {self.depth_end - self.depth_start})' + return f'(start: {self.start}, end {self.end}, size {self.end - self.start})' class ConstraintsMode(Enum): No = 0 @@ -140,8 +132,8 @@ class ConstraintsMode(Enum): class ScheduledBuffer(): name: str number: int - interval: TimeInterval - location: MemSpan + time_interval: Interval + mem_interval: Interval constraints: ConstraintsMode shape: List[int] stride: List[int] @@ -166,8 +158,8 @@ class ScheduledBuffer(): 'height': [], 'alpha': [], 'color': [], - 'location': [], - 'interval': [], + 'mem_interval': [], + 'time_interval': [], 'shape': [], 'stride': [], } @@ -177,10 +169,10 @@ class ScheduledBuffer(): color_dict = {} for buffer in buffers: source['name'].append(buffer.name) - width = buffer.interval.end - buffer.interval.start - x = buffer.interval.start + (width / 2) - height = buffer.location.depth_end - buffer.location.depth_start - y = buffer.location.depth_start + (height / 2) + width = buffer.time_interval.end - buffer.time_interval.start + x = buffer.time_interval.start + (width / 2) + height = buffer.mem_interval.end - buffer.mem_interval.start + y = buffer.mem_interval.start + (height / 2) y_range_max = max(y_range_max, y) x_range_max = max(x_range_max, buffer.interval.end) source['x'].append(x) @@ -193,13 +185,13 @@ class ScheduledBuffer(): color_dict[buffer.name] = color source['color'].append(color) source['alpha'].append(0.2 if buffer.inplace else 1.0) - source['interval'].append(str(buffer.interval)) - source['location'].append(str(buffer.location)) + source['time_interval'].append(str(buffer.time_interval)) + source['mem_interval'].append(str(buffer.mem_interval)) source['shape'].append(','.join([str(s) for s in buffer.shape])) source['stride'].append(','.join([str(s) for s in buffer.stride])) source = ColumnDataSource(source) -hover = HoverTool(tooltips=[('name', '@name'), ('interval', '@interval'), ('location', '@location'), +hover = HoverTool(tooltips=[('name', '@name'), ('time_interval', '@time_interval'), ('mem_interval', '@mem_interval'), ('shape', '@shape'), ('stride', '@stride')]) p = figure(tools=[hover, WheelPanTool(), SaveTool(), WheelZoomTool(), ResetTool()], width=1280, height=720, diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs index 1edcc263dd..938317762b 100644 --- a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -10,16 +10,16 @@ namespace Nncase.Passes.BufferSchedule; -internal sealed class LifeTimeCollector : ExprVisitor +public class LifeTimeCollector : ExprVisitor { public int TimeStamp { get; private set; } - public Dictionary LifenessMap { get; } = new(ReferenceEqualityComparer.Instance); + public Dictionary LifenessMap { get; } = new(ReferenceEqualityComparer.Instance); - public IReadOnlyDictionary Collect(Function entry) + public IReadOnlyDictionary Collect(Expr expr) { - Visit(entry.Body); - Update(entry.Body); // avoid final call time interval size == 1. + Visit(expr); + Update(expr); // avoid final call time interval size == 1. Alias(); var d = new Dictionary(ReferenceEqualityComparer.Instance); @@ -32,8 +32,7 @@ internal sealed class LifeTimeCollector : ExprVisitor Var va => va.Name, _ => k.GetType().Name, }; - var size = GetSize(k.CheckedType, out var shape, out var stride); - + var size = ComputeBufferSize(k.CheckedType, out var shape, out var stride); d.Add(k, new(name, count++, v, new(0, size), shape, stride, false)); } @@ -62,6 +61,29 @@ protected override Unit VisitLeafCall(Call expr) return Unit.Default; } + protected virtual int ComputeBufferSize(IRType type, out int[] shape, out int[] stride) + { + shape = Array.Empty(); + stride = Array.Empty(); + var size = 0; + if (type is TensorType tensorType) + { + shape = tensorType.Shape.ToValueArray(); + stride = TensorUtilities.GetStrides(shape); + size = TensorUtilities.GetSize(shape, stride, tensorType.DType.SizeInBytes); + } + else if (type is TupleType tupleType) + { + size = 0; + foreach (var item in tupleType) + { + size += ComputeBufferSize(item, out _, out _); + } + } + + return size; + } + private void Update(Expr expr) { if (expr is Const or None) @@ -85,7 +107,7 @@ private void Update(Expr expr) } else { - interval.Death = TimeStamp + 1; + interval.Stop = TimeStamp + 1; } LifenessMap[expr] = interval; @@ -123,12 +145,12 @@ private void Alias() } while (changed); } - private bool AliasTime(Call call, TimeInterval interval) + private bool AliasTime(Call call, Interval interval) { - var brith = call.GetArguments().Select(arg => LifenessMap[arg].Death).Concat(new[] { interval.Brith }).Max(); - var death = call.GetUsers().Select(usr => LifenessMap[usr].Brith).Concat(new[] { interval.Death }).Min(); + var brith = call.GetArguments().Select(arg => LifenessMap[arg].Stop).Concat(new[] { interval.Start }).Max(); + var death = call.GetUsers().Select(usr => LifenessMap[usr].Start).Concat(new[] { interval.Stop }).Min(); - if (brith == interval.Brith && death == interval.Death) + if (brith == interval.Start && death == interval.Stop) { return false; } @@ -138,31 +160,8 @@ private bool AliasTime(Call call, TimeInterval interval) throw new InvalidOperationException(); } - interval.Brith = brith; - interval.Death = death; + interval.Start = brith; + interval.Stop = death; return true; } - - private int GetSize(IRType type, out int[] shape, out int[] stride) - { - shape = Array.Empty(); - stride = Array.Empty(); - var size = 0; - if (type is TensorType tensorType) - { - shape = tensorType.Shape.ToValueArray(); - stride = TensorUtilities.GetStrides(shape); - size = TensorUtilities.GetSize(shape, stride, tensorType.DType.SizeInBytes); - } - else if (type is TupleType tupleType) - { - size = 0; - foreach (var item in tupleType) - { - size += GetSize(item, out _, out _); - } - } - - return size; - } } diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index 15a6505686..2103b7b045 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -46,7 +46,8 @@ protected override async Task RunCoreAsync(IRModule module, RunPassCon if (module.Entry is Function { ModuleKind: Callable.StackVMModuleKind, Body: Expr body } func && IsFixedType(body.CheckedType)) { var sch = new BufferSchedule.BufferScheduler(); - var buffers = sch.CollectLifeTime(func); + var c = new BufferSchedule.LifeTimeCollector(); + var buffers = c.Collect(func.Body); sch.Schedule(buffers); using (var fs = Diagnostics.DumpScope.Current.OpenFile("draw_buffers.py")) { diff --git a/src/Nncase.Passes/EGraphExtractPass.cs b/src/Nncase.Passes/EGraphExtractPass.cs index 2c2baa12b8..ba028ea765 100644 --- a/src/Nncase.Passes/EGraphExtractPass.cs +++ b/src/Nncase.Passes/EGraphExtractPass.cs @@ -24,7 +24,7 @@ public EGraphExtractPass(IBaseFuncCostEvaluator? costEvaluator = null) protected override Task RunCoreAsync(IEGraph input, RunPassContext context) { - var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, out _); + var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, Array.Empty()); IRHelpers.DCE(post); return Task.FromResult(post); } diff --git a/src/Nncase.Tests/CostModel/UnitTestEGraphCostModel.cs b/src/Nncase.Tests/CostModel/UnitTestEGraphCostModel.cs index c14174bdeb..2fe2ba6d3f 100644 --- a/src/Nncase.Tests/CostModel/UnitTestEGraphCostModel.cs +++ b/src/Nncase.Tests/CostModel/UnitTestEGraphCostModel.cs @@ -47,10 +47,10 @@ public void TestEGraphExtractMinBy() }, }; - Assert.IsType(list.OrderBy(e => e, EGraphExtractExtensions.ENodeTypeComparer.Instance).First()); + Assert.IsType(list.OrderBy(e => e, ENodeTypeComparer.Instance).First()); Assert.True(cost[b] < cost[c]); - Assert.IsType(list.OrderBy(e => e, EGraphExtractExtensions.ENodeTypeComparer.Instance).MinBy(e => cost[e])); + Assert.IsType(list.OrderBy(e => e, ENodeTypeComparer.Instance).MinBy(e => cost[e])); } }