diff --git a/.github/actions/setup-go/action.yaml b/.github/actions/setup-go/action.yaml new file mode 100644 index 0000000..8c1d5f6 --- /dev/null +++ b/.github/actions/setup-go/action.yaml @@ -0,0 +1,24 @@ +name: "Setup Go" +description: | + Sets up the Go environment for tests, builds, etc. +inputs: + version: + description: "The Go version to use." + default: "1.26.2" + use-cache: + description: "Whether to use the cache." + default: "true" +runs: + using: "composite" + steps: + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 + with: + go-version: ${{ inputs.version }} + cache: ${{ inputs.use-cache }} + + # It isn't necessary that we ever do this, but it helps separate the "setup" + # from the "run" times. + - name: go mod download + shell: bash + run: go mod download -x diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..1fde89c --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,68 @@ +name: quality + +on: + push: + branches: + - main + pull_request: + workflow_dispatch: + +permissions: + contents: read + +# Cancel in-progress runs for pull requests when developers push additional +# changes. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +jobs: + fmt: + name: fmt + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + + - name: make fmt + run: make fmt + + - name: Check unstaged + run: | + if [[ -n $(git ls-files --other --modified --exclude-standard) ]]; then + echo "Unexpected difference in directories after formatting. Run 'make fmt' and include the output in the commit." + exit 1 + fi + + lint: + name: lint + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + + - name: Setup Go + uses: ./.github/actions/setup-go + + - name: make lint + run: make lint + + test: + name: test + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + + - name: Setup Go + uses: ./.github/actions/setup-go + + - name: make test + run: make test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ac29b5a --- /dev/null +++ b/.gitignore @@ -0,0 +1,34 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Code coverage profiles and other test artifacts +*.out +coverage.* +*.coverprofile +profile.cov + +# Go workspace file +go.work +go.work.sum + +# env file +.env + +# Editor/IDE +.idea/ +.vscode/ + +# Key files +*.key +*.pub +*.pem + +# Output directory +build/ diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..e29a148 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,28 @@ +version: "2" + +linters: + exclusions: + rules: + - path: _test\.go + linters: + - gosec + text: "G304: Potential file inclusion via variable" + enable: + - goconst + - gocritic + - gosec + - misspell + - nakedret + - revive + - unconvert + - unparam + settings: + govet: + enable: + - shadow + misspell: + locale: US + revive: + rules: + - name: package-comments + disabled: true diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..edb389b --- /dev/null +++ b/.prettierrc @@ -0,0 +1,14 @@ +{ + "printWidth": 120, + "semi": false, + "trailingComma": "all", + "overrides": [ + { + "files": ["./*.md", "./**/*.md"], + "options": { + "printWidth": 80, + "proseWrap": "always" + } + } + ] +} diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..690c38c --- /dev/null +++ b/Makefile @@ -0,0 +1,27 @@ +FIND_EXCLUSIONS= \ + -not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path '*/.terraform/*' \) -prune \) +GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go') +GO_FMT_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -print0 | xargs -0 grep -E --null -L '^// Code generated .* DO NOT EDIT\.$$' | tr '\0' ' ') + +default: build + +build/whichtests: $(GO_SRC_FILES) go.mod go.sum + mkdir -p ./build + go build -o ./build/whichtests . + +build: build/whichtests +.PHONY: build + +fmt: + go mod tidy + go run golang.org/x/tools/cmd/goimports@v0.35.0 -w $(GO_FMT_FILES) + go run mvdan.cc/gofumpt@v0.8.0 -w -l $(GO_FMT_FILES) +.PHONY: fmt + +lint: + go run github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.4.0 run ./... +.PHONY: lint + +test: + go test -test.v -timeout 30s -cover ./... +.PHONY: test diff --git a/README.md b/README.md new file mode 100644 index 0000000..7fcda2f --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# whichtests + +`whichtests` is the Go test-plan generator that drives the `flake-go` CI +workflow in `coder/coder`. Given a base/head git revision pair (or a +GitHub Actions event), it walks the diff, parses each changed test +file, picks the smallest set of tests to rerun, and emits a workflow +matrix plus a human-readable Markdown summary. + +## Building and running + +```sh +go build ./ +./whichtests --help +``` + +Typical invocation against the local working tree: + +```sh +./whichtests \ + --repo-root . \ + --base-sha origin/main \ + --head-sha HEAD \ + --out-matrix ./flake-matrix.json \ + --out-summary - +``` + +In GitHub Actions: + +```sh +go run ./ \ + --repo-root . \ + --github-actions \ + --out-matrix "$RUNNER_TEMP/flake-matrix.json" +``` + +For `pull_request` events, checkout must use the PR head SHA, for example `github.event.pull_request.head.sha`. The default synthetic merge ref is rejected because the checked-out `HEAD` must match `pull_request.head.sha`. + +The matrix JSON contains `include` rows with `package`, `run_regex`, and `test_count`. `package` is normally one safe Go package pattern. If the matrix cap is hit, the final overflow row stores a space-separated list of safe package tokens in `package`, leaves `run_regex` empty, and sets `test_count` to `1`; this is the contract consumed by the current `flake-go` workflow. + +## File layout + +The binary is a single `package main`, split into focused files: + +| File | Responsibility | +| --------------- | ------------------------------------------------------------------- | +| `cli.go` | `main`, flag parsing, command orchestration (`runCommand`). | +| `config.go` | `config` / `commandConfig` types and defaults. | +| `request.go` | `runRequest`, `diffRange`, revision validation. | +| `gitexec.go` | `gitRunner` / `gitFetcher` types and the real `exec.Command` impl. | +| `diff.go` | Reading and parsing `git diff`, change kinds, hunks, line ranges. | +| `snapshot.go` | AST snapshot parsing, `fileSnapshot`, and `sharedDecl`. | +| `broadening.go` | Per-kind broadening rules (`broadeningScope`). | +| `selection.go` | Per-change selection logic (`selectChange`, broaden vs narrow). | +| `inventory.go` | `inventoryCache` for package/directory test discovery. | +| `plan.go` | Plan construction, matrix and summary rendering (`buildExecutionPlan`, `selectTestPlan`). | +| `githubactions.go` | GitHub Actions request builder and history preparation. | +| `publish.go` | Single sink for matrix and summary outputs. | + +## Testing + +```sh +go test ./... +``` diff --git a/broadening.go b/broadening.go new file mode 100644 index 0000000..969bf7b --- /dev/null +++ b/broadening.go @@ -0,0 +1,60 @@ +package main + +type broadeningScope uint8 + +const ( + broadeningNone broadeningScope = iota + broadeningPackage + broadeningDirectory +) + +func broadeningScopeForOldHunk(decls []sharedDecl, candidate lineRange) broadeningScope { + scope := broadeningNone + for _, decl := range decls { + if !decl.Range.overlaps(candidate) { + continue + } + scope = max(scope, decl.broadeningScopeOnOldSide()) + } + return scope +} + +func broadeningScopeForNewHunk(decls []sharedDecl, oldSnapshot *fileSnapshot, candidate lineRange) broadeningScope { + scope := broadeningNone + for _, decl := range decls { + if !decl.Range.overlaps(candidate) { + continue + } + scope = max(scope, decl.broadeningScopeOnNewSide(oldSnapshot)) + } + return scope +} + +func (decl sharedDecl) broadeningScopeOnOldSide() broadeningScope { + switch decl.Kind { + case sharedDeclInit, sharedDeclTestMain: + // Go builds package and package_test files into one test binary. + // Init and TestMain changes can affect every test in the directory. + return broadeningDirectory + case sharedDeclImport, sharedDeclVar, sharedDeclConst, sharedDeclType, sharedDeclHelper: + return broadeningPackage + } + return broadeningNone +} + +func (decl sharedDecl) broadeningScopeOnNewSide(oldSnapshot *fileSnapshot) broadeningScope { + switch decl.Kind { + // TODO: Decide whether new imports should narrow to tests that still + // reference package-local declarations. Today any import edit broadens + // the package. + case sharedDeclImport: + return broadeningPackage + case sharedDeclInit, sharedDeclTestMain: + return broadeningDirectory + case sharedDeclVar, sharedDeclConst, sharedDeclType, sharedDeclHelper: + if oldSnapshot != nil && oldSnapshot.hasAnySharedKey(decl.Keys) { + return broadeningPackage + } + } + return broadeningNone +} diff --git a/broadening_test.go b/broadening_test.go new file mode 100644 index 0000000..bd67685 --- /dev/null +++ b/broadening_test.go @@ -0,0 +1,51 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBroadeningScopeForOldHunkChoosesMaxOverlappingScope(t *testing.T) { + t.Parallel() + + data := []byte(`package sample + +import "testing" + +func init() { + register() +} + +func TestAlpha(t *testing.T) {} +`) + snapshot, err := parseFileSnapshot(data) + require.NoError(t, err) + candidate := rangeSpan( + singleLineRange(t, string(data), `import "testing"`), + singleLineRange(t, string(data), "register()"), + ) + require.Equal(t, broadeningDirectory, broadeningScopeForOldHunk(snapshot.shared, candidate)) +} + +func TestBroadeningScopeForNewHunkChoosesMaxOverlappingScope(t *testing.T) { + t.Parallel() + + data := []byte(`package sample + +import "testing" + +func TestMain(m *testing.M) { + m.Run() +} + +func TestAlpha(t *testing.T) {} +`) + snapshot, err := parseFileSnapshot(data) + require.NoError(t, err) + candidate := rangeSpan( + singleLineRange(t, string(data), `import "testing"`), + singleLineRange(t, string(data), "m.Run()"), + ) + require.Equal(t, broadeningDirectory, broadeningScopeForNewHunk(snapshot.shared, nil, candidate)) +} diff --git a/cli.go b/cli.go new file mode 100644 index 0000000..0c24ac3 --- /dev/null +++ b/cli.go @@ -0,0 +1,95 @@ +// Command whichtests produces deterministic Go test plans for the +// flake-go workflow. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "io" + "os" +) + +func main() { + cfg := defaultCommandConfig() + flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flags.StringVar(&cfg.RepoRoot, "repo-root", cfg.RepoRoot, "repository root") + flags.StringVar(&cfg.BaseSHA, "base-sha", cfg.BaseSHA, "base revision to diff against") + flags.StringVar(&cfg.HeadSHA, "head-sha", cfg.HeadSHA, "head revision to diff against") + flags.StringVar(&cfg.OutMatrix, "out-matrix", cfg.OutMatrix, "path to write workflow matrix JSON") + flags.StringVar(&cfg.OutSummary, "out-summary", cfg.OutSummary, "path to write Markdown summary, or - for stdout") + flags.BoolVar(&cfg.GitHubActions, "github-actions", cfg.GitHubActions, "read diff range and output paths from GitHub Actions environment") + if err := flags.Parse(os.Args[1:]); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + os.Exit(2) + } + if err := runCommand(context.Background(), cfg, os.Stdout, os.Stderr, execGit, execGitFetch); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func runCommand(ctx context.Context, cfg commandConfig, stdout, stderr io.Writer, git gitRunner, fetch gitFetcher) error { + var ( + req runRequest + err error + ) + if cfg.GitHubActions { + req, err = githubActionsRunRequest(ctx, cfg, git) + } else { + req, err = explicitRunRequest(cfg.config) + } + if err != nil { + return err + } + return executeRunRequest(ctx, req, stdout, stderr, git, fetch) +} + +func explicitRunRequest(cfg config) (runRequest, error) { + cfg = cfg.withDefaults() + if cfg.BaseSHA == "" { + return runRequest{}, errors.New("--base-sha is required") + } + if cfg.OutMatrix == "" { + return runRequest{}, errors.New("--out-matrix is required") + } + if err := validateRevisionArg("--base-sha", cfg.BaseSHA); err != nil { + return runRequest{}, err + } + if err := validateRevisionArg("--head-sha", cfg.HeadSHA); err != nil { + return runRequest{}, err + } + return runRequest{ + RepoRoot: cfg.RepoRoot, + Range: diffRange{ + BaseSHA: cfg.BaseSHA, + HeadSHA: cfg.HeadSHA, + }, + Sinks: outputSinks{ + OutMatrix: cfg.OutMatrix, + OutSummary: cfg.OutSummary, + }, + }, nil +} + +func executeRunRequest(ctx context.Context, req runRequest, stdout, stderr io.Writer, git gitRunner, fetch gitFetcher) error { + if err := ensureRangeAvailable(ctx, &req, git, fetch); err != nil { + return err + } + selectorCfg := config{ + RepoRoot: req.RepoRoot, + BaseSHA: req.Range.BaseSHA, + HeadSHA: req.Range.HeadSHA, + } + changedFiles, result, err := selectTestPlan(ctx, selectorCfg, git) + if err != nil { + return err + } + summary := renderSummary(changedFiles, result.Summary) + if err := publishPlan(req.Sinks, result.Matrix, summary, stdout); err != nil { + return err + } + _, _ = fmt.Fprintf(stderr, "selected %d package targets from %d changed test files\n", len(result.Matrix.Include), len(changedFiles)) + return nil +} diff --git a/cli_test.go b/cli_test.go new file mode 100644 index 0000000..7e5b478 --- /dev/null +++ b/cli_test.go @@ -0,0 +1,744 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRunValidationErrors(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + neverGit := func(_ context.Context, _ string, _ ...string) (gitResult, error) { + return gitResult{}, errors.New("git should not be called") + } + + err := runCommand(t.Context(), commandConfig{config: config{OutMatrix: "matrix.json"}}, &stdout, &stderr, neverGit, nil) + require.EqualError(t, err, "--base-sha is required") + + err = runCommand(t.Context(), commandConfig{config: config{BaseSHA: "base"}}, &stdout, &stderr, neverGit, nil) + require.EqualError(t, err, "--out-matrix is required") + + err = runCommand(t.Context(), commandConfig{config: config{BaseSHA: "-bad", OutMatrix: "matrix.json"}}, &stdout, &stderr, neverGit, nil) + require.ErrorContains(t, err, "must not start with '-'") + + err = runCommand(t.Context(), commandConfig{config: config{BaseSHA: "base:bad", OutMatrix: "matrix.json"}}, &stdout, &stderr, neverGit, nil) + require.ErrorContains(t, err, "must not contain ':'") + + err = runCommand(t.Context(), commandConfig{config: config{BaseSHA: "base\x00bad", OutMatrix: "matrix.json"}}, &stdout, &stderr, neverGit, nil) + require.ErrorContains(t, err, "must not contain NUL bytes") +} + +func TestRunWritesMatrixAndSummaryWithPackageScopedEntries(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + baseFiles := map[string]string{ + "pkgone/shared_test.go": `package one + +import "testing" + +func TestShared(t *testing.T) { + t.Log("before one") +} +`, + "pkgtwo/shared_test.go": `package two + +import "testing" + +func TestShared(t *testing.T) { + t.Log("before two") +} +`, + } + headFiles := map[string]string{ + "pkgone/shared_test.go": `package one + +import "testing" + +func TestShared(t *testing.T) { + t.Log("changed one") +} +`, + "pkgtwo/shared_test.go": `package two + +import "testing" + +func TestShared(t *testing.T) { + t.Log("changed two") +} +`, + } + repo := fakeGitRepo{ + changes: []testFileChange{ + {Kind: changeModified, OldPath: "pkgone/shared_test.go", NewPath: "pkgone/shared_test.go"}, + {Kind: changeModified, OldPath: "pkgtwo/shared_test.go", NewPath: "pkgtwo/shared_test.go"}, + }, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + "pkgone/shared_test.go": diffForChange( + singleLineRange(t, baseFiles["pkgone/shared_test.go"], `t.Log("before one")`), + singleLineRange(t, headFiles["pkgone/shared_test.go"], `t.Log("changed one")`), + ), + "pkgtwo/shared_test.go": diffForChange( + singleLineRange(t, baseFiles["pkgtwo/shared_test.go"], `t.Log("before two")`), + singleLineRange(t, headFiles["pkgtwo/shared_test.go"], `t.Log("changed two")`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + summaryPath := filepath.Join(repoRoot, "summary.md") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath, OutSummary: summaryPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + require.Empty(t, stdout.String()) + require.Contains(t, stderr.String(), "selected 2 package targets") + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 2) + require.Equal(t, "./pkgone", matrix.Include[0].Package) + require.Equal(t, "^(TestShared)(/.*)?$", matrix.Include[0].RunRegex) + require.Equal(t, "10", matrix.Include[0].TestCount) + require.Equal(t, "./pkgtwo", matrix.Include[1].Package) + + summary, err := os.ReadFile(summaryPath) + require.NoError(t, err) + require.Contains(t, string(summary), "Selected 2 tests across 2 package targets") + require.Contains(t, string(summary), "### `./pkgone`") + require.Contains(t, string(summary), "### `./pkgtwo`") +} + +func TestRunWritesSummaryToStdout(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + baseFiles := map[string]string{ + "pkg/sample_test.go": `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("before") +} +`, + } + headFiles := map[string]string{ + "pkg/sample_test.go": `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("after") +} +`, + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeModified, OldPath: "pkg/sample_test.go", NewPath: "pkg/sample_test.go"}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + "pkg/sample_test.go": diffForChange( + singleLineRange(t, baseFiles["pkg/sample_test.go"], `t.Log("before")`), + singleLineRange(t, headFiles["pkg/sample_test.go"], `t.Log("after")`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath, OutSummary: "-"}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + require.Contains(t, stdout.String(), "## Go test flake detector selection") + require.Contains(t, stdout.String(), "### `./pkg`") + require.Contains(t, stderr.String(), "selected 1 package targets") +} + +func TestRunBroadensTestMainAcrossPackageAndPackageTest(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + baseFiles := map[string]string{ + "pkg/setup_test.go": `package sample + +import ( + "os" + "testing" +) + +func TestMain(m *testing.M) { + os.Exit(m.Run()) +} +`, + "pkg/internal_test.go": `package sample + +import "testing" + +func TestInternal(t *testing.T) { + t.Log("internal") +} +`, + "pkg/external_test.go": `package sample_test + +import "testing" + +func TestExternal(t *testing.T) { + t.Log("external") +} +`, + } + headFiles := map[string]string{ + "pkg/setup_test.go": `package sample + +import ( + "fmt" + "os" + "testing" +) + +func TestMain(m *testing.M) { + fmt.Println("setup") + os.Exit(m.Run()) +} +`, + "pkg/internal_test.go": baseFiles["pkg/internal_test.go"], + "pkg/external_test.go": baseFiles["pkg/external_test.go"], + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeModified, OldPath: "pkg/setup_test.go", NewPath: "pkg/setup_test.go"}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + "pkg/setup_test.go": diffForChange( + singleLineRange(t, baseFiles["pkg/setup_test.go"], `os.Exit(m.Run())`), + singleLineRange(t, headFiles["pkg/setup_test.go"], `fmt.Println("setup")`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + summaryPath := filepath.Join(repoRoot, "summary.md") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath, OutSummary: summaryPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "./pkg", matrix.Include[0].Package) + require.Equal(t, "^(TestExternal|TestInternal)(/.*)?$", matrix.Include[0].RunRegex) + + summary, err := os.ReadFile(summaryPath) + require.NoError(t, err) + require.Contains(t, string(summary), "TestInternal") + require.Contains(t, string(summary), "TestExternal") +} + +func TestRunHandlesRename(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + oldPath := "pkg/old_test.go" + newPath := "pkg/new_test.go" + baseFiles := map[string]string{ + oldPath: `package sample + +import "testing" + +func TestRenamed(t *testing.T) { + t.Log("before rename") +} +`, + } + headFiles := map[string]string{ + newPath: `package sample + +import "testing" + +func TestRenamed(t *testing.T) { + t.Log("after rename") +} +`, + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeRenamed, OldPath: oldPath, NewPath: newPath}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + oldPath + "\x00" + newPath: diffForChange( + singleLineRange(t, baseFiles[oldPath], `t.Log("before rename")`), + singleLineRange(t, headFiles[newPath], `t.Log("after rename")`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + summaryPath := filepath.Join(repoRoot, "summary.md") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath, OutSummary: summaryPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "./pkg", matrix.Include[0].Package) + require.Equal(t, "^(TestRenamed)(/.*)?$", matrix.Include[0].RunRegex) + + summary, err := os.ReadFile(summaryPath) + require.NoError(t, err) + require.Contains(t, string(summary), newPath) +} + +func TestRunUsesHeadRevisionInsteadOfWorkingTree(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + writeTestFile(t, repoRoot, "pkg/sample_test.go", `package sample + +import "testing" + +func TestWorkingTree(t *testing.T) { + t.Log("working tree") +} +`) + + baseFiles := map[string]string{ + "pkg/sample_test.go": `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +`, + } + headFiles := map[string]string{ + "pkg/sample_test.go": `package sample + +import "testing" + +func TestHead(t *testing.T) { + t.Log("head") +} +`, + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeModified, OldPath: "pkg/sample_test.go", NewPath: "pkg/sample_test.go"}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + "pkg/sample_test.go": diffForChange( + singleLineRange(t, baseFiles["pkg/sample_test.go"], `func TestAlpha`), + singleLineRange(t, headFiles["pkg/sample_test.go"], `func TestHead`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "^(TestHead)(/.*)?$", matrix.Include[0].RunRegex) + require.NotContains(t, string(matrixData), "TestWorkingTree") +} + +func TestRunSkipsNonRunnableChangedTestFiles(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + headFiles := map[string]string{ + "pkg/testdata/example_test.go": `package sample + +import "testing" + +func TestIgnored(t *testing.T) { + t.Log("ignored") +} +`, + "pkg/_ignored_test.go": `package sample + +import "testing" + +func TestUnderscoreIgnored(t *testing.T) { + t.Log("ignored") +} +`, + "pkg/.hidden_test.go": `package sample + +import "testing" + +func TestHiddenIgnored(t *testing.T) { + t.Log("ignored") +} +`, + } + repo := fakeGitRepo{ + changes: []testFileChange{ + {Kind: changeAdded, NewPath: "pkg/testdata/example_test.go"}, + {Kind: changeAdded, NewPath: "pkg/_ignored_test.go"}, + {Kind: changeAdded, NewPath: "pkg/.hidden_test.go"}, + }, + revisions: map[string]map[string]string{"base": {}, "head": headFiles}, + diffOutputs: map[string]string{}, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + summaryPath := filepath.Join(repoRoot, "summary.md") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath, OutSummary: summaryPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Empty(t, matrix.Include) + + summary, err := os.ReadFile(summaryPath) + require.NoError(t, err) + require.Contains(t, string(summary), "No changed `*_test.go` files were detected") +} + +func TestRunToleratesDuplicateRunnableNamesInPackageInventory(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + linuxPath := "pkg/platform_linux_test.go" + windowsPath := "pkg/platform_windows_test.go" + baseFiles := map[string]string{ + linuxPath: `//go:build linux + +package sample + +import "testing" + +func TestPlatform(t *testing.T) { + t.Log("linux before") +} +`, + windowsPath: `//go:build windows + +package sample + +import "testing" + +func TestPlatform(t *testing.T) { + t.Log("windows") +} +`, + } + headFiles := map[string]string{ + linuxPath: `//go:build linux + +package sample + +import "testing" + +func TestPlatform(t *testing.T) { + t.Log("linux after") +} +`, + windowsPath: baseFiles[windowsPath], + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeModified, OldPath: linuxPath, NewPath: linuxPath}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + linuxPath: diffForChange( + singleLineRange(t, baseFiles[linuxPath], `t.Log("linux before")`), + singleLineRange(t, headFiles[linuxPath], `t.Log("linux after")`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "^(TestPlatform)(/.*)?$", matrix.Include[0].RunRegex) +} + +func TestRunHandlesDeletedSetupFile(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + setupPath := "pkg/setup_test.go" + testPath := "pkg/alpha_test.go" + baseFiles := map[string]string{ + setupPath: `package sample + +import "testing" + +func setup(t *testing.T) { + t.Helper() + t.Log("setup") +} +`, + testPath: `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +`, + } + headFiles := map[string]string{ + testPath: baseFiles[testPath], + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeDeleted, OldPath: setupPath}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + setupPath: diffForChange( + singleLineRange(t, baseFiles[setupPath], `t.Log("setup")`), + emptyRangeAt(1), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + summaryPath := filepath.Join(repoRoot, "summary.md") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath, OutSummary: summaryPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "^(TestAlpha)(/.*)?$", matrix.Include[0].RunRegex) + + summary, err := os.ReadFile(summaryPath) + require.NoError(t, err) + require.Contains(t, string(summary), setupPath) + require.Contains(t, string(summary), "TestAlpha") +} + +func TestRunBroadensInitAcrossPackageAndPackageTest(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + setupPath := "pkg/external_setup_test.go" + baseFiles := map[string]string{ + setupPath: `package sample_test + +func init() { + println("before") +} +`, + "pkg/internal_test.go": `package sample + +import "testing" + +func TestInternal(t *testing.T) { + t.Log("internal") +} +`, + "pkg/external_test.go": `package sample_test + +import "testing" + +func TestExternal(t *testing.T) { + t.Log("external") +} +`, + } + headFiles := map[string]string{ + setupPath: `package sample_test + +func init() { + println("after") +} +`, + "pkg/internal_test.go": baseFiles["pkg/internal_test.go"], + "pkg/external_test.go": baseFiles["pkg/external_test.go"], + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeModified, OldPath: setupPath, NewPath: setupPath}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + setupPath: diffForChange( + singleLineRange(t, baseFiles[setupPath], `println("before")`), + singleLineRange(t, headFiles[setupPath], `println("after")`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "^(TestExternal|TestInternal)(/.*)?$", matrix.Include[0].RunRegex) +} + +func TestRunHandlesCrossDirectoryRenamePrecisely(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + oldPath := "oldpkg/moved_test.go" + newPath := "newpkg/moved_test.go" + baseFiles := map[string]string{ + oldPath: `package oldpkg + +import "testing" + +func TestMoved(t *testing.T) { + t.Log("before") +} +`, + "oldpkg/stable_test.go": `package oldpkg + +import "testing" + +func TestOldStable(t *testing.T) { + t.Log("old") +} +`, + } + headFiles := map[string]string{ + newPath: `package newpkg + +import "testing" + +func TestMoved(t *testing.T) { + t.Log("after") +} +`, + "oldpkg/stable_test.go": baseFiles["oldpkg/stable_test.go"], + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeRenamed, OldPath: oldPath, NewPath: newPath}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + oldPath + "\x00" + newPath: diffForChange( + singleLineRange(t, baseFiles[oldPath], `t.Log("before")`), + singleLineRange(t, headFiles[newPath], `t.Log("after")`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "./newpkg", matrix.Include[0].Package) + require.Equal(t, "^(TestMoved)(/.*)?$", matrix.Include[0].RunRegex) +} + +func TestRunHandlesCrossDirectoryRenameSourceFallout(t *testing.T) { + t.Parallel() + + repoRoot := t.TempDir() + oldPath := "oldpkg/setup_test.go" + newPath := "newpkg/setup_test.go" + baseFiles := map[string]string{ + oldPath: `package oldpkg + +import "testing" + +func setup(t *testing.T) { + t.Helper() + t.Log("before") +} +`, + "oldpkg/stable_test.go": `package oldpkg + +import "testing" + +func TestOldStable(t *testing.T) { + t.Log("old") +} +`, + } + headFiles := map[string]string{ + newPath: `package newpkg + +import "testing" + +func setup(t *testing.T) { + t.Helper() + t.Log("after") +} +`, + "oldpkg/stable_test.go": baseFiles["oldpkg/stable_test.go"], + "newpkg/stable_test.go": `package newpkg + +import "testing" + +func TestNewStable(t *testing.T) { + t.Log("new") +} +`, + } + repo := fakeGitRepo{ + changes: []testFileChange{{Kind: changeRenamed, OldPath: oldPath, NewPath: newPath}}, + revisions: map[string]map[string]string{"base": baseFiles, "head": headFiles}, + diffOutputs: map[string]string{ + oldPath + "\x00" + newPath: diffForChange( + singleLineRange(t, baseFiles[oldPath], `t.Log("before")`), + singleLineRange(t, headFiles[newPath], `t.Log("after")`), + ), + }, + } + + matrixPath := filepath.Join(repoRoot, "matrix.json") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: "base", HeadSHA: "head", OutMatrix: matrixPath}}, &stdout, &stderr, repo.runner(t), nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "./oldpkg", matrix.Include[0].Package) + require.Equal(t, "^(TestOldStable)(/.*)?$", matrix.Include[0].RunRegex) +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..db9984d --- /dev/null +++ b/config.go @@ -0,0 +1,48 @@ +package main + +import ( + "cmp" +) + +const ( + defaultRepoRoot = "." + defaultHeadSHA = "HEAD" + defaultOutSummary = "-" + defaultTestCount = "10" + runOnceTestCount = "1" + + // Package-wide and matrix-wide caps keep the detector cheap by + // running broad fallback targets once instead of repeatedly. + maxMatrixEntries = 20 + maxBroadenedTests = 50 + maxOverflowSummaries = 10 +) + +type config struct { + RepoRoot string + BaseSHA string + HeadSHA string + OutMatrix string + OutSummary string +} + +func defaultConfig() config { + return config{}.withDefaults() +} + +func (cfg config) withDefaults() config { + cfg.RepoRoot = cmp.Or(cfg.RepoRoot, defaultRepoRoot) + cfg.HeadSHA = cmp.Or(cfg.HeadSHA, defaultHeadSHA) + cfg.OutSummary = cmp.Or(cfg.OutSummary, defaultOutSummary) + return cfg +} + +type commandConfig struct { + config + + GitHubActions bool +} + +func defaultCommandConfig() commandConfig { + return commandConfig{config: defaultConfig()} +} diff --git a/diff.go b/diff.go new file mode 100644 index 0000000..1780b60 --- /dev/null +++ b/diff.go @@ -0,0 +1,268 @@ +package main + +import ( + "cmp" + "context" + "fmt" + "path/filepath" + "regexp" + "slices" + "strconv" + "strings" +) + +var hunkHeaderRE = regexp.MustCompile(`^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@`) + +type changeKind string + +// changeKind mirrors git diff status letters. T is a type change. +const ( + changeAdded changeKind = "A" + changeDeleted changeKind = "D" + changeModified changeKind = "M" + changeRenamed changeKind = "R" + changeType changeKind = "T" +) + +type testFileChange struct { + Kind changeKind + OldPath string + NewPath string +} + +func (change testFileChange) displayPath() string { + return cmp.Or(change.NewPath, change.OldPath) +} + +func (change testFileChange) pathspecs() []string { + oldPath := cmp.Or(change.OldPath, change.NewPath) + newPath := cmp.Or(change.NewPath, change.OldPath) + if oldPath == "" { + return []string{newPath} + } + if newPath == "" || newPath == oldPath { + return []string{oldPath} + } + return []string{oldPath, newPath} +} + +// lineRange uses End < Start to represent an empty span from a zero-count diff +// hunk. hasLines reports whether the span contains any real source lines. +type lineRange struct { + Start int + End int +} + +type diffHunk struct { + Old lineRange + New lineRange +} + +func newSideOnlyHunks(hunks []diffHunk) []diffHunk { + trimmed := make([]diffHunk, 0, len(hunks)) + for _, hunk := range hunks { + hunk.Old = lineRange{} + trimmed = append(trimmed, hunk) + } + return trimmed +} + +func listChangedTestFiles(ctx context.Context, cfg config, git gitRunner) ([]testFileChange, error) { + result, err := git( + ctx, + cfg.RepoRoot, + "diff", + "--name-status", + "-z", + "--find-renames", + "--diff-filter=ADMRT", + diffRangeSpec(cfg), + ) + if err != nil { + return nil, err + } + if result.Stdout == "" { + return nil, nil + } + + fields := strings.Split(result.Stdout, "\x00") + changes := make([]testFileChange, 0) + for index := 0; index < len(fields); { + status := fields[index] + index++ + if status == "" { + continue + } + kind, err := parseChangeKind(status) + if err != nil { + return nil, err + } + switch kind { + case changeRenamed: + if index+1 >= len(fields) { + return nil, fmt.Errorf("rename status %q is missing paths", status) + } + oldPath := cleanGitPath(fields[index]) + newPath := cleanGitPath(fields[index+1]) + index += 2 + change := testFileChange{Kind: kind, OldPath: oldPath, NewPath: newPath} + if !isRunnableTestFilePath(change.OldPath) && !isRunnableTestFilePath(change.NewPath) { + continue + } + changes = append(changes, change) + default: + if index >= len(fields) { + return nil, fmt.Errorf("status %q is missing a path", status) + } + path := cleanGitPath(fields[index]) + index++ + change := testFileChange{Kind: kind, OldPath: path, NewPath: path} + switch kind { + case changeAdded: + change.OldPath = "" + case changeDeleted: + change.NewPath = "" + } + if !isRunnableTestFilePath(change.displayPath()) { + continue + } + changes = append(changes, change) + } + } + slices.SortFunc(changes, func(left, right testFileChange) int { + return cmp.Compare(left.displayPath(), right.displayPath()) + }) + return changes, nil +} + +func parseChangeKind(status string) (changeKind, error) { + switch { + case strings.HasPrefix(status, string(changeAdded)): + return changeAdded, nil + case strings.HasPrefix(status, string(changeDeleted)): + return changeDeleted, nil + case strings.HasPrefix(status, string(changeModified)): + return changeModified, nil + case strings.HasPrefix(status, string(changeRenamed)): + return changeRenamed, nil + case strings.HasPrefix(status, string(changeType)): + return changeType, nil + default: + return "", fmt.Errorf("unsupported diff status %q", status) + } +} + +func cleanGitPath(path string) string { + return filepath.ToSlash(filepath.Clean(path)) +} + +func isRunnableTestFilePath(path string) bool { + if !strings.HasSuffix(path, "_test.go") { + return false + } + cleanPath := cleanGitPath(path) + baseName := filepath.Base(cleanPath) + if strings.HasPrefix(baseName, ".") || strings.HasPrefix(baseName, "_") { + return false + } + for segment := range strings.SplitSeq(filepath.ToSlash(filepath.Dir(cleanPath)), "/") { + if segment == "." || segment == "" { + continue + } + if segment == "testdata" || segment == "vendor" || strings.HasPrefix(segment, ".") || strings.HasPrefix(segment, "_") { + return false + } + } + return true +} + +func listDiffHunks(ctx context.Context, cfg config, git gitRunner, change testFileChange) ([]diffHunk, error) { + args := []string{"diff", "--unified=0", "--no-color", "--find-renames", diffRangeSpec(cfg), "--"} + args = append(args, change.pathspecs()...) + result, err := git(ctx, cfg.RepoRoot, args...) + if err != nil { + return nil, err + } + return parseDiffHunks(result.Stdout) +} + +func parseDiffHunks(diff string) ([]diffHunk, error) { + hunks := make([]diffHunk, 0) + for line := range strings.Lines(diff) { + line = strings.TrimSuffix(line, "\n") + matches := hunkHeaderRE.FindStringSubmatch(line) + if matches == nil { + continue + } + oldRange, err := parseRange(matches[1], matches[2]) + if err != nil { + return nil, err + } + newRange, err := parseRange(matches[3], matches[4]) + if err != nil { + return nil, err + } + hunks = append(hunks, diffHunk{Old: oldRange, New: newRange}) + } + return hunks, nil +} + +func parseRange(startText, countText string) (lineRange, error) { + start, err := parseNonNegativeInt(startText) + if err != nil { + return lineRange{}, err + } + count := 1 + if countText != "" { + count, err = parseNonNegativeInt(countText) + if err != nil { + return lineRange{}, err + } + } + if count == 0 { + if start == 0 { + start = 1 + } + return lineRange{Start: start, End: start - 1}, nil + } + return lineRange{Start: start, End: start + count - 1}, nil +} + +func parseNonNegativeInt(value string) (int, error) { + parsed, err := strconv.Atoi(value) + if err != nil { + return 0, fmt.Errorf("parse integer %q: %w", value, err) + } + if parsed < 0 { + return 0, fmt.Errorf("negative value %q", value) + } + return parsed, nil +} + +func fileExistsAtRevision(ctx context.Context, cfg config, git gitRunner, revision, filePath string) (bool, error) { + result, err := git(ctx, cfg.RepoRoot, "ls-tree", "-z", "--name-only", revision, "--", filePath) + if err != nil { + return false, fmt.Errorf("check whether %s exists at %s: %w", filePath, revision, err) + } + cleanPath := cleanGitPath(filePath) + for part := range strings.SplitSeq(result.Stdout, "\x00") { + if part == "" { + continue + } + if cleanGitPath(part) == cleanPath { + return true, nil + } + } + return false, nil +} + +func (r lineRange) hasLines() bool { + return r.Start > 0 && r.End >= r.Start +} + +func (r lineRange) overlaps(other lineRange) bool { + if !r.hasLines() || !other.hasLines() { + return false + } + return r.Start <= other.End && other.Start <= r.End +} diff --git a/diff_test.go b/diff_test.go new file mode 100644 index 0000000..a9b379b --- /dev/null +++ b/diff_test.go @@ -0,0 +1,87 @@ +package main + +import ( + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseChangeKind(t *testing.T) { + t.Parallel() + + tests := []struct { + status string + want changeKind + }{ + {status: "A", want: changeAdded}, + {status: "D", want: changeDeleted}, + {status: "M", want: changeModified}, + {status: "R100", want: changeRenamed}, + {status: "T", want: changeType}, + } + for _, tt := range tests { + t.Run(tt.status, func(t *testing.T) { + t.Parallel() + kind, err := parseChangeKind(tt.status) + require.NoError(t, err) + require.Equal(t, tt.want, kind) + }) + } + _, err := parseChangeKind("X") + require.ErrorContains(t, err, "unsupported diff status") +} + +func TestParseDiffHunks(t *testing.T) { + t.Parallel() + + hunks, err := parseDiffHunks(strings.Join([]string{ + "@@ -10 +12 @@", + "@@ -0,0 +5,3 @@", + "@@ -20,4 +30,6 @@", + "@@ malformed @@", + }, "\n")) + require.NoError(t, err) + require.Equal(t, []diffHunk{ + {Old: lineRange{Start: 10, End: 10}, New: lineRange{Start: 12, End: 12}}, + {Old: lineRange{Start: 1, End: 0}, New: lineRange{Start: 5, End: 7}}, + {Old: lineRange{Start: 20, End: 23}, New: lineRange{Start: 30, End: 35}}, + }, hunks) +} + +func TestParseNonNegativeInt(t *testing.T) { + t.Parallel() + + value, err := parseNonNegativeInt("0") + require.NoError(t, err) + require.Zero(t, value) + + value, err = parseNonNegativeInt("42") + require.NoError(t, err) + require.Equal(t, 42, value) + + _, err = parseNonNegativeInt("x") + require.Error(t, err) +} + +func TestFileExistsAtRevisionPropagatesLsTreeFailures(t *testing.T) { + t.Parallel() + + repo := fakeGitRepo{ + revisions: map[string]map[string]string{ + "head": { + "pkg/sample_test.go": `package sample +`, + }, + }, + failures: map[string]gitResponse{ + gitKey("ls-tree", "-z", "--name-only", "head", "--", "pkg/sample_test.go"): { + result: gitResult{}, + err: errors.New("fatal: ls-tree failed"), + }, + }, + } + _, err := fileExistsAtRevision(t.Context(), config{RepoRoot: t.TempDir()}, repo.runner(t), "head", "pkg/sample_test.go") + require.ErrorContains(t, err, "check whether pkg/sample_test.go exists at head") +} diff --git a/gitexec.go b/gitexec.go new file mode 100644 index 0000000..2e49c74 --- /dev/null +++ b/gitexec.go @@ -0,0 +1,53 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "strings" +) + +type gitResult struct { + Stdout string +} + +type gitRunner func(ctx context.Context, dir string, args ...string) (gitResult, error) + +type gitFetcher func(ctx context.Context, dir string, spec fetchSpec) (gitResult, error) + +func ensureRevisionExists(ctx context.Context, cfg config, git gitRunner, revision string) error { + _, err := git(ctx, cfg.RepoRoot, "cat-file", "-e", revision+"^{commit}") + if err != nil { + return fmt.Errorf("revision %s is not available: %w", revision, err) + } + return nil +} + +func execGit(ctx context.Context, dir string, args ...string) (gitResult, error) { + cmd := exec.CommandContext(ctx, "git", args...) + cmd.Dir = dir + cmd.Env = append(os.Environ(), "LC_ALL=C", "LANG=C") + + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + result := gitResult{Stdout: stdout.String()} + if err == nil { + return result, nil + } + message := strings.TrimSpace(stderr.String()) + if message == "" { + message = strings.TrimSpace(result.Stdout) + } + if strings.Contains(message, "no merge base") { + return result, fmt.Errorf("git %s: %s. Ensure both revisions have full history before diffing", strings.Join(args, " "), message) + } + if message == "" { + message = err.Error() + } + return result, fmt.Errorf("git %s: %s", strings.Join(args, " "), message) +} diff --git a/gitexec_test.go b/gitexec_test.go new file mode 100644 index 0000000..671b6fc --- /dev/null +++ b/gitexec_test.go @@ -0,0 +1,36 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExecGitNoMergeBaseDiagnosticIsGeneric(t *testing.T) { + t.Parallel() + + requireGit(t) + repoRoot := t.TempDir() + runGit(t, repoRoot, "init") + runGit(t, repoRoot, "config", "user.email", "test@example.com") + runGit(t, repoRoot, "config", "user.name", "Test User") + writeTestFile(t, repoRoot, "pkg/sample_test.go", `package sample +`) + runGit(t, repoRoot, "add", ".") + runGit(t, repoRoot, "commit", "-m", "base") + runGit(t, repoRoot, "branch", "left") + + runGit(t, repoRoot, "checkout", "--orphan", "right") + require.NoError(t, os.Remove(filepath.Join(repoRoot, "pkg", "sample_test.go"))) + writeTestFile(t, repoRoot, "pkg/sample_test.go", `package sample +`) + runGit(t, repoRoot, "add", ".") + runGit(t, repoRoot, "commit", "-m", "right") + + _, err := execGit(t.Context(), repoRoot, "diff", "left...right", "--", "pkg/sample_test.go") + require.Error(t, err) + require.ErrorContains(t, err, "Ensure both revisions have full history before diffing") + require.NotContains(t, err.Error(), `diffing "pkg/sample_test.go"`) +} diff --git a/gitfake_test.go b/gitfake_test.go new file mode 100644 index 0000000..e45f32d --- /dev/null +++ b/gitfake_test.go @@ -0,0 +1,175 @@ +package main + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type fakeGitRepo struct { + changes []testFileChange + revisions map[string]map[string]string + diffOutputs map[string]string + mergeBases map[string]string + headSHA string + failures map[string]gitResponse +} + +type gitResponse struct { + result gitResult + err error +} + +func (repo fakeGitRepo) runner(t *testing.T) gitRunner { + t.Helper() + return func(_ context.Context, _ string, args ...string) (gitResult, error) { + t.Helper() + if response, ok := repo.failures[gitKey(args...)]; ok { + return response.result, response.err + } + switch args[0] { + case "diff": + return repo.diffResponse(t, args) + case "cat-file": + return repo.catFileResponse(t, args) + case "show": + return repo.showResponse(t, args) + case "ls-tree": + return repo.lsTreeResponse(t, args) + case "merge-base": + return repo.mergeBaseResponse(t, args) + case "rev-parse": + return repo.revParseResponse(t, args) + default: + t.Fatalf("unexpected git command: %v", args) + return gitResult{}, nil + } + } +} + +func (repo fakeGitRepo) diffResponse(t *testing.T, args []string) (gitResult, error) { + t.Helper() + if len(args) >= 2 && args[1] == "--name-status" { + return gitResult{Stdout: repo.nameStatusOutput()}, nil + } + separator := slices.Index(args, "--") + require.NotEqual(t, -1, separator) + paths := args[separator+1:] + output, ok := repo.diffOutputs[strings.Join(paths, "\x00")] + if !ok { + t.Fatalf("unexpected diff paths %q", strings.Join(paths, "\x00")) + } + return gitResult{Stdout: output}, nil +} + +func (repo fakeGitRepo) nameStatusOutput() string { + parts := make([]string, 0, len(repo.changes)*3) + for _, change := range repo.changes { + switch change.Kind { + case changeRenamed: + parts = append(parts, "R100", change.OldPath, change.NewPath) + case changeAdded: + parts = append(parts, string(change.Kind), change.NewPath) + case changeDeleted: + parts = append(parts, string(change.Kind), change.OldPath) + default: + parts = append(parts, string(change.Kind), change.displayPath()) + } + } + return strings.Join(parts, "\x00") + "\x00" +} + +func (repo fakeGitRepo) catFileResponse(t *testing.T, args []string) (gitResult, error) { + t.Helper() + require.Len(t, args, 3) + require.Equal(t, "-e", args[1]) + spec := args[2] + if revision, ok := strings.CutSuffix(spec, "^{commit}"); ok { + if _, ok := repo.revisions[revision]; ok { + return gitResult{}, nil + } + return gitFailure(fmt.Sprintf("fatal: bad revision %q", revision)) + } + revision, path := splitRevisionPath(t, spec) + if _, ok := repo.revisions[revision][path]; ok { + return gitResult{}, nil + } + return gitFailure(fmt.Sprintf("fatal: path %q does not exist in %q", path, revision)) +} + +func (repo fakeGitRepo) showResponse(t *testing.T, args []string) (gitResult, error) { + t.Helper() + require.Len(t, args, 2) + revision, path := splitRevisionPath(t, args[1]) + content, ok := repo.revisions[revision][path] + if !ok { + return gitFailure(fmt.Sprintf("fatal: path %q does not exist in %q", path, revision)) + } + return gitResult{Stdout: content}, nil +} + +func (repo fakeGitRepo) lsTreeResponse(t *testing.T, args []string) (gitResult, error) { + t.Helper() + separator := slices.Index(args, "--") + require.Greater(t, separator, 1) + require.Less(t, separator+1, len(args)) + revision := args[separator-1] + pathspec := cleanGitPath(args[separator+1]) + files := make([]string, 0) + for filePath := range repo.revisions[revision] { + cleanPath := cleanGitPath(filePath) + if pathspec != "." && !strings.HasPrefix(cleanPath, pathspec+"/") && cleanPath != pathspec { + continue + } + files = append(files, cleanPath) + } + slices.Sort(files) + return gitResult{Stdout: strings.Join(files, "\x00") + "\x00"}, nil +} + +func (repo fakeGitRepo) mergeBaseResponse(t *testing.T, args []string) (gitResult, error) { + t.Helper() + require.Len(t, args, 3) + key := gitKey(args...) + if repo.mergeBases != nil { + if base, ok := repo.mergeBases[key]; ok { + return gitResult{Stdout: base + "\n"}, nil + } + } + left := args[1] + if _, ok := repo.revisions[left]; ok { + return gitResult{Stdout: left + "\n"}, nil + } + return gitFailure(fmt.Sprintf("fatal: no merge base for %s and %s", args[1], args[2])) +} + +func (repo fakeGitRepo) revParseResponse(t *testing.T, args []string) (gitResult, error) { + t.Helper() + require.Equal(t, []string{"rev-parse", "HEAD"}, args) + head := repo.headSHA + if head == "" { + head = "head" + } + return gitResult{Stdout: head + "\n"}, nil +} + +func splitRevisionPath(t *testing.T, spec string) (revision string, path string) { + t.Helper() + revision, path, ok := strings.Cut(spec, ":") + require.True(t, ok) + return revision, cleanGitPath(path) +} + +func gitFailure(stderr string) (gitResult, error) { + return gitResult{}, errors.New(stderr) +} + +func gitKey(args ...string) string { + // NUL is a stable separator because git diff pathspecs can contain spaces. + return strings.Join(args, "\x00") +} diff --git a/githubactions.go b/githubactions.go new file mode 100644 index 0000000..5adddc0 --- /dev/null +++ b/githubactions.go @@ -0,0 +1,328 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "regexp" + "strings" + "unicode" + "unicode/utf8" +) + +// defaultDispatchBaseRef follows coder/coder's default branch name. +const defaultDispatchBaseRef = "main" + +var repoFullNameRE = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9-]*/[A-Za-z0-9_.-]+$`) + +type githubEvent struct { + PullRequest struct { + Base struct { + SHA string `json:"sha"` + Ref string `json:"ref"` + Repo struct { + FullName string `json:"full_name"` + } `json:"repo"` + } `json:"base"` + Head struct { + SHA string `json:"sha"` + } `json:"head"` + } `json:"pull_request"` + Inputs struct { + BaseSHA string `json:"base_sha"` + HeadSHA string `json:"head_sha"` + } `json:"inputs"` +} + +func githubActionsRunRequest(ctx context.Context, cfg commandConfig, git gitRunner) (runRequest, error) { + baseCfg := cfg.withDefaults() + if baseCfg.OutMatrix == "" { + return runRequest{}, errors.New("--out-matrix is required") + } + + eventName := os.Getenv("GITHUB_EVENT_NAME") + if eventName == "" { + return runRequest{}, errors.New("GITHUB_EVENT_NAME is required") + } + eventPath := os.Getenv("GITHUB_EVENT_PATH") + if eventPath == "" { + return runRequest{}, errors.New("GITHUB_EVENT_PATH is required") + } + githubOutput := os.Getenv("GITHUB_OUTPUT") + if githubOutput == "" { + return runRequest{}, errors.New("GITHUB_OUTPUT is required") + } + stepSummary := os.Getenv("GITHUB_STEP_SUMMARY") + + event, err := readGitHubEvent(eventPath) + if err != nil { + return runRequest{}, err + } + currentHead, err := currentHeadSHA(ctx, baseCfg.RepoRoot, git) + if err != nil { + return runRequest{}, err + } + + req := runRequest{ + RepoRoot: baseCfg.RepoRoot, + Range: diffRange{ + HeadSHA: currentHead, + }, + Sinks: outputSinks{ + OutMatrix: baseCfg.OutMatrix, + OutSummary: baseCfg.OutSummary, + GitHubOutput: githubOutput, + GitHubStepSummary: stepSummary, + }, + } + + switch eventName { + case "pull_request": + return pullRequestRunRequest(req, event) + case "workflow_dispatch": + return workflowDispatchRunRequest(req, event) + default: + return runRequest{}, fmt.Errorf("unsupported GitHub event %q", eventName) + } +} + +func pullRequestRunRequest(req runRequest, event githubEvent) (runRequest, error) { + baseSHA := event.PullRequest.Base.SHA + if err := validateRevisionArg("pull_request.base.sha", baseSHA); err != nil { + return runRequest{}, err + } + baseRef := event.PullRequest.Base.Ref + if err := validateRef("pull_request.base.ref", baseRef); err != nil { + return runRequest{}, err + } + baseRepo := event.PullRequest.Base.Repo.FullName + if err := validateRepoFullName("pull_request.base.repo.full_name", baseRepo); err != nil { + return runRequest{}, err + } + payloadHead := event.PullRequest.Head.SHA + if err := validateRevisionArg("pull_request.head.sha", payloadHead); err != nil { + return runRequest{}, err + } + if req.Range.HeadSHA != payloadHead { + return runRequest{}, fmt.Errorf("checked out HEAD %s does not match pull_request.head.sha %s; update actions/checkout ref to the pull request head commit", req.Range.HeadSHA, payloadHead) + } + + baseURL := githubRepoURL(baseRepo) + req.Range.BaseSHA = baseSHA + req.Fetches = []fetchSpec{ + {Remote: baseURL, Ref: branchFetchRef(baseRef)}, + {Remote: baseURL, Ref: baseSHA}, + } + return req, nil +} + +func workflowDispatchRunRequest(req runRequest, event githubEvent) (runRequest, error) { + if headSHA := event.Inputs.HeadSHA; headSHA != "" { + if err := validateRevisionArg("workflow_dispatch.inputs.head_sha", headSHA); err != nil { + return runRequest{}, err + } + if req.Range.HeadSHA != headSHA { + return runRequest{}, fmt.Errorf("checked out HEAD %s does not match workflow_dispatch.inputs.head_sha %s; update actions/checkout ref to the requested head commit", req.Range.HeadSHA, headSHA) + } + } + + baseSHA := event.Inputs.BaseSHA + mainFetch := fetchSpec{Remote: "origin", Ref: remoteTrackingRefspec(defaultDispatchBaseRef)} + if baseSHA != "" { + if err := validateRevisionArg("workflow_dispatch.inputs.base_sha", baseSHA); err != nil { + return runRequest{}, err + } + req.Range.BaseSHA = baseSHA + req.Fetches = []fetchSpec{mainFetch, {Remote: "origin", Ref: baseSHA}} + return req, nil + } + + req.Fetches = []fetchSpec{mainFetch} + req.MergeBaseRef = "origin/" + defaultDispatchBaseRef + return req, nil +} + +func readGitHubEvent(path string) (githubEvent, error) { + // #nosec G304: path comes from the GitHub Actions runner environment. + data, err := os.ReadFile(path) + if err != nil { + return githubEvent{}, fmt.Errorf("read GitHub event payload %s: %w", path, err) + } + var event githubEvent + if err := json.Unmarshal(data, &event); err != nil { + return githubEvent{}, fmt.Errorf("parse GitHub event payload %s: %w", path, err) + } + return event, nil +} + +func currentHeadSHA(ctx context.Context, repoRoot string, git gitRunner) (string, error) { + result, err := git(ctx, repoRoot, "rev-parse", "HEAD") + if err != nil { + return "", fmt.Errorf("resolve checked out HEAD: %w", err) + } + head := strings.TrimSpace(result.Stdout) + if err := validateRevisionArg("checked out HEAD", head); err != nil { + return "", err + } + return head, nil +} + +func ensureRangeAvailable(ctx context.Context, req *runRequest, git gitRunner, fetch gitFetcher) error { + if req.RepoRoot == "" { + req.RepoRoot = defaultRepoRoot + } + if err := validateRevisionArg("head revision", req.Range.HeadSHA); err != nil { + return err + } + if req.Range.BaseSHA != "" { + return ensureConcreteRangeAvailable(ctx, req, git, fetch) + } + if req.MergeBaseRef == "" { + return errors.New("base revision is required") + } + if err := runFetches(ctx, req, fetch); err != nil { + return err + } + + baseSHA, err := gitMergeBase(ctx, req.RepoRoot, git, req.Range.HeadSHA, req.MergeBaseRef) + if err != nil { + return fmt.Errorf("failed to resolve merge-base between %s and %s after fetching base history: %w", req.Range.HeadSHA, req.MergeBaseRef, err) + } + if err := validateRevisionArg("resolved base revision", baseSHA); err != nil { + return err + } + req.Range.BaseSHA = baseSHA + return nil +} + +func ensureConcreteRangeAvailable(ctx context.Context, req *runRequest, git gitRunner, fetch gitFetcher) error { + if err := validateRevisionArg("base revision", req.Range.BaseSHA); err != nil { + return err + } + _, mergeErr := gitMergeBase(ctx, req.RepoRoot, git, req.Range.BaseSHA, req.Range.HeadSHA) + if mergeErr == nil { + return nil + } + if len(req.Fetches) == 0 { + return fmt.Errorf("unable to resolve merge base for %s...%s: %w", req.Range.BaseSHA, req.Range.HeadSHA, mergeErr) + } + if fetch == nil { + return errors.New("history fetch is required but no fetcher was configured") + } + + attempts := []error{fmt.Errorf("initial merge-base: %w", mergeErr)} + for _, spec := range req.Fetches { + if err := validateFetchSpec(spec); err != nil { + attempts = append(attempts, fmt.Errorf("validate fetch spec %s: %w", spec.Ref, err)) + continue + } + if _, err := fetch(ctx, req.RepoRoot, spec); err != nil { + attempts = append(attempts, fmt.Errorf("fetch %s from %s: %w", spec.Ref, spec.Remote, err)) + continue + } + _, err := gitMergeBase(ctx, req.RepoRoot, git, req.Range.BaseSHA, req.Range.HeadSHA) + if err == nil { + return nil + } + attempts = append(attempts, fmt.Errorf("merge-base after fetching %s from %s: %w", spec.Ref, spec.Remote, err)) + } + return fmt.Errorf("unable to resolve a merge base for %s...%s after fetching base history: %w", req.Range.BaseSHA, req.Range.HeadSHA, errors.Join(attempts...)) +} + +func runFetches(ctx context.Context, req *runRequest, fetch gitFetcher) error { + if len(req.Fetches) == 0 { + return nil + } + if fetch == nil { + return errors.New("history fetch is required but no fetcher was configured") + } + for _, spec := range req.Fetches { + if err := validateFetchSpec(spec); err != nil { + return err + } + if _, err := fetch(ctx, req.RepoRoot, spec); err != nil { + return fmt.Errorf("fetch %s from %s: %w", spec.Ref, spec.Remote, err) + } + } + return nil +} + +func validateFetchSpec(spec fetchSpec) error { + if spec.Remote == "" || spec.Ref == "" { + return fmt.Errorf("invalid fetch spec: remote and ref are required") + } + return nil +} + +func gitMergeBase(ctx context.Context, repoRoot string, git gitRunner, left, right string) (string, error) { + result, err := git(ctx, repoRoot, "merge-base", left, right) + if err != nil { + return "", err + } + base := strings.TrimSpace(result.Stdout) + if base == "" { + return "", fmt.Errorf("git merge-base %s %s returned no revision", left, right) + } + return base, nil +} + +func execGitFetch(ctx context.Context, dir string, spec fetchSpec) (gitResult, error) { + return execGit(ctx, dir, "fetch", "--no-tags", spec.Remote, spec.Ref) +} + +func validateRef(name, value string) error { + if value == "" { + return fmt.Errorf("%s is required", name) + } + if strings.HasPrefix(value, "-") { + return fmt.Errorf("%s must not start with '-': %q", name, value) + } + if !utf8.ValidString(value) || strings.ContainsRune(value, '\x00') { + return fmt.Errorf("%s must not contain invalid bytes", name) + } + if strings.HasPrefix(value, "/") || strings.HasSuffix(value, "/") || strings.Contains(value, "//") { + return fmt.Errorf("%s must be a safe branch ref: %q", name, value) + } + if strings.Contains(value, "..") || strings.Contains(value, "@{") || strings.HasSuffix(value, ".lock") { + return fmt.Errorf("%s must be a safe branch ref: %q", name, value) + } + for _, r := range value { + if unicode.IsControl(r) || unicode.IsSpace(r) { + return fmt.Errorf("%s must not contain control or whitespace characters: %q", name, value) + } + switch r { + case ':', '^', '~', '?', '*', '[', '\\': + return fmt.Errorf("%s must be a safe branch ref: %q", name, value) + } + } + for segment := range strings.SplitSeq(value, "/") { + if segment == "" || strings.HasPrefix(segment, ".") { + return fmt.Errorf("%s must be a safe branch ref: %q", name, value) + } + } + return nil +} + +func validateRepoFullName(name, value string) error { + if value == "" { + return fmt.Errorf("%s is required", name) + } + if !repoFullNameRE.MatchString(value) || strings.Contains(value, "..") { + return fmt.Errorf("%s must be a GitHub owner/repository name: %q", name, value) + } + return nil +} + +func githubRepoURL(fullName string) string { + return "https://github.com/" + fullName + ".git" +} + +func branchFetchRef(ref string) string { + return "refs/heads/" + ref +} + +func remoteTrackingRefspec(ref string) string { + return branchFetchRef(ref) + ":refs/remotes/origin/" + ref +} diff --git a/githubactions_test.go b/githubactions_test.go new file mode 100644 index 0000000..59a89e0 --- /dev/null +++ b/githubactions_test.go @@ -0,0 +1,432 @@ +package main + +import ( + "bytes" + "context" + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGitHubActionsRunRequestPullRequest(t *testing.T) { + eventPath := writeGitHubEvent(t, `{ + "pull_request": { + "base": { + "sha": "base123", + "ref": "main", + "repo": {"full_name": "coder/coder"} + }, + "head": {"sha": "head123"} + }, + "ignored": true + }`) + t.Setenv("GITHUB_EVENT_NAME", "pull_request") + t.Setenv("GITHUB_EVENT_PATH", eventPath) + t.Setenv("GITHUB_OUTPUT", "output.txt") + t.Setenv("GITHUB_STEP_SUMMARY", "summary.md") + t.Setenv("UNRELATED_EXTRA_ENV", "ignored") + + req, err := githubActionsRunRequest(t.Context(), commandConfig{ + config: config{RepoRoot: "/repo", OutMatrix: "matrix.json"}, + }, fakeGitRepo{headSHA: "head123"}.runner(t)) + require.NoError(t, err) + require.Equal(t, "/repo", req.RepoRoot) + require.Equal(t, diffRange{BaseSHA: "base123", HeadSHA: "head123"}, req.Range) + require.Equal(t, []fetchSpec{ + {Remote: "https://github.com/coder/coder.git", Ref: "refs/heads/main"}, + {Remote: "https://github.com/coder/coder.git", Ref: "base123"}, + }, req.Fetches) + require.Equal(t, "matrix.json", req.Sinks.OutMatrix) + require.Equal(t, "output.txt", req.Sinks.GitHubOutput) + require.Equal(t, "summary.md", req.Sinks.GitHubStepSummary) +} + +func TestGitHubActionsRunRequestVerifiesPullRequestHead(t *testing.T) { + eventPath := writeGitHubEvent(t, `{ + "pull_request": { + "base": { + "sha": "base123", + "ref": "main", + "repo": {"full_name": "coder/coder"} + }, + "head": {"sha": "expected-head"} + } + }`) + t.Setenv("GITHUB_EVENT_NAME", "pull_request") + t.Setenv("GITHUB_EVENT_PATH", eventPath) + t.Setenv("GITHUB_OUTPUT", "output.txt") + + _, err := githubActionsRunRequest(t.Context(), commandConfig{ + config: config{RepoRoot: "/repo", OutMatrix: "matrix.json"}, + }, fakeGitRepo{headSHA: "actual-head"}.runner(t)) + require.ErrorContains(t, err, "checked out HEAD actual-head does not match pull_request.head.sha expected-head") +} + +func TestGitHubActionsRunRequestRequiresPullRequestHead(t *testing.T) { + eventPath := writeGitHubEvent(t, `{ + "pull_request": { + "base": { + "sha": "base123", + "ref": "main", + "repo": {"full_name": "coder/coder"} + }, + "head": {"sha": ""} + } + }`) + t.Setenv("GITHUB_EVENT_NAME", "pull_request") + t.Setenv("GITHUB_EVENT_PATH", eventPath) + t.Setenv("GITHUB_OUTPUT", "output.txt") + + _, err := githubActionsRunRequest(t.Context(), commandConfig{ + config: config{RepoRoot: "/repo", OutMatrix: "matrix.json"}, + }, fakeGitRepo{headSHA: "head123"}.runner(t)) + require.ErrorContains(t, err, "pull_request.head.sha is required") +} + +func TestGitHubActionsRunRequestWorkflowDispatchExplicitRange(t *testing.T) { + eventPath := writeGitHubEvent(t, `{ + "inputs": { + "base_sha": "base123", + "head_sha": "head123" + } + }`) + t.Setenv("GITHUB_EVENT_NAME", "workflow_dispatch") + t.Setenv("GITHUB_EVENT_PATH", eventPath) + t.Setenv("GITHUB_OUTPUT", "output.txt") + + req, err := githubActionsRunRequest(t.Context(), commandConfig{ + config: config{RepoRoot: "/repo", OutMatrix: "matrix.json"}, + }, fakeGitRepo{headSHA: "head123"}.runner(t)) + require.NoError(t, err) + require.Equal(t, diffRange{BaseSHA: "base123", HeadSHA: "head123"}, req.Range) + require.Equal(t, []fetchSpec{ + {Remote: "origin", Ref: "refs/heads/main:refs/remotes/origin/main"}, + {Remote: "origin", Ref: "base123"}, + }, req.Fetches) + require.Empty(t, req.MergeBaseRef) +} + +func TestRunCommandGitHubActionsWritesOutputs(t *testing.T) { + oldContent := `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("old") +} +` + newContent := `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("new") +} +` + rangeForAlpha := singleLineRange(t, newContent, `t.Log("new")`) + repo := fakeGitRepo{ + headSHA: "head123", + changes: []testFileChange{{ + Kind: changeModified, + OldPath: "pkg/sample_test.go", + NewPath: "pkg/sample_test.go", + }}, + revisions: map[string]map[string]string{ + "base123": {"pkg/sample_test.go": oldContent}, + "head123": {"pkg/sample_test.go": newContent}, + }, + diffOutputs: map[string]string{ + "pkg/sample_test.go": diffForChange(rangeForAlpha, rangeForAlpha), + }, + } + tmpDir := t.TempDir() + eventPath := writeGitHubEvent(t, `{ + "pull_request": { + "base": { + "sha": "base123", + "ref": "main", + "repo": {"full_name": "coder/coder"} + }, + "head": {"sha": "head123"} + } + }`) + outputPath := filepath.Join(tmpDir, "github-output.txt") + stepSummaryPath := filepath.Join(tmpDir, "step-summary.md") + localSummaryPath := filepath.Join(tmpDir, "summary.md") + matrixPath := filepath.Join(tmpDir, "matrix.json") + t.Setenv("GITHUB_EVENT_NAME", "pull_request") + t.Setenv("GITHUB_EVENT_PATH", eventPath) + t.Setenv("GITHUB_OUTPUT", outputPath) + t.Setenv("GITHUB_STEP_SUMMARY", stepSummaryPath) + + var stdout bytes.Buffer + var stderr bytes.Buffer + fetch := func(context.Context, string, fetchSpec) (gitResult, error) { + t.Fatal("unexpected fetch") + return gitResult{}, nil + } + err := runCommand(t.Context(), commandConfig{ + config: config{RepoRoot: "/repo", OutMatrix: matrixPath, OutSummary: localSummaryPath}, + GitHubActions: true, + }, &stdout, &stderr, repo.runner(t), fetch) + require.NoError(t, err) + + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.JSONEq(t, `{"include":[{"package":"./pkg","run_regex":"^(TestAlpha)(/.*)?$","test_count":"10"}]}`, string(matrixData)) + outputData, err := os.ReadFile(outputPath) + require.NoError(t, err) + require.Equal(t, "matrix="+string(bytes.TrimSpace(matrixData))+"\n", string(outputData)) + stepSummary, err := os.ReadFile(stepSummaryPath) + require.NoError(t, err) + require.Contains(t, string(stepSummary), `"pkg/sample_test.go"`) + require.Contains(t, string(stepSummary), "TestAlpha") + require.Empty(t, stdout.String()) + require.Contains(t, stderr.String(), "selected 1 package targets") +} + +func TestEnsureRangeAvailableWorkflowDispatchDefaultBase(t *testing.T) { + t.Parallel() + + req := runRequest{ + RepoRoot: "/repo", + Range: diffRange{HeadSHA: "head123"}, + Fetches: []fetchSpec{{Remote: "origin", Ref: "refs/heads/main:refs/remotes/origin/main"}}, + MergeBaseRef: "origin/main", + } + repo := fakeGitRepo{ + revisions: map[string]map[string]string{ + "base123": {}, + "head123": {}, + }, + mergeBases: map[string]string{ + gitKey("merge-base", "head123", "origin/main"): "base123", + }, + } + var fetches []fetchSpec + fetch := func(_ context.Context, _ string, spec fetchSpec) (gitResult, error) { + fetches = append(fetches, spec) + return gitResult{}, nil + } + err := ensureRangeAvailable(t.Context(), &req, repo.runner(t), fetch) + require.NoError(t, err) + require.Equal(t, "base123", req.Range.BaseSHA) + require.Equal(t, []fetchSpec{{Remote: "origin", Ref: "refs/heads/main:refs/remotes/origin/main"}}, fetches) +} + +func TestEnsureRangeAvailableFetchesLazily(t *testing.T) { + t.Parallel() + + req := runRequest{ + RepoRoot: "/repo", + Range: diffRange{BaseSHA: "base123", HeadSHA: "head123"}, + Fetches: []fetchSpec{{Remote: "https://github.com/coder/coder.git", Ref: "refs/heads/main"}}, + } + repo := fakeGitRepo{revisions: map[string]map[string]string{"base123": {}, "head123": {}}} + fetch := func(_ context.Context, _ string, spec fetchSpec) (gitResult, error) { + t.Fatalf("unexpected fetch: %+v", spec) + return gitResult{}, nil + } + require.NoError(t, ensureRangeAvailable(t.Context(), &req, repo.runner(t), fetch)) +} + +func TestEnsureRangeAvailableFetchesWhenMergeBaseIsMissing(t *testing.T) { + t.Parallel() + + req := runRequest{ + RepoRoot: "/repo", + Range: diffRange{BaseSHA: "base123", HeadSHA: "head123"}, + Fetches: []fetchSpec{ + {Remote: "https://github.com/coder/coder.git", Ref: "refs/heads/main"}, + {Remote: "https://github.com/coder/coder.git", Ref: "base123"}, + }, + } + mergeBaseCalls := 0 + git := func(_ context.Context, _ string, args ...string) (gitResult, error) { + require.Equal(t, []string{"merge-base", "base123", "head123"}, args) + mergeBaseCalls++ + if mergeBaseCalls == 1 { + return gitFailure("fatal: no merge base") + } + return gitResult{Stdout: "base123\n"}, nil + } + var fetches []fetchSpec + fetch := func(_ context.Context, _ string, spec fetchSpec) (gitResult, error) { + fetches = append(fetches, spec) + return gitResult{}, nil + } + require.NoError(t, ensureRangeAvailable(t.Context(), &req, git, fetch)) + require.Equal(t, 2, mergeBaseCalls) + require.Equal(t, req.Fetches[:1], fetches) +} + +func TestGitHubActionsRunRequestValidatesInputsBeforeFetch(t *testing.T) { + tests := []struct { + name string + eventName string + eventJSON string + want string + }{ + { + name: "bad base revision", + eventName: "pull_request", + eventJSON: `{"pull_request":{"base":{"sha":"-bad","ref":"main","repo":{"full_name":"coder/coder"}},"head":{"sha":"head123"}}}`, + want: "must not start with '-'", + }, + { + name: "bad base ref", + eventName: "pull_request", + eventJSON: `{"pull_request":{"base":{"sha":"base123","ref":"main:evil","repo":{"full_name":"coder/coder"}},"head":{"sha":"head123"}}}`, + want: "safe branch ref", + }, + { + name: "bad base repo", + eventName: "pull_request", + eventJSON: `{"pull_request":{"base":{"sha":"base123","ref":"main","repo":{"full_name":"../coder"}},"head":{"sha":"head123"}}}`, + want: "owner/repository", + }, + { + name: "bad dispatch head", + eventName: "workflow_dispatch", + eventJSON: `{"inputs":{"head_sha":"head:bad"}}`, + want: "must not contain ':'", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + eventPath := writeGitHubEvent(t, tc.eventJSON) + t.Setenv("GITHUB_EVENT_NAME", tc.eventName) + t.Setenv("GITHUB_EVENT_PATH", eventPath) + t.Setenv("GITHUB_OUTPUT", "output.txt") + + _, err := githubActionsRunRequest(t.Context(), commandConfig{ + config: config{RepoRoot: "/repo", OutMatrix: "matrix.json"}, + }, fakeGitRepo{headSHA: "head123"}.runner(t)) + require.ErrorContains(t, err, tc.want) + }) + } +} + +func TestEnsureRangeAvailableFallsBackWhenFirstFetchFails(t *testing.T) { + t.Parallel() + + req := runRequest{ + RepoRoot: "/repo", + Range: diffRange{BaseSHA: "base123", HeadSHA: "head123"}, + Fetches: []fetchSpec{ + {Remote: "https://github.com/coder/coder.git", Ref: "refs/heads/main"}, + {Remote: "https://github.com/coder/coder.git", Ref: "base123"}, + }, + } + mergeBaseCalls := 0 + git := func(_ context.Context, _ string, args ...string) (gitResult, error) { + require.Equal(t, []string{"merge-base", "base123", "head123"}, args) + mergeBaseCalls++ + if mergeBaseCalls == 1 { + return gitFailure("fatal: no merge base") + } + return gitResult{Stdout: "base123\n"}, nil + } + var fetches []fetchSpec + fetch := func(_ context.Context, _ string, spec fetchSpec) (gitResult, error) { + fetches = append(fetches, spec) + if len(fetches) == 1 { + return gitResult{}, errors.New("network refused") + } + return gitResult{}, nil + } + require.NoError(t, ensureRangeAvailable(t.Context(), &req, git, fetch)) + require.Equal(t, 2, mergeBaseCalls) + require.Equal(t, req.Fetches, fetches) +} + +func TestEnsureRangeAvailableSkipsInvalidFetchSpecWhenLaterFetchSucceeds(t *testing.T) { + t.Parallel() + + req := runRequest{ + RepoRoot: "/repo", + Range: diffRange{BaseSHA: "base123", HeadSHA: "head123"}, + Fetches: []fetchSpec{ + {Remote: "", Ref: "refs/heads/main"}, + {Remote: "https://github.com/coder/coder.git", Ref: "base123"}, + }, + } + mergeBaseCalls := 0 + git := func(_ context.Context, _ string, args ...string) (gitResult, error) { + require.Equal(t, []string{"merge-base", "base123", "head123"}, args) + mergeBaseCalls++ + if mergeBaseCalls == 1 { + return gitFailure("fatal: no merge base") + } + return gitResult{Stdout: "base123\n"}, nil + } + var fetches []fetchSpec + fetch := func(_ context.Context, _ string, spec fetchSpec) (gitResult, error) { + fetches = append(fetches, spec) + return gitResult{}, nil + } + require.NoError(t, ensureRangeAvailable(t.Context(), &req, git, fetch)) + require.Equal(t, 2, mergeBaseCalls) + require.Equal(t, []fetchSpec{{Remote: "https://github.com/coder/coder.git", Ref: "base123"}}, fetches) +} + +func TestEnsureRangeAvailableReportsInvalidFetchSpecWithAttempts(t *testing.T) { + t.Parallel() + + req := runRequest{ + RepoRoot: "/repo", + Range: diffRange{BaseSHA: "base123", HeadSHA: "head123"}, + Fetches: []fetchSpec{ + {Remote: "", Ref: "refs/heads/main"}, + }, + } + git := func(_ context.Context, _ string, args ...string) (gitResult, error) { + require.Equal(t, []string{"merge-base", "base123", "head123"}, args) + return gitFailure("fatal: no merge base") + } + fetch := func(_ context.Context, _ string, spec fetchSpec) (gitResult, error) { + t.Fatalf("unexpected fetch for invalid spec: %+v", spec) + return gitResult{}, nil + } + err := ensureRangeAvailable(t.Context(), &req, git, fetch) + require.Error(t, err) + require.ErrorContains(t, err, "initial merge-base") + require.ErrorContains(t, err, "validate fetch spec refs/heads/main") + require.ErrorContains(t, err, "invalid fetch spec") +} + +func TestEnsureRangeAvailableReportsAllFetchFailures(t *testing.T) { + t.Parallel() + + req := runRequest{ + RepoRoot: "/repo", + Range: diffRange{BaseSHA: "base123", HeadSHA: "head123"}, + Fetches: []fetchSpec{ + {Remote: "https://github.com/coder/coder.git", Ref: "refs/heads/main"}, + {Remote: "https://github.com/coder/coder.git", Ref: "base123"}, + }, + } + git := func(_ context.Context, _ string, args ...string) (gitResult, error) { + require.Equal(t, []string{"merge-base", "base123", "head123"}, args) + return gitFailure("fatal: no merge base") + } + fetch := func(_ context.Context, _ string, spec fetchSpec) (gitResult, error) { + return gitResult{}, errors.New("fetch failed for " + spec.Ref) + } + err := ensureRangeAvailable(t.Context(), &req, git, fetch) + require.Error(t, err) + require.ErrorContains(t, err, "initial merge-base") + require.ErrorContains(t, err, "fetch refs/heads/main from https://github.com/coder/coder.git") + require.ErrorContains(t, err, "fetch base123 from https://github.com/coder/coder.git") +} + +func writeGitHubEvent(t *testing.T, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), "event.json") + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) + return path +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0c26715 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/coder/whichtests + +go 1.26.2 + +require github.com/stretchr/testify v1.10.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..713a0b4 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/helpers_test.go b/helpers_test.go new file mode 100644 index 0000000..ae801a3 --- /dev/null +++ b/helpers_test.go @@ -0,0 +1,114 @@ +package main + +import ( + "fmt" + "maps" + "os" + "path/filepath" + "slices" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func selectionNames(selection *packageSelection) []string { + if selection == nil { + return nil + } + return slices.Sorted(maps.Keys(selection.Tests)) +} + +// mustPackageInventory builds a packageInventory for the synthetic "pkg" +// directory and "sample" package used throughout the test suite. +func mustPackageInventory(t *testing.T, files map[string]string) packageInventory { + t.Helper() + const packageName = "sample" + inventory := packageInventory{ + Key: packageKey{Dir: "pkg", Name: packageName}, + Tests: map[string]struct{}{}, + } + for _, content := range files { + snapshot, err := parseFileSnapshot([]byte(content)) + require.NoError(t, err) + require.Equal(t, packageName, snapshot.packageName) + for testName := range snapshot.tests { + inventory.Tests[testName] = struct{}{} + } + } + return inventory +} + +func mustFileSnapshot(t *testing.T, data []byte) fileSnapshot { + t.Helper() + snapshot, err := parseFileSnapshot(data) + require.NoError(t, err) + return snapshot +} + +func mustOptionalFileSnapshot(t *testing.T, data []byte) *fileSnapshot { + t.Helper() + if data == nil { + return nil + } + snapshot := mustFileSnapshot(t, data) + return &snapshot +} + +func diffForChange(oldRange, newRange lineRange) string { + return fmt.Sprintf("@@ -%s +%s @@\n", formatDiffRange(oldRange), formatDiffRange(newRange)) +} + +func formatDiffRange(r lineRange) string { + if !r.hasLines() { + start := r.Start + if start == 0 { + start = 1 + } + return fmt.Sprintf("%d,0", start) + } + count := r.End - r.Start + 1 + if count == 1 { + return fmt.Sprintf("%d", r.Start) + } + return fmt.Sprintf("%d,%d", r.Start, count) +} + +func singleLineRange(t *testing.T, content, needle string) lineRange { + t.Helper() + line := lineNumberForSubstring(t, content, needle) + return lineRange{Start: line, End: line} +} + +func rangeSpan(start, end lineRange) lineRange { + return lineRange{Start: start.Start, End: end.End} +} + +func emptyRangeAt(start int) lineRange { + return lineRange{Start: start, End: start - 1} +} + +func lineNumberForSubstring(t *testing.T, content, needle string) int { + t.Helper() + lineNumber := 0 + for index, line := range strings.Split(content, "\n") { + if !strings.Contains(line, needle) { + continue + } + if lineNumber != 0 { + t.Fatalf("needle %q matched more than once", needle) + } + lineNumber = index + 1 + } + if lineNumber == 0 { + t.Fatalf("needle %q not found", needle) + } + return lineNumber +} + +func writeTestFile(t *testing.T, root, relativePath, content string) { + t.Helper() + path := filepath.Join(root, filepath.FromSlash(relativePath)) + require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o750)) + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) +} diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 0000000..eae49a8 --- /dev/null +++ b/integration_test.go @@ -0,0 +1,189 @@ +package main + +import ( + "bytes" + "encoding/json" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRunWithRealGitHandlesAddedFileAtRevision(t *testing.T) { + t.Parallel() + + requireGit(t) + repoRoot := t.TempDir() + runGit(t, repoRoot, "init") + runGit(t, repoRoot, "config", "user.email", "test@example.com") + runGit(t, repoRoot, "config", "user.name", "Test User") + runGit(t, repoRoot, "commit", "--allow-empty", "-m", "base") + baseSHA := strings.TrimSpace(runGit(t, repoRoot, "rev-parse", "HEAD")) + + writeTestFile(t, repoRoot, "pkg/new_test.go", `package sample + +import "testing" + +func TestAdded(t *testing.T) { + t.Log("added") +} +`) + runGit(t, repoRoot, "add", ".") + runGit(t, repoRoot, "commit", "-m", "head") + headSHA := strings.TrimSpace(runGit(t, repoRoot, "rev-parse", "HEAD")) + + matrixPath := filepath.Join(repoRoot, "matrix.json") + summaryPath := filepath.Join(repoRoot, "summary.md") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: baseSHA, HeadSHA: headSHA, OutMatrix: matrixPath, OutSummary: summaryPath}}, &stdout, &stderr, execGit, nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "./pkg", matrix.Include[0].Package) + require.Equal(t, "^(TestAdded)(/.*)?$", matrix.Include[0].RunRegex) + + summary, err := os.ReadFile(summaryPath) + require.NoError(t, err) + require.Contains(t, string(summary), "TestAdded") +} + +func TestRunWithRealGitHandlesDeletedSetupFile(t *testing.T) { + t.Parallel() + + requireGit(t) + repoRoot := t.TempDir() + runGit(t, repoRoot, "init") + runGit(t, repoRoot, "config", "user.email", "test@example.com") + runGit(t, repoRoot, "config", "user.name", "Test User") + writeTestFile(t, repoRoot, "pkg/setup_test.go", `package sample + +import "testing" + +func setup(t *testing.T) { + t.Helper() +} +`) + writeTestFile(t, repoRoot, "pkg/alpha_test.go", `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +`) + runGit(t, repoRoot, "add", ".") + runGit(t, repoRoot, "commit", "-m", "base") + baseSHA := strings.TrimSpace(runGit(t, repoRoot, "rev-parse", "HEAD")) + + runGit(t, repoRoot, "rm", "pkg/setup_test.go") + runGit(t, repoRoot, "commit", "-m", "head") + headSHA := strings.TrimSpace(runGit(t, repoRoot, "rev-parse", "HEAD")) + + matrixPath := filepath.Join(repoRoot, "matrix.json") + summaryPath := filepath.Join(repoRoot, "summary.md") + var stdout bytes.Buffer + var stderr bytes.Buffer + err := runCommand(t.Context(), commandConfig{config: config{RepoRoot: repoRoot, BaseSHA: baseSHA, HeadSHA: headSHA, OutMatrix: matrixPath, OutSummary: summaryPath}}, &stdout, &stderr, execGit, nil) + require.NoError(t, err) + + var matrix matrixOutput + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(matrixData, &matrix)) + require.Len(t, matrix.Include, 1) + require.Equal(t, "./pkg", matrix.Include[0].Package) + require.Equal(t, "^(TestAlpha)(/.*)?$", matrix.Include[0].RunRegex) + + summary, err := os.ReadFile(summaryPath) + require.NoError(t, err) + require.Contains(t, string(summary), "pkg/setup_test.go") + require.Contains(t, string(summary), "TestAlpha") +} + +func TestEnsureRangeAvailableWithRealGitFetchesMovedBase(t *testing.T) { + t.Parallel() + + requireGit(t) + root := t.TempDir() + workRoot := filepath.Join(root, "work") + bareRoot := filepath.Join(root, "upstream.git") + cloneRoot := filepath.Join(root, "clone") + require.NoError(t, os.MkdirAll(workRoot, 0o750)) + runGit(t, workRoot, "init") + runGit(t, workRoot, "config", "user.email", "test@example.com") + runGit(t, workRoot, "config", "user.name", "Test User") + writeTestFile(t, workRoot, "pkg/sample_test.go", `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("base") +} +`) + runGit(t, workRoot, "add", ".") + runGit(t, workRoot, "commit", "-m", "base") + runGit(t, workRoot, "branch", "-M", "main") + + runGit(t, workRoot, "checkout", "-b", "feature") + writeTestFile(t, workRoot, "pkg/sample_test.go", `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("feature") +} +`) + runGit(t, workRoot, "commit", "-am", "feature") + headSHA := strings.TrimSpace(runGit(t, workRoot, "rev-parse", "HEAD")) + + runGit(t, workRoot, "checkout", "main") + writeTestFile(t, workRoot, "README.md", "base branch moved\n") + runGit(t, workRoot, "add", "README.md") + runGit(t, workRoot, "commit", "-m", "move base") + baseSHA := strings.TrimSpace(runGit(t, workRoot, "rev-parse", "HEAD")) + + runGit(t, workRoot, "init", "--bare", bareRoot) + runGit(t, workRoot, "remote", "add", "origin", bareRoot) + runGit(t, workRoot, "push", "origin", "main", "feature") + runGit(t, root, "clone", "--single-branch", "--branch", "feature", "file://"+bareRoot, cloneRoot) + _, err := execGit(t.Context(), cloneRoot, "cat-file", "-e", baseSHA+"^{commit}") + require.Error(t, err) + + req := runRequest{ + RepoRoot: cloneRoot, + Range: diffRange{BaseSHA: baseSHA, HeadSHA: headSHA}, + Fetches: []fetchSpec{ + {Remote: "origin", Ref: remoteTrackingRefspec(defaultDispatchBaseRef)}, + {Remote: "origin", Ref: baseSHA}, + }, + } + err = ensureRangeAvailable(t.Context(), &req, execGit, execGitFetch) + require.NoError(t, err) + changed := strings.TrimSpace(runGit(t, cloneRoot, "diff", "--name-only", baseSHA+"..."+headSHA)) + require.Equal(t, "pkg/sample_test.go", changed) +} + +func requireGit(t *testing.T) { + t.Helper() + if _, err := exec.LookPath("git"); err != nil { + t.Skipf("git is not available on PATH: %v", err) + } +} + +func runGit(t *testing.T, dir string, args ...string) string { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + cmd.Env = append(os.Environ(), "LC_ALL=C", "LANG=C") + output, err := cmd.CombinedOutput() + require.NoErrorf(t, err, "git %s failed: %s", strings.Join(args, " "), string(output)) + return string(output) +} diff --git a/inventory.go b/inventory.go new file mode 100644 index 0000000..6456746 --- /dev/null +++ b/inventory.go @@ -0,0 +1,226 @@ +package main + +import ( + "cmp" + "context" + "fmt" + "maps" + "path/filepath" + "slices" + "strings" +) + +// inventoryCache owns repository facts for a single run. Returned package +// inventories and parsed snapshots alias cached maps and slices, so callers must +// treat them as read-only unless they clone before mutating. +type inventoryCache struct { + cfg config + git gitRunner + validRevisions map[string]struct{} + files map[revisionFileKey]cachedFile + fileLists map[string][]string + packages map[string]packageInventory +} + +type revisionFileKey struct { + Revision string + Path string +} + +type cachedFile struct { + existenceKnown bool + exists bool + parsed bool + snapshot parsedFileSnapshot +} + +func newInventoryCache(cfg config, git gitRunner) *inventoryCache { + return &inventoryCache{ + cfg: cfg, + git: git, + validRevisions: map[string]struct{}{}, + files: map[revisionFileKey]cachedFile{}, + fileLists: map[string][]string{}, + packages: map[string]packageInventory{}, + } +} + +func (cache *inventoryCache) ensureRevisionExists(ctx context.Context, revision string) error { + if _, ok := cache.validRevisions[revision]; ok { + return nil + } + if err := ensureRevisionExists(ctx, cache.cfg, cache.git, revision); err != nil { + return err + } + cache.validRevisions[revision] = struct{}{} + return nil +} + +func (cache *inventoryCache) noteFileExists(revision, filePath string) { + key := revisionFileKey{Revision: revision, Path: cleanGitPath(filePath)} + file := cache.files[key] + file.existenceKnown = true + file.exists = true + cache.files[key] = file +} + +// parseFileAtRevision returns a parsed snapshot for an existing file. The +// returned snapshot aliases cache state and must be treated as read-only. +func (cache *inventoryCache) parseFileAtRevision(ctx context.Context, revision, filePath string) (parsedFileSnapshot, bool, error) { + key := revisionFileKey{Revision: revision, Path: cleanGitPath(filePath)} + file := cache.files[key] + if file.parsed { + return file.snapshot, true, nil + } + if err := cache.ensureRevisionExists(ctx, revision); err != nil { + return parsedFileSnapshot{}, false, err + } + if !file.existenceKnown { + exists, err := fileExistsAtRevision(ctx, cache.cfg, cache.git, revision, key.Path) + if err != nil { + return parsedFileSnapshot{}, false, err + } + file.existenceKnown = true + file.exists = exists + cache.files[key] = file + } + if !file.exists { + return parsedFileSnapshot{}, false, nil + } + + result, err := cache.git(ctx, cache.cfg.RepoRoot, "show", revision+":"+key.Path) + if err != nil { + return parsedFileSnapshot{}, false, fmt.Errorf("read %s at %s: %w", key.Path, revision, err) + } + parsed, err := parseSnapshotForPath(key.Path, []byte(result.Stdout)) + if err != nil { + return parsedFileSnapshot{}, true, fmt.Errorf("parse %s at %s: %w", key.Path, revision, err) + } + file.parsed = true + file.snapshot = parsed + cache.files[key] = file + return parsed, true, nil +} + +func (cache *inventoryCache) parseChangeFileAtRevision(ctx context.Context, revision, filePath string) (parsedFileSnapshot, bool, error) { + if filePath == "" || !isRunnableTestFilePath(filePath) { + return parsedFileSnapshot{}, false, nil + } + return cache.parseFileAtRevision(ctx, revision, filePath) +} + +// loadPackageInventory returns an inventory whose maps alias cache state. Callers +// must treat the result as read-only or clone maps before mutating. +func (cache *inventoryCache) loadPackageInventory(ctx context.Context, revision string, key packageKey) (packageInventory, error) { + cacheKey := revision + "\x00" + key.Dir + "\x00" + key.Name + if inventory, ok := cache.packages[cacheKey]; ok { + return inventory, nil + } + + files, err := cache.listTestFilesInDir(ctx, revision, key.Dir) + if err != nil { + return packageInventory{}, err + } + inventory := packageInventory{ + Key: key, + Tests: map[string]struct{}{}, + } + for _, filePath := range files { + parsed, exists, err := cache.parseFileAtRevision(ctx, revision, filePath) + if err != nil { + return packageInventory{}, err + } + if !exists || parsed.Snapshot.packageName != key.Name { + continue + } + for testName := range parsed.Snapshot.tests { + inventory.Tests[testName] = struct{}{} + } + } + cache.packages[cacheKey] = inventory + return inventory, nil +} + +func (cache *inventoryCache) listTestFilesInDir(ctx context.Context, revision, dir string) ([]string, error) { + cleanDir := filepath.ToSlash(filepath.Clean(dir)) + cacheKey := revision + "\x00" + cleanDir + if files, ok := cache.fileLists[cacheKey]; ok { + return files, nil + } + if err := cache.ensureRevisionExists(ctx, revision); err != nil { + return nil, err + } + pathspec := cmp.Or(cleanDir, ".") + result, err := cache.git(ctx, cache.cfg.RepoRoot, "ls-tree", "-r", "-z", "--name-only", revision, "--", pathspec) + if err != nil { + return nil, err + } + files := make([]string, 0) + for part := range strings.SplitSeq(result.Stdout, "\x00") { + if part == "" { + continue + } + filePath := cleanGitPath(part) + if !isRunnableTestFilePath(filePath) { + continue + } + if filepath.ToSlash(filepath.Dir(filePath)) != cleanDir { + continue + } + files = append(files, filePath) + cache.noteFileExists(revision, filePath) + } + slices.Sort(files) + cache.fileLists[cacheKey] = files + return files, nil +} + +func (cache *inventoryCache) directoryWideSelections(ctx context.Context, revision, dir string, files map[string]struct{}) ([]*packageSelection, error) { + inventories, err := cache.loadDirectoryInventories(ctx, revision, dir) + if err != nil { + return nil, err + } + selections := make([]*packageSelection, 0, len(inventories)) + for _, inventory := range inventories { + selection := allPackageTestsSelectionForFiles(inventory, maps.Clone(files)) + if selection == nil { + continue + } + selections = append(selections, selection) + } + return selections, nil +} + +func (cache *inventoryCache) loadDirectoryInventories(ctx context.Context, revision, dir string) ([]packageInventory, error) { + files, err := cache.listTestFilesInDir(ctx, revision, dir) + if err != nil { + return nil, err + } + packageNames := map[string]struct{}{} + for _, filePath := range files { + parsed, exists, err := cache.parseFileAtRevision(ctx, revision, filePath) + if err != nil { + return nil, err + } + if !exists { + continue + } + packageNames[parsed.Snapshot.packageName] = struct{}{} + } + keys := make([]packageKey, 0, len(packageNames)) + for packageName := range packageNames { + keys = append(keys, packageKey{Dir: filepath.ToSlash(filepath.Clean(dir)), Name: packageName}) + } + slices.SortFunc(keys, func(left, right packageKey) int { + return cmp.Compare(left.Name, right.Name) + }) + inventories := make([]packageInventory, 0, len(keys)) + for _, key := range keys { + inventory, err := cache.loadPackageInventory(ctx, revision, key) + if err != nil { + return nil, err + } + inventories = append(inventories, inventory) + } + return inventories, nil +} diff --git a/inventory_test.go b/inventory_test.go new file mode 100644 index 0000000..190d5c1 --- /dev/null +++ b/inventory_test.go @@ -0,0 +1,122 @@ +package main + +import ( + "context" + "maps" + "slices" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoadPackageInventoryReturnsParseErrors(t *testing.T) { + t.Parallel() + + repo := fakeGitRepo{ + revisions: map[string]map[string]string{ + "head": { + "pkg/good_test.go": `package sample + +import "testing" + +func TestGood(t *testing.T) {} +`, + "pkg/broken_test.go": `package sample + +import "testing" + +func TestBroken(t *testing.T) { +`, + }, + }, + } + cache := newInventoryCache(config{RepoRoot: "/repo"}, repo.runner(t)) + _, err := cache.loadPackageInventory(t.Context(), "head", packageKey{Dir: "pkg", Name: "sample"}) + require.ErrorContains(t, err, "parse pkg/broken_test.go at head") +} + +func TestLoadPackageInventoryCachesResults(t *testing.T) { + t.Parallel() + + repo := fakeGitRepo{ + revisions: map[string]map[string]string{ + "head": { + "pkg/alpha_test.go": `package sample + +import "testing" + +func TestAlpha(t *testing.T) {} +`, + }, + }, + } + counter := newCountingGitRunner(repo.runner(t)) + cache := newInventoryCache(config{RepoRoot: "/repo"}, counter.run) + key := packageKey{Dir: "pkg", Name: "sample"} + inventory, err := cache.loadPackageInventory(t.Context(), "head", key) + require.NoError(t, err) + require.Equal(t, []string{"TestAlpha"}, slices.Sorted(maps.Keys(inventory.Tests))) + firstCommandCount := counter.total + + inventory, err = cache.loadPackageInventory(t.Context(), "head", key) + require.NoError(t, err) + require.Equal(t, []string{"TestAlpha"}, slices.Sorted(maps.Keys(inventory.Tests))) + require.Equal(t, firstCommandCount, counter.total) +} + +func TestLoadDirectoryInventoriesSharesSnapshotsAcrossPackages(t *testing.T) { + t.Parallel() + + repo := fakeGitRepo{ + revisions: map[string]map[string]string{ + "head": { + "pkg/internal_test.go": `package foo + +import "testing" + +func TestInternal(t *testing.T) {} +`, + "pkg/external_test.go": `package foo_test + +import "testing" + +func TestExternal(t *testing.T) {} +`, + }, + }, + } + counter := newCountingGitRunner(repo.runner(t)) + cache := newInventoryCache(config{RepoRoot: "/repo"}, counter.run) + inventories, err := cache.loadDirectoryInventories(t.Context(), "head", "pkg") + require.NoError(t, err) + require.Len(t, inventories, 2) + require.Equal(t, []packageKey{ + {Dir: "pkg", Name: "foo"}, + {Dir: "pkg", Name: "foo_test"}, + }, []packageKey{inventories[0].Key, inventories[1].Key}) + require.Equal(t, 1, counter.counts["cat-file"]) + require.Equal(t, 1, counter.counts["ls-tree"]) + require.Equal(t, 2, counter.counts["show"]) + firstCommandCount := counter.total + + inventories, err = cache.loadDirectoryInventories(t.Context(), "head", "pkg") + require.NoError(t, err) + require.Len(t, inventories, 2) + require.Equal(t, firstCommandCount, counter.total) +} + +type countingGitRunner struct { + runner gitRunner + counts map[string]int + total int +} + +func newCountingGitRunner(runner gitRunner) *countingGitRunner { + return &countingGitRunner{runner: runner, counts: map[string]int{}} +} + +func (counter *countingGitRunner) run(ctx context.Context, dir string, args ...string) (gitResult, error) { + counter.counts[args[0]]++ + counter.total++ + return counter.runner(ctx, dir, args...) +} diff --git a/plan.go b/plan.go new file mode 100644 index 0000000..8ea8c06 --- /dev/null +++ b/plan.go @@ -0,0 +1,288 @@ +package main + +import ( + "context" + "fmt" + "maps" + "regexp" + "slices" + "strconv" + "strings" +) + +var ( + safeTestNameRE = regexp.MustCompile(`^[A-Za-z0-9_]+$`) + safePackagePatternRE = regexp.MustCompile(`^(?:\.|\./[A-Za-z0-9._/-]+)$`) +) + +type matrixOutput struct { + Include []matrixEntry `json:"include"` +} + +// matrixEntry.Package is a single safe package token except for overflow rows, +// where it is a space-separated list of safe package tokens consumed by the +// flake-go workflow. +type matrixEntry struct { + Package string `json:"package"` + RunRegex string `json:"run_regex,omitempty"` + TestCount string `json:"test_count"` +} + +type summaryReport struct { + Entries []summaryEntry +} + +type summaryEntry struct { + Label string + Files []string + Tests []string + RunAll bool + TestCount string + Notes []string +} + +type buildResult struct { + Matrix matrixOutput + Summary summaryReport +} + +type executionAccumulator struct { + Files map[string]struct{} + Tests map[string]struct{} + Broadened bool + RunAll bool + TestCount string + Notes []string +} + +func selectTestPlan(ctx context.Context, cfg config, git gitRunner) ([]string, buildResult, error) { + changes, err := listChangedTestFiles(ctx, cfg, git) + if err != nil { + return nil, buildResult{}, err + } + changedFiles := make([]string, 0, len(changes)) + for _, change := range changes { + changedFiles = append(changedFiles, change.displayPath()) + } + + cache := newInventoryCache(cfg, git) + selections := map[packageKey]*packageSelection{} + for _, change := range changes { + if err = selectChange(ctx, cache, selections, change); err != nil { + return nil, buildResult{}, err + } + } + + result, err := buildExecutionPlan(selections) + if err != nil { + return nil, buildResult{}, err + } + return changedFiles, result, nil +} + +func buildExecutionPlan(selections map[packageKey]*packageSelection) (buildResult, error) { + accumulators := map[string]*executionAccumulator{} + for key, selection := range selections { + packagePath := packagePattern(key.Dir) + if !isSafePackagePattern(packagePath) { + return buildResult{}, fmt.Errorf("unsafe package path %q", packagePath) + } + entry := accumulators[packagePath] + if entry == nil { + entry = &executionAccumulator{ + Files: map[string]struct{}{}, + Tests: map[string]struct{}{}, + TestCount: defaultTestCount, + } + accumulators[packagePath] = entry + } + entry.Broadened = entry.Broadened || selection.Broadened + maps.Copy(entry.Files, selection.Files) + maps.Copy(entry.Tests, selection.Tests) + } + + orderedPackages := slices.Sorted(maps.Keys(accumulators)) + result := buildResult{Matrix: matrixOutput{Include: []matrixEntry{}}} + for _, packagePath := range orderedPackages { + entry := accumulators[packagePath] + tests := slices.Sorted(maps.Keys(entry.Tests)) + files := slices.Sorted(maps.Keys(entry.Files)) + if entry.Broadened && len(tests) > maxBroadenedTests { + entry.RunAll = true + entry.TestCount = runOnceTestCount + entry.Notes = append(entry.Notes, fmt.Sprintf("Package-wide broadening selected %d tests, above the %d-test cap, so this target will run all tests once.", len(tests), maxBroadenedTests)) + } + if unsafeTestCount := unsafeRunRegexTestCount(tests); unsafeTestCount > 0 { + entry.RunAll = true + entry.TestCount = runOnceTestCount + entry.Notes = append(entry.Notes, fmt.Sprintf("Selected %d test names that cannot be passed safely through RUN, so this target will run all tests once.", unsafeTestCount)) + } + runRegex := "" + if !entry.RunAll { + runRegex = buildRunRegex(tests) + } + result.Matrix.Include = append(result.Matrix.Include, matrixEntry{ + Package: packagePath, + RunRegex: runRegex, + TestCount: entry.TestCount, + }) + result.Summary.Entries = append(result.Summary.Entries, summaryEntry{ + Label: packagePath, + Files: files, + Tests: tests, + RunAll: entry.RunAll, + TestCount: entry.TestCount, + Notes: entry.Notes, + }) + } + + if len(result.Matrix.Include) > maxMatrixEntries { + keep := maxMatrixEntries - 1 + overflowPackages := make([]string, 0, len(result.Matrix.Include)-keep) + overflowFiles := map[string]struct{}{} + for _, entry := range result.Matrix.Include[keep:] { + overflowPackages = append(overflowPackages, entry.Package) + } + for _, entry := range result.Summary.Entries[keep:] { + for _, filePath := range entry.Files { + overflowFiles[filePath] = struct{}{} + } + } + note := fmt.Sprintf("Matrix target cap %d hit. Collapsed %d additional packages into one overflow target that runs once.", maxMatrixEntries, len(overflowPackages)) + result.Matrix.Include = result.Matrix.Include[:keep] + result.Matrix.Include = append(result.Matrix.Include, matrixEntry{ + Package: strings.Join(overflowPackages, " "), + TestCount: runOnceTestCount, + }) + result.Summary.Entries = result.Summary.Entries[:keep] + result.Summary.Entries = append(result.Summary.Entries, summaryEntry{ + Label: fmt.Sprintf("overflow target (%d packages)", len(overflowPackages)), + Files: slices.Sorted(maps.Keys(overflowFiles)), + RunAll: true, + TestCount: runOnceTestCount, + Notes: []string{ + note, + summarizePackages(overflowPackages), + }, + }) + } + + return result, nil +} + +func summarizePackages(packages []string) string { + display := packages + if len(display) > maxOverflowSummaries { + display = display[:maxOverflowSummaries] + } + quoted := make([]string, 0, len(display)) + for _, packagePath := range display { + quoted = append(quoted, "`"+packagePath+"`") + } + note := "Packages: " + strings.Join(quoted, ", ") + if len(packages) > len(display) { + note += fmt.Sprintf(", and %d more.", len(packages)-len(display)) + } + return note +} + +func isSafePackagePattern(packagePath string) bool { + if !safePackagePatternRE.MatchString(packagePath) { + return false + } + if packagePath == "." { + return true + } + trimmed, ok := strings.CutPrefix(packagePath, "./") + if !ok { + return false + } + for segment := range strings.SplitSeq(trimmed, "/") { + if segment == ".." { + return false + } + } + return true +} + +func unsafeRunRegexTestCount(tests []string) int { + count := 0 + for _, testName := range tests { + if !safeTestNameRE.MatchString(testName) { + count++ + } + } + return count +} + +func buildRunRegex(tests []string) string { + quoted := make([]string, 0, len(tests)) + for _, testName := range tests { + quoted = append(quoted, regexp.QuoteMeta(testName)) + } + return "^(" + strings.Join(quoted, "|") + ")(/.*)?$" +} + +func renderSummary(changedFiles []string, summary summaryReport) string { + var builder strings.Builder + _, _ = builder.WriteString("## Go test flake detector selection\n\n") + if len(changedFiles) == 0 { + _, _ = builder.WriteString("No changed `*_test.go` files were detected.\n") + return builder.String() + } + if len(summary.Entries) == 0 { + _, _ = builder.WriteString("Changed `*_test.go` files were detected, but no runnable top-level tests were selected.\n\n") + _, _ = builder.WriteString("Files:\n") + for _, filePath := range changedFiles { + _, _ = builder.WriteString("- " + renderSummaryFilePath(filePath) + "\n") + } + return builder.String() + } + + totalTests := 0 + for _, entry := range summary.Entries { + totalTests += len(entry.Tests) + } + _, _ = fmt.Fprintf(&builder, "Selected %d tests across %d package targets.\n\n", totalTests, len(summary.Entries)) + for _, entry := range summary.Entries { + _, _ = builder.WriteString("### `" + entry.Label + "`\n\n") + _, _ = builder.WriteString("Files:\n") + for _, filePath := range entry.Files { + _, _ = builder.WriteString("- " + renderSummaryFilePath(filePath) + "\n") + } + if len(entry.Notes) > 0 { + _, _ = builder.WriteString("\nNotes:\n") + for _, note := range entry.Notes { + _, _ = builder.WriteString("- " + note + "\n") + } + } + if entry.RunAll { + _, _ = builder.WriteString("\nRuns all tests in this target " + countDescription(entry.TestCount) + ".\n") + if len(entry.Tests) > 0 { + _, _ = builder.WriteString("\nAttributed tests:\n") + for _, testName := range entry.Tests { + _, _ = builder.WriteString("- `" + testName + "`\n") + } + } + _, _ = builder.WriteString("\n") + continue + } + _, _ = builder.WriteString("\nTests:\n") + for _, testName := range entry.Tests { + _, _ = builder.WriteString("- `" + testName + "`\n") + } + _, _ = builder.WriteString("\n") + } + return builder.String() +} + +func renderSummaryFilePath(filePath string) string { + return strconv.QuoteToASCII(filePath) +} + +func countDescription(count string) string { + if count == "1" { + return "once" + } + return count + " times" +} diff --git a/plan_test.go b/plan_test.go new file mode 100644 index 0000000..4fa8e12 --- /dev/null +++ b/plan_test.go @@ -0,0 +1,153 @@ +package main + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRenderSummaryNoChangedFiles(t *testing.T) { + t.Parallel() + + summary := renderSummary(nil, summaryReport{}) + require.Contains(t, summary, "No changed `*_test.go` files were detected") +} + +func TestRenderSummaryNoRunnableTests(t *testing.T) { + t.Parallel() + + summary := renderSummary([]string{"pkg/changed_test.go"}, summaryReport{}) + require.Contains(t, summary, "no runnable top-level tests were selected") + require.Contains(t, summary, "pkg/changed_test.go") +} + +func TestRenderSummaryQuotesFilenames(t *testing.T) { + t.Parallel() + + summary := renderSummary([]string{"pkg/with`tick_test.go", "pkg/with\nnewline_test.go"}, summaryReport{}) + require.Contains(t, summary, `"pkg/with`+"`"+`tick_test.go"`) + require.Contains(t, summary, `"pkg/with\nnewline_test.go"`) + require.NotContains(t, summary, "pkg/with\nnewline_test.go") +} + +func TestBuildExecutionPlanRunsAllForUnsafeTestNames(t *testing.T) { + t.Parallel() + + selection := &packageSelection{ + Key: packageKey{Dir: "pkg", Name: "sample"}, + Tests: map[string]struct{}{"TestAlpha": {}, "TestĪ›": {}}, + Files: map[string]struct{}{"pkg/sample_test.go": {}}, + } + result, err := buildExecutionPlan(map[packageKey]*packageSelection{selection.Key: selection}) + require.NoError(t, err) + require.Len(t, result.Matrix.Include, 1) + require.Empty(t, result.Matrix.Include[0].RunRegex) + require.Equal(t, "1", result.Matrix.Include[0].TestCount) + require.True(t, result.Summary.Entries[0].RunAll) + require.Contains(t, result.Summary.Entries[0].Notes[0], "cannot be passed safely") +} + +func TestBuildExecutionPlanRejectsUnsafePackagePaths(t *testing.T) { + t.Parallel() + + key := packageKey{Dir: "pkg$(echo bad)", Name: "sample"} + _, err := buildExecutionPlan(map[packageKey]*packageSelection{ + key: { + Key: key, + Tests: map[string]struct{}{"TestAlpha": {}}, + Files: map[string]struct{}{"pkg$(echo bad)/sample_test.go": {}}, + }, + }) + require.ErrorContains(t, err, "unsafe package path") +} + +func TestIsSafePackagePatternAllowsSafeNamesAndRejectsTraversal(t *testing.T) { + t.Parallel() + + for _, packagePath := range []string{".", "./foo_bar", "./foo-bar", "./foo.bar", "./foo/bar_baz"} { + require.True(t, isSafePackagePattern(packagePath), packagePath) + } + for _, packagePath := range []string{"./foo/../bar", "./..", "./foo/..", "../foo", "./foo bar"} { + require.False(t, isSafePackagePattern(packagePath), packagePath) + } +} + +func TestBuildExecutionPlanCapsBroadenedTarget(t *testing.T) { + t.Parallel() + + selection := &packageSelection{ + Key: packageKey{Dir: "pkg", Name: "sample"}, + Tests: map[string]struct{}{}, + Files: map[string]struct{}{"pkg/setup_test.go": {}}, + Broadened: true, + } + for index := range maxBroadenedTests + 1 { + selection.Tests[fmt.Sprintf("Test%03d", index)] = struct{}{} + } + result, err := buildExecutionPlan(map[packageKey]*packageSelection{selection.Key: selection}) + require.NoError(t, err) + require.Len(t, result.Matrix.Include, 1) + require.Equal(t, "1", result.Matrix.Include[0].TestCount) + require.Empty(t, result.Matrix.Include[0].RunRegex) + require.True(t, result.Summary.Entries[0].RunAll) + require.Contains(t, result.Summary.Entries[0].Notes[0], "above the 50-test cap") +} + +func TestBuildExecutionPlanCapsMatrixTargets(t *testing.T) { + t.Parallel() + + selections := map[packageKey]*packageSelection{} + for index := range maxMatrixEntries + maxOverflowSummaries + 2 { + key := packageKey{Dir: fmt.Sprintf("pkg%02d", index), Name: "sample"} + selections[key] = &packageSelection{ + Key: key, + Tests: map[string]struct{}{fmt.Sprintf("Test%02d", index): {}}, + Files: map[string]struct{}{fmt.Sprintf("pkg%02d/file_test.go", index): {}}, + } + } + result, err := buildExecutionPlan(selections) + require.NoError(t, err) + require.Len(t, result.Matrix.Include, maxMatrixEntries) + overflow := result.Matrix.Include[len(result.Matrix.Include)-1] + require.Equal(t, strings.Join([]string{ + "./pkg19", "./pkg20", "./pkg21", "./pkg22", "./pkg23", "./pkg24", "./pkg25", + "./pkg26", "./pkg27", "./pkg28", "./pkg29", "./pkg30", "./pkg31", + }, " "), overflow.Package) + require.Empty(t, overflow.RunRegex) + require.Equal(t, "1", overflow.TestCount) + for _, packagePath := range strings.Fields(overflow.Package) { + require.True(t, isSafePackagePattern(packagePath), packagePath) + } + overflowSummary := result.Summary.Entries[len(result.Summary.Entries)-1] + require.Contains(t, overflowSummary.Notes[0], "Matrix target cap") + require.Contains(t, overflowSummary.Notes[1], "and 3 more") + summary := renderSummary([]string{"pkg00/file_test.go"}, result.Summary) + require.Equal(t, 1, strings.Count(summary, "Matrix target cap")) +} + +func TestBuildExecutionPlanKeepsSameNamePackageAndExternalTestsPrecise(t *testing.T) { + t.Parallel() + + selections := map[packageKey]*packageSelection{ + {Dir: "pkg", Name: "sample"}: { + Key: packageKey{Dir: "pkg", Name: "sample"}, + Tests: map[string]struct{}{"TestShared": {}}, + Files: map[string]struct{}{"pkg/internal_test.go": {}}, + }, + {Dir: "pkg", Name: "sample_test"}: { + Key: packageKey{Dir: "pkg", Name: "sample_test"}, + Tests: map[string]struct{}{"TestShared": {}}, + Files: map[string]struct{}{"pkg/external_test.go": {}}, + }, + } + result, err := buildExecutionPlan(selections) + require.NoError(t, err) + require.Len(t, result.Matrix.Include, 1) + require.Equal(t, "./pkg", result.Matrix.Include[0].Package) + require.Equal(t, "^(TestShared)(/.*)?$", result.Matrix.Include[0].RunRegex) + require.Equal(t, "10", result.Matrix.Include[0].TestCount) + require.False(t, result.Summary.Entries[0].RunAll) + require.Empty(t, result.Summary.Entries[0].Notes) +} diff --git a/publish.go b/publish.go new file mode 100644 index 0000000..30d4502 --- /dev/null +++ b/publish.go @@ -0,0 +1,114 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +const defaultGitHubOutputValueLimit = 1024 * 1024 + +func publishPlan(sinks outputSinks, matrix matrixOutput, summary string, stdout io.Writer) error { + matrixData, err := marshalMatrix(matrix) + if err != nil { + return err + } + if sinks.OutMatrix != "" { + if err := writeFile(sinks.OutMatrix, append(matrixData, '\n')); err != nil { + return err + } + } + if sinks.OutSummary != "" { + if err := writeSummary(sinks.OutSummary, summary, stdout); err != nil { + return err + } + } + if sinks.GitHubOutput != "" { + if err := appendGitHubOutput(sinks.GitHubOutput, "matrix", string(matrixData)); err != nil { + return err + } + } + if sinks.GitHubStepSummary != "" { + if err := appendFile(sinks.GitHubStepSummary, []byte(summary)); err != nil { + return err + } + } + return nil +} + +func marshalMatrix(matrix matrixOutput) ([]byte, error) { + if matrix.Include == nil { + matrix.Include = []matrixEntry{} + } + data, err := json.Marshal(matrix) + if err != nil { + return nil, fmt.Errorf("marshal matrix json: %w", err) + } + return data, nil +} + +func appendGitHubOutput(path, name, value string) error { + if err := ensureGitHubOutputFits(name, value, defaultGitHubOutputValueLimit); err != nil { + return err + } + return appendFile(path, []byte(name+"="+value+"\n")) +} + +func ensureGitHubOutputFits(name, value string, limit int) error { + if strings.ContainsAny(value, "\r\n") { + return fmt.Errorf("GitHub output %s must be a single line", name) + } + if len(value) > limit { + return fmt.Errorf("GitHub output %s is %d bytes, above the %d byte limit", name, len(value), limit) + } + return nil +} + +func writeSummary(path, summary string, stdout io.Writer) error { + if path == "-" { + _, err := io.WriteString(stdout, summary) + return err + } + return writeFile(path, []byte(summary)) +} + +func writeFile(path string, data []byte) error { + dir := filepath.Dir(path) + if dir != "." { + if err := os.MkdirAll(dir, 0o750); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("write %s: %w", path, err) + } + return nil +} + +func appendFile(path string, data []byte) (err error) { + dir := filepath.Dir(path) + if dir != "." { + if err = os.MkdirAll(dir, 0o750); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + } + // #nosec G304: path is a user-supplied output path or a GitHub Actions runner path. + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) + if err != nil { + return fmt.Errorf("open %s: %w", path, err) + } + defer func() { + // Surface Close errors only if Write succeeded; write paths can + // lose data on a deferred fsync/flush failure. + if cerr := file.Close(); cerr != nil && err == nil { + err = fmt.Errorf("close %s: %w", path, cerr) + } + }() + if _, err := file.Write(data); err != nil { + return fmt.Errorf("append %s: %w", path, err) + } + return nil +} diff --git a/publish_test.go b/publish_test.go new file mode 100644 index 0000000..26eb816 --- /dev/null +++ b/publish_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPublishPlanWritesCompactGitHubOutputs(t *testing.T) { + t.Parallel() + + root := t.TempDir() + matrixPath := filepath.Join(root, "matrix.json") + summaryPath := filepath.Join(root, "summary.md") + outputPath := filepath.Join(root, "output.txt") + stepSummaryPath := filepath.Join(root, "step-summary.md") + summary := "## Summary\n" + err := publishPlan(outputSinks{ + OutMatrix: matrixPath, + OutSummary: summaryPath, + GitHubOutput: outputPath, + GitHubStepSummary: stepSummaryPath, + }, matrixOutput{Include: []matrixEntry{{Package: "./pkg", RunRegex: "^(TestAlpha)(/.*)?$", TestCount: "10"}}}, summary, nil) + require.NoError(t, err) + + matrixData, err := os.ReadFile(matrixPath) + require.NoError(t, err) + wantMatrix := `{"include":[{"package":"./pkg","run_regex":"^(TestAlpha)(/.*)?$","test_count":"10"}]}` + require.Equal(t, wantMatrix+"\n", string(matrixData)) + + outputData, err := os.ReadFile(outputPath) + require.NoError(t, err) + require.Equal(t, "matrix="+wantMatrix+"\n", string(outputData)) + outputValue := strings.TrimSuffix(strings.TrimPrefix(string(outputData), "matrix="), "\n") + require.NotContains(t, outputValue, "\n") + + localSummary, err := os.ReadFile(summaryPath) + require.NoError(t, err) + require.Equal(t, summary, string(localSummary)) + stepSummary, err := os.ReadFile(stepSummaryPath) + require.NoError(t, err) + require.Equal(t, summary, string(stepSummary)) +} + +func TestPublishPlanWritesEmptyMatrixAndRejectsUnsafeOutput(t *testing.T) { + t.Parallel() + + matrixData, err := marshalMatrix(matrixOutput{}) + require.NoError(t, err) + require.Equal(t, `{"include":[]}`, string(matrixData)) + + err = ensureGitHubOutputFits("matrix", "first\nsecond", defaultGitHubOutputValueLimit) + require.ErrorContains(t, err, "single line") + + err = ensureGitHubOutputFits("matrix", "too-long", 3) + require.ErrorContains(t, err, "above the 3 byte limit") +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..e5976b1 --- /dev/null +++ b/request.go @@ -0,0 +1,53 @@ +package main + +import ( + "fmt" + "strings" +) + +type diffRange struct { + BaseSHA string + HeadSHA string +} + +type runRequest struct { + RepoRoot string + Range diffRange + Fetches []fetchSpec + MergeBaseRef string + Sinks outputSinks +} + +type fetchSpec struct { + Remote string + Ref string +} + +type outputSinks struct { + OutMatrix string + OutSummary string + GitHubOutput string + GitHubStepSummary string +} + +// validateRevisionArg rejects git revision strings that would be unsafe to pass +// as a single argv element. It is not a SHA-format validator. +func validateRevisionArg(name, revision string) error { + if revision == "" { + return fmt.Errorf("%s is required", name) + } + if strings.HasPrefix(revision, "-") { + return fmt.Errorf("%s must not start with '-': %q", name, revision) + } + if strings.Contains(revision, ":") { + return fmt.Errorf("%s must not contain ':': %q", name, revision) + } + if strings.ContainsRune(revision, '\x00') { + return fmt.Errorf("%s must not contain NUL bytes", name) + } + return nil +} + +func diffRangeSpec(cfg config) string { + return cfg.BaseSHA + "..." + cfg.HeadSHA +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..53cf13a --- /dev/null +++ b/request_test.go @@ -0,0 +1,44 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateRevisionArgAllowsCommonGitRevisions(t *testing.T) { + t.Parallel() + + for _, revision := range []string{ + "HEAD", + "HEAD~3", + "origin/main", + "refs/heads/main", + "v1.2.3", + "abc1234", + "0123456789abcdef0123456789abcdef01234567", + } { + require.NoError(t, validateRevisionArg("revision", revision), revision) + } +} + +func TestValidateRevisionArgRejectsUnsafeValues(t *testing.T) { + t.Parallel() + + tests := []struct { + revision string + want string + }{ + {revision: "", want: "is required"}, + {revision: "-bad", want: "must not start with '-'"}, + {revision: "head:bad", want: "must not contain ':'"}, + {revision: "head\x00bad", want: "must not contain NUL bytes"}, + } + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + t.Parallel() + err := validateRevisionArg("--head-sha", tt.revision) + require.ErrorContains(t, err, tt.want) + }) + } +} diff --git a/selection.go b/selection.go new file mode 100644 index 0000000..879ad44 --- /dev/null +++ b/selection.go @@ -0,0 +1,298 @@ +package main + +import ( + "context" + "fmt" + "maps" + "path/filepath" + "slices" +) + +type packageKey struct { + Dir string + Name string +} + +type packageInventory struct { + Key packageKey + Tests map[string]struct{} +} + +type packageSelection struct { + Key packageKey + Tests map[string]struct{} + Files map[string]struct{} + Broadened bool + DirectoryWide bool +} + +type parsedFileSnapshot struct { + Key packageKey + Snapshot fileSnapshot +} + +// selectChange handles the four diff states for a runnable test file: add +// (old absent, new present), delete (old present, new absent), in-place modify +// (both sides present with the same package key), and cross-package move or +// package rename (both sides present with different package keys). +func selectChange(ctx context.Context, cache *inventoryCache, selections map[packageKey]*packageSelection, change testFileChange) error { + cfg := cache.cfg + hunks, err := listDiffHunks(ctx, cfg, cache.git, change) + if err != nil { + return fmt.Errorf("list diff hunks for %s: %w", change.displayPath(), err) + } + if len(hunks) == 0 { + return nil + } + + oldParsed, oldExists, err := cache.parseChangeFileAtRevision(ctx, cfg.BaseSHA, change.OldPath) + if err != nil { + return fmt.Errorf("resolve old package for %s: %w", change.displayPath(), err) + } + newParsed, newExists, err := cache.parseChangeFileAtRevision(ctx, cfg.HeadSHA, change.NewPath) + if err != nil { + return fmt.Errorf("resolve new package for %s: %w", change.displayPath(), err) + } + if change.expectsOldFile() && !oldExists { + return fmt.Errorf("base revision %s is missing %s", cfg.BaseSHA, change.OldPath) + } + if change.expectsNewFile() && !newExists { + return fmt.Errorf("head revision %s is missing %s", cfg.HeadSHA, change.NewPath) + } + + var oldFile *parsedFileSnapshot + if oldExists { + oldFile = &oldParsed + } + var newFile *parsedFileSnapshot + if newExists { + newFile = &newParsed + } + + if newFile != nil { + inventory, err := cache.loadPackageInventory(ctx, cfg.HeadSHA, newFile.Key) + if err != nil { + return fmt.Errorf("load package inventory for %s: %w", newFile.Key.String(), err) + } + var oldSnapshot *fileSnapshot + selectionHunks := hunks + if oldFile != nil && oldFile.Key == newFile.Key { + oldSnapshot = &oldFile.Snapshot + } else { + selectionHunks = newSideOnlyHunks(hunks) + } + selection := selectTestsFromHunks(change, oldSnapshot, newFile.Snapshot, inventory, selectionHunks) + if err := mergeSelection(ctx, cache, selections, selection); err != nil { + return err + } + } + + if oldFile != nil && (newFile == nil || oldFile.Key != newFile.Key) { + inventory, err := cache.loadPackageInventory(ctx, cfg.HeadSHA, oldFile.Key) + if err != nil { + return fmt.Errorf("load package inventory for %s: %w", oldFile.Key.String(), err) + } + sourceChange := testFileChange{Kind: changeDeleted, OldPath: change.OldPath} + selection := selectSourceRemoval(sourceChange, oldFile.Snapshot, inventory, hunks) + if err := mergeSelection(ctx, cache, selections, selection); err != nil { + return err + } + } + + return nil +} + +func (change testFileChange) expectsOldFile() bool { + oldRequired, _ := change.Kind.expectedFileSides() + return oldRequired && isRunnableTestFilePath(change.OldPath) +} + +func (change testFileChange) expectsNewFile() bool { + _, newRequired := change.Kind.expectedFileSides() + return newRequired && isRunnableTestFilePath(change.NewPath) +} + +func (kind changeKind) expectedFileSides() (oldRequired bool, newRequired bool) { + switch kind { + case changeAdded: + return false, true + case changeDeleted: + return true, false + case changeModified, changeRenamed, changeType: + return true, true + } + // Unknown change kinds intentionally require both sides so selectChange fails + // loud. parseChangeKind is the choke point for supported git diff statuses. + return true, true +} + +func parseSnapshotForPath(filePath string, data []byte) (parsedFileSnapshot, error) { + snapshot, err := parseFileSnapshot(data) + if err != nil { + return parsedFileSnapshot{}, fmt.Errorf("parse package clause: %w", err) + } + return parsedFileSnapshot{ + Key: packageKey{Dir: filepath.ToSlash(filepath.Dir(filePath)), Name: snapshot.packageName}, + Snapshot: snapshot, + }, nil +} + +func mergeSelection(ctx context.Context, cache *inventoryCache, selections map[packageKey]*packageSelection, selection *packageSelection) error { + if selection == nil { + return nil + } + if !selection.DirectoryWide { + if len(selection.Tests) > 0 { + mergePackageSelection(selections, selection) + } + return nil + } + + expanded, err := cache.directoryWideSelections(ctx, cache.cfg.HeadSHA, selection.Key.Dir, selection.Files) + if err != nil { + return fmt.Errorf("load directory-wide inventory for %s: %w", packagePattern(selection.Key.Dir), err) + } + for _, expandedSelection := range expanded { + mergePackageSelection(selections, expandedSelection) + } + return nil +} + +func mergePackageSelection(selections map[packageKey]*packageSelection, selection *packageSelection) { + merged := selections[selection.Key] + if merged == nil { + merged = &packageSelection{ + Key: selection.Key, + Tests: map[string]struct{}{}, + Files: map[string]struct{}{}, + } + selections[selection.Key] = merged + } + merged.Broadened = merged.Broadened || selection.Broadened + maps.Copy(merged.Files, selection.Files) + maps.Copy(merged.Tests, selection.Tests) +} + +func selectTestsFromHunks(change testFileChange, oldSnapshot *fileSnapshot, newSnapshot fileSnapshot, newInventory packageInventory, hunks []diffHunk) *packageSelection { + if oldSnapshot == nil && needsOldSnapshot(hunks) { + return allPackageTestsSelection(newInventory, change.displayPath()) + } + + selected := map[string]struct{}{} + for _, hunk := range hunks { + if oldSnapshot != nil { + switch broadeningScopeForOldHunk(oldSnapshot.shared, hunk.Old) { + case broadeningDirectory: + return allDirectoryTestsSelection(newInventory.Key.Dir, change.displayPath()) + case broadeningPackage: + return allPackageTestsSelection(newInventory, change.displayPath()) + } + } + switch broadeningScopeForNewHunk(newSnapshot.shared, oldSnapshot, hunk.New) { + case broadeningDirectory: + return allDirectoryTestsSelection(newInventory.Key.Dir, change.displayPath()) + case broadeningPackage: + return allPackageTestsSelection(newInventory, change.displayPath()) + } + addMatchingTests(selected, newSnapshot.tests, hunk.New) + if oldSnapshot == nil { + continue + } + for name, declRange := range oldSnapshot.tests { + if !declRange.overlaps(hunk.Old) { + continue + } + if _, ok := newInventory.Tests[name]; ok { + selected[name] = struct{}{} + } + } + } + if len(selected) == 0 { + return nil + } + return &packageSelection{ + Key: newInventory.Key, + Tests: selected, + Files: map[string]struct{}{change.displayPath(): {}}, + } +} + +func selectSourceRemoval(change testFileChange, oldSnapshot fileSnapshot, inventory packageInventory, hunks []diffHunk) *packageSelection { + selected := map[string]struct{}{} + for _, hunk := range hunks { + switch broadeningScopeForOldHunk(oldSnapshot.shared, hunk.Old) { + case broadeningDirectory: + return allDirectoryTestsSelection(inventory.Key.Dir, change.displayPath()) + case broadeningPackage: + return allPackageTestsSelection(inventory, change.displayPath()) + } + for name, declRange := range oldSnapshot.tests { + if !declRange.overlaps(hunk.Old) { + continue + } + if _, ok := inventory.Tests[name]; ok { + selected[name] = struct{}{} + } + } + } + if len(selected) == 0 { + return nil + } + return &packageSelection{ + Key: inventory.Key, + Tests: selected, + Files: map[string]struct{}{change.displayPath(): {}}, + } +} + +func allPackageTestsSelection(inventory packageInventory, filePath string) *packageSelection { + return allPackageTestsSelectionForFiles(inventory, map[string]struct{}{filePath: {}}) +} + +func allPackageTestsSelectionForFiles(inventory packageInventory, files map[string]struct{}) *packageSelection { + selection := &packageSelection{ + Key: inventory.Key, + Tests: map[string]struct{}{}, + Files: files, + Broadened: true, + } + maps.Copy(selection.Tests, inventory.Tests) + if len(selection.Tests) == 0 { + return nil + } + return selection +} + +func allDirectoryTestsSelection(dir, filePath string) *packageSelection { + return &packageSelection{ + Key: packageKey{Dir: dir}, + Files: map[string]struct{}{filePath: {}}, + DirectoryWide: true, + } +} + +func needsOldSnapshot(hunks []diffHunk) bool { + return slices.ContainsFunc(hunks, func(hunk diffHunk) bool { + return hunk.Old.hasLines() + }) +} + +func addMatchingTests(selected map[string]struct{}, tests map[string]lineRange, candidate lineRange) { + for name, declRange := range tests { + if declRange.overlaps(candidate) { + selected[name] = struct{}{} + } + } +} + +func (key packageKey) String() string { + return fmt.Sprintf("%s (%s)", packagePattern(key.Dir), key.Name) +} + +func packagePattern(dir string) string { + cleanDir := filepath.ToSlash(filepath.Clean(dir)) + if cleanDir == "." { + return "." + } + return "./" + cleanDir +} diff --git a/selection_test.go b/selection_test.go new file mode 100644 index 0000000..1b8c11a --- /dev/null +++ b/selection_test.go @@ -0,0 +1,783 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSelectTestsForSnapshots(t *testing.T) { + t.Parallel() + + const changedPath = "pkg/changed_test.go" + change := testFileChange{Kind: changeModified, OldPath: changedPath, NewPath: changedPath} + + const ( + // These fixtures hoist repeated row sources. + selectionFixture01 = `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("before alpha") +} + +func TestBeta(t *testing.T) { + t.Log("stable beta") +} +` + selectionFixture02 = `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("changed alpha") +} + +func TestBeta(t *testing.T) { + t.Log("stable beta") +} +` + selectionFixture03 = `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +` + selectionFixture04 = `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} + +func TestBeta(t *testing.T) { + t.Log("new beta") +} +` + selectionFixture05 = `package sample + +import "testing" + +func setup(t *testing.T) { + t.Helper() + t.Log("before helper") +} + +func TestAlpha(t *testing.T) { + setup(t) +} +` + selectionFixture06 = `package sample + +import "testing" + +func setup(t *testing.T) { + t.Helper() + t.Log("changed helper") +} + +func TestAlpha(t *testing.T) { + setup(t) +} +` + selectionFixture07 = `package sample + +import "testing" + +var packageValue = 1 + +func TestAlpha(t *testing.T) { + t.Log(packageValue) +} +` + selectionFixture08 = `package sample + +import "testing" + +var packageValue = 2 + +func TestAlpha(t *testing.T) { + t.Log(packageValue) +} +` + selectionFixture09 = `package sample + +import ( + "testing" +) + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} + +func TestBeta(t *testing.T) { + t.Log("beta") +} +` + selectionFixture10 = `package sample + +import ( + "fmt" + "testing" +) + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} + +func TestBeta(t *testing.T) { + t.Log("beta") +} +` + selectionFixture11 = `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} + +func setupCase(t *testing.T) { + t.Helper() + t.Log("beta helper") +} + +func TestBeta(t *testing.T) { + setupCase(t) +} +` + selectionFixture12 = `package sample + +import ( + "fmt" + "testing" +) + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +` + selectionFixture13 = `package sample + +import ( + "testing" +) + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +` + selectionFixture14 = `package sample + +import "testing" + +func TestBeta(t *testing.T) { + t.Log("beta") +} +` + selectionFixture15 = `package sample + +import ( + "os" + "testing" +) + +func TestMain(m *testing.M) { + os.Exit(m.Run()) +} +` + selectionFixture16 = `package sample + +import ( + "fmt" + "os" + "testing" +) + +func TestMain(m *testing.M) { + fmt.Println("setup") + os.Exit(m.Run()) +} +` + selectionFixture17 = `package sample + +import "testing" + +func init() { + register("before") +} +` + selectionFixture18 = `package sample + +import "testing" + +func init() { + register("after") +} +` + selectionFixture19 = `package sample + +import "testing" + +func setup(t *testing.T) { + t.Helper() + t.Log("helper") +} + +func TestAlpha(t *testing.T) { + setup(t) +} +` + selectionFixture20 = `package sample + +import "testing" + +func TestBeta(t *testing.T) { + t.Log("new beta") +} +` + selectionFixture21 = `package sample + +import . "testing" + +func TestAlpha(t *T) { + t.Log("before alpha") +} +` + selectionFixture22 = `package sample + +import . "testing" + +func TestAlpha(t *T) { + t.Log("changed alpha") +} +` + ) + + tests := []struct { + name string + oldData []byte + newData []byte + inventory packageInventory + hunks []diffHunk + wantTests []string + wantBroadened bool + wantDirectoryWide bool + wantNoSelection bool + }{ + { + name: "body change selects only changed test", + oldData: []byte(selectionFixture01), + newData: []byte(selectionFixture02), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture02, + }), + hunks: []diffHunk{{ + Old: singleLineRange(t, selectionFixture01, `t.Log("before alpha")`), + New: singleLineRange(t, selectionFixture02, `t.Log("changed alpha")`), + }}, + wantTests: []string{"TestAlpha"}, + }, + { + name: "new top-level test selects only new test", + oldData: []byte(selectionFixture03), + newData: []byte(selectionFixture04), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture04, + }), + hunks: []diffHunk{{ + Old: emptyRangeAt(7), + New: singleLineRange(t, selectionFixture04, `t.Log("new beta")`), + }}, + wantTests: []string{"TestBeta"}, + }, + { + name: "existing helper change broadens across package", + oldData: []byte(selectionFixture05), + newData: []byte(selectionFixture06), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture06, + "pkg/sibling_test.go": `package sample + +import "testing" + +func TestBeta(t *testing.T) { + setup(t) +} +`, + }), + hunks: []diffHunk{{ + Old: singleLineRange(t, selectionFixture05, `t.Log("before helper")`), + New: singleLineRange(t, selectionFixture06, `t.Log("changed helper")`), + }}, + wantTests: []string{"TestAlpha", "TestBeta"}, + wantBroadened: true, + }, + { + name: "package variable change broadens across package", + oldData: []byte(selectionFixture07), + newData: []byte(selectionFixture08), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture08, + "pkg/sibling_test.go": `package sample + +import "testing" + +func TestBeta(t *testing.T) { + t.Log(packageValue) +} +`, + }), + hunks: []diffHunk{{ + Old: singleLineRange(t, selectionFixture07, "var packageValue = 1"), + New: singleLineRange(t, selectionFixture08, "var packageValue = 2"), + }}, + wantTests: []string{"TestAlpha", "TestBeta"}, + wantBroadened: true, + }, + { + name: "additive import broadens package", + oldData: []byte(selectionFixture09), + newData: []byte(selectionFixture10), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture10, + }), + hunks: []diffHunk{{ + Old: emptyRangeAt(singleLineRange(t, selectionFixture09, `"testing"`).Start), + New: singleLineRange(t, selectionFixture10, `"fmt"`), + }}, + wantTests: []string{"TestAlpha", "TestBeta"}, + wantBroadened: true, + }, + { + name: "additive helper with new test stays narrow", + oldData: []byte(selectionFixture03), + newData: []byte(selectionFixture11), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture11, + }), + hunks: []diffHunk{{ + Old: emptyRangeAt(7), + New: rangeSpan( + singleLineRange(t, selectionFixture11, "func setupCase(t *testing.T) {"), + singleLineRange(t, selectionFixture11, "setupCase(t)"), + ), + }}, + wantTests: []string{"TestBeta"}, + }, + { + name: "removed import broadens across package", + oldData: []byte(selectionFixture12), + newData: []byte(selectionFixture13), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture13, + "pkg/sibling_test.go": selectionFixture14, + }), + hunks: []diffHunk{{ + Old: singleLineRange(t, selectionFixture12, `"fmt"`), + New: emptyRangeAt(singleLineRange(t, selectionFixture13, `"testing"`).Start), + }}, + wantTests: []string{"TestAlpha", "TestBeta"}, + wantBroadened: true, + }, + { + name: "TestMain broadens across sibling files in same package", + oldData: []byte(selectionFixture15), + newData: []byte(selectionFixture16), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture16, + "pkg/internal_test.go": selectionFixture03, + }), + hunks: []diffHunk{{ + Old: singleLineRange(t, selectionFixture15, `os.Exit(m.Run())`), + New: singleLineRange(t, selectionFixture16, `fmt.Println("setup")`), + }}, + wantDirectoryWide: true, + }, + { + name: "init broadens across sibling files in same package", + oldData: []byte(selectionFixture17), + newData: []byte(selectionFixture18), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture18, + "pkg/internal_test.go": selectionFixture03, + }), + hunks: []diffHunk{{ + Old: singleLineRange(t, selectionFixture17, `register("before")`), + New: singleLineRange(t, selectionFixture18, `register("after")`), + }}, + wantDirectoryWide: true, + }, + { + name: "deleted helper uses old snapshot to broaden package", + oldData: []byte(selectionFixture19), + newData: []byte(selectionFixture03), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture03, + "pkg/sibling_test.go": selectionFixture14, + }), + hunks: []diffHunk{{ + Old: rangeSpan( + singleLineRange(t, selectionFixture19, "func setup(t *testing.T) {"), + singleLineRange(t, selectionFixture19, `t.Log("helper")`), + ), + New: emptyRangeAt(singleLineRange(t, selectionFixture03, `func TestAlpha(t *testing.T) {`).Start), + }}, + wantTests: []string{"TestAlpha", "TestBeta"}, + wantBroadened: true, + }, + { + name: "brand-new file with additive hunk selects only new tests", + oldData: nil, + newData: []byte(selectionFixture20), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture20, + }), + hunks: []diffHunk{{ + Old: emptyRangeAt(1), + New: rangeSpan( + singleLineRange(t, selectionFixture20, "func TestBeta(t *testing.T) {"), + singleLineRange(t, selectionFixture20, `t.Log("new beta")`), + ), + }}, + wantTests: []string{"TestBeta"}, + }, + { + name: "dot imported testing is recognized", + oldData: []byte(selectionFixture21), + newData: []byte(selectionFixture22), + inventory: mustPackageInventory(t, map[string]string{ + changedPath: selectionFixture22, + }), + hunks: []diffHunk{{ + Old: singleLineRange(t, selectionFixture21, `t.Log("before alpha")`), + New: singleLineRange(t, selectionFixture22, `t.Log("changed alpha")`), + }}, + wantTests: []string{"TestAlpha"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + oldSnapshot := mustOptionalFileSnapshot(t, tt.oldData) + newSnapshot := mustFileSnapshot(t, tt.newData) + selection := selectTestsFromHunks(change, oldSnapshot, newSnapshot, tt.inventory, tt.hunks) + if tt.wantNoSelection { + require.Nil(t, selection) + return + } + require.NotNil(t, selection) + require.Equal(t, tt.wantDirectoryWide, selection.DirectoryWide) + if tt.wantDirectoryWide { + require.Empty(t, selection.Tests) + require.Contains(t, selection.Files, changedPath) + } else { + require.Equal(t, tt.wantTests, selectionNames(selection)) + } + require.Equal(t, tt.wantBroadened, selection.Broadened) + }) + } +} + +func TestSelectTestsForSnapshotsTreatsTestMethodsAsSharedHelpers(t *testing.T) { + t.Parallel() + + change := testFileChange{Kind: changeModified, OldPath: "pkg/changed_test.go", NewPath: "pkg/changed_test.go"} + oldData := []byte(`package sample + +import "testing" + +type suite struct{} + +func (suite) TestMethod(t *testing.T) { + t.Log("before method") +} + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +`) + newData := []byte(`package sample + +import "testing" + +type suite struct{} + +func (suite) TestMethod(t *testing.T) { + t.Log("changed method") +} + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +`) + inventory := mustPackageInventory(t, map[string]string{ + "pkg/changed_test.go": string(newData), + "pkg/sibling_test.go": `package sample + +import "testing" + +func TestBeta(t *testing.T) { + t.Log("beta") +} +`, + }) + oldSnapshot := mustOptionalFileSnapshot(t, oldData) + newSnapshot := mustFileSnapshot(t, newData) + selection := selectTestsFromHunks(change, oldSnapshot, newSnapshot, inventory, []diffHunk{{ + Old: singleLineRange(t, string(oldData), `t.Log("before method")`), + New: singleLineRange(t, string(newData), `t.Log("changed method")`), + }}) + require.NotNil(t, selection) + require.Equal(t, []string{"TestAlpha", "TestBeta"}, selectionNames(selection)) + require.True(t, selection.Broadened) +} + +func TestSelectTestsForSnapshotsAdditiveSharedDeclsStayNarrow(t *testing.T) { + t.Parallel() + + change := testFileChange{Kind: changeModified, OldPath: "pkg/changed_test.go", NewPath: "pkg/changed_test.go"} + basePrefix := `package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} + +` + cases := []struct { + name string + declaration string + needle string + }{ + {name: "var", declaration: "var packageValue = 1\n", needle: "var packageValue = 1"}, + {name: "const", declaration: "const packageValue = 1\n", needle: "const packageValue = 1"}, + {name: "type", declaration: "type packageValue struct{}\n", needle: "type packageValue struct{}"}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + oldData := []byte(basePrefix) + newData := []byte(basePrefix + tt.declaration + ` +func TestBeta(t *testing.T) { + t.Log("beta") +} +`) + inventory := mustPackageInventory(t, map[string]string{ + "pkg/changed_test.go": string(newData), + }) + oldSnapshot := mustOptionalFileSnapshot(t, oldData) + newSnapshot := mustFileSnapshot(t, newData) + selection := selectTestsFromHunks(change, oldSnapshot, newSnapshot, inventory, []diffHunk{{ + Old: emptyRangeAt(7), + New: rangeSpan( + singleLineRange(t, string(newData), tt.needle), + singleLineRange(t, string(newData), `t.Log("beta")`), + ), + }}) + require.NotNil(t, selection) + require.Equal(t, []string{"TestBeta"}, selectionNames(selection)) + require.False(t, selection.Broadened) + }) + } +} + +func TestSelectTestsForSnapshotsBroadensAddedImports(t *testing.T) { + t.Parallel() + + change := testFileChange{Kind: changeModified, OldPath: "pkg/changed_test.go", NewPath: "pkg/changed_test.go"} + oldData := []byte(`package sample + +import "testing" + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +`) + newData := []byte(`package sample + +import ( + _ "example.com/sideeffect" + "testing" +) + +func TestAlpha(t *testing.T) { + t.Log("alpha") +} +`) + inventory := mustPackageInventory(t, map[string]string{ + "pkg/changed_test.go": string(newData), + "pkg/sibling_test.go": `package sample + +import "testing" + +func TestBeta(t *testing.T) { + t.Log("beta") +} +`, + }) + oldSnapshot := mustOptionalFileSnapshot(t, oldData) + newSnapshot := mustFileSnapshot(t, newData) + selection := selectTestsFromHunks(change, oldSnapshot, newSnapshot, inventory, []diffHunk{{ + Old: emptyRangeAt(3), + New: singleLineRange(t, string(newData), `_ "example.com/sideeffect"`), + }}) + require.NotNil(t, selection) + require.Equal(t, []string{"TestAlpha", "TestBeta"}, selectionNames(selection)) + require.True(t, selection.Broadened) +} + +func TestMergePackageSelectionCombinesSamePackageFiles(t *testing.T) { + t.Parallel() + + key := packageKey{Dir: "pkg", Name: "sample"} + selections := map[packageKey]*packageSelection{} + mergePackageSelection(selections, &packageSelection{ + Key: key, + Tests: map[string]struct{}{"TestAlpha": {}}, + Files: map[string]struct{}{"pkg/alpha_test.go": {}}, + }) + mergePackageSelection(selections, &packageSelection{ + Key: key, + Tests: map[string]struct{}{"TestBeta": {}}, + Files: map[string]struct{}{"pkg/beta_test.go": {}}, + Broadened: true, + }) + + require.Equal(t, []string{"TestAlpha", "TestBeta"}, selectionNames(selections[key])) + require.True(t, selections[key].Broadened) + require.Contains(t, selections[key].Files, "pkg/alpha_test.go") + require.Contains(t, selections[key].Files, "pkg/beta_test.go") +} + +func TestSelectChangeRequiresOldFileWhenKindExpectsIt(t *testing.T) { + t.Parallel() + + oldPath := "pkg/old_test.go" + newPath := "pkg/new_test.go" + newContent := `package sample + +import "testing" + +func TestAlpha(t *testing.T) {} +` + tests := []struct { + name string + change testFileChange + key string + }{ + { + name: "modified", + change: testFileChange{Kind: changeModified, OldPath: oldPath, NewPath: oldPath}, + key: oldPath, + }, + { + name: "renamed", + change: testFileChange{Kind: changeRenamed, OldPath: oldPath, NewPath: newPath}, + key: oldPath + "\x00" + newPath, + }, + { + name: "deleted", + change: testFileChange{Kind: changeDeleted, OldPath: oldPath}, + key: oldPath, + }, + { + name: "type", + change: testFileChange{Kind: changeType, OldPath: oldPath, NewPath: oldPath}, + key: oldPath, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + headFiles := map[string]string{} + if tt.change.NewPath != "" { + headFiles[tt.change.NewPath] = newContent + } + repo := fakeGitRepo{ + revisions: map[string]map[string]string{ + "base": {}, + "head": headFiles, + }, + diffOutputs: map[string]string{ + tt.key: diffForChange(lineRange{Start: 5, End: 5}, lineRange{Start: 5, End: 5}), + }, + } + cache := newInventoryCache(config{RepoRoot: "/repo", BaseSHA: "base", HeadSHA: "head"}, repo.runner(t)) + err := selectChange(t.Context(), cache, map[packageKey]*packageSelection{}, tt.change) + require.ErrorContains(t, err, "base revision base is missing "+oldPath) + }) + } +} + +func TestSelectChangeRequiresNewFileWhenKindExpectsIt(t *testing.T) { + t.Parallel() + + oldPath := "pkg/base_side_test.go" + newPath := "pkg/head_side_test.go" + oldContent := `package sample + +import "testing" + +func TestAlpha(t *testing.T) {} +` + tests := []struct { + name string + change testFileChange + key string + wantPath string + }{ + { + name: "added", + change: testFileChange{Kind: changeAdded, NewPath: newPath}, + key: newPath, + wantPath: newPath, + }, + { + name: "modified", + change: testFileChange{Kind: changeModified, OldPath: oldPath, NewPath: oldPath}, + key: oldPath, + wantPath: oldPath, + }, + { + name: "renamed", + change: testFileChange{Kind: changeRenamed, OldPath: oldPath, NewPath: newPath}, + key: oldPath + "\x00" + newPath, + wantPath: newPath, + }, + { + name: "type", + change: testFileChange{Kind: changeType, OldPath: oldPath, NewPath: oldPath}, + key: oldPath, + wantPath: oldPath, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + baseFiles := map[string]string{} + if tt.change.OldPath != "" { + baseFiles[tt.change.OldPath] = oldContent + } + repo := fakeGitRepo{ + revisions: map[string]map[string]string{ + "base": baseFiles, + "head": {}, + }, + diffOutputs: map[string]string{ + tt.key: diffForChange(lineRange{Start: 5, End: 5}, lineRange{Start: 5, End: 5}), + }, + } + cache := newInventoryCache(config{RepoRoot: "/repo", BaseSHA: "base", HeadSHA: "head"}, repo.runner(t)) + err := selectChange(t.Context(), cache, map[packageKey]*packageSelection{}, tt.change) + require.ErrorContains(t, err, "head revision head is missing "+tt.wantPath) + }) + } +} diff --git a/snapshot.go b/snapshot.go new file mode 100644 index 0000000..19b6e48 --- /dev/null +++ b/snapshot.go @@ -0,0 +1,282 @@ +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "slices" + "strings" + "unicode" + "unicode/utf8" +) + +type fileSnapshot struct { + packageName string + tests map[string]lineRange + shared []sharedDecl + sharedKeys map[string]struct{} +} + +type sharedDeclKind uint8 + +const ( + sharedDeclImport sharedDeclKind = iota + 1 + sharedDeclVar + sharedDeclConst + sharedDeclType + sharedDeclHelper + sharedDeclInit + sharedDeclTestMain +) + +type sharedDecl struct { + Range lineRange + Kind sharedDeclKind + Keys []string +} + +func parseFileSnapshot(data []byte) (fileSnapshot, error) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "", data, parser.SkipObjectResolution) + if err != nil { + return fileSnapshot{}, err + } + + testingDotImport := hasTestingDotImport(file) + snapshot := fileSnapshot{ + packageName: file.Name.Name, + tests: map[string]lineRange{}, + sharedKeys: map[string]struct{}{}, + } + for _, decl := range file.Decls { + rangeForDecl := nodeRange(fset, decl) + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + snapshot.addSharedDecl(sharedDecl{Range: rangeForDecl, Kind: sharedDeclHelper}) + continue + } + snapshot.addSharedDecl(classifyGenDecl(rangeForDecl, genDecl)) + continue + } + if funcDecl.Name == nil { + snapshot.addSharedDecl(sharedDecl{Range: rangeForDecl, Kind: sharedDeclHelper}) + continue + } + + name := funcDecl.Name.Name + switch { + case name == "TestMain": + snapshot.addSharedDecl(sharedDecl{ + Range: rangeForDecl, + Kind: sharedDeclTestMain, + Keys: []string{"func:TestMain"}, + }) + case name == "init": + snapshot.addSharedDecl(sharedDecl{Range: rangeForDecl, Kind: sharedDeclInit}) + case isTopLevelTestFunc(funcDecl, testingDotImport), isTopLevelFuzzFunc(funcDecl, testingDotImport), isTopLevelExampleFunc(funcDecl): + snapshot.tests[name] = rangeForDecl + default: + snapshot.addSharedDecl(sharedDecl{ + Range: rangeForDecl, + Kind: sharedDeclHelper, + Keys: []string{funcIdentity(fset, funcDecl)}, + }) + } + } + return snapshot, nil +} + +func hasTestingDotImport(file *ast.File) bool { + for _, importSpec := range file.Imports { + if importSpec == nil || importSpec.Path == nil { + continue + } + if importSpec.Name == nil || importSpec.Name.Name != "." { + continue + } + if strings.Trim(importSpec.Path.Value, `"`) == "testing" { + return true + } + } + return false +} + +func classifyGenDecl(rangeForDecl lineRange, decl *ast.GenDecl) sharedDecl { + shared := sharedDecl{Range: rangeForDecl} + switch decl.Tok { + case token.IMPORT: + shared.Kind = sharedDeclImport + case token.VAR: + shared.Kind = sharedDeclVar + shared.Keys = genDeclKeys("var", decl.Specs) + case token.CONST: + shared.Kind = sharedDeclConst + shared.Keys = genDeclKeys("const", decl.Specs) + case token.TYPE: + shared.Kind = sharedDeclType + shared.Keys = genDeclKeys("type", decl.Specs) + default: + shared.Kind = sharedDeclHelper + } + return shared +} + +func genDeclKeys(prefix string, specs []ast.Spec) []string { + keys := make([]string, 0, len(specs)) + for _, spec := range specs { + switch typed := spec.(type) { + case *ast.TypeSpec: + if typed.Name == nil || typed.Name.Name == "_" { + continue + } + keys = append(keys, prefix+":"+typed.Name.Name) + case *ast.ValueSpec: + for _, name := range typed.Names { + if name == nil || name.Name == "_" { + continue + } + keys = append(keys, prefix+":"+name.Name) + } + } + } + slices.Sort(keys) + return keys +} + +func funcIdentity(fset *token.FileSet, fn *ast.FuncDecl) string { + if fn.Name == nil { + return "" + } + if fn.Recv == nil || len(fn.Recv.List) == 0 { + return "func:" + fn.Name.Name + } + return "method:" + exprString(fset, fn.Recv.List[0].Type) + "." + fn.Name.Name +} + +func exprString(fset *token.FileSet, expr ast.Expr) string { + var buffer bytes.Buffer + if err := printer.Fprint(&buffer, fset, expr); err != nil { + return fmt.Sprintf("%T", expr) + } + return buffer.String() +} + +func nodeRange(fset *token.FileSet, node ast.Node) lineRange { + start := fset.Position(node.Pos()).Line + end := fset.Position(node.End()).Line + if end < start { + end = start + } + return lineRange{Start: start, End: end} +} + +func isTopLevelTestFunc(fn *ast.FuncDecl, testingDotImport bool) bool { + if fn.Recv != nil || !hasRunnableName(fn.Name, "Test", false) { + return false + } + if hasParamSelectorName(fn, "T") { + return true + } + return testingDotImport && hasParamIdentName(fn, "T") +} + +func isTopLevelFuzzFunc(fn *ast.FuncDecl, testingDotImport bool) bool { + if fn.Recv != nil || !hasRunnableName(fn.Name, "Fuzz", false) { + return false + } + if hasParamSelectorName(fn, "F") { + return true + } + return testingDotImport && hasParamIdentName(fn, "F") +} + +func isTopLevelExampleFunc(fn *ast.FuncDecl) bool { + return fn.Recv == nil && hasRunnableName(fn.Name, "Example", true) && fn.Type != nil && fn.Type.Params != nil && len(fn.Type.Params.List) == 0 +} + +func hasRunnableName(name *ast.Ident, prefix string, allowBare bool) bool { + if name == nil { + return false + } + rest, ok := strings.CutPrefix(name.Name, prefix) + if !ok { + return false + } + if rest == "" { + return allowBare + } + r, _ := utf8.DecodeRuneInString(rest) + return r == '_' || !unicode.IsLower(r) +} + +func hasParamSelectorName(fn *ast.FuncDecl, expectedName string) bool { + if fn.Type == nil || fn.Type.Params == nil { + return false + } + params := fn.Type.Params.List + if len(params) != 1 { + return false + } + name, ok := pointerSelectorName(params[0].Type) + return ok && name == expectedName +} + +func hasParamIdentName(fn *ast.FuncDecl, expectedName string) bool { + if fn.Type == nil || fn.Type.Params == nil { + return false + } + params := fn.Type.Params.List + if len(params) != 1 { + return false + } + name, ok := pointerIdentName(params[0].Type) + return ok && name == expectedName +} + +func pointerSelectorName(expr ast.Expr) (string, bool) { + star, ok := expr.(*ast.StarExpr) + if !ok { + return "", false + } + selector, ok := star.X.(*ast.SelectorExpr) + if !ok || selector.Sel == nil { + return "", false + } + return selector.Sel.Name, true +} + +func pointerIdentName(expr ast.Expr) (string, bool) { + star, ok := expr.(*ast.StarExpr) + if !ok { + return "", false + } + ident, ok := star.X.(*ast.Ident) + if !ok { + return "", false + } + return ident.Name, true +} + +func (snapshot *fileSnapshot) addSharedDecl(decl sharedDecl) { + snapshot.shared = append(snapshot.shared, decl) + for _, key := range decl.Keys { + if key == "" { + continue + } + snapshot.sharedKeys[key] = struct{}{} + } +} + +func (snapshot *fileSnapshot) hasAnySharedKey(keys []string) bool { + for _, key := range keys { + if _, ok := snapshot.sharedKeys[key]; ok { + return true + } + } + return false +} diff --git a/snapshot_test.go b/snapshot_test.go new file mode 100644 index 0000000..38db136 --- /dev/null +++ b/snapshot_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "maps" + "slices" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseFileSnapshotRejectsLowercaseSuffixes(t *testing.T) { + t.Parallel() + + snapshot, err := parseFileSnapshot([]byte(`package sample + +import "testing" + +func TestAlpha(t *testing.T) {} +func Testify(t *testing.T) {} +func FuzzAlpha(f *testing.F) {} +func Fuzzbar(f *testing.F) {} +func Example() {} +func ExampleFoo() {} +func Examplefoo() {} +`)) + require.NoError(t, err) + require.Equal(t, []string{"Example", "ExampleFoo", "FuzzAlpha", "TestAlpha"}, slices.Sorted(maps.Keys(snapshot.tests))) +} + +func TestParseFileSnapshotRecordsStructure(t *testing.T) { + t.Parallel() + + snapshot, err := parseFileSnapshot([]byte(`package sample + +import . "testing" + +const answer = 42 +var packageValue = answer +type fixture struct{} + +func helper() {} +func init() {} +func TestMain(m *M) {} +func TestAlpha(t *T) {} +func FuzzAlpha(f *F) {} +`)) + require.NoError(t, err) + require.Equal(t, "sample", snapshot.packageName) + require.Equal(t, []string{"FuzzAlpha", "TestAlpha"}, slices.Sorted(maps.Keys(snapshot.tests))) + require.Contains(t, snapshot.sharedKeys, "const:answer") + require.Contains(t, snapshot.sharedKeys, "var:packageValue") + require.Contains(t, snapshot.sharedKeys, "type:fixture") + require.Contains(t, snapshot.sharedKeys, "func:helper") + require.Contains(t, snapshot.sharedKeys, "func:TestMain") + require.True(t, slices.ContainsFunc(snapshot.shared, func(decl sharedDecl) bool { + return decl.Kind == sharedDeclInit + })) +}