Skip to content

Commit

Permalink
Merge pull request #16 from jtrw/develop
Browse files Browse the repository at this point in the history
Add SizeLimit middleware for body size validation
  • Loading branch information
nilBora committed Apr 8, 2024
2 parents 8d410f7 + 0cf874b commit 415f385
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
40 changes: 40 additions & 0 deletions size_limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package rest

import (
"bytes"
"io"
"net/http"
)

// SizeLimit middleware checks if body size is above the limit and returns StatusRequestEntityTooLarge (413)
func SizeLimit(size int64) func(http.Handler) http.Handler {

return func(h http.Handler) http.Handler {

fn := func(w http.ResponseWriter, r *http.Request) {

// check ContentLength
if r.ContentLength > size {
w.WriteHeader(http.StatusRequestEntityTooLarge)
return
}

// check size of the actual body
content, err := io.ReadAll(io.LimitReader(r.Body, size+1))
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
_ = r.Body.Close() // the original body already consumed

if int64(len(content)) > size {
w.WriteHeader(http.StatusRequestEntityTooLarge)
return
}
r.Body = io.NopCloser(bytes.NewReader(content))
h.ServeHTTP(w, r)
}

return http.HandlerFunc(fn)
}
}
69 changes: 69 additions & 0 deletions size_limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package rest

import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/http/httputil"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSizeLimit(t *testing.T) {

tbl := []struct {
method string
body string
code int
}{
{"GET", "", 200},
{"POST", "1234567", 200},
{"POST", "1234567890", 200},
{"POST", "12345678901", 413},
{"POST", "1234567", 200},
}

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.NoError(t, err, "body read failed")
_, err = w.Write(body)
require.NoError(t, err, "body write failed")
dump, _ := httputil.DumpRequest(r, true)
t.Log(string(dump))
})

ts := httptest.NewServer(SizeLimit(10)(handler))
defer ts.Close()

for i, tt := range tbl {
i := i
tt := tt
for _, wrap := range []bool{false, true} {
wrap := wrap
t.Run(fmt.Sprintf("test-%d/%v", i, wrap), func(t *testing.T) {
client := http.Client{Timeout: 1 * time.Second}
var reader io.Reader = strings.NewReader(tt.body)
if wrap {
reader = io.NopCloser(reader) // to prevent ContentLength setting up
}
req, err := http.NewRequest(tt.method, fmt.Sprintf("%s/%d/%v", ts.URL, i, wrap), reader)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, tt.code, resp.StatusCode)

if resp.StatusCode != http.StatusRequestEntityTooLarge {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, tt.body, string(body), "body match")
}
})
}
}
}

0 comments on commit 415f385

Please sign in to comment.