Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
262 lines (231 sloc) 5.9 KB
package main
import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"strings"
"github.com/pkg/errors"
"github.com/spf13/pflag"
"golang.org/x/tools/go/loader"
"golang.org/x/tools/imports"
)
type options struct {
pkgs []string
dryRun bool
debug bool
cmpImportName string
showLoaderErrors bool
useAllFiles bool
}
func main() {
name := os.Args[0]
flags, opts := setupFlags(name)
handleExitError(name, flags.Parse(os.Args[1:]))
setupLogging(opts)
opts.pkgs = flags.Args()
handleExitError(name, run(*opts))
}
func setupLogging(opts *options) {
log.SetFlags(0)
enableDebug = opts.debug
}
var enableDebug = false
func debugf(msg string, args ...interface{}) {
if enableDebug {
log.Printf("DEBUG: "+msg, args...)
}
}
func setupFlags(name string) (*pflag.FlagSet, *options) {
opts := options{}
flags := pflag.NewFlagSet(name, pflag.ContinueOnError)
flags.BoolVar(&opts.dryRun, "dry-run", false,
"don't write changes to file")
flags.BoolVar(&opts.debug, "debug", false, "enable debug logging")
flags.StringVar(&opts.cmpImportName, "cmp-pkg-import-alias", "is",
"import alias to use for the assert/cmp package")
flags.BoolVar(&opts.showLoaderErrors, "print-loader-errors", false,
"print errors from loading source")
flags.BoolVar(&opts.useAllFiles, "ignore-build-tags", false,
"migrate all files ignoring build tags")
flags.Usage = func() {
fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS] PACKAGE [PACKAGE...]
Migrate calls from testify/{assert|require} to gotest.tools/assert.
%s`, name, flags.FlagUsages())
}
return flags, &opts
}
func handleExitError(name string, err error) {
switch {
case err == nil:
return
case err == pflag.ErrHelp:
os.Exit(0)
default:
log.Println(name + ": Error: " + err.Error())
os.Exit(3)
}
}
func run(opts options) error {
program, err := loadProgram(opts)
if err != nil {
return errors.Wrapf(err, "failed to load program")
}
pkgs := program.InitialPackages()
debugf("package count: %d", len(pkgs))
fileset := program.Fset
for _, pkg := range pkgs {
for _, astFile := range pkg.Files {
absFilename := fileset.File(astFile.Pos()).Name()
filename := relativePath(absFilename)
importNames := newImportNames(astFile.Imports, opts)
if !importNames.hasTestifyImports() {
debugf("skipping file %s, no imports", filename)
continue
}
debugf("migrating %s with imports: %#v", filename, importNames)
m := migration{
file: astFile,
fileset: fileset,
importNames: importNames,
pkgInfo: pkg,
}
migrateFile(m)
if opts.dryRun {
continue
}
raw, err := formatFile(m)
if err != nil {
return errors.Wrapf(err, "failed to format %s", filename)
}
if err := ioutil.WriteFile(absFilename, raw, 0); err != nil {
return errors.Wrapf(err, "failed to write file %s", filename)
}
}
}
return nil
}
func loadProgram(opts options) (*loader.Program, error) {
fakeImporter, err := newFakeImporter()
if err != nil {
return nil, err
}
defer fakeImporter.Close()
conf := loader.Config{
Fset: token.NewFileSet(),
ParserMode: parser.ParseComments,
Build: buildContext(opts),
AllowErrors: true,
FindPackage: fakeImporter.Import,
}
for _, pkg := range opts.pkgs {
conf.ImportWithTests(pkg)
}
if !opts.showLoaderErrors {
conf.TypeChecker.Error = func(e error) {}
}
program, err := conf.Load()
if opts.showLoaderErrors {
for p, pkg := range program.AllPackages {
if len(pkg.Errors) > 0 {
fmt.Printf("Package %s loaded with some errors:\n", p.Name())
for _, err := range pkg.Errors {
fmt.Println(" ", err.Error())
}
}
}
}
return program, err
}
func buildContext(opts options) *build.Context {
c := build.Default
c.UseAllFiles = opts.useAllFiles
if val, ok := os.LookupEnv("GOPATH"); ok {
c.GOPATH = val
}
return &c
}
func relativePath(p string) string {
cwd, err := os.Getwd()
if err != nil {
return p
}
rel, err := filepath.Rel(cwd, p)
if err != nil {
return p
}
return rel
}
type importNames struct {
testifyAssert string
testifyRequire string
assert string
cmp string
}
func (p importNames) hasTestifyImports() bool {
return p.testifyAssert != "" || p.testifyRequire != ""
}
func (p importNames) matchesTestify(ident *ast.Ident) bool {
return ident.Name == p.testifyAssert || ident.Name == p.testifyRequire
}
func (p importNames) funcNameFromTestifyName(name string) string {
switch name {
case p.testifyAssert:
return funcNameCheck
case p.testifyRequire:
return funcNameAssert
default:
panic("unexpected testify import name " + name)
}
}
func newImportNames(imports []*ast.ImportSpec, opt options) importNames {
importNames := importNames{
assert: path.Base(pkgAssert),
cmp: path.Base(pkgCmp),
}
for _, spec := range imports {
switch strings.Trim(spec.Path.Value, `"`) {
case pkgTestifyAssert, pkgGopkgTestifyAssert:
importNames.testifyAssert = identOrDefault(spec.Name, "assert")
case pkgTestifyRequire, pkgGopkgTestifyRequire:
importNames.testifyRequire = identOrDefault(spec.Name, "require")
default:
if importedAs(spec, path.Base(pkgAssert)) {
importNames.assert = "gtyassert"
}
}
}
if opt.cmpImportName != "" {
importNames.cmp = opt.cmpImportName
}
return importNames
}
func importedAs(spec *ast.ImportSpec, pkg string) bool {
if path.Base(strings.Trim(spec.Path.Value, `"`)) == pkg {
return true
}
return spec.Name != nil && spec.Name.Name == pkg
}
func identOrDefault(ident *ast.Ident, def string) string {
if ident != nil {
return ident.Name
}
return def
}
func formatFile(migration migration) ([]byte, error) {
buf := new(bytes.Buffer)
err := format.Node(buf, migration.fileset, migration.file)
if err != nil {
return nil, err
}
filename := migration.fileset.File(migration.file.Pos()).Name()
return imports.Process(filename, buf.Bytes(), nil)
}
You can’t perform that action at this time.