Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/egraph extract constrains #1175

Merged
merged 4 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
Expand Down Expand Up @@ -141,6 +140,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
});

Expand Down
17 changes: 17 additions & 0 deletions src/Nncase.Core/TIR/TIRUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,21 @@
IR.F.Math.Max(0, t.First.Start),
IR.F.Math.Min(t.Second.FixedValue, t.First.Stop),
t.First.Step)).ToArray();

public static bool TryGetFixedRegions(TIR.BufferRegion region, out (int Start, int Stop, int Step)[] slice)
{
slice = new (int Start, int Stop, int Step)[region.Region.Length];

Check warning on line 77 in src/Nncase.Core/TIR/TIRUtilities.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/TIRUtilities.cs#L77

Added line #L77 was not covered by tests
for (int i = 0; i < region.Region.Length; i++)
{
var rg = region.Region[i];

Check warning on line 80 in src/Nncase.Core/TIR/TIRUtilities.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/TIRUtilities.cs#L80

Added line #L80 was not covered by tests
if (rg is not Range { Start: IR.TensorConst start, Stop: IR.TensorConst stop, Step: IR.TensorConst step })
{
return false;

Check warning on line 83 in src/Nncase.Core/TIR/TIRUtilities.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/TIRUtilities.cs#L83

Added line #L83 was not covered by tests
}

slice[i] = (start.Value.ToScalar<int>(), stop.Value.ToScalar<int>(), step.Value.ToScalar<int>());

Check warning on line 86 in src/Nncase.Core/TIR/TIRUtilities.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/TIRUtilities.cs#L86

Added line #L86 was not covered by tests
}

return true;

Check warning on line 89 in src/Nncase.Core/TIR/TIRUtilities.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/TIR/TIRUtilities.cs#L89

Added line #L89 was not covered by tests
}
}
1 change: 0 additions & 1 deletion src/Nncase.Core/Utilities/ShapeExprUtility.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using GiGraph.Dot.Output.Writers.Edges;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.IR.Tensors;
Expand Down
47 changes: 43 additions & 4 deletions src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@
return printer.SaveToStream(file);
}

/// <summary>
/// find the minCostEnode in eclass.
/// <remarks>
/// the marker first.
/// </remarks>
/// </summary>
internal static ENode MinByWithMarker(EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? CostModel.Cost.Zero : costModel[x])!;
}

/// <summary>
/// find the minCostEnode in eclass skip marker.
/// </summary>
internal static ENode MinByWithOutMarker(EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!;
}

private DotGraph AttachEGraphCost(CostModel.EGraphCostModel costModel, EClass entry)
{
// 1. display each enode costs.
Expand Down Expand Up @@ -72,12 +91,12 @@
continue;
}

var minCostEnode = parent.MinByWithMarker(costModel);
var minCostEnode = MinByWithMarker(parent, costModel);

Check warning on line 94 in src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs#L94

Added line #L94 was not covered by tests

// when this marker ecalss has been visited, skip it.
if (markerEclassMemo.Contains(parent))
{
minCostEnode = parent.MinByWithOutMarker(costModel);
minCostEnode = MinByWithOutMarker(parent, costModel);

Check warning on line 99 in src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs#L99

Added line #L99 was not covered by tests
}

var (minCostDotnode, table) = NodesMap[minCostEnode];
Expand All @@ -93,7 +112,7 @@
if (minCostEnode.Expr is Marker && child == parent)
{
markerEclassMemo.Add(child);
var otherminCostENode = child.MinByWithOutMarker(costModel);
var otherminCostENode = MinByWithOutMarker(child, costModel);

Check warning on line 115 in src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs#L115

Added line #L115 was not covered by tests
var (childDotNode, _) = NodesMap[otherminCostENode];
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
{
Expand All @@ -103,7 +122,7 @@
}
else
{
var childEnode = child.Find().MinByWithMarker(costModel);
var childEnode = MinByWithMarker(child.Find(), costModel);

Check warning on line 125 in src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs#L125

Added line #L125 was not covered by tests
var (childDotNode, _) = NodesMap[childEnode];
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
{
Expand All @@ -126,3 +145,23 @@
return _dotGraph;
}
}

internal sealed class ENodeTypeComparer : IComparer<Expr>
{
public static readonly ENodeTypeComparer Instance = new();

public int Compare(Expr? x, Expr? y) => (x, y) switch
{
(null, null) => 0,
(Expr, null) => 1,
(null, Expr) => -1,

Check warning on line 157 in src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs#L155-L157

Added lines #L155 - L157 were not covered by tests
(Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)),
};

private int GetPriority(Expr x) => x switch
{
Marker => 0,

Check warning on line 163 in src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs#L163

Added line #L163 was not covered by tests
Const => 1,
_ => 2,
};
}
50 changes: 50 additions & 0 deletions src/Nncase.EGraph/Passes/EGraphExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Google.OrTools.Sat;
using Nncase.CostModel;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes;

/// <summary>
/// EGraph extract extensions.
/// </summary>
public static class EGraphExtensions
{
/// <summary>
/// Extract egraph.
/// </summary>
/// <param name="eGraph">egraph.</param>
/// <param name="root">Root eclass.</param>
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
/// <param name="constrains">the cp model constrains.</param>
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[] constrains)
{
// 1. set enode expr with more accuracy type.
foreach (var eclass in eGraph.Classes)
{
foreach (var nodes in eclass.Nodes)
{
if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0)
{
nodes.Expr.CheckedType = eclass.CheckedType;

Check warning on line 40 in src/Nncase.EGraph/Passes/EGraphExtensions.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.EGraph/Passes/EGraphExtensions.cs#L40

Added line #L40 was not covered by tests
}
}
}

// 2. start the cost evaluator
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();

return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains);
}
}
95 changes: 0 additions & 95 deletions src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@
using Nncase.Diagnostics;
using Nncase.IR;

namespace Nncase.Passes.EGraphExtractors;
namespace Nncase.Passes;

internal class SatExtractor : IExtractor
public delegate void EGraphExtractConstrains(CpModel model, IReadOnlyDictionary<ENode, BoolVar> vars);

internal class EGraphExtractor
{
private readonly EGraphCostModel _costModel;

public SatExtractor(EGraphCostModel costModel)
public EGraphExtractor(EGraphCostModel costModel)
{
_costModel = costModel;
}

public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode, bool> picks)
public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] constrains)
{
var cpmodel = new CpModel();

Expand Down Expand Up @@ -68,6 +70,11 @@
EliminateAllCycles(root, new(), new(), visited, cpmodel, vars);
}

foreach (var constrain in constrains)
{
constrain(cpmodel, vars);

Check warning on line 75 in src/Nncase.EGraph/Passes/EGraphExtractor.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.EGraph/Passes/EGraphExtractor.cs#L75

Added line #L75 was not covered by tests
}

// 3. add pick weights for all enode.
cpmodel.Minimize(LinearExpr.WeightedSum(eGraph.Nodes.Select(n => vars[n]), eGraph.Nodes.Select(n => checked((long)_costModel[n].Score))));

Expand Down Expand Up @@ -121,7 +128,7 @@
throw new InvalidProgramException("SatExtract Failed!");
}

picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e]));
var picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e]));
using (var dumpStream = enableDump ? DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null)
{
EGraphPrinter.DumpEgraphAsDot(eGraph, _costModel, picks, root.Find(), dumpStream);
Expand Down
Loading
Loading