diff --git a/CLAUDE.md b/CLAUDE.md index 01a79a2..32755c2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -64,7 +64,7 @@ make ci # Runs: fmt → vet → lint → test-unit → test-race → secu - **Timezone Support**: Offline utilities + calendar integration ✅ - **Email Signing**: GPG/PGP email signing (RFC 3156 PGP/MIME) ✅ - **AI Chat**: Web-based chat interface using locally installed AI agents ✅ -- **Credential Storage**: System keyring (see below) +- **Credential Storage**: System keyring for secrets; file-backed grant cache for non-secret grant metadata (see below) - **Web UI**: Air - browser-based interface (localhost:7365) **Details:** See `docs/ARCHITECTURE.md` @@ -84,13 +84,16 @@ make ci # Runs: fmt → vet → lint → test-unit → test-race → secu --- -## Credential Storage (Keyring) +## Credential Storage -Credentials stored in system keyring (service: `"nylas"`) via `nylas auth config`. +Secrets are stored in the system keyring (service: `"nylas"`) via `nylas auth config`. +Grant metadata and the local default grant are stored outside the keyring in the grant cache. -**Key files:** `internal/ports/secrets.go` (constants), `internal/adapters/keyring/` (implementation), `internal/app/auth/config.go` (setup) +**Key files:** `internal/ports/secrets.go` (constants), `internal/adapters/keyring/` (secret storage), `internal/adapters/grantcache/` (grant metadata cache), `internal/app/auth/config.go` (setup) -**Keys:** `client_id`, `api_key`, `client_secret`, `org_id`, `grants`, `default_grant`, `grant_token_` +**Secret keys:** `client_id`, `api_key`, `client_secret`, `org_id` + +**Grant cache:** non-secret grant ID, email, provider, and default grant at `filepath.Join(os.UserCacheDir(), "nylas", "grants.json")` **Disable keyring:** `NYLAS_DISABLE_KEYRING=true` (falls back to encrypted file at `~/.config/nylas/`) @@ -110,9 +113,11 @@ Credentials stored in system keyring (service: `"nylas"`) via `nylas auth config - `internal/ports/output.go` - OutputWriter interface for pluggable formatting - `internal/adapters/output/` - Table, JSON, YAML, Quiet output adapters - `internal/httputil/` - HTTP response helpers (WriteJSON, LimitedBody, DecodeJSON) +- `internal/adapters/grantcache/` - File-backed local grant metadata/default cache - `internal/adapters/gpg/` - GPG/PGP email signing service (2026) - `internal/adapters/mime/` - RFC 3156 PGP/MIME message builder (2026) - `internal/chat/` - AI chat interface with local agent support (2026) +- `internal/webguard/` - Shared localhost web UI request guards - `internal/cli/setup/` - First-time setup wizard (`nylas init`) **Full inventory:** `docs/ARCHITECTURE.md` diff --git a/README.md b/README.md index 175fdd3..d9fcf76 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ Step-by-step tutorials on [cli.nylas.com](https://cli.nylas.com/guides): ## Configuration -Credentials are stored in your system keyring (macOS Keychain, Linux Secret Service, Windows Credential Manager). Nothing is written to plain-text files. +Credentials are stored in your system keyring (macOS Keychain, Linux Secret Service, Windows Credential Manager). Non-secret grant metadata, such as account email/provider and the local default grant, is cached separately for fast local lookup. ```bash nylas auth status # Check what's configured diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 5a9907b..4f9fe01 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -18,6 +18,7 @@ internal/ ai/ # AI providers (Claude, OpenAI, Groq, Ollama) analytics/ # Focus optimizer, meeting scorer keyring/ # Secret storage + grantcache/ # Non-secret local grant metadata/default cache config/ # Configuration validation mcp/ # MCP proxy server slack/ # Slack API client @@ -26,6 +27,7 @@ internal/ browser/ # Browser automation tunnel/ # Cloudflare tunnel webhookserver/ # Webhook server + webguard/ # Shared localhost web UI request guards cli/ # CLI commands common/ # Shared helpers (client, context, errors, flags, format, html, timeutil) admin/ # API key management @@ -182,14 +184,15 @@ url := qb.BuildURL(baseURL) - `utilities.go` - Utilities interface - `webhook_server.go` - Webhook server interface -3. **Adapters** (`internal/adapters/`) - 12 adapter directories +3. **Adapters** (`internal/adapters/`) - 13 adapter directories | Adapter | Files | Purpose | |---------|-------|---------| | `nylas/` | 94 | Nylas API client (messages, calendars, contacts, events) | | `ai/` | 24 | AI clients (Claude, OpenAI, Groq, Ollama), email analyzer | | `analytics/` | 14 | Focus optimizer, conflict resolver, meeting scorer | - | `keyring/` | 6 | Credential storage (system keyring, file-based) | + | `keyring/` | 8 | Secret storage (system keyring, encrypted file fallback) | + | `grantcache/` | 2 | Non-secret local grant metadata/default cache | | `mcp/` | 8 | MCP proxy server for AI assistants | | `slack/` | 21 | Slack API client (channels, messages, users) | | `config/` | 5 | Configuration validation | diff --git a/docs/COMMANDS.md b/docs/COMMANDS.md index a0c666e..516bfd5 100644 --- a/docs/COMMANDS.md +++ b/docs/COMMANDS.md @@ -439,8 +439,9 @@ nylas webhook pubsub delete --yes ```bash nylas webhook test send # Send test payload nylas webhook test payload [trigger-type] # Generate test payload -nylas webhook server # Start local webhook server -nylas webhook server --port 8080 --tunnel cloudflared # With public tunnel +nylas webhook server # Interactive preflight (offers cloudflared tunnel) +nylas webhook server --no-tunnel # Loopback-only (skip preflight) +nylas webhook server --port 8080 --tunnel cloudflared --secret xxx # Public tunnel + HMAC verify ``` **Details:** `docs/commands/webhooks.md` diff --git a/docs/commands/webhooks.md b/docs/commands/webhooks.md index 1a91f56..bccf6f5 100644 --- a/docs/commands/webhooks.md +++ b/docs/commands/webhooks.md @@ -22,20 +22,35 @@ nylas webhook pubsub delete --yes Start a local webhook server for development and testing: ```bash -# Start server (local only) +# Interactive: detects cloudflared and prompts to enable a public tunnel. +# (Nylas can't deliver webhooks to localhost, so a tunnel is needed to +# receive real events.) nylas webhook server -# Start with public tunnel (cloudflared required) -nylas webhook server --tunnel cloudflared +# Skip the prompt and run loopback-only (useful for local curl tests +# or non-interactive environments) +nylas webhook server --no-tunnel -# Custom port -nylas webhook server --port 8080 --tunnel cloudflared +# Start with public tunnel (cloudflared required) + signature verification +nylas webhook server --tunnel cloudflared --secret your-webhook-secret + +# Custom port with a tunnel +nylas webhook server --port 8080 --tunnel cloudflared --secret your-webhook-secret ``` -**Install cloudflared:** +When `--tunnel` is set, `--secret` is required (or pass `--allow-unsigned` +to opt out explicitly). The interactive preflight will prompt for a +secret inline when you accept the tunnel; leaving it empty opts into +unsigned mode. + +**Cloudflared install:** + +On macOS, the preflight will offer to run `brew install cloudflared` for +you when cloudflared isn't on `PATH`. On other platforms, see: + ```bash -brew install cloudflared # macOS -# Or download from: github.com/cloudflare/cloudflared +brew install cloudflared # macOS (manual) +# Linux/Windows: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation/ ``` ### TUI Webhook Server diff --git a/docs/security/overview.md b/docs/security/overview.md index 2da502f..14f4939 100644 --- a/docs/security/overview.md +++ b/docs/security/overview.md @@ -14,7 +14,7 @@ nylas auth config # Configure API credentials (stored securely) ### Keyring Storage -Credentials are stored in the system keyring under service name `"nylas"`: +Secrets are stored in the system keyring under service name `"nylas"`: | Key | Constant | Description | |-----|----------|-------------| @@ -22,9 +22,10 @@ Credentials are stored in the system keyring under service name `"nylas"`: | `api_key` | `ports.KeyAPIKey` | Nylas API key (Bearer auth) | | `client_secret` | `ports.KeyClientSecret` | Provider OAuth secret (Google/Microsoft) | | `org_id` | `ports.KeyOrgID` | Nylas Organization ID | -| `grants` | `grantsKey` | JSON array of grant info (ID, email, provider) | -| `default_grant` | `defaultGrantKey` | Default grant ID for CLI operations | -| `grant_token_` | `ports.GrantTokenKey()` | Per-grant access tokens | + +Grant IDs, emails, providers, and the local default grant are non-secret metadata. +They are stored in the grant cache at `filepath.Join(os.UserCacheDir(), "nylas", "grants.json")`. +Keyring remains secrets-only. ### Implementation Files @@ -32,7 +33,7 @@ Credentials are stored in the system keyring under service name `"nylas"`: |------|---------| | `internal/ports/secrets.go` | Key constants (`KeyClientID`, `KeyAPIKey`, etc.) | | `internal/adapters/keyring/keyring.go` | System keyring implementation | -| `internal/adapters/keyring/grants.go` | Grant storage (`grants`, `default_grant`) | +| `internal/adapters/grantcache/cache.go` | File-backed non-secret grant metadata/default cache | | `internal/app/auth/config.go` | `SetupConfig()` saves credentials to keyring | ### Platform Backends @@ -53,6 +54,7 @@ NYLAS_DISABLE_KEYRING=true # Force encrypted file store (useful for testing/CI Non-sensitive settings stored in `~/.config/nylas/config.yaml`: - Region (us/eu) - Callback port +- Local default grant mirror --- diff --git a/internal/adapters/ai/base_client.go b/internal/adapters/ai/base_client.go index 8be6d69..0a2da8f 100644 --- a/internal/adapters/ai/base_client.go +++ b/internal/adapters/ai/base_client.go @@ -173,3 +173,92 @@ func FallbackStreamChat(ctx context.Context, req *domain.ChatRequest, chatFunc f } return callback(resp.Content) } + +// openAICompatibleResponse is the shared shape of /v1/chat/completions +// responses across providers that speak the OpenAI API surface (OpenAI, +// Groq, Together, Anyscale, etc.). Kept private to this package. +type openAICompatibleResponse struct { + Choices []struct { + Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls,omitempty"` + } `json:"message"` + } `json:"choices"` + Model string `json:"model"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +// OpenAICompatibleChat performs a chat request against any provider that +// implements the OpenAI /v1/chat/completions surface. provider is used to +// label the response and shape error messages. +// +// Callers (OpenAIClient, GroqClient, …) should validate IsConfigured before +// calling this method. +func (b *BaseClient) OpenAICompatibleChat(ctx context.Context, provider string, req *domain.ChatRequest, tools []domain.Tool) (*domain.ChatResponse, error) { + body := map[string]any{ + "model": b.GetModel(req.Model), + "messages": ConvertMessagesToMaps(req.Messages), + } + if req.MaxTokens > 0 { + body["max_tokens"] = req.MaxTokens + } + if req.Temperature > 0 { + body["temperature"] = req.Temperature + } + if len(tools) > 0 { + body["tools"] = ConvertToolsOpenAIFormat(tools) + body["tool_choice"] = "auto" + } + + headers := map[string]string{ + "Authorization": "Bearer " + b.apiKey, + } + + var raw openAICompatibleResponse + if err := b.DoJSONRequestAndDecode(ctx, "POST", "/chat/completions", body, headers, &raw); err != nil { + return nil, err + } + if len(raw.Choices) == 0 { + return nil, fmt.Errorf("no response from %s", provider) + } + + resp := &domain.ChatResponse{ + Content: raw.Choices[0].Message.Content, + Model: raw.Model, + Provider: provider, + Usage: domain.TokenUsage{ + PromptTokens: raw.Usage.PromptTokens, + CompletionTokens: raw.Usage.CompletionTokens, + TotalTokens: raw.Usage.TotalTokens, + }, + } + for _, tc := range raw.Choices[0].Message.ToolCalls { + var args map[string]any + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + // The model emitted a tool-call with malformed JSON arguments. + // Silently dropping it would leave the caller wondering why + // `len(ToolCalls)` is short — return the parse error so the + // scheduler can decide whether to retry or surface it. + return nil, fmt.Errorf("model tool-call %q has invalid JSON arguments: %w", + tc.Function.Name, err) + } + resp.ToolCalls = append(resp.ToolCalls, domain.ToolCall{ + ID: tc.ID, + Function: tc.Function.Name, + Arguments: args, + }) + } + return resp, nil +} diff --git a/internal/adapters/ai/groq_client.go b/internal/adapters/ai/groq_client.go index 4869656..088c419 100644 --- a/internal/adapters/ai/groq_client.go +++ b/internal/adapters/ai/groq_client.go @@ -2,7 +2,6 @@ package ai import ( "context" - "encoding/json" "fmt" "github.com/nylas/cli/internal/domain" @@ -48,92 +47,14 @@ func (c *GroqClient) Chat(ctx context.Context, req *domain.ChatRequest) (*domain return c.ChatWithTools(ctx, req, nil) } -// ChatWithTools sends a chat request with function calling. +// ChatWithTools sends a chat request with function calling. Groq exposes the +// OpenAI /v1/chat/completions surface, so this delegates to the shared +// pipeline. func (c *GroqClient) ChatWithTools(ctx context.Context, req *domain.ChatRequest, tools []domain.Tool) (*domain.ChatResponse, error) { if !c.IsConfigured() { return nil, fmt.Errorf("groq API key not configured") } - - // Prepare Groq request (OpenAI-compatible format) - groqReq := map[string]any{ - "model": c.GetModel(req.Model), - "messages": ConvertMessagesToMaps(req.Messages), - } - - if req.MaxTokens > 0 { - groqReq["max_tokens"] = req.MaxTokens - } - - if req.Temperature > 0 { - groqReq["temperature"] = req.Temperature - } - - // Tools support - if len(tools) > 0 { - groqReq["tools"] = ConvertToolsOpenAIFormat(tools) - groqReq["tool_choice"] = "auto" - } - - // Send request using base client - var groqResp struct { - Choices []struct { - Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls,omitempty"` - } `json:"message"` - } `json:"choices"` - Model string `json:"model"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` - } - - headers := map[string]string{ - "Authorization": "Bearer " + c.apiKey, - } - - if err := c.DoJSONRequestAndDecode(ctx, "POST", "/chat/completions", groqReq, headers, &groqResp); err != nil { - return nil, err - } - - if len(groqResp.Choices) == 0 { - return nil, fmt.Errorf("no response from Groq") - } - - response := &domain.ChatResponse{ - Content: groqResp.Choices[0].Message.Content, - Model: groqResp.Model, - Provider: "groq", - Usage: domain.TokenUsage{ - PromptTokens: groqResp.Usage.PromptTokens, - CompletionTokens: groqResp.Usage.CompletionTokens, - TotalTokens: groqResp.Usage.TotalTokens, - }, - } - - // Convert tool calls if present - for _, tc := range groqResp.Choices[0].Message.ToolCalls { - var args map[string]any - if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil { - response.ToolCalls = append(response.ToolCalls, domain.ToolCall{ - ID: tc.ID, - Function: tc.Function.Name, - Arguments: args, - }) - } - } - - return response, nil + return c.OpenAICompatibleChat(ctx, "groq", req, tools) } // StreamChat streams chat responses. diff --git a/internal/adapters/ai/openai_client.go b/internal/adapters/ai/openai_client.go index 6bec859..9918cd2 100644 --- a/internal/adapters/ai/openai_client.go +++ b/internal/adapters/ai/openai_client.go @@ -2,7 +2,6 @@ package ai import ( "context" - "encoding/json" "fmt" "github.com/nylas/cli/internal/domain" @@ -48,92 +47,13 @@ func (c *OpenAIClient) Chat(ctx context.Context, req *domain.ChatRequest) (*doma return c.ChatWithTools(ctx, req, nil) } -// ChatWithTools sends a chat request with function calling. +// ChatWithTools sends a chat request with function calling, delegating to +// the shared OpenAI-compatible pipeline in BaseClient. func (c *OpenAIClient) ChatWithTools(ctx context.Context, req *domain.ChatRequest, tools []domain.Tool) (*domain.ChatResponse, error) { if !c.IsConfigured() { return nil, fmt.Errorf("openai API key not configured") } - - // Prepare OpenAI request - openaiReq := map[string]any{ - "model": c.GetModel(req.Model), - "messages": ConvertMessagesToMaps(req.Messages), - } - - if req.MaxTokens > 0 { - openaiReq["max_tokens"] = req.MaxTokens - } - - if req.Temperature > 0 { - openaiReq["temperature"] = req.Temperature - } - - // Tools support - if len(tools) > 0 { - openaiReq["tools"] = ConvertToolsOpenAIFormat(tools) - openaiReq["tool_choice"] = "auto" - } - - // Send request using base client - var openaiResp struct { - Choices []struct { - Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls,omitempty"` - } `json:"message"` - } `json:"choices"` - Model string `json:"model"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` - } - - headers := map[string]string{ - "Authorization": "Bearer " + c.apiKey, - } - - if err := c.DoJSONRequestAndDecode(ctx, "POST", "/chat/completions", openaiReq, headers, &openaiResp); err != nil { - return nil, err - } - - if len(openaiResp.Choices) == 0 { - return nil, fmt.Errorf("no response from OpenAI") - } - - response := &domain.ChatResponse{ - Content: openaiResp.Choices[0].Message.Content, - Model: openaiResp.Model, - Provider: "openai", - Usage: domain.TokenUsage{ - PromptTokens: openaiResp.Usage.PromptTokens, - CompletionTokens: openaiResp.Usage.CompletionTokens, - TotalTokens: openaiResp.Usage.TotalTokens, - }, - } - - // Convert tool calls if present - for _, tc := range openaiResp.Choices[0].Message.ToolCalls { - var args map[string]any - if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil { - response.ToolCalls = append(response.ToolCalls, domain.ToolCall{ - ID: tc.ID, - Function: tc.Function.Name, - Arguments: args, - }) - } - } - - return response, nil + return c.OpenAICompatibleChat(ctx, "openai", req, tools) } // StreamChat streams chat responses. diff --git a/internal/adapters/ai/pattern_learner.go b/internal/adapters/ai/pattern_learner.go index 1526988..23abf5a 100644 --- a/internal/adapters/ai/pattern_learner.go +++ b/internal/adapters/ai/pattern_learner.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "os" "strings" "time" @@ -151,6 +152,7 @@ func (p *PatternLearner) fetchHistoricalEvents(ctx context.Context, req *LearnPa } allEvents := []domain.Event{} + skipped := []string{} // Fetch events from each calendar for _, calendar := range calendars { @@ -161,13 +163,26 @@ func (p *PatternLearner) fetchHistoricalEvents(ctx context.Context, req *LearnPa }) if err != nil { - // Skip calendar if error (might be read-only, etc.) + // Some calendars are read-only or temporarily unavailable. Record + // the skip with the underlying error so the caller (and test + // harness) can see analysis was partial — silently dropping the + // calendar gives the user "patterns" computed from incomplete + // data and no way to know. + skipped = append(skipped, fmt.Sprintf("%s: %v", calendar.ID, err)) continue } allEvents = append(allEvents, events...) } + if len(skipped) > 0 { + // Log to stderr; downstream callers that already check for + // PartialAnalysis on the returned struct (set below) get the same + // signal without depending on a logger interface. + fmt.Fprintf(os.Stderr, "warn: pattern analysis skipped %d calendar(s): %s\n", + len(skipped), strings.Join(skipped, "; ")) + } + // Filter out recurring events if not requested if !req.IncludeRecurring { filtered := []domain.Event{} @@ -184,7 +199,7 @@ func (p *PatternLearner) fetchHistoricalEvents(ctx context.Context, req *LearnPa } // calculateAnalysisPeriod calculates the actual period analyzed. -func (p *PatternLearner) calculateAnalysisPeriod(events []domain.Event, requestedDays int) AnalysisPeriod { +func (p *PatternLearner) calculateAnalysisPeriod(events []domain.Event, _ int) AnalysisPeriod { if len(events) == 0 { return AnalysisPeriod{} } @@ -320,10 +335,10 @@ func (p *PatternLearner) buildPatternContext(events []domain.Event, acceptance [ } // SavePatterns saves learned patterns (stub for future storage implementation). +// Returns an error rather than nil so callers can't mistake the no-op for a +// successful persist — pairs with LoadPatterns which already errors. func (p *PatternLearner) SavePatterns(ctx context.Context, patterns *SchedulingPatterns) error { - // Future: Save to local storage/database - // For now, this is a no-op - return nil + return fmt.Errorf("pattern storage not yet implemented") } // LoadPatterns loads previously learned patterns (stub for future storage implementation). diff --git a/internal/adapters/ai/pattern_learner_test.go b/internal/adapters/ai/pattern_learner_test.go index 181d586..06e564c 100644 --- a/internal/adapters/ai/pattern_learner_test.go +++ b/internal/adapters/ai/pattern_learner_test.go @@ -154,9 +154,12 @@ func TestPatternLearner_SaveLoadPatterns(t *testing.T) { ctx := context.Background() learner := &PatternLearner{} - t.Run("SavePatterns returns nil (stub)", func(t *testing.T) { + t.Run("SavePatterns returns not implemented error", func(t *testing.T) { + // SavePatterns is a stub; returning a real error keeps a caller + // from mistaking the no-op for a successful persist. err := learner.SavePatterns(ctx, &SchedulingPatterns{}) - assert.NoError(t, err) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not yet implemented") }) t.Run("LoadPatterns returns not implemented error", func(t *testing.T) { diff --git a/internal/adapters/ai/scheduler_tools.go b/internal/adapters/ai/scheduler_tools.go index 087397d..5534895 100644 --- a/internal/adapters/ai/scheduler_tools.go +++ b/internal/adapters/ai/scheduler_tools.go @@ -2,12 +2,7 @@ package ai import ( "context" - "encoding/json" "fmt" - "math" - "sort" - "strconv" - "strings" "time" tzutil "github.com/nylas/cli/internal/adapters/utilities/timezone" @@ -35,7 +30,11 @@ func (s *AIScheduler) runFindMeetingTime(ctx context.Context, args map[string]an return "", fmt.Errorf("duration must be greater than zero") } - searchStart, searchEnd, err := dateRangeArgs(args["dateRange"], requestLocation(req)) + loc, err := requestLocation(req) + if err != nil { + return "", err + } + searchStart, searchEnd, err := dateRangeArgs(args["dateRange"], loc) if err != nil { return "", err } @@ -46,7 +45,7 @@ func (s *AIScheduler) runFindMeetingTime(ctx context.Context, args map[string]an return "", fmt.Errorf("failed to get availability: %w", err) } - slots := rankAvailableSlots(result.Data.TimeSlots, requestLocation(req)) + slots := rankAvailableSlots(result.Data.TimeSlots, loc) payload := map[string]any{ "status": "success", "message": fmt.Sprintf("Found %d available time slots", len(slots)), @@ -103,7 +102,7 @@ func (s *AIScheduler) runCheckDST(ctx context.Context, args map[string]any) (str return marshalToolResult(payload) } -func (s *AIScheduler) runValidateWorkingHours(ctx context.Context, args map[string]any) (string, error) { +func (s *AIScheduler) runValidateWorkingHours(_ context.Context, args map[string]any) (string, error) { timezoneID, err := stringArg(args, "timezone", "") if err != nil { return "", err @@ -269,7 +268,10 @@ func (s *AIScheduler) runGetAvailability(ctx context.Context, args map[string]an return "", fmt.Errorf("participants are required") } - loc := requestLocation(req) + loc, err := requestLocation(req) + if err != nil { + return "", err + } startTime, err := timeArg(args, "startTime", loc) if err != nil { return "", err @@ -352,307 +354,3 @@ func (s *AIScheduler) runGetTimezoneInfo(ctx context.Context, args map[string]an return marshalToolResult(payload) } - -func buildAvailabilityRequest(participants []string, startTime, endTime time.Time, durationMinutes int) *domain.AvailabilityRequest { - availParticipants := make([]domain.AvailabilityParticipant, 0, len(participants)) - for _, email := range participants { - availParticipants = append(availParticipants, domain.AvailabilityParticipant{ - Email: email, - }) - } - - return &domain.AvailabilityRequest{ - StartTime: startTime.Unix(), - EndTime: endTime.Unix(), - DurationMinutes: durationMinutes, - Participants: availParticipants, - IntervalMinutes: 30, - } -} - -func rankAvailableSlots(slots []domain.AvailableSlot, loc *time.Location) []map[string]any { - type rankedSlot struct { - slot domain.AvailableSlot - score int - } - - ranked := make([]rankedSlot, 0, len(slots)) - for _, slot := range slots { - start := time.Unix(slot.StartTime, 0).In(loc) - ranked = append(ranked, rankedSlot{ - slot: slot, - score: localTimeScore(start), - }) - } - - sort.Slice(ranked, func(i, j int) bool { - if ranked[i].score == ranked[j].score { - return ranked[i].slot.StartTime < ranked[j].slot.StartTime - } - return ranked[i].score > ranked[j].score - }) - - limit := len(ranked) - if limit > 10 { - limit = 10 - } - - result := make([]map[string]any, 0, limit) - for _, entry := range ranked[:limit] { - result = append(result, map[string]any{ - "start": time.Unix(entry.slot.StartTime, 0).UTC().Format(time.RFC3339), - "end": time.Unix(entry.slot.EndTime, 0).UTC().Format(time.RFC3339), - "score": entry.score, - "emails": entry.slot.Emails, - "timezone": loc.String(), - }) - } - - return result -} - -func localTimeScore(start time.Time) int { - localHour := float64(start.Hour()) + float64(start.Minute())/60 - distanceFromIdeal := math.Abs(localHour - 13) - - score := 100 - int(distanceFromIdeal*8) - switch start.Weekday() { - case time.Tuesday, time.Wednesday, time.Thursday: - score += 5 - case time.Saturday, time.Sunday: - score -= 25 - } - - if score < 0 { - return 0 - } - if score > 100 { - return 100 - } - return score -} - -func (s *AIScheduler) defaultWritableCalendarID(ctx context.Context, grantID string) (string, error) { - calendars, err := s.nylasClient.GetCalendars(ctx, grantID) - if err != nil { - return "", fmt.Errorf("failed to list calendars: %w", err) - } - if len(calendars) == 0 { - return "", fmt.Errorf("no calendars available") - } - - for _, cal := range calendars { - if cal.IsPrimary && !cal.ReadOnly { - return cal.ID, nil - } - } - for _, cal := range calendars { - if !cal.ReadOnly { - return cal.ID, nil - } - } - - return "", fmt.Errorf("no writable calendar available") -} - -func marshalToolResult(payload map[string]any) (string, error) { - bytes, err := json.Marshal(payload) - if err != nil { - return "", fmt.Errorf("failed to marshal result: %w", err) - } - return string(bytes), nil -} - -func requestLocation(req *ScheduleRequest) *time.Location { - if req == nil || req.UserTimezone == "" { - return time.UTC - } - - loc, err := time.LoadLocation(req.UserTimezone) - if err != nil { - return time.UTC - } - return loc -} - -func dateRangeArgs(value any, loc *time.Location) (time.Time, time.Time, error) { - rangeArgs, ok := value.(map[string]any) - if !ok { - return time.Time{}, time.Time{}, fmt.Errorf("dateRange must be an object") - } - - startValue, ok := rangeArgs["start"] - if !ok { - return time.Time{}, time.Time{}, fmt.Errorf("dateRange.start is required") - } - endValue, ok := rangeArgs["end"] - if !ok { - return time.Time{}, time.Time{}, fmt.Errorf("dateRange.end is required") - } - - startDate, err := time.ParseInLocation("2006-01-02", fmt.Sprint(startValue), loc) - if err != nil { - return time.Time{}, time.Time{}, fmt.Errorf("invalid dateRange.start: %w", err) - } - endDate, err := time.ParseInLocation("2006-01-02", fmt.Sprint(endValue), loc) - if err != nil { - return time.Time{}, time.Time{}, fmt.Errorf("invalid dateRange.end: %w", err) - } - - endOfDay := time.Date(endDate.Year(), endDate.Month(), endDate.Day(), 23, 59, 59, 0, loc) - return startDate, endOfDay, nil -} - -func participantEmailsArg(args map[string]any, key string) ([]string, error) { - value, ok := args[key] - if !ok { - return nil, nil - } - - switch typed := value.(type) { - case []string: - return cleanStrings(typed), nil - case []any: - result := make([]string, 0, len(typed)) - for _, item := range typed { - value, ok := item.(string) - if !ok { - return nil, fmt.Errorf("%s entries must be strings", key) - } - if strings.TrimSpace(value) != "" { - result = append(result, strings.TrimSpace(value)) - } - } - return result, nil - case string: - return cleanStrings(strings.Split(typed, ",")), nil - default: - return nil, fmt.Errorf("%s must be a string array", key) - } -} - -func stringArg(args map[string]any, key, fallback string) (string, error) { - value, ok := args[key] - if !ok || value == nil { - return fallback, nil - } - - switch typed := value.(type) { - case string: - if strings.TrimSpace(typed) == "" { - return fallback, nil - } - return strings.TrimSpace(typed), nil - default: - return "", fmt.Errorf("%s must be a string", key) - } -} - -func intArg(args map[string]any, key string, fallback int) (int, error) { - value, ok := args[key] - if !ok || value == nil { - return fallback, nil - } - - switch typed := value.(type) { - case int: - return typed, nil - case int64: - return int(typed), nil - case float64: - return int(typed), nil - case string: - parsed, err := strconv.Atoi(strings.TrimSpace(typed)) - if err != nil { - return 0, fmt.Errorf("%s must be an integer", key) - } - return parsed, nil - default: - return 0, fmt.Errorf("%s must be an integer", key) - } -} - -func timeArg(args map[string]any, key string, loc *time.Location) (time.Time, error) { - value, ok := args[key] - if !ok || value == nil { - return time.Time{}, fmt.Errorf("%s is required", key) - } - - raw, ok := value.(string) - if !ok { - return time.Time{}, fmt.Errorf("%s must be a string", key) - } - raw = strings.TrimSpace(raw) - if raw == "" { - return time.Time{}, fmt.Errorf("%s is required", key) - } - - layouts := []string{ - time.RFC3339, - "2006-01-02T15:04:05", - "2006-01-02 15:04:05", - "2006-01-02T15:04", - "2006-01-02 15:04", - } - - for _, layout := range layouts { - var ( - parsed time.Time - err error - ) - - if layout == time.RFC3339 { - parsed, err = time.Parse(layout, raw) - } else { - parsed, err = time.ParseInLocation(layout, raw, loc) - } - if err == nil { - return parsed, nil - } - } - - return time.Time{}, fmt.Errorf("invalid %s: %q", key, raw) -} - -func clockMinutes(value string) (int, error) { - parts := strings.Split(value, ":") - if len(parts) != 2 { - return 0, fmt.Errorf("expected HH:MM") - } - - hour, err := strconv.Atoi(parts[0]) - if err != nil { - return 0, fmt.Errorf("invalid hour") - } - minute, err := strconv.Atoi(parts[1]) - if err != nil { - return 0, fmt.Errorf("invalid minute") - } - if hour < 0 || hour > 23 || minute < 0 || minute > 59 { - return 0, fmt.Errorf("hour must be 0-23 and minute must be 0-59") - } - - return hour*60 + minute, nil -} - -func formatOffset(seconds int) string { - sign := "+" - if seconds < 0 { - sign = "-" - seconds = -seconds - } - - hours := seconds / 3600 - minutes := (seconds % 3600) / 60 - return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes) -} - -func cleanStrings(values []string) []string { - result := make([]string, 0, len(values)) - for _, value := range values { - if strings.TrimSpace(value) != "" { - result = append(result, strings.TrimSpace(value)) - } - } - return result -} diff --git a/internal/adapters/ai/scheduler_tools_helpers.go b/internal/adapters/ai/scheduler_tools_helpers.go new file mode 100644 index 0000000..14453e1 --- /dev/null +++ b/internal/adapters/ai/scheduler_tools_helpers.go @@ -0,0 +1,320 @@ +package ai + +import ( + "context" + "encoding/json" + "fmt" + "math" + "sort" + "strconv" + "strings" + "time" + + "github.com/nylas/cli/internal/domain" +) + +func buildAvailabilityRequest(participants []string, startTime, endTime time.Time, durationMinutes int) *domain.AvailabilityRequest { + availParticipants := make([]domain.AvailabilityParticipant, 0, len(participants)) + for _, email := range participants { + availParticipants = append(availParticipants, domain.AvailabilityParticipant{ + Email: email, + }) + } + + return &domain.AvailabilityRequest{ + StartTime: startTime.Unix(), + EndTime: endTime.Unix(), + DurationMinutes: durationMinutes, + Participants: availParticipants, + IntervalMinutes: 30, + } +} + +func rankAvailableSlots(slots []domain.AvailableSlot, loc *time.Location) []map[string]any { + type rankedSlot struct { + slot domain.AvailableSlot + score int + } + + ranked := make([]rankedSlot, 0, len(slots)) + for _, slot := range slots { + start := time.Unix(slot.StartTime, 0).In(loc) + ranked = append(ranked, rankedSlot{ + slot: slot, + score: localTimeScore(start), + }) + } + + sort.Slice(ranked, func(i, j int) bool { + if ranked[i].score == ranked[j].score { + return ranked[i].slot.StartTime < ranked[j].slot.StartTime + } + return ranked[i].score > ranked[j].score + }) + + limit := min(len(ranked), 10) + + result := make([]map[string]any, 0, limit) + for _, entry := range ranked[:limit] { + result = append(result, map[string]any{ + "start": time.Unix(entry.slot.StartTime, 0).UTC().Format(time.RFC3339), + "end": time.Unix(entry.slot.EndTime, 0).UTC().Format(time.RFC3339), + "score": entry.score, + "emails": entry.slot.Emails, + "timezone": loc.String(), + }) + } + + return result +} + +func localTimeScore(start time.Time) int { + localHour := float64(start.Hour()) + float64(start.Minute())/60 + distanceFromIdeal := math.Abs(localHour - 13) + + score := 100 - int(distanceFromIdeal*8) + switch start.Weekday() { + case time.Tuesday, time.Wednesday, time.Thursday: + score += 5 + case time.Saturday, time.Sunday: + score -= 25 + } + + if score < 0 { + return 0 + } + if score > 100 { + return 100 + } + return score +} + +func (s *AIScheduler) defaultWritableCalendarID(ctx context.Context, grantID string) (string, error) { + calendars, err := s.nylasClient.GetCalendars(ctx, grantID) + if err != nil { + return "", fmt.Errorf("failed to list calendars: %w", err) + } + if len(calendars) == 0 { + return "", fmt.Errorf("no calendars available") + } + + for _, cal := range calendars { + if cal.IsPrimary && !cal.ReadOnly { + return cal.ID, nil + } + } + for _, cal := range calendars { + if !cal.ReadOnly { + return cal.ID, nil + } + } + + return "", fmt.Errorf("no writable calendar available") +} + +func marshalToolResult(payload map[string]any) (string, error) { + bytes, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal result: %w", err) + } + return string(bytes), nil +} + +// requestLocation resolves the user's timezone for a scheduling request. +// A bad timezone ID (e.g. "PST" instead of "America/Los_Angeles") yields +// an error rather than silently rounding to UTC — slot rankings produced +// against the wrong zone look correct but are wrong by hours, which is +// exactly the kind of failure the user has no easy way to catch. +func requestLocation(req *ScheduleRequest) (*time.Location, error) { + if req == nil || req.UserTimezone == "" { + return time.UTC, nil + } + + loc, err := time.LoadLocation(req.UserTimezone) + if err != nil { + return nil, fmt.Errorf("invalid timezone %q: %w", req.UserTimezone, err) + } + return loc, nil +} + +func dateRangeArgs(value any, loc *time.Location) (time.Time, time.Time, error) { + rangeArgs, ok := value.(map[string]any) + if !ok { + return time.Time{}, time.Time{}, fmt.Errorf("dateRange must be an object") + } + + startValue, ok := rangeArgs["start"] + if !ok { + return time.Time{}, time.Time{}, fmt.Errorf("dateRange.start is required") + } + endValue, ok := rangeArgs["end"] + if !ok { + return time.Time{}, time.Time{}, fmt.Errorf("dateRange.end is required") + } + + startDate, err := time.ParseInLocation("2006-01-02", fmt.Sprint(startValue), loc) + if err != nil { + return time.Time{}, time.Time{}, fmt.Errorf("invalid dateRange.start: %w", err) + } + endDate, err := time.ParseInLocation("2006-01-02", fmt.Sprint(endValue), loc) + if err != nil { + return time.Time{}, time.Time{}, fmt.Errorf("invalid dateRange.end: %w", err) + } + + endOfDay := time.Date(endDate.Year(), endDate.Month(), endDate.Day(), 23, 59, 59, 0, loc) + return startDate, endOfDay, nil +} + +func participantEmailsArg(args map[string]any, key string) ([]string, error) { + value, ok := args[key] + if !ok { + return nil, nil + } + + switch typed := value.(type) { + case []string: + return cleanStrings(typed), nil + case []any: + result := make([]string, 0, len(typed)) + for _, item := range typed { + value, ok := item.(string) + if !ok { + return nil, fmt.Errorf("%s entries must be strings", key) + } + if strings.TrimSpace(value) != "" { + result = append(result, strings.TrimSpace(value)) + } + } + return result, nil + case string: + return cleanStrings(strings.Split(typed, ",")), nil + default: + return nil, fmt.Errorf("%s must be a string array", key) + } +} + +func stringArg(args map[string]any, key, fallback string) (string, error) { + value, ok := args[key] + if !ok || value == nil { + return fallback, nil + } + + switch typed := value.(type) { + case string: + if strings.TrimSpace(typed) == "" { + return fallback, nil + } + return strings.TrimSpace(typed), nil + default: + return "", fmt.Errorf("%s must be a string", key) + } +} + +func intArg(args map[string]any, key string, fallback int) (int, error) { + value, ok := args[key] + if !ok || value == nil { + return fallback, nil + } + + switch typed := value.(type) { + case int: + return typed, nil + case int64: + return int(typed), nil + case float64: + return int(typed), nil + case string: + parsed, err := strconv.Atoi(strings.TrimSpace(typed)) + if err != nil { + return 0, fmt.Errorf("%s must be an integer", key) + } + return parsed, nil + default: + return 0, fmt.Errorf("%s must be an integer", key) + } +} + +func timeArg(args map[string]any, key string, loc *time.Location) (time.Time, error) { + value, ok := args[key] + if !ok || value == nil { + return time.Time{}, fmt.Errorf("%s is required", key) + } + + raw, ok := value.(string) + if !ok { + return time.Time{}, fmt.Errorf("%s must be a string", key) + } + raw = strings.TrimSpace(raw) + if raw == "" { + return time.Time{}, fmt.Errorf("%s is required", key) + } + + layouts := []string{ + time.RFC3339, + "2006-01-02T15:04:05", + "2006-01-02 15:04:05", + "2006-01-02T15:04", + "2006-01-02 15:04", + } + + for _, layout := range layouts { + var ( + parsed time.Time + err error + ) + + if layout == time.RFC3339 { + parsed, err = time.Parse(layout, raw) + } else { + parsed, err = time.ParseInLocation(layout, raw, loc) + } + if err == nil { + return parsed, nil + } + } + + return time.Time{}, fmt.Errorf("invalid %s: %q", key, raw) +} + +func clockMinutes(value string) (int, error) { + parts := strings.Split(value, ":") + if len(parts) != 2 { + return 0, fmt.Errorf("expected HH:MM") + } + + hour, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, fmt.Errorf("invalid hour") + } + minute, err := strconv.Atoi(parts[1]) + if err != nil { + return 0, fmt.Errorf("invalid minute") + } + if hour < 0 || hour > 23 || minute < 0 || minute > 59 { + return 0, fmt.Errorf("hour must be 0-23 and minute must be 0-59") + } + + return hour*60 + minute, nil +} + +func formatOffset(seconds int) string { + sign := "+" + if seconds < 0 { + sign = "-" + seconds = -seconds + } + + hours := seconds / 3600 + minutes := (seconds % 3600) / 60 + return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes) +} + +func cleanStrings(values []string) []string { + result := make([]string, 0, len(values)) + for _, value := range values { + if strings.TrimSpace(value) != "" { + result = append(result, strings.TrimSpace(value)) + } + } + return result +} diff --git a/internal/adapters/config/config_test.go b/internal/adapters/config/config_test.go index d4339bf..0e3755d 100644 --- a/internal/adapters/config/config_test.go +++ b/internal/adapters/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "strings" "testing" "github.com/nylas/cli/internal/domain" @@ -121,6 +122,42 @@ func TestFileStore_LoadSaveRoundTrip(t *testing.T) { } } +func TestFileStore_SaveDropsLegacyGrants(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + legacyYAML := []byte(`region: us +callback_port: 9007 +default_grant: grant-1 +grants: + - id: grant-1 + email: user@example.com + provider: google +`) + if err := os.WriteFile(configPath, legacyYAML, 0600); err != nil { + t.Fatalf("failed to write legacy config: %v", err) + } + + store := NewFileStore(configPath) + cfg, err := store.Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if cfg.DefaultGrant != "grant-1" { + t.Fatalf("DefaultGrant = %q, want %q", cfg.DefaultGrant, "grant-1") + } + + if err := store.Save(cfg); err != nil { + t.Fatalf("Save() error = %v", err) + } + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("failed to read saved config: %v", err) + } + if strings.Contains(string(data), "grants:") { + t.Fatalf("saved config still contains legacy grants list:\n%s", string(data)) + } +} + func TestFileStore_LoadNonExistent(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "nonexistent.yaml") diff --git a/internal/adapters/grantcache/cache.go b/internal/adapters/grantcache/cache.go new file mode 100644 index 0000000..98b38da --- /dev/null +++ b/internal/adapters/grantcache/cache.go @@ -0,0 +1,282 @@ +package grantcache + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sync" + "time" + + "github.com/nylas/cli/internal/domain" +) + +const ( + fileVersion = 1 + lockWait = 10 * time.Second + staleLockAge = 2 * time.Minute +) + +// Store implements ports.GrantStore using a plaintext JSON file for +// non-secret grant metadata and local default-grant preference. +type Store struct { + path string + mu sync.RWMutex +} + +type fileShape struct { + Version int `json:"version"` + DefaultGrant string `json:"default_grant,omitempty"` + Grants []domain.GrantInfo `json:"grants"` +} + +// New creates a file-backed grant metadata store. +func New(path string) *Store { + return &Store{path: path} +} + +// SaveGrant saves or replaces one grant in local metadata. +func (s *Store) SaveGrant(info domain.GrantInfo) error { + if info.ID == "" || info.Email == "" { + return domain.ErrInvalidInput + } + return s.mutate(func(shape *fileShape) error { + for i, grant := range shape.Grants { + if grant.ID == info.ID { + shape.Grants[i] = info + return nil + } + } + shape.Grants = append(shape.Grants, info) + return nil + }) +} + +// ReplaceGrants replaces locally cached grant metadata after a successful +// live API listing. The default grant is preserved only if it still exists. +func (s *Store) ReplaceGrants(grants []domain.GrantInfo) error { + for _, grant := range grants { + if grant.ID == "" || grant.Email == "" { + return domain.ErrInvalidInput + } + } + return s.mutate(func(shape *fileShape) error { + defaultGrant := shape.DefaultGrant + shape.Grants = append([]domain.GrantInfo(nil), grants...) + if defaultGrant != "" && !containsGrantID(shape.Grants, defaultGrant) { + shape.DefaultGrant = "" + } + return nil + }) +} + +// GetGrant retrieves grant info by ID. +func (s *Store) GetGrant(grantID string) (*domain.GrantInfo, error) { + grants, err := s.ListGrants() + if err != nil { + return nil, err + } + for _, grant := range grants { + if grant.ID == grantID { + out := grant + return &out, nil + } + } + return nil, domain.ErrGrantNotFound +} + +// GetGrantByEmail retrieves grant info by email. +func (s *Store) GetGrantByEmail(email string) (*domain.GrantInfo, error) { + grants, err := s.ListGrants() + if err != nil { + return nil, err + } + for _, grant := range grants { + if grant.Email == email { + out := grant + return &out, nil + } + } + return nil, domain.ErrGrantNotFound +} + +// ListGrants returns locally cached grant metadata. +func (s *Store) ListGrants() ([]domain.GrantInfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + shape, err := s.read() + if err != nil { + return nil, err + } + return append([]domain.GrantInfo(nil), shape.Grants...), nil +} + +// DeleteGrant removes one grant from local metadata. +func (s *Store) DeleteGrant(grantID string) error { + return s.mutate(func(shape *fileShape) error { + grants := shape.Grants[:0] + for _, grant := range shape.Grants { + if grant.ID != grantID { + grants = append(grants, grant) + } + } + shape.Grants = grants + if shape.DefaultGrant == grantID { + shape.DefaultGrant = "" + } + return nil + }) +} + +// SetDefaultGrant stores the local default-grant preference. +func (s *Store) SetDefaultGrant(grantID string) error { + return s.mutate(func(shape *fileShape) error { + shape.DefaultGrant = grantID + return nil + }) +} + +// GetDefaultGrant returns the local default-grant preference. +func (s *Store) GetDefaultGrant() (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + shape, err := s.read() + if err != nil { + return "", err + } + if shape.DefaultGrant == "" { + return "", domain.ErrNoDefaultGrant + } + return shape.DefaultGrant, nil +} + +// ClearGrants removes all local grant metadata and default preference. +func (s *Store) ClearGrants() error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.withFileLock(func() error { + if err := os.Remove(s.path); err != nil && !errors.Is(err, fs.ErrNotExist) { + return err + } + return nil + }) +} + +func (s *Store) mutate(fn func(*fileShape) error) error { + s.mu.Lock() + defer s.mu.Unlock() + + return s.withFileLock(func() error { + shape, err := s.read() + if err != nil { + return err + } + if err := fn(shape); err != nil { + return err + } + return s.write(shape) + }) +} + +func (s *Store) read() (*fileShape, error) { + data, err := os.ReadFile(s.path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return &fileShape{Version: fileVersion}, nil + } + return nil, err + } + + var shape fileShape + if err := json.Unmarshal(data, &shape); err != nil { + return &fileShape{Version: fileVersion}, nil + } + if shape.Version == 0 { + shape.Version = fileVersion + } + return &shape, nil +} + +func (s *Store) write(shape *fileShape) error { + shape.Version = fileVersion + + dir := filepath.Dir(s.path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + + data, err := json.MarshalIndent(shape, "", " ") + if err != nil { + return err + } + + tmp, err := os.CreateTemp(dir, ".grants-*.tmp") + if err != nil { + return err + } + tmpPath := tmp.Name() + defer func() { _ = os.Remove(tmpPath) }() + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Chmod(0o600); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + return os.Rename(tmpPath, s.path) +} + +func (s *Store) withFileLock(fn func() error) error { + dir := filepath.Dir(s.path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + + lockPath := s.path + ".lock" + deadline := time.Now().Add(lockWait) + for { + err := os.Mkdir(lockPath, 0o700) + if err == nil { + defer func() { _ = os.Remove(lockPath) }() + return fn() + } + if !errors.Is(err, fs.ErrExist) { + return err + } + if isStaleLock(lockPath) { + _ = os.RemoveAll(lockPath) + continue + } + if time.Now().After(deadline) { + return fmt.Errorf("timed out waiting for grant cache lock: %s", lockPath) + } + time.Sleep(10 * time.Millisecond) + } +} + +func isStaleLock(path string) bool { + info, err := os.Stat(path) + if err != nil { + return false + } + return time.Since(info.ModTime()) > staleLockAge +} + +func containsGrantID(grants []domain.GrantInfo, grantID string) bool { + for _, grant := range grants { + if grant.ID == grantID { + return true + } + } + return false +} diff --git a/internal/adapters/grantcache/cache_test.go b/internal/adapters/grantcache/cache_test.go new file mode 100644 index 0000000..e36e645 --- /dev/null +++ b/internal/adapters/grantcache/cache_test.go @@ -0,0 +1,185 @@ +package grantcache + +import ( + "fmt" + "os" + "path/filepath" + "sync" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStore_MissingFileReturnsEmpty(t *testing.T) { + store := New(filepath.Join(t.TempDir(), "grants.json")) + + grants, err := store.ListGrants() + require.NoError(t, err) + assert.Empty(t, grants) + + defaultGrant, err := store.GetDefaultGrant() + assert.Empty(t, defaultGrant) + assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) +} + +func TestStore_CorruptFileReturnsEmpty(t *testing.T) { + path := filepath.Join(t.TempDir(), "grants.json") + require.NoError(t, os.WriteFile(path, []byte("{not-json"), 0o600)) + + store := New(path) + grants, err := store.ListGrants() + require.NoError(t, err) + assert.Empty(t, grants) + + defaultGrant, err := store.GetDefaultGrant() + assert.Empty(t, defaultGrant) + assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) +} + +func TestStore_SaveGetDeleteAndClear(t *testing.T) { + path := filepath.Join(t.TempDir(), "nylas", "grants.json") + store := New(path) + + grant1 := domain.GrantInfo{ID: "grant-1", Email: "one@example.com", Provider: domain.ProviderGoogle} + grant2 := domain.GrantInfo{ID: "grant-2", Email: "two@example.com", Provider: domain.ProviderMicrosoft} + + require.NoError(t, store.SaveGrant(grant1)) + require.NoError(t, store.SaveGrant(grant2)) + require.NoError(t, store.SetDefaultGrant(grant1.ID)) + + byID, err := store.GetGrant(grant1.ID) + require.NoError(t, err) + assert.Equal(t, grant1, *byID) + + byEmail, err := store.GetGrantByEmail(grant2.Email) + require.NoError(t, err) + assert.Equal(t, grant2, *byEmail) + + grants, err := store.ListGrants() + require.NoError(t, err) + assert.Equal(t, []domain.GrantInfo{grant1, grant2}, grants) + + updated := domain.GrantInfo{ID: "grant-2", Email: "new-two@example.com", Provider: domain.ProviderGoogle} + require.NoError(t, store.SaveGrant(updated)) + grants, err = store.ListGrants() + require.NoError(t, err) + assert.Equal(t, []domain.GrantInfo{grant1, updated}, grants) + + require.NoError(t, store.DeleteGrant(grant1.ID)) + _, err = store.GetDefaultGrant() + assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) + _, err = store.GetGrant(grant1.ID) + assert.ErrorIs(t, err, domain.ErrGrantNotFound) + + require.NoError(t, store.ClearGrants()) + grants, err = store.ListGrants() + require.NoError(t, err) + assert.Empty(t, grants) + _, err = os.Stat(path) + assert.ErrorIs(t, err, os.ErrNotExist) +} + +func TestStore_FilePermissions(t *testing.T) { + path := filepath.Join(t.TempDir(), "nylas", "grants.json") + store := New(path) + + require.NoError(t, store.SaveGrant(domain.GrantInfo{ + ID: "grant-1", + Email: "one@example.com", + })) + + dirInfo, err := os.Stat(filepath.Dir(path)) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o700), dirInfo.Mode().Perm()) + + fileInfo, err := os.Stat(path) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o600), fileInfo.Mode().Perm()) +} + +func TestStore_ReplaceGrantsPreservesValidDefault(t *testing.T) { + store := New(filepath.Join(t.TempDir(), "grants.json")) + require.NoError(t, store.SaveGrant(domain.GrantInfo{ + ID: "grant-1", + Email: "old@example.com", + Provider: domain.ProviderGoogle, + })) + require.NoError(t, store.SetDefaultGrant("grant-1")) + + require.NoError(t, store.ReplaceGrants([]domain.GrantInfo{ + {ID: "grant-1", Email: "new@example.com", Provider: domain.ProviderMicrosoft}, + {ID: "grant-2", Email: "two@example.com", Provider: domain.ProviderGoogle}, + })) + + defaultGrant, err := store.GetDefaultGrant() + require.NoError(t, err) + assert.Equal(t, "grant-1", defaultGrant) + + grant, err := store.GetGrant("grant-1") + require.NoError(t, err) + assert.Equal(t, "new@example.com", grant.Email) + assert.Equal(t, domain.ProviderMicrosoft, grant.Provider) +} + +func TestStore_ReplaceGrantsClearsMissingDefault(t *testing.T) { + store := New(filepath.Join(t.TempDir(), "grants.json")) + require.NoError(t, store.SaveGrant(domain.GrantInfo{ + ID: "grant-1", + Email: "old@example.com", + Provider: domain.ProviderGoogle, + })) + require.NoError(t, store.SetDefaultGrant("grant-1")) + + require.NoError(t, store.ReplaceGrants([]domain.GrantInfo{ + {ID: "grant-2", Email: "two@example.com", Provider: domain.ProviderGoogle}, + })) + + defaultGrant, err := store.GetDefaultGrant() + assert.Empty(t, defaultGrant) + assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) +} + +func TestStore_RejectsInvalidGrantInfo(t *testing.T) { + store := New(filepath.Join(t.TempDir(), "grants.json")) + + assert.ErrorIs(t, store.SaveGrant(domain.GrantInfo{ID: "grant-1"}), domain.ErrInvalidInput) + assert.ErrorIs(t, store.ReplaceGrants([]domain.GrantInfo{{ID: "grant-1"}}), domain.ErrInvalidInput) +} + +func TestStore_ConcurrentWritersPreserveUpdates(t *testing.T) { + path := filepath.Join(t.TempDir(), "grants.json") + stores := []*Store{New(path), New(path)} + + var wg sync.WaitGroup + errs := make(chan error, 40) + for i := 0; i < 40; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + errs <- stores[i%len(stores)].SaveGrant(domain.GrantInfo{ + ID: fmt.Sprintf("grant-%02d", i), + Email: fmt.Sprintf("user-%02d@example.com", i), + }) + }() + } + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + + grants, err := New(path).ListGrants() + require.NoError(t, err) + assert.Len(t, grants, 40) + seen := make(map[string]bool, len(grants)) + for _, grant := range grants { + seen[grant.ID] = true + } + for i := 0; i < 40; i++ { + assert.True(t, seen[fmt.Sprintf("grant-%02d", i)]) + } +} diff --git a/internal/adapters/keyring/crossplatform_test.go b/internal/adapters/keyring/crossplatform_test.go index ed5219c..b2b9980 100644 --- a/internal/adapters/keyring/crossplatform_test.go +++ b/internal/adapters/keyring/crossplatform_test.go @@ -250,14 +250,65 @@ func TestEncryptedFileStore_RequiresPassphraseForWrites(t *testing.T) { t.Cleanup(func() { _ = os.Setenv(fileStorePassphraseEnv, orig) }) } + // Fresh install: no legacy file, no passphrase. Construction succeeds so + // callers can probe the empty store, but Set must fail with a clear + // message pointing at NYLAS_FILE_STORE_PASSPHRASE. store, err := NewEncryptedFileStore(t.TempDir()) if err != nil { - t.Fatalf("NewEncryptedFileStore failed: %v", err) + t.Fatalf("NewEncryptedFileStore should not fail on a fresh install: %v", err) + } + + if err := store.Set("api_key", "value"); err == nil { + t.Fatal("Set succeeded without passphrase on fresh install") + } else if !strings.Contains(err.Error(), fileStorePassphraseEnv) { + t.Fatalf("Set error %q does not mention %s", err.Error(), fileStorePassphraseEnv) + } +} + +// TestEncryptedFileStore_ReadsLegacyWithoutPassphraseButRefusesWrite verifies +// that existing fallback-store users are not locked out of read-only commands +// after upgrade, while writes still require the passphrase migration path. +func TestEncryptedFileStore_ReadsLegacyWithoutPassphraseButRefusesWrite(t *testing.T) { + tmpDir := t.TempDir() + + legacyKey, err := deriveLegacyKey() + if err != nil { + t.Fatalf("deriveLegacyKey failed: %v", err) + } + legacyCiphertext, err := encryptWithKey(legacyKey, []byte(`{"api_key":"old-value"}`)) + if err != nil { + t.Fatalf("encryptWithKey failed: %v", err) + } + secretsPath := filepath.Join(tmpDir, ".secrets.enc") + if err := os.WriteFile(secretsPath, legacyCiphertext, 0600); err != nil { + t.Fatalf("failed to write legacy file: %v", err) + } + + orig := os.Getenv(fileStorePassphraseEnv) + if orig != "" { + _ = os.Unsetenv(fileStorePassphraseEnv) + t.Cleanup(func() { _ = os.Setenv(fileStorePassphraseEnv, orig) }) + } + + // Legacy file exists → construction should succeed (store is openable). + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed when legacy file exists: %v", err) } - err = store.Set("api_key", "value") + // Read-only commands must continue to work against the legacy file. + value, err := store.Get("api_key") + if err != nil { + t.Fatalf("Get failed without passphrase on legacy file: %v", err) + } + if value != "old-value" { + t.Fatalf("Get returned %q, want old-value", value) + } + + // Write must also fail with migration-required error. + err = store.Set("api_key", "new-value") if err == nil { - t.Fatal("Set succeeded without passphrase") + t.Fatal("Set succeeded without passphrase on legacy file") } if !strings.Contains(err.Error(), fileStorePassphraseEnv) { t.Fatalf("Set error %q does not mention %s", err.Error(), fileStorePassphraseEnv) @@ -543,3 +594,136 @@ func TestConcurrentAccess(t *testing.T) { t.Errorf("Concurrent access error: %v", err) } } + +// TestEncryptedFileStore_MigratesOnFirstGet verifies that the one-shot migration +// happens on the first Get, not only after a subsequent Set. +// After Get returns the plaintext, the on-disk file must already be re-encrypted +// with the passphrase-derived key and must no longer be decryptable by the legacy key. +func TestEncryptedFileStore_MigratesOnFirstGet(t *testing.T) { + tmpDir := t.TempDir() + passphrase := setFileStorePassphrase(t) + + legacyKey, err := deriveLegacyKey() + if err != nil { + t.Fatalf("deriveLegacyKey failed: %v", err) + } + + legacyCiphertext, err := encryptWithKey(legacyKey, []byte(`{"migrate_key":"migrate_value"}`)) + if err != nil { + t.Fatalf("encryptWithKey failed: %v", err) + } + + secretsPath := filepath.Join(tmpDir, ".secrets.enc") + if err := os.WriteFile(secretsPath, legacyCiphertext, 0600); err != nil { + t.Fatalf("failed to write legacy secrets file: %v", err) + } + + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + + // First Get — should trigger migration inline. + value, err := store.Get("migrate_key") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if value != "migrate_value" { + t.Fatalf("Get returned %q, want %q", value, "migrate_value") + } + + // After Get, the on-disk file must already use the passphrase-derived key. + data, err := os.ReadFile(secretsPath) + if err != nil { + t.Fatalf("failed to read secrets file after migration: %v", err) + } + + // Legacy key must no longer decrypt the file. + if _, err := decryptWithKey(legacyKey, data); err == nil { + t.Fatal("on-disk file is still decryptable with the legacy key after Get-triggered migration") + } + + // Passphrase-derived key must decrypt successfully. + saltData, err := os.ReadFile(filepath.Join(tmpDir, ".secrets.salt")) + if err != nil { + t.Fatalf("failed to read salt file: %v", err) + } + decodedSalt, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(saltData))) + if err != nil { + t.Fatalf("failed to decode salt: %v", err) + } + if _, err := decryptWithKey(derivePassphraseKey([]byte(passphrase), decodedSalt), data); err != nil { + t.Fatalf("failed to decrypt migrated file with passphrase-derived key: %v", err) + } +} + +// TestDetectKeyType verifies the detectKeyType helper across the expected states. +func TestDetectKeyType(t *testing.T) { + t.Run("none_when_no_file", func(t *testing.T) { + tmpDir := t.TempDir() + setFileStorePassphrase(t) + + store, err := NewEncryptedFileStore(tmpDir) + // Fresh install with passphrase set — construction should succeed. + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + // No file written yet. + kt, err := store.detectKeyType() + if err != nil { + t.Fatalf("detectKeyType failed: %v", err) + } + if kt != fileStoreKeyNone { + t.Fatalf("detectKeyType = %d, want fileStoreKeyNone (%d)", kt, fileStoreKeyNone) + } + }) + + t.Run("passphrase_only_after_write", func(t *testing.T) { + tmpDir := t.TempDir() + setFileStorePassphrase(t) + + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + if err := store.Set("k", "v"); err != nil { + t.Fatalf("Set failed: %v", err) + } + kt, err := store.detectKeyType() + if err != nil { + t.Fatalf("detectKeyType failed: %v", err) + } + if kt != fileStoreKeyPassphraseOnly { + t.Fatalf("detectKeyType = %d, want fileStoreKeyPassphraseOnly (%d)", kt, fileStoreKeyPassphraseOnly) + } + }) + + t.Run("legacy_only_before_migration", func(t *testing.T) { + tmpDir := t.TempDir() + setFileStorePassphrase(t) + + legacyKey, err := deriveLegacyKey() + if err != nil { + t.Fatalf("deriveLegacyKey failed: %v", err) + } + ct, err := encryptWithKey(legacyKey, []byte(`{"k":"v"}`)) + if err != nil { + t.Fatalf("encryptWithKey failed: %v", err) + } + if err := os.WriteFile(filepath.Join(tmpDir, ".secrets.enc"), ct, 0600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + kt, err := store.detectKeyType() + if err != nil { + t.Fatalf("detectKeyType failed: %v", err) + } + if kt != fileStoreKeyLegacyOnly { + t.Fatalf("detectKeyType = %d, want fileStoreKeyLegacyOnly (%d)", kt, fileStoreKeyLegacyOnly) + } + }) +} diff --git a/internal/adapters/keyring/file.go b/internal/adapters/keyring/file.go index ea13ef6..2b8e030 100644 --- a/internal/adapters/keyring/file.go +++ b/internal/adapters/keyring/file.go @@ -1,17 +1,13 @@ package keyring import ( - "crypto/aes" - "crypto/cipher" "crypto/rand" - "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "io" "os" "path/filepath" - "runtime" "strings" "sync" @@ -24,9 +20,28 @@ const ( fileStoreSaltSize = 16 ) +// fileStoreKeyType describes which key(s) can decrypt the on-disk .secrets.enc file. +type fileStoreKeyType int + +const ( + fileStoreKeyNone fileStoreKeyType = iota // file does not exist or neither key decrypts it + fileStoreKeyLegacyOnly // decryptable only with the legacy machine-derived key + fileStoreKeyPassphraseOnly // decryptable only with the passphrase-derived key + fileStoreKeyBoth // decryptable with either key +) + // EncryptedFileStore implements SecretStore using an encrypted file. // This is a fallback for environments where the system keyring is unavailable. -// Uses AES-256-GCM encryption with a key derived from user-supplied secret material. +// Uses AES-256-GCM encryption with an Argon2id key derived from a user-supplied +// passphrase set via NYLAS_FILE_STORE_PASSPHRASE. +// +// REQUIREMENT: NYLAS_FILE_STORE_PASSPHRASE must be set before using this store +// on a fresh install. Existing installations that used the legacy machine-derived +// key remain readable without a passphrase for backward compatibility, and will +// be migrated automatically the first time NYLAS_FILE_STORE_PASSPHRASE is supplied. +// +// To avoid the file store entirely, leave NYLAS_DISABLE_KEYRING unset and let +// the system keyring be used, or run `nylas auth config` to reconfigure. type EncryptedFileStore struct { path string keyPath string @@ -37,8 +52,15 @@ type EncryptedFileStore struct { mu sync.RWMutex } -// NewEncryptedFileStore creates a new EncryptedFileStore. -// The secrets are stored in an encrypted file within the config directory. +// NewEncryptedFileStore creates a new EncryptedFileStore rooted in configDir. +// +// Construction always succeeds — a fresh install (no passphrase, no legacy +// file) yields a store whose reads return ErrSecretNotFound and whose writes +// fail with a clear "set NYLAS_FILE_STORE_PASSPHRASE" error. This lets +// callers probe an empty store without crashing. +// +// To actually persist secrets, set NYLAS_FILE_STORE_PASSPHRASE to a strong +// value, or run `nylas auth config` to switch to the system keyring. func NewEncryptedFileStore(configDir string) (*EncryptedFileStore, error) { path := filepath.Join(configDir, ".secrets.enc") keyPath := filepath.Join(configDir, ".secrets.key") @@ -56,6 +78,17 @@ func NewEncryptedFileStore(configDir string) (*EncryptedFileStore, error) { var passphrase []byte if value := os.Getenv(fileStorePassphraseEnv); value != "" { + // Enforce a minimum length so a 4-character passphrase isn't held + // up as adequate defense. This is a deliberately gentle floor (12 + // characters) — long enough to make offline brute-force materially + // expensive when combined with Argon2id, short enough that real + // users can comply. + if len(value) < minPassphraseLen { + return nil, fmt.Errorf( + "%s must be at least %d characters (got %d) — pick a longer passphrase", + fileStorePassphraseEnv, minPassphraseLen, len(value), + ) + } passphrase = []byte(value) } @@ -90,9 +123,16 @@ func (f *EncryptedFileStore) Set(key, value string) error { } // Get retrieves a secret value for the given key. +// +// Holds the exclusive lock — not RLock — because loadSecrets→decrypt may +// run migrateToPassphrase on legacy data, which writes BOTH .secrets.salt +// and .secrets.enc. Two concurrent first-readers under RLock can interleave +// those writes and leave a salt/ciphertext pair that no longer decrypts. +// CLI workloads aren't read-heavy, so serializing reads is the right +// trade for guaranteed migration correctness. func (f *EncryptedFileStore) Get(key string) (string, error) { - f.mu.RLock() - defer f.mu.RUnlock() + f.mu.Lock() + defer f.mu.Unlock() secrets, err := f.loadSecrets() if err != nil { @@ -139,6 +179,56 @@ func (f *EncryptedFileStore) Name() string { return "encrypted file" } +// detectKeyType returns which key(s) can currently decrypt the on-disk file. +// It reads the file once and probes each key in order. If the file does not +// exist, fileStoreKeyNone is returned with no error. +func (f *EncryptedFileStore) detectKeyType() (fileStoreKeyType, error) { + data, err := os.ReadFile(f.path) + if err != nil { + if os.IsNotExist(err) { + return fileStoreKeyNone, nil + } + return fileStoreKeyNone, err + } + + hasPassphrase := false + if key, err := f.passphraseKey(false); err == nil { + if _, err := decryptWithKey(key, data); err == nil { + hasPassphrase = true + } + zeroBytes(key) + } + + hasLegacy := f.canDecryptWithLegacyKeys(data) + + switch { + case hasPassphrase && hasLegacy: + return fileStoreKeyBoth, nil + case hasPassphrase: + return fileStoreKeyPassphraseOnly, nil + case hasLegacy: + return fileStoreKeyLegacyOnly, nil + default: + return fileStoreKeyNone, nil + } +} + +// canDecryptWithLegacyKeys returns true when the ciphertext can be opened by +// either the migration master key or the legacy machine-derived key. +func (f *EncryptedFileStore) canDecryptWithLegacyKeys(data []byte) bool { + if len(f.migrationKey) > 0 { + if _, err := decryptWithKey(f.migrationKey, data); err == nil { + return true + } + } + if len(f.legacyKey) > 0 { + if _, err := decryptWithKey(f.legacyKey, data); err == nil { + return true + } + } + return false +} + // loadSecrets loads and decrypts the secrets file. func (f *EncryptedFileStore) loadSecrets() (map[string]string, error) { data, err := os.ReadFile(f.path) @@ -158,7 +248,7 @@ func (f *EncryptedFileStore) loadSecrets() (map[string]string, error) { return secrets, nil } -// saveSecrets encrypts and saves the secrets file. +// saveSecrets encrypts and saves the secrets file atomically. func (f *EncryptedFileStore) saveSecrets(secrets map[string]string) error { plaintext, err := json.Marshal(secrets) if err != nil { @@ -170,16 +260,42 @@ func (f *EncryptedFileStore) saveSecrets(secrets map[string]string) error { return err } - // Ensure directory exists + // Ensure directory exists. dir := filepath.Dir(f.path) if err := os.MkdirAll(dir, 0700); err != nil { return err } - // Write with restrictive permissions - if err := os.WriteFile(f.path, ciphertext, 0600); err != nil { + // Atomic write: write to a sibling temp file, then rename. + tmp, err := os.CreateTemp(dir, ".secrets.enc.tmp.*") + if err != nil { + return err + } + tmpPath := tmp.Name() + + // Clean up the temp file on any failure path. + committed := false + defer func() { + if !committed { + _ = os.Remove(tmpPath) + } + }() + + if err := tmp.Chmod(0600); err != nil { + _ = tmp.Close() + return err + } + if _, err := tmp.Write(ciphertext); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { return err } + if err := os.Rename(tmpPath, f.path); err != nil { + return err + } + committed = true // Remove the plaintext migration key once the store has been rewritten. if f.keyPath != "" { @@ -190,88 +306,160 @@ func (f *EncryptedFileStore) saveSecrets(secrets map[string]string) error { } // encrypt encrypts plaintext using AES-256-GCM. +// +// Passphrase rules enforced here: +// - If passphrase is set: encrypt with the passphrase-derived key. +// - If passphrase is unset AND a legacy file exists: refuse writes until the +// caller sets NYLAS_FILE_STORE_PASSPHRASE so the migration path in decrypt +// can re-encrypt before mutation. +// - If passphrase is unset AND no legacy file: refuse — this is a new install +// that should never have been constructed (NewEncryptedFileStore checks this). func (f *EncryptedFileStore) encrypt(plaintext []byte) ([]byte, error) { + if len(f.passphrase) == 0 { + // Distinguish between "legacy file exists" and "fresh install" for clearer errors. + if _, statErr := os.Stat(f.path); statErr == nil { + return nil, fmt.Errorf( + "encrypted file store requires %s to migrate from the legacy machine-derived key. "+ + "Set it and re-run, or run `nylas auth config` to switch to the system keyring", + fileStorePassphraseEnv, + ) + } + return nil, fmt.Errorf( + "%s must be set to use the encrypted file secret store. "+ + "Set it and re-run, or run `nylas auth config` to switch to the system keyring", + fileStorePassphraseEnv, + ) + } + key, err := f.passphraseKey(true) if err != nil { return nil, err } + defer zeroBytes(key) return encryptWithKey(key, plaintext) } // decrypt decrypts ciphertext using AES-256-GCM. +// +// Decryption order: +// 1. Passphrase key (if passphrase is set) — normal path. +// 2. Legacy key (migration master key or machine-derived key): +// - If passphrase is NOT set: return plaintext for read-only backward +// compatibility. Writes still fail in encrypt until a passphrase is set. +// - If passphrase IS set: re-encrypt with the passphrase key (one-shot +// migration), print a notice to stderr, and return the plaintext. +// 3. Neither key works: return an error. func (f *EncryptedFileStore) decrypt(data []byte) ([]byte, error) { - if key, err := f.passphraseKey(false); err == nil { - plaintext, err := decryptWithKey(key, data) - if err == nil { - return plaintext, nil + // 1. Try passphrase key first. + if len(f.passphrase) > 0 { + if key, err := f.passphraseKey(false); err == nil { + plaintext, decErr := decryptWithKey(key, data) + zeroBytes(key) + if decErr == nil { + return plaintext, nil + } + } else if !os.IsNotExist(err) { + return nil, err } - } else if !os.IsNotExist(err) && len(f.passphrase) > 0 { - return nil, err + // Passphrase set but salt missing or passphrase wrong — fall through to legacy. } - if len(f.migrationKey) > 0 { - plaintext, err := decryptWithKey(f.migrationKey, data) - if err == nil { + // 2. Try legacy keys. + if plaintext, legacyKey, ok := f.tryLegacyDecrypt(data); ok { + _ = legacyKey // used only for the migration path below + if len(f.passphrase) == 0 { return plaintext, nil } - } - if len(f.legacyKey) > 0 { - plaintext, err := decryptWithKey(f.legacyKey, data) - if err == nil { - return plaintext, nil + // Passphrase is available — perform one-shot migration. + if migrateErr := f.migrateToPassphrase(plaintext); migrateErr != nil { + // Migration failed (e.g. disk write error). Return the plaintext so + // the caller's operation still succeeds; the legacy file remains intact. + fmt.Fprintf(os.Stderr, "warning: failed to migrate encrypted file store: %v\n", migrateErr) + } else { + fmt.Fprintf(os.Stderr, + "notice: encrypted file store migrated to passphrase-derived key (%s)\n", + fileStorePassphraseEnv) } + return plaintext, nil } + // 3. Nothing worked. if len(f.passphrase) == 0 { - return nil, fmt.Errorf("%s must be set to unlock the encrypted file store", fileStorePassphraseEnv) + return nil, fmt.Errorf( + "%s must be set to unlock the encrypted file store", + fileStorePassphraseEnv, + ) } - return nil, fmt.Errorf("failed to decrypt encrypted file store with the configured passphrase") } -func encryptWithKey(key, plaintext []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err +// tryLegacyDecrypt attempts decryption with the migration master key first, +// then the machine-derived legacy key. Returns the plaintext, the key used, +// and whether decryption succeeded. +func (f *EncryptedFileStore) tryLegacyDecrypt(data []byte) (plaintext []byte, key []byte, ok bool) { + if len(f.migrationKey) > 0 { + if pt, err := decryptWithKey(f.migrationKey, data); err == nil { + return pt, f.migrationKey, true + } } + if len(f.legacyKey) > 0 { + if pt, err := decryptWithKey(f.legacyKey, data); err == nil { + return pt, f.legacyKey, true + } + } + return nil, nil, false +} - gcm, err := cipher.NewGCM(block) +// migrateToPassphrase re-encrypts plaintext with the passphrase-derived key and +// atomically writes it to disk. If this fails, the on-disk legacy file is left +// intact so the next invocation can retry. +func (f *EncryptedFileStore) migrateToPassphrase(plaintext []byte) error { + ciphertext, err := f.encrypt(plaintext) if err != nil { - return nil, err + return fmt.Errorf("re-encrypt for migration: %w", err) } - nonce := make([]byte, gcm.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return nil, err + dir := filepath.Dir(f.path) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("mkdir for migration: %w", err) } - ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) - return []byte(base64.StdEncoding.EncodeToString(ciphertext)), nil -} - -func decryptWithKey(key, data []byte) ([]byte, error) { - ciphertext, err := base64.StdEncoding.DecodeString(string(data)) + tmp, err := os.CreateTemp(dir, ".secrets.enc.tmp.*") if err != nil { - return nil, err + return fmt.Errorf("create temp file for migration: %w", err) } + tmpPath := tmp.Name() - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } + committed := false + defer func() { + if !committed { + _ = os.Remove(tmpPath) + } + }() - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err + if err := tmp.Chmod(0600); err != nil { + _ = tmp.Close() + return fmt.Errorf("chmod temp file for migration: %w", err) + } + if _, err := tmp.Write(ciphertext); err != nil { + _ = tmp.Close() + return fmt.Errorf("write temp file for migration: %w", err) } + if err := tmp.Close(); err != nil { + return fmt.Errorf("close temp file for migration: %w", err) + } + if err := os.Rename(tmpPath, f.path); err != nil { + return fmt.Errorf("rename temp file for migration: %w", err) + } + committed = true - if len(ciphertext) < gcm.NonceSize() { - return nil, fmt.Errorf("ciphertext too short") + // Remove the plaintext migration master key now that re-encryption succeeded. + if f.keyPath != "" { + _ = os.Remove(f.keyPath) } - nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():] - return gcm.Open(nil, nonce, ciphertext, nil) + return nil } func readCompatibilityMasterKey(path string) ([]byte, error) { @@ -342,84 +530,26 @@ func (f *EncryptedFileStore) loadSalt(create bool) ([]byte, error) { return salt, nil } +// argon2id parameters. The OWASP 2024 guidance is t=2, m=19MiB, p=1 as +// the absolute minimum for password storage; modern hosts comfortably +// support t=3, m=64MiB, p=4 for a CLI use case where derive happens once +// per process. Increasing t from 1 (the previous setting) to 3 raises +// offline-cracking cost ~3x for any attacker who exfiltrates the salt and +// ciphertext. +const ( + argon2idTime uint32 = 3 + argon2idMemory uint32 = 64 * 1024 // 64 MiB + argon2idThreads uint8 = 4 + argon2idKeyLen uint32 = 32 + + // minPassphraseLen is the minimum length we accept for + // NYLAS_FILE_STORE_PASSPHRASE. Argon2id alone cannot save a 4-character + // passphrase from offline brute-force. + minPassphraseLen = 12 +) + func derivePassphraseKey(passphrase, salt []byte) []byte { // Argon2id keeps the fallback store bound to user-supplied secret material // instead of host metadata while staying fast enough for CLI use. - return argon2.IDKey(passphrase, salt, 1, 64*1024, 4, 32) -} - -// deriveLegacyKey derives the pre-v2 machine-specific fallback key so older -// encrypted files can still be read and rewritten with a passphrase-derived key. -func deriveLegacyKey() ([]byte, error) { - // Collect machine-specific identifiers - var identifiers []byte - - // Add hostname - hostname, _ := os.Hostname() - identifiers = append(identifiers, []byte(hostname)...) - - // Add user info - identifiers = append(identifiers, []byte(os.Getenv("USER"))...) - identifiers = append(identifiers, []byte(os.Getenv("USERNAME"))...) // Windows - - // Add home directory - home, _ := os.UserHomeDir() - identifiers = append(identifiers, []byte(home)...) - - // Add OS-specific machine ID if available - machineID := getMachineID() - identifiers = append(identifiers, []byte(machineID)...) - - // Add a static salt to prevent rainbow table attacks - salt := []byte("nylas-cli-v1-secret-store") - identifiers = append(identifiers, salt...) - - // Derive key using SHA-256 - hash := sha256.Sum256(identifiers) - return hash[:], nil -} - -// getMachineID attempts to read a machine-specific ID. -func getMachineID() string { - switch runtime.GOOS { - case "linux": - // Try /etc/machine-id (systemd) - if data, err := os.ReadFile("/etc/machine-id"); err == nil { - return string(data) - } - // Try /var/lib/dbus/machine-id - if data, err := os.ReadFile("/var/lib/dbus/machine-id"); err == nil { - return string(data) - } - case "darwin": - // Try to get hardware UUID from system profiler - if data, err := os.ReadFile("/var/db/SystemKey"); err == nil { - return string(data) - } - // Fallback: use boot time + serial from IOKit - if data, err := os.ReadFile("/Library/Preferences/SystemConfiguration/com.apple.smb.server.plist"); err == nil { - return string(data) - } - case "windows": - // Try to read MachineGuid from registry path - programData := os.Getenv("PROGRAMDATA") - if programData != "" { - // Construct and clean the path to prevent traversal - guidPath := filepath.Join(programData, "Microsoft", "Crypto", "RSA", "MachineKeys", ".GUID") - cleanPath := filepath.Clean(guidPath) - - // Validate the path starts with the expected base (security check) - if strings.HasPrefix(cleanPath, filepath.Clean(programData)) { - if data, err := os.ReadFile(cleanPath); err == nil { - return string(data) - } - } - } - // Fallback: use system drive serial - systemRoot := os.Getenv("SYSTEMROOT") - if systemRoot != "" { - return systemRoot - } - } - return "" + return argon2.IDKey(passphrase, salt, argon2idTime, argon2idMemory, argon2idThreads, argon2idKeyLen) } diff --git a/internal/adapters/keyring/file_crypto.go b/internal/adapters/keyring/file_crypto.go new file mode 100644 index 0000000..ba7d828 --- /dev/null +++ b/internal/adapters/keyring/file_crypto.go @@ -0,0 +1,71 @@ +package keyring + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "io" +) + +// zeroBytes overwrites b with zeros so a key derived from a user +// passphrase doesn't linger in heap memory after use. Go's GC retains +// allocations until they're collected; for a long-running `nylas air` / +// `nylas chat` process the derived AES key would otherwise survive in +// RAM for the lifetime of the process. +func zeroBytes(b []byte) { + for i := range b { + b[i] = 0 + } +} + +// encryptWithKey encrypts plaintext with AES-256-GCM using the given key. +// The returned bytes are base64-encoded and include the nonce prepended to +// the ciphertext. +func encryptWithKey(key, plaintext []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + return []byte(base64.StdEncoding.EncodeToString(ciphertext)), nil +} + +// decryptWithKey decrypts base64-encoded AES-256-GCM ciphertext using the +// given key. The nonce is read from the first gcm.NonceSize() bytes of the +// decoded ciphertext. +func decryptWithKey(key, data []byte) ([]byte, error) { + ciphertext, err := base64.StdEncoding.DecodeString(string(data)) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + if len(ciphertext) < gcm.NonceSize() { + return nil, fmt.Errorf("ciphertext too short") + } + + nonce, ciphertext := ciphertext[:gcm.NonceSize()], ciphertext[gcm.NonceSize():] + return gcm.Open(nil, nonce, ciphertext, nil) +} diff --git a/internal/adapters/keyring/file_legacy.go b/internal/adapters/keyring/file_legacy.go new file mode 100644 index 0000000..3188440 --- /dev/null +++ b/internal/adapters/keyring/file_legacy.go @@ -0,0 +1,83 @@ +package keyring + +import ( + "crypto/sha256" + "os" + "path/filepath" + "runtime" + "strings" +) + +// deriveLegacyKey derives the pre-v2 machine-specific fallback key so older +// encrypted files can still be read and then re-encrypted with a +// passphrase-derived key (one-shot migration). +// +// The key is a SHA-256 hash of concatenated host metadata. It is +// intentionally low-entropy compared to a user-supplied passphrase and exists +// only to allow migration from legacy installations. +func deriveLegacyKey() ([]byte, error) { + var identifiers []byte + + hostname, _ := os.Hostname() + identifiers = append(identifiers, []byte(hostname)...) + + identifiers = append(identifiers, []byte(os.Getenv("USER"))...) + identifiers = append(identifiers, []byte(os.Getenv("USERNAME"))...) // Windows + + home, _ := os.UserHomeDir() + identifiers = append(identifiers, []byte(home)...) + + identifiers = append(identifiers, []byte(getMachineID())...) + + // Static salt to prevent rainbow table attacks against this specific construction. + identifiers = append(identifiers, []byte("nylas-cli-v1-secret-store")...) + + hash := sha256.Sum256(identifiers) + return hash[:], nil +} + +// getMachineID attempts to read a platform-specific machine identifier. +// Returns an empty string when no identifier is available; callers handle this +// gracefully by concatenating an empty contribution. +// +// On macOS, both candidate paths typically require elevated privileges and +// won't exist on a stock install — most macOS users will fall through with +// an empty machine ID. That's intentional: this helper feeds the legacy +// machine-derived migration key, which now exists only to decrypt files +// written by older versions and re-encrypt them under the user-supplied +// passphrase. Empty contribution means the legacy key is weaker, but the +// migration path requires NYLAS_FILE_STORE_PASSPHRASE to be set anyway. +func getMachineID() string { + switch runtime.GOOS { + case "linux": + if data, err := os.ReadFile("/etc/machine-id"); err == nil { + return string(data) + } + if data, err := os.ReadFile("/var/lib/dbus/machine-id"); err == nil { + return string(data) + } + case "darwin": + // Both files are typically root-owned on modern macOS. + if data, err := os.ReadFile("/var/db/SystemKey"); err == nil { + return string(data) + } + if data, err := os.ReadFile("/Library/Preferences/SystemConfiguration/com.apple.smb.server.plist"); err == nil { + return string(data) + } + case "windows": + programData := os.Getenv("PROGRAMDATA") + if programData != "" { + guidPath := filepath.Join(programData, "Microsoft", "Crypto", "RSA", "MachineKeys", ".GUID") + cleanPath := filepath.Clean(guidPath) + if strings.HasPrefix(cleanPath, filepath.Clean(programData)) { + if data, err := os.ReadFile(cleanPath); err == nil { + return string(data) + } + } + } + if systemRoot := os.Getenv("SYSTEMROOT"); systemRoot != "" { + return systemRoot + } + } + return "" +} diff --git a/internal/adapters/keyring/file_migration_concurrent_test.go b/internal/adapters/keyring/file_migration_concurrent_test.go new file mode 100644 index 0000000..55a0be9 --- /dev/null +++ b/internal/adapters/keyring/file_migration_concurrent_test.go @@ -0,0 +1,131 @@ +package keyring + +import ( + "encoding/base64" + "os" + "path/filepath" + "strings" + "sync" + "testing" +) + +// TestEncryptedFileStore_ConcurrentFirstReadMigration covers the race +// flagged in code review: when several callers hit Get on a legacy-only +// .secrets.enc, the first read triggers migration which writes BOTH +// .secrets.salt and .secrets.enc. Under the old RLock-based Get, two +// readers could interleave those writes and leave a salt that didn't +// match the on-disk ciphertext — silently bricking the store. +// +// The test launches N concurrent Get calls, then re-opens the store +// from disk and verifies the migrated salt+ciphertext are self- +// consistent and round-trip the original plaintext. +func TestEncryptedFileStore_ConcurrentFirstReadMigration(t *testing.T) { + tmpDir := t.TempDir() + passphrase := setFileStorePassphrase(t) + + // Seed a legacy-encrypted secrets file. Use the machine-derived + // legacy key — the same path real installations would have come + // from before passphrase support landed. + legacyKey, err := deriveLegacyKey() + if err != nil { + t.Fatalf("deriveLegacyKey failed: %v", err) + } + const legacyValue = "legacy-value" + legacyJSON := []byte(`{"api_key":"` + legacyValue + `"}`) + legacyCiphertext, err := encryptWithKey(legacyKey, legacyJSON) + if err != nil { + t.Fatalf("encryptWithKey failed: %v", err) + } + secretsPath := filepath.Join(tmpDir, ".secrets.enc") + if err := os.WriteFile(secretsPath, legacyCiphertext, 0600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + + // Fan out — enough goroutines to make any race observable, gated + // on a single barrier so they all hit Get in roughly the same + // instant. + const concurrency = 32 + var ( + wg sync.WaitGroup + barrier = make(chan struct{}) + mu sync.Mutex + results = make([]string, 0, concurrency) + errs = make([]error, 0) + ) + wg.Add(concurrency) + for range concurrency { + go func() { + defer wg.Done() + <-barrier + v, gerr := store.Get("api_key") + mu.Lock() + defer mu.Unlock() + if gerr != nil { + errs = append(errs, gerr) + return + } + results = append(results, v) + }() + } + close(barrier) + wg.Wait() + + if len(errs) > 0 { + t.Fatalf("concurrent Get returned %d errors; first: %v", len(errs), errs[0]) + } + if len(results) != concurrency { + t.Fatalf("got %d results, want %d", len(results), concurrency) + } + for i, v := range results { + if v != legacyValue { + t.Fatalf("result[%d] = %q, want %q", i, v, legacyValue) + } + } + + // On-disk consistency check: salt and ciphertext must match. + saltRaw, err := os.ReadFile(filepath.Join(tmpDir, ".secrets.salt")) + if err != nil { + t.Fatalf("read .secrets.salt: %v", err) + } + salt, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(saltRaw))) + if err != nil { + t.Fatalf("decode salt: %v", err) + } + ciphertext, err := os.ReadFile(secretsPath) + if err != nil { + t.Fatalf("read .secrets.enc: %v", err) + } + plaintext, err := decryptWithKey(derivePassphraseKey([]byte(passphrase), salt), ciphertext) + if err != nil { + t.Fatalf("salt/ciphertext mismatch — store would be unrecoverable: %v", err) + } + if !strings.Contains(string(plaintext), legacyValue) { + t.Fatalf("decrypted plaintext %q missing %q", string(plaintext), legacyValue) + } + + // And the migrated ciphertext must NOT decrypt with the legacy key + // any more — proves the migration actually flipped the encryption. + if _, err := decryptWithKey(legacyKey, ciphertext); err == nil { + t.Fatal("migrated ciphertext still decrypts with the legacy key") + } + + // Fresh store from the same dir should also Get the value, proving + // the persisted state is openable from a cold start (not just from + // the in-process store that performed the migration). + fresh, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("re-open NewEncryptedFileStore: %v", err) + } + v, err := fresh.Get("api_key") + if err != nil { + t.Fatalf("fresh.Get after migration: %v", err) + } + if v != legacyValue { + t.Fatalf("fresh.Get = %q, want %q", v, legacyValue) + } +} diff --git a/internal/adapters/keyring/grants.go b/internal/adapters/keyring/grants.go deleted file mode 100644 index c0c4f90..0000000 --- a/internal/adapters/keyring/grants.go +++ /dev/null @@ -1,212 +0,0 @@ -package keyring - -import ( - "encoding/json" - - "github.com/nylas/cli/internal/domain" - "github.com/nylas/cli/internal/ports" -) - -const ( - grantsKey = "grants" - defaultGrantKey = "default_grant" -) - -// GrantStore implements ports.GrantStore using a SecretStore backend. -type GrantStore struct { - secrets ports.SecretStore -} - -// NewGrantStore creates a new GrantStore. -func NewGrantStore(secrets ports.SecretStore) *GrantStore { - return &GrantStore{secrets: secrets} -} - -// SaveGrant saves grant info to storage. -func (g *GrantStore) SaveGrant(info domain.GrantInfo) error { - if info.ID == "" || info.Email == "" { - return domain.ErrInvalidInput - } - - grants, err := g.ListGrants() - if err != nil && err != domain.ErrSecretNotFound { - return err - } - - // Check if grant already exists and update it - found := false - for i, grant := range grants { - if grant.ID == info.ID { - grants[i] = info - found = true - break - } - } - if !found { - grants = append(grants, info) - } - - return g.saveGrants(grants) -} - -// GetGrant retrieves grant info by ID. -func (g *GrantStore) GetGrant(grantID string) (*domain.GrantInfo, error) { - grants, err := g.ListGrants() - if err != nil { - return nil, err - } - - for _, grant := range grants { - if grant.ID == grantID { - return &grant, nil - } - } - return nil, domain.ErrGrantNotFound -} - -// GetGrantByEmail retrieves grant info by email. -func (g *GrantStore) GetGrantByEmail(email string) (*domain.GrantInfo, error) { - grants, err := g.ListGrants() - if err != nil { - return nil, err - } - - for _, grant := range grants { - if grant.Email == email { - return &grant, nil - } - } - return nil, domain.ErrGrantNotFound -} - -// ListGrants returns all stored grants. -func (g *GrantStore) ListGrants() ([]domain.GrantInfo, error) { - data, err := g.secrets.Get(grantsKey) - if err != nil { - if err == domain.ErrSecretNotFound { - if err := g.repairDefaultGrant(nil); err != nil { - return nil, err - } - return []domain.GrantInfo{}, nil - } - return nil, err - } - - var grants []domain.GrantInfo - if err := json.Unmarshal([]byte(data), &grants); err != nil { - return nil, err - } - - sanitized, changed := sanitizeGrants(grants) - if changed { - if err := g.saveGrants(sanitized); err != nil { - return nil, err - } - } - if err := g.repairDefaultGrant(sanitized); err != nil { - return nil, err - } - - return sanitized, nil -} - -// DeleteGrant removes a grant from storage. -func (g *GrantStore) DeleteGrant(grantID string) error { - grants, err := g.ListGrants() - if err != nil { - return err - } - - newGrants := make([]domain.GrantInfo, 0, len(grants)) - for _, grant := range grants { - if grant.ID != grantID { - newGrants = append(newGrants, grant) - } - } - - // Also delete the grant's token if it exists - _ = g.secrets.Delete(ports.GrantTokenKey(grantID)) - - if err := g.saveGrants(newGrants); err != nil { - return err - } - - return g.repairDefaultGrant(newGrants) -} - -// SetDefaultGrant sets the default grant ID. -func (g *GrantStore) SetDefaultGrant(grantID string) error { - return g.secrets.Set(defaultGrantKey, grantID) -} - -// GetDefaultGrant returns the default grant ID. -func (g *GrantStore) GetDefaultGrant() (string, error) { - _, err := g.ListGrants() - if err != nil { - return "", err - } - - grantID, err := g.secrets.Get(defaultGrantKey) - if err != nil { - if err == domain.ErrSecretNotFound { - return "", domain.ErrNoDefaultGrant - } - return "", err - } - - return grantID, nil -} - -// ClearGrants removes all grants from storage. -func (g *GrantStore) ClearGrants() error { - _ = g.secrets.Delete(grantsKey) - _ = g.secrets.Delete(defaultGrantKey) - return nil -} - -func (g *GrantStore) saveGrants(grants []domain.GrantInfo) error { - data, err := json.Marshal(grants) - if err != nil { - return err - } - return g.secrets.Set(grantsKey, string(data)) -} - -func sanitizeGrants(grants []domain.GrantInfo) ([]domain.GrantInfo, bool) { - sanitized := make([]domain.GrantInfo, 0, len(grants)) - changed := false - for _, grant := range grants { - if grant.ID == "" || grant.Email == "" { - changed = true - continue - } - sanitized = append(sanitized, grant) - } - return sanitized, changed -} - -func (g *GrantStore) repairDefaultGrant(grants []domain.GrantInfo) error { - defaultID, err := g.secrets.Get(defaultGrantKey) - if err != nil { - if err != domain.ErrSecretNotFound { - return err - } - if len(grants) == 0 { - return nil - } - return g.secrets.Set(defaultGrantKey, grants[0].ID) - } - - for _, grant := range grants { - if grant.ID == defaultID { - return nil - } - } - - if len(grants) == 0 { - _ = g.secrets.Delete(defaultGrantKey) - return nil - } - - return g.secrets.Set(defaultGrantKey, grants[0].ID) -} diff --git a/internal/adapters/keyring/keyring.go b/internal/adapters/keyring/keyring.go index 003b8c8..a3541b5 100644 --- a/internal/adapters/keyring/keyring.go +++ b/internal/adapters/keyring/keyring.go @@ -3,6 +3,7 @@ package keyring import ( "errors" + "fmt" "os" "github.com/nylas/cli/internal/domain" @@ -28,7 +29,7 @@ func (k *SystemKeyring) Set(key, value string) error { // Get retrieves a secret value for the given key. func (k *SystemKeyring) Get(key string) (string, error) { value, err := keyring.Get(serviceName, key) - if err == keyring.ErrNotFound { + if errors.Is(err, keyring.ErrNotFound) { return "", domain.ErrSecretNotFound } return value, err @@ -97,23 +98,34 @@ func NewSecretStore(configDir string) (ports.SecretStore, error) { return nil, err } - // Migrate credentials from file store to keyring - if apiKey != "" { - _ = kr.Set(ports.KeyAPIKey, apiKey) + // Migrate credentials from file store to keyring. Keep going on per-key + // failures so a single broken entry doesn't block the rest of the move, + // but surface the failures so the user knows something didn't migrate. + var migrationErrs []error + migrate := func(key, value string) { + if value == "" { + return + } + if err := kr.Set(key, value); err != nil { + migrationErrs = append(migrationErrs, fmt.Errorf("migrate %s: %w", key, err)) + } } - if clientID, err := fileStore.Get(ports.KeyClientID); err == nil && clientID != "" { - _ = kr.Set(ports.KeyClientID, clientID) + + migrate(ports.KeyAPIKey, apiKey) + if clientID, err := fileStore.Get(ports.KeyClientID); err == nil { + migrate(ports.KeyClientID, clientID) } - if clientSecret, err := fileStore.Get(ports.KeyClientSecret); err == nil && clientSecret != "" { - _ = kr.Set(ports.KeyClientSecret, clientSecret) + if clientSecret, err := fileStore.Get(ports.KeyClientSecret); err == nil { + migrate(ports.KeyClientSecret, clientSecret) } - // Migrate grants data - if grants, err := fileStore.Get("grants"); err == nil && grants != "" { - _ = kr.Set("grants", grants) - } - if defaultGrant, err := fileStore.Get("default_grant"); err == nil && defaultGrant != "" { - _ = kr.Set("default_grant", defaultGrant) + if len(migrationErrs) > 0 { + // Print to stderr but do not fail — the keyring is usable even with + // partial migration; users may need to re-run `nylas auth config`. + fmt.Fprintf(os.Stderr, "warning: %d secrets failed to migrate from file store to keyring; re-run `nylas auth config` to retry\n", len(migrationErrs)) + for _, e := range migrationErrs { + fmt.Fprintf(os.Stderr, " - %v\n", e) + } } return kr, nil diff --git a/internal/adapters/keyring/keyring_test.go b/internal/adapters/keyring/keyring_test.go index f9fda6a..3045a74 100644 --- a/internal/adapters/keyring/keyring_test.go +++ b/internal/adapters/keyring/keyring_test.go @@ -248,238 +248,3 @@ func TestCrossPlatformKeyDerivation(t *testing.T) { t.Logf("Cross-platform key derivation works on: %s", runtime.GOOS) } - -func TestGrantStore(t *testing.T) { - secrets := keyring.NewMockSecretStore() - store := keyring.NewGrantStore(secrets) - - t.Run("save and get grant", func(t *testing.T) { - info := domain.GrantInfo{ - ID: "test-grant-1", - Email: "test@example.com", - Provider: domain.ProviderGoogle, - } - - err := store.SaveGrant(info) - require.NoError(t, err) - - retrieved, err := store.GetGrant("test-grant-1") - require.NoError(t, err) - assert.Equal(t, info.ID, retrieved.ID) - assert.Equal(t, info.Email, retrieved.Email) - }) - - t.Run("get grant by email", func(t *testing.T) { - retrieved, err := store.GetGrantByEmail("test@example.com") - require.NoError(t, err) - assert.Equal(t, "test-grant-1", retrieved.ID) - }) - - t.Run("list grants", func(t *testing.T) { - grants, err := store.ListGrants() - require.NoError(t, err) - assert.Len(t, grants, 1) - }) - - t.Run("set and get default grant", func(t *testing.T) { - err := store.SetDefaultGrant("test-grant-1") - require.NoError(t, err) - - defaultID, err := store.GetDefaultGrant() - require.NoError(t, err) - assert.Equal(t, "test-grant-1", defaultID) - }) - - t.Run("delete grant", func(t *testing.T) { - err := store.DeleteGrant("test-grant-1") - require.NoError(t, err) - - _, err = store.GetGrant("test-grant-1") - assert.ErrorIs(t, err, domain.ErrGrantNotFound) - }) -} - -func TestGrantStore_MultipleGrants(t *testing.T) { - secrets := keyring.NewMockSecretStore() - store := keyring.NewGrantStore(secrets) - - // Create test grants - grant1 := domain.GrantInfo{ - ID: "grant-1", - Email: "user1@gmail.com", - Provider: domain.ProviderGoogle, - } - grant2 := domain.GrantInfo{ - ID: "grant-2", - Email: "user2@outlook.com", - Provider: domain.ProviderMicrosoft, - } - grant3 := domain.GrantInfo{ - ID: "grant-3", - Email: "user3@gmail.com", - Provider: domain.ProviderGoogle, - } - - t.Run("save multiple grants", func(t *testing.T) { - require.NoError(t, store.SaveGrant(grant1)) - require.NoError(t, store.SaveGrant(grant2)) - require.NoError(t, store.SaveGrant(grant3)) - - grants, err := store.ListGrants() - require.NoError(t, err) - assert.Len(t, grants, 3) - }) - - t.Run("get grant by email from multiple", func(t *testing.T) { - retrieved, err := store.GetGrantByEmail("user2@outlook.com") - require.NoError(t, err) - assert.Equal(t, "grant-2", retrieved.ID) - assert.Equal(t, domain.ProviderMicrosoft, retrieved.Provider) - }) - - t.Run("set and switch default grant", func(t *testing.T) { - // Set grant1 as default - require.NoError(t, store.SetDefaultGrant("grant-1")) - defaultID, err := store.GetDefaultGrant() - require.NoError(t, err) - assert.Equal(t, "grant-1", defaultID) - - // Switch to grant2 - require.NoError(t, store.SetDefaultGrant("grant-2")) - defaultID, err = store.GetDefaultGrant() - require.NoError(t, err) - assert.Equal(t, "grant-2", defaultID) - - // Switch to grant3 - require.NoError(t, store.SetDefaultGrant("grant-3")) - defaultID, err = store.GetDefaultGrant() - require.NoError(t, err) - assert.Equal(t, "grant-3", defaultID) - }) - - t.Run("update existing grant preserves order", func(t *testing.T) { - updatedGrant2 := domain.GrantInfo{ - ID: "grant-2", - Email: "user2-updated@outlook.com", - Provider: domain.ProviderMicrosoft, - } - require.NoError(t, store.SaveGrant(updatedGrant2)) - - retrieved, err := store.GetGrant("grant-2") - require.NoError(t, err) - assert.Equal(t, "user2-updated@outlook.com", retrieved.Email) - - // Should still have 3 grants - grants, err := store.ListGrants() - require.NoError(t, err) - assert.Len(t, grants, 3) - }) - - t.Run("delete grant preserves others", func(t *testing.T) { - require.NoError(t, store.DeleteGrant("grant-2")) - - grants, err := store.ListGrants() - require.NoError(t, err) - assert.Len(t, grants, 2) - - // Remaining grants should be grant1 and grant3 - _, err = store.GetGrant("grant-1") - require.NoError(t, err) - _, err = store.GetGrant("grant-3") - require.NoError(t, err) - _, err = store.GetGrant("grant-2") - assert.ErrorIs(t, err, domain.ErrGrantNotFound) - }) - - t.Run("clear grants removes all", func(t *testing.T) { - require.NoError(t, store.ClearGrants()) - - grants, err := store.ListGrants() - require.NoError(t, err) - assert.Len(t, grants, 0) - - _, err = store.GetDefaultGrant() - assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) - }) - - t.Run("no default grant error", func(t *testing.T) { - _, err := store.GetDefaultGrant() - assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) - }) -} - -// TestGrantStore_DefaultGrantBehavior tests the behavior of default grants -// when grants are deleted or local grant storage needs repair. -func TestGrantStore_DefaultGrantBehavior(t *testing.T) { - secrets := keyring.NewMockSecretStore() - store := keyring.NewGrantStore(secrets) - - grant1 := domain.GrantInfo{ - ID: "grant-1", - Email: "user1@gmail.com", - Provider: domain.ProviderGoogle, - } - - grant2 := domain.GrantInfo{ - ID: "grant-2", - Email: "user2@gmail.com", - Provider: domain.ProviderGoogle, - } - - t.Run("delete default grant switches to another stored grant", func(t *testing.T) { - require.NoError(t, store.SaveGrant(grant1)) - require.NoError(t, store.SaveGrant(grant2)) - require.NoError(t, store.SetDefaultGrant(grant1.ID)) - - require.NoError(t, store.DeleteGrant(grant1.ID)) - - _, err := store.GetGrant(grant1.ID) - assert.ErrorIs(t, err, domain.ErrGrantNotFound) - - defaultID, err := store.GetDefaultGrant() - require.NoError(t, err) - assert.Equal(t, "grant-2", defaultID) - }) - - t.Run("delete last remaining default grant clears default", func(t *testing.T) { - require.NoError(t, store.ClearGrants()) - - require.NoError(t, store.SaveGrant(grant1)) - require.NoError(t, store.SetDefaultGrant(grant1.ID)) - - require.NoError(t, store.DeleteGrant(grant1.ID)) - - _, err := store.GetGrant(grant1.ID) - assert.ErrorIs(t, err, domain.ErrGrantNotFound) - - _, err = store.GetDefaultGrant() - assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) - }) - - t.Run("clear grants clears default", func(t *testing.T) { - // Re-add a grant - require.NoError(t, store.SaveGrant(grant1)) - require.NoError(t, store.SetDefaultGrant(grant1.ID)) - - // Clear all grants - require.NoError(t, store.ClearGrants()) - - // Now default should be cleared - _, err := store.GetDefaultGrant() - assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) - }) - - t.Run("list grants repairs malformed local entries and stale default", func(t *testing.T) { - require.NoError(t, secrets.Set("grants", `[{"id":"","email":"","provider":"nylas"},{"id":"grant-2","email":"user2@gmail.com","provider":"google"}]`)) - require.NoError(t, store.SetDefaultGrant("missing-grant")) - - grants, err := store.ListGrants() - require.NoError(t, err) - require.Len(t, grants, 1) - assert.Equal(t, "grant-2", grants[0].ID) - - defaultID, err := store.GetDefaultGrant() - require.NoError(t, err) - assert.Equal(t, "grant-2", defaultID) - }) -} diff --git a/internal/adapters/mcp/proxy_sse_test.go b/internal/adapters/mcp/proxy_sse_test.go index b64ade2..ddc3a0c 100644 --- a/internal/adapters/mcp/proxy_sse_test.go +++ b/internal/adapters/mcp/proxy_sse_test.go @@ -132,6 +132,11 @@ func (m *mockGrantStore) SaveGrant(info domain.GrantInfo) error { return nil } +func (m *mockGrantStore) ReplaceGrants(grants []domain.GrantInfo) error { + m.grants = append([]domain.GrantInfo(nil), grants...) + return nil +} + func (m *mockGrantStore) GetGrant(grantID string) (*domain.GrantInfo, error) { for _, g := range m.grants { if g.ID == grantID { diff --git a/internal/adapters/nylas/admin.go b/internal/adapters/nylas/admin.go index c13bd46..bc5de13 100644 --- a/internal/adapters/nylas/admin.go +++ b/internal/adapters/nylas/admin.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -32,8 +33,11 @@ func (c *HTTPClient) ListApplications(ctx context.Context) ([]domain.Application return nil, c.parseError(resp) } - // Read body once (special handling: API may return array or single object) - body, err := io.ReadAll(resp.Body) + // Read body once (special handling: API may return array or single object). + // Bound the read so a misbehaving upstream cannot OOM us, and so that an + // auth/error JSON containing tokens or PII cannot be echoed unbounded into + // our error string below. + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return nil, err } @@ -57,17 +61,22 @@ func (c *HTTPClient) ListApplications(ctx context.Context) ([]domain.Application } } - // If both fail, return error with response body for debugging - return nil, fmt.Errorf("failed to decode applications response: %s", string(body)) + // Both decodings failed. Report status only — never echo the raw body + // since it may carry tokens, customer data, or other sensitive fields. + return nil, fmt.Errorf("failed to decode applications response (status %d, %d bytes)", resp.StatusCode, len(body)) } +// maxResponseBodySize bounds bodies read for ad-hoc decoding (1 MiB). Larger +// bodies are an upstream bug; truncating prevents secret leakage in errors. +const maxResponseBodySize = 1 << 20 + // GetApplication retrieves a specific application. func (c *HTTPClient) GetApplication(ctx context.Context, appID string) (*domain.Application, error) { if err := validateRequired("application ID", appID); err != nil { return nil, err } - queryURL := fmt.Sprintf("%s/v3/applications/%s", c.baseURL, appID) + queryURL := fmt.Sprintf("%s/v3/applications/%s", c.baseURL, url.PathEscape(appID)) var result struct { Data domain.Application `json:"data"` @@ -102,7 +111,7 @@ func (c *HTTPClient) UpdateApplication(ctx context.Context, appID string, req *d return nil, err } - queryURL := fmt.Sprintf("%s/v3/applications/%s", c.baseURL, appID) + queryURL := fmt.Sprintf("%s/v3/applications/%s", c.baseURL, url.PathEscape(appID)) resp, err := c.doJSONRequest(ctx, "PATCH", queryURL, req) if err != nil { @@ -123,7 +132,7 @@ func (c *HTTPClient) DeleteApplication(ctx context.Context, appID string) error if err := validateRequired("application ID", appID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/applications/%s", c.baseURL, appID) + queryURL := fmt.Sprintf("%s/v3/applications/%s", c.baseURL, url.PathEscape(appID)) return c.doDelete(ctx, queryURL) } @@ -148,7 +157,7 @@ func (c *HTTPClient) GetCallbackURI(ctx context.Context, uriID string) (*domain. return nil, err } - queryURL := fmt.Sprintf("%s/v3/applications/callback-uris/%s", c.baseURL, uriID) + queryURL := fmt.Sprintf("%s/v3/applications/callback-uris/%s", c.baseURL, url.PathEscape(uriID)) var result struct { Data domain.CallbackURI `json:"data"` @@ -190,7 +199,7 @@ func (c *HTTPClient) UpdateCallbackURI(ctx context.Context, uriID string, req *d return nil, err } - queryURL := fmt.Sprintf("%s/v3/applications/callback-uris/%s", c.baseURL, uriID) + queryURL := fmt.Sprintf("%s/v3/applications/callback-uris/%s", c.baseURL, url.PathEscape(uriID)) resp, err := c.doJSONRequest(ctx, "PATCH", queryURL, req) if err != nil { @@ -211,7 +220,7 @@ func (c *HTTPClient) DeleteCallbackURI(ctx context.Context, uriID string) error if err := validateRequired("callback URI ID", uriID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/applications/callback-uris/%s", c.baseURL, uriID) + queryURL := fmt.Sprintf("%s/v3/applications/callback-uris/%s", c.baseURL, url.PathEscape(uriID)) return c.doDelete(ctx, queryURL) } @@ -236,7 +245,7 @@ func (c *HTTPClient) GetConnector(ctx context.Context, connectorID string) (*dom return nil, err } - queryURL := fmt.Sprintf("%s/v3/connectors/%s", c.baseURL, connectorID) + queryURL := fmt.Sprintf("%s/v3/connectors/%s", c.baseURL, url.PathEscape(connectorID)) var result struct { Data domain.Connector `json:"data"` @@ -271,7 +280,7 @@ func (c *HTTPClient) UpdateConnector(ctx context.Context, connectorID string, re return nil, err } - queryURL := fmt.Sprintf("%s/v3/connectors/%s", c.baseURL, connectorID) + queryURL := fmt.Sprintf("%s/v3/connectors/%s", c.baseURL, url.PathEscape(connectorID)) resp, err := c.doJSONRequest(ctx, "PATCH", queryURL, req) if err != nil { @@ -292,7 +301,7 @@ func (c *HTTPClient) DeleteConnector(ctx context.Context, connectorID string) er if err := validateRequired("connector ID", connectorID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/connectors/%s", c.baseURL, connectorID) + queryURL := fmt.Sprintf("%s/v3/connectors/%s", c.baseURL, url.PathEscape(connectorID)) return c.doDelete(ctx, queryURL) } @@ -304,7 +313,7 @@ func (c *HTTPClient) ListCredentials(ctx context.Context, connectorID string) ([ return nil, err } - queryURL := fmt.Sprintf("%s/v3/connectors/%s/credentials", c.baseURL, connectorID) + queryURL := fmt.Sprintf("%s/v3/connectors/%s/credentials", c.baseURL, url.PathEscape(connectorID)) var result struct { Data []domain.ConnectorCredential `json:"data"` @@ -321,7 +330,7 @@ func (c *HTTPClient) GetCredential(ctx context.Context, credentialID string) (*d return nil, err } - queryURL := fmt.Sprintf("%s/v3/credentials/%s", c.baseURL, credentialID) + queryURL := fmt.Sprintf("%s/v3/credentials/%s", c.baseURL, url.PathEscape(credentialID)) var result struct { Data domain.ConnectorCredential `json:"data"` @@ -338,7 +347,7 @@ func (c *HTTPClient) CreateCredential(ctx context.Context, connectorID string, r return nil, err } - queryURL := fmt.Sprintf("%s/v3/connectors/%s/credentials", c.baseURL, connectorID) + queryURL := fmt.Sprintf("%s/v3/connectors/%s/credentials", c.baseURL, url.PathEscape(connectorID)) resp, err := c.doJSONRequest(ctx, "POST", queryURL, req) if err != nil { @@ -360,7 +369,7 @@ func (c *HTTPClient) UpdateCredential(ctx context.Context, credentialID string, return nil, err } - queryURL := fmt.Sprintf("%s/v3/credentials/%s", c.baseURL, credentialID) + queryURL := fmt.Sprintf("%s/v3/credentials/%s", c.baseURL, url.PathEscape(credentialID)) resp, err := c.doJSONRequest(ctx, "PATCH", queryURL, req) if err != nil { @@ -381,32 +390,75 @@ func (c *HTTPClient) DeleteCredential(ctx context.Context, credentialID string) if err := validateRequired("credential ID", credentialID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/credentials/%s", c.baseURL, credentialID) + queryURL := fmt.Sprintf("%s/v3/credentials/%s", c.baseURL, url.PathEscape(credentialID)) return c.doDelete(ctx, queryURL) } // Admin Grant Operations -// ListAllGrants retrieves all grants with optional filtering. +// grantPageSize is the per-request page size used when walking grants. +// The Nylas v3 /v3/grants endpoint accepts limit values up to 200 and +// defaults to 10 — using the maximum minimises the number of round-trips +// when paginating large result sets. +const grantPageSize = 200 + +// maxGrantPages caps the number of pages ListAllGrants will fetch as a +// hard safety ceiling. At grantPageSize=200 this allows up to 200,000 +// grants — well above any realistic tenant. +const maxGrantPages = 1000 + +// ListAllGrants retrieves all grants matching the optional filters, +// transparently walking offset/limit pagination so callers always see the +// complete result set. +// +// The Nylas v3 /v3/grants endpoint is offset-paginated (limit, offset), +// not cursor-paginated — its response does not include next_cursor — so +// pagination stops when a page returns fewer than grantPageSize grants. +// +// When params.Limit is positive, at most that many grants are returned +// and pagination stops once the cap is reached. When params is nil or +// Limit is zero, every page is fetched. func (c *HTTPClient) ListAllGrants(ctx context.Context, params *domain.GrantsQueryParams) ([]domain.Grant, error) { baseURL := fmt.Sprintf("%s/v3/grants", c.baseURL) - qb := NewQueryBuilder() + maxResults := 0 + connectorID := "" + status := "" + offset := 0 if params != nil { - qb.AddInt("limit", params.Limit). - AddInt("offset", params.Offset). - Add("connector_id", params.ConnectorID). - Add("status", params.Status) - } - queryURL := qb.BuildURL(baseURL) + maxResults = params.Limit + connectorID = params.ConnectorID + status = params.Status + offset = params.Offset + } + + grants := make([]domain.Grant, 0) + for range maxGrantPages { + queryURL := NewQueryBuilder(). + AddInt("limit", grantPageSize). + AddInt("offset", offset). + Add("connector_id", connectorID). + Add("status", status). + BuildURL(baseURL) + + var result struct { + Data []domain.Grant `json:"data"` + } + if err := c.doGet(ctx, queryURL, &result); err != nil { + return nil, err + } - var result struct { - Data []domain.Grant `json:"data"` - } - if err := c.doGet(ctx, queryURL, &result); err != nil { - return nil, err + grants = append(grants, result.Data...) + if maxResults > 0 && len(grants) >= maxResults { + return grants[:maxResults], nil + } + // Last page: server returned fewer than a full page. + if len(result.Data) < grantPageSize { + return grants, nil + } + offset += len(result.Data) } - return result.Data, nil + return nil, fmt.Errorf("failed to paginate grants: exceeded max page count (%d)", maxGrantPages) } // GetGrantStats retrieves grant statistics. diff --git a/internal/adapters/nylas/admin_grants_test.go b/internal/adapters/nylas/admin_grants_test.go index b4d67b0..e0bcdcb 100644 --- a/internal/adapters/nylas/admin_grants_test.go +++ b/internal/adapters/nylas/admin_grants_test.go @@ -3,6 +3,7 @@ package nylas_test import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -58,9 +59,9 @@ func TestHTTPClient_ListAllGrants_WithParams(t *testing.T) { assert.Equal(t, "/v3/grants", r.URL.Path) assert.Equal(t, "GET", r.Method) - // Check query parameters + // limit is the API page size (max 200), not the caller's cap. query := r.URL.Query() - assert.Equal(t, "10", query.Get("limit")) + assert.Equal(t, "200", query.Get("limit")) assert.Equal(t, "conn-123", query.Get("connector_id")) response := map[string]any{ @@ -94,6 +95,104 @@ func TestHTTPClient_ListAllGrants_WithParams(t *testing.T) { assert.Equal(t, "google", string(grants[0].Provider)) } +func TestHTTPClient_ListAllGrants_FollowsPagination(t *testing.T) { + // Regression: the Nylas v3 /v3/grants endpoint is offset-paginated + // (limit, offset) and does NOT return next_cursor. ListAllGrants + // previously made a single request and silently truncated to the API + // default page size (10), so any tenant with >10 grants — including + // the `nylas auth config` flow via SyncGrants → ListGrants — only + // saw the first page. + const apiPageSize = 200 // mirrors the unexported grantPageSize constant + full := make([]map[string]any, 0, apiPageSize) + for i := range apiPageSize { + full = append(full, map[string]any{ + "id": fmt.Sprintf("grant-page1-%d", i), + "provider": "google", + "grant_status": "valid", + }) + } + + calls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + query := r.URL.Query() + assert.Equal(t, "200", query.Get("limit"), "should request the API max page size") + + w.Header().Set("Content-Type", "application/json") + switch query.Get("offset") { + case "", "0": + // First page is full — implementation must fetch another. + _ = json.NewEncoder(w).Encode(map[string]any{"data": full}) + case "200": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"id": "grant-tail-1", "provider": "microsoft", "grant_status": "valid"}, + {"id": "grant-tail-2", "provider": "microsoft", "grant_status": "valid"}, + }, + }) + default: + t.Fatalf("unexpected offset %q", query.Get("offset")) + } + })) + defer server.Close() + + client := nylas.NewHTTPClient() + client.SetCredentials("client-id", "secret", "api-key") + client.SetBaseURL(server.URL) + + grants, err := client.ListAllGrants(context.Background(), nil) + require.NoError(t, err) + assert.Equal(t, 2, calls, "should have advanced offset and made a second request") + assert.Len(t, grants, 202) + assert.Equal(t, "grant-tail-2", grants[201].ID) +} + +func TestHTTPClient_ListAllGrants_StopsOnShortPage(t *testing.T) { + calls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"id": "grant-1", "provider": "google", "grant_status": "valid"}, + {"id": "grant-2", "provider": "google", "grant_status": "valid"}, + }, + }) + })) + defer server.Close() + + client := nylas.NewHTTPClient() + client.SetCredentials("client-id", "secret", "api-key") + client.SetBaseURL(server.URL) + + grants, err := client.ListAllGrants(context.Background(), nil) + require.NoError(t, err) + assert.Equal(t, 1, calls, "short first page must terminate pagination") + assert.Len(t, grants, 2) +} + +func TestHTTPClient_ListAllGrants_LimitCapsResults(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"id": "grant-1", "provider": "google", "grant_status": "valid"}, + {"id": "grant-2", "provider": "google", "grant_status": "valid"}, + {"id": "grant-3", "provider": "google", "grant_status": "valid"}, + }, + }) + })) + defer server.Close() + + client := nylas.NewHTTPClient() + client.SetCredentials("client-id", "secret", "api-key") + client.SetBaseURL(server.URL) + + grants, err := client.ListAllGrants(context.Background(), &domain.GrantsQueryParams{Limit: 2}) + require.NoError(t, err) + assert.Len(t, grants, 2, "client-side limit should cap results") +} + func TestHTTPClient_GetGrantStats(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v3/grants", r.URL.Path) diff --git a/internal/adapters/nylas/agent.go b/internal/adapters/nylas/agent.go index 507b342..8da113b 100644 --- a/internal/adapters/nylas/agent.go +++ b/internal/adapters/nylas/agent.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -93,7 +94,7 @@ func (c *HTTPClient) UpdateAgentAccount(ctx context.Context, grantID, email, app return nil, fmt.Errorf("%w: grant is not a nylas agent account (provider=%s)", domain.ErrInvalidGrant, grant.Provider) } - queryURL := fmt.Sprintf("%s/v3/grants/%s", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s", c.baseURL, url.PathEscape(grantID)) settings := make(map[string]any) settings["email"] = email if grant.Settings.PolicyID != "" { diff --git a/internal/adapters/nylas/agent_test.go b/internal/adapters/nylas/agent_test.go index 476b42d..dc1ca29 100644 --- a/internal/adapters/nylas/agent_test.go +++ b/internal/adapters/nylas/agent_test.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/http/httptest" + "strconv" "strings" "sync/atomic" "testing" @@ -44,7 +45,10 @@ func TestListAgentAccounts(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v3/grants", r.URL.Path) assert.Equal(t, http.MethodGet, r.Method) - assert.Equal(t, "nylas", r.URL.Query().Get("provider")) + // listManagedGrants intentionally does NOT pass provider= as a + // server-side filter — the filtered listing has been observed to + // lag freshly-created managed grants. We filter client-side. + assert.Equal(t, "", r.URL.Query().Get("provider")) response := map[string]any{ "data": []map[string]any{ @@ -59,6 +63,13 @@ func TestListAgentAccounts(t *testing.T) { }, "created_at": time.Now().Unix(), }, + // Mix in a non-nylas grant — the client must filter it out. + { + "id": "google-001", + "email": "user@gmail.com", + "provider": "google", + "grant_status": "valid", + }, }, } @@ -73,7 +84,7 @@ func TestListAgentAccounts(t *testing.T) { accounts, err := client.ListAgentAccounts(context.Background()) require.NoError(t, err) - require.Len(t, accounts, 1) + require.Len(t, accounts, 1, "non-nylas grants must be filtered client-side") assert.Equal(t, "agent-001", accounts[0].ID) assert.Equal(t, "agent@example.com", accounts[0].Email) assert.Equal(t, "policy-123", accounts[0].Settings.PolicyID) @@ -84,11 +95,28 @@ func TestListAgentAccounts_PaginatesAllResults(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v3/grants", r.URL.Path) assert.Equal(t, http.MethodGet, r.Method) - assert.Equal(t, "nylas", r.URL.Query().Get("provider")) + // No server-side provider filter — see TestListAgentAccounts. + assert.Equal(t, "", r.URL.Query().Get("provider")) + assert.Equal(t, "", r.URL.Query().Get("page_token")) + assert.Equal(t, strconv.Itoa(grantPageSize), r.URL.Query().Get("limit")) requests++ var response map[string]any - if r.URL.Query().Get("page_token") == "" { + switch r.URL.Query().Get("offset") { + case "": + page := make([]map[string]any, grantPageSize) + for i := range page { + page[i] = map[string]any{ + "id": "google-" + strconv.Itoa(i), + "email": "user" + strconv.Itoa(i) + "@gmail.com", + "provider": "google", + "grant_status": "valid", + } + } + response = map[string]any{ + "data": page, + } + case strconv.Itoa(grantPageSize): response = map[string]any{ "data": []map[string]any{ { @@ -97,13 +125,6 @@ func TestListAgentAccounts_PaginatesAllResults(t *testing.T) { "provider": "nylas", "grant_status": "valid", }, - }, - "next_cursor": "cursor-2", - } - } else { - assert.Equal(t, "cursor-2", r.URL.Query().Get("page_token")) - response = map[string]any{ - "data": []map[string]any{ { "id": "agent-002", "email": "second@example.com", @@ -112,6 +133,8 @@ func TestListAgentAccounts_PaginatesAllResults(t *testing.T) { }, }, } + default: + t.Fatalf("unexpected offset %q", r.URL.Query().Get("offset")) } w.Header().Set("Content-Type", "application/json") diff --git a/internal/adapters/nylas/attachments.go b/internal/adapters/nylas/attachments.go index 45e8194..2375e75 100644 --- a/internal/adapters/nylas/attachments.go +++ b/internal/adapters/nylas/attachments.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -22,7 +23,7 @@ type attachmentResponse struct { // GetAttachment retrieves attachment metadata. func (c *HTTPClient) GetAttachment(ctx context.Context, grantID, messageID, attachmentID string) (*domain.Attachment, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s/attachments/%s", c.baseURL, grantID, messageID, attachmentID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s/attachments/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(messageID), url.PathEscape(attachmentID)) var result struct { Data attachmentResponse `json:"data"` @@ -44,7 +45,7 @@ func (c *HTTPClient) GetAttachment(ctx context.Context, grantID, messageID, atta // DownloadAttachment downloads attachment content. func (c *HTTPClient) DownloadAttachment(ctx context.Context, grantID, messageID, attachmentID string) (io.ReadCloser, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s/attachments/%s/download", c.baseURL, grantID, messageID, attachmentID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s/attachments/%s/download", c.baseURL, url.PathEscape(grantID), url.PathEscape(messageID), url.PathEscape(attachmentID)) req, err := http.NewRequestWithContext(ctx, "GET", queryURL, nil) if err != nil { diff --git a/internal/adapters/nylas/auth.go b/internal/adapters/nylas/auth.go index 5cde7c3..9ccbb5a 100644 --- a/internal/adapters/nylas/auth.go +++ b/internal/adapters/nylas/auth.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net/http" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -71,23 +72,17 @@ func (c *HTTPClient) ExchangeCode(ctx context.Context, code, redirectURI, codeVe }, nil } -// ListGrants lists all grants for the application. +// ListGrants lists all grants for the application, transparently +// following next_cursor pagination so callers always see the complete +// result set. The Nylas v3 default page size (10) would otherwise +// silently truncate accounts with more than ten grants. func (c *HTTPClient) ListGrants(ctx context.Context) ([]domain.Grant, error) { - queryURL := c.baseURL + "/v3/grants" - - var result struct { - Data []domain.Grant `json:"data"` - } - if err := c.doGet(ctx, queryURL, &result); err != nil { - return nil, err - } - - return result.Data, nil + return c.ListAllGrants(ctx, nil) } // GetGrant retrieves a specific grant. func (c *HTTPClient) GetGrant(ctx context.Context, grantID string) (*domain.Grant, error) { - queryURL := c.baseURL + "/v3/grants/" + grantID + queryURL := c.baseURL + "/v3/grants/" + url.PathEscape(grantID) var result struct { Data domain.Grant `json:"data"` @@ -101,7 +96,7 @@ func (c *HTTPClient) GetGrant(ctx context.Context, grantID string) (*domain.Gran // RevokeGrant revokes a grant. func (c *HTTPClient) RevokeGrant(ctx context.Context, grantID string) error { - req, err := http.NewRequestWithContext(ctx, "DELETE", c.baseURL+"/v3/grants/"+grantID, nil) + req, err := http.NewRequestWithContext(ctx, "DELETE", c.baseURL+"/v3/grants/"+url.PathEscape(grantID), nil) if err != nil { return err } diff --git a/internal/adapters/nylas/auth_test.go b/internal/adapters/nylas/auth_test.go index 4c07532..ecfe19c 100644 --- a/internal/adapters/nylas/auth_test.go +++ b/internal/adapters/nylas/auth_test.go @@ -5,6 +5,7 @@ package nylas import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -269,6 +270,52 @@ func TestHTTPClient_ListGrants(t *testing.T) { } } +func TestHTTPClient_ListGrants_FollowsPagination(t *testing.T) { + // Regression: ListGrants issued a single GET /v3/grants without + // paginating, so tenants with more grants than the v3 default page + // size (10) silently lost the rest — including in the `nylas auth + // config` flow, which would only sync the first page (e.g. 9 valid + // grants out of 25). The /v3/grants endpoint is offset-paginated, so + // ListGrants must walk offsets until a short page is returned. + const apiPageSize = 200 + full := make([]map[string]any, 0, apiPageSize) + for i := range apiPageSize { + full = append(full, map[string]any{ + "id": fmt.Sprintf("grant-page1-%d", i), + "provider": "google", + "grant_status": "valid", + }) + } + + calls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + w.Header().Set("Content-Type", "application/json") + switch r.URL.Query().Get("offset") { + case "", "0": + _ = json.NewEncoder(w).Encode(map[string]any{"data": full}) + case "200": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"id": "grant-tail", "provider": "microsoft", "grant_status": "valid"}, + }, + }) + default: + t.Fatalf("unexpected offset %q", r.URL.Query().Get("offset")) + } + })) + defer server.Close() + + client := newTestClient("test-api-key", "test-client-id", "") + client.SetBaseURL(server.URL) + + grants, err := client.ListGrants(context.Background()) + require.NoError(t, err) + assert.Equal(t, 2, calls, "should advance offset and fetch a second page") + assert.Len(t, grants, 201) + assert.Equal(t, "grant-tail", grants[200].ID) +} + func TestHTTPClient_GetGrant(t *testing.T) { tests := []struct { name string diff --git a/internal/adapters/nylas/calendars_calendars.go b/internal/adapters/nylas/calendars_calendars.go index 3d54314..cd6577e 100644 --- a/internal/adapters/nylas/calendars_calendars.go +++ b/internal/adapters/nylas/calendars_calendars.go @@ -3,6 +3,7 @@ package nylas import ( "context" "fmt" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -12,7 +13,7 @@ func (c *HTTPClient) GetCalendars(ctx context.Context, grantID string) ([]domain return nil, err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars", c.baseURL, url.PathEscape(grantID)) var result struct { Data []calendarResponse `json:"data"` @@ -33,7 +34,7 @@ func (c *HTTPClient) GetCalendar(ctx context.Context, grantID, calendarID string return nil, err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars/%s", c.baseURL, grantID, calendarID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(calendarID)) var result struct { Data calendarResponse `json:"data"` @@ -52,7 +53,7 @@ func (c *HTTPClient) CreateCalendar(ctx context.Context, grantID string, req *do return nil, err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars", c.baseURL, url.PathEscape(grantID)) payload := map[string]any{ "name": req.Name, @@ -92,7 +93,7 @@ func (c *HTTPClient) UpdateCalendar(ctx context.Context, grantID, calendarID str return nil, err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars/%s", c.baseURL, grantID, calendarID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(calendarID)) payload := make(map[string]any) if req.Name != nil { @@ -135,7 +136,7 @@ func (c *HTTPClient) DeleteCalendar(ctx context.Context, grantID, calendarID str if err := validateRequired("calendar ID", calendarID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars/%s", c.baseURL, grantID, calendarID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(calendarID)) return c.doDelete(ctx, queryURL) } diff --git a/internal/adapters/nylas/calendars_events.go b/internal/adapters/nylas/calendars_events.go index ec475b2..681e7b4 100644 --- a/internal/adapters/nylas/calendars_events.go +++ b/internal/adapters/nylas/calendars_events.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -31,7 +32,7 @@ func (c *HTTPClient) GetEventsWithCursor(ctx context.Context, grantID, calendarI params.Limit = 10 } - baseURL := fmt.Sprintf("%s/v3/grants/%s/events", c.baseURL, grantID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/events", c.baseURL, url.PathEscape(grantID)) queryURL := NewQueryBuilder(). Add("calendar_id", calendarID). AddInt("limit", params.Limit). @@ -74,7 +75,7 @@ func (c *HTTPClient) GetEvent(ctx context.Context, grantID, calendarID, eventID if err := validateRequired("event ID", eventID); err != nil { return nil, err } - baseURL := fmt.Sprintf("%s/v3/grants/%s/events/%s", c.baseURL, grantID, eventID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/events/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(eventID)) queryURL := NewQueryBuilder().Add("calendar_id", calendarID).BuildURL(baseURL) var result struct { @@ -90,7 +91,7 @@ func (c *HTTPClient) GetEvent(ctx context.Context, grantID, calendarID, eventID // CreateEvent creates a new event. func (c *HTTPClient) CreateEvent(ctx context.Context, grantID, calendarID string, req *domain.CreateEventRequest) (*domain.Event, error) { - baseURL := fmt.Sprintf("%s/v3/grants/%s/events", c.baseURL, grantID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/events", c.baseURL, url.PathEscape(grantID)) queryURL := NewQueryBuilder().Add("calendar_id", calendarID).BuildURL(baseURL) payload := map[string]any{ @@ -142,7 +143,7 @@ func (c *HTTPClient) CreateEvent(ctx context.Context, grantID, calendarID string // UpdateEvent updates an existing event. func (c *HTTPClient) UpdateEvent(ctx context.Context, grantID, calendarID, eventID string, req *domain.UpdateEventRequest) (*domain.Event, error) { - baseURL := fmt.Sprintf("%s/v3/grants/%s/events/%s", c.baseURL, grantID, eventID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/events/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(eventID)) queryURL := NewQueryBuilder().Add("calendar_id", calendarID).BuildURL(baseURL) payload := make(map[string]any) @@ -195,14 +196,14 @@ func (c *HTTPClient) UpdateEvent(ctx context.Context, grantID, calendarID, event // DeleteEvent deletes an event. func (c *HTTPClient) DeleteEvent(ctx context.Context, grantID, calendarID, eventID string) error { - baseURL := fmt.Sprintf("%s/v3/grants/%s/events/%s", c.baseURL, grantID, eventID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/events/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(eventID)) queryURL := NewQueryBuilder().Add("calendar_id", calendarID).BuildURL(baseURL) return c.doDelete(ctx, queryURL) } // SendRSVP sends an RSVP response to an event invitation. func (c *HTTPClient) SendRSVP(ctx context.Context, grantID, calendarID, eventID string, req *domain.SendRSVPRequest) error { - baseURL := fmt.Sprintf("%s/v3/grants/%s/events/%s/send-rsvp", c.baseURL, grantID, eventID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/events/%s/send-rsvp", c.baseURL, url.PathEscape(grantID), url.PathEscape(eventID)) queryURL := NewQueryBuilder().Add("calendar_id", calendarID).BuildURL(baseURL) payload := map[string]any{ @@ -223,7 +224,7 @@ func (c *HTTPClient) SendRSVP(ctx context.Context, grantID, calendarID, eventID // GetFreeBusy retrieves free/busy information. func (c *HTTPClient) GetFreeBusy(ctx context.Context, grantID string, freeBusyReq *domain.FreeBusyRequest) (*domain.FreeBusyResponse, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars/free-busy", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/calendars/free-busy", c.baseURL, url.PathEscape(grantID)) resp, err := c.doJSONRequest(ctx, "POST", queryURL, freeBusyReq) if err != nil { diff --git a/internal/adapters/nylas/calendars_virtual_recurring.go b/internal/adapters/nylas/calendars_virtual_recurring.go index dfc4c45..6c8f602 100644 --- a/internal/adapters/nylas/calendars_virtual_recurring.go +++ b/internal/adapters/nylas/calendars_virtual_recurring.go @@ -3,6 +3,7 @@ package nylas import ( "context" "fmt" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -48,7 +49,7 @@ func (c *HTTPClient) ListVirtualCalendarGrants(ctx context.Context) ([]domain.Vi // GetVirtualCalendarGrant retrieves a single virtual calendar grant by ID. func (c *HTTPClient) GetVirtualCalendarGrant(ctx context.Context, grantID string) (*domain.VirtualCalendarGrant, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s", c.baseURL, url.PathEscape(grantID)) var result struct { Data domain.VirtualCalendarGrant `json:"data"` @@ -62,7 +63,7 @@ func (c *HTTPClient) GetVirtualCalendarGrant(ctx context.Context, grantID string // DeleteVirtualCalendarGrant deletes a virtual calendar grant. func (c *HTTPClient) DeleteVirtualCalendarGrant(ctx context.Context, grantID string) error { - queryURL := fmt.Sprintf("%s/v3/grants/%s", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s", c.baseURL, url.PathEscape(grantID)) return c.doDelete(ctx, queryURL) } @@ -77,7 +78,7 @@ func (c *HTTPClient) GetRecurringEventInstances(ctx context.Context, grantID, ca params.ExpandRecurring = true } - baseURL := fmt.Sprintf("%s/v3/grants/%s/events", c.baseURL, grantID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/events", c.baseURL, url.PathEscape(grantID)) queryURL := NewQueryBuilder(). Add("calendar_id", calendarID). Add("master_event_id", masterEventID). diff --git a/internal/adapters/nylas/contacts.go b/internal/adapters/nylas/contacts.go index 737d5f2..f5ce062 100644 --- a/internal/adapters/nylas/contacts.go +++ b/internal/adapters/nylas/contacts.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/util" @@ -55,7 +56,7 @@ func (c *HTTPClient) GetContacts(ctx context.Context, grantID string, params *do // GetContactsWithCursor retrieves contacts with pagination cursor. func (c *HTTPClient) GetContactsWithCursor(ctx context.Context, grantID string, params *domain.ContactQueryParams) (*domain.ContactListResponse, error) { - baseURL := fmt.Sprintf("%s/v3/grants/%s/contacts", c.baseURL, grantID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/contacts", c.baseURL, url.PathEscape(grantID)) qb := NewQueryBuilder() if params != nil { @@ -94,7 +95,7 @@ func (c *HTTPClient) GetContact(ctx context.Context, grantID, contactID string) // GetContactWithPicture retrieves a single contact by ID with optional profile picture. func (c *HTTPClient) GetContactWithPicture(ctx context.Context, grantID, contactID string, includePicture bool) (*domain.Contact, error) { - baseURL := fmt.Sprintf("%s/v3/grants/%s/contacts/%s", c.baseURL, grantID, contactID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/contacts/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(contactID)) queryURL := NewQueryBuilder().AddBool("profile_picture", includePicture).BuildURL(baseURL) var result struct { @@ -110,7 +111,7 @@ func (c *HTTPClient) GetContactWithPicture(ctx context.Context, grantID, contact // CreateContact creates a new contact. func (c *HTTPClient) CreateContact(ctx context.Context, grantID string, req *domain.CreateContactRequest) (*domain.Contact, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts", c.baseURL, url.PathEscape(grantID)) resp, err := c.doJSONRequest(ctx, "POST", queryURL, req) if err != nil { @@ -130,7 +131,7 @@ func (c *HTTPClient) CreateContact(ctx context.Context, grantID string, req *dom // UpdateContact updates an existing contact. func (c *HTTPClient) UpdateContact(ctx context.Context, grantID, contactID string, req *domain.UpdateContactRequest) (*domain.Contact, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/%s", c.baseURL, grantID, contactID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(contactID)) resp, err := c.doJSONRequest(ctx, "PUT", queryURL, req, http.StatusOK) if err != nil { @@ -150,13 +151,13 @@ func (c *HTTPClient) UpdateContact(ctx context.Context, grantID, contactID strin // DeleteContact deletes a contact. func (c *HTTPClient) DeleteContact(ctx context.Context, grantID, contactID string) error { - queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/%s", c.baseURL, grantID, contactID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(contactID)) return c.doDelete(ctx, queryURL) } // GetContactGroups retrieves contact groups for a grant. func (c *HTTPClient) GetContactGroups(ctx context.Context, grantID string) ([]domain.ContactGroup, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups", c.baseURL, url.PathEscape(grantID)) var result struct { Data []contactGroupResponse `json:"data"` @@ -170,7 +171,7 @@ func (c *HTTPClient) GetContactGroups(ctx context.Context, grantID string) ([]do // GetContactGroup retrieves a single contact group by ID. func (c *HTTPClient) GetContactGroup(ctx context.Context, grantID, groupID string) (*domain.ContactGroup, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups/%s", c.baseURL, grantID, groupID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(groupID)) var result struct { Data contactGroupResponse `json:"data"` @@ -185,7 +186,7 @@ func (c *HTTPClient) GetContactGroup(ctx context.Context, grantID, groupID strin // CreateContactGroup creates a new contact group. func (c *HTTPClient) CreateContactGroup(ctx context.Context, grantID string, req *domain.CreateContactGroupRequest) (*domain.ContactGroup, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups", c.baseURL, url.PathEscape(grantID)) resp, err := c.doJSONRequest(ctx, "POST", queryURL, req) if err != nil { @@ -205,7 +206,7 @@ func (c *HTTPClient) CreateContactGroup(ctx context.Context, grantID string, req // UpdateContactGroup updates an existing contact group. func (c *HTTPClient) UpdateContactGroup(ctx context.Context, grantID, groupID string, req *domain.UpdateContactGroupRequest) (*domain.ContactGroup, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups/%s", c.baseURL, grantID, groupID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(groupID)) resp, err := c.doJSONRequest(ctx, "PUT", queryURL, req, http.StatusOK) if err != nil { @@ -225,7 +226,7 @@ func (c *HTTPClient) UpdateContactGroup(ctx context.Context, grantID, groupID st // DeleteContactGroup deletes a contact group. func (c *HTTPClient) DeleteContactGroup(ctx context.Context, grantID, groupID string) error { - queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups/%s", c.baseURL, grantID, groupID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/contacts/groups/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(groupID)) return c.doDelete(ctx, queryURL) } diff --git a/internal/adapters/nylas/drafts.go b/internal/adapters/nylas/drafts.go index d4a8f95..702b3fa 100644 --- a/internal/adapters/nylas/drafts.go +++ b/internal/adapters/nylas/drafts.go @@ -9,6 +9,8 @@ import ( "mime/multipart" "net/http" "net/textproto" + "net/url" + "slices" "time" "github.com/nylas/cli/internal/domain" @@ -59,7 +61,7 @@ func (c *HTTPClient) GetDrafts(ctx context.Context, grantID string, limit int) ( limit = 10 } - baseURL := fmt.Sprintf("%s/v3/grants/%s/drafts", c.baseURL, grantID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/drafts", c.baseURL, url.PathEscape(grantID)) queryURL := NewQueryBuilder().AddInt("limit", limit).BuildURL(baseURL) var result struct { @@ -74,7 +76,7 @@ func (c *HTTPClient) GetDrafts(ctx context.Context, grantID string, limit int) ( // GetDraft retrieves a single draft by ID. func (c *HTTPClient) GetDraft(ctx context.Context, grantID, draftID string) (*domain.Draft, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, grantID, draftID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(draftID)) var result struct { Data draftResponse `json:"data"` @@ -129,7 +131,7 @@ func buildDraftPayload(req *domain.CreateDraftRequest, includeSignature bool) ma // createDraftWithJSON creates a draft using JSON encoding (no attachments or small attachments). func (c *HTTPClient) createDraftWithJSON(ctx context.Context, grantID string, req *domain.CreateDraftRequest) (*domain.Draft, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts", c.baseURL, url.PathEscape(grantID)) resp, err := c.doJSONRequest(ctx, "POST", queryURL, buildDraftPayload(req, true)) if err != nil { @@ -147,25 +149,24 @@ func (c *HTTPClient) createDraftWithJSON(ctx context.Context, grantID string, re return &draft, nil } -// createDraftWithMultipart creates a draft with attachments using multipart/form-data. -func (c *HTTPClient) createDraftWithMultipart(ctx context.Context, grantID string, req *domain.CreateDraftRequest) (*domain.Draft, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts", c.baseURL, grantID) - +// doMultipartDraft sends a multipart/form-data draft request and decodes the +// response into out. Used by both createDraftWithMultipart and updateDraftWithMultipart. +func (c *HTTPClient) doMultipartDraft(ctx context.Context, method, url string, payload map[string]any, attachments []domain.Attachment, out any, acceptedStatuses ...int) error { // Create multipart form var buf bytes.Buffer writer := multipart.NewWriter(&buf) // Add message as JSON field - messageJSON, err := json.Marshal(buildDraftPayload(req, true)) + messageJSON, err := json.Marshal(payload) if err != nil { - return nil, fmt.Errorf("failed to marshal message: %w", err) + return fmt.Errorf("failed to marshal message: %w", err) } if err := writer.WriteField("message", string(messageJSON)); err != nil { - return nil, fmt.Errorf("failed to write message field: %w", err) + return fmt.Errorf("failed to write message field: %w", err) } // Add each attachment as a file - for i, att := range req.Attachments { + for i, att := range attachments { if len(att.Content) == 0 { continue // Skip attachments without content } @@ -181,38 +182,45 @@ func (c *HTTPClient) createDraftWithMultipart(ctx context.Context, grantID strin part, err := writer.CreatePart(h) if err != nil { - return nil, fmt.Errorf("failed to create attachment part: %w", err) + return fmt.Errorf("failed to create attachment part: %w", err) } if _, err := part.Write(att.Content); err != nil { - return nil, fmt.Errorf("failed to write attachment content: %w", err) + return fmt.Errorf("failed to write attachment content: %w", err) } } if err := writer.Close(); err != nil { - return nil, fmt.Errorf("failed to close multipart writer: %w", err) + return fmt.Errorf("failed to close multipart writer: %w", err) } - httpReq, err := http.NewRequestWithContext(ctx, "POST", queryURL, &buf) + httpReq, err := http.NewRequestWithContext(ctx, method, url, &buf) if err != nil { - return nil, err + return err } httpReq.Header.Set("Content-Type", writer.FormDataContentType()) c.setAuthHeader(httpReq) resp, err := c.doRequest(ctx, httpReq) if err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) + return fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - return nil, c.parseError(resp) + if !slices.Contains(acceptedStatuses, resp.StatusCode) { + defer func() { _ = resp.Body.Close() }() + return c.parseError(resp) } + return c.decodeJSONResponse(resp, out) +} + +// createDraftWithMultipart creates a draft with attachments using multipart/form-data. +func (c *HTTPClient) createDraftWithMultipart(ctx context.Context, grantID string, req *domain.CreateDraftRequest) (*domain.Draft, error) { + queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts", c.baseURL, url.PathEscape(grantID)) + var result struct { Data draftResponse `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.doMultipartDraft(ctx, "POST", queryURL, buildDraftPayload(req, true), req.Attachments, &result, http.StatusOK, http.StatusCreated); err != nil { return nil, err } @@ -223,7 +231,7 @@ func (c *HTTPClient) createDraftWithMultipart(ctx context.Context, grantID strin // CreateDraftWithAttachmentFromReader creates a draft with an attachment from an io.Reader. // This is useful for large attachments or streaming file uploads. func (c *HTTPClient) CreateDraftWithAttachmentFromReader(ctx context.Context, grantID string, req *domain.CreateDraftRequest, filename string, contentType string, reader io.Reader) (*domain.Draft, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts", c.baseURL, url.PathEscape(grantID)) payload := buildDraftPayload(req, true) // Use pipe to stream multipart data @@ -280,21 +288,22 @@ func (c *HTTPClient) CreateDraftWithAttachmentFromReader(ctx context.Context, gr if err != nil { return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - defer func() { _ = resp.Body.Close() }() // Wait for writer goroutine to finish if writerErr := <-errCh; writerErr != nil { + _ = resp.Body.Close() return nil, writerErr } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + defer func() { _ = resp.Body.Close() }() return nil, c.parseError(resp) } var result struct { Data draftResponse `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.decodeJSONResponse(resp, &result); err != nil { return nil, err } @@ -304,13 +313,13 @@ func (c *HTTPClient) CreateDraftWithAttachmentFromReader(ctx context.Context, gr // DeleteDraft deletes a draft. func (c *HTTPClient) DeleteDraft(ctx context.Context, grantID, draftID string) error { - queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, grantID, draftID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(draftID)) return c.doDelete(ctx, queryURL) } // SendDraft sends a draft. func (c *HTTPClient) SendDraft(ctx context.Context, grantID, draftID string, req *domain.SendDraftRequest) (*domain.Message, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, grantID, draftID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(draftID)) var bodyReader io.Reader if req != nil && req.SignatureID != "" { @@ -332,16 +341,16 @@ func (c *HTTPClient) SendDraft(ctx context.Context, grantID, draftID string, req if err != nil { return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + defer func() { _ = resp.Body.Close() }() return nil, c.parseError(resp) } var result struct { Data messageResponse `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.decodeJSONResponse(resp, &result); err != nil { return nil, err } diff --git a/internal/adapters/nylas/drafts_update.go b/internal/adapters/nylas/drafts_update.go index 8a9d3f5..7889cc7 100644 --- a/internal/adapters/nylas/drafts_update.go +++ b/internal/adapters/nylas/drafts_update.go @@ -1,13 +1,10 @@ package nylas import ( - "bytes" "context" - "encoding/json" "fmt" - "mime/multipart" "net/http" - "net/textproto" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -23,7 +20,7 @@ func (c *HTTPClient) UpdateDraft(ctx context.Context, grantID, draftID string, r // updateDraftWithJSON updates a draft using JSON encoding (no attachments). func (c *HTTPClient) updateDraftWithJSON(ctx context.Context, grantID, draftID string, req *domain.CreateDraftRequest) (*domain.Draft, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, grantID, draftID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(draftID)) resp, err := c.doJSONRequest(ctx, "PUT", queryURL, buildDraftPayload(req, false), http.StatusOK) if err != nil { @@ -43,73 +40,12 @@ func (c *HTTPClient) updateDraftWithJSON(ctx context.Context, grantID, draftID s // updateDraftWithMultipart updates a draft with attachments using multipart/form-data. func (c *HTTPClient) updateDraftWithMultipart(ctx context.Context, grantID, draftID string, req *domain.CreateDraftRequest) (*domain.Draft, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, grantID, draftID) - - // Build the message JSON - message := buildDraftPayload(req, false) - - // Create multipart form - var buf bytes.Buffer - writer := multipart.NewWriter(&buf) - - // Add message as JSON field - messageJSON, err := json.Marshal(message) - if err != nil { - return nil, fmt.Errorf("failed to marshal message: %w", err) - } - if err := writer.WriteField("message", string(messageJSON)); err != nil { - return nil, fmt.Errorf("failed to write message field: %w", err) - } - - // Add each attachment as a file - for i, att := range req.Attachments { - if len(att.Content) == 0 { - continue // Skip attachments without content - } - - // Create form file with proper headers - h := make(textproto.MIMEHeader) - h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file%d"; filename="%s"`, i, att.Filename)) - if att.ContentType != "" { - h.Set("Content-Type", att.ContentType) - } else { - h.Set("Content-Type", "application/octet-stream") - } - - part, err := writer.CreatePart(h) - if err != nil { - return nil, fmt.Errorf("failed to create attachment part: %w", err) - } - if _, err := part.Write(att.Content); err != nil { - return nil, fmt.Errorf("failed to write attachment content: %w", err) - } - } - - if err := writer.Close(); err != nil { - return nil, fmt.Errorf("failed to close multipart writer: %w", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, "PUT", queryURL, &buf) - if err != nil { - return nil, err - } - httpReq.Header.Set("Content-Type", writer.FormDataContentType()) - c.setAuthHeader(httpReq) - - resp, err := c.doRequest(ctx, httpReq) - if err != nil { - return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return nil, c.parseError(resp) - } + queryURL := fmt.Sprintf("%s/v3/grants/%s/drafts/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(draftID)) var result struct { Data draftResponse `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.doMultipartDraft(ctx, "PUT", queryURL, buildDraftPayload(req, false), req.Attachments, &result, http.StatusOK); err != nil { return nil, err } diff --git a/internal/adapters/nylas/folders.go b/internal/adapters/nylas/folders.go index 719efa7..f5e394f 100644 --- a/internal/adapters/nylas/folders.go +++ b/internal/adapters/nylas/folders.go @@ -3,6 +3,7 @@ package nylas import ( "context" "fmt" + "net/url" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/util" @@ -29,7 +30,7 @@ func (c *HTTPClient) GetFolders(ctx context.Context, grantID string) ([]domain.F return nil, err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/folders", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/folders", c.baseURL, url.PathEscape(grantID)) var result struct { Data []folderResponse `json:"data"` @@ -50,7 +51,7 @@ func (c *HTTPClient) GetFolder(ctx context.Context, grantID, folderID string) (* return nil, err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/folders/%s", c.baseURL, grantID, folderID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/folders/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(folderID)) var result struct { Data folderResponse `json:"data"` @@ -69,7 +70,7 @@ func (c *HTTPClient) CreateFolder(ctx context.Context, grantID string, req *doma return nil, err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/folders", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/folders", c.baseURL, url.PathEscape(grantID)) payload := map[string]any{ "name": req.Name, @@ -109,7 +110,7 @@ func (c *HTTPClient) UpdateFolder(ctx context.Context, grantID, folderID string, return nil, err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/folders/%s", c.baseURL, grantID, folderID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/folders/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(folderID)) payload := make(map[string]any, 4) // Pre-allocate for up to 4 fields if req.Name != "" { @@ -149,7 +150,7 @@ func (c *HTTPClient) DeleteFolder(ctx context.Context, grantID, folderID string) if err := validateRequired("folder ID", folderID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/folders/%s", c.baseURL, grantID, folderID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/folders/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(folderID)) return c.doDelete(ctx, queryURL) } diff --git a/internal/adapters/nylas/managed_grants.go b/internal/adapters/nylas/managed_grants.go index f493ef3..212259e 100644 --- a/internal/adapters/nylas/managed_grants.go +++ b/internal/adapters/nylas/managed_grants.go @@ -3,6 +3,7 @@ package nylas import ( "context" "fmt" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -23,20 +24,27 @@ type agentSettingsPayload struct { PolicyID string `json:"policy_id,omitempty"` } +// listManagedGrants returns every grant whose provider matches `provider`. +// +// We deliberately do NOT pass `provider=` as a server-side filter: +// the server-side filtered listing has been observed to lag freshly- +// created managed grants by tens of seconds (>70s in the worst case), +// while the unfiltered listing surfaces new grants almost immediately. +// Trade ~4x more page bytes (typical tenants have <100 grants) for +// freshness and predictability. We filter to `provider` client-side. func (c *HTTPClient) listManagedGrants(ctx context.Context, provider domain.Provider) ([]managedGrantResponse, error) { baseURL := fmt.Sprintf("%s/v3/grants", c.baseURL) - pageToken := "" + offset := 0 grants := make([]managedGrantResponse, 0) - for { + for range maxGrantPages { queryURL := NewQueryBuilder(). - Add("provider", string(provider)). - Add("page_token", pageToken). + AddInt("limit", grantPageSize). + AddInt("offset", offset). BuildURL(baseURL) var result struct { - Data []managedGrantResponse `json:"data"` - NextCursor string `json:"next_cursor,omitempty"` + Data []managedGrantResponse `json:"data"` } if err := c.doGet(ctx, queryURL, &result); err != nil { return nil, err @@ -48,20 +56,16 @@ func (c *HTTPClient) listManagedGrants(ctx context.Context, provider domain.Prov } } - if result.NextCursor == "" { - break + if len(result.Data) < grantPageSize { + return grants, nil } - if result.NextCursor == pageToken { - return nil, fmt.Errorf("failed to paginate managed grants: repeated cursor %q", result.NextCursor) - } - pageToken = result.NextCursor + offset += len(result.Data) } - - return grants, nil + return nil, fmt.Errorf("failed to paginate managed grants: exceeded max page count (%d)", maxGrantPages) } func (c *HTTPClient) getManagedGrant(ctx context.Context, grantID string) (*managedGrantResponse, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s", c.baseURL, url.PathEscape(grantID)) var result struct { Data managedGrantResponse `json:"data"` diff --git a/internal/adapters/nylas/messages.go b/internal/adapters/nylas/messages.go index c6d5233..32948f7 100644 --- a/internal/adapters/nylas/messages.go +++ b/internal/adapters/nylas/messages.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "net/url" "strings" "time" @@ -94,7 +95,7 @@ func (c *HTTPClient) GetMessagesWithCursor(ctx context.Context, grantID string, params.Limit = 10 } - baseURL := fmt.Sprintf("%s/v3/grants/%s/messages", c.baseURL, grantID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/messages", c.baseURL, url.PathEscape(grantID)) queryURL := NewQueryBuilder(). AddInt("limit", params.Limit). Add("page_token", params.PageToken). @@ -147,7 +148,7 @@ func (c *HTTPClient) GetMessageWithFields(ctx context.Context, grantID, messageI return nil, err } - baseURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s", c.baseURL, grantID, messageID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(messageID)) queryURL := NewQueryBuilder().Add("fields", fields).BuildURL(baseURL) var result struct { @@ -163,7 +164,7 @@ func (c *HTTPClient) GetMessageWithFields(ctx context.Context, grantID, messageI // UpdateMessage updates message properties. func (c *HTTPClient) UpdateMessage(ctx context.Context, grantID, messageID string, req *domain.UpdateMessageRequest) (*domain.Message, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s", c.baseURL, grantID, messageID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(messageID)) payload := make(map[string]any, 3) // Pre-allocate for up to 3 fields if req.Unread != nil { @@ -194,7 +195,7 @@ func (c *HTTPClient) UpdateMessage(ctx context.Context, grantID, messageID strin // DeleteMessage deletes a message (moves to trash). func (c *HTTPClient) DeleteMessage(ctx context.Context, grantID, messageID string) error { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s", c.baseURL, grantID, messageID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(messageID)) return c.doDelete(ctx, queryURL) } diff --git a/internal/adapters/nylas/messages_send.go b/internal/adapters/nylas/messages_send.go index d020c12..7cd4243 100644 --- a/internal/adapters/nylas/messages_send.go +++ b/internal/adapters/nylas/messages_send.go @@ -8,6 +8,7 @@ import ( "io" "mime/multipart" "net/http" + "net/url" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/util" @@ -53,7 +54,7 @@ func buildSendMessagePayload(req *domain.SendMessageRequest, includeSignature bo // SendMessage sends an email. func (c *HTTPClient) SendMessage(ctx context.Context, grantID string, req *domain.SendMessageRequest) (*domain.Message, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/send", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/send", c.baseURL, url.PathEscape(grantID)) resp, err := c.doJSONRequest(ctx, "POST", queryURL, buildSendMessagePayload(req, true), http.StatusOK, http.StatusCreated, http.StatusAccepted) if err != nil { @@ -74,7 +75,7 @@ func (c *HTTPClient) SendMessage(ctx context.Context, grantID string, req *domai // SendRawMessage sends a raw RFC 822 MIME message via the Nylas API. // Uses multipart/form-data with type=mime query parameter per Nylas API v3 spec. func (c *HTTPClient) SendRawMessage(ctx context.Context, grantID string, rawMIME []byte) (*domain.Message, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/send?type=mime", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/send?type=mime", c.baseURL, url.PathEscape(grantID)) // Create multipart form data var body bytes.Buffer @@ -106,10 +107,10 @@ func (c *HTTPClient) SendRawMessage(ctx context.Context, grantID string, rawMIME if err != nil { return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - defer func() { _ = resp.Body.Close() }() // Check status code if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusAccepted { + defer func() { _ = resp.Body.Close() }() return nil, c.parseError(resp) } @@ -117,7 +118,7 @@ func (c *HTTPClient) SendRawMessage(ctx context.Context, grantID string, rawMIME var result struct { Data messageResponse `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.decodeJSONResponse(resp, &result); err != nil { return nil, err } @@ -127,7 +128,7 @@ func (c *HTTPClient) SendRawMessage(ctx context.Context, grantID string, rawMIME // ListScheduledMessages retrieves all scheduled messages for a grant. func (c *HTTPClient) ListScheduledMessages(ctx context.Context, grantID string) ([]domain.ScheduledMessage, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/schedules", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/schedules", c.baseURL, url.PathEscape(grantID)) req, err := http.NewRequestWithContext(ctx, "GET", queryURL, nil) if err != nil { @@ -139,9 +140,9 @@ func (c *HTTPClient) ListScheduledMessages(ctx context.Context, grantID string) if err != nil { return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { + defer func() { _ = resp.Body.Close() }() return nil, c.parseError(resp) } @@ -155,7 +156,7 @@ func (c *HTTPClient) ListScheduledMessages(ctx context.Context, grantID string) CloseTime int64 `json:"close_time"` } `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.decodeJSONResponse(resp, &result); err != nil { return nil, err } @@ -177,7 +178,7 @@ func (c *HTTPClient) ListScheduledMessages(ctx context.Context, grantID string) // GetScheduledMessage retrieves a specific scheduled message. func (c *HTTPClient) GetScheduledMessage(ctx context.Context, grantID, scheduleID string) (*domain.ScheduledMessage, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/schedules/%s", c.baseURL, grantID, scheduleID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/schedules/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(scheduleID)) req, err := http.NewRequestWithContext(ctx, "GET", queryURL, nil) if err != nil { @@ -189,12 +190,13 @@ func (c *HTTPClient) GetScheduledMessage(ctx context.Context, grantID, scheduleI if err != nil { return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode == http.StatusNotFound { + _ = resp.Body.Close() return nil, domain.ErrMessageNotFound } if resp.StatusCode != http.StatusOK { + defer func() { _ = resp.Body.Close() }() return nil, c.parseError(resp) } @@ -208,7 +210,7 @@ func (c *HTTPClient) GetScheduledMessage(ctx context.Context, grantID, scheduleI CloseTime int64 `json:"close_time"` } `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.decodeJSONResponse(resp, &result); err != nil { return nil, err } @@ -221,7 +223,7 @@ func (c *HTTPClient) GetScheduledMessage(ctx context.Context, grantID, scheduleI // CancelScheduledMessage cancels a scheduled message. func (c *HTTPClient) CancelScheduledMessage(ctx context.Context, grantID, scheduleID string) error { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/schedules/%s", c.baseURL, grantID, scheduleID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/schedules/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(scheduleID)) req, err := http.NewRequestWithContext(ctx, "DELETE", queryURL, nil) if err != nil { @@ -245,7 +247,7 @@ func (c *HTTPClient) CancelScheduledMessage(ctx context.Context, grantID, schedu // SmartCompose generates an AI-powered email draft based on a prompt. // Uses Nylas Smart Compose API (requires Plus package). func (c *HTTPClient) SmartCompose(ctx context.Context, grantID string, req *domain.SmartComposeRequest) (*domain.SmartComposeSuggestion, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/smart-compose", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/smart-compose", c.baseURL, url.PathEscape(grantID)) payload := map[string]any{ "prompt": req.Prompt, @@ -268,16 +270,16 @@ func (c *HTTPClient) SmartCompose(ctx context.Context, grantID string, req *doma if err != nil { return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { + defer func() { _ = resp.Body.Close() }() return nil, c.parseError(resp) } var result struct { Data domain.SmartComposeSuggestion `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.decodeJSONResponse(resp, &result); err != nil { return nil, err } @@ -287,7 +289,7 @@ func (c *HTTPClient) SmartCompose(ctx context.Context, grantID string, req *doma // SmartComposeReply generates an AI-powered reply to a specific message. // Uses Nylas Smart Compose API (requires Plus package). func (c *HTTPClient) SmartComposeReply(ctx context.Context, grantID, messageID string, req *domain.SmartComposeRequest) (*domain.SmartComposeSuggestion, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s/smart-compose", c.baseURL, grantID, messageID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/messages/%s/smart-compose", c.baseURL, url.PathEscape(grantID), url.PathEscape(messageID)) payload := map[string]any{ "prompt": req.Prompt, @@ -310,16 +312,16 @@ func (c *HTTPClient) SmartComposeReply(ctx context.Context, grantID, messageID s if err != nil { return nil, fmt.Errorf("%w: %v", domain.ErrNetworkError, err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { + defer func() { _ = resp.Body.Close() }() return nil, c.parseError(resp) } var result struct { Data domain.SmartComposeSuggestion `json:"data"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := c.decodeJSONResponse(resp, &result); err != nil { return nil, err } diff --git a/internal/adapters/nylas/notetakers.go b/internal/adapters/nylas/notetakers.go index 55e2c18..0561f02 100644 --- a/internal/adapters/nylas/notetakers.go +++ b/internal/adapters/nylas/notetakers.go @@ -3,6 +3,7 @@ package nylas import ( "context" "fmt" + "net/url" "time" "github.com/nylas/cli/internal/domain" @@ -52,7 +53,7 @@ func (c *HTTPClient) ListNotetakers(ctx context.Context, grantID string, params params.Limit = 10 } - baseURL := fmt.Sprintf("%s/v3/grants/%s/notetakers", c.baseURL, grantID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/notetakers", c.baseURL, url.PathEscape(grantID)) queryURL := NewQueryBuilder(). AddInt("limit", params.Limit). Add("page_token", params.PageToken). @@ -71,7 +72,7 @@ func (c *HTTPClient) ListNotetakers(ctx context.Context, grantID string, params // GetNotetaker retrieves a single notetaker by ID. func (c *HTTPClient) GetNotetaker(ctx context.Context, grantID, notetakerID string) (*domain.Notetaker, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/notetakers/%s", c.baseURL, grantID, notetakerID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/notetakers/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(notetakerID)) var result struct { Data notetakerResponse `json:"data"` @@ -86,7 +87,7 @@ func (c *HTTPClient) GetNotetaker(ctx context.Context, grantID, notetakerID stri // CreateNotetaker creates a new notetaker to join a meeting. func (c *HTTPClient) CreateNotetaker(ctx context.Context, grantID string, req *domain.CreateNotetakerRequest) (*domain.Notetaker, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/notetakers", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/notetakers", c.baseURL, url.PathEscape(grantID)) payload := map[string]any{ "meeting_link": req.MeetingLink, @@ -126,13 +127,13 @@ func (c *HTTPClient) CreateNotetaker(ctx context.Context, grantID string, req *d // DeleteNotetaker deletes/cancels a notetaker. func (c *HTTPClient) DeleteNotetaker(ctx context.Context, grantID, notetakerID string) error { - queryURL := fmt.Sprintf("%s/v3/grants/%s/notetakers/%s", c.baseURL, grantID, notetakerID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/notetakers/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(notetakerID)) return c.doDelete(ctx, queryURL) } // GetNotetakerMedia retrieves the media (recording/transcript) for a notetaker. func (c *HTTPClient) GetNotetakerMedia(ctx context.Context, grantID, notetakerID string) (*domain.MediaData, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/notetakers/%s/media", c.baseURL, grantID, notetakerID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/notetakers/%s/media", c.baseURL, url.PathEscape(grantID), url.PathEscape(notetakerID)) var result struct { Data struct { diff --git a/internal/adapters/nylas/scheduler.go b/internal/adapters/nylas/scheduler.go index 65ac5bb..58d29c2 100644 --- a/internal/adapters/nylas/scheduler.go +++ b/internal/adapters/nylas/scheduler.go @@ -30,7 +30,7 @@ func (c *HTTPClient) GetSchedulerConfiguration(ctx context.Context, configID str return nil, err } - queryURL := fmt.Sprintf("%s/v3/scheduling/configurations/%s", c.baseURL, configID) + queryURL := fmt.Sprintf("%s/v3/scheduling/configurations/%s", c.baseURL, url.PathEscape(configID)) var result struct { Data domain.SchedulerConfiguration `json:"data"` @@ -65,7 +65,7 @@ func (c *HTTPClient) UpdateSchedulerConfiguration(ctx context.Context, configID return nil, err } - queryURL := fmt.Sprintf("%s/v3/scheduling/configurations/%s", c.baseURL, configID) + queryURL := fmt.Sprintf("%s/v3/scheduling/configurations/%s", c.baseURL, url.PathEscape(configID)) resp, err := c.doJSONRequest(ctx, "PUT", queryURL, req) if err != nil { @@ -86,7 +86,7 @@ func (c *HTTPClient) DeleteSchedulerConfiguration(ctx context.Context, configID if err := validateRequired("configuration ID", configID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/scheduling/configurations/%s", c.baseURL, configID) + queryURL := fmt.Sprintf("%s/v3/scheduling/configurations/%s", c.baseURL, url.PathEscape(configID)) return c.doDelete(ctx, queryURL) } @@ -116,7 +116,7 @@ func (c *HTTPClient) GetSchedulerSession(ctx context.Context, sessionID string) return nil, err } - queryURL := fmt.Sprintf("%s/v3/scheduling/sessions/%s", c.baseURL, sessionID) + queryURL := fmt.Sprintf("%s/v3/scheduling/sessions/%s", c.baseURL, url.PathEscape(sessionID)) var result struct { Data domain.SchedulerSession `json:"data"` @@ -149,7 +149,7 @@ func (c *HTTPClient) GetBooking(ctx context.Context, bookingID string) (*domain. return nil, err } - queryURL := fmt.Sprintf("%s/v3/scheduling/bookings/%s", c.baseURL, bookingID) + queryURL := fmt.Sprintf("%s/v3/scheduling/bookings/%s", c.baseURL, url.PathEscape(bookingID)) var result struct { Data domain.Booking `json:"data"` @@ -166,7 +166,7 @@ func (c *HTTPClient) ConfirmBooking(ctx context.Context, bookingID string, req * return nil, err } - queryURL := fmt.Sprintf("%s/v3/scheduling/bookings/%s", c.baseURL, bookingID) + queryURL := fmt.Sprintf("%s/v3/scheduling/bookings/%s", c.baseURL, url.PathEscape(bookingID)) resp, err := c.doJSONRequest(ctx, "PUT", queryURL, req) if err != nil { @@ -243,7 +243,7 @@ func (c *HTTPClient) GetSchedulerPage(ctx context.Context, pageID string) (*doma return nil, err } - queryURL := fmt.Sprintf("%s/v3/scheduling/pages/%s", c.baseURL, pageID) + queryURL := fmt.Sprintf("%s/v3/scheduling/pages/%s", c.baseURL, url.PathEscape(pageID)) var result struct { Data domain.SchedulerPage `json:"data"` @@ -278,7 +278,7 @@ func (c *HTTPClient) UpdateSchedulerPage(ctx context.Context, pageID string, req return nil, err } - queryURL := fmt.Sprintf("%s/v3/scheduling/pages/%s", c.baseURL, pageID) + queryURL := fmt.Sprintf("%s/v3/scheduling/pages/%s", c.baseURL, url.PathEscape(pageID)) resp, err := c.doJSONRequest(ctx, "PUT", queryURL, req) if err != nil { @@ -299,6 +299,6 @@ func (c *HTTPClient) DeleteSchedulerPage(ctx context.Context, pageID string) err if err := validateRequired("page ID", pageID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/scheduling/pages/%s", c.baseURL, pageID) + queryURL := fmt.Sprintf("%s/v3/scheduling/pages/%s", c.baseURL, url.PathEscape(pageID)) return c.doDelete(ctx, queryURL) } diff --git a/internal/adapters/nylas/signatures.go b/internal/adapters/nylas/signatures.go index a83722d..1145d86 100644 --- a/internal/adapters/nylas/signatures.go +++ b/internal/adapters/nylas/signatures.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" "time" "github.com/nylas/cli/internal/domain" @@ -20,7 +21,7 @@ type signatureResponse struct { // GetSignatures retrieves all signatures for a grant. func (c *HTTPClient) GetSignatures(ctx context.Context, grantID string) ([]domain.Signature, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures", c.baseURL, url.PathEscape(grantID)) var result struct { Data []signatureResponse `json:"data"` @@ -34,7 +35,7 @@ func (c *HTTPClient) GetSignatures(ctx context.Context, grantID string) ([]domai // GetSignature retrieves a specific signature. func (c *HTTPClient) GetSignature(ctx context.Context, grantID, signatureID string) (*domain.Signature, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures/%s", c.baseURL, grantID, signatureID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(signatureID)) var result struct { Data signatureResponse `json:"data"` @@ -49,7 +50,7 @@ func (c *HTTPClient) GetSignature(ctx context.Context, grantID, signatureID stri // CreateSignature creates a new signature. func (c *HTTPClient) CreateSignature(ctx context.Context, grantID string, req *domain.CreateSignatureRequest) (*domain.Signature, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures", c.baseURL, grantID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures", c.baseURL, url.PathEscape(grantID)) resp, err := c.doJSONRequestNoRetry(ctx, http.MethodPost, queryURL, req) if err != nil { @@ -69,7 +70,7 @@ func (c *HTTPClient) CreateSignature(ctx context.Context, grantID string, req *d // UpdateSignature updates an existing signature. func (c *HTTPClient) UpdateSignature(ctx context.Context, grantID, signatureID string, req *domain.UpdateSignatureRequest) (*domain.Signature, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures/%s", c.baseURL, grantID, signatureID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(signatureID)) resp, err := c.doJSONRequestNoRetry(ctx, http.MethodPut, queryURL, req, http.StatusOK) if err != nil { @@ -89,7 +90,7 @@ func (c *HTTPClient) UpdateSignature(ctx context.Context, grantID, signatureID s // DeleteSignature deletes a signature. func (c *HTTPClient) DeleteSignature(ctx context.Context, grantID, signatureID string) error { - queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures/%s", c.baseURL, grantID, signatureID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/signatures/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(signatureID)) return c.doDelete(ctx, queryURL) } diff --git a/internal/adapters/nylas/threads.go b/internal/adapters/nylas/threads.go index 3c6c027..be4d878 100644 --- a/internal/adapters/nylas/threads.go +++ b/internal/adapters/nylas/threads.go @@ -3,6 +3,7 @@ package nylas import ( "context" "fmt" + "net/url" "time" "github.com/nylas/cli/internal/domain" @@ -40,7 +41,7 @@ func (c *HTTPClient) GetThreads(ctx context.Context, grantID string, params *dom params.Limit = 10 } - baseURL := fmt.Sprintf("%s/v3/grants/%s/threads", c.baseURL, grantID) + baseURL := fmt.Sprintf("%s/v3/grants/%s/threads", c.baseURL, url.PathEscape(grantID)) queryURL := NewQueryBuilder(). AddInt("limit", params.Limit). AddInt("offset", params.Offset). @@ -64,7 +65,7 @@ func (c *HTTPClient) GetThreads(ctx context.Context, grantID string, params *dom // GetThread retrieves a single thread by ID. func (c *HTTPClient) GetThread(ctx context.Context, grantID, threadID string) (*domain.Thread, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/threads/%s", c.baseURL, grantID, threadID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/threads/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(threadID)) var result struct { Data threadResponse `json:"data"` @@ -79,7 +80,7 @@ func (c *HTTPClient) GetThread(ctx context.Context, grantID, threadID string) (* // UpdateThread updates thread properties. func (c *HTTPClient) UpdateThread(ctx context.Context, grantID, threadID string, req *domain.UpdateMessageRequest) (*domain.Thread, error) { - queryURL := fmt.Sprintf("%s/v3/grants/%s/threads/%s", c.baseURL, grantID, threadID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/threads/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(threadID)) payload := make(map[string]any) if req.Unread != nil { @@ -113,7 +114,7 @@ func (c *HTTPClient) DeleteThread(ctx context.Context, grantID, threadID string) if err := validateRequired("thread ID", threadID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/grants/%s/threads/%s", c.baseURL, grantID, threadID) + queryURL := fmt.Sprintf("%s/v3/grants/%s/threads/%s", c.baseURL, url.PathEscape(grantID), url.PathEscape(threadID)) return c.doDelete(ctx, queryURL) } diff --git a/internal/adapters/nylas/transactional.go b/internal/adapters/nylas/transactional.go index 3f9bea0..16392c1 100644 --- a/internal/adapters/nylas/transactional.go +++ b/internal/adapters/nylas/transactional.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" "github.com/nylas/cli/internal/domain" ) @@ -24,7 +25,7 @@ func buildTransactionalSendPayload(req *domain.SendMessageRequest) map[string]an // SendTransactionalMessage sends an email via the domain-based transactional endpoint. // Used for managed Nylas grants: POST /v3/domains/{domain}/messages/send func (c *HTTPClient) SendTransactionalMessage(ctx context.Context, domainName string, req *domain.SendMessageRequest) (*domain.Message, error) { - queryURL := fmt.Sprintf("%s/v3/domains/%s/messages/send", c.baseURL, domainName) + queryURL := fmt.Sprintf("%s/v3/domains/%s/messages/send", c.baseURL, url.PathEscape(domainName)) resp, err := c.doJSONRequest(ctx, "POST", queryURL, buildTransactionalSendPayload(req), http.StatusOK, http.StatusCreated, http.StatusAccepted) if err != nil { diff --git a/internal/adapters/nylas/webhooks.go b/internal/adapters/nylas/webhooks.go index 67d2d72..3e23e7b 100644 --- a/internal/adapters/nylas/webhooks.go +++ b/internal/adapters/nylas/webhooks.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/url" "time" "github.com/nylas/cli/internal/domain" @@ -44,7 +45,7 @@ func (c *HTTPClient) GetWebhook(ctx context.Context, webhookID string) (*domain. return nil, err } - queryURL := fmt.Sprintf("%s/v3/webhooks/%s", c.baseURL, webhookID) + queryURL := fmt.Sprintf("%s/v3/webhooks/%s", c.baseURL, url.PathEscape(webhookID)) var result struct { Data webhookResponse `json:"data"` @@ -83,7 +84,7 @@ func (c *HTTPClient) UpdateWebhook(ctx context.Context, webhookID string, req *d return nil, err } - queryURL := fmt.Sprintf("%s/v3/webhooks/%s", c.baseURL, webhookID) + queryURL := fmt.Sprintf("%s/v3/webhooks/%s", c.baseURL, url.PathEscape(webhookID)) resp, err := c.doJSONRequest(ctx, "PUT", queryURL, req) if err != nil { @@ -106,7 +107,7 @@ func (c *HTTPClient) DeleteWebhook(ctx context.Context, webhookID string) error if err := validateRequired("webhook ID", webhookID); err != nil { return err } - queryURL := fmt.Sprintf("%s/v3/webhooks/%s", c.baseURL, webhookID) + queryURL := fmt.Sprintf("%s/v3/webhooks/%s", c.baseURL, url.PathEscape(webhookID)) return c.doDelete(ctx, queryURL) } @@ -119,7 +120,7 @@ func (c *HTTPClient) RotateWebhookSecret( return nil, err } - queryURL := fmt.Sprintf("%s/v3/webhooks/rotate-secret/%s", c.baseURL, webhookID) + queryURL := fmt.Sprintf("%s/v3/webhooks/rotate-secret/%s", c.baseURL, url.PathEscape(webhookID)) resp, err := c.doJSONRequest(ctx, "POST", queryURL, nil) if err != nil { diff --git a/internal/adapters/oauth/server.go b/internal/adapters/oauth/server.go index af9898d..23a6887 100644 --- a/internal/adapters/oauth/server.go +++ b/internal/adapters/oauth/server.go @@ -4,9 +4,11 @@ package oauth import ( "context" "crypto/subtle" + "errors" "fmt" "net" "net/http" + "strings" "sync" "time" @@ -15,14 +17,15 @@ import ( // CallbackServer implements the OAuth callback server. type CallbackServer struct { - port int - server *http.Server - listener net.Listener - codeChan chan string - errChan chan error - once sync.Once - mu sync.RWMutex - state string + port int + server *http.Server + listener net.Listener + listeners []net.Listener + codeChan chan string + errChan chan error + once sync.Once + mu sync.RWMutex + state string } // NewCallbackServer creates a new callback server. @@ -52,21 +55,73 @@ func (s *CallbackServer) Start() error { IdleTimeout: 60 * time.Second, } - var err error - s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.port)) + // Bind to loopback only. The browser follows the advertised localhost + // redirect URI, which can resolve to either IPv4 or IPv6 loopback + // depending on host configuration. Listen on both loopback families when + // available without accepting LAN traffic. + listeners, port, err := listenLoopback(s.port) if err != nil { return fmt.Errorf("failed to start callback server: %w", err) } + s.port = port + s.listeners = listeners + s.listener = listeners[0] - go func() { - if err := s.server.Serve(s.listener); err != http.ErrServerClosed { - s.errChan <- err - } - }() + for _, listener := range listeners { + go s.serve(listener) + } return nil } +func listenLoopback(port int) ([]net.Listener, int, error) { + ipv4, err := net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return nil, 0, err + } + + actualPort := port + if actualPort == 0 { + tcpAddr, ok := ipv4.Addr().(*net.TCPAddr) + if !ok { + _ = ipv4.Close() + return nil, 0, fmt.Errorf("unexpected listener address type %T", ipv4.Addr()) + } + actualPort = tcpAddr.Port + } + + listeners := []net.Listener{ipv4} + ipv6, err := net.Listen("tcp6", fmt.Sprintf("[::1]:%d", actualPort)) + if err != nil { + if !isIPv6LoopbackUnavailable(err) { + _ = ipv4.Close() + return nil, 0, fmt.Errorf("failed to start IPv6 callback listener on port %d: %w", actualPort, err) + } + return listeners, actualPort, nil + } + + return append(listeners, ipv6), actualPort, nil +} + +func isIPv6LoopbackUnavailable(err error) bool { + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "address family not supported") || + strings.Contains(msg, "cannot assign requested address") || + strings.Contains(msg, "can't assign requested address") || + strings.Contains(msg, "protocol not available") +} + +func (s *CallbackServer) serve(listener net.Listener) { + if err := s.server.Serve(listener); err != nil && + !errors.Is(err, http.ErrServerClosed) && + !errors.Is(err, net.ErrClosed) { + select { + case s.errChan <- err: + default: + } + } +} + // Stop stops the callback server. func (s *CallbackServer) Stop() error { if s.server != nil { diff --git a/internal/adapters/oauth/server_test.go b/internal/adapters/oauth/server_test.go index 3be99a9..d95d1fc 100644 --- a/internal/adapters/oauth/server_test.go +++ b/internal/adapters/oauth/server_test.go @@ -3,8 +3,10 @@ package oauth import ( "context" "errors" + "net" "net/http" "net/http/httptest" + "strconv" "testing" "time" @@ -59,6 +61,37 @@ func TestCallbackServer_GetRedirectURI(t *testing.T) { } } +func TestCallbackServer_StartAcceptsIPv6LoopbackForAdvertisedLocalhost(t *testing.T) { + probe, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skipf("IPv6 loopback is not available: %v", err) + } + _ = probe.Close() + + server := NewCallbackServer(0) + if err := server.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + defer func() { _ = server.Stop() }() + server.setExpectedState("test-state") + + client := &http.Client{ + Timeout: time.Second, + Transport: &http.Transport{ + Proxy: nil, + }, + } + + resp, err := client.Get("http://[::1]:" + strconv.Itoa(server.port) + "/callback?code=test-code&state=test-state") + if err != nil { + t.Fatalf("IPv6 loopback callback request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + t.Fatalf("IPv6 loopback callback status = %d, want %d", resp.StatusCode, http.StatusOK) + } +} + func TestCallbackServer_handleCallback_Success(t *testing.T) { server := NewCallbackServer(8080) server.setExpectedState("test-state-123") diff --git a/internal/adapters/webhookserver/server.go b/internal/adapters/webhookserver/server.go index 459eeb7..8ee72bb 100644 --- a/internal/adapters/webhookserver/server.go +++ b/internal/adapters/webhookserver/server.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "html/template" "io" "net" "net/http" @@ -14,17 +15,76 @@ import ( "github.com/nylas/cli/internal/ports" ) +const ( + // maxWebhookBodyBytes caps the request body size accepted by the webhook + // receiver. Nylas events are well under 100 KB; the cap exists to prevent + // a malicious sender (the public tunnel URL is reachable by anyone) from + // asking the server to allocate gigabytes of RAM before HMAC verifies. + maxWebhookBodyBytes = 1 << 20 // 1 MiB + + // maxConcurrentHandlers bounds the goroutines that fan-out registered + // handlers. Without a bound, a flood of events combined with slow handlers + // would let an attacker drive unbounded goroutine creation. + maxConcurrentHandlers = 32 +) + +// LocalBaseURL returns the loopback URL used by the webhook receiver and +// local tunnels. It intentionally uses IPv4 loopback because the server binds +// to 127.0.0.1, not localhost's platform-dependent IPv4/IPv6 resolution. +func LocalBaseURL(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d", port) +} + +func localEndpointURL(port int, path string) string { + return LocalBaseURL(port) + path +} + +// rootTemplate renders the webhook server landing page. html/template HTML- +// escapes every value, preventing PublicURL (and any future user-derived +// fields) from breaking out of the document. +var rootTemplate = template.Must(template.New("root").Parse(` + + + Nylas Webhook Server + + + +

