Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions browser_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,11 @@ func performBrowserFlowWithUpdates(
}
// Hard error (CSRF mismatch, token exchange failure, OAuth rejection,
// etc.) — surface it to the user.
// Send error update and give UI time to process it
updates <- tui.FlowUpdate{
Type: tui.StepError,
Step: 3,
Message: err.Error(),
}
// Small delay to ensure error is displayed before returning
time.Sleep(100 * time.Millisecond)
return nil, false, fmt.Errorf("authentication failed: %w", err)
}

Expand Down
5 changes: 4 additions & 1 deletion callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ func startCallbackServer(ctx context.Context, port int, expectedState string,
_ = srv.Shutdown(shutdownCtx)
}()

timer := time.NewTimer(callbackTimeout)
defer timer.Stop()

select {
case result := <-resultCh:
if result.Error != "" {
Expand All @@ -168,7 +171,7 @@ func startCallbackServer(ctx context.Context, port int, expectedState string,
}
return result.Storage, nil

case <-time.After(callbackTimeout):
case <-timer.C:
return nil, fmt.Errorf("%w after %s", ErrCallbackTimeout, callbackTimeout)

case <-ctx.Done():
Expand Down
6 changes: 2 additions & 4 deletions device_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ func handleDevicePollError(
return pollErrorResult{pollFail, fmt.Errorf("token exchange failed: %w", err)}
}

var errResp ErrorResponse
if json.Unmarshal(oauthErr.Body, &errResp) != nil || errResp.Error == "" {
errResp, ok := parseOAuthError(oauthErr.Body)
if !ok {
return pollErrorResult{
pollFail,
fmt.Errorf("token exchange failed (body: %s): %w", oauthErr.Body, err),
Expand Down Expand Up @@ -253,8 +253,6 @@ func performDeviceFlowWithUpdates(
Step: 2,
Message: fmt.Sprintf("Authorization failed: %v", err),
}
// Small delay to ensure error is displayed before returning
time.Sleep(100 * time.Millisecond)
return nil, fmt.Errorf("token poll failed: %w", err)
}

Expand Down
28 changes: 17 additions & 11 deletions tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ type ErrorResponse struct {
ErrorDescription string `json:"error_description"`
}

// parseOAuthError attempts to unmarshal an OAuth error response from raw JSON.
// Returns the parsed response and true if successful, or a zero value and false
// if the body is not a valid OAuth error (missing "error" field).
func parseOAuthError(body []byte) (ErrorResponse, bool) {
var errResp ErrorResponse
if err := json.Unmarshal(body, &errResp); err != nil || errResp.Error == "" {
return ErrorResponse{}, false
}
return errResp, true
}

// readResponseBody reads the response body with a size limit to guard against oversized responses.
func readResponseBody(resp *http.Response) ([]byte, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
Expand All @@ -42,10 +53,10 @@ func readResponseBody(resp *http.Response) ([]byte, error) {
// formatHTTPError attempts to parse an OAuth error response from body,
// falling back to a generic status+body error message.
func formatHTTPError(body []byte, statusCode int) error {
var errResp ErrorResponse
if json.Unmarshal(body, &errResp) == nil && errResp.Error != "" {
if errResp.ErrorDescription != "" {
return fmt.Errorf("%s: %s", errResp.Error, errResp.ErrorDescription)
if errResp, ok := parseOAuthError(body); ok {
desc := strings.TrimSpace(errResp.ErrorDescription)
if desc != "" {
return fmt.Errorf("%s: %s", errResp.Error, desc)
}
return fmt.Errorf("%s", errResp.Error)
}
Expand Down Expand Up @@ -83,18 +94,13 @@ func doTokenExchange(
}

if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.Error != "" {
if errResp, ok := parseOAuthError(body); ok {
if errHook != nil {
if hookErr := errHook(errResp, body); hookErr != nil {
return nil, hookErr
}
}
desc := strings.TrimSpace(errResp.ErrorDescription)
if desc == "" {
return nil, fmt.Errorf("%s", errResp.Error)
}
return nil, fmt.Errorf("%s: %s", errResp.Error, desc)
return nil, formatHTTPError(body, resp.StatusCode)
}
return nil, fmt.Errorf(
"token exchange failed with status %d: %s",
Expand Down
32 changes: 0 additions & 32 deletions tui/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ func TestFlowUpdateHelpers(t *testing.T) {
Data: map[string]any{
"string_val": "hello",
"int_val": 42,
"float_val": 3.14,
"duration_val": 5000000000, // 5 seconds in nanoseconds
},
}
Expand All @@ -127,11 +126,6 @@ func TestFlowUpdateHelpers(t *testing.T) {
t.Errorf("GetInt: expected 42, got %d", got)
}

// Test GetFloat64
if got := update.GetFloat64("float_val"); got != 3.14 {
t.Errorf("GetFloat64: expected 3.14, got %f", got)
}

// Test missing keys return zero values
if got := update.GetString("missing"); got != "" {
t.Errorf("GetString (missing): expected empty string, got '%s'", got)
Expand Down Expand Up @@ -163,29 +157,3 @@ func TestFlowUpdate_FallbackField(t *testing.T) {
}
})
}

func TestFlowUpdateTypeString(t *testing.T) {
tests := []struct {
updateType FlowUpdateType
expected string
}{
{StepStart, "StepStart"},
{StepProgress, "StepProgress"},
{StepComplete, "StepComplete"},
{StepError, "StepError"},
{TimerTick, "TimerTick"},
{BrowserOpened, "BrowserOpened"},
{CallbackReceived, "CallbackReceived"},
{DeviceCodeReceived, "DeviceCodeReceived"},
{PollingUpdate, "PollingUpdate"},
{BackoffChanged, "BackoffChanged"},
}

for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
if got := tt.updateType.String(); got != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, got)
}
})
}
}
39 changes: 0 additions & 39 deletions tui/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,34 +42,6 @@ type FlowUpdate struct {
Data map[string]any // Additional data for specific update types
}

// String returns a human-readable representation of the FlowUpdateType.
func (t FlowUpdateType) String() string {
switch t {
case StepStart:
return "StepStart"
case StepProgress:
return "StepProgress"
case StepComplete:
return "StepComplete"
case StepError:
return "StepError"
case TimerTick:
return "TimerTick"
case BrowserOpened:
return "BrowserOpened"
case CallbackReceived:
return "CallbackReceived"
case DeviceCodeReceived:
return "DeviceCodeReceived"
case PollingUpdate:
return "PollingUpdate"
case BackoffChanged:
return "BackoffChanged"
default:
return "Unknown"
}
}

// Helper functions to extract data from FlowUpdate.Data

// GetString safely extracts a string value from Data.
Expand Down Expand Up @@ -104,14 +76,3 @@ func (u *FlowUpdate) GetDuration(key string) time.Duration {
}
return 0
}

// GetFloat64 safely extracts a float64 value from Data.
func (u *FlowUpdate) GetFloat64(key string) float64 {
if u.Data == nil {
return 0
}
if val, ok := u.Data[key].(float64); ok {
return val
}
return 0
}
29 changes: 0 additions & 29 deletions tui/styles.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ var (
colorBright = lipgloss.Color("#FFFFFF")
)

// FormatDurationCompact formats a duration in MM:SS format (e.g., "2:05").
// Negative durations are formatted with a leading "-" (e.g., "-1:05").
func FormatDurationCompact(d time.Duration) string {
sign := ""
if d < 0 {
sign = "-"
d = -d
}
d = d.Round(time.Second)
totalSeconds := int(d.Seconds())
minutes := totalSeconds / 60
seconds := totalSeconds % 60
return fmt.Sprintf("%s%d:%02d", sign, minutes, seconds)
}

// FormatDurationHuman formats a duration in human-readable format (e.g., "1h 30m", "5m", "30s").
func FormatDurationHuman(d time.Duration) string {
if d < 0 {
Expand Down Expand Up @@ -64,20 +49,6 @@ func FormatDurationHuman(d time.Duration) string {
return fmt.Sprintf("%ds", seconds)
}

// FormatInterval formats a duration as a compact interval string (e.g., "5s", "2m30s").
func FormatInterval(d time.Duration) string {
seconds := int(d.Seconds())
if seconds < 60 {
return fmt.Sprintf("%ds", seconds)
}
minutes := seconds / 60
seconds %= 60
if seconds == 0 {
return fmt.Sprintf("%dm", minutes)
}
return fmt.Sprintf("%dm%ds", minutes, seconds)
}

// maskTokenPreview masks token for preview display (shows first 8 and last 4 chars)
func maskTokenPreview(token string) string {
if len(token) <= 16 {
Expand Down
Loading