From faf35444001af66cf4c3cf878bea02c4252770b0 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 8 Mar 2026 12:22:12 +0800 Subject: [PATCH 1/4] refactor(util): simplify CryptoRandomBytes parameter and improve test coverage - Change CryptoRandomBytes parameter type from int64 to int, removing unnecessary casts at call sites - Migrate context_test.go and url_test.go to use testify assertions for consistency - Add edge case tests for zero length, odd length, empty token/salt, and invalid URLs Co-Authored-By: Claude Opus 4.6 --- internal/handlers/oauth_handler.go | 2 +- internal/store/sqlite.go | 2 +- internal/util/context_test.go | 32 ++++++-------------- internal/util/crypto.go | 4 +-- internal/util/crypto_test.go | 48 ++++++++++++++++++++++++++++++ internal/util/url_test.go | 43 +++++++++++++++++++++++--- 6 files changed, 100 insertions(+), 31 deletions(-) diff --git a/internal/handlers/oauth_handler.go b/internal/handlers/oauth_handler.go index 889d537a..0086ccfd 100644 --- a/internal/handlers/oauth_handler.go +++ b/internal/handlers/oauth_handler.go @@ -249,7 +249,7 @@ func (h *OAuthHandler) OAuthCallback(c *gin.Context) { // generateRandomState returns a URL-safe base64-encoded string of nBytes random bytes. func generateRandomState(nBytes int) (string, error) { - b, err := util.CryptoRandomBytes(int64(nBytes)) + b, err := util.CryptoRandomBytes(nBytes) if err != nil { return "", err } diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index 5f019236..18575458 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -129,7 +129,7 @@ func (s *Store) Close(ctx context.Context) error { // generateRandomPassword generates a random password of specified length. // Uses base64url encoding and truncates to length printable characters. func generateRandomPassword(length int) (string, error) { - b, err := util.CryptoRandomBytes(int64(length)) + b, err := util.CryptoRandomBytes(length) if err != nil { return "", err } diff --git a/internal/util/context_test.go b/internal/util/context_test.go index 4bb4ea4f..f147608c 100644 --- a/internal/util/context_test.go +++ b/internal/util/context_test.go @@ -3,6 +3,9 @@ package util import ( "context" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSetIPContext(t *testing.T) { @@ -28,20 +31,13 @@ func TestSetIPContext(t *testing.T) { ctx := context.Background() newCtx := SetIPContext(ctx, tt.ip) - if newCtx == nil { - t.Fatal("SetIPContext returned nil context") - } + require.NotNil(t, newCtx, "SetIPContext returned nil context") - // Try to retrieve the IP retrievedIP := GetIPFromContext(newCtx) if tt.expected { - if retrievedIP != tt.ip { - t.Errorf("Expected IP %s, got %s", tt.ip, retrievedIP) - } + assert.Equal(t, tt.ip, retrievedIP) } else { - if retrievedIP != "" { - t.Errorf("Expected empty IP, but got %s", retrievedIP) - } + assert.Empty(t, retrievedIP) } }) } @@ -80,15 +76,12 @@ func TestGetIPFromContext(t *testing.T) { } ip := GetIPFromContext(ctx) - if ip != tt.expected { - t.Errorf("Expected IP %q, got %q", tt.expected, ip) - } + assert.Equal(t, tt.expected, ip) }) } } func TestIPContextChaining(t *testing.T) { - // Test that context values are preserved when chaining type testKey int const testKeyOther testKey = 0 @@ -96,13 +89,6 @@ func TestIPContextChaining(t *testing.T) { ctx = context.WithValue(ctx, testKeyOther, "other_value") ctx = SetIPContext(ctx, "192.168.1.1") - // Check IP is accessible - if GetIPFromContext(ctx) != "192.168.1.1" { - t.Error("IP context was not preserved") - } - - // Check other values are accessible - if val := ctx.Value(testKeyOther); val != "other_value" { - t.Error("Other context values were not preserved") - } + assert.Equal(t, "192.168.1.1", GetIPFromContext(ctx)) + assert.Equal(t, "other_value", ctx.Value(testKeyOther)) } diff --git a/internal/util/crypto.go b/internal/util/crypto.go index c35ff492..86a31af1 100644 --- a/internal/util/crypto.go +++ b/internal/util/crypto.go @@ -9,7 +9,7 @@ import ( ) // CryptoRandomBytes generates cryptographically secure random bytes -func CryptoRandomBytes(length int64) ([]byte, error) { +func CryptoRandomBytes(length int) ([]byte, error) { buf := make([]byte, length) _, err := rand.Read(buf) return buf, err @@ -17,7 +17,7 @@ func CryptoRandomBytes(length int64) ([]byte, error) { // CryptoRandomString generates a random hex string for salts func CryptoRandomString(length int) (string, error) { - randomBytes, err := CryptoRandomBytes(int64((length + 1) / 2)) + randomBytes, err := CryptoRandomBytes((length + 1) / 2) if err != nil { return "", err } diff --git a/internal/util/crypto_test.go b/internal/util/crypto_test.go index d835fe0c..7089ba0a 100644 --- a/internal/util/crypto_test.go +++ b/internal/util/crypto_test.go @@ -23,6 +23,12 @@ func TestCryptoRandomBytes(t *testing.T) { assert.NotEqual(t, bytes1, bytes2, "Random bytes should not be identical") }) + + t.Run("Zero length returns empty slice", func(t *testing.T) { + bytes, err := CryptoRandomBytes(0) + require.NoError(t, err) + assert.Empty(t, bytes) + }) } func TestCryptoRandomString(t *testing.T) { @@ -41,6 +47,28 @@ func TestCryptoRandomString(t *testing.T) { "Character '%c' is not a valid hex digit", c) } }) + + t.Run("Odd length produces correct length", func(t *testing.T) { + str, err := CryptoRandomString(7) + require.NoError(t, err) + assert.Len(t, str, 7) + }) + + t.Run("Length of 1", func(t *testing.T) { + str, err := CryptoRandomString(1) + require.NoError(t, err) + assert.Len(t, str, 1) + }) + + t.Run("Generate unique values", func(t *testing.T) { + str1, err := CryptoRandomString(32) + require.NoError(t, err) + + str2, err := CryptoRandomString(32) + require.NoError(t, err) + + assert.NotEqual(t, str1, str2) + }) } func TestSHA256Hex(t *testing.T) { @@ -109,4 +137,24 @@ func TestHashToken(t *testing.T) { assert.NotEqual(t, hash1, hash2) }) + + t.Run("Empty token produces valid hash", func(t *testing.T) { + hash := HashToken("", "some-salt") + assert.Len(t, hash, 100) + assert.NotEmpty(t, hash) + }) + + t.Run("Empty salt produces valid hash", func(t *testing.T) { + hash := HashToken("some-token", "") + assert.Len(t, hash, 100) + assert.NotEmpty(t, hash) + }) + + t.Run("Output contains only hex characters", func(t *testing.T) { + hash := HashToken("test-token", "test-salt") + for _, c := range hash { + assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'), + "Character '%c' is not a valid hex digit", c) + } + }) } diff --git a/internal/util/url_test.go b/internal/util/url_test.go index dea8b252..6f84d250 100644 --- a/internal/util/url_test.go +++ b/internal/util/url_test.go @@ -2,6 +2,8 @@ package util import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestIsRedirectSafe(t *testing.T) { @@ -147,6 +149,36 @@ func TestIsRedirectSafe(t *testing.T) { want: false, }, + // Attack vectors - non-http schemes + { + name: "ftp URL", + redirectURL: "ftp://evil.com/file", + baseURL: baseURL, + want: false, + }, + { + name: "file URL", + redirectURL: "file:///etc/passwd", + baseURL: baseURL, + want: false, + }, + + // Edge cases - invalid baseURL + { + name: "invalid baseURL with absolute redirect", + redirectURL: "http://localhost:8080/device", + baseURL: "://invalid-url", + want: false, + }, + + // Edge cases - invalid redirect URL + { + name: "unparseable redirect URL", + redirectURL: "http://[::1:bad", + baseURL: baseURL, + want: false, + }, + // Valid edge cases { name: "path with fragments", @@ -160,15 +192,18 @@ func TestIsRedirectSafe(t *testing.T) { baseURL: baseURL, want: true, }, + { + name: "absolute URL without host matches any baseURL", + redirectURL: "http:", + baseURL: baseURL, + want: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := IsRedirectSafe(tt.redirectURL, tt.baseURL) - if got != tt.want { - t.Errorf("IsRedirectSafe(%q, %q) = %v, want %v", - tt.redirectURL, tt.baseURL, got, tt.want) - } + assert.Equal(t, tt.want, got, "IsRedirectSafe(%q, %q)", tt.redirectURL, tt.baseURL) }) } } From 67db77e0077a016c96da801726a22a7fd7c3d780 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 8 Mar 2026 12:26:34 +0800 Subject: [PATCH 2/4] Update internal/util/url_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/util/url_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/util/url_test.go b/internal/util/url_test.go index 6f84d250..291fb80a 100644 --- a/internal/util/url_test.go +++ b/internal/util/url_test.go @@ -193,10 +193,10 @@ func TestIsRedirectSafe(t *testing.T) { want: true, }, { - name: "absolute URL without host matches any baseURL", + name: "scheme-only URL without host is unsafe", redirectURL: "http:", baseURL: baseURL, - want: true, + want: false, }, } From 442eab23a833e7002d574cc99eaaf8029eb9c4a7 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 8 Mar 2026 12:26:41 +0800 Subject: [PATCH 3/4] Update internal/util/url_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/util/url_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/util/url_test.go b/internal/util/url_test.go index 291fb80a..eaf742d7 100644 --- a/internal/util/url_test.go +++ b/internal/util/url_test.go @@ -203,7 +203,7 @@ func TestIsRedirectSafe(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := IsRedirectSafe(tt.redirectURL, tt.baseURL) - assert.Equal(t, tt.want, got, "IsRedirectSafe(%q, %q)", tt.redirectURL, tt.baseURL) + assert.Equalf(t, tt.want, got, "IsRedirectSafe(%q, %q)", tt.redirectURL, tt.baseURL) }) } } From 566cc4716ccf79ce5c6a9f7542d13679632cd90b Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 8 Mar 2026 12:26:51 +0800 Subject: [PATCH 4/4] Update internal/util/crypto_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/util/crypto_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/util/crypto_test.go b/internal/util/crypto_test.go index 7089ba0a..03f5a1f1 100644 --- a/internal/util/crypto_test.go +++ b/internal/util/crypto_test.go @@ -153,7 +153,7 @@ func TestHashToken(t *testing.T) { t.Run("Output contains only hex characters", func(t *testing.T) { hash := HashToken("test-token", "test-salt") for _, c := range hash { - assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'), + assert.Truef(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'), "Character '%c' is not a valid hex digit", c) } })