From 8985bb29cb103d41c2fbcc2ebe9d4ec7f90aa67e Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Tue, 31 Mar 2026 13:03:56 +0200 Subject: [PATCH] Fix multiple bugs in rewrite-go parser, visitor, printer, and RPC Motivation: Code review of the rewrite-go module identified several bugs and missing visitor coverage that would impact correctness and robustness. Summary: - Fix panic recovery in safeHandleRequest to return a proper JSON-RPC error response instead of nil when a panic is caught - Fix parser to handle multiple import blocks (e.g., `import "fmt"` followed by `import "os"`), using a new ImportBlock marker to track block boundaries - Fix visitAndCast and visitExpression to handle nil returns from visitors instead of panicking on the type assertion - Fix VisitCompilationUnit to visit PackageDecl and Imports, which were previously skipped entirely (preventing recipe visitors from transforming package names or imports) - Remove dead code in mapArrayType (unreachable closePrefix computation) --- rewrite-go/rewrite/cmd/rpc/main.go | 7 +- rewrite-go/rewrite/pkg/parser/go_parser.go | 91 ++++++++++---- rewrite-go/rewrite/pkg/printer/go_printer.go | 18 ++- rewrite-go/rewrite/pkg/rpc/space_rpc.go | 22 ++++ rewrite-go/rewrite/pkg/rpc/value_types.go | 3 + rewrite-go/rewrite/pkg/tree/go.go | 25 ++++ rewrite-go/rewrite/pkg/visitor/go_visitor.go | 20 ++++ rewrite-go/rewrite/test/import_test.go | 32 +++++ rewrite-go/rewrite/test/visitor_test.go | 112 ++++++++++++++++++ .../golang/marker/ImportBlock.java | 55 +++++++++ 10 files changed, 358 insertions(+), 27 deletions(-) create mode 100644 rewrite-go/rewrite/test/visitor_test.go create mode 100644 rewrite-go/src/main/java/org/openrewrite/golang/marker/ImportBlock.java diff --git a/rewrite-go/rewrite/cmd/rpc/main.go b/rewrite-go/rewrite/cmd/rpc/main.go index c6a990de8c4..7cf48870ba7 100644 --- a/rewrite-go/rewrite/cmd/rpc/main.go +++ b/rewrite-go/rewrite/cmd/rpc/main.go @@ -194,12 +194,17 @@ func (s *server) writeMessage(resp *jsonRPCResponse) error { } // safeHandleRequest wraps handleRequest with panic recovery. -func (s *server) safeHandleRequest(req *jsonRPCRequest) *jsonRPCResponse { +func (s *server) safeHandleRequest(req *jsonRPCRequest) (resp *jsonRPCResponse) { defer func() { if r := recover(); r != nil { buf := make([]byte, 4096) n := runtime.Stack(buf, false) s.logger.Printf("PANIC in %s: %v\n%s", req.Method, r, buf[:n]) + resp = &jsonRPCResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &rpcError{Code: -32603, Message: fmt.Sprintf("Internal error: %v", r)}, + } } }() return s.handleRequest(req) diff --git a/rewrite-go/rewrite/pkg/parser/go_parser.go b/rewrite-go/rewrite/pkg/parser/go_parser.go index cdc51580b9a..6c50d3c9e8f 100644 --- a/rewrite-go/rewrite/pkg/parser/go_parser.go +++ b/rewrite-go/rewrite/pkg/parser/go_parser.go @@ -171,59 +171,102 @@ func (ctx *parseContext) mapFile(file *ast.File, sourcePath string) *tree.Compil } } -// mapImports maps the import declarations in the file. +// mapImports maps all import declarations in the file into a single Container. +// Go allows multiple import blocks; subsequent blocks are tracked via ImportBlock markers. func (ctx *parseContext) mapImports(file *ast.File) *tree.Container[*tree.Import] { - var importDecl *ast.GenDecl + // Collect all import GenDecls in order. + var importDecls []*ast.GenDecl for _, decl := range file.Decls { if gd, ok := decl.(*ast.GenDecl); ok && gd.Tok == token.IMPORT { - importDecl = gd - break + importDecls = append(importDecls, gd) } } - if importDecl == nil { + if len(importDecls) == 0 { return nil } - before := ctx.prefixAndSkip(importDecl.Pos(), len("import")) - var elements []tree.RightPadded[*tree.Import] - var containerMarkers tree.Markers - if importDecl.Lparen.IsValid() { - openParenPrefix := ctx.prefix(importDecl.Lparen) - ctx.skip(1) // skip "(" + prevGrouped := false + // First import block: captured into Container.Before and Container.Markers + first := importDecls[0] + before := ctx.prefixAndSkip(first.Pos(), len("import")) + + if first.Lparen.IsValid() { + prevGrouped = true + openParenPrefix := ctx.prefix(first.Lparen) + ctx.skip(1) // skip "(" containerMarkers = tree.Markers{ ID: uuid.New(), Entries: []tree.Marker{ tree.GroupedImport{Ident: uuid.New(), Before: openParenPrefix}, }, } + } - for _, spec := range importDecl.Specs { - is := spec.(*ast.ImportSpec) - imp := ctx.mapImportSpec(is) - elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp}) - } + for _, spec := range first.Specs { + is := spec.(*ast.ImportSpec) + imp := ctx.mapImportSpec(is) + elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp}) + } - closeParen := ctx.prefix(importDecl.Rparen) + if first.Lparen.IsValid() { + closeParen := ctx.prefix(first.Rparen) ctx.skip(1) // skip ")" - if len(elements) > 0 { elements[len(elements)-1].After = closeParen } - } else { - for _, spec := range importDecl.Specs { - is := spec.(*ast.ImportSpec) - imp := ctx.mapImportSpec(is) - elements = append(elements, tree.RightPadded[*tree.Import]{Element: imp}) + } + + // Subsequent import blocks: attach ImportBlock marker to first import of each + for _, importDecl := range importDecls[1:] { + blockBefore := ctx.prefixAndSkip(importDecl.Pos(), len("import")) + grouped := importDecl.Lparen.IsValid() + var groupedBefore tree.Space + if grouped { + groupedBefore = ctx.prefix(importDecl.Lparen) + ctx.skip(1) // skip "(" } + + ctx.mapImportBlockSpecs(importDecl, &elements, tree.ImportBlock{ + Ident: uuid.New(), + ClosePrevious: prevGrouped, + Before: blockBefore, + Grouped: grouped, + GroupedBefore: groupedBefore, + }) + + if grouped { + closeParen := ctx.prefix(importDecl.Rparen) + ctx.skip(1) // skip ")" + if len(elements) > 0 { + elements[len(elements)-1].After = closeParen + } + } + prevGrouped = grouped } container := tree.Container[*tree.Import]{Before: before, Elements: elements, Markers: containerMarkers} return &container } +// mapImportBlockSpecs maps the specs of a subsequent import block, attaching +// the ImportBlock marker to the first spec's Import node. +func (ctx *parseContext) mapImportBlockSpecs(decl *ast.GenDecl, elements *[]tree.RightPadded[*tree.Import], marker tree.ImportBlock) { + for j, spec := range decl.Specs { + is := spec.(*ast.ImportSpec) + imp := ctx.mapImportSpec(is) + if j == 0 { + imp.Markers = tree.Markers{ + ID: uuid.New(), + Entries: []tree.Marker{marker}, + } + } + *elements = append(*elements, tree.RightPadded[*tree.Import]{Element: imp}) + } +} + // mapImportSpec maps a single import spec. func (ctx *parseContext) mapImportSpec(spec *ast.ImportSpec) *tree.Import { prefix := ctx.prefix(spec.Pos()) @@ -1761,8 +1804,8 @@ func (ctx *parseContext) mapArrayType(expr *ast.ArrayType) tree.Expression { length = ctx.mapExpr(expr.Len) } - closePrefix := ctx.prefix(expr.Lbrack + token.Pos(ctx.findNextFrom('[', ctx.file.Offset(expr.Lbrack)) - ctx.file.Offset(expr.Lbrack))) // Find the `]` + var closePrefix tree.Space rbrackOff := ctx.findNext(']') if rbrackOff >= 0 { closePrefix = ctx.prefix(ctx.file.Pos(rbrackOff)) diff --git a/rewrite-go/rewrite/pkg/printer/go_printer.go b/rewrite-go/rewrite/pkg/printer/go_printer.go index 761bc6567f8..2d02e85c20f 100644 --- a/rewrite-go/rewrite/pkg/printer/go_printer.go +++ b/rewrite-go/rewrite/pkg/printer/go_printer.go @@ -75,15 +75,29 @@ func (p *GoPrinter) VisitCompilationUnit(cu *tree.CompilationUnit, param any) tr out.Append("import") grouped := tree.FindMarker[tree.GroupedImport](cu.Imports.Markers) - if grouped != nil { + isGrouped := grouped != nil + if isGrouped { p.visitSpace(grouped.Before, out) out.Append("(") } for _, rp := range cu.Imports.Elements { + block := tree.FindMarker[tree.ImportBlock](rp.Element.Markers) + if block != nil { + if block.ClosePrevious { + out.Append(")") + } + p.visitSpace(block.Before, out) + out.Append("import") + if block.Grouped { + p.visitSpace(block.GroupedBefore, out) + out.Append("(") + } + isGrouped = block.Grouped + } p.Visit(rp.Element, out) p.visitSpace(rp.After, out) } - if grouped != nil { + if isGrouped { out.Append(")") } } diff --git a/rewrite-go/rewrite/pkg/rpc/space_rpc.go b/rewrite-go/rewrite/pkg/rpc/space_rpc.go index eb904767b54..61e06b326e3 100644 --- a/rewrite-go/rewrite/pkg/rpc/space_rpc.go +++ b/rewrite-go/rewrite/pkg/rpc/space_rpc.go @@ -119,6 +119,13 @@ func sendMarkerCodecFields(v any, q *SendQueue) { // GroupedImport.rpcSend sends: id (UUID string), before whitespace (string) q.GetAndSend(m, func(x any) any { return x.(tree.GroupedImport).Ident.String() }, nil) q.GetAndSend(m, func(x any) any { return x.(tree.GroupedImport).Before.Whitespace }, nil) + case tree.ImportBlock: + // ImportBlock.rpcSend sends: id, closePrevious, before, grouped, groupedBefore + q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Ident.String() }, nil) + q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).ClosePrevious }, nil) + q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Before.Whitespace }, nil) + q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).Grouped }, nil) + q.GetAndSend(m, func(x any) any { return x.(tree.ImportBlock).GroupedBefore.Whitespace }, nil) case tree.ShortVarDecl: q.GetAndSend(m, func(x any) any { return x.(tree.ShortVarDecl).Ident.String() }, nil) case tree.VarKeyword: @@ -193,6 +200,21 @@ func receiveMarkersCodec(q *ReceiveQueue, before tree.Markers) tree.Markers { ws := receiveScalar[string](q, m.Before.Whitespace) m.Before = tree.Space{Whitespace: ws} return m + case tree.ImportBlock: + // ImportBlock.rpcReceive: id, closePrevious, before, grouped, groupedBefore + idStr := receiveScalar[string](q, m.Ident.String()) + if idStr != "" { + if parsed, err := uuid.Parse(idStr); err == nil { + m.Ident = parsed + } + } + m.ClosePrevious = receiveScalar[bool](q, m.ClosePrevious) + ws := receiveScalar[string](q, m.Before.Whitespace) + m.Before = tree.Space{Whitespace: ws} + m.Grouped = receiveScalar[bool](q, m.Grouped) + gbWs := receiveScalar[string](q, m.GroupedBefore.Whitespace) + m.GroupedBefore = tree.Space{Whitespace: gbWs} + return m case tree.ShortVarDecl: idStr := receiveScalar[string](q, m.Ident.String()) if idStr != "" { diff --git a/rewrite-go/rewrite/pkg/rpc/value_types.go b/rewrite-go/rewrite/pkg/rpc/value_types.go index 44cab46ab42..7d1046c6ced 100644 --- a/rewrite-go/rewrite/pkg/rpc/value_types.go +++ b/rewrite-go/rewrite/pkg/rpc/value_types.go @@ -86,6 +86,7 @@ func init() { // Go-specific marker valueType registrations (for send-side type resolution) RegisterValueType(reflect.TypeOf(tree.GroupedImport{}), "org.openrewrite.golang.marker.GroupedImport") + RegisterValueType(reflect.TypeOf(tree.ImportBlock{}), "org.openrewrite.golang.marker.ImportBlock") RegisterValueType(reflect.TypeOf(tree.ShortVarDecl{}), "org.openrewrite.golang.marker.ShortVarDecl") RegisterValueType(reflect.TypeOf(tree.VarKeyword{}), "org.openrewrite.golang.marker.VarKeyword") RegisterValueType(reflect.TypeOf(tree.ConstDecl{}), "org.openrewrite.golang.marker.ConstDecl") @@ -175,6 +176,8 @@ func init() { RegisterFactory("org.openrewrite.marker.SearchResult", func() any { return tree.SearchResult{} }) // GroupedImport: IS an RpcCodec, sends 2 sub-fields (id, before whitespace) RegisterFactory("org.openrewrite.golang.marker.GroupedImport", func() any { return tree.GroupedImport{} }) + // ImportBlock: IS an RpcCodec, sends 5 sub-fields (id, closePrevious, before, grouped, groupedBefore) + RegisterFactory("org.openrewrite.golang.marker.ImportBlock", func() any { return tree.ImportBlock{} }) // Go-specific markers: all are RpcCodec RegisterFactory("org.openrewrite.golang.marker.ShortVarDecl", func() any { return tree.ShortVarDecl{} }) RegisterFactory("org.openrewrite.golang.marker.VarKeyword", func() any { return tree.VarKeyword{} }) diff --git a/rewrite-go/rewrite/pkg/tree/go.go b/rewrite-go/rewrite/pkg/tree/go.go index 2d8eec2e3df..97f7eb1d1aa 100644 --- a/rewrite-go/rewrite/pkg/tree/go.go +++ b/rewrite-go/rewrite/pkg/tree/go.go @@ -52,6 +52,18 @@ func (n *CompilationUnit) WithStatements(statements []RightPadded[Statement]) *C return &c } +func (n *CompilationUnit) WithPackageDecl(pkg *RightPadded[*Identifier]) *CompilationUnit { + c := *n + c.PackageDecl = pkg + return &c +} + +func (n *CompilationUnit) WithImports(imports *Container[*Import]) *CompilationUnit { + c := *n + c.Imports = imports + return &c +} + func (n *CompilationUnit) WithEOF(eof Space) *CompilationUnit { c := *n c.EOF = eof @@ -462,6 +474,19 @@ type GroupedImport struct { func (g GroupedImport) ID() uuid.UUID { return g.Ident } +// ImportBlock is a marker on the first Import of a subsequent import block +// (2nd, 3rd, etc.) in files with multiple import declarations. It carries +// the information needed to print the block boundary. +type ImportBlock struct { + Ident uuid.UUID + ClosePrevious bool // true if the previous block was grouped (need to print ")") + Before Space // space before the "import" keyword + Grouped bool // true if this block uses import (...) + GroupedBefore Space // space between "import" and "(" (only if Grouped) +} + +func (b ImportBlock) ID() uuid.UUID { return b.Ident } + // MultiAssignment represents a multi-value assignment: `x, y = 1, 2` or `x, y := f()`. type MultiAssignment struct { ID uuid.UUID diff --git a/rewrite-go/rewrite/pkg/visitor/go_visitor.go b/rewrite-go/rewrite/pkg/visitor/go_visitor.go index 7e1a9be2654..b519c874a40 100644 --- a/rewrite-go/rewrite/pkg/visitor/go_visitor.go +++ b/rewrite-go/rewrite/pkg/visitor/go_visitor.go @@ -218,6 +218,19 @@ var _ VisitorI = (*GoVisitor)(nil) func (v *GoVisitor) VisitCompilationUnit(cu *tree.CompilationUnit, p any) tree.J { cu = cu.WithPrefix(v.self().VisitSpace(cu.Prefix, p)) cu = cu.WithMarkers(v.visitMarkers(cu.Markers, p)) + if cu.PackageDecl != nil { + pkg := *cu.PackageDecl + pkg.Element = visitAndCast[*tree.Identifier](v, pkg.Element, p) + pkg.After = v.self().VisitSpace(pkg.After, p) + cu = cu.WithPackageDecl(&pkg) + } + if cu.Imports != nil { + imports := *cu.Imports + imports.Before = v.self().VisitSpace(imports.Before, p) + imports.Markers = v.visitMarkers(imports.Markers, p) + imports.Elements = visitRightPaddedList(v, imports.Elements, p) + cu = cu.WithImports(&imports) + } cu = cu.WithStatements(visitRightPaddedList(v, cu.Statements, p)) cu = cu.WithEOF(v.self().VisitSpace(cu.EOF, p)) return cu @@ -552,11 +565,18 @@ func (v *GoVisitor) visitMarkers(markers tree.Markers, p any) tree.Markers { func visitAndCast[T tree.Tree](v *GoVisitor, t tree.Tree, p any) T { result := v.self().Visit(t, p) + if result == nil { + var zero T + return zero + } return result.(T) } func visitExpression(v *GoVisitor, expr tree.Expression, p any) tree.Expression { result := v.self().Visit(expr, p) + if result == nil { + return nil + } return result.(tree.Expression) } diff --git a/rewrite-go/rewrite/test/import_test.go b/rewrite-go/rewrite/test/import_test.go index 58a4c0d1fa1..844b6977ab9 100644 --- a/rewrite-go/rewrite/test/import_test.go +++ b/rewrite-go/rewrite/test/import_test.go @@ -48,3 +48,35 @@ func TestParseGroupedImports(t *testing.T) { } `)) } + +func TestParseMultipleImportBlocks(t *testing.T) { + NewRecipeSpec().RewriteRun(t, + Golang(` + package main + + import "fmt" + import "os" + + func hello() { + } + `)) +} + +func TestParseMultipleGroupedImportBlocks(t *testing.T) { + NewRecipeSpec().RewriteRun(t, + Golang(` + package main + + import ( + "fmt" + ) + + import ( + "os" + "strings" + ) + + func hello() { + } + `)) +} diff --git a/rewrite-go/rewrite/test/visitor_test.go b/rewrite-go/rewrite/test/visitor_test.go new file mode 100644 index 00000000000..8feb11b10a5 --- /dev/null +++ b/rewrite-go/rewrite/test/visitor_test.go @@ -0,0 +1,112 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://docs.moderne.io/licensing/moderne-source-available-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/openrewrite/rewrite/pkg/parser" + "github.com/openrewrite/rewrite/pkg/tree" + "github.com/openrewrite/rewrite/pkg/visitor" +) + +// deletingVisitor returns nil for Return nodes, exercising the nil guard in visitAndCast/visitExpression. +type deletingVisitor struct { + visitor.GoVisitor +} + +func (v *deletingVisitor) VisitReturn(ret *tree.Return, p any) tree.J { + return nil // delete the return statement +} + +func TestVisitorReturningNilDoesNotPanic(t *testing.T) { + src := "package main\n\nfunc foo() int {\n\treturn 1\n}\n" + p := parser.NewGoParser() + cu, err := p.Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + v := visitor.Init(&deletingVisitor{}) + // This should not panic even though VisitReturn returns nil. + result := v.Visit(cu, nil) + if result == nil { + t.Fatal("visitor returned nil for compilation unit") + } +} + +// importCountingVisitor counts how many Import nodes are visited. +type importCountingVisitor struct { + visitor.GoVisitor + count int +} + +func (v *importCountingVisitor) VisitImport(imp *tree.Import, p any) tree.J { + v.count++ + return imp +} + +func TestVisitorVisitsImports(t *testing.T) { + src := "package main\n\nimport (\n\t\"fmt\"\n\t\"os\"\n)\n\nfunc main() {\n}\n" + p := parser.NewGoParser() + cu, err := p.Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + v := visitor.Init(&importCountingVisitor{}) + v.Visit(cu, nil) + if v.count != 2 { + t.Errorf("expected 2 imports visited, got %d", v.count) + } +} + +// identCountingVisitor counts how many Identifier nodes are visited. +type identCountingVisitor struct { + visitor.GoVisitor + names []string +} + +func (v *identCountingVisitor) VisitIdentifier(ident *tree.Identifier, p any) tree.J { + v.names = append(v.names, ident.Name) + return ident +} + +func TestVisitorVisitsPackageDecl(t *testing.T) { + // "main" appears as the package name and as the function name. + // Without visiting PackageDecl, only the function name would be found. + src := "package pkg\n\nfunc foo() {\n}\n" + p := parser.NewGoParser() + cu, err := p.Parse("test.go", src) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + v := visitor.Init(&identCountingVisitor{}) + v.Visit(cu, nil) + + found := false + for _, name := range v.names { + if name == "pkg" { + found = true + break + } + } + if !found { + t.Errorf("visitor did not visit package decl identifier 'pkg'; visited: %v", v.names) + } +} diff --git a/rewrite-go/src/main/java/org/openrewrite/golang/marker/ImportBlock.java b/rewrite-go/src/main/java/org/openrewrite/golang/marker/ImportBlock.java new file mode 100644 index 00000000000..e3fc1ef0776 --- /dev/null +++ b/rewrite-go/src/main/java/org/openrewrite/golang/marker/ImportBlock.java @@ -0,0 +1,55 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.golang.marker; + +import lombok.Value; +import lombok.With; +import org.openrewrite.java.tree.Space; +import org.openrewrite.marker.Marker; +import org.openrewrite.rpc.RpcCodec; +import org.openrewrite.rpc.RpcReceiveQueue; +import org.openrewrite.rpc.RpcSendQueue; + +import java.util.UUID; + +@Value +@With +public class ImportBlock implements Marker, RpcCodec { + UUID id; + boolean closePrevious; + Space before; + boolean grouped; + Space groupedBefore; + + @Override + public void rpcSend(ImportBlock after, RpcSendQueue q) { + q.getAndSend(after, Marker::getId); + q.getAndSend(after, ImportBlock::isClosePrevious); + q.getAndSend(after, b -> b.getBefore().getWhitespace()); + q.getAndSend(after, ImportBlock::isGrouped); + q.getAndSend(after, b -> b.getGroupedBefore().getWhitespace()); + } + + @Override + public ImportBlock rpcReceive(ImportBlock before, RpcReceiveQueue q) { + return before + .withId(q.receiveAndGet(before.getId(), UUID::fromString)) + .withClosePrevious(q.receiveAndGet(before.isClosePrevious(), Boolean::parseBoolean)) + .withBefore(Space.format(q.receive(before.getBefore() == null ? "" : before.getBefore().getWhitespace()))) + .withGrouped(q.receiveAndGet(before.isGrouped(), Boolean::parseBoolean)) + .withGroupedBefore(Space.format(q.receive(before.getGroupedBefore() == null ? "" : before.getGroupedBefore().getWhitespace()))); + } +}