Skip to content

Commit

Permalink
look for package name on local filesystem
Browse files Browse the repository at this point in the history
  • Loading branch information
incu6us committed Jul 24, 2020
1 parent 86b4af1 commit 612795e
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 26 deletions.
1 change: 0 additions & 1 deletion go.mod
Expand Up @@ -3,7 +3,6 @@ module github.com/incu6us/goimports-reviser
go 1.13

require (
github.com/davecgh/go-spew v1.1.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.5.1
golang.org/x/tools v0.0.0-20200407143752-a3568bac92ae
Expand Down
8 changes: 7 additions & 1 deletion main.go
Expand Up @@ -31,6 +31,8 @@ var (
shouldShowVersion *bool
shouldRemoveUnusedImports *bool
shouldSetAlias *bool

gopath = os.Getenv("GOPATH")
)

var projectName, filePath string
Expand Down Expand Up @@ -114,7 +116,7 @@ func main() {
options = append(options, reviser.OptionUseAliasForVersionSuffix)
}

formattedOutput, hasChange, err := reviser.Execute(projectName, filePath, options...)
formattedOutput, hasChange, err := reviser.Execute(gopath, projectName, filePath, options...)
if err != nil {
log.Fatalf("%+v", errors.WithStack(err))
}
Expand All @@ -139,6 +141,10 @@ func validateInputs(projectName, filePath string) error {
errMessages = append(errMessages, fmt.Sprintf("-%s should be set", filePathArg))
}

if gopath == "" {
errMessages = append(errMessages, "GOPATH environment variable should be set")
}

if len(errMessages) > 0 {
return errors.New(strings.Join(errMessages, "\n"))
}
Expand Down
91 changes: 78 additions & 13 deletions pkg/astutil/astutil.go
Expand Up @@ -2,21 +2,32 @@ package astutil

import (
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"os"
"path"
"path/filepath"
"strconv"
"strings"
)

const (
srcPathPrefix = "src"
pathSeparator = string(os.PathSeparator)
goFileExtensionSuffix = ".go"
)

// UsesImport is a similar to astutil.UsesImport but with skipping version in the import path
func UsesImport(f *ast.File, importPath string) bool {
func UsesImport(f *ast.File, gopath, importPath string) bool {
importIdentNames := make(map[string]struct{}, len(f.Imports))

var importSpec *ast.ImportSpec
for _, spec := range f.Imports {
name := spec.Name.String()
switch name {
case "<nil>":
pkgName, _ := PackageNameFromImportPath(importPath)
pkgName, _ := PackageNameFromImportPath(gopath, importPath)
importIdentNames[pkgName] = struct{}{}
case "_", ".":
return true
Expand All @@ -36,7 +47,7 @@ func UsesImport(f *ast.File, importPath string) bool {
ident, ok := sel.X.(*ast.Ident)
if ok {
if _, ok := importIdentNames[ident.Name]; ok {
pkg, _ := PackageNameFromImportPath(importPath)
pkg, _ := PackageNameFromImportPath(gopath, importPath)
if (ident.Name == pkg || ident.Name == importSpec.Name.String()) && ident.Obj == nil {
used = true
return
Expand All @@ -49,23 +60,36 @@ func UsesImport(f *ast.File, importPath string) bool {
return used
}

// PackageNameFromImportPath will return package alias name
// and true if it has a version suffix in the end of the path (ex.: github.com/go-pg/pg/v9)
func PackageNameFromImportPath(importPath string) (string, bool) {
var hasVersionSuffix bool
// PackageNameFromImportPath will return package name
// and true if import base suffix is different from its package name
func PackageNameFromImportPath(gopath, importPath string) (string, bool) {
pkgNameFromPath := path.Base(importPath)

base := path.Base(importPath)
if strings.HasPrefix(base, "v") {
if _, err := strconv.Atoi(base[1:]); err == nil {
hasVersionSuffix = true
if strings.HasPrefix(pkgNameFromPath, "v") {
if _, err := strconv.Atoi(pkgNameFromPath[1:]); err == nil {
dir := path.Dir(importPath)
if dir != "." {
base = path.Base(dir)
pkgNameFromPath = path.Base(dir)
}

return pkgNameFromPath, true
}
}

return base, hasVersionSuffix
pkgNameFromFS, err := resolvePackageName(gopath, importPath)
if err != nil {
if os.IsNotExist(err) {
return pkgNameFromPath, false
}

panic(err)
}

if pkgNameFromFS != pkgNameFromPath {
return pkgNameFromFS, true
}

return pkgNameFromFS, false
}

type visitFn func(node ast.Node)
Expand All @@ -74,3 +98,44 @@ func (f visitFn) Visit(node ast.Node) ast.Visitor {
f(node)
return f
}

// resolvePackageName resolves import to package name token(on local FS)
// Input:
// 1 - GOPATH value
// 2 - import package name(like: github.com/pkg/errors)
// Output:
// 1 - package (like: errors)
// 2 - error
func resolvePackageName(gopath string, pkg string) (string, error) {
srcPath := strings.Join([]string{gopath, srcPathPrefix}, pathSeparator)

pkgPath := strings.Join([]string{srcPath, pkg}, pathSeparator)

fileInfos, err := ioutil.ReadDir(pkgPath)
if err != nil {
return "", err
}

for _, fileInfo := range fileInfos {
if fileInfo.IsDir() {
continue
}

if filepath.Ext(fileInfo.Name()) != goFileExtensionSuffix {
continue
}

relativePathToFile := strings.Join([]string{pkgPath, fileInfo.Name()}, pathSeparator)

pf, err := parser.ParseFile(token.NewFileSet(), relativePathToFile, nil, parser.PackageClauseOnly)
if err != nil {
return "", err
}

if pf.Name != nil {
return pf.Name.String(), nil
}
}

return "", nil
}
4 changes: 3 additions & 1 deletion pkg/astutil/astutil_test.go
Expand Up @@ -9,6 +9,8 @@ import (
)

func TestUsesImport(t *testing.T) {
const gopath = "./testdata"

type args struct {
fileData string
path string
Expand Down Expand Up @@ -143,7 +145,7 @@ func main(){
require.Nil(t, err)
}

if got := UsesImport(f, tt.args.path); got != tt.want {
if got := UsesImport(f, gopath, tt.args.path); got != tt.want {
t.Errorf("UsesImport() = %v, want %v", got, tt.want)
}
})
Expand Down
1 change: 1 addition & 0 deletions pkg/astutil/testdata/src/github.com/go-pg/pg/v9/pg.go
@@ -0,0 +1 @@
package pg
@@ -0,0 +1 @@
package innderpkg
1 change: 1 addition & 0 deletions pkg/astutil/testdata/src/some-pkg-go/1
@@ -0,0 +1 @@
package test
Empty file.
3 changes: 3 additions & 0 deletions pkg/astutil/testdata/src/some-pkg-go/lib.go
@@ -0,0 +1,3 @@
// Some comments here

package useful_pkg
14 changes: 7 additions & 7 deletions reviser/reviser.go
Expand Up @@ -46,7 +46,7 @@ func (o Options) shouldUseAliasForVersionSuffix() bool {
}

// Revise imports and format the code
func Execute(projectName, filePath string, options ...Option) ([]byte, bool, error) {
func Execute(gopath, projectName, filePath string, options ...Option) ([]byte, bool, error) {
originalContent, err := ioutil.ReadFile(filePath)
if err != nil {
return nil, false, err
Expand All @@ -59,7 +59,7 @@ func Execute(projectName, filePath string, options ...Option) ([]byte, bool, err
return nil, false, err
}

importsWithMetadata := combineAllImportsWithMetadata(pf, options)
importsWithMetadata := combineAllImportsWithMetadata(pf, gopath, options)

stdImports, generalImports, projectImports := groupImports(projectName, importsWithMetadata)

Expand Down Expand Up @@ -218,7 +218,7 @@ func importWithComment(imprt string, commentsMetadata map[string]*commentsMetada
return fmt.Sprintf("%s %s", imprt, comment)
}

func combineAllImportsWithMetadata(f *ast.File, options Options) map[string]*commentsMetadata {
func combineAllImportsWithMetadata(f *ast.File, gopath string, options Options) map[string]*commentsMetadata {
importsWithMetadata := map[string]*commentsMetadata{}

shouldRemoveUnusedImports := options.shouldRemoveUnusedImports()
Expand All @@ -233,15 +233,15 @@ func combineAllImportsWithMetadata(f *ast.File, options Options) map[string]*com
var importSpecStr string
importSpec := spec.(*ast.ImportSpec)

if shouldRemoveUnusedImports && !astutil.UsesImport(f, strings.Trim(importSpec.Path.Value, `"`)) {
if shouldRemoveUnusedImports && !astutil.UsesImport(f, gopath, strings.Trim(importSpec.Path.Value, `"`)) {
continue
}

if importSpec.Name != nil {
importSpecStr = strings.Join([]string{importSpec.Name.String(), importSpec.Path.Value}, " ")
} else {
if shouldUseAliasForVersionSuffix {
importSpecStr = setAliasForVersionedImportSpec(importSpec)
importSpecStr = setAliasForVersionedImportSpec(gopath, importSpec)
} else {
importSpecStr = importSpec.Path.Value
}
Expand All @@ -259,10 +259,10 @@ func combineAllImportsWithMetadata(f *ast.File, options Options) map[string]*com
return importsWithMetadata
}

func setAliasForVersionedImportSpec(importSpec *ast.ImportSpec) string {
func setAliasForVersionedImportSpec(gopath string, importSpec *ast.ImportSpec) string {
var importSpecStr string

aliasName, ok := astutil.PackageNameFromImportPath(strings.Trim(importSpec.Path.Value, `"`))
aliasName, ok := astutil.PackageNameFromImportPath(gopath, strings.Trim(importSpec.Path.Value, `"`))
if ok {
importSpecStr = fmt.Sprintf("%s %s", aliasName, importSpec.Path.Value)
} else {
Expand Down
9 changes: 6 additions & 3 deletions reviser/reviser_test.go
Expand Up @@ -7,7 +7,10 @@ import (
"github.com/stretchr/testify/assert"
)

const gopath = "./testdata"

func TestExecute(t *testing.T) {

type args struct {
projectName string
filePath string
Expand Down Expand Up @@ -342,7 +345,7 @@ import (
}

t.Run(tt.name, func(t *testing.T) {
got, hasChange, err := Execute(tt.args.projectName, tt.args.filePath)
got, hasChange, err := Execute(gopath, tt.args.projectName, tt.args.filePath)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -558,7 +561,7 @@ const webDirectory = "web"
}

t.Run(tt.name, func(t *testing.T) {
got, hasChange, err := Execute(tt.args.projectName, tt.args.filePath, OptionRemoveUnusedImports)
got, hasChange, err := Execute(gopath, tt.args.projectName, tt.args.filePath, OptionRemoveUnusedImports)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -659,7 +662,7 @@ func main() {
}

t.Run(tt.name, func(t *testing.T) {
got, hasChange, err := Execute(tt.args.projectName, tt.args.filePath, OptionUseAliasForVersionSuffix)
got, hasChange, err := Execute(gopath, tt.args.projectName, tt.args.filePath, OptionUseAliasForVersionSuffix)
if (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
@@ -0,0 +1 @@
package innderpkg

0 comments on commit 612795e

Please sign in to comment.