diff --git a/internal/imports/fix.go b/internal/imports/fix.go index 2578166c8ba..d9ba3b786dd 100644 --- a/internal/imports/fix.go +++ b/internal/imports/fix.go @@ -702,7 +702,7 @@ func addExternalCandidates(pass *pass, refs references, filename string) error { go func(pkgName string, symbols map[string]bool) { defer wg.Done() - found, err := findImport(ctx, pass.env, dirScan, pkgName, symbols, filename) + found, err := findImport(ctx, pass, dirScan, pkgName, symbols, filename) if err != nil { firstErrOnce.Do(func() { @@ -1028,7 +1028,7 @@ func loadExports(ctx context.Context, env *ProcessEnv, expectPackage string, pkg // findImport searches for a package with the given symbols. // If no package is found, findImport returns ("", false, nil) -func findImport(ctx context.Context, env *ProcessEnv, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) { +func findImport(ctx context.Context, pass *pass, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) { pkgDir, err := filepath.Abs(filename) if err != nil { return nil, err @@ -1038,7 +1038,12 @@ func findImport(ctx context.Context, env *ProcessEnv, dirScan []*pkg, pkgName st // Find candidate packages, looking only at their directory names first. var candidates []pkgDistance for _, pkg := range dirScan { - if pkg.dir != pkgDir && pkgIsCandidate(filename, pkgName, pkg) { + if pkg.dir == pkgDir && pass.f.Name.Name == pkgName { + // The candidate is in the same directory and has the + // same package name. Don't try to import ourselves. + continue + } + if pkgIsCandidate(filename, pkgName, pkg) { candidates = append(candidates, pkgDistance{ pkg: pkg, distance: distance(pkgDir, pkg.dir), @@ -1051,7 +1056,7 @@ func findImport(ctx context.Context, env *ProcessEnv, dirScan []*pkg, pkgName st // ones. Note that this sorts by the de-vendored name, so // there's no "penalty" for vendoring. sort.Sort(byDistanceOrImportPathShortLength(candidates)) - if env.Debug { + if pass.env.Debug { for i, c := range candidates { log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir) } @@ -1090,9 +1095,9 @@ func findImport(ctx context.Context, env *ProcessEnv, dirScan []*pkg, pkgName st wg.Done() }() - exports, err := loadExports(ctx, env, pkgName, c.pkg) + exports, err := loadExports(ctx, pass.env, pkgName, c.pkg) if err != nil { - if env.Debug { + if pass.env.Debug { log.Printf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err) } resc <- nil diff --git a/internal/imports/fix_test.go b/internal/imports/fix_test.go index 3a8445a201f..a5f7f7fe35e 100644 --- a/internal/imports/fix_test.go +++ b/internal/imports/fix_test.go @@ -2116,6 +2116,32 @@ const _ = pkg.X }.processTest(t, "foo.com", "pkg/b.go", nil, nil, want) } +func TestExternalTestImportsPackageUnderTest(t *testing.T) { + const provide = `package pkg +func DoIt(){} +` + const input = `package pkg_test + +var _ = pkg.DoIt` + + const want = `package pkg_test + +import "foo.com/pkg" + +var _ = pkg.DoIt +` + + testConfig{ + module: packagestest.Module{ + Name: "foo.com", + Files: fm{ + "pkg/provide.go": provide, + "pkg/x_test.go": input, + }, + }, + }.processTest(t, "foo.com", "pkg/x_test.go", nil, nil, want) +} + func TestPkgIsCandidate(t *testing.T) { tests := []struct { name string