diff --git a/go.mod b/go.mod index 94bba07..000ce57 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.26 require ( github.com/byteness/percent v0.2.2 + github.com/google/uuid v1.6.0 github.com/open-cli-collective/cli-common v0.3.0 github.com/spf13/cobra v1.10.2 modernc.org/sqlite v1.51.0 @@ -11,7 +12,6 @@ require ( require ( github.com/dustin/go-humanize v1.0.1 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect diff --git a/internal/ledger/ledger.go b/internal/ledger/ledger.go new file mode 100644 index 0000000..5a1baca --- /dev/null +++ b/internal/ledger/ledger.go @@ -0,0 +1,1357 @@ +// Package ledger stores cr review runs in SQLite. +package ledger + +import ( + "context" + "database/sql" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/open-cli-collective/codereview-cli/internal/dbmig" + "github.com/open-cli-collective/codereview-cli/internal/review" + sqlite "modernc.org/sqlite" // register the SQLite database/sql driver. + sqlite3 "modernc.org/sqlite/lib" +) + +const ( + // SchemaVersion is the current ledger schema version. + SchemaVersion = 1 + // DefaultBusyTimeout is the SQLite busy timeout configured at open. + DefaultBusyTimeout = 5 * time.Second + writeQueueSize = 64 +) + +var ( + // ErrClosed means a mutating operation was attempted after Store.Close. + ErrClosed = errors.New("ledger: store closed") + // ErrNotFound means a requested ledger row does not exist. + ErrNotFound = errors.New("ledger: not found") + // ErrInvalidInput means a caller supplied an invalid storage value. + ErrInvalidInput = errors.New("ledger: invalid input") + // ErrRunExists means AllocateRun recovery mode received an existing run id. + ErrRunExists = errors.New("ledger: run already exists") +) + +// PostMode records whether a run is live or dry-run. +type PostMode string + +// PostMode values. +const ( + PostModeLive PostMode = "live" + PostModeDryRun PostMode = "dry_run" +) + +func (m PostMode) String() string { return string(m) } + +// Valid reports whether m is a known post mode. +func (m PostMode) Valid() bool { + switch m { + case PostModeLive, PostModeDryRun: + return true + default: + return false + } +} + +// ParsePostMode parses a storage post mode. +func ParsePostMode(value string) (PostMode, error) { + mode := PostMode(normalizeStorageValue(value)) + if !mode.Valid() { + return "", invalidInput("post_mode", value) + } + return mode, nil +} + +// Outcome records a finalized run outcome. +type Outcome string + +// Outcome values. +const ( + OutcomeIncomplete Outcome = "incomplete" + OutcomeApproved Outcome = "approved" + OutcomeRequestChanges Outcome = "request_changes" + OutcomeComment Outcome = "comment" + OutcomeNothingToReview Outcome = "nothing_to_review" + OutcomeDryRun Outcome = "dry_run" + OutcomeAborted Outcome = "aborted" + OutcomeFailed Outcome = "failed" +) + +func (o Outcome) String() string { return string(o) } + +// Valid reports whether o is a known run outcome. +func (o Outcome) Valid() bool { + switch o { + case OutcomeIncomplete, OutcomeApproved, OutcomeRequestChanges, OutcomeComment, OutcomeNothingToReview, OutcomeDryRun, OutcomeAborted, OutcomeFailed: + return true + default: + return false + } +} + +// ParseOutcome parses a storage outcome. +func ParseOutcome(value string) (Outcome, error) { + outcome := Outcome(normalizeStorageValue(value)) + if !outcome.Valid() { + return "", invalidInput("outcome", value) + } + return outcome, nil +} + +// SessionRole records the role attached to an LLM session row. +type SessionRole string + +// SessionRole values. +const ( + SessionRoleOrchestrator SessionRole = "orchestrator" + SessionRoleReviewer SessionRole = "reviewer" +) + +func (r SessionRole) String() string { return string(r) } + +// Valid reports whether r is a known session role. +func (r SessionRole) Valid() bool { + switch r { + case SessionRoleOrchestrator, SessionRoleReviewer: + return true + default: + return false + } +} + +// ParseSessionRole parses a storage session role. +func ParseSessionRole(value string) (SessionRole, error) { + role := SessionRole(normalizeStorageValue(value)) + if !role.Valid() { + return "", invalidInput("session role", value) + } + return role, nil +} + +// PlannedActionKind identifies the host-side action to execute later. +type PlannedActionKind string + +// PlannedActionKind values. +const ( + PlannedActionInlineComment PlannedActionKind = "inline_comment" + PlannedActionThreadReply PlannedActionKind = "thread_reply" + PlannedActionResolveThread PlannedActionKind = "resolve_thread" + PlannedActionRollupComment PlannedActionKind = "rollup_comment" + PlannedActionSubmitReview PlannedActionKind = "submit_review" +) + +func (k PlannedActionKind) String() string { return string(k) } + +// Valid reports whether k is a known planned action kind. +func (k PlannedActionKind) Valid() bool { + switch k { + case PlannedActionInlineComment, PlannedActionThreadReply, PlannedActionResolveThread, PlannedActionRollupComment, PlannedActionSubmitReview: + return true + default: + return false + } +} + +// ParsePlannedActionKind parses a storage planned action kind. +func ParsePlannedActionKind(value string) (PlannedActionKind, error) { + kind := PlannedActionKind(normalizeStorageValue(value)) + if !kind.Valid() { + return "", invalidInput("planned action kind", value) + } + return kind, nil +} + +// PlannedActionStatus records the durable outbox status for a planned action. +type PlannedActionStatus string + +// PlannedActionStatus values. +const ( + PlannedActionPending PlannedActionStatus = "pending" + PlannedActionPosted PlannedActionStatus = "posted" + PlannedActionFailedTerminal PlannedActionStatus = "failed_terminal" + PlannedActionSuperseded PlannedActionStatus = "superseded" + PlannedActionPlannedOnly PlannedActionStatus = "planned_only" +) + +func (s PlannedActionStatus) String() string { return string(s) } + +// Valid reports whether s is a known planned action status. +func (s PlannedActionStatus) Valid() bool { + switch s { + case PlannedActionPending, PlannedActionPosted, PlannedActionFailedTerminal, PlannedActionSuperseded, PlannedActionPlannedOnly: + return true + default: + return false + } +} + +// ParsePlannedActionStatus parses a storage planned action status. +func ParsePlannedActionStatus(value string) (PlannedActionStatus, error) { + status := PlannedActionStatus(normalizeStorageValue(value)) + if !status.Valid() { + return "", invalidInput("planned action status", value) + } + return status, nil +} + +// Store owns the SQLite connection and serializes mutating ledger writes. +type Store struct { + db *sql.DB + writes chan writeRequest + done chan struct{} + closeErr error + mu sync.Mutex + closed bool +} + +type writeRequest struct { + ctx context.Context + fn func(context.Context, *sql.DB) error + res chan error +} + +// PR is the stable pull-request identity row. +type PR struct { + PRKey string + PRURL string + FirstSeenAt time.Time +} + +// Run is one cr review invocation recorded in the ledger. +type Run struct { + RunID string + PRKey string + SHA string + BaseSHA string + Attempt int + Profile string + PostingIdentity string + PostMode PostMode + StartedAt time.Time + HeartbeatAt *time.Time + CompletedAt *time.Time + Outcome *Outcome + ArtifactPath string + BlockingCount int + MajorCount int + MinorCount int + NitsCount int +} + +// AllocateRunParams contains the inputs needed to create a run row. +type AllocateRunParams struct { + PRKey string + PRURL string + RunID string + SHA string + BaseSHA string + Profile string + PostingIdentity string + PostMode PostMode + StartedAt time.Time + ArtifactPath string +} + +// Session records LLM session usage for a run. +type Session struct { + SessionRowID string + RunID string + ProviderSessionID string + Role SessionRole + AgentID *string + Adapter string + Model string + Effort *string + StartedAt time.Time + CompletedAt *time.Time + DurationMS *int64 + TokensIn *int64 + TokensOut *int64 + CacheRead *int64 + CacheCreate *int64 + CostUSD *float64 +} + +// Finding records one harness-assigned review finding. +type Finding struct { + FindingID review.FindingID + RunID string + SessionRowID string + Severity review.Severity + FilePath string + Side *review.DiffSide + Line *int64 + DiffPosition *int64 + Anchoring review.Anchoring + Body string +} + +// PlannedAction records one outbox action planned for a run. +type PlannedAction struct { + ActionID string + RunID string + Kind PlannedActionKind + FindingID *string + ThreadID *string + PlannedAt time.Time + PayloadJSON string + Status PlannedActionStatus + Required bool + Attempts int + AttemptedAt *time.Time + PostedAt *time.Time + UpstreamID *string + Error *string +} + +// NamedSession records cross-run provider session reuse metadata. +type NamedSession struct { + Name string + Profile string + Provider string + Adapter string + Model string + Host string + ProviderSessionID string + CreatedAt time.Time + LastUsedAt time.Time +} + +// Open opens or creates a ledger database at path and applies migrations. +func Open(ctx context.Context, path string) (*Store, error) { + if strings.TrimSpace(path) == "" { + return nil, invalidInput("path", path) + } + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return nil, fmt.Errorf("ledger: create db parent: %w", err) + } + + db, err := sql.Open("sqlite", path) + if err != nil { + return nil, fmt.Errorf("ledger: open sqlite: %w", err) + } + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + if err := configureSQLite(ctx, db); err != nil { + _ = db.Close() + return nil, err + } + if _, err := dbmig.Apply(ctx, db, migrations()); err != nil { + _ = db.Close() + return nil, fmt.Errorf("ledger: migrate: %w", err) + } + + store := &Store{ + db: db, + writes: make(chan writeRequest, writeQueueSize), + done: make(chan struct{}), + } + go store.writer() + return store, nil +} + +// Close stops the writer goroutine and closes the database. +func (s *Store) Close() error { + s.mu.Lock() + if s.closed { + err := s.closeErr + s.mu.Unlock() + return err + } + s.closed = true + close(s.writes) + s.mu.Unlock() + + <-s.done + err := s.db.Close() + + s.mu.Lock() + s.closeErr = err + s.mu.Unlock() + return err +} + +func (s *Store) writer() { + defer close(s.done) + for req := range s.writes { + req.res <- req.fn(req.ctx, s.db) + close(req.res) + } +} + +func (s *Store) write(ctx context.Context, fn func(context.Context, *sql.DB) error) error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return ErrClosed + } + req := writeRequest{ctx: ctx, fn: fn, res: make(chan error, 1)} + select { + case s.writes <- req: + s.mu.Unlock() + case <-ctx.Done(): + s.mu.Unlock() + return ctx.Err() + } + + // Once dispatched, report the writer result rather than racing context + // cancellation against a write that may already have committed. + return <-req.res +} + +func (s *Store) checkOpen() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return ErrClosed + } + return nil +} + +func configureSQLite(ctx context.Context, db *sql.DB) error { + if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + return fmt.Errorf("ledger: enable foreign keys: %w", err) + } + if _, err := db.ExecContext(ctx, "PRAGMA journal_mode = WAL"); err != nil { + return fmt.Errorf("ledger: enable WAL: %w", err) + } + if _, err := db.ExecContext(ctx, fmt.Sprintf("PRAGMA busy_timeout = %d", DefaultBusyTimeout.Milliseconds())); err != nil { + return fmt.Errorf("ledger: set busy timeout: %w", err) + } + return nil +} + +func migrations() []dbmig.Migration { + return []dbmig.Migration{{ + Version: 1, + Name: "ledger schema", + Up: func(ctx context.Context, tx *sql.Tx) error { + for _, statement := range schemaStatements { + if _, err := tx.ExecContext(ctx, statement); err != nil { + return err + } + } + return nil + }, + }} +} + +// AllocateRun creates a run and allocates its attempt transactionally. +func (s *Store) AllocateRun(ctx context.Context, params AllocateRunParams) (Run, error) { + if err := validateAllocateRunParams(params); err != nil { + return Run{}, err + } + + var run Run + err := s.write(ctx, func(ctx context.Context, db *sql.DB) error { + inserted, err := allocateRunTx(ctx, db, params) + if err != nil { + return err + } + run = inserted + return nil + }) + if err != nil { + return Run{}, err + } + return run, nil +} + +func allocateRunTx(ctx context.Context, db *sql.DB, params AllocateRunParams) (Run, error) { + for { + run, err := allocateRunOnce(ctx, db, params) + if err == nil { + return run, nil + } + retry, classifyErr := classifyAllocateConstraint(ctx, db, params, err) + if classifyErr != nil { + return Run{}, classifyErr + } + if !retry { + return Run{}, err + } + } +} + +func classifyAllocateConstraint(ctx context.Context, db *sql.DB, params AllocateRunParams, err error) (bool, error) { + if !isSQLiteConstraintError(err) { + return false, err + } + if params.RunID != "" { + exists, existsErr := runIDExists(ctx, db, params.RunID) + if existsErr != nil { + return false, existsErr + } + if exists { + return false, fmt.Errorf("%w: %s", ErrRunExists, params.RunID) + } + } + if isResumeAttemptConstraintError(err) { + return true, nil + } + return false, err +} + +func runIDExists(ctx context.Context, db *sql.DB, runID string) (bool, error) { + var exists int + if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM runs WHERE run_id = ?", runID).Scan(&exists); err != nil { + return false, fmt.Errorf("ledger: check run id after constraint: %w", err) + } + return exists > 0, nil +} + +func allocateRunOnce(ctx context.Context, db *sql.DB, params AllocateRunParams) (Run, error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return Run{}, fmt.Errorf("ledger: begin allocate run: %w", err) + } + defer func() { + _ = tx.Rollback() + }() + + if params.RunID != "" { + var exists int + if err := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM runs WHERE run_id = ?", params.RunID).Scan(&exists); err != nil { + return Run{}, fmt.Errorf("ledger: check run id: %w", err) + } + if exists > 0 { + return Run{}, fmt.Errorf("%w: %s", ErrRunExists, params.RunID) + } + } + + if _, err := tx.ExecContext(ctx, ` +INSERT INTO prs (pr_key, pr_url, first_seen_at) +VALUES (?, ?, ?) +ON CONFLICT(pr_key) DO UPDATE SET pr_url = excluded.pr_url`, + params.PRKey, params.PRURL, encodeTime(params.StartedAt), + ); err != nil { + return Run{}, fmt.Errorf("ledger: upsert pr: %w", err) + } + + var attempt int + if err := tx.QueryRowContext(ctx, ` +SELECT COALESCE(MAX(attempt), 0) + 1 +FROM runs +WHERE pr_key = ? AND sha = ? AND base_sha = ? AND profile = ? AND posting_identity = ?`, + params.PRKey, params.SHA, params.BaseSHA, params.Profile, params.PostingIdentity, + ).Scan(&attempt); err != nil { + return Run{}, fmt.Errorf("ledger: allocate attempt: %w", err) + } + + runID := params.RunID + if runID == "" { + runID = uuid.NewString() + } + run := Run{ + RunID: runID, + PRKey: params.PRKey, + SHA: params.SHA, + BaseSHA: params.BaseSHA, + Attempt: attempt, + Profile: params.Profile, + PostingIdentity: params.PostingIdentity, + PostMode: params.PostMode, + StartedAt: params.StartedAt.UTC(), + ArtifactPath: params.ArtifactPath, + } + + if _, err := tx.ExecContext(ctx, ` +INSERT INTO runs ( + run_id, pr_key, sha, base_sha, attempt, profile, posting_identity, post_mode, + started_at, heartbeat_at, completed_at, outcome, artifact_path, + blocking_count, major_count, minor_count, nits_count +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + run.RunID, run.PRKey, run.SHA, run.BaseSHA, run.Attempt, run.Profile, run.PostingIdentity, + run.PostMode.String(), encodeTime(run.StartedAt), nil, nil, nil, run.ArtifactPath, + run.BlockingCount, run.MajorCount, run.MinorCount, run.NitsCount, + ); err != nil { + return Run{}, fmt.Errorf("ledger: insert run: %w", err) + } + + if err := tx.Commit(); err != nil { + return Run{}, fmt.Errorf("ledger: commit allocate run: %w", err) + } + return run, nil +} + +// GetPR returns a pull-request identity row by key. +func (s *Store) GetPR(ctx context.Context, prKey string) (PR, error) { + if strings.TrimSpace(prKey) == "" { + return PR{}, invalidInput("pr_key", prKey) + } + if err := s.checkOpen(); err != nil { + return PR{}, err + } + var pr PR + var firstSeenAt string + err := s.db.QueryRowContext(ctx, "SELECT pr_key, pr_url, first_seen_at FROM prs WHERE pr_key = ?", prKey). + Scan(&pr.PRKey, &pr.PRURL, &firstSeenAt) + if errors.Is(err, sql.ErrNoRows) { + return PR{}, ErrNotFound + } + if err != nil { + return PR{}, fmt.Errorf("ledger: get pr: %w", err) + } + parsed, err := parseTime(firstSeenAt) + if err != nil { + return PR{}, err + } + pr.FirstSeenAt = parsed + return pr, nil +} + +// GetRun returns a run by id. +func (s *Store) GetRun(ctx context.Context, runID string) (Run, error) { + if strings.TrimSpace(runID) == "" { + return Run{}, invalidInput("run_id", runID) + } + if err := s.checkOpen(); err != nil { + return Run{}, err + } + row := s.db.QueryRowContext(ctx, ` +SELECT run_id, pr_key, sha, base_sha, attempt, profile, posting_identity, post_mode, + started_at, heartbeat_at, completed_at, outcome, artifact_path, + blocking_count, major_count, minor_count, nits_count +FROM runs WHERE run_id = ?`, runID) + run, err := scanRun(row) + if errors.Is(err, sql.ErrNoRows) { + return Run{}, ErrNotFound + } + if err != nil { + return Run{}, fmt.Errorf("ledger: get run: %w", err) + } + return run, nil +} + +// DeleteRun deletes a run and lets SQLite cascade child rows. +func (s *Store) DeleteRun(ctx context.Context, runID string) error { + if strings.TrimSpace(runID) == "" { + return invalidInput("run_id", runID) + } + return s.write(ctx, func(ctx context.Context, db *sql.DB) error { + result, err := db.ExecContext(ctx, "DELETE FROM runs WHERE run_id = ?", runID) + if err != nil { + return fmt.Errorf("ledger: delete run: %w", err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("ledger: delete run rows affected: %w", err) + } + if affected == 0 { + return ErrNotFound + } + return nil + }) +} + +// InsertSession inserts an LLM session row. +func (s *Store) InsertSession(ctx context.Context, session Session) error { + if err := validateSession(session); err != nil { + return err + } + return s.write(ctx, func(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, ` +INSERT INTO sessions ( + session_row_id, run_id, provider_session_id, role, agent_id, adapter, model, effort, + started_at, completed_at, duration_ms, tokens_in, tokens_out, cache_read, cache_create, cost_usd +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + session.SessionRowID, session.RunID, session.ProviderSessionID, session.Role.String(), session.AgentID, + session.Adapter, session.Model, session.Effort, encodeTime(session.StartedAt), encodeOptionalTime(session.CompletedAt), + session.DurationMS, session.TokensIn, session.TokensOut, session.CacheRead, session.CacheCreate, session.CostUSD, + ) + if err != nil { + return fmt.Errorf("ledger: insert session: %w", err) + } + return nil + }) +} + +// GetSession returns a session by internal row id. +func (s *Store) GetSession(ctx context.Context, sessionRowID string) (Session, error) { + if strings.TrimSpace(sessionRowID) == "" { + return Session{}, invalidInput("session_row_id", sessionRowID) + } + if err := s.checkOpen(); err != nil { + return Session{}, err + } + row := s.db.QueryRowContext(ctx, ` +SELECT session_row_id, run_id, provider_session_id, role, agent_id, adapter, model, effort, + started_at, completed_at, duration_ms, tokens_in, tokens_out, cache_read, cache_create, cost_usd +FROM sessions WHERE session_row_id = ?`, sessionRowID) + session, err := scanSession(row) + if errors.Is(err, sql.ErrNoRows) { + return Session{}, ErrNotFound + } + if err != nil { + return Session{}, fmt.Errorf("ledger: get session: %w", err) + } + return session, nil +} + +// InsertFinding inserts a validated harness finding row. +func (s *Store) InsertFinding(ctx context.Context, finding Finding) error { + if err := validateFinding(finding); err != nil { + return err + } + return s.write(ctx, func(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, ` +INSERT INTO findings ( + finding_id, run_id, session_row_id, severity, file_path, side, line, diff_position, anchoring, body +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + finding.FindingID.String(), finding.RunID, finding.SessionRowID, finding.Severity.String(), finding.FilePath, + finding.Side, finding.Line, finding.DiffPosition, finding.Anchoring.String(), finding.Body, + ) + if err != nil { + return fmt.Errorf("ledger: insert finding: %w", err) + } + return nil + }) +} + +// ListFindings lists findings for a run in stable order. +func (s *Store) ListFindings(ctx context.Context, runID string) ([]Finding, error) { + if strings.TrimSpace(runID) == "" { + return nil, invalidInput("run_id", runID) + } + if err := s.checkOpen(); err != nil { + return nil, err + } + rows, err := s.db.QueryContext(ctx, ` +SELECT finding_id, run_id, session_row_id, severity, file_path, side, line, diff_position, anchoring, body +FROM findings WHERE run_id = ? ORDER BY finding_id`, runID) + if err != nil { + return nil, fmt.Errorf("ledger: list findings: %w", err) + } + defer rows.Close() + + var findings []Finding + for rows.Next() { + finding, err := scanFinding(rows) + if err != nil { + return nil, err + } + findings = append(findings, finding) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("ledger: list findings rows: %w", err) + } + return findings, nil +} + +// InsertPlannedAction inserts an outbox planned-action row. +func (s *Store) InsertPlannedAction(ctx context.Context, action PlannedAction) error { + if err := validatePlannedAction(action); err != nil { + return err + } + return s.write(ctx, func(ctx context.Context, db *sql.DB) error { + required := 0 + if action.Required { + required = 1 + } + _, err := db.ExecContext(ctx, ` +INSERT INTO planned_actions ( + action_id, run_id, kind, finding_id, thread_id, planned_at, payload_json, status, + required, attempts, attempted_at, posted_at, upstream_id, error +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + action.ActionID, action.RunID, action.Kind.String(), action.FindingID, action.ThreadID, encodeTime(action.PlannedAt), + action.PayloadJSON, action.Status.String(), required, action.Attempts, encodeOptionalTime(action.AttemptedAt), + encodeOptionalTime(action.PostedAt), action.UpstreamID, action.Error, + ) + if err != nil { + return fmt.Errorf("ledger: insert planned action: %w", err) + } + return nil + }) +} + +// ListPlannedActions lists planned actions for a run in stable order. +func (s *Store) ListPlannedActions(ctx context.Context, runID string) ([]PlannedAction, error) { + if strings.TrimSpace(runID) == "" { + return nil, invalidInput("run_id", runID) + } + if err := s.checkOpen(); err != nil { + return nil, err + } + rows, err := s.db.QueryContext(ctx, ` +SELECT action_id, run_id, kind, finding_id, thread_id, planned_at, payload_json, status, + required, attempts, attempted_at, posted_at, upstream_id, error +FROM planned_actions WHERE run_id = ? ORDER BY action_id`, runID) + if err != nil { + return nil, fmt.Errorf("ledger: list planned actions: %w", err) + } + defer rows.Close() + + var actions []PlannedAction + for rows.Next() { + action, err := scanPlannedAction(rows) + if err != nil { + return nil, err + } + actions = append(actions, action) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("ledger: list planned action rows: %w", err) + } + return actions, nil +} + +// UpsertNamedSession inserts or updates a named provider session row. +func (s *Store) UpsertNamedSession(ctx context.Context, session NamedSession) error { + if err := validateNamedSession(session); err != nil { + return err + } + return s.write(ctx, func(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, ` +INSERT INTO named_sessions ( + name, profile, provider, adapter, model, host, provider_session_id, created_at, last_used_at +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(name) DO UPDATE SET + profile = excluded.profile, + provider = excluded.provider, + adapter = excluded.adapter, + model = excluded.model, + host = excluded.host, + provider_session_id = excluded.provider_session_id, + last_used_at = excluded.last_used_at`, + session.Name, session.Profile, session.Provider, session.Adapter, session.Model, session.Host, + session.ProviderSessionID, encodeTime(session.CreatedAt), encodeTime(session.LastUsedAt), + ) + if err != nil { + return fmt.Errorf("ledger: upsert named session: %w", err) + } + return nil + }) +} + +// GetNamedSession returns a named provider session row. +func (s *Store) GetNamedSession(ctx context.Context, name string) (NamedSession, error) { + if strings.TrimSpace(name) == "" { + return NamedSession{}, invalidInput("name", name) + } + if err := s.checkOpen(); err != nil { + return NamedSession{}, err + } + row := s.db.QueryRowContext(ctx, ` +SELECT name, profile, provider, adapter, model, host, provider_session_id, created_at, last_used_at +FROM named_sessions WHERE name = ?`, name) + session, err := scanNamedSession(row) + if errors.Is(err, sql.ErrNoRows) { + return NamedSession{}, ErrNotFound + } + if err != nil { + return NamedSession{}, fmt.Errorf("ledger: get named session: %w", err) + } + return session, nil +} + +func validateAllocateRunParams(params AllocateRunParams) error { + for field, value := range map[string]string{ + "pr_key": params.PRKey, + "pr_url": params.PRURL, + "sha": params.SHA, + "base_sha": params.BaseSHA, + "profile": params.Profile, + "posting_identity": params.PostingIdentity, + "artifact_path": params.ArtifactPath, + } { + if strings.TrimSpace(value) == "" { + return invalidInput(field, value) + } + } + if params.StartedAt.IsZero() { + return invalidInput("started_at", "") + } + if !params.PostMode.Valid() { + return invalidInput("post_mode", params.PostMode.String()) + } + return nil +} + +func validateSession(session Session) error { + for field, value := range map[string]string{ + "session_row_id": session.SessionRowID, + "run_id": session.RunID, + "provider_session_id": session.ProviderSessionID, + "adapter": session.Adapter, + "model": session.Model, + } { + if strings.TrimSpace(value) == "" { + return invalidInput(field, value) + } + } + if !session.Role.Valid() { + return invalidInput("role", session.Role.String()) + } + if session.StartedAt.IsZero() { + return invalidInput("started_at", "") + } + return nil +} + +func validateFinding(finding Finding) error { + for field, value := range map[string]string{ + "finding_id": finding.FindingID.String(), + "run_id": finding.RunID, + "session_row_id": finding.SessionRowID, + "file_path": finding.FilePath, + "body": finding.Body, + } { + if strings.TrimSpace(value) == "" { + return invalidInput(field, value) + } + } + if !finding.Severity.Valid() { + return invalidInput("severity", finding.Severity.String()) + } + if finding.Side != nil && !finding.Side.Valid() { + return invalidInput("side", finding.Side.String()) + } + if !finding.Anchoring.Valid() { + return invalidInput("anchoring", finding.Anchoring.String()) + } + return nil +} + +func validatePlannedAction(action PlannedAction) error { + for field, value := range map[string]string{ + "action_id": action.ActionID, + "run_id": action.RunID, + "payload_json": action.PayloadJSON, + } { + if strings.TrimSpace(value) == "" { + return invalidInput(field, value) + } + } + if !action.Kind.Valid() { + return invalidInput("kind", action.Kind.String()) + } + if !action.Status.Valid() { + return invalidInput("status", action.Status.String()) + } + if action.PlannedAt.IsZero() { + return invalidInput("planned_at", "") + } + if action.Attempts < 0 { + return invalidInput("attempts", fmt.Sprint(action.Attempts)) + } + return nil +} + +func validateNamedSession(session NamedSession) error { + for field, value := range map[string]string{ + "name": session.Name, + "profile": session.Profile, + "provider": session.Provider, + "adapter": session.Adapter, + "model": session.Model, + "host": session.Host, + "provider_session_id": session.ProviderSessionID, + } { + if strings.TrimSpace(value) == "" { + return invalidInput(field, value) + } + } + if session.CreatedAt.IsZero() { + return invalidInput("created_at", "") + } + if session.LastUsedAt.IsZero() { + return invalidInput("last_used_at", "") + } + return nil +} + +func scanRun(row interface{ Scan(...any) error }) (Run, error) { + var ( + run Run + postMode string + startedAt string + heartbeatAt sql.NullString + completedAt sql.NullString + outcome sql.NullString + blocking, major int + minor, nits int + ) + if err := row.Scan( + &run.RunID, &run.PRKey, &run.SHA, &run.BaseSHA, &run.Attempt, &run.Profile, &run.PostingIdentity, + &postMode, &startedAt, &heartbeatAt, &completedAt, &outcome, &run.ArtifactPath, + &blocking, &major, &minor, &nits, + ); err != nil { + return Run{}, err + } + parsedPostMode, err := ParsePostMode(postMode) + if err != nil { + return Run{}, err + } + run.PostMode = parsedPostMode + run.StartedAt, err = parseTime(startedAt) + if err != nil { + return Run{}, err + } + if heartbeatAt.Valid { + parsed, err := parseTime(heartbeatAt.String) + if err != nil { + return Run{}, err + } + run.HeartbeatAt = &parsed + } + if completedAt.Valid { + parsed, err := parseTime(completedAt.String) + if err != nil { + return Run{}, err + } + run.CompletedAt = &parsed + } + if outcome.Valid { + parsed, err := ParseOutcome(outcome.String) + if err != nil { + return Run{}, err + } + run.Outcome = &parsed + } + run.BlockingCount = blocking + run.MajorCount = major + run.MinorCount = minor + run.NitsCount = nits + return run, nil +} + +func scanSession(row interface{ Scan(...any) error }) (Session, error) { + var ( + session Session + role string + agentID sql.NullString + effort sql.NullString + startedAt string + completedAt sql.NullString + durationMS sql.NullInt64 + tokensIn sql.NullInt64 + tokensOut sql.NullInt64 + cacheRead sql.NullInt64 + cacheCreate sql.NullInt64 + costUSD sql.NullFloat64 + ) + if err := row.Scan( + &session.SessionRowID, &session.RunID, &session.ProviderSessionID, &role, &agentID, &session.Adapter, + &session.Model, &effort, &startedAt, &completedAt, &durationMS, &tokensIn, &tokensOut, + &cacheRead, &cacheCreate, &costUSD, + ); err != nil { + return Session{}, err + } + parsedRole, err := ParseSessionRole(role) + if err != nil { + return Session{}, err + } + session.Role = parsedRole + session.AgentID = stringPtrFromNull(agentID) + session.Effort = stringPtrFromNull(effort) + session.StartedAt, err = parseTime(startedAt) + if err != nil { + return Session{}, err + } + session.CompletedAt, err = timePtrFromNull(completedAt) + if err != nil { + return Session{}, err + } + session.DurationMS = int64PtrFromNull(durationMS) + session.TokensIn = int64PtrFromNull(tokensIn) + session.TokensOut = int64PtrFromNull(tokensOut) + session.CacheRead = int64PtrFromNull(cacheRead) + session.CacheCreate = int64PtrFromNull(cacheCreate) + session.CostUSD = float64PtrFromNull(costUSD) + return session, nil +} + +func scanFinding(row interface{ Scan(...any) error }) (Finding, error) { + var ( + finding Finding + findingID string + severity string + side sql.NullString + line sql.NullInt64 + diffPosition sql.NullInt64 + anchoring string + ) + if err := row.Scan( + &findingID, &finding.RunID, &finding.SessionRowID, &severity, &finding.FilePath, + &side, &line, &diffPosition, &anchoring, &finding.Body, + ); err != nil { + return Finding{}, err + } + finding.FindingID = review.FindingID(findingID) + parsedSeverity, err := review.ParseSeverity(severity) + if err != nil { + return Finding{}, err + } + finding.Severity = parsedSeverity + if side.Valid { + parsed, err := review.ParseDiffSide(side.String) + if err != nil { + return Finding{}, err + } + finding.Side = &parsed + } + finding.Line = int64PtrFromNull(line) + finding.DiffPosition = int64PtrFromNull(diffPosition) + parsedAnchoring, err := review.ParseAnchoring(anchoring) + if err != nil { + return Finding{}, err + } + finding.Anchoring = parsedAnchoring + return finding, nil +} + +func scanPlannedAction(row interface{ Scan(...any) error }) (PlannedAction, error) { + var ( + action PlannedAction + kind string + findingID sql.NullString + threadID sql.NullString + plannedAt string + status string + required int + attemptedAt sql.NullString + postedAt sql.NullString + upstreamID sql.NullString + errorText sql.NullString + ) + if err := row.Scan( + &action.ActionID, &action.RunID, &kind, &findingID, &threadID, &plannedAt, &action.PayloadJSON, + &status, &required, &action.Attempts, &attemptedAt, &postedAt, &upstreamID, &errorText, + ); err != nil { + return PlannedAction{}, err + } + parsedKind, err := ParsePlannedActionKind(kind) + if err != nil { + return PlannedAction{}, err + } + action.Kind = parsedKind + parsedStatus, err := ParsePlannedActionStatus(status) + if err != nil { + return PlannedAction{}, err + } + action.Status = parsedStatus + action.FindingID = stringPtrFromNull(findingID) + action.ThreadID = stringPtrFromNull(threadID) + action.PlannedAt, err = parseTime(plannedAt) + if err != nil { + return PlannedAction{}, err + } + action.Required = required != 0 + action.AttemptedAt, err = timePtrFromNull(attemptedAt) + if err != nil { + return PlannedAction{}, err + } + action.PostedAt, err = timePtrFromNull(postedAt) + if err != nil { + return PlannedAction{}, err + } + action.UpstreamID = stringPtrFromNull(upstreamID) + action.Error = stringPtrFromNull(errorText) + return action, nil +} + +func scanNamedSession(row interface{ Scan(...any) error }) (NamedSession, error) { + var session NamedSession + var createdAt, lastUsedAt string + if err := row.Scan( + &session.Name, &session.Profile, &session.Provider, &session.Adapter, &session.Model, &session.Host, + &session.ProviderSessionID, &createdAt, &lastUsedAt, + ); err != nil { + return NamedSession{}, err + } + parsedCreated, err := parseTime(createdAt) + if err != nil { + return NamedSession{}, err + } + parsedLastUsed, err := parseTime(lastUsedAt) + if err != nil { + return NamedSession{}, err + } + session.CreatedAt = parsedCreated + session.LastUsedAt = parsedLastUsed + return session, nil +} + +func encodeTime(value time.Time) string { + return value.UTC().Format(time.RFC3339Nano) +} + +func encodeOptionalTime(value *time.Time) any { + if value == nil { + return nil + } + return encodeTime(*value) +} + +func parseTime(value string) (time.Time, error) { + parsed, err := time.Parse(time.RFC3339Nano, value) + if err != nil { + return time.Time{}, fmt.Errorf("ledger: parse time %q: %w", value, err) + } + return parsed, nil +} + +func timePtrFromNull(value sql.NullString) (*time.Time, error) { + if !value.Valid { + return nil, nil + } + parsed, err := parseTime(value.String) + if err != nil { + return nil, err + } + return &parsed, nil +} + +func stringPtrFromNull(value sql.NullString) *string { + if !value.Valid { + return nil + } + return &value.String +} + +func int64PtrFromNull(value sql.NullInt64) *int64 { + if !value.Valid { + return nil + } + return &value.Int64 +} + +func float64PtrFromNull(value sql.NullFloat64) *float64 { + if !value.Valid { + return nil + } + return &value.Float64 +} + +func normalizeStorageValue(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func invalidInput(field, value string) error { + return fmt.Errorf("%w: %s %q", ErrInvalidInput, field, value) +} + +func isSQLiteConstraintError(err error) bool { + var sqliteErr *sqlite.Error + if !errors.As(err, &sqliteErr) { + return false + } + return sqliteErr.Code()&0xff == sqlite3.SQLITE_CONSTRAINT +} + +func isResumeAttemptConstraintError(err error) bool { + message := err.Error() + for _, column := range []string{ + "runs.pr_key", + "runs.sha", + "runs.base_sha", + "runs.profile", + "runs.posting_identity", + "runs.attempt", + } { + if !strings.Contains(message, column) { + return false + } + } + return true +} + +var schemaStatements = []string{ + `CREATE TABLE prs ( + pr_key TEXT PRIMARY KEY, + pr_url TEXT NOT NULL, + first_seen_at TEXT NOT NULL +)`, + `CREATE TABLE runs ( + run_id TEXT PRIMARY KEY, + pr_key TEXT NOT NULL REFERENCES prs(pr_key), + sha TEXT NOT NULL, + base_sha TEXT NOT NULL, + attempt INTEGER NOT NULL, + profile TEXT NOT NULL, + posting_identity TEXT NOT NULL, + post_mode TEXT NOT NULL DEFAULT 'live', + started_at TEXT NOT NULL, + heartbeat_at TEXT, + completed_at TEXT, + outcome TEXT, + artifact_path TEXT NOT NULL, + blocking_count INTEGER NOT NULL DEFAULT 0, + major_count INTEGER NOT NULL DEFAULT 0, + minor_count INTEGER NOT NULL DEFAULT 0, + nits_count INTEGER NOT NULL DEFAULT 0, + UNIQUE(pr_key, sha, base_sha, profile, posting_identity, attempt) +)`, + `CREATE INDEX runs_pr_sha ON runs(pr_key, sha)`, + `CREATE INDEX runs_resume ON runs(pr_key, sha, base_sha, profile, posting_identity, post_mode, outcome)`, + `CREATE INDEX runs_started_at ON runs(started_at)`, + `CREATE TABLE sessions ( + session_row_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL REFERENCES runs(run_id) ON DELETE CASCADE, + provider_session_id TEXT NOT NULL, + role TEXT NOT NULL, + agent_id TEXT, + adapter TEXT NOT NULL, + model TEXT NOT NULL, + effort TEXT, + started_at TEXT NOT NULL, + completed_at TEXT, + duration_ms INTEGER, + tokens_in INTEGER, + tokens_out INTEGER, + cache_read INTEGER, + cache_create INTEGER, + cost_usd REAL +)`, + `CREATE INDEX sessions_run ON sessions(run_id)`, + `CREATE INDEX sessions_provider ON sessions(provider_session_id)`, + `CREATE TABLE findings ( + finding_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL REFERENCES runs(run_id) ON DELETE CASCADE, + session_row_id TEXT NOT NULL REFERENCES sessions(session_row_id) ON DELETE CASCADE, + severity TEXT NOT NULL, + file_path TEXT NOT NULL, + side TEXT, + line INTEGER, + diff_position INTEGER, + anchoring TEXT NOT NULL, + body TEXT NOT NULL +)`, + `CREATE INDEX findings_run ON findings(run_id)`, + `CREATE TABLE planned_actions ( + action_id TEXT PRIMARY KEY, + run_id TEXT NOT NULL REFERENCES runs(run_id) ON DELETE CASCADE, + kind TEXT NOT NULL, + finding_id TEXT REFERENCES findings(finding_id) ON DELETE CASCADE, + thread_id TEXT, + planned_at TEXT NOT NULL, + payload_json TEXT NOT NULL, + status TEXT NOT NULL, + required INTEGER NOT NULL DEFAULT 0, + attempts INTEGER NOT NULL DEFAULT 0, + attempted_at TEXT, + posted_at TEXT, + upstream_id TEXT, + error TEXT +)`, + `CREATE INDEX planned_actions_run ON planned_actions(run_id)`, + `CREATE INDEX planned_actions_status ON planned_actions(status)`, + `CREATE TABLE named_sessions ( + name TEXT PRIMARY KEY, + profile TEXT NOT NULL, + provider TEXT NOT NULL, + adapter TEXT NOT NULL, + model TEXT NOT NULL, + host TEXT NOT NULL, + provider_session_id TEXT NOT NULL, + created_at TEXT NOT NULL, + last_used_at TEXT NOT NULL +)`, +} diff --git a/internal/ledger/ledger_test.go b/internal/ledger/ledger_test.go new file mode 100644 index 0000000..fc0f752 --- /dev/null +++ b/internal/ledger/ledger_test.go @@ -0,0 +1,1173 @@ +package ledger + +import ( + "context" + "database/sql" + "errors" + "path/filepath" + "reflect" + "slices" + "strconv" + "sync" + "testing" + "time" + + "github.com/open-cli-collective/codereview-cli/internal/review" + _ "modernc.org/sqlite" +) + +func TestOpenMigratesFreshDatabaseAndAppliesStartupContract(t *testing.T) { + store := openStore(t) + + if version := queryInt(t, store.db, "SELECT schema_version FROM meta"); version != SchemaVersion { + t.Fatalf("schema_version = %d, want %d", version, SchemaVersion) + } + if got := queryInt(t, store.db, "PRAGMA foreign_keys"); got != 1 { + t.Fatalf("PRAGMA foreign_keys = %d, want 1", got) + } + if got := queryString(t, store.db, "PRAGMA journal_mode"); got != "wal" { + t.Fatalf("PRAGMA journal_mode = %q, want wal", got) + } + if got := queryInt(t, store.db, "PRAGMA busy_timeout"); int64(got) != DefaultBusyTimeout.Milliseconds() { + t.Fatalf("PRAGMA busy_timeout = %d, want %d", got, DefaultBusyTimeout.Milliseconds()) + } + + for _, table := range []string{"prs", "runs", "sessions", "findings", "planned_actions", "named_sessions"} { + if !tableExists(t, store.db, table) { + t.Fatalf("table %s does not exist", table) + } + } + for _, index := range []string{"runs_pr_sha", "runs_resume", "runs_started_at", "sessions_run", "sessions_provider", "findings_run", "planned_actions_run", "planned_actions_status"} { + if !indexExists(t, store.db, index) { + t.Fatalf("index %s does not exist", index) + } + } + wantResumeIndex := []string{"pr_key", "sha", "base_sha", "profile", "posting_identity", "post_mode", "outcome"} + if got := indexColumns(t, store.db, "runs_resume"); !reflect.DeepEqual(got, wantResumeIndex) { + t.Fatalf("runs_resume columns = %#v, want %#v", got, wantResumeIndex) + } +} + +func TestForeignKeyCascadeDeletesRunChildren(t *testing.T) { + store := openStore(t) + ctx := context.Background() + run := allocateRun(t, store, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, store, session) + insertFinding(t, store, validFinding(run.RunID, session.SessionRowID)) + insertPlannedAction(t, store, validPlannedAction(run.RunID)) + + if err := store.DeleteRun(ctx, run.RunID); err != nil { + t.Fatalf("DeleteRun: %v", err) + } + + if _, err := store.GetRun(ctx, run.RunID); !errors.Is(err, ErrNotFound) { + t.Fatalf("GetRun after delete error = %v, want ErrNotFound", err) + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM sessions"); count != 0 { + t.Fatalf("sessions count = %d, want 0", count) + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM findings"); count != 0 { + t.Fatalf("findings count = %d, want 0", count) + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM planned_actions"); count != 0 { + t.Fatalf("planned_actions count = %d, want 0", count) + } +} + +func TestForeignKeyCascadeDeletesSessionChildren(t *testing.T) { + store := openStore(t) + run := allocateRun(t, store, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, store, session) + insertFinding(t, store, validFinding(run.RunID, session.SessionRowID)) + insertPlannedAction(t, store, validPlannedAction(run.RunID)) + + execSQL(t, store.db, "DELETE FROM sessions WHERE session_row_id = ?", session.SessionRowID) + + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM runs"); count != 1 { + t.Fatalf("runs count = %d, want 1", count) + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM sessions"); count != 0 { + t.Fatalf("sessions count = %d, want 0", count) + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM findings"); count != 0 { + t.Fatalf("findings count = %d, want 0", count) + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM planned_actions"); count != 0 { + t.Fatalf("planned_actions count = %d, want 0", count) + } +} + +func TestRunUniqueResumeAttemptConstraint(t *testing.T) { + store := openStore(t) + run := allocateRun(t, store, validAllocateRunParams()) + + _, err := store.db.ExecContext( + context.Background(), + `INSERT INTO runs ( + run_id, pr_key, sha, base_sha, attempt, profile, posting_identity, post_mode, + started_at, artifact_path + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "manual-run", run.PRKey, run.SHA, run.BaseSHA, run.Attempt, run.Profile, run.PostingIdentity, + run.PostMode.String(), encodeTime(run.StartedAt), "/tmp/manual", + ) + if err == nil { + t.Fatal("duplicate resume attempt insert error = nil, want unique constraint failure") + } +} + +func TestAllocateRunAllocatesSequentialAttempts(t *testing.T) { + store := openStore(t) + params := validAllocateRunParams() + + first := allocateRun(t, store, params) + params.RunID = "" + params.StartedAt = params.StartedAt.Add(time.Minute) + params.ArtifactPath = "/tmp/run-2" + second := allocateRun(t, store, params) + + if first.Attempt != 1 { + t.Fatalf("first attempt = %d, want 1", first.Attempt) + } + if second.Attempt != 2 { + t.Fatalf("second attempt = %d, want 2", second.Attempt) + } +} + +func TestAllocateRunSeparatesAttemptSequencesByResumeKey(t *testing.T) { + store := openStore(t) + base := validAllocateRunParams() + baseRun := allocateRun(t, store, base) + if baseRun.Attempt != 1 { + t.Fatalf("base attempt = %d, want 1", baseRun.Attempt) + } + + tests := []struct { + name string + mutate func(*AllocateRunParams) + }{ + {name: "pr key", mutate: func(p *AllocateRunParams) { p.PRKey = "github_other_repo_1" }}, + {name: "sha", mutate: func(p *AllocateRunParams) { p.SHA = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" }}, + {name: "base sha", mutate: func(p *AllocateRunParams) { p.BaseSHA = "cccccccccccccccccccccccccccccccccccccccc" }}, + {name: "profile", mutate: func(p *AllocateRunParams) { p.Profile = "other" }}, + {name: "posting identity", mutate: func(p *AllocateRunParams) { p.PostingIdentity = "other@example.com" }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := base + params.RunID = "" + params.ArtifactPath = filepath.Join("/tmp", tt.name) + tt.mutate(¶ms) + + run := allocateRun(t, store, params) + if run.Attempt != 1 { + t.Fatalf("attempt = %d, want independent sequence at 1", run.Attempt) + } + }) + } +} + +func TestAllocateRunConcurrentSameKey(t *testing.T) { + store := openStore(t) + ctx := context.Background() + params := validAllocateRunParams() + + const callers = 2 + var wg sync.WaitGroup + results := make(chan Run, callers) + errs := make(chan error, callers) + for i := range callers { + wg.Add(1) + go func(i int) { + defer wg.Done() + p := params + p.RunID = "" + p.ArtifactPath = filepath.Join("/tmp", "concurrent", strconv.Itoa(i)) + run, err := store.AllocateRun(ctx, p) + if err != nil { + errs <- err + return + } + results <- run + }(i) + } + wg.Wait() + close(results) + close(errs) + + for err := range errs { + t.Fatalf("AllocateRun concurrent error: %v", err) + } + var attempts []int + for run := range results { + attempts = append(attempts, run.Attempt) + } + slices.Sort(attempts) + if !reflect.DeepEqual(attempts, []int{1, 2}) { + t.Fatalf("attempts = %v, want [1 2]", attempts) + } +} + +func TestAllocateRunRecoveryMode(t *testing.T) { + store := openStore(t) + params := validAllocateRunParams() + params.RunID = "recovered-run" + + run := allocateRun(t, store, params) + if run.RunID != "recovered-run" { + t.Fatalf("RunID = %q, want recovered-run", run.RunID) + } + if run.Attempt != 1 { + t.Fatalf("Attempt = %d, want 1", run.Attempt) + } + + got, err := store.GetRun(context.Background(), "recovered-run") + if err != nil { + t.Fatalf("GetRun: %v", err) + } + if got.RunID != run.RunID || got.Attempt != run.Attempt { + t.Fatalf("GetRun = %#v, want run %#v", got, run) + } +} + +func TestAllocateRunRecoveryModeUsesNextAttempt(t *testing.T) { + store := openStore(t) + params := validAllocateRunParams() + first := allocateRun(t, store, params) + + params.RunID = "recovered-run" + params.StartedAt = params.StartedAt.Add(time.Minute) + params.ArtifactPath = "/tmp/recovered-run" + recovered := allocateRun(t, store, params) + + if first.Attempt != 1 { + t.Fatalf("first attempt = %d, want 1", first.Attempt) + } + if recovered.Attempt != 2 { + t.Fatalf("recovered attempt = %d, want 2", recovered.Attempt) + } +} + +func TestAllocateRunRecoveryRejectsExistingRunID(t *testing.T) { + store := openStore(t) + params := validAllocateRunParams() + params.RunID = "same-run" + original := allocateRun(t, store, params) + + params.PRURL = "https://example.test/changed" + params.ArtifactPath = "/tmp/changed" + _, err := store.AllocateRun(context.Background(), params) + if !errors.Is(err, ErrRunExists) { + t.Fatalf("AllocateRun duplicate error = %v, want ErrRunExists", err) + } + + got, err := store.GetRun(context.Background(), original.RunID) + if err != nil { + t.Fatalf("GetRun original: %v", err) + } + if got.ArtifactPath != original.ArtifactPath { + t.Fatalf("ArtifactPath = %q, want original %q", got.ArtifactPath, original.ArtifactPath) + } + pr, err := store.GetPR(context.Background(), original.PRKey) + if err != nil { + t.Fatalf("GetPR original: %v", err) + } + if pr.PRURL != "https://example.test/pr/36" { + t.Fatalf("PRURL = %q, want original URL", pr.PRURL) + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM runs"); count != 1 { + t.Fatalf("runs count = %d, want 1", count) + } +} + +func TestClassifyAllocateConstraintForRecoveryMode(t *testing.T) { + store := openStore(t) + params := validAllocateRunParams() + params.RunID = "existing-run" + run := allocateRun(t, store, params) + duplicateRunIDErr := duplicateRunIDConstraintError(t, store, run) + + retry, err := classifyAllocateConstraint(context.Background(), store.db, params, duplicateRunIDErr) + if !errors.Is(err, ErrRunExists) { + t.Fatalf("classify existing run error = %v, want ErrRunExists", err) + } + if retry { + t.Fatal("classify existing run retry = true, want false") + } + + params.RunID = "missing-run" + retry, err = classifyAllocateConstraint(context.Background(), store.db, params, duplicateRunIDErr) + if err == nil { + t.Fatal("classify missing run id constraint error = nil, want original constraint") + } + if retry { + t.Fatal("classify missing run id constraint retry = true, want false") + } + + resumeErr := duplicateResumeAttemptConstraintError(t, store, run) + retry, err = classifyAllocateConstraint(context.Background(), store.db, params, resumeErr) + if err != nil { + t.Fatalf("classify recovery resume collision error = %v, want nil", err) + } + if !retry { + t.Fatal("classify recovery resume collision retry = false, want true") + } + + params.RunID = "" + retry, err = classifyAllocateConstraint(context.Background(), store.db, params, resumeErr) + if err != nil { + t.Fatalf("classify fresh resume collision error = %v, want nil", err) + } + if !retry { + t.Fatal("classify fresh resume collision retry = false, want true") + } + + params.RunID = "missing-run" + foreignKeyErr := missingSessionForeignKeyConstraintError(t, store) + retry, err = classifyAllocateConstraint(context.Background(), store.db, params, foreignKeyErr) + if err == nil { + t.Fatal("classify foreign-key constraint error = nil, want original constraint") + } + if retry { + t.Fatal("classify foreign-key constraint retry = true, want false") + } +} + +func TestSQLiteConstraintClassificationRequiresDriverError(t *testing.T) { + if isSQLiteConstraintError(errors.New("constraint failed")) { + t.Fatal("isSQLiteConstraintError(text error) = true, want false") + } +} + +func TestAllocateRunPreservesPRFirstSeenAndUpdatesURL(t *testing.T) { + store := openStore(t) + params := validAllocateRunParams() + first := allocateRun(t, store, params) + + params.RunID = "" + params.PRURL = "https://example.test/new-url" + params.StartedAt = params.StartedAt.Add(time.Minute) + params.ArtifactPath = "/tmp/second" + allocateRun(t, store, params) + + pr, err := store.GetPR(context.Background(), first.PRKey) + if err != nil { + t.Fatalf("GetPR: %v", err) + } + if pr.PRURL != "https://example.test/new-url" { + t.Fatalf("PRURL = %q, want updated URL", pr.PRURL) + } + if !pr.FirstSeenAt.Equal(first.StartedAt) { + t.Fatalf("FirstSeenAt = %s, want first run started_at %s", pr.FirstSeenAt, first.StartedAt) + } +} + +func TestAllocateRunRollsBackPRUpsertWhenRunInsertFails(t *testing.T) { + store := openStore(t) + params := validAllocateRunParams() + + execSQL(t, store.db, `CREATE TRIGGER fail_runs_insert +BEFORE INSERT ON runs +BEGIN + INSERT INTO missing_table VALUES (1); +END`) + + _, err := store.AllocateRun(context.Background(), params) + if err == nil { + t.Fatal("AllocateRun error = nil, want trigger failure") + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM prs WHERE pr_key = ?", params.PRKey); count != 0 { + t.Fatalf("prs count for failed allocation = %d, want 0", count) + } + if count := queryInt(t, store.db, "SELECT COUNT(*) FROM runs WHERE run_id = ?", params.RunID); count != 0 { + t.Fatalf("runs count for failed allocation = %d, want 0", count) + } +} + +func TestInvalidMutationsLeaveTargetTablesUnchanged(t *testing.T) { + t.Run("allocate run", func(t *testing.T) { + store := openStore(t) + params := validAllocateRunParams() + params.PRKey = "" + + _, err := store.AllocateRun(context.Background(), params) + if !errors.Is(err, ErrInvalidInput) { + t.Fatalf("AllocateRun error = %v, want ErrInvalidInput", err) + } + assertTableCount(t, store.db, "prs", 0) + assertTableCount(t, store.db, "runs", 0) + }) + + t.Run("session", func(t *testing.T) { + store := openStore(t) + run := allocateRun(t, store, validAllocateRunParams()) + session := validSession(run.RunID) + session.SessionRowID = "" + + err := store.InsertSession(context.Background(), session) + if !errors.Is(err, ErrInvalidInput) { + t.Fatalf("InsertSession error = %v, want ErrInvalidInput", err) + } + assertTableCount(t, store.db, "sessions", 0) + }) + + t.Run("finding", func(t *testing.T) { + store := openStore(t) + run := allocateRun(t, store, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, store, session) + finding := validFinding(run.RunID, session.SessionRowID) + finding.FindingID = "" + + err := store.InsertFinding(context.Background(), finding) + if !errors.Is(err, ErrInvalidInput) { + t.Fatalf("InsertFinding error = %v, want ErrInvalidInput", err) + } + assertTableCount(t, store.db, "findings", 0) + }) + + t.Run("planned action", func(t *testing.T) { + store := openStore(t) + run := allocateRun(t, store, validAllocateRunParams()) + action := validPlannedAction(run.RunID) + action.ActionID = "" + + err := store.InsertPlannedAction(context.Background(), action) + if !errors.Is(err, ErrInvalidInput) { + t.Fatalf("InsertPlannedAction error = %v, want ErrInvalidInput", err) + } + assertTableCount(t, store.db, "planned_actions", 0) + }) + + t.Run("named session", func(t *testing.T) { + store := openStore(t) + session := validNamedSession() + session.Name = "" + + err := store.UpsertNamedSession(context.Background(), session) + if !errors.Is(err, ErrInvalidInput) { + t.Fatalf("UpsertNamedSession error = %v, want ErrInvalidInput", err) + } + assertTableCount(t, store.db, "named_sessions", 0) + }) +} + +func TestTypedPersistenceRoundTrips(t *testing.T) { + store := openStore(t) + run := allocateRun(t, store, validAllocateRunParams()) + + session := validSession(run.RunID) + insertSession(t, store, session) + gotSession, err := store.GetSession(context.Background(), session.SessionRowID) + if err != nil { + t.Fatalf("GetSession: %v", err) + } + if !reflect.DeepEqual(gotSession, session) { + t.Fatalf("GetSession = %#v, want %#v", gotSession, session) + } + + finding := validFinding(run.RunID, session.SessionRowID) + insertFinding(t, store, finding) + findings, err := store.ListFindings(context.Background(), run.RunID) + if err != nil { + t.Fatalf("ListFindings: %v", err) + } + if !reflect.DeepEqual(findings, []Finding{finding}) { + t.Fatalf("ListFindings = %#v, want %#v", findings, []Finding{finding}) + } + + action := validPlannedAction(run.RunID) + insertPlannedAction(t, store, action) + actions, err := store.ListPlannedActions(context.Background(), run.RunID) + if err != nil { + t.Fatalf("ListPlannedActions: %v", err) + } + if !reflect.DeepEqual(actions, []PlannedAction{action}) { + t.Fatalf("ListPlannedActions = %#v, want %#v", actions, []PlannedAction{action}) + } + + named := validNamedSession() + if err := store.UpsertNamedSession(context.Background(), named); err != nil { + t.Fatalf("UpsertNamedSession: %v", err) + } + gotNamed, err := store.GetNamedSession(context.Background(), named.Name) + if err != nil { + t.Fatalf("GetNamedSession: %v", err) + } + if !reflect.DeepEqual(gotNamed, named) { + t.Fatalf("GetNamedSession = %#v, want %#v", gotNamed, named) + } +} + +func TestInvalidInputsReturnErrInvalidInputBeforeMutation(t *testing.T) { + tests := []struct { + name string + run func(*testing.T, *Store) error + }{ + {name: "allocate empty pr key", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.PRKey = "" + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "allocate empty pr url", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.PRURL = "" + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "allocate empty sha", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.SHA = "" + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "allocate empty base sha", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.BaseSHA = "" + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "allocate empty profile", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.Profile = "" + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "allocate empty posting identity", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.PostingIdentity = "" + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "allocate zero started_at", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.StartedAt = time.Time{} + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "allocate empty artifact path", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.ArtifactPath = "" + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "allocate bad post mode", run: func(_ *testing.T, s *Store) error { + params := validAllocateRunParams() + params.PostMode = PostMode("preview") + _, err := s.AllocateRun(context.Background(), params) + return err + }}, + {name: "session missing id", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + session.SessionRowID = "" + return s.InsertSession(context.Background(), session) + }}, + {name: "session missing run id", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + session.RunID = "" + return s.InsertSession(context.Background(), session) + }}, + {name: "session missing provider session id", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + session.ProviderSessionID = "" + return s.InsertSession(context.Background(), session) + }}, + {name: "session bad role", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + session.Role = SessionRole("manager") + return s.InsertSession(context.Background(), session) + }}, + {name: "session missing adapter", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + session.Adapter = "" + return s.InsertSession(context.Background(), session) + }}, + {name: "session missing model", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + session.Model = "" + return s.InsertSession(context.Background(), session) + }}, + {name: "session zero started_at", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + session.StartedAt = time.Time{} + return s.InsertSession(context.Background(), session) + }}, + {name: "finding missing id", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, s, session) + finding := validFinding(run.RunID, session.SessionRowID) + finding.FindingID = "" + return s.InsertFinding(context.Background(), finding) + }}, + {name: "finding missing run id", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, s, session) + finding := validFinding(run.RunID, session.SessionRowID) + finding.RunID = "" + return s.InsertFinding(context.Background(), finding) + }}, + {name: "finding missing session row id", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, s, session) + finding := validFinding(run.RunID, session.SessionRowID) + finding.SessionRowID = "" + return s.InsertFinding(context.Background(), finding) + }}, + {name: "finding bad severity", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, s, session) + finding := validFinding(run.RunID, session.SessionRowID) + finding.Severity = review.Severity("medium") + return s.InsertFinding(context.Background(), finding) + }}, + {name: "finding missing file path", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, s, session) + finding := validFinding(run.RunID, session.SessionRowID) + finding.FilePath = "" + return s.InsertFinding(context.Background(), finding) + }}, + {name: "finding bad side", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, s, session) + finding := validFinding(run.RunID, session.SessionRowID) + side := review.DiffSide("BOTH") + finding.Side = &side + return s.InsertFinding(context.Background(), finding) + }}, + {name: "finding bad anchoring", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, s, session) + finding := validFinding(run.RunID, session.SessionRowID) + finding.Anchoring = review.Anchoring("file") + return s.InsertFinding(context.Background(), finding) + }}, + {name: "finding missing body", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + session := validSession(run.RunID) + insertSession(t, s, session) + finding := validFinding(run.RunID, session.SessionRowID) + finding.Body = "" + return s.InsertFinding(context.Background(), finding) + }}, + {name: "planned action missing action id", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + action := validPlannedAction(run.RunID) + action.ActionID = "" + return s.InsertPlannedAction(context.Background(), action) + }}, + {name: "planned action missing run id", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + action := validPlannedAction(run.RunID) + action.RunID = "" + return s.InsertPlannedAction(context.Background(), action) + }}, + {name: "planned action bad kind", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + action := validPlannedAction(run.RunID) + action.Kind = PlannedActionKind("comment") + return s.InsertPlannedAction(context.Background(), action) + }}, + {name: "planned action zero planned_at", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + action := validPlannedAction(run.RunID) + action.PlannedAt = time.Time{} + return s.InsertPlannedAction(context.Background(), action) + }}, + {name: "planned action missing payload", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + action := validPlannedAction(run.RunID) + action.PayloadJSON = "" + return s.InsertPlannedAction(context.Background(), action) + }}, + {name: "planned action bad status", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + action := validPlannedAction(run.RunID) + action.Status = PlannedActionStatus("done") + return s.InsertPlannedAction(context.Background(), action) + }}, + {name: "planned action negative attempts", run: func(t *testing.T, s *Store) error { + run := allocateRun(t, s, validAllocateRunParams()) + action := validPlannedAction(run.RunID) + action.Attempts = -1 + return s.InsertPlannedAction(context.Background(), action) + }}, + {name: "named session missing name", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.Name = "" + return s.UpsertNamedSession(context.Background(), named) + }}, + {name: "named session missing profile", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.Profile = "" + return s.UpsertNamedSession(context.Background(), named) + }}, + {name: "named session missing provider", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.Provider = "" + return s.UpsertNamedSession(context.Background(), named) + }}, + {name: "named session missing adapter", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.Adapter = "" + return s.UpsertNamedSession(context.Background(), named) + }}, + {name: "named session missing model", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.Model = "" + return s.UpsertNamedSession(context.Background(), named) + }}, + {name: "named session missing host", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.Host = "" + return s.UpsertNamedSession(context.Background(), named) + }}, + {name: "named session missing provider session id", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.ProviderSessionID = "" + return s.UpsertNamedSession(context.Background(), named) + }}, + {name: "named session zero created_at", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.CreatedAt = time.Time{} + return s.UpsertNamedSession(context.Background(), named) + }}, + {name: "named session zero last_used_at", run: func(_ *testing.T, s *Store) error { + named := validNamedSession() + named.LastUsedAt = time.Time{} + return s.UpsertNamedSession(context.Background(), named) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := openStore(t) + err := tt.run(t, store) + if !errors.Is(err, ErrInvalidInput) { + t.Fatalf("error = %v, want ErrInvalidInput", err) + } + }) + } +} + +func TestStorageValueParsers(t *testing.T) { + if got, err := ParseOutcome(" request_changes "); err != nil || got != OutcomeRequestChanges { + t.Fatalf("ParseOutcome = %q, %v; want request_changes, nil", got, err) + } + if _, err := ParseOutcome("changes_requested"); !errors.Is(err, ErrInvalidInput) { + t.Fatalf("ParseOutcome invalid error = %v, want ErrInvalidInput", err) + } + if got, err := ParsePlannedActionKind("ROLLUP_COMMENT"); err != nil || got != PlannedActionRollupComment { + t.Fatalf("ParsePlannedActionKind = %q, %v; want rollup_comment, nil", got, err) + } +} + +func TestWriteReturnsWriterResultAfterDispatchDespiteContextCancellation(t *testing.T) { + store := openStore(t) + ctx, cancel := context.WithCancel(context.Background()) + entered := make(chan struct{}) + release := make(chan struct{}) + errCh := make(chan error, 1) + + go func() { + errCh <- store.write(ctx, func(context.Context, *sql.DB) error { + close(entered) + <-release + return nil + }) + }() + + <-entered + cancel() + close(release) + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("write error = %v, want nil", err) + } + case <-time.After(time.Second): + t.Fatal("write did not return") + } +} + +func TestCloseStopsWriterAndRejectsMutation(t *testing.T) { + store := openStore(t) + if err := store.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if err := store.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } + + _, err := store.AllocateRun(context.Background(), validAllocateRunParams()) + if !errors.Is(err, ErrClosed) { + t.Fatalf("AllocateRun after Close error = %v, want ErrClosed", err) + } + + readChecks := []struct { + name string + run func() error + }{ + {name: "GetPR", run: func() error { + _, err := store.GetPR(context.Background(), "github_open-cli_codereview-cli_36") + return err + }}, + {name: "GetRun", run: func() error { + _, err := store.GetRun(context.Background(), "run-1") + return err + }}, + {name: "GetSession", run: func() error { + _, err := store.GetSession(context.Background(), "session-row-1") + return err + }}, + {name: "ListFindings", run: func() error { + _, err := store.ListFindings(context.Background(), "run-1") + return err + }}, + {name: "ListPlannedActions", run: func() error { + _, err := store.ListPlannedActions(context.Background(), "run-1") + return err + }}, + {name: "GetNamedSession", run: func() error { + _, err := store.GetNamedSession(context.Background(), "daily") + return err + }}, + } + for _, check := range readChecks { + t.Run(check.name, func(t *testing.T) { + if err := check.run(); !errors.Is(err, ErrClosed) { + t.Fatalf("%s after Close error = %v, want ErrClosed", check.name, err) + } + }) + } +} + +func openStore(t *testing.T) *Store { + t.Helper() + + return openStoreAt(t, filepath.Join(t.TempDir(), "ledger.db")) +} + +func openStoreAt(t *testing.T, path string) *Store { + t.Helper() + + store, err := Open(context.Background(), path) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { + if err := store.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + }) + return store +} + +func validAllocateRunParams() AllocateRunParams { + return AllocateRunParams{ + PRKey: "github_open-cli_codereview-cli_36", + PRURL: "https://example.test/pr/36", + RunID: "run-1", + SHA: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + BaseSHA: "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + Profile: "default", + PostingIdentity: "reviewer@example.com", + PostMode: PostModeLive, + StartedAt: time.Date(2026, 5, 30, 12, 0, 0, 0, time.UTC), + ArtifactPath: "/tmp/run-1", + } +} + +func allocateRun(t *testing.T, store *Store, params AllocateRunParams) Run { + t.Helper() + + run, err := store.AllocateRun(context.Background(), params) + if err != nil { + t.Fatalf("AllocateRun: %v", err) + } + return run +} + +func validSession(runID string) Session { + completed := time.Date(2026, 5, 30, 12, 3, 0, 0, time.UTC) + duration := int64(1200) + tokensIn := int64(100) + tokensOut := int64(20) + cacheRead := int64(5) + cacheCreate := int64(7) + cost := 0.42 + + return Session{ + SessionRowID: "session-row-1", + RunID: runID, + ProviderSessionID: "provider-session-1", + Role: SessionRoleReviewer, + AgentID: strPtr("harness:architecture"), + Adapter: "codex_cli", + Model: "gpt-5.5", + Effort: strPtr("xhigh"), + StartedAt: time.Date(2026, 5, 30, 12, 1, 0, 0, time.UTC), + CompletedAt: &completed, + DurationMS: &duration, + TokensIn: &tokensIn, + TokensOut: &tokensOut, + CacheRead: &cacheRead, + CacheCreate: &cacheCreate, + CostUSD: &cost, + } +} + +func insertSession(t *testing.T, store *Store, session Session) { + t.Helper() + if err := store.InsertSession(context.Background(), session); err != nil { + t.Fatalf("InsertSession: %v", err) + } +} + +func validFinding(runID, sessionRowID string) Finding { + side := review.DiffSideRight + line := int64(42) + diffPosition := int64(17) + + return Finding{ + FindingID: review.FindingID("finding-1"), + RunID: runID, + SessionRowID: sessionRowID, + Severity: review.SeverityMajor, + FilePath: "internal/ledger/ledger.go", + Side: &side, + Line: &line, + DiffPosition: &diffPosition, + Anchoring: review.AnchoringInline, + Body: "finding body", + } +} + +func insertFinding(t *testing.T, store *Store, finding Finding) { + t.Helper() + if err := store.InsertFinding(context.Background(), finding); err != nil { + t.Fatalf("InsertFinding: %v", err) + } +} + +func validPlannedAction(runID string) PlannedAction { + attemptedAt := time.Date(2026, 5, 30, 12, 4, 0, 0, time.UTC) + return PlannedAction{ + ActionID: "action-1", + RunID: runID, + Kind: PlannedActionInlineComment, + FindingID: strPtr("finding-1"), + ThreadID: nil, + PlannedAt: time.Date(2026, 5, 30, 12, 2, 0, 0, time.UTC), + PayloadJSON: `{"body":"hello"}`, + Status: PlannedActionPending, + Required: true, + Attempts: 1, + AttemptedAt: &attemptedAt, + PostedAt: nil, + UpstreamID: nil, + Error: strPtr("retry later"), + } +} + +func insertPlannedAction(t *testing.T, store *Store, action PlannedAction) { + t.Helper() + if err := store.InsertPlannedAction(context.Background(), action); err != nil { + t.Fatalf("InsertPlannedAction: %v", err) + } +} + +func duplicateRunIDConstraintError(t *testing.T, store *Store, run Run) error { + t.Helper() + + _, err := store.db.ExecContext( + context.Background(), + `INSERT INTO runs ( + run_id, pr_key, sha, base_sha, attempt, profile, posting_identity, post_mode, + started_at, artifact_path + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + run.RunID, run.PRKey, run.SHA, run.BaseSHA, run.Attempt+1, run.Profile, run.PostingIdentity, + run.PostMode.String(), encodeTime(run.StartedAt.Add(time.Minute)), "/tmp/duplicate-run-id", + ) + if err == nil { + t.Fatal("duplicate run id insert error = nil, want constraint failure") + } + if !isSQLiteConstraintError(err) { + t.Fatalf("duplicate run id error = %v, want SQLite constraint", err) + } + return err +} + +func duplicateResumeAttemptConstraintError(t *testing.T, store *Store, run Run) error { + t.Helper() + + _, err := store.db.ExecContext( + context.Background(), + `INSERT INTO runs ( + run_id, pr_key, sha, base_sha, attempt, profile, posting_identity, post_mode, + started_at, artifact_path + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "manual-run", run.PRKey, run.SHA, run.BaseSHA, run.Attempt, run.Profile, run.PostingIdentity, + run.PostMode.String(), encodeTime(run.StartedAt.Add(time.Minute)), "/tmp/duplicate-resume", + ) + if err == nil { + t.Fatal("duplicate resume attempt insert error = nil, want constraint failure") + } + if !isSQLiteConstraintError(err) { + t.Fatalf("duplicate resume attempt error = %v, want SQLite constraint", err) + } + return err +} + +func missingSessionForeignKeyConstraintError(t *testing.T, store *Store) error { + t.Helper() + + session := validSession("missing-run") + session.SessionRowID = "missing-run-session" + _, err := store.db.ExecContext( + context.Background(), + `INSERT INTO sessions ( + session_row_id, run_id, provider_session_id, role, agent_id, adapter, model, effort, + started_at, completed_at, duration_ms, tokens_in, tokens_out, cache_read, cache_create, cost_usd + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + session.SessionRowID, session.RunID, session.ProviderSessionID, session.Role.String(), session.AgentID, + session.Adapter, session.Model, session.Effort, encodeTime(session.StartedAt), encodeOptionalTime(session.CompletedAt), + session.DurationMS, session.TokensIn, session.TokensOut, session.CacheRead, session.CacheCreate, session.CostUSD, + ) + if err == nil { + t.Fatal("missing run session insert error = nil, want constraint failure") + } + if !isSQLiteConstraintError(err) { + t.Fatalf("missing run session error = %v, want SQLite constraint", err) + } + return err +} + +func validNamedSession() NamedSession { + return NamedSession{ + Name: "daily", + Profile: "default", + Provider: "openai", + Adapter: "codex_cli", + Model: "gpt-5.5", + Host: "github.com", + ProviderSessionID: "provider-session-1", + CreatedAt: time.Date(2026, 5, 30, 11, 0, 0, 0, time.UTC), + LastUsedAt: time.Date(2026, 5, 30, 12, 0, 0, 0, time.UTC), + } +} + +func queryInt(t *testing.T, db *sql.DB, query string, args ...any) int { + t.Helper() + var got int + if err := db.QueryRowContext(context.Background(), query, args...).Scan(&got); err != nil { + t.Fatalf("query int %q: %v", query, err) + } + return got +} + +func queryString(t *testing.T, db *sql.DB, query string, args ...any) string { + t.Helper() + var got string + if err := db.QueryRowContext(context.Background(), query, args...).Scan(&got); err != nil { + t.Fatalf("query string %q: %v", query, err) + } + return got +} + +func execSQL(t *testing.T, db *sql.DB, statement string, args ...any) { + t.Helper() + + if _, err := db.ExecContext(context.Background(), statement, args...); err != nil { + t.Fatalf("exec %q: %v", statement, err) + } +} + +func assertTableCount(t *testing.T, db *sql.DB, table string, want int) { + t.Helper() + + var query string + switch table { + case "prs": + query = "SELECT COUNT(*) FROM prs" + case "runs": + query = "SELECT COUNT(*) FROM runs" + case "sessions": + query = "SELECT COUNT(*) FROM sessions" + case "findings": + query = "SELECT COUNT(*) FROM findings" + case "planned_actions": + query = "SELECT COUNT(*) FROM planned_actions" + case "named_sessions": + query = "SELECT COUNT(*) FROM named_sessions" + default: + t.Fatalf("unsupported table %q", table) + } + got := queryInt(t, db, query) + if got != want { + t.Fatalf("%s count = %d, want %d", table, got, want) + } +} + +func tableExists(t *testing.T, db *sql.DB, name string) bool { + t.Helper() + return queryInt(t, db, "SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = ?", name) == 1 +} + +func indexExists(t *testing.T, db *sql.DB, name string) bool { + t.Helper() + return queryInt(t, db, "SELECT COUNT(*) FROM sqlite_master WHERE type = 'index' AND name = ?", name) == 1 +} + +func indexColumns(t *testing.T, db *sql.DB, name string) []string { + t.Helper() + + var query string + switch name { + case "runs_resume": + query = "PRAGMA index_info(runs_resume)" + default: + t.Fatalf("unsupported index %q", name) + } + rows, err := db.QueryContext(context.Background(), query) + if err != nil { + t.Fatalf("PRAGMA index_info(%s): %v", name, err) + } + defer rows.Close() + + var columns []string + for rows.Next() { + var seqno, cid int + var column string + if err := rows.Scan(&seqno, &cid, &column); err != nil { + t.Fatalf("scan index_info(%s): %v", name, err) + } + columns = append(columns, column) + } + if err := rows.Err(); err != nil { + t.Fatalf("index_info(%s) rows: %v", name, err) + } + return columns +} + +func strPtr(value string) *string { + return &value +}