Skip to content

Commit

Permalink
refactor/eg: Add support for multi line after statements to eg.
Browse files Browse the repository at this point in the history
The semantics of this change are that the last line will be subsituted
in place of the expression, where as the lines before that will undergo
variable substitution and be prepended before the lowest (in the AST
tree sense) statement which included the expression.

Change-Id: Ie2571934dcc1b0a30b5cec157e690924a4ac2c5a
Reviewed-on: https://go-review.googlesource.com/77730
Reviewed-by: Alan Donovan <adonovan@google.com>
  • Loading branch information
clrprod authored and matloob committed Mar 19, 2018
1 parent 96caea4 commit 2226533
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 50 deletions.
32 changes: 31 additions & 1 deletion refactor/eg/eg.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ type Transformer struct {
env map[string]ast.Expr // maps parameter name to wildcard binding
importedObjs map[types.Object]*ast.SelectorExpr // objects imported by after().
before, after ast.Expr
afterStmts []ast.Stmt
allowWildcards bool

// Working state of Transform():
Expand Down Expand Up @@ -198,7 +199,7 @@ func NewTransformer(fset *token.FileSet, tmplPkg *types.Package, tmplFile *ast.F
if err != nil {
return nil, fmt.Errorf("before: %s", err)
}
after, err := soleExpr(afterDecl)
afterStmts, after, err := stmtAndExpr(afterDecl)
if err != nil {
return nil, fmt.Errorf("after: %s", err)
}
Expand Down Expand Up @@ -242,6 +243,7 @@ func NewTransformer(fset *token.FileSet, tmplPkg *types.Package, tmplFile *ast.F
importedObjs: make(map[types.Object]*ast.SelectorExpr),
before: before,
after: after,
afterStmts: afterStmts,
}

// Combine type info from the template and input packages, and
Expand Down Expand Up @@ -279,6 +281,7 @@ func WriteAST(fset *token.FileSet, filename string, f *ast.File) (err error) {
if err != nil {
return err
}

defer func() {
if err2 := fh.Close(); err != nil {
err = err2 // prefer earlier error
Expand Down Expand Up @@ -319,6 +322,33 @@ func soleExpr(fn *ast.FuncDecl) (ast.Expr, error) {
return nil, fmt.Errorf("must contain a single return or expression statement")
}

// stmtAndExpr returns the expression in the last return statement as well as the preceeding lines.
func stmtAndExpr(fn *ast.FuncDecl) ([]ast.Stmt, ast.Expr, error) {
if fn.Body == nil {
return nil, nil, fmt.Errorf("no body")
}

n := len(fn.Body.List)
if n == 0 {
return nil, nil, fmt.Errorf("must contain at least one statement")
}

stmts, last := fn.Body.List[:n-1], fn.Body.List[n-1]

switch last := last.(type) {
case *ast.ReturnStmt:
if len(last.Results) != 1 {
return nil, nil, fmt.Errorf("return statement must have a single operand")
}
return stmts, last.Results[0], nil

case *ast.ExprStmt:
return stmts, last.X, nil
}

return nil, nil, fmt.Errorf("must end with a single return or expression statement")
}

// mergeTypeInfo adds type info from src to dst.
func mergeTypeInfo(dst, src *types.Info) {
for k, v := range src.Types {
Expand Down
6 changes: 6 additions & 0 deletions refactor/eg/eg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ func Test(t *testing.T) {
"testdata/H.template",
"testdata/H1.go",

"testdata/I.template",
"testdata/I1.go",

"testdata/J.template",
"testdata/J1.go",

"testdata/bad_type.template",
"testdata/no_before.template",
"testdata/no_after_return.template",
Expand Down
152 changes: 105 additions & 47 deletions refactor/eg/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,52 @@ import (
"golang.org/x/tools/go/ast/astutil"
)

// transformItem takes a reflect.Value representing a variable of type ast.Node
// transforms its child elements recursively with apply, and then transforms the
// actual element if it contains an expression.
func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
// don't bother if val is invalid to start with
if !rv.IsValid() {
return reflect.Value{}, false, nil
}

rv, changed, newEnv := tr.apply(tr.transformItem, rv)

e := rvToExpr(rv)
if e == nil {
return rv, changed, newEnv
}

savedEnv := tr.env
tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs

if tr.matchExpr(tr.before, e) {
if tr.verbose {
fmt.Fprintf(os.Stderr, "%s matches %s",
astString(tr.fset, tr.before), astString(tr.fset, e))
if len(tr.env) > 0 {
fmt.Fprintf(os.Stderr, " with:")
for name, ast := range tr.env {
fmt.Fprintf(os.Stderr, " %s->%s",
name, astString(tr.fset, ast))
}
}
fmt.Fprintf(os.Stderr, "\n")
}
tr.nsubsts++

// Clone the replacement tree, performing parameter substitution.
// We update all positions to n.Pos() to aid comment placement.
rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
reflect.ValueOf(e.Pos()))
changed = true
newEnv = tr.env
}
tr.env = savedEnv

return rv, changed, newEnv
}

// Transform applies the transformation to the specified parsed file,
// whose type information is supplied in info, and returns the number
// of replacements that were made.
Expand All @@ -43,48 +89,14 @@ func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast
if tr.verbose {
fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
}

var f func(rv reflect.Value) reflect.Value
f = func(rv reflect.Value) reflect.Value {
// don't bother if val is invalid to start with
if !rv.IsValid() {
return reflect.Value{}
}

rv = apply(f, rv)

e := rvToExpr(rv)
if e != nil {
savedEnv := tr.env
tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs

if tr.matchExpr(tr.before, e) {
if tr.verbose {
fmt.Fprintf(os.Stderr, "%s matches %s",
astString(tr.fset, tr.before), astString(tr.fset, e))
if len(tr.env) > 0 {
fmt.Fprintf(os.Stderr, " with:")
for name, ast := range tr.env {
fmt.Fprintf(os.Stderr, " %s->%s",
name, astString(tr.fset, ast))
}
}
fmt.Fprintf(os.Stderr, "\n")
}
tr.nsubsts++

// Clone the replacement tree, performing parameter substitution.
// We update all positions to n.Pos() to aid comment placement.
rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
reflect.ValueOf(e.Pos()))
}
tr.env = savedEnv
}

return rv
o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
if changed {
panic("BUG")
}
file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File)
file2 := o.Interface().(*ast.File)

// By construction, the root node is unchanged.
if file != file2 {
Expand Down Expand Up @@ -150,45 +162,91 @@ var (
identType = reflect.TypeOf((*ast.Ident)(nil))
selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
objectPtrType = reflect.TypeOf((*ast.Object)(nil))
statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem()
positionType = reflect.TypeOf(token.NoPos)
scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
)

// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
// f takes a reflect.Value representing the variable to modify of type ast.Node.
// It returns a reflect.Value containing the transformed value of type ast.Node,
// whether any change was made, and a map of identifiers to ast.Expr (so we can
// do contextually correct substitutions in the parent statements).
func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
if !val.IsValid() {
return reflect.Value{}
return reflect.Value{}, false, nil
}

// *ast.Objects introduce cycles and are likely incorrect after
// rewrite; don't follow them but replace with nil instead
if val.Type() == objectPtrType {
return objectPtrNil
return objectPtrNil, false, nil
}

// similarly for scopes: they are likely incorrect after a rewrite;
// replace them with nil
if val.Type() == scopePtrType {
return scopePtrNil
return scopePtrNil, false, nil
}

switch v := reflect.Indirect(val); v.Kind() {
case reflect.Slice:
// no possible rewriting of statements.
if v.Type().Elem() != statementType {
changed := false
var envp map[string]ast.Expr
for i := 0; i < v.Len(); i++ {
e := v.Index(i)
o, localchanged, env := f(e)
if localchanged {
changed = true
// we clobber envp here,
// which means if we have two sucessive
// replacements inside the same statement
// we will only generate the setup for one of them.
envp = env
}
setValue(e, o)
}
return val, changed, envp
}

// statements are rewritten.
var out []ast.Stmt
for i := 0; i < v.Len(); i++ {
e := v.Index(i)
setValue(e, f(e))
o, changed, env := f(e)
if changed {
for _, s := range tr.afterStmts {
t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
out = append(out, t.(ast.Stmt))
}
}
setValue(e, o)
out = append(out, e.Interface().(ast.Stmt))
}
return reflect.ValueOf(out), false, nil
case reflect.Struct:
changed := false
var envp map[string]ast.Expr
for i := 0; i < v.NumField(); i++ {
e := v.Field(i)
setValue(e, f(e))
o, localchanged, env := f(e)
if localchanged {
changed = true
envp = env
}
setValue(e, o)
}
return val, changed, envp
case reflect.Interface:
e := v.Elem()
setValue(v, f(e))
o, changed, env := f(e)
setValue(v, o)
return val, changed, env
}
return val
return val, false, nil
}

// subst returns a copy of (replacement) pattern with values from env
Expand Down
14 changes: 14 additions & 0 deletions refactor/eg/testdata/I.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// +build ignore

package templates

import (
"errors"
"fmt"
)

func before(s string) error { return fmt.Errorf("%s", s) }
func after(s string) error {
n := fmt.Sprintf("error - %s", s)
return errors.New(n)
}
9 changes: 9 additions & 0 deletions refactor/eg/testdata/I1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// +build ignore

package I1

import "fmt"

func example() {
_ = fmt.Errorf("%s", "foo")
}
14 changes: 14 additions & 0 deletions refactor/eg/testdata/I1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// +build ignore

package I1

import (
"errors"
"fmt"
)

func example() {

n := fmt.Sprintf("error - %s", "foo")
_ = errors.New(n)
}
11 changes: 11 additions & 0 deletions refactor/eg/testdata/J.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// +build ignore

package templates

import ()

func before(x int) int { return x + x + x }
func after(x int) int {
temp := x + x
return temp + x
}
10 changes: 10 additions & 0 deletions refactor/eg/testdata/J1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// +build ignore

package I1

import "fmt"

func example() {
temp := 5
fmt.Print(temp + temp + temp)
}
11 changes: 11 additions & 0 deletions refactor/eg/testdata/J1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// +build ignore

package I1

import "fmt"

func example() {
temp := 5
temp := temp + temp
fmt.Print(temp + temp)
}
2 changes: 0 additions & 2 deletions refactor/eg/testdata/no_after_return.template
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
package template

const shouldFail = "after: must contain a single statement"

func before() int { return 0 }
func after() int { println(); return 0 }

0 comments on commit 2226533

Please sign in to comment.