Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/handlers/oauth_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (h *OAuthHandler) OAuthCallback(c *gin.Context) {

// generateRandomState returns a URL-safe base64-encoded string of nBytes random bytes.
func generateRandomState(nBytes int) (string, error) {
b, err := util.CryptoRandomBytes(int64(nBytes))
b, err := util.CryptoRandomBytes(nBytes)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/store/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (s *Store) Close(ctx context.Context) error {
// generateRandomPassword generates a random password of specified length.
// Uses base64url encoding and truncates to length printable characters.
func generateRandomPassword(length int) (string, error) {
b, err := util.CryptoRandomBytes(int64(length))
b, err := util.CryptoRandomBytes(length)
if err != nil {
return "", err
}
Expand Down
32 changes: 9 additions & 23 deletions internal/util/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package util
import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSetIPContext(t *testing.T) {
Expand All @@ -28,20 +31,13 @@ func TestSetIPContext(t *testing.T) {
ctx := context.Background()
newCtx := SetIPContext(ctx, tt.ip)

if newCtx == nil {
t.Fatal("SetIPContext returned nil context")
}
require.NotNil(t, newCtx, "SetIPContext returned nil context")

// Try to retrieve the IP
retrievedIP := GetIPFromContext(newCtx)
if tt.expected {
if retrievedIP != tt.ip {
t.Errorf("Expected IP %s, got %s", tt.ip, retrievedIP)
}
assert.Equal(t, tt.ip, retrievedIP)
} else {
if retrievedIP != "" {
t.Errorf("Expected empty IP, but got %s", retrievedIP)
}
assert.Empty(t, retrievedIP)
}
})
}
Expand Down Expand Up @@ -80,29 +76,19 @@ func TestGetIPFromContext(t *testing.T) {
}

ip := GetIPFromContext(ctx)
if ip != tt.expected {
t.Errorf("Expected IP %q, got %q", tt.expected, ip)
}
assert.Equal(t, tt.expected, ip)
})
}
}

func TestIPContextChaining(t *testing.T) {
// Test that context values are preserved when chaining
type testKey int
const testKeyOther testKey = 0

ctx := context.Background()
ctx = context.WithValue(ctx, testKeyOther, "other_value")
ctx = SetIPContext(ctx, "192.168.1.1")

// Check IP is accessible
if GetIPFromContext(ctx) != "192.168.1.1" {
t.Error("IP context was not preserved")
}

// Check other values are accessible
if val := ctx.Value(testKeyOther); val != "other_value" {
t.Error("Other context values were not preserved")
}
assert.Equal(t, "192.168.1.1", GetIPFromContext(ctx))
assert.Equal(t, "other_value", ctx.Value(testKeyOther))
}
4 changes: 2 additions & 2 deletions internal/util/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
)

// CryptoRandomBytes generates cryptographically secure random bytes
func CryptoRandomBytes(length int64) ([]byte, error) {
func CryptoRandomBytes(length int) ([]byte, error) {
buf := make([]byte, length)
_, err := rand.Read(buf)
return buf, err
}

// CryptoRandomString generates a random hex string for salts
func CryptoRandomString(length int) (string, error) {
randomBytes, err := CryptoRandomBytes(int64((length + 1) / 2))
randomBytes, err := CryptoRandomBytes((length + 1) / 2)
if err != nil {
return "", err
}
Expand Down
48 changes: 48 additions & 0 deletions internal/util/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ func TestCryptoRandomBytes(t *testing.T) {

assert.NotEqual(t, bytes1, bytes2, "Random bytes should not be identical")
})

t.Run("Zero length returns empty slice", func(t *testing.T) {
bytes, err := CryptoRandomBytes(0)
require.NoError(t, err)
assert.Empty(t, bytes)
})
}

func TestCryptoRandomString(t *testing.T) {
Expand All @@ -41,6 +47,28 @@ func TestCryptoRandomString(t *testing.T) {
"Character '%c' is not a valid hex digit", c)
}
})

t.Run("Odd length produces correct length", func(t *testing.T) {
str, err := CryptoRandomString(7)
require.NoError(t, err)
assert.Len(t, str, 7)
})

t.Run("Length of 1", func(t *testing.T) {
str, err := CryptoRandomString(1)
require.NoError(t, err)
assert.Len(t, str, 1)
})

t.Run("Generate unique values", func(t *testing.T) {
str1, err := CryptoRandomString(32)
require.NoError(t, err)

str2, err := CryptoRandomString(32)
require.NoError(t, err)

assert.NotEqual(t, str1, str2)
})
}

func TestSHA256Hex(t *testing.T) {
Expand Down Expand Up @@ -109,4 +137,24 @@ func TestHashToken(t *testing.T) {

assert.NotEqual(t, hash1, hash2)
})

t.Run("Empty token produces valid hash", func(t *testing.T) {
hash := HashToken("", "some-salt")
assert.Len(t, hash, 100)
assert.NotEmpty(t, hash)
})

t.Run("Empty salt produces valid hash", func(t *testing.T) {
hash := HashToken("some-token", "")
assert.Len(t, hash, 100)
assert.NotEmpty(t, hash)
})

t.Run("Output contains only hex characters", func(t *testing.T) {
hash := HashToken("test-token", "test-salt")
for _, c := range hash {
assert.Truef(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'),
"Character '%c' is not a valid hex digit", c)
}
})
}
43 changes: 39 additions & 4 deletions internal/util/url_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package util

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestIsRedirectSafe(t *testing.T) {
Expand Down Expand Up @@ -147,6 +149,36 @@ func TestIsRedirectSafe(t *testing.T) {
want: false,
},

// Attack vectors - non-http schemes
{
name: "ftp URL",
redirectURL: "ftp://evil.com/file",
baseURL: baseURL,
want: false,
},
{
name: "file URL",
redirectURL: "file:///etc/passwd",
baseURL: baseURL,
want: false,
},

// Edge cases - invalid baseURL
{
name: "invalid baseURL with absolute redirect",
redirectURL: "http://localhost:8080/device",
baseURL: "://invalid-url",
want: false,
},

// Edge cases - invalid redirect URL
{
name: "unparseable redirect URL",
redirectURL: "http://[::1:bad",
baseURL: baseURL,
want: false,
},

// Valid edge cases
{
name: "path with fragments",
Expand All @@ -160,15 +192,18 @@ func TestIsRedirectSafe(t *testing.T) {
baseURL: baseURL,
want: true,
},
{
name: "scheme-only URL without host is unsafe",
redirectURL: "http:",
baseURL: baseURL,
want: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsRedirectSafe(tt.redirectURL, tt.baseURL)
if got != tt.want {
t.Errorf("IsRedirectSafe(%q, %q) = %v, want %v",
tt.redirectURL, tt.baseURL, got, tt.want)
}
assert.Equalf(t, tt.want, got, "IsRedirectSafe(%q, %q)", tt.redirectURL, tt.baseURL)
})
}
}
Loading