diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index df4ca513ce2..ffac3bf2010 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -180,8 +180,16 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic) // getExtractCodeActions returns any refactor.extract code actions for the selection. func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) { + var actions []protocol.CodeAction + + extractToNewFileActions, err := getExtractToNewFileCodeActions(pgf, rng, options) + if err != nil { + return nil, err + } + actions = append(actions, extractToNewFileActions...) + if rng.Start == rng.End { - return nil, nil + return actions, nil } start, end, err := pgf.RangePos(rng) @@ -226,7 +234,6 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti } commands = append(commands, cmd) } - var actions []protocol.CodeAction for i := range commands { actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options)) } diff --git a/gopls/internal/golang/move_to_new_file.go b/gopls/internal/golang/move_to_new_file.go new file mode 100644 index 00000000000..e743d423dd5 --- /dev/null +++ b/gopls/internal/golang/move_to_new_file.go @@ -0,0 +1,361 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package golang + +// This file defines the code action "extract to a new file". + +// todo: rename file to extract_to_new_file.go after code review + +import ( + "context" + "errors" + "fmt" + "go/ast" + "go/format" + "go/token" + "go/types" + "os" + "path/filepath" + "strings" + + "golang.org/x/tools/gopls/internal/cache" + "golang.org/x/tools/gopls/internal/cache/parsego" + "golang.org/x/tools/gopls/internal/file" + "golang.org/x/tools/gopls/internal/protocol" + "golang.org/x/tools/gopls/internal/protocol/command" + "golang.org/x/tools/gopls/internal/settings" +) + +func getExtractToNewFileCodeActions(pgf *parsego.File, rng protocol.Range, _ *settings.Options) ([]protocol.CodeAction, error) { + ok := canExtractToNewFile(pgf, rng) + if !ok { + return nil, nil + } + cmd, err := command.NewExtractToNewFileCommand( + "Extract declarations to new file", + command.ExtractToNewFileArgs{URI: pgf.URI, Range: rng}, + ) + if err != nil { + return nil, err + } + return []protocol.CodeAction{{ + Title: "Extract declarations to new file", + Kind: protocol.RefactorExtract, + Command: &cmd, + }}, nil +} + +// canExtractToNewFile reports whether the code in the given range can be extracted to a new file. +func canExtractToNewFile(pgf *parsego.File, rng protocol.Range) bool { + _, err := extractToNewFileInternal(nil, nil, pgf, rng, true) + if err != nil { + return false + } else { + return true + } +} + +// ExtractToNewFile moves selected declarations into a new file. +func ExtractToNewFile( + ctx context.Context, + snapshot *cache.Snapshot, + fh file.Handle, + rng protocol.Range, +) (*protocol.WorkspaceEdit, error) { + pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI()) + if err != nil { + return nil, err + } + return extractToNewFileInternal(fh, pkg, pgf, rng, false) +} + +// findImportEdits finds imports specs that needs to be added to the new file +// or deleted from the old file if the range is extracted to a new file. +func findImportEdits(pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (adds []*ast.ImportSpec, deletes []*ast.ImportSpec) { + var ( + foundInSelection = make(map[*types.PkgName]bool) + foundInNonSelection = make(map[*types.PkgName]bool) + ) + for ident, use := range pkg.GetTypesInfo().Uses { + if pkgName, ok := use.(*types.PkgName); ok { + if contain(start, end, ident.Pos(), ident.End()) { + foundInSelection[pkgName] = true + } else { + foundInNonSelection[pkgName] = true + } + } + } + type NamePath struct { + Name string + Path string + } + + imports := make(map[NamePath]*ast.ImportSpec) + for _, v := range pgf.File.Imports { + path := strings.Trim(v.Path.Value, `"`) + if v.Name != nil { + imports[NamePath{v.Name.Name, path}] = v + } else { + imports[NamePath{"", path}] = v + } + } + + for pkgName := range foundInSelection { + importSpec := imports[NamePath{pkgName.Name(), pkgName.Imported().Path()}] + if importSpec == nil { + importSpec = imports[NamePath{"", pkgName.Imported().Path()}] + } + + adds = append(adds, importSpec) + if !foundInNonSelection[pkgName] { + deletes = append(deletes, importSpec) + } + } + + return adds, deletes +} + +// extractToNewFileInternal moves selected declarations into a new file. +func extractToNewFileInternal( + fh file.Handle, + pkg *cache.Package, + pgf *parsego.File, + rng protocol.Range, + dry bool, +) (*protocol.WorkspaceEdit, error) { + errorPrefix := "moveToANewFileInternal" + + start, end, err := pgf.RangePos(rng) + if err != nil { + return nil, fmt.Errorf("%s: %w", errorPrefix, err) + } + + start, end, filename, err := findRangeAndFilename(pgf, start, end) + if err != nil { + return nil, fmt.Errorf("%s: %w", errorPrefix, err) + } + + if dry { + return nil, nil + } + + end = skipWhiteSpaces(pgf, end) + + replaceRange, err := pgf.PosRange(start, end) + if err != nil { + return nil, fmt.Errorf("%s: %w", errorPrefix, err) + } + + adds, deletes := findImportEdits(pkg, pgf, start, end) + + var importDeletes []protocol.TextEdit + parenthesisFreeImports := findParenthesisFreeImports(pgf) + for _, importSpec := range deletes { + if decl := parenthesisFreeImports[importSpec]; decl != nil { + importDeletes = append(importDeletes, removeNode(pgf, decl)) + } else { + importDeletes = append(importDeletes, removeNode(pgf, importSpec)) + } + } + + importAdds := "" + if len(adds) > 0 { + importAdds += "import (" + for _, importSpec := range adds { + if importSpec.Name != nil { + importAdds += importSpec.Name.Name + " " + importSpec.Path.Value + "\n" + } else { + importAdds += importSpec.Path.Value + "\n" + } + } + importAdds += ")" + } + + createFileURI, err := resolveCreateFileURI(pgf.URI.Dir().Path(), filename) + if err != nil { + return nil, fmt.Errorf("%s: %w", errorPrefix, err) + } + + creatFileText, err := format.Source([]byte( + "package " + pgf.File.Name.Name + "\n" + + importAdds + "\n" + + string(pgf.Src[start-pgf.File.FileStart:end-pgf.File.FileStart]), + )) + if err != nil { + return nil, err + } + + return &protocol.WorkspaceEdit{ + DocumentChanges: []protocol.DocumentChanges{ + // original file edits + protocol.TextEditsToDocumentChanges(fh.URI(), fh.Version(), append( + importDeletes, + protocol.TextEdit{ + Range: replaceRange, + NewText: "", + }, + ))[0], + { + CreateFile: &protocol.CreateFile{ + Kind: "create", + URI: createFileURI, + }, + }, + // created file edits + protocol.TextEditsToDocumentChanges(createFileURI, 0, []protocol.TextEdit{ + { + Range: protocol.Range{}, + NewText: string(creatFileText), + }, + })[0], + }, + }, nil +} + +// resolveCreateFileURI checks that basename.go does not exists in dir, otherwise +// select basename.{1,2,3,4,5}.go as filename. +func resolveCreateFileURI(dir string, basename string) (protocol.DocumentURI, error) { + basename = strings.ToLower(basename) + newPath := filepath.Join(dir, basename+".go") + for count := 1; ; count++ { + if _, err := os.Stat(newPath); errors.Is(err, os.ErrNotExist) { + break + } + if count >= 5 { + return "", fmt.Errorf("resolveNewFileURI: exceeded retry limit") + } + filename := fmt.Sprintf("%s.%d.go", basename, count) + newPath = filepath.Join(dir, filename) + } + return protocol.URIFromPath(newPath), nil +} + +// findRangeAndFilename checks the selection is valid and extends range as needed and returns adjusted +// range and selected filename. +func findRangeAndFilename(pgf *parsego.File, start, end token.Pos) (token.Pos, token.Pos, string, error) { + if intersect(start, end, pgf.File.Package, pgf.File.Name.End()) { + return 0, 0, "", errors.New("selection cannot intersect a package declaration") + } + firstName := "" + for _, node := range pgf.File.Decls { + if intersect(start, end, node.Pos(), node.End()) { + if v, ok := node.(*ast.GenDecl); ok && v.Tok == token.IMPORT { + return 0, 0, "", errors.New("selection cannot intersect an import declaration") + } + if _, ok := node.(*ast.BadDecl); ok { + return 0, 0, "", errors.New("selection cannot intersect a bad declaration") + } + // should work when only selecting keyword "func" or function name + if v, ok := node.(*ast.FuncDecl); ok && contain(v.Pos(), v.Name.End(), start, end) { + start, end = v.Pos(), v.End() + } + // should work when only selecting keyword "type", "var", "const" + if v, ok := node.(*ast.GenDecl); ok && (v.Tok == token.TYPE && contain(v.Pos(), v.Pos()+4, start, end) || + v.Tok == token.CONST && contain(v.Pos(), v.Pos()+5, start, end) || + v.Tok == token.VAR && contain(v.Pos(), v.Pos()+3, start, end)) { + start, end = v.Pos(), v.End() + } + if !contain(start, end, node.Pos(), node.End()) { + return 0, 0, "", errors.New("selection cannot partially intersect a node") + } else { + if firstName == "" { + firstName = getNodeName(node) + } + // extends selection to docs comments + if c := getCommentGroup(node); c != nil { + if c.Pos() < start { + start = c.Pos() + } + } + } + } + } + for _, node := range pgf.File.Comments { + if intersect(start, end, node.Pos(), node.End()) { + if !contain(start, end, node.Pos(), node.End()) { + return 0, 0, "", errors.New("selection cannot partially intersect a comment") + } + } + } + if firstName == "" { + return 0, 0, "", errors.New("nothing selected") + } + return start, end, firstName, nil +} + +func skipWhiteSpaces(pgf *parsego.File, pos token.Pos) token.Pos { + i := pos + for ; i-pgf.File.FileStart < token.Pos(len(pgf.Src)); i++ { + c := pgf.Src[i-pgf.File.FileStart] + if c == ' ' || c == '\t' || c == '\n' { + continue + } else { + break + } + } + return i +} + +func getCommentGroup(node ast.Node) *ast.CommentGroup { + switch n := node.(type) { + case *ast.GenDecl: + return n.Doc + case *ast.FuncDecl: + return n.Doc + } + return nil +} + +func findParenthesisFreeImports(pgf *parsego.File) map[*ast.ImportSpec]*ast.GenDecl { + decls := make(map[*ast.ImportSpec]*ast.GenDecl) + for _, decl := range pgf.File.Decls { + if g, ok := decl.(*ast.GenDecl); ok { + if !g.Lparen.IsValid() && len(g.Specs) > 0 { + if v, ok := g.Specs[0].(*ast.ImportSpec); ok { + decls[v] = g + } + } + } + } + return decls +} + +// removeNode returns a TextEdit that removes the node +func removeNode(pgf *parsego.File, node ast.Node) protocol.TextEdit { + rng, _ := pgf.PosRange(node.Pos(), node.End()) + return protocol.TextEdit{Range: rng, NewText: ""} +} + +// getNodeName returns the first func name or variable name +func getNodeName(node ast.Node) string { + switch n := node.(type) { + case *ast.FuncDecl: + return n.Name.Name + case *ast.GenDecl: + if len(n.Specs) == 0 { + return "" + } + switch m := n.Specs[0].(type) { + case *ast.TypeSpec: + return m.Name.Name + case *ast.ValueSpec: + if len(m.Names) == 0 { + return "" + } + return m.Names[0].Name + } + } + return "" +} + +// intersect checks if [a, b) and [c, d) intersect, assuming a <= b and c <= d +func intersect(a, b, c, d token.Pos) bool { + return !(b <= c || d <= a) +} + +// contain checks if [a, b) contains [c, d), assuming a <= b and c <= d +func contain(a, b, c, d token.Pos) bool { + return a <= c && d <= b +} diff --git a/gopls/internal/protocol/command/command_gen.go b/gopls/internal/protocol/command/command_gen.go index 2ee0a5d19be..4d3a770ee66 100644 --- a/gopls/internal/protocol/command/command_gen.go +++ b/gopls/internal/protocol/command/command_gen.go @@ -30,6 +30,7 @@ const ( CheckUpgrades Command = "check_upgrades" DiagnoseFiles Command = "diagnose_files" EditGoDirective Command = "edit_go_directive" + ExtractToNewFile Command = "extract_to_new_file" FetchVulncheckResult Command = "fetch_vulncheck_result" GCDetails Command = "gc_details" Generate Command = "generate" @@ -66,6 +67,7 @@ var Commands = []Command{ CheckUpgrades, DiagnoseFiles, EditGoDirective, + ExtractToNewFile, FetchVulncheckResult, GCDetails, Generate, @@ -143,6 +145,12 @@ func Dispatch(ctx context.Context, params *protocol.ExecuteCommandParams, s Inte return nil, err } return nil, s.EditGoDirective(ctx, a0) + case "gopls.extract_to_new_file": + var a0 ExtractToNewFileArgs + if err := UnmarshalArgs(params.Arguments, &a0); err != nil { + return nil, err + } + return nil, s.ExtractToNewFile(ctx, a0) case "gopls.fetch_vulncheck_result": var a0 URIArg if err := UnmarshalArgs(params.Arguments, &a0); err != nil { @@ -379,6 +387,18 @@ func NewEditGoDirectiveCommand(title string, a0 EditGoDirectiveArgs) (protocol.C }, nil } +func NewExtractToNewFileCommand(title string, a0 ExtractToNewFileArgs) (protocol.Command, error) { + args, err := MarshalArgs(a0) + if err != nil { + return protocol.Command{}, err + } + return protocol.Command{ + Title: title, + Command: "gopls.extract_to_new_file", + Arguments: args, + }, nil +} + func NewFetchVulncheckResultCommand(title string, a0 URIArg) (protocol.Command, error) { args, err := MarshalArgs(a0) if err != nil { diff --git a/gopls/internal/protocol/command/interface.go b/gopls/internal/protocol/command/interface.go index b0dd06088bd..bc73064bff3 100644 --- a/gopls/internal/protocol/command/interface.go +++ b/gopls/internal/protocol/command/interface.go @@ -148,6 +148,11 @@ type Interface interface { // themselves. AddImport(context.Context, AddImportArgs) error + // ExtractToNewFile: Move selected codes to a new file + // + // Used by the code action of the same name. + ExtractToNewFile(context.Context, ExtractToNewFileArgs) error + // StartDebugging: Start the gopls debug server // // Start the gopls debug server if it isn't running, and return the debug @@ -330,6 +335,12 @@ type AddImportArgs struct { URI protocol.DocumentURI } +type ExtractToNewFileArgs struct { + // URI of the file + URI protocol.DocumentURI + Range protocol.Range +} + type ListKnownPackagesResult struct { // Packages is a list of packages relative // to the URIArg passed by the command request. diff --git a/gopls/internal/protocol/tsdocument_changes.go b/gopls/internal/protocol/tsdocument_changes.go index 2c7a524e178..393c5c85216 100644 --- a/gopls/internal/protocol/tsdocument_changes.go +++ b/gopls/internal/protocol/tsdocument_changes.go @@ -6,14 +6,16 @@ package protocol import ( "encoding/json" + "errors" "fmt" ) -// DocumentChanges is a union of a file edit and directory rename operations +// DocumentChanges is a union of a file edit, file creation, and directory rename operations // for package renaming feature. At most one field of this struct is non-nil. type DocumentChanges struct { TextDocumentEdit *TextDocumentEdit RenameFile *RenameFile + CreateFile *CreateFile } func (d *DocumentChanges) UnmarshalJSON(data []byte) error { @@ -26,10 +28,16 @@ func (d *DocumentChanges) UnmarshalJSON(data []byte) error { if _, ok := m["textDocument"]; ok { d.TextDocumentEdit = new(TextDocumentEdit) return json.Unmarshal(data, d.TextDocumentEdit) + } else if kind, ok := m["kind"]; ok { + if kind == "create" { + d.CreateFile = new(CreateFile) + return json.Unmarshal(data, d.CreateFile) + } else if kind == "rename" { + d.RenameFile = new(RenameFile) + return json.Unmarshal(data, d.RenameFile) + } } - - d.RenameFile = new(RenameFile) - return json.Unmarshal(data, d.RenameFile) + return errors.New("don't know how to unmarshal") } func (d *DocumentChanges) MarshalJSON() ([]byte, error) { @@ -37,6 +45,8 @@ func (d *DocumentChanges) MarshalJSON() ([]byte, error) { return json.Marshal(d.TextDocumentEdit) } else if d.RenameFile != nil { return json.Marshal(d.RenameFile) + } else if d.CreateFile != nil { + return json.Marshal(d.CreateFile) } - return nil, fmt.Errorf("Empty DocumentChanges union value") + return nil, fmt.Errorf("empty DocumentChanges union value") } diff --git a/gopls/internal/server/command.go b/gopls/internal/server/command.go index 19ea884f45d..a093e123c91 100644 --- a/gopls/internal/server/command.go +++ b/gopls/internal/server/command.go @@ -870,6 +870,22 @@ func (c *commandHandler) AddImport(ctx context.Context, args command.AddImportAr }) } +func (c *commandHandler) ExtractToNewFile(ctx context.Context, args command.ExtractToNewFileArgs) error { + return c.run(ctx, commandConfig{ + progress: "Extract to a new file", + forURI: args.URI, + }, func(ctx context.Context, deps commandDeps) error { + edit, err := golang.ExtractToNewFile(ctx, deps.snapshot, deps.fh, args.Range) + if err != nil { + return err + } + if _, err := c.s.client.ApplyEdit(ctx, &protocol.ApplyWorkspaceEditParams{Edit: *edit}); err != nil { + return fmt.Errorf("could not apply edits: %v", err) + } + return nil + }) +} + func (c *commandHandler) StartDebugging(ctx context.Context, args command.DebuggingArgs) (result command.DebuggingResult, _ error) { addr := args.Addr if addr == "" { diff --git a/gopls/internal/test/integration/fake/editor.go b/gopls/internal/test/integration/fake/editor.go index e93776408b6..8125f04887c 100644 --- a/gopls/internal/test/integration/fake/editor.go +++ b/gopls/internal/test/integration/fake/editor.go @@ -761,6 +761,18 @@ func (e *Editor) RegexpReplace(ctx context.Context, path, re, replace string) er return e.setBufferContentLocked(ctx, path, true, patched, edits) } +// OffsetLocation returns the Location for integers offsets start and end +func (e *Editor) OffsetLocation(bufName string, start, end int) (protocol.Location, error) { + e.mu.Lock() + buf, ok := e.buffers[bufName] + e.mu.Unlock() + if !ok { + return protocol.Location{}, ErrUnknownBuffer + } + + return buf.mapper.OffsetLocation(start, end) +} + // EditBuffer applies the given test edits to the buffer identified by path. func (e *Editor) EditBuffer(ctx context.Context, path string, edits []protocol.TextEdit) error { e.mu.Lock() @@ -1402,7 +1414,11 @@ func (e *Editor) applyDocumentChange(ctx context.Context, change protocol.Docume if change.TextDocumentEdit != nil { return e.applyTextDocumentEdit(ctx, *change.TextDocumentEdit) } - panic("Internal error: one of RenameFile or TextDocumentEdit must be set") + if change.CreateFile != nil { + path := e.sandbox.Workdir.URIToPath(change.CreateFile.URI) + return e.sandbox.Workdir.WriteFile(ctx, path, "") + } + panic("Internal error: one of RenameFile, CreateFile, or TextDocumentEdit must be set") } func (e *Editor) applyTextDocumentEdit(ctx context.Context, change protocol.TextDocumentEdit) error { diff --git a/gopls/internal/test/integration/misc/move_to_new_file_test.go b/gopls/internal/test/integration/misc/move_to_new_file_test.go new file mode 100644 index 00000000000..b236410190d --- /dev/null +++ b/gopls/internal/test/integration/misc/move_to_new_file_test.go @@ -0,0 +1,471 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package misc + +// todo: rename file to extract_to_new_file_test.go after code review + +import ( + "fmt" + "regexp" + "strings" + "testing" + + . "golang.org/x/tools/gopls/internal/test/integration" + + "golang.org/x/tools/gopls/internal/protocol" +) + +func dedent(s string) string { + s = strings.TrimPrefix(s, "\n") + indents := regexp.MustCompile("^\t*").FindString(s) + return regexp.MustCompile(fmt.Sprintf("(?m)^\t{0,%d}", len(indents))).ReplaceAllString(s, "") +} + +func indent(s string) string { + return regexp.MustCompile("(?m)^").ReplaceAllString(s, "\t") +} + +// compileTemplate replaces two █ characters in text and write to dest and returns +// the location enclosed by the two █ +func compileTemplate(env *Env, text string, dest string) protocol.Location { + i := strings.Index(text, "█") + j := strings.LastIndex(text, "█") + if strings.Count(text, "█") != 2 { + panic("expecting exactly two █ characters in source") + } + out := text[:i] + text[i+len("█"):j] + text[j+len("█"):] + env.Sandbox.Workdir.WriteFile(env.Ctx, dest, out) + env.OpenFile(dest) + loc, err := env.Editor.OffsetLocation(dest, i, j-len("█")) + if err != nil { + panic(err) + } + return loc +} + +func TestExtractToNewFile(t *testing.T) { + const files = ` +-- go.mod -- +module mod.com + +go 1.18 +-- main.go -- +package main + +-- existing.go -- +package main + +-- existing2.go -- +package main + +-- existing2.1.go -- +package main + +` + for _, tc := range []struct { + name string + source string + fixed string + createdFilename string + created string + }{ + { + name: "func declaration", + source: ` + package main + + func _() {} + + // fn docs + █func fn() {}█ + `, + fixed: ` + package main + + func _() {} + + `, + createdFilename: "fn.go", + created: ` + package main + + // fn docs + func fn() {} + `, + }, + { + name: "only select function name", + source: ` + package main + func █F█() {} + `, + fixed: ` + package main + `, + createdFilename: "f.go", + created: ` + package main + + func F() {} + `, + }, + { + name: "zero-width range", + source: ` + package main + func ██F() {} + `, + fixed: ` + package main + `, + createdFilename: "f.go", + created: ` + package main + + func F() {} + `, + }, + { + name: "type declaration", + source: ` + package main + + // T docs + █type T int + type S int█ + `, + fixed: ` + package main + + `, + createdFilename: "t.go", + created: ` + package main + + // T docs + type T int + type S int + `, + }, + { + name: "const and var declaration", + source: ` + package main + + // c docs + █const c = 0 + var v = 0█ + `, + fixed: ` + package main + + `, + createdFilename: "c.go", + created: ` + package main + + // c docs + const c = 0 + + var v = 0 + `, + }, + { + name: "select only const keyword", + source: ` + package main + + █const█ ( + A = iota + B + C + ) + `, + fixed: ` + package main + + `, + createdFilename: "a.go", + created: ` + package main + + const ( + A = iota + B + C + ) + `, + }, + { + name: "select surrounding comments", + source: ` + package main + + █// above + + func fn() {} + + // below█ + `, + fixed: ` + package main + + `, + createdFilename: "fn.go", + created: ` + package main + + // above + + func fn() {} + + // below + `, + }, + + { + name: "create file name conflict", + source: ` + package main + █func existing() {}█ + `, + fixed: ` + package main + `, + createdFilename: "existing.1.go", + created: ` + package main + + func existing() {} + `, + }, + { + name: "create file name conflict again", + source: ` + package main + █func existing2() {}█ + `, + fixed: ` + package main + `, + createdFilename: "existing2.2.go", + created: ` + package main + + func existing2() {} + `, + }, + { + name: "imports", + source: ` + package main + import "fmt" + █func F() { + fmt.Println() + }█ + `, + fixed: ` + package main + + `, + createdFilename: "f.go", + created: ` + package main + + import ( + "fmt" + ) + + func F() { + fmt.Println() + } + `, + }, + { + name: "import alias", + source: ` + package main + import fmt2 "fmt" + █func F() { + fmt2.Println() + }█ + `, + fixed: ` + package main + + `, + createdFilename: "f.go", + created: ` + package main + + import ( + fmt2 "fmt" + ) + + func F() { + fmt2.Println() + } + `, + }, + { + name: "multiple imports", + source: ` + package main + import ( + "fmt" + "log" + ) + func init(){ + log.Println() + } + █func F() { + fmt.Println() + }█ + + `, + fixed: ` + package main + import ( + + "log" + ) + func init(){ + log.Println() + } + `, + createdFilename: "f.go", + created: ` + package main + + import ( + "fmt" + ) + + func F() { + fmt.Println() + } + `, + }, + } { + t.Run(tc.name, func(t *testing.T) { + Run(t, files, func(t *testing.T, env *Env) { + tc.source, tc.fixed, tc.created = dedent(tc.source), dedent(tc.fixed), dedent(tc.created) + filename := "source.go" + loc := compileTemplate(env, tc.source, filename) + actions, err := env.Editor.CodeAction(env.Ctx, loc, nil) + if err != nil { + t.Fatal(err) + } + var codeAction *protocol.CodeAction + for _, action := range actions { + if action.Title == "Extract declarations to new file" { + codeAction = &action + break + } + } + if codeAction == nil { + t.Fatal("cannot find Extract declarations to new file action") + } + + env.ApplyCodeAction(*codeAction) + got := env.BufferText(filename) + if tc.fixed != got { + t.Errorf(`incorrect output of fixed file: +source: +%s +got: +%s +want: +%s +`, indent(tc.source), indent(got), indent(tc.fixed)) + } + gotMoved := env.BufferText(tc.createdFilename) + if tc.created != gotMoved { + t.Errorf(`incorrect output of created file: +source: +%s +got created file: +%s +want created file: +%s +`, indent(tc.source), indent(gotMoved), indent(tc.created)) + } + + }) + }) + } +} + +func TestExtractToNewFileInvalidSelection(t *testing.T) { + const files = ` +-- go.mod -- +module mod.com + +go 1.18 +-- main.go -- +package main + +` + for _, tc := range []struct { + name string + source string + }{ + { + name: "select package declaration", + source: ` + █package main█ + func fn() {} + `, + }, + { + name: "select imports", + source: ` + package main + █import fmt█ + `, + }, + { + name: "select only comment", + source: ` + package main + █// comment█ + `, + }, + { + name: "selection does not contain whole top-level node", + source: ` + package main + func fn() { + █print(0)█ + } + `, + }, + { + name: "selection cross a comment", + source: ` + package main + + █func fn() {} // comment█ comment + `, + }, + } { + t.Run(tc.name, func(t *testing.T) { + Run(t, files, func(t *testing.T, env *Env) { + filename := "source.go" + loc := compileTemplate(env, dedent(tc.source), filename) + actions, err := env.Editor.CodeAction(env.Ctx, loc, nil) + if err != nil { + t.Fatal(err) + } + + for _, action := range actions { + if action.Title == "Extract declarations to new file" { + t.Errorf("should not offer code action") + return + } + } + }) + }) + } +}