Nylas Webhook Server

+
+
Status: Running
+
Events Received: {{.EventsReceived}}
+
Started: {{.StartedAt}}
+
+

Webhook Endpoint

+
{{.PublicURL}}
+

Send POST requests to this URL to receive webhook events.

+

Health Check

+
{{.HealthURL}}
+ +`)) + // Server implements the WebhookServer interface. type Server struct { - config ports.WebhookServerConfig - server *http.Server - listener net.Listener - tunnel ports.Tunnel - events chan *ports.WebhookEvent - handlers []ports.WebhookEventHandler - stats ports.WebhookServerStats - mu sync.RWMutex - startedAt time.Time + config ports.WebhookServerConfig + server *http.Server + listener net.Listener + tunnel ports.Tunnel + events chan *ports.WebhookEvent + handlers []ports.WebhookEventHandler + handlerSlots chan struct{} + seenSignatures map[string]time.Time + stats ports.WebhookServerStats + mu sync.RWMutex + startedAt time.Time + closeOnce sync.Once } // NewServer creates a new webhook server. @@ -37,10 +97,12 @@ func NewServer(config ports.WebhookServerConfig) *Server { } return &Server{ - config: config, - events: make(chan *ports.WebhookEvent, 100), + config: config, + events: make(chan *ports.WebhookEvent, 100), + handlerSlots: make(chan struct{}, maxConcurrentHandlers), + seenSignatures: make(map[string]time.Time), stats: ports.WebhookServerStats{ - LocalURL: fmt.Sprintf("http://localhost:%d%s", config.Port, config.Path), + LocalURL: localEndpointURL(config.Port, config.Path), }, } } @@ -66,9 +128,12 @@ func (s *Server) Start(ctx context.Context) error { ReadHeaderTimeout: 10 * time.Second, } - // Start listener + // Start listener bound to loopback only. Tunnels (cloudflared, ngrok) + // connect to 127.0.0.1 — there is no use case for accepting webhooks + // directly from the LAN, and binding to 0.0.0.0 would let any host on + // the local network forge webhook events. var err error - s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.config.Port)) + s.listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", s.config.Port)) if err != nil { return fmt.Errorf("failed to start listener on port %d: %w", s.config.Port, err) } @@ -76,7 +141,7 @@ func (s *Server) Start(ctx context.Context) error { s.startedAt = time.Now() s.mu.Lock() s.stats.StartedAt = s.startedAt - s.stats.LocalURL = fmt.Sprintf("http://localhost:%d%s", s.config.Port, s.config.Path) + s.stats.LocalURL = localEndpointURL(s.config.Port, s.config.Path) s.mu.Unlock() // Start HTTP server in goroutine @@ -86,7 +151,6 @@ func (s *Server) Start(ctx context.Context) error { // Start tunnel if configured if s.tunnel != nil { - localURL := fmt.Sprintf("http://localhost:%d", s.config.Port) publicURL, err := s.tunnel.Start(ctx) if err != nil { _ = s.Stop() // Ignore stop error - we're returning tunnel start error @@ -98,35 +162,38 @@ func (s *Server) Start(ctx context.Context) error { s.stats.TunnelProvider = s.config.TunnelProvider s.stats.TunnelStatus = string(s.tunnel.Status()) s.mu.Unlock() - - _ = localURL // used by tunnel } return nil } -// Stop stops the webhook server and tunnel. +// Stop stops the webhook server and tunnel. Safe to call more than once — +// the channel is closed under sync.Once, and the events channel is only +// closed after http.Server.Shutdown returns (which waits for in-flight +// handlers to complete) so producers cannot race the close. func (s *Server) Stop() error { var errs []error - // Stop tunnel first - if s.tunnel != nil { - if err := s.tunnel.Stop(); err != nil { - errs = append(errs, fmt.Errorf("tunnel stop: %w", err)) + s.closeOnce.Do(func() { + // Stop tunnel first. + if s.tunnel != nil { + if err := s.tunnel.Stop(); err != nil { + errs = append(errs, fmt.Errorf("tunnel stop: %w", err)) + } } - } - // Stop HTTP server - if s.server != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := s.server.Shutdown(ctx); err != nil { - errs = append(errs, fmt.Errorf("server shutdown: %w", err)) + // Stop HTTP server. Shutdown blocks until all in-flight requests have + // finished, so when it returns no handler is sending to s.events. + if s.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.server.Shutdown(ctx); err != nil { + errs = append(errs, fmt.Errorf("server shutdown: %w", err)) + } } - } - // Close events channel - close(s.events) + close(s.events) + }) if len(errs) > 0 { return errs[0] @@ -192,10 +259,17 @@ func (s *Server) handleWebhook(w http.ResponseWriter, r *http.Request) { return } - // Read body + // Cap request body size so a malicious sender on a public tunnel can't + // drive unbounded RAM allocation. MaxBytesReader closes the body and + // returns an error from ReadAll once the limit is exceeded. + r.Body = http.MaxBytesReader(w, r.Body, maxWebhookBodyBytes) body, err := io.ReadAll(r.Body) if err != nil { - http.Error(w, "Failed to read body", http.StatusBadRequest) + // MaxBytesReader returns *http.MaxBytesError once the cap is hit; any + // other read error (timeout, connection reset) is also surfaced as a + // 413 to keep the response simple — the client cannot recover either + // way. + http.Error(w, "Request body too large or unreadable", http.StatusRequestEntityTooLarge) return } defer func() { _ = r.Body.Close() }() @@ -253,6 +327,31 @@ func (s *Server) handleWebhook(w http.ResponseWriter, r *http.Request) { } } } + + // Replay protection. The signature has already been verified above. + // When configured, reject events whose CloudEvents `time` field is + // older than the allowed skew. Payloads without `time` are covered + // by the signed-body dedupe below. + if s.config.WebhookSecret != "" && s.config.MaxEventAge > 0 { + if rawTime, ok := payload["time"].(string); ok { + eventTime, terr := time.Parse(time.RFC3339, rawTime) + if terr != nil { + http.Error(w, "Invalid event timestamp", http.StatusBadRequest) + return + } + skew := time.Since(eventTime) + if skew > s.config.MaxEventAge || skew < -s.config.MaxEventAge { + http.Error(w, "Event timestamp outside allowed skew", http.StatusUnauthorized) + return + } + } + } + } + + if s.shouldSuppressSignedReplay(signature, time.Now()) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("Duplicate webhook ignored")) + return } // Update stats @@ -261,20 +360,43 @@ func (s *Server) handleWebhook(w http.ResponseWriter, r *http.Request) { s.stats.LastEventAt = time.Now() s.mu.Unlock() - // Send to channel (non-blocking) + // Send to channel non-blocking. If the buffer is full we drop the + // *new* event (not the oldest) — bump a stat so callers can see the + // loss. select { case s.events <- event: default: - // Channel full, drop oldest + s.mu.Lock() + s.stats.EventsDropped++ + s.mu.Unlock() } - // Call handlers + // Call handlers. Goroutines are bounded by handlerSlots so a flood of + // events combined with slow handlers cannot drive unbounded goroutine + // creation. If the slot pool is saturated, we run synchronously rather + // than dropping the call — handlers are typically cheap (channel sends). s.mu.RLock() handlers := s.handlers s.mu.RUnlock() for _, handler := range handlers { - go handler(event) + select { + case s.handlerSlots <- struct{}{}: + go func(h ports.WebhookEventHandler) { + defer func() { + <-s.handlerSlots + // Recover so a buggy handler can't take down the server. + _ = recover() + }() + h(event) + }(handler) + default: + // Slot pool full — invoke inline to apply backpressure. + func() { + defer func() { _ = recover() }() + handler(event) + }() + } } // Respond with 200 OK @@ -291,6 +413,7 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { "status": "healthy", "started_at": stats.StartedAt, "events_received": stats.EventsReceived, + "events_dropped": stats.EventsDropped, "local_url": stats.LocalURL, "public_url": stats.PublicURL, } @@ -312,46 +435,46 @@ func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) { } stats := s.GetStats() + data := struct { + EventsReceived int + StartedAt string + PublicURL string + HealthURL string + }{ + EventsReceived: stats.EventsReceived, + StartedAt: stats.StartedAt.Format(time.RFC3339), + PublicURL: stats.PublicURL, + HealthURL: s.GetPublicURL() + "/health", + } - html := fmt.Sprintf(` - - - Nylas Webhook Server - - - -

