Skip to content

Commit

Permalink
cmd/compile/internal/inline/inlheur: assign scores to callsites
Browse files Browse the repository at this point in the history
Assign scores to callsites based on previously computed function
properties and callsite properties. This currently works by taking the
size score for the function (as computed by CanInline) and then making
a series of adjustments, positive or negative based on various
function and callsite properties.

NB: much work also remaining on deciding what are the best score
adjustment values for specific heuristics. I've picked a bunch of
arbitrary constants, but they will almost certainly need tuning and
tweaking to arrive at something that has good performance.

Updates #61502.

Change-Id: I887403f95e76d7aa2708494b8686c6026861a6ed
Reviewed-on: https://go-review.googlesource.com/c/go/+/511566
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
  • Loading branch information
thanm committed Sep 8, 2023
1 parent dc0548f commit 746f7e1
Show file tree
Hide file tree
Showing 10 changed files with 574 additions and 73 deletions.
18 changes: 10 additions & 8 deletions src/cmd/compile/internal/inline/inl.go
Expand Up @@ -293,15 +293,10 @@ func CanInline(fn *ir.Func, profile *pgo.Profile) {
base.Fatalf("CanInline no nname %+v", fn)
}

canInline := func(fn *ir.Func) { CanInline(fn, profile) }

var funcProps *inlheur.FuncProps
if goexperiment.NewInliner {
funcProps = inlheur.AnalyzeFunc(fn, canInline)
}

if base.Debug.DumpInlFuncProps != "" {
inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps, canInline)
if goexperiment.NewInliner || inlheur.UnitTesting() {
funcProps = inlheur.AnalyzeFunc(fn,
func(fn *ir.Func) { CanInline(fn, profile) })
}

