Skip to content

Commit

Permalink
Compiler optimization first iteration (#165)
Browse files Browse the repository at this point in the history
* dead code elimination phase 1

* combine dead code elimination with return fix code

* remove last instruction tracking from compiler code (not needed)

* fix a symbol table block scope bug

* add some more tests
  • Loading branch information
d5 committed Mar 24, 2019
1 parent 01fe30f commit b9c1c92
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 98 deletions.
7 changes: 3 additions & 4 deletions compiler/compilation_scope.go
Expand Up @@ -5,8 +5,7 @@ import "github.com/d5/tengo/compiler/source"
// CompilationScope represents a compiled instructions
// and the last two instructions that were emitted.
type CompilationScope struct {
instructions []byte
lastInstructions [2]EmittedInstruction
symbolInit map[string]bool
sourceMap map[int]source.Pos
instructions []byte
symbolInit map[string]bool
sourceMap map[int]source.Pos
}
144 changes: 89 additions & 55 deletions compiler/compiler.go
Expand Up @@ -6,6 +6,7 @@ import (
"io/ioutil"
"path/filepath"
"reflect"
"sort"
"strings"

"github.com/d5/tengo"
Expand Down Expand Up @@ -292,6 +293,15 @@ func (c *Compiler) Compile(node ast.Node) error {
}

case *ast.BlockStmt:
if len(node.Stmts) == 0 {
return nil
}

c.symbolTable = c.symbolTable.Fork(true)
defer func() {
c.symbolTable = c.symbolTable.Parent(false)
}()

for _, stmt := range node.Stmts {
if err := c.Compile(stmt); err != nil {
return err
Expand Down Expand Up @@ -404,8 +414,8 @@ func (c *Compiler) Compile(node ast.Node) error {
return err
}

// add OpReturn if function returns nothing
c.fixReturn(node)
// code optimization
c.optimizeFunc(node)

freeSymbols := c.symbolTable.FreeSymbols()
numLocals := c.symbolTable.MaxSymbols()
Expand Down Expand Up @@ -688,33 +698,6 @@ func (c *Compiler) addInstruction(b []byte) int {
return posNewIns
}

func (c *Compiler) setLastInstruction(op Opcode, pos int) {
c.scopes[c.scopeIndex].lastInstructions[1] = c.scopes[c.scopeIndex].lastInstructions[0]

c.scopes[c.scopeIndex].lastInstructions[0].Opcode = op
c.scopes[c.scopeIndex].lastInstructions[0].Position = pos
}

func (c *Compiler) lastInstructionIs(op Opcode) bool {
if len(c.currentInstructions()) == 0 {
return false
}

return c.scopes[c.scopeIndex].lastInstructions[0].Opcode == op
}

func (c *Compiler) removeLastInstruction() {
lastPos := c.scopes[c.scopeIndex].lastInstructions[0].Position

if c.trace != nil {
c.printTrace(fmt.Sprintf("DELET %s",
FormatInstructions(c.scopes[c.scopeIndex].instructions[lastPos:], lastPos)[0]))
}

c.scopes[c.scopeIndex].instructions = c.currentInstructions()[:lastPos]
c.scopes[c.scopeIndex].lastInstructions[0] = c.scopes[c.scopeIndex].lastInstructions[1]
}

func (c *Compiler) replaceInstruction(pos int, inst []byte) {
copy(c.currentInstructions()[pos:], inst)

Expand All @@ -731,36 +714,88 @@ func (c *Compiler) changeOperand(opPos int, operand ...int) {
c.replaceInstruction(opPos, inst)
}

// fixReturn appends "return" statement at the end of the function if
// 1) the function does not have a "return" statement at the end.
// 2) or, there are jump instructions that jump to the end of the function.
func (c *Compiler) fixReturn(node ast.Node) {
var appendReturn bool
// optimizeFunc performs some code-level optimization for the current function instructions
// it removes unreachable (dead code) instructions and adds "returns" instruction if needed.
func (c *Compiler) optimizeFunc(node ast.Node) {
// any instructions between RETURN and the function end
// or instructions between RETURN and jump target position
// are considered as unreachable.

// pass 1. identify all jump destinations
var dsts []int
iterateInstructions(c.scopes[c.scopeIndex].instructions, func(pos int, opcode Opcode, operands []int) bool {
switch opcode {
case OpJump, OpJumpFalsy, OpAndJump, OpOrJump:
dsts = append(dsts, operands[0])
}

return true
})
sort.Ints(dsts) // sort jump positions

var newInsts []byte

// pass 2. eliminate dead code
posMap := make(map[int]int) // old position to new position
var dstIdx int
var deadCode bool
iterateInstructions(c.scopes[c.scopeIndex].instructions, func(pos int, opcode Opcode, operands []int) bool {
switch {
case opcode == OpReturn:
if deadCode {
return true
}
deadCode = true
case dstIdx < len(dsts) && pos == dsts[dstIdx]:
dstIdx++
deadCode = false
case deadCode:
return true
}

if !c.lastInstructionIs(OpReturn) {
appendReturn = true
} else {
var lastOp Opcode
insts := c.scopes[c.scopeIndex].instructions
endPos := len(insts)
iterateInstructions(insts, func(pos int, opcode Opcode, operands []int) bool {
defer func() { lastOp = opcode }()

switch opcode {
case OpJump, OpJumpFalsy, OpAndJump, OpOrJump:
dst := operands[0]
if dst == endPos && lastOp != OpReturn {
appendReturn = true
return false
} else if dst > endPos {
panic(fmt.Errorf("wrong jump position: %d (end: %d)", dst, endPos))
}
posMap[pos] = len(newInsts)
newInsts = append(newInsts, MakeInstruction(opcode, operands...)...)
return true
})

// pass 3. update jump positions
var lastOp Opcode
var appendReturn bool
endPos := len(newInsts)
iterateInstructions(newInsts, func(pos int, opcode Opcode, operands []int) bool {
switch opcode {
case OpJump, OpJumpFalsy, OpAndJump, OpOrJump:
newDst, ok := posMap[operands[0]]
if ok {
copy(newInsts[pos:], MakeInstruction(opcode, newDst))
} else if endPos == operands[0] {
// there's a jump instruction that jumps to the end of function
// compiler should append "return".
appendReturn = true
} else {
panic(fmt.Errorf("invalid jump position: %d", newDst))
}
}
lastOp = opcode
return true
})
if lastOp != OpReturn {
appendReturn = true
}

return true
})
// pass 4. update source map
newSourceMap := make(map[int]source.Pos)
for pos, srcPos := range c.scopes[c.scopeIndex].sourceMap {
newPos, ok := posMap[pos]
if ok {
newSourceMap[newPos] = srcPos
}
}

c.scopes[c.scopeIndex].instructions = newInsts
c.scopes[c.scopeIndex].sourceMap = newSourceMap

// append "return"
if appendReturn {
c.emit(node, OpReturn, 0)
}
Expand All @@ -775,7 +810,6 @@ func (c *Compiler) emit(node ast.Node, opcode Opcode, operands ...int) int {
inst := MakeInstruction(opcode, operands...)
pos := c.addInstruction(inst)
c.scopes[c.scopeIndex].sourceMap[pos] = filePos
c.setLastInstruction(opcode, pos)

if c.trace != nil {
c.printTrace(fmt.Sprintf("EMIT %s",
Expand Down
3 changes: 2 additions & 1 deletion compiler/compiler_module.go
Expand Up @@ -49,7 +49,8 @@ func (c *Compiler) compileModule(node ast.Node, moduleName, modulePath string, s
return nil, err
}

moduleCompiler.fixReturn(node)
// code optimization
moduleCompiler.optimizeFunc(node)

compiledFunc := moduleCompiler.Bytecode().MainFunction
compiledFunc.NumLocals = symbolTable.MaxSymbols()
Expand Down
124 changes: 124 additions & 0 deletions compiler/compiler_optimize_test.go
@@ -0,0 +1,124 @@
package compiler_test

import (
"testing"

"github.com/d5/tengo/compiler"
)

func TestCompilerDeadCode(t *testing.T) {
expect(t, `
func() {
a := 4
return a
b := 5 // dead code from here
c := a
return b
}`,
bytecode(
concat(
compiler.MakeInstruction(compiler.OpConstant, 2),
compiler.MakeInstruction(compiler.OpPop)),
objectsArray(
intObject(4),
intObject(5),
compiledFunction(0, 0,
compiler.MakeInstruction(compiler.OpConstant, 0),
compiler.MakeInstruction(compiler.OpDefineLocal, 0),
compiler.MakeInstruction(compiler.OpGetLocal, 0),
compiler.MakeInstruction(compiler.OpReturn, 1)))))

expect(t, `
func() {
if true {
return 5
a := 4 // dead code from here
b := a
return b
} else {
return 4
c := 5 // dead code from here
d := c
return d
}
}`, bytecode(
concat(
compiler.MakeInstruction(compiler.OpConstant, 2),
compiler.MakeInstruction(compiler.OpPop)),
objectsArray(
intObject(5),
intObject(4),
compiledFunction(0, 0,
compiler.MakeInstruction(compiler.OpTrue),
compiler.MakeInstruction(compiler.OpJumpFalsy, 9),
compiler.MakeInstruction(compiler.OpConstant, 0),
compiler.MakeInstruction(compiler.OpReturn, 1),
compiler.MakeInstruction(compiler.OpConstant, 1),
compiler.MakeInstruction(compiler.OpReturn, 1)))))

expect(t, `
func() {
a := 1
for {
if a == 5 {
return 10
}
5 + 5
return 20
b := a
return b
}
}`, bytecode(
concat(
compiler.MakeInstruction(compiler.OpConstant, 4),
compiler.MakeInstruction(compiler.OpPop)),
objectsArray(
intObject(1),
intObject(5),
intObject(10),
intObject(20),
compiledFunction(0, 0,
compiler.MakeInstruction(compiler.OpConstant, 0),
compiler.MakeInstruction(compiler.OpDefineLocal, 0),
compiler.MakeInstruction(compiler.OpGetLocal, 0),
compiler.MakeInstruction(compiler.OpConstant, 1),
compiler.MakeInstruction(compiler.OpEqual),
compiler.MakeInstruction(compiler.OpJumpFalsy, 19),
compiler.MakeInstruction(compiler.OpConstant, 2),
compiler.MakeInstruction(compiler.OpReturn, 1),
compiler.MakeInstruction(compiler.OpConstant, 1),
compiler.MakeInstruction(compiler.OpConstant, 1),
compiler.MakeInstruction(compiler.OpBinaryOp, 11),
compiler.MakeInstruction(compiler.OpPop),
compiler.MakeInstruction(compiler.OpConstant, 3),
compiler.MakeInstruction(compiler.OpReturn, 1)))))

expect(t, `
func() {
if true {
return 5
a := 4 // dead code from here
b := a
return b
} else {
return 4
c := 5 // dead code from here
d := c
return d
}
}`, bytecode(
concat(
compiler.MakeInstruction(compiler.OpConstant, 2),
compiler.MakeInstruction(compiler.OpPop)),
objectsArray(
intObject(5),
intObject(4),
compiledFunction(0, 0,
compiler.MakeInstruction(compiler.OpTrue),
compiler.MakeInstruction(compiler.OpJumpFalsy, 9),
compiler.MakeInstruction(compiler.OpConstant, 0),
compiler.MakeInstruction(compiler.OpReturn, 1),
compiler.MakeInstruction(compiler.OpConstant, 1),
compiler.MakeInstruction(compiler.OpReturn, 1)))))
}

0 comments on commit b9c1c92

Please sign in to comment.