Skip to content

Commit

Permalink
join all standard library imports
Browse files Browse the repository at this point in the history
This code is a bit tricky since we have to be careful to not break other
imports, and to not move comments around too much.

Fixes #8.
  • Loading branch information
mvdan committed Apr 27, 2019
1 parent 4bef639 commit b7afc71
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
46 changes: 46 additions & 0 deletions internal/gofumpt.go
Expand Up @@ -187,6 +187,7 @@ func (f *fumpter) visit(node ast.Node) {
pos = comments[0].Pos()
}

// multiline top-level declarations should be separated
multi := f.posLine(decl.Pos()) < f.posLine(decl.End())
if (multi && lastMulti) &&
f.posLine(lastEnd)+1 == f.posLine(pos) {
Expand Down Expand Up @@ -222,6 +223,9 @@ func (f *fumpter) visit(node ast.Node) {
}

case *ast.GenDecl:
if node.Tok == token.IMPORT && node.Lparen.IsValid() {
f.joinStdImports(node)
}
if len(node.Specs) == 1 && node.Lparen.IsValid() {
// If the single spec has any comment, it must go before
// the entire declaration now.
Expand Down Expand Up @@ -329,3 +333,45 @@ func (f *fumpter) visit(node ast.Node) {
f.removeLines(openLine, closeLine)
}
}

// joinStdImports ensures that all standard library imports are together and at
// the top of the imports list.
func (f *fumpter) joinStdImports(d *ast.GenDecl) {
var std, other []ast.Spec
for _, spec := range d.Specs {
spec := spec.(*ast.ImportSpec)
// First, separate the non-std imports.
if strings.Contains(spec.Path.Value, ".") {
other = append(other, spec)
continue
}
if len(other) > 0 {
// If we're moving this std import further up, reset its
// position, to avoid breaking comments.
setPos(reflect.ValueOf(spec), d.Pos())
}
std = append(std, spec)
}
// Finally, join the imports, keeping std at the top.
d.Specs = append(std, other...)
}

var posType = reflect.TypeOf(token.NoPos)

// setPos recursively sets all position fields in the node v to pos.
func setPos(v reflect.Value, pos token.Pos) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if !v.IsValid() {
return
}
if v.Type() == posType {
v.Set(reflect.ValueOf(pos))
}
if v.Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
setPos(v.Field(i), pos)
}
}
}
39 changes: 39 additions & 0 deletions testdata/scripts/std-imports.txt
@@ -0,0 +1,39 @@
[gofumports] skip 'don''t add or remove imports'

gofumpt -w foo.go .
cmp foo.go foo.go.golden

-- foo.go --
package p

import (
"io"

_ "bufio" // for a side effect
)

import (
"os"

"foo.localhost/other"

bytes_ "bytes"

"io"
)
-- foo.go.golden --
package p

import (
"io"

_ "bufio" // for a side effect
)

import (
bytes_ "bytes"
"io"
"os"

"foo.localhost/other"
)

0 comments on commit b7afc71

Please sign in to comment.