diff --git a/cmd/errtrace/slices_contains_go121.go b/cmd/errtrace/slices_contains_go121.go new file mode 100644 index 0000000..c6188ca --- /dev/null +++ b/cmd/errtrace/slices_contains_go121.go @@ -0,0 +1,9 @@ +//go:build go1.21 + +package main + +import "slices" + +func slicesContains[T comparable](s []T, find T) bool { + return slices.Contains[[]T](s, find) +} diff --git a/cmd/errtrace/slices_contains_pre_go121.go b/cmd/errtrace/slices_contains_pre_go121.go new file mode 100644 index 0000000..ec19412 --- /dev/null +++ b/cmd/errtrace/slices_contains_pre_go121.go @@ -0,0 +1,12 @@ +//go:build !go1.21 + +package main + +func slicesContains[T comparable](s []T, find T) bool { + for _, v := range s { + if v == find { + return true + } + } + return false +} diff --git a/cmd/errtrace/testdata/toolexec-test/main_test.go b/cmd/errtrace/testdata/toolexec-test/main_test.go new file mode 100644 index 0000000..22aed24 --- /dev/null +++ b/cmd/errtrace/testdata/toolexec-test/main_test.go @@ -0,0 +1,7 @@ +package main + +import "testing" + +func TestFoo(t *testing.T) { + t.Errorf("fail") +} diff --git a/cmd/errtrace/toolexec.go b/cmd/errtrace/toolexec.go index ed77d14..d21cb42 100644 --- a/cmd/errtrace/toolexec.go +++ b/cmd/errtrace/toolexec.go @@ -71,9 +71,7 @@ func (cmd *mainCmd) toolExecRewrite(pkg string, args []string) (exitCode int) { } // We only modify files that import errtrace, so stdlib is never eliglble. - // To avoid unnecessary parsing, use a heuristic to detect stdlib packages -- - // whether the name contains ".". - if !strings.Contains(pkg, ".") { + if isStdLib(args) { return cmd.runOriginal(args) } @@ -224,3 +222,8 @@ func readBuildSHA() (_ string, ok bool) { } return sha, sha != "" } + +// isStdLib checks if the current execution is for stdlib. +func isStdLib(args []string) bool { + return slicesContains(args, "-std") +} diff --git a/cmd/errtrace/toolexec_test.go b/cmd/errtrace/toolexec_test.go index 5e590d9..7985a16 100644 --- a/cmd/errtrace/toolexec_test.go +++ b/cmd/errtrace/toolexec_test.go @@ -56,23 +56,87 @@ func TestToolExec(t *testing.T) { } sort.Strings(wantTraces) - t.Run("no toolexec", func(t *testing.T) { - stdout, _ := runGo(t, testProg, "run", ".") - if lines := fileLines(stdout); len(lines) > 0 { - t.Errorf("expected no file:line, got %v", lines) - } - }) + tests := []struct { + name string + goArgs func(t testing.TB) []string + wantTraces []string + }{ + { + name: "no toolexec", + goArgs: func(t testing.TB) []string { + return []string{"."} + }, + wantTraces: nil, + }, + { + name: "toolexec with pkg", + goArgs: func(t testing.TB) []string { + return []string{"-toolexec", errTraceCmd, "."} + }, + wantTraces: wantTraces, + }, + { + name: "toolexec with files", + goArgs: func(t testing.TB) []string { + files, err := goListFiles([]string{testProg}) + if err != nil { + t.Fatalf("list go files in %v: %v", testProg, err) + } + + // TODO: Once go.1.20 is dropped, we can use slices.DeleteFunc + nonTest := files[:0] + for _, file := range files { + if !strings.HasSuffix(file, "_test.go") { + nonTest = append(nonTest, file) + } + } + + args := []string{"-toolexec", errTraceCmd} + args = append(args, nonTest...) + return args + }, + wantTraces: wantTraces, + }, + } - t.Run("with toolexec", func(t *testing.T) { - stdout, _ := runGo(t, testProg, "run", "-toolexec", errTraceCmd, ".") - gotLines := fileLines(stdout) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := tt.goArgs(t) - sort.Strings(gotLines) - if d := diff.Diff(wantTraces, gotLines); d != "" { - t.Errorf("diff in traces:\n%s", d) - t.Errorf("go run output:\n%s", stdout) - } - }) + verify := func(t testing.TB, stdout string) { + gotLines := fileLines(stdout) + sort.Strings(gotLines) + + if d := diff.Diff(tt.wantTraces, gotLines); d != "" { + t.Errorf("diff in traces:\n%s", d) + t.Errorf("go run output:\n%s", stdout) + } + } + + t.Run("go run", func(t *testing.T) { + runArgs := append([]string{"run"}, args...) + stdout, _ := runGo(t, testProg, runArgs...) + verify(t, stdout) + }) + + t.Run("go build", func(t *testing.T) { + outExe := filepath.Join(t.TempDir(), "main") + if runtime.GOOS == "windows" { + outExe += ".exe" + } + + runArgs := append([]string{"build", "-o", outExe}, args...) + runGo(t, testProg, runArgs...) + + cmd := exec.Command(outExe) + output, err := cmd.Output() + if err != nil { + t.Fatalf("run built binary: %v", err) + } + verify(t, string(output)) + }) + }) + } } func findTraceLines(t testing.TB, file string) []int {