Nylas Webhook Server

-
-
Status: Running
-
Events Received: %d
-
Started: %s
-
-

Webhook Endpoint

-
%s
-

Send POST requests to this URL to receive webhook events.

-

Health Check

-
%s/health
- -`, - stats.EventsReceived, - stats.StartedAt.Format(time.RFC3339), - stats.PublicURL, - s.GetPublicURL(), - ) - - w.Header().Set("Content-Type", "text/html") - _, _ = w.Write([]byte(html)) // Ignore write error - best effort response + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _ = rootTemplate.Execute(w, data) // best-effort response } // verifySignature verifies the webhook signature using HMAC-SHA256. func (s *Server) verifySignature(payload []byte, signature string) bool { return VerifySignature(payload, signature, s.config.WebhookSecret) } + +func (s *Server) shouldSuppressSignedReplay(signature string, now time.Time) bool { + if s.config.WebhookSecret == "" || s.config.MaxEventAge <= 0 || signature == "" { + return false + } + + s.mu.Lock() + defer s.mu.Unlock() + + cutoff := now.Add(-s.config.MaxEventAge) + for key, seenAt := range s.seenSignatures { + if seenAt.Before(cutoff) { + delete(s.seenSignatures, key) + } + } + + if seenAt, ok := s.seenSignatures[signature]; ok && !seenAt.Before(cutoff) { + return true + } + + s.seenSignatures[signature] = now + return false +} diff --git a/internal/adapters/webhookserver/server_test.go b/internal/adapters/webhookserver/server_test.go index f4dc6bc..322caf0 100644 --- a/internal/adapters/webhookserver/server_test.go +++ b/internal/adapters/webhookserver/server_test.go @@ -56,7 +56,7 @@ func TestServer_StartStop(t *testing.T) { // Server should be running localURL := server.GetLocalURL() - assert.Contains(t, localURL, "/webhook") + assert.Equal(t, fmt.Sprintf("http://127.0.0.1:%d/webhook", port), localURL) // Stop the server err = server.Stop() @@ -86,7 +86,7 @@ func TestServer_GetStats(t *testing.T) { }) stats := server.GetStats() - assert.Equal(t, "http://localhost:3001/webhook", stats.LocalURL) + assert.Equal(t, "http://127.0.0.1:3001/webhook", stats.LocalURL) assert.Equal(t, 0, stats.EventsReceived) } @@ -319,7 +319,7 @@ func TestServer_GetLocalURL(t *testing.T) { }) url := server.GetLocalURL() - assert.Equal(t, "http://localhost:8080/api/hooks", url) + assert.Equal(t, "http://127.0.0.1:8080/api/hooks", url) } func TestServer_GetPublicURL(t *testing.T) { @@ -330,7 +330,7 @@ func TestServer_GetPublicURL(t *testing.T) { // Without tunnel, public URL equals local URL url := server.GetPublicURL() - assert.Equal(t, "http://localhost:8080/webhook", url) + assert.Equal(t, "http://127.0.0.1:8080/webhook", url) } func signWebhookPayload(secret string, payload []byte) string { @@ -338,3 +338,149 @@ func signWebhookPayload(secret string, payload []byte) string { _, _ = mac.Write(payload) return hex.EncodeToString(mac.Sum(nil)) } + +// TestServer_HandleWebhook_RejectsOversizedBody verifies the body-size cap +// rejects payloads larger than the configured limit before allocating +// gigabytes of RAM. This is the gate that prevents a malicious sender on +// a public tunnel URL from driving the receiver out of memory. +func TestServer_HandleWebhook_RejectsOversizedBody(t *testing.T) { + server := NewServer(ports.WebhookServerConfig{Port: 0, Path: "/webhook"}) + + // 2 MiB — twice the cap. + oversized := bytes.Repeat([]byte("A"), 2<<20) + + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(oversized)) + rec := httptest.NewRecorder() + server.handleWebhook(rec, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code, + "oversized body must be rejected with 413, got %d", rec.Code) +} + +// TestServer_HandleWebhook_DropsOldEvents_ReplayWindow exercises the +// MaxEventAge gate. With MaxEventAge configured, an event whose +// CloudEvents `time` field is older than the window is rejected as a +// replay even when the HMAC verifies — bound to a captured signature, an +// attacker would otherwise be able to replay a single signed body +// indefinitely. +func TestServer_HandleWebhook_DropsOldEvents_ReplayWindow(t *testing.T) { + secret := "test-secret" + server := NewServer(ports.WebhookServerConfig{ + Port: 0, + Path: "/webhook", + WebhookSecret: secret, + MaxEventAge: 30 * time.Second, + }) + + // Body with a `time` field 5 minutes in the past. + oldTime := time.Now().Add(-5 * time.Minute).UTC().Format(time.RFC3339) + body := []byte(`{"id":"evt_1","type":"message.created","time":"` + oldTime + `"}`) + sig := signWebhookPayload(secret, body) + + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-Nylas-Signature", sig) + rec := httptest.NewRecorder() + server.handleWebhook(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code, + "stale event must be rejected as replay, got %d", rec.Code) +} + +// TestServer_HandleWebhook_AcceptsRecentEvents_ReplayWindow is the +// positive twin of the replay test: a signed event with a fresh `time` +// passes the gate. +func TestServer_HandleWebhook_AcceptsRecentEvents_ReplayWindow(t *testing.T) { + secret := "test-secret" + server := NewServer(ports.WebhookServerConfig{ + Port: 0, + Path: "/webhook", + WebhookSecret: secret, + MaxEventAge: 60 * time.Second, + }) + + now := time.Now().UTC().Format(time.RFC3339) + body := []byte(`{"id":"evt_2","type":"message.created","time":"` + now + `"}`) + sig := signWebhookPayload(secret, body) + + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-Nylas-Signature", sig) + rec := httptest.NewRecorder() + server.handleWebhook(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, "fresh event must be accepted") +} + +func TestServer_HandleWebhook_DeduplicatesSignedPayloadWithoutTime(t *testing.T) { + secret := "test-secret" + server := NewServer(ports.WebhookServerConfig{ + Port: 0, + Path: "/webhook", + WebhookSecret: secret, + MaxEventAge: 60 * time.Second, + }) + + body := []byte(`{"id":"evt_without_time","type":"message.created"}`) + sig := signWebhookPayload(secret, body) + + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-Nylas-Signature", sig) + rec := httptest.NewRecorder() + server.handleWebhook(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "first signed event must be accepted") + assert.Equal(t, 1, server.GetStats().EventsReceived) + + select { + case <-server.Events(): + default: + t.Fatal("expected first signed event to be queued") + } + + req = httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-Nylas-Signature", sig) + rec = httptest.NewRecorder() + server.handleWebhook(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, "duplicate signed event should be acknowledged without processing") + assert.Equal(t, 1, server.GetStats().EventsReceived, "duplicate signed body must not be counted as a new event") + select { + case event := <-server.Events(): + t.Fatalf("duplicate signed event was queued: %+v", event) + default: + } +} + +// TestServer_HandleHealth_SurfacesEventsDropped confirms the health +// response includes the events_dropped counter so operators can detect a +// slow consumer without parsing logs. +func TestServer_HandleHealth_SurfacesEventsDropped(t *testing.T) { + server := NewServer(ports.WebhookServerConfig{Port: 0, Path: "/webhook"}) + server.mu.Lock() + server.stats.EventsDropped = 7 + server.mu.Unlock() + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rec := httptest.NewRecorder() + server.handleHealth(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var body map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &body)) + require.Contains(t, body, "events_dropped") + assert.Equal(t, float64(7), body["events_dropped"]) +} + +// TestServer_StartBindsLoopbackOnly asserts the listener address is on a +// loopback interface — guards against an accidental change from +// 127.0.0.1: to :PORT (which would let any host on the LAN forge events). +func TestServer_StartBindsLoopbackOnly(t *testing.T) { + server := NewServer(ports.WebhookServerConfig{Port: reserveTCPPort(t), Path: "/webhook"}) + require.NoError(t, server.Start(context.Background())) + defer func() { _ = server.Stop() }() + + addr := server.listener.Addr().String() + host, _, err := net.SplitHostPort(addr) + require.NoError(t, err) + ip := net.ParseIP(host) + require.NotNil(t, ip, "could not parse listener host as IP: %s", host) + assert.True(t, ip.IsLoopback(), "listener bound to non-loopback address: %s", addr) +} diff --git a/internal/air/handlers_ai_complete.go b/internal/air/handlers_ai_complete.go index 01f2407..73d165f 100644 --- a/internal/air/handlers_ai_complete.go +++ b/internal/air/handlers_ai_complete.go @@ -10,6 +10,8 @@ import ( "strconv" "strings" "time" + + "github.com/nylas/cli/internal/httputil" ) const smartComposeTimeout = 5 * time.Second @@ -42,11 +44,7 @@ func (s *Server) handleAIComplete(w http.ResponseWriter, r *http.Request) { } if req.Text == "" { - w.Header().Set("Content-Type", "application/json") - resp := CompleteResponse{Suggestion: "", Confidence: 0} - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, CompleteResponse{Suggestion: "", Confidence: 0}) return } @@ -56,14 +54,10 @@ func (s *Server) handleAIComplete(w http.ResponseWriter, r *http.Request) { suggestion := getAICompletion(r.Context(), req.Text, req.MaxLength) - w.Header().Set("Content-Type", "application/json") - resp := CompleteResponse{ + httputil.WriteJSON(w, http.StatusOK, CompleteResponse{ Suggestion: suggestion, Confidence: 0.8, - } - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - } + }) } // getAICompletion gets completion from Claude via CLI @@ -175,10 +169,7 @@ func (s *Server) handleNLSearch(w http.ResponseWriter, r *http.Request) { result := parseNaturalLanguageSearch(req.Query) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(result); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, result) } // parseNaturalLanguageSearch converts NL query to search params diff --git a/internal/air/handlers_ai_config.go b/internal/air/handlers_ai_config.go index 1d303b7..d13bec1 100644 --- a/internal/air/handlers_ai_config.go +++ b/internal/air/handlers_ai_config.go @@ -2,8 +2,11 @@ package air import ( "encoding/json" + "maps" "net/http" "sync" + + "github.com/nylas/cli/internal/httputil" ) // AIConfig represents AI provider configuration @@ -81,10 +84,7 @@ func (s *Server) handleGetAIConfig(w http.ResponseWriter, r *http.Request) { config.APIKey = "***" + config.APIKey[max(0, len(config.APIKey)-4):] } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(config); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, config) } // handleUpdateAIConfig updates AI configuration @@ -118,20 +118,14 @@ func (s *Server) handleUpdateAIConfig(w http.ResponseWriter, r *http.Request) { aiStore.config.Temperature = req.Temperature } if req.TaskModels != nil { - for task, model := range req.TaskModels { - aiStore.config.TaskModels[task] = model - } + maps.Copy(aiStore.config.TaskModels, req.TaskModels) } if req.UsageBudget > 0 { aiStore.config.UsageBudget = req.UsageBudget } aiStore.config.Enabled = req.Enabled - w.Header().Set("Content-Type", "application/json") - resp := map[string]string{"status": "updated"} - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, map[string]string{"status": "updated"}) } // handleTestAIConnection tests the AI provider connection @@ -153,10 +147,7 @@ func (s *Server) handleTestAIConnection(w http.ResponseWriter, r *http.Request) result["message"] = "API key required" } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(result); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, result) } // handleGetAIUsage returns AI usage statistics @@ -172,10 +163,7 @@ func (s *Server) handleGetAIUsage(w http.ResponseWriter, r *http.Request) { "percentUsed": (aiStore.config.UsageSpent / aiStore.config.UsageBudget) * 100, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, response) } // GetAIProviders returns available AI providers @@ -207,10 +195,7 @@ func (s *Server) handleGetAIProviders(w http.ResponseWriter, r *http.Request) { }, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(providers); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, providers) } // RecordAIUsage records AI usage for a task diff --git a/internal/air/handlers_analytics.go b/internal/air/handlers_analytics.go index 2ce9733..a91f6dd 100644 --- a/internal/air/handlers_analytics.go +++ b/internal/air/handlers_analytics.go @@ -1,10 +1,11 @@ package air import ( - "encoding/json" "net/http" "sync" "time" + + "github.com/nylas/cli/internal/httputil" ) // EmailAnalytics represents email analytics data @@ -121,10 +122,7 @@ func (s *Server) handleGetAnalyticsDashboard(w http.ResponseWriter, r *http.Requ aStore.mu.RLock() defer aStore.mu.RUnlock() - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(aStore.analytics); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, aStore.analytics) } // handleGetAnalyticsTrends returns email trends @@ -144,10 +142,7 @@ func (s *Server) handleGetAnalyticsTrends(w http.ResponseWriter, r *http.Request "dailyVolume": aStore.analytics.DailyVolume, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, response) } // handleGetFocusTimeSuggestions returns suggested focus time blocks @@ -198,10 +193,7 @@ func (s *Server) handleGetFocusTimeSuggestions(w http.ResponseWriter, r *http.Re Score: 90, }) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(suggestions); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, suggestions) } // handleGetProductivityStats returns productivity metrics @@ -219,10 +211,7 @@ func (s *Server) handleGetProductivityStats(w http.ResponseWriter, r *http.Reque "emailsProcessed": aStore.analytics.TotalArchived + aStore.analytics.TotalDeleted, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, response) } // RecordEmailReceived records a received email diff --git a/internal/air/handlers_bundles.go b/internal/air/handlers_bundles.go index 25924f6..610c1ae 100644 --- a/internal/air/handlers_bundles.go +++ b/internal/air/handlers_bundles.go @@ -7,6 +7,8 @@ import ( "strings" "sync" "time" + + "github.com/nylas/cli/internal/httputil" ) // Bundle represents an email bundle/category @@ -156,10 +158,7 @@ func (s *Server) handleGetBundles(w http.ResponseWriter, r *http.Request) { bundles = append(bundles, b) } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(bundles); err != nil { - http.Error(w, "Failed to encode bundles", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, bundles) } // handleBundleCategorize assigns an email to a bundle @@ -187,11 +186,7 @@ func (s *Server) handleBundleCategorize(w http.ResponseWriter, r *http.Request) bundleStore.mu.Unlock() } - w.Header().Set("Content-Type", "application/json") - resp := map[string]string{"bundleId": bundleID} - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, map[string]string{"bundleId": bundleID}) } // categorizeEmail determines which bundle an email belongs to @@ -295,10 +290,7 @@ func (s *Server) handleUpdateBundle(w http.ResponseWriter, r *http.Request) { existing.LastUpdated = time.Now() } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(bundle); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, bundle) } // handleGetBundleEmails returns emails for a specific bundle @@ -319,8 +311,5 @@ func (s *Server) handleGetBundleEmails(w http.ResponseWriter, r *http.Request) { } } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(emailIDs); err != nil { - http.Error(w, "Failed to encode response", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, emailIDs) } diff --git a/internal/air/handlers_cache.go b/internal/air/handlers_cache.go index e8766c3..2013cb8 100644 --- a/internal/air/handlers_cache.go +++ b/internal/air/handlers_cache.go @@ -7,6 +7,7 @@ import ( "time" "github.com/nylas/cli/internal/air/cache" + "github.com/nylas/cli/internal/domain" ) // CacheStatusResponse represents the cache status API response. @@ -175,20 +176,32 @@ func (s *Server) handleCacheSync(w http.ResponseWriter, r *http.Request) { return } - // Get email from query param (optional - sync all if not specified) + // Get email from query param (optional - sync default account if not specified) email := r.URL.Query().Get("email") - // Get all grants - grants, err := s.grantStore.ListGrants() - if err != nil { - writeJSON(w, http.StatusOK, CacheSyncResponse{ - Success: false, - Error: "Failed to get accounts", - }) - return + var grants []domain.GrantInfo + if email == "" { + grant, err := s.resolveDefaultGrantInfo() + if err != nil { + writeJSON(w, http.StatusOK, CacheSyncResponse{ + Success: false, + Error: "Failed to get default account", + }) + return + } + grants = []domain.GrantInfo{grant} + } else { + var err error + grants, err = s.grantStore.ListGrants() + if err != nil { + writeJSON(w, http.StatusOK, CacheSyncResponse{ + Success: false, + Error: "Failed to get accounts", + }) + return + } } - // Sync accounts synced := 0 for _, grant := range grants { if !grant.Provider.IsSupportedByAir() { diff --git a/internal/air/handlers_config.go b/internal/air/handlers_config.go index aac4757..3cb3044 100644 --- a/internal/air/handlers_config.go +++ b/internal/air/handlers_config.go @@ -162,6 +162,7 @@ func (s *Server) handleSetDefaultGrant(w http.ResponseWriter, r *http.Request) { }) return } + s.refreshActiveAccountCacheRuntime() writeJSON(w, http.StatusOK, SetDefaultGrantResponse{ Success: true, diff --git a/internal/air/handlers_config_test.go b/internal/air/handlers_config_test.go index ff8bc8a..378fb47 100644 --- a/internal/air/handlers_config_test.go +++ b/internal/air/handlers_config_test.go @@ -10,6 +10,7 @@ import ( configadapter "github.com/nylas/cli/internal/adapters/config" keyringadapter "github.com/nylas/cli/internal/adapters/keyring" + "github.com/nylas/cli/internal/air/cache" authapp "github.com/nylas/cli/internal/app/auth" "github.com/nylas/cli/internal/domain" ) @@ -294,6 +295,68 @@ func TestHandleSetDefaultGrant_RejectsUnsupportedProviders(t *testing.T) { } } +func TestHandleSetDefaultGrantRefreshesActiveOfflineQueue(t *testing.T) { + t.Parallel() + + manager, err := cache.NewManager(cache.Config{BasePath: t.TempDir()}) + if err != nil { + t.Fatalf("new cache manager: %v", err) + } + t.Cleanup(func() { + _ = manager.Close() + }) + + settings := cache.DefaultSettings() + settings.Enabled = true + settings.OfflineQueueEnabled = true + + grantStore := &testGrantStore{ + grants: []domain.GrantInfo{ + {ID: "grant-old", Email: "old@example.com", Provider: domain.ProviderGoogle}, + {ID: "grant-new", Email: "new@example.com", Provider: domain.ProviderNylas}, + }, + defaultGrant: "grant-old", + } + server := &Server{ + cacheManager: manager, + cacheSettings: settings, + configStore: configadapter.NewMockConfigStore(), + grantStore: grantStore, + offlineQueues: make(map[string]*cache.OfflineQueue), + isOnline: false, + } + t.Cleanup(server.stopBackgroundSync) + + if err := server.withOfflineQueue("old@example.com", func(*cache.OfflineQueue) error { return nil }); err != nil { + t.Fatalf("initialize old offline queue: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/api/grants/default", strings.NewReader(`{"grant_id":"grant-new"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + server.handleSetDefaultGrant(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + if grantStore.defaultGrant != "grant-new" { + t.Fatalf("expected default grant to switch to grant-new, got %q", grantStore.defaultGrant) + } + + server.offlineQueuesMu.RLock() + defer server.offlineQueuesMu.RUnlock() + if len(server.offlineQueues) != 1 { + t.Fatalf("expected exactly one offline queue after switch, got %d", len(server.offlineQueues)) + } + if server.offlineQueues["new@example.com"] == nil { + t.Fatal("expected new default account offline queue to be initialized") + } + if server.offlineQueues["old@example.com"] != nil { + t.Fatal("did not expect old default account offline queue to remain initialized") + } +} + func TestBuildPageData_UsesResolvedSupportedDefaultGrantWithoutPersistingIt(t *testing.T) { t.Parallel() diff --git a/internal/air/handlers_email_cache_runtime_test.go b/internal/air/handlers_email_cache_runtime_test.go index f544cb8..005d80d 100644 --- a/internal/air/handlers_email_cache_runtime_test.go +++ b/internal/air/handlers_email_cache_runtime_test.go @@ -38,6 +38,11 @@ func (s *testGrantStore) SaveGrant(info domain.GrantInfo) error { return nil } +func (s *testGrantStore) ReplaceGrants(grants []domain.GrantInfo) error { + s.grants = append([]domain.GrantInfo(nil), grants...) + return nil +} + func (s *testGrantStore) GetGrant(grantID string) (*domain.GrantInfo, error) { for i := range s.grants { if s.grants[i].ID == grantID { diff --git a/internal/air/handlers_focus_mode.go b/internal/air/handlers_focus_mode.go index 6dceb8c..2b4938b 100644 --- a/internal/air/handlers_focus_mode.go +++ b/internal/air/handlers_focus_mode.go @@ -3,15 +3,18 @@ package air import ( "encoding/json" "net/http" + "slices" "sync" "time" + + "github.com/nylas/cli/internal/httputil" ) // FocusModeState represents the current focus mode state type FocusModeState struct { IsActive bool `json:"isActive"` - StartedAt time.Time `json:"startedAt,omitempty"` - EndsAt time.Time `json:"endsAt,omitempty"` + StartedAt time.Time `json:"startedAt,omitzero"` + EndsAt time.Time `json:"endsAt,omitzero"` Duration int `json:"duration"` // minutes PomodoroMode bool `json:"pomodoroMode"` SessionCount int `json:"sessionCount"` @@ -112,10 +115,7 @@ func (s *Server) handleGetFocusModeState(w http.ResponseWriter, r *http.Request) } } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, response) } // handleStartFocusMode starts a focus mode session @@ -154,10 +154,7 @@ func (s *Server) handleStartFocusMode(w http.ResponseWriter, r *http.Request) { InBreak: false, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(fmStore.state); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, fmStore.state) } // handleStopFocusMode stops the current focus mode session @@ -171,14 +168,10 @@ func (s *Server) handleStopFocusMode(w http.ResponseWriter, r *http.Request) { fmStore.state.IsActive = false fmStore.state.InBreak = false - w.Header().Set("Content-Type", "application/json") - resp := map[string]any{ + httputil.WriteJSON(w, http.StatusOK, map[string]any{ "status": "stopped", "sessionCount": fmStore.state.SessionCount, - } - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + }) } // handleStartBreak starts a break in pomodoro mode @@ -203,10 +196,7 @@ func (s *Server) handleStartBreak(w http.ResponseWriter, r *http.Request) { fmStore.state.EndsAt = now.Add(time.Duration(breakDuration) * time.Minute) fmStore.state.BreakDuration = breakDuration - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(fmStore.state); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, fmStore.state) } // handleGetFocusModeSettings returns focus mode settings @@ -214,10 +204,7 @@ func (s *Server) handleGetFocusModeSettings(w http.ResponseWriter, r *http.Reque fmStore.mu.RLock() defer fmStore.mu.RUnlock() - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(fmStore.settings); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, fmStore.settings) } // handleUpdateFocusModeSettings updates focus mode settings @@ -232,11 +219,7 @@ func (s *Server) handleUpdateFocusModeSettings(w http.ResponseWriter, r *http.Re fmStore.settings = &settings fmStore.mu.Unlock() - w.Header().Set("Content-Type", "application/json") - resp := map[string]string{"status": "updated"} - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, map[string]string{"status": "updated"}) } // IsFocusModeActive returns whether focus mode is active @@ -256,11 +239,5 @@ func ShouldAllowNotification(senderEmail string) bool { } // Check if sender is in allowed list - for _, allowed := range fmStore.settings.AllowedSenders { - if allowed == senderEmail { - return true - } - } - - return false + return slices.Contains(fmStore.settings.AllowedSenders, senderEmail) } diff --git a/internal/air/handlers_notetaker.go b/internal/air/handlers_notetaker.go index 71a15cb..878ae60 100644 --- a/internal/air/handlers_notetaker.go +++ b/internal/air/handlers_notetaker.go @@ -9,6 +9,7 @@ import ( "time" "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/httputil" ) // States to filter out from notetaker list @@ -124,10 +125,7 @@ func (s *Server) handleListNotetakers(w http.ResponseWriter, r *http.Request) { } } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, response) } // domainToNotetakerResponse converts a domain.Notetaker to NotetakerResponse @@ -205,10 +203,7 @@ func (s *Server) handleCreateNotetaker(w http.ResponseWriter, r *http.Request) { return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(domainToNotetakerResponse(nt)); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, domainToNotetakerResponse(nt)) } // handleGetNotetaker returns a single notetaker from the Nylas API @@ -233,10 +228,7 @@ func (s *Server) handleGetNotetaker(w http.ResponseWriter, r *http.Request) { return } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(domainToNotetakerResponse(nt)); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, domainToNotetakerResponse(nt)) } // handleGetNotetakerMedia returns media for a notetaker from the Nylas API @@ -273,10 +265,7 @@ func (s *Server) handleGetNotetakerMedia(w http.ResponseWriter, r *http.Request) media.TranscriptSize = mediaData.Transcript.Size } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(media); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, media) } // handleDeleteNotetaker cancels a notetaker via the Nylas API diff --git a/internal/air/handlers_read_receipts.go b/internal/air/handlers_read_receipts.go index 7bbc1ff..67ea4a9 100644 --- a/internal/air/handlers_read_receipts.go +++ b/internal/air/handlers_read_receipts.go @@ -5,13 +5,15 @@ import ( "net/http" "sync" "time" + + "github.com/nylas/cli/internal/httputil" ) // ReadReceipt represents a read receipt for a sent email type ReadReceipt struct { EmailID string `json:"emailId"` Recipient string `json:"recipient"` - OpenedAt time.Time `json:"openedAt,omitempty"` + OpenedAt time.Time `json:"openedAt,omitzero"` OpenCount int `json:"openCount"` Device string `json:"device,omitempty"` Location string `json:"location,omitempty"` @@ -55,10 +57,7 @@ func (s *Server) handleGetReadReceipts(w http.ResponseWriter, r *http.Request) { if emailID != "" { if receipt, ok := rrStore.receipts[emailID]; ok { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(receipt); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, receipt) return } http.Error(w, "Receipt not found", http.StatusNotFound) @@ -70,10 +69,7 @@ func (s *Server) handleGetReadReceipts(w http.ResponseWriter, r *http.Request) { receipts = append(receipts, r) } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(receipts); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, receipts) } // handleTrackOpen records an email open (tracking pixel endpoint) @@ -131,10 +127,7 @@ func (s *Server) handleGetReadReceiptSettings(w http.ResponseWriter, r *http.Req rrStore.mu.RLock() defer rrStore.mu.RUnlock() - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(rrStore.settings); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, rrStore.settings) } // handleUpdateReadReceiptSettings updates settings @@ -149,11 +142,7 @@ func (s *Server) handleUpdateReadReceiptSettings(w http.ResponseWriter, r *http. rrStore.settings = &settings rrStore.mu.Unlock() - w.Header().Set("Content-Type", "application/json") - resp := map[string]string{"status": "updated"} - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, map[string]string{"status": "updated"}) } // RegisterEmailForTracking registers an email for read tracking diff --git a/internal/air/handlers_reply_later.go b/internal/air/handlers_reply_later.go index f20ffcb..260e429 100644 --- a/internal/air/handlers_reply_later.go +++ b/internal/air/handlers_reply_later.go @@ -5,6 +5,8 @@ import ( "net/http" "sync" "time" + + "github.com/nylas/cli/internal/httputil" ) // ReplyLaterItem represents an email in the reply later queue @@ -13,7 +15,7 @@ type ReplyLaterItem struct { Subject string `json:"subject"` From string `json:"from"` AddedAt time.Time `json:"addedAt"` - RemindAt time.Time `json:"remindAt,omitempty"` + RemindAt time.Time `json:"remindAt,omitzero"` DraftID string `json:"draftId,omitempty"` Notes string `json:"notes,omitempty"` Priority int `json:"priority"` // 1=high, 2=medium, 3=low @@ -56,10 +58,7 @@ func (s *Server) handleGetReplyLaterItems(w http.ResponseWriter, r *http.Request } } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(items); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, items) } // handleAddToReplyLater adds an email to reply later queue @@ -117,10 +116,7 @@ func (s *Server) handleAddToReplyLater(w http.ResponseWriter, r *http.Request) { rlStore.items[req.EmailID] = item rlStore.mu.Unlock() - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(item); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, item) } // handleUpdateReplyLater updates a reply later item @@ -158,10 +154,7 @@ func (s *Server) handleUpdateReplyLater(w http.ResponseWriter, r *http.Request) } item.IsCompleted = req.IsCompleted - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(item); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, item) } // handleRemoveFromReplyLater removes an email from reply later queue diff --git a/internal/air/handlers_screener.go b/internal/air/handlers_screener.go index 64d57d6..f3a19c7 100644 --- a/internal/air/handlers_screener.go +++ b/internal/air/handlers_screener.go @@ -5,6 +5,8 @@ import ( "net/http" "sync" "time" + + "github.com/nylas/cli/internal/httputil" ) // ScreenedSender represents a sender pending approval @@ -46,10 +48,7 @@ func (s *Server) handleGetScreenedSenders(w http.ResponseWriter, r *http.Request } } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(senders); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, senders) } // handleScreenerAllow allows a sender @@ -83,11 +82,7 @@ func (s *Server) handleScreenerAllow(w http.ResponseWriter, r *http.Request) { } } - w.Header().Set("Content-Type", "application/json") - resp := map[string]string{"status": "allowed", "destination": req.Destination} - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, map[string]string{"status": "allowed", "destination": req.Destination}) } // handleScreenerBlock blocks a sender @@ -114,11 +109,7 @@ func (s *Server) handleScreenerBlock(w http.ResponseWriter, r *http.Request) { } } - w.Header().Set("Content-Type", "application/json") - resp := map[string]string{"status": "blocked"} - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, map[string]string{"status": "blocked"}) } // handleAddToScreener adds a new sender for screening @@ -156,11 +147,7 @@ func (s *Server) handleAddToScreener(w http.ResponseWriter, r *http.Request) { } } - w.Header().Set("Content-Type", "application/json") - resp := map[string]string{"status": "pending"} - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, "Failed to encode", http.StatusInternalServerError) - } + httputil.WriteJSON(w, http.StatusOK, map[string]string{"status": "pending"}) } // IsSenderAllowed checks if a sender is allowed diff --git a/internal/air/integration_base_test.go b/internal/air/integration_base_test.go index 7f1ce7b..dba8966 100644 --- a/internal/air/integration_base_test.go +++ b/internal/air/integration_base_test.go @@ -12,6 +12,7 @@ import ( "github.com/nylas/cli/internal/adapters/keyring" "github.com/nylas/cli/internal/adapters/nylas" authapp "github.com/nylas/cli/internal/app/auth" + "github.com/nylas/cli/internal/cli/common" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" ) @@ -26,7 +27,10 @@ func testServer(t *testing.T) *Server { t.Skipf("Skipping: cannot access secret store: %v", err) } - grantStore := keyring.NewGrantStore(secretStore) + grantStore, err := common.NewDefaultGrantStore() + if err != nil { + t.Skipf("Skipping: cannot access grant store: %v", err) + } configSvc := authapp.NewConfigService(configStore, secretStore) // Check configuration diff --git a/internal/air/layout_regression_test.go b/internal/air/layout_regression_test.go index 5f4797e..1e7e0b0 100644 --- a/internal/air/layout_regression_test.go +++ b/internal/air/layout_regression_test.go @@ -122,3 +122,38 @@ func TestMainLayoutHasExplicitHeight(t *testing.T) { t.Error("main-layout should use calc(100vh - ...) for explicit height calculation") } } + +func TestAccountDropdownIsViewportBoundedAndScrollable(t *testing.T) { + t.Parallel() + + cssContent, err := staticFiles.ReadFile("static/css/components-account.css") + if err != nil { + t.Fatalf("failed to read components-account.css: %v", err) + } + + rule := cssRule(t, string(cssContent), ".account-dropdown") + required := []string{ + "max-height:", + "overflow-y: auto", + "overscroll-behavior: contain", + } + for _, declaration := range required { + if !strings.Contains(rule, declaration) { + t.Errorf(".account-dropdown must include %q so long account lists fit the viewport and scroll", declaration) + } + } +} + +func cssRule(t *testing.T, css, selector string) string { + t.Helper() + + start := strings.Index(css, selector+" {") + if start == -1 { + t.Fatalf("missing CSS rule for %s", selector) + } + end := strings.Index(css[start:], "}") + if end == -1 { + t.Fatalf("unterminated CSS rule for %s", selector) + } + return css[start : start+end] +} diff --git a/internal/air/middleware.go b/internal/air/middleware.go index 47e5498..f33e116 100644 --- a/internal/air/middleware.go +++ b/internal/air/middleware.go @@ -3,12 +3,12 @@ package air import ( "compress/gzip" "io" - "net" "net/http" - "net/url" "strconv" "strings" "time" + + "github.com/nylas/cli/internal/webguard" ) // gzipResponseWriter wraps http.ResponseWriter to support gzip compression. @@ -180,14 +180,28 @@ func SecurityHeadersMiddleware(next http.Handler) http.Handler { // Referrer policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") - // Content Security Policy (relaxed for local development) + // Content Security Policy. + // + // 'unsafe-inline' is removed from script-src: all inline onclick/ + // onchange handlers and inline + + @@ -123,7 +72,10 @@ + + + @@ -209,6 +161,9 @@ + + + {{end}} diff --git a/internal/air/templates/pages/calendar.gohtml b/internal/air/templates/pages/calendar.gohtml index 55ac5f5..737f78b 100644 --- a/internal/air/templates/pages/calendar.gohtml +++ b/internal/air/templates/pages/calendar.gohtml @@ -3,7 +3,7 @@
{{if eq .Provider "nylas"}} - {{end}} diff --git a/internal/air/templates/pages/notetaker.gohtml b/internal/air/templates/pages/notetaker.gohtml index 59739c9..e7c952b 100644 --- a/internal/air/templates/pages/notetaker.gohtml +++ b/internal/air/templates/pages/notetaker.gohtml @@ -5,12 +5,12 @@
- - + +
- - + +
@@ -21,7 +21,7 @@
🎙️

No recordings yet

Join a meeting to start recording

- + @@ -33,7 +33,7 @@ - @@ -18,7 +18,7 @@

Policies

Managed mailbox policies available for this Nylas account.

- +
@@ -35,7 +35,7 @@

Rules

Inbound processing rules configured for this Nylas account.

- +
diff --git a/internal/air/templates/partials/header.gohtml b/internal/air/templates/partials/header.gohtml index a585fe2..eda8861 100644 --- a/internal/air/templates/partials/header.gohtml +++ b/internal/air/templates/partials/header.gohtml @@ -7,27 +7,27 @@ Nylas Air
@@ -156,8 +156,8 @@
diff --git a/internal/air/templates/partials/modals_navigation.gohtml b/internal/air/templates/partials/modals_navigation.gohtml index 3281935..287a494 100644 --- a/internal/air/templates/partials/modals_navigation.gohtml +++ b/internal/air/templates/partials/modals_navigation.gohtml @@ -1,7 +1,7 @@ {{define "modals_navigation"}} -