Skip to content

Commit

Permalink
Feature/egraph extract constrains (#1175)
Browse files Browse the repository at this point in the history
* add egraph-extract-constrains
* reorder SwapBinaryArgs
  • Loading branch information
xhuohai committed Mar 11, 2024
1 parent 1cdea27 commit dbcde6f
Show file tree
Hide file tree
Showing 15 changed files with 233 additions and 444 deletions.
2 changes: 1 addition & 1 deletion src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
Expand Down Expand Up @@ -141,6 +140,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
});

Expand Down
17 changes: 17 additions & 0 deletions src/Nncase.Core/TIR/TIRUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,21 @@ public static IReadOnlyList<TIR.Range> ComputeBounds(IReadOnlyList<TIR.Range> su
IR.F.Math.Max(0, t.First.Start),
IR.F.Math.Min(t.Second.FixedValue, t.First.Stop),
t.First.Step)).ToArray();

public static bool TryGetFixedRegions(TIR.BufferRegion region, out (int Start, int Stop, int Step)[] slice)
{
slice = new (int Start, int Stop, int Step)[region.Region.Length];
for (int i = 0; i < region.Region.Length; i++)
{
var rg = region.Region[i];
if (rg is not Range { Start: IR.TensorConst start, Stop: IR.TensorConst stop, Step: IR.TensorConst step })
{
return false;
}

slice[i] = (start.Value.ToScalar<int>(), stop.Value.ToScalar<int>(), step.Value.ToScalar<int>());
}

return true;
}
}
1 change: 0 additions & 1 deletion src/Nncase.Core/Utilities/ShapeExprUtility.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using GiGraph.Dot.Output.Writers.Edges;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.IR.Tensors;
Expand Down
47 changes: 43 additions & 4 deletions src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ internal static DotGraph DumpEgraphAsDot(IEGraph eGraph, CostModel.EGraphCostMod
return printer.SaveToStream(file);
}

/// <summary>
/// find the minCostEnode in eclass.
/// <remarks>
/// the marker first.
/// </remarks>
/// </summary>
internal static ENode MinByWithMarker(EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? CostModel.Cost.Zero : costModel[x])!;
}

/// <summary>
/// find the minCostEnode in eclass skip marker.
/// </summary>
internal static ENode MinByWithOutMarker(EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!;
}

private DotGraph AttachEGraphCost(CostModel.EGraphCostModel costModel, EClass entry)
{
// 1. display each enode costs.
Expand Down Expand Up @@ -72,12 +91,12 @@ void Dfs(EClass curclass)
continue;
}

var minCostEnode = parent.MinByWithMarker(costModel);
var minCostEnode = MinByWithMarker(parent, costModel);

// when this marker ecalss has been visited, skip it.
if (markerEclassMemo.Contains(parent))
{
minCostEnode = parent.MinByWithOutMarker(costModel);
minCostEnode = MinByWithOutMarker(parent, costModel);
}

var (minCostDotnode, table) = NodesMap[minCostEnode];
Expand All @@ -93,7 +112,7 @@ void Dfs(EClass curclass)
if (minCostEnode.Expr is Marker && child == parent)
{
markerEclassMemo.Add(child);
var otherminCostENode = child.MinByWithOutMarker(costModel);
var otherminCostENode = MinByWithOutMarker(child, costModel);
var (childDotNode, _) = NodesMap[otherminCostENode];
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
{
Expand All @@ -103,7 +122,7 @@ void Dfs(EClass curclass)
}
else
{
var childEnode = child.Find().MinByWithMarker(costModel);
var childEnode = MinByWithMarker(child.Find(), costModel);
var (childDotNode, _) = NodesMap[childEnode];
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
{
Expand All @@ -126,3 +145,23 @@ void Dfs(EClass curclass)
return _dotGraph;
}
}

internal sealed class ENodeTypeComparer : IComparer<Expr>
{
public static readonly ENodeTypeComparer Instance = new();

public int Compare(Expr? x, Expr? y) => (x, y) switch
{
(null, null) => 0,
(Expr, null) => 1,
(null, Expr) => -1,
(Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)),
};

private int GetPriority(Expr x) => x switch
{
Marker => 0,
Const => 1,
_ => 2,
};
}
50 changes: 50 additions & 0 deletions src/Nncase.EGraph/Passes/EGraphExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Google.OrTools.Sat;
using Nncase.CostModel;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes;

/// <summary>
/// EGraph extract extensions.
/// </summary>
public static class EGraphExtensions
{
/// <summary>
/// Extract egraph.
/// </summary>
/// <param name="eGraph">egraph.</param>
/// <param name="root">Root eclass.</param>
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
/// <param name="constrains">the cp model constrains.</param>
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[] constrains)
{
// 1. set enode expr with more accuracy type.
foreach (var eclass in eGraph.Classes)
{
foreach (var nodes in eclass.Nodes)
{
if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0)
{
nodes.Expr.CheckedType = eclass.CheckedType;
}
}
}

// 2. start the cost evaluator
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();

return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains);
}
}
95 changes: 0 additions & 95 deletions src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@
using Nncase.Diagnostics;
using Nncase.IR;

namespace Nncase.Passes.EGraphExtractors;
namespace Nncase.Passes;

internal class SatExtractor : IExtractor
public delegate void EGraphExtractConstrains(CpModel model, IReadOnlyDictionary<ENode, BoolVar> vars);

internal class EGraphExtractor
{
private readonly EGraphCostModel _costModel;

public SatExtractor(EGraphCostModel costModel)
public EGraphExtractor(EGraphCostModel costModel)
{
_costModel = costModel;
}

public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode, bool> picks)
public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] constrains)
{
var cpmodel = new CpModel();

Expand Down Expand Up @@ -68,6 +70,11 @@ public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode,
EliminateAllCycles(root, new(), new(), visited, cpmodel, vars);
}

foreach (var constrain in constrains)
{
constrain(cpmodel, vars);
}

// 3. add pick weights for all enode.
cpmodel.Minimize(LinearExpr.WeightedSum(eGraph.Nodes.Select(n => vars[n]), eGraph.Nodes.Select(n => checked((long)_costModel[n].Score))));

Expand Down Expand Up @@ -121,7 +128,7 @@ public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode,
throw new InvalidProgramException("SatExtract Failed!");
}

picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e]));
var picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e]));
using (var dumpStream = enableDump ? DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null)
{
EGraphPrinter.DumpEgraphAsDot(eGraph, _costModel, picks, root.Find(), dumpStream);
Expand Down
Loading

0 comments on commit dbcde6f

Please sign in to comment.