diff --git a/src/cmd/compile/internal/inline/interleaved/interleaved.go b/src/cmd/compile/internal/inline/interleaved/interleaved.go index c5334d0300f2a..e55b0f1aeeb7b 100644 --- a/src/cmd/compile/internal/inline/interleaved/interleaved.go +++ b/src/cmd/compile/internal/inline/interleaved/interleaved.go @@ -83,39 +83,108 @@ func DevirtualizeAndInlineFunc(fn *ir.Func, profile *pgo.Profile) { fmt.Printf("%v: function %v considered 'big'; reducing max cost of inlinees\n", ir.Line(fn), fn) } - // Walk fn's body and apply devirtualization and inlining. - var inlCalls []*ir.InlinedCallExpr - var edit func(ir.Node) ir.Node - edit = func(n ir.Node) ir.Node { + match := func(n ir.Node) bool { switch n := n.(type) { + case *ir.CallExpr: + return true case *ir.TailCallStmt: n.Call.NoInline = true // can't inline yet } + return false + } + + edit := func(n ir.Node) ir.Node { + call, ok := n.(*ir.CallExpr) + if !ok { // previously inlined + return nil + } + + devirtualize.StaticCall(call) + if inlCall := inline.TryInlineCall(fn, call, bigCaller, profile); inlCall != nil { + return inlCall + } + return nil + } + + fixpoint(fn, match, edit) + }) +} + +// fixpoint repeatedly edits a function until it stabilizes. +// +// First, fixpoint applies match to every node n within fn. Then it +// iteratively applies edit to each node satisfying match(n). +// +// If edit(n) returns nil, no change is made. Otherwise, the result +// replaces n in fn's body, and fixpoint iterates at least once more. +// +// After an iteration where all edit calls return nil, fixpoint +// returns. +func fixpoint(fn *ir.Func, match func(ir.Node) bool, edit func(ir.Node) ir.Node) { + // Consider the expression "f(g())". We want to be able to replace + // "g()" in-place with its inlined representation. But if we first + // replace "f(...)" with its inlined representation, then "g()" will + // instead appear somewhere within this new AST. + // + // To mitigate this, each matched node n is wrapped in a ParenExpr, + // so we can reliably replace n in-place by assigning ParenExpr.X. + // It's safe to use ParenExpr here, because typecheck already + // removed them all. + + var parens []*ir.ParenExpr + var mark func(ir.Node) ir.Node + mark = func(n ir.Node) ir.Node { + if _, ok := n.(*ir.ParenExpr); ok { + return n // already visited n.X before wrapping + } - ir.EditChildren(n, edit) + ok := match(n) - if call, ok := n.(*ir.CallExpr); ok { - devirtualize.StaticCall(call) + ir.EditChildren(n, mark) - if inlCall := inline.TryInlineCall(fn, call, bigCaller, profile); inlCall != nil { - inlCalls = append(inlCalls, inlCall) - n = inlCall - } + if ok { + paren := ir.NewParenExpr(n.Pos(), n) + paren.SetType(n.Type()) + paren.SetTypecheck(n.Typecheck()) + + parens = append(parens, paren) + n = paren + } + + return n + } + ir.EditChildren(fn, mark) + + // Edit until stable. + for { + done := true + + for i := 0; i < len(parens); i++ { // can't use "range parens" here + paren := parens[i] + if new := edit(paren.X); new != nil { + // Update AST and recursively mark nodes. + paren.X = new + ir.EditChildren(new, mark) // mark may append to parens + done = false } + } - return n + if done { + break } - ir.EditChildren(fn, edit) - - // If we inlined any calls, we want to recursively visit their - // bodies for further devirtualization and inlining. However, we - // need to wait until *after* the original function body has been - // expanded, or else inlCallee can have false positives (e.g., - // #54632). - for len(inlCalls) > 0 { - call := inlCalls[0] - inlCalls = inlCalls[1:] - ir.EditChildren(call, edit) + } + + // Finally, remove any parens we inserted. + if len(parens) == 0 { + return // short circuit + } + var unparen func(ir.Node) ir.Node + unparen = func(n ir.Node) ir.Node { + if paren, ok := n.(*ir.ParenExpr); ok { + n = paren.X } - }) + ir.EditChildren(n, unparen) + return n + } + ir.EditChildren(fn, unparen) } diff --git a/src/cmd/compile/internal/ir/expr.go b/src/cmd/compile/internal/ir/expr.go index da5b437f99e6a..345828c1638a8 100644 --- a/src/cmd/compile/internal/ir/expr.go +++ b/src/cmd/compile/internal/ir/expr.go @@ -856,13 +856,19 @@ func IsAddressable(n Node) bool { // "g()" expression. func StaticValue(n Node) Node { for { - if n.Op() == OCONVNOP { - n = n.(*ConvExpr).X - continue - } - - if n.Op() == OINLCALL { - n = n.(*InlinedCallExpr).SingleResult() + switch n1 := n.(type) { + case *ConvExpr: + if n1.Op() == OCONVNOP { + n = n1.X + continue + } + case *InlinedCallExpr: + if n1.Op() == OINLCALL { + n = n1.SingleResult() + continue + } + case *ParenExpr: + n = n1.X continue }