var reason string // reason, if any, that the function was not inlined
Expand Down Expand Up @@ -803,6 +798,13 @@ func isBigFunc(fn *ir.Func) bool {
// InlineCalls/inlnode walks fn's statements and expressions and substitutes any
// calls made to inlineable functions. This is the external entry point.
func InlineCalls(fn *ir.Func, profile *pgo.Profile) {
if goexperiment.NewInliner && !fn.Wrapper() {
inlheur.ScoreCalls(fn)
}
if base.Debug.DumpInlFuncProps != "" && !fn.Wrapper() {
inlheur.DumpFuncProps(fn, base.Debug.DumpInlFuncProps,
func(fn *ir.Func) { CanInline(fn, profile) })
}
savefn := ir.CurFunc
ir.CurFunc = fn
bigCaller := isBigFunc(fn)
Expand Down
35 changes: 30 additions & 5 deletions src/cmd/compile/internal/inline/inlheur/analyze.go
Expand Up @@ -24,6 +24,7 @@ const (
debugTraceParams
debugTraceExprClassify
debugTraceCalls
debugTraceScoring
)

// propAnalyzer interface is used for defining one or more analyzer
Expand Down Expand Up @@ -76,6 +77,9 @@ func AnalyzeFunc(fn *ir.Func, canInline func(*ir.Func)) *FuncProps {
base.FatalfAt(fn.Pos(), "%v", err)
}
fpmap[fn] = entry
if fn.Inl != nil && fn.Inl.Properties == "" {
fn.Inl.Properties = entry.props.SerializeToString()
}
return fp
}

Expand Down Expand Up @@ -139,12 +143,26 @@ func UnitTesting() bool {
}

// DumpFuncProps computes and caches function properties for the func
// 'fn', or if fn is nil, writes out the cached set of properties to
// the file given in 'dumpfile'. Used for the "-d=dumpinlfuncprops=..."
// command line flag, intended for use primarily in unit testing.
// 'fn' and any closures it contains, or if fn is nil, it writes out the
// cached set of properties to the file given in 'dumpfile'. Used for
// the "-d=dumpinlfuncprops=..." command line flag, intended for use
// primarily in unit testing.
func DumpFuncProps(fn *ir.Func, dumpfile string, canInline func(*ir.Func)) {
if fn != nil {
dmp := func(fn *ir.Func) {

if !goexperiment.NewInliner {
ScoreCalls(fn)
}
captureFuncDumpEntry(fn, canInline)
}
captureFuncDumpEntry(fn, canInline)
dmp(fn)
ir.Visit(fn, func(n ir.Node) {
if clo, ok := n.(*ir.ClosureExpr); ok {
dmp(clo.Func)
}
})
} else {
emitDumpToFile(dumpfile)
}
Expand Down Expand Up @@ -185,9 +203,16 @@ func emitDumpToFile(dumpfile string) {
dumpBuffer = nil
}

// captureFuncDumpEntry analyzes function 'fn' and adds a entry
// for it to 'dumpBuffer'. Used for unit testing.
// captureFuncDumpEntry grabs the function properties object for 'fn'
// and enqueues it for later dumping. Used for the
// "-d=dumpinlfuncprops=..." command line flag, intended for use
// primarily in unit testing.
func captureFuncDumpEntry(fn *ir.Func, canInline func(*ir.Func)) {
if debugTrace&debugTraceFuncs != 0 {
fmt.Fprintf(os.Stderr, "=-= capturing dump for %v:\n",
fn.Sym().Name)
}

// avoid capturing compiler-generated equality funcs.
if strings.HasPrefix(fn.Sym().Name, ".eq.") {
return
Expand Down
84 changes: 81 additions & 3 deletions src/cmd/compile/internal/inline/inlheur/analyze_func_callsites.go
Expand Up @@ -5,10 +5,12 @@
package inlheur

import (
"cmd/compile/internal/base"
"cmd/compile/internal/ir"
"cmd/compile/internal/pgo"
"fmt"
"os"
"sort"
"strings"
)

Expand Down Expand Up @@ -120,26 +122,102 @@ func (csa *callSiteAnalyzer) determinePanicPathBits(call ir.Node, r CSPropBits)
}

func (csa *callSiteAnalyzer) addCallSite(callee *ir.Func, call *ir.CallExpr) {
flags := csa.flagsForNode(call)
// FIXME: maybe bulk-allocate these?
cs := &CallSite{
Call: call,
Callee: callee,
Assign: csa.containingAssignment(call),
Flags: csa.flagsForNode(call),
Id: uint(len(csa.cstab)),
Flags: flags,
ID: uint(len(csa.cstab)),
}
if _, ok := csa.cstab[call]; ok {
fmt.Fprintf(os.Stderr, "*** cstab duplicate entry at: %s\n",
fmtFullPos(call.Pos()))
fmt.Fprintf(os.Stderr, "*** call: %+v\n", call)
panic("bad")
}
if callee.Inl != nil {
// Set initial score for callsite to the cost computed
// by CanInline; this score will be refined later based
// on heuristics.
cs.Score = int(callee.Inl.Cost)
}

csa.cstab[call] = cs
if debugTrace&debugTraceCalls != 0 {
fmt.Fprintf(os.Stderr, "=-= added callsite: callee=%s call=%v\n",
callee.Sym().Name, callee)
}
}

csa.cstab[call] = cs
// ScoreCalls assigns numeric scores to each of the callsites in
// function 'fn'; the lower the score, the more helpful we think it
// will be to inline.
//
// Unlike a lot of the other inline heuristics machinery, callsite
// scoring can't be done as part of the CanInline call for a function,
// due to fact that we may be working on a non-trivial SCC. So for
// example with this SCC:
//
// func foo(x int) { func bar(x int, f func()) {
// if x != 0 { f()
// bar(x, func(){}) foo(x-1)
// } }
// }
//
// We don't want to perform scoring for the 'foo' call in "bar" until
// after foo has been analyzed, but it's conceivable that CanInline
// might visit bar before foo for this SCC.
func ScoreCalls(fn *ir.Func) {
enableDebugTraceIfEnv()
defer disableDebugTrace()
if debugTrace&debugTraceScoring != 0 {
fmt.Fprintf(os.Stderr, "=-= ScoreCalls(%v)\n", ir.FuncName(fn))
}

fih, ok := fpmap[fn]
if !ok {
// TODO: add an assert/panic here.
return
}

// Sort callsites to avoid any surprises with non deterministic
// map iteration order (this is probably not needed, but here just
// in case).
csl := make([]*CallSite, 0, len(fih.cstab))
for _, cs := range fih.cstab {
csl = append(csl, cs)
}
sort.Slice(csl, func(i, j int) bool {
return csl[i].ID < csl[j].ID
})

// Score each call site.
for _, cs := range csl {
var cprops *FuncProps
fihcprops := false
desercprops := false
if fih, ok := fpmap[cs.Callee]; ok {
cprops = fih.props
fihcprops = true
} else if cs.Callee.Inl != nil {
cprops = DeserializeFromString(cs.Callee.Inl.Properties)
desercprops = true
} else {
if base.Debug.DumpInlFuncProps != "" {
fmt.Fprintf(os.Stderr, "=-= *** unable to score call to %s from %s\n", cs.Callee.Sym().Name, fmtFullPos(cs.Call.Pos()))
panic("should never happen")
} else {
continue
}
}
cs.Score, cs.ScoreMask = computeCallSiteScore(cs.Callee, cprops, cs.Call, cs.Flags)

if debugTrace&debugTraceScoring != 0 {
fmt.Fprintf(os.Stderr, "=-= scoring call at %s: flags=%d score=%d fih=%v deser=%v\n", fmtFullPos(cs.Call.Pos()), cs.Flags, cs.Score, fihcprops, desercprops)
}
}
}

func (csa *callSiteAnalyzer) nodeVisitPre(n ir.Node) {
Expand Down
42 changes: 29 additions & 13 deletions src/cmd/compile/internal/inline/inlheur/callsite.go
Expand Up @@ -22,15 +22,16 @@ import (
// appears in the form of a top-level statement, e.g. "x := foo()"),
// "Flags" contains properties of the call that might be useful for
// making inlining decisions, "Score" is the final score assigned to
// the site, and "Id" is a numeric ID for the site within its
// the site, and "ID" is a numeric ID for the site within its
// containing function.
type CallSite struct {
Callee *ir.Func
Call *ir.CallExpr
Assign ir.Node
Flags CSPropBits
Score int
Id uint
Callee *ir.Func
Call *ir.CallExpr
Assign ir.Node
Flags CSPropBits
Score int
ScoreMask scoreAdjustTyp
ID uint
}

// CallSiteTab is a table of call sites, keyed by call expr.
Expand All @@ -53,8 +54,19 @@ const (

// encodedCallSiteTab is a table keyed by "encoded" callsite
// (stringified src.XPos plus call site ID) mapping to a value of call
// property bits.
type encodedCallSiteTab map[string]CSPropBits
// property bits and score.
type encodedCallSiteTab map[string]propsAndScore

type propsAndScore struct {
props CSPropBits
score int
mask scoreAdjustTyp
}

func (pas propsAndScore) String() string {
return fmt.Sprintf("P=%s|S=%d|M=%s", pas.props.String(),
pas.score, pas.mask.String())
}

func (cst CallSiteTab) merge(other CallSiteTab) error {
for k, v := range other {
Expand All @@ -80,17 +92,21 @@ func fmtFullPos(p src.XPos) string {

func encodeCallSiteKey(cs *CallSite) string {
var sb strings.Builder
// FIXME: rewrite line offsets relative to function start
// FIXME: maybe rewrite line offsets relative to function start?
sb.WriteString(fmtFullPos(cs.Call.Pos()))
fmt.Fprintf(&sb, "|%d", cs.Id)
fmt.Fprintf(&sb, "|%d", cs.ID)
return sb.String()
}

func buildEncodedCallSiteTab(tab CallSiteTab) encodedCallSiteTab {
r := make(encodedCallSiteTab)
for _, cs := range tab {
k := encodeCallSiteKey(cs)
r[k] = cs.Flags
r[k] = propsAndScore{
props: cs.Flags,
score: cs.Score,
mask: cs.ScoreMask,
}
}
return r
}
Expand All @@ -109,7 +125,7 @@ func dumpCallSiteComments(w io.Writer, tab CallSiteTab, ecst encodedCallSiteTab)
sort.Strings(tags)
for _, s := range tags {
v := ecst[s]
fmt.Fprintf(w, "// callsite: %s flagstr %q flagval %d\n", s, v.String(), v)
fmt.Fprintf(w, "// callsite: %s flagstr %q flagval %d score %d mask %d maskstr %q\n", s, v.props.String(), v.props, v.score, v.mask, v.mask.String())
}
fmt.Fprintf(w, "// %s\n", csDelimiter)
}
40 changes: 26 additions & 14 deletions src/cmd/compile/internal/inline/inlheur/funcprops_test.go
Expand Up @@ -72,8 +72,7 @@ func TestFuncProperties(t *testing.T) {
continue
}
if eidx >= len(eentries) {
t.Errorf("missing expected entry for %s, skipping",
dentry.fname)
t.Errorf("testcase %s missing expected entry for %s, skipping", tc, dentry.fname)
continue
}
eentry := eentries[eidx]
Expand Down Expand Up @@ -124,20 +123,18 @@ func compareEntries(t *testing.T, tc string, dentry *fnInlHeur, dcsites encodedC
// Compare call sites.
for k, ve := range ecsites {
if vd, ok := dcsites[k]; !ok {
t.Errorf("missing expected callsite %q in func %q",
dfn, k)
t.Errorf("testcase %q missing expected callsite %q in func %q", tc, k, dfn)
continue
} else {
if vd != ve {
t.Errorf("callsite %q in func %q: got %s want %s",
k, dfn, vd.String(), ve.String())
t.Errorf("testcase %q callsite %q in func %q: got %+v want %+v",
tc, k, dfn, vd.String(), ve.String())
}
}
}
for k := range dcsites {
if _, ok := ecsites[k]; !ok {
t.Errorf("unexpected extra callsite %q in func %q",
dfn, k)
t.Errorf("testcase %q unexpected extra callsite %q in func %q", tc, k, dfn)
}
}
}
Expand Down Expand Up @@ -276,13 +273,12 @@ func (dr *dumpReader) readEntry() (fnInlHeur, encodedCallSiteTab, error) {
if line == csDelimiter {
break
}
// expected format: "// callsite: <expanded pos> flagstr <desc> flagval <flags>"
// expected format: "// callsite: <expanded pos> flagstr <desc> flagval <flags> score <score> mask <scoremask> maskstr <scoremaskstring>"
fields := strings.Fields(line)
if len(fields) != 6 {
return fih, nil, fmt.Errorf("malformed callsite %s line %d: %s",
dr.p, dr.ln, line)
if len(fields) != 12 {
return fih, nil, fmt.Errorf("malformed callsite (nf=%d) %s line %d: %s", len(fields), dr.p, dr.ln, line)
}
if fields[2] != "flagstr" || fields[4] != "flagval" {
if fields[2] != "flagstr" || fields[4] != "flagval" || fields[6] != "score" || fields[8] != "mask" || fields[10] != "maskstr" {
return fih, nil, fmt.Errorf("malformed callsite %s line %d: %s",
dr.p, dr.ln, line)
}
Expand All @@ -293,7 +289,23 @@ func (dr *dumpReader) readEntry() (fnInlHeur, encodedCallSiteTab, error) {
return fih, nil, fmt.Errorf("bad flags val %s line %d: %q err=%v",
dr.p, dr.ln, line, err)
}
callsites[tag] = CSPropBits(flags)
scorestr := fields[7]
score, err2 := strconv.Atoi(scorestr)
if err2 != nil {
return fih, nil, fmt.Errorf("bad score val %s line %d: %q err=%v",
dr.p, dr.ln, line, err2)
}
maskstr := fields[9]
mask, err3 := strconv.Atoi(maskstr)
if err3 != nil {
return fih, nil, fmt.Errorf("bad mask val %s line %d: %q err=%v",
dr.p, dr.ln, line, err3)
}
callsites[tag] = propsAndScore{
props: CSPropBits(flags),
score: score,
mask: scoreAdjustTyp(mask),
}
}

// Consume function delimiter.
Expand Down

0 comments on commit 746f7e1

Please sign in to comment.