Skip to content

Commit

Permalink
CSRF/RequestID mw: switch math/random usage to crypto/random
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas committed Jul 21, 2023
1 parent 3f8ae15 commit 626f13e
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 9 deletions.
4 changes: 2 additions & 2 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"time"

"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
)

type (
Expand Down Expand Up @@ -103,6 +102,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}

if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup
}
Expand Down Expand Up @@ -132,7 +132,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {

token := ""
if k, err := c.Cookie(config.CookieName); err != nil {
token = random.String(config.TokenLength) // Generate token
token = randomString(config.TokenLength)
} else {
token = k.Value // Reuse token
}
Expand Down
3 changes: 1 addition & 2 deletions middleware/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"testing"

"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -233,7 +232,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c))

// Valid CSRF token
token := random.String(32)
token := randomString(32)
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) {
Expand Down
3 changes: 1 addition & 2 deletions middleware/rate_limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"time"

"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
Expand Down Expand Up @@ -410,7 +409,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
func generateAddressList(count int) []string {
addrs := make([]string, count)
for i := 0; i < count; i++ {
addrs[i] = random.String(15)
addrs[i] = randomString(15)
}
return addrs
}
Expand Down
5 changes: 2 additions & 3 deletions middleware/request_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package middleware

import (
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
)

type (
Expand All @@ -12,7 +11,7 @@ type (
Skipper Skipper

// Generator defines a function to generate an ID.
// Optional. Default value random.String(32).
// Optional. Defaults to generator for random string of length 32.
Generator func() string

// RequestIDHandler defines a function which is executed for a request id.
Expand Down Expand Up @@ -73,5 +72,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
}

func generator() string {
return random.String(32)
return randomString(32)
}
17 changes: 17 additions & 0 deletions middleware/util.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package middleware

import (
"crypto/rand"
"fmt"
"strings"
)

Expand Down Expand Up @@ -52,3 +54,18 @@ func matchSubdomain(domain, pattern string) bool {
}
return false
}

func randomString(length uint8) string {
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
// we are out of random. let the request fail
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
}
for i, b := range bytes {
bytes[i] = charset[b%byte(len(charset))]
}
return string(bytes)
}
24 changes: 24 additions & 0 deletions middleware/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,27 @@ func Test_matchSubdomain(t *testing.T) {
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
}
}

func TestRandomString(t *testing.T) {
var testCases = []struct {
name string
whenLength uint8
expect string
}{
{
name: "ok, 16",
whenLength: 16,
},
{
name: "ok, 32",
whenLength: 32,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
uid := randomString(tc.whenLength)
assert.Len(t, uid, int(tc.whenLength))
})
}
}

0 comments on commit 626f13e

Please sign in to comment.