Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func New(opts ...Option) gin.HandlerFunc {
for k, vv := range tw.Header() {
dst[k] = vv
}
tw.ResponseWriter.WriteHeader(tw.code)

if _, err := tw.ResponseWriter.Write(buffer.Bytes()); err != nil {
panic(err)
}
Expand Down
5 changes: 4 additions & 1 deletion writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ func (w *Writer) Write(data []byte) (int, error) {
return w.body.Write(data)
}

// WriteHeader will write http status code
// WriteHeader sends an HTTP response header with the provided status code.
// If the response writer has already written headers or if a timeout has occurred,
// this method does nothing.
func (w *Writer) WriteHeader(code int) {
checkWriteHeaderCode(code)
if w.timeout || w.wroteHeaders {
Expand All @@ -48,6 +50,7 @@ func (w *Writer) WriteHeader(code int) {
defer w.mu.Unlock()

w.writeHeader(code)
w.ResponseWriter.WriteHeader(code)
}

func (w *Writer) writeHeader(code int) {
Expand Down
146 changes: 146 additions & 0 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package timeout

import (
"fmt"
"log"
"net/http"
"net/http/httptest"
"strconv"
Expand Down Expand Up @@ -57,3 +58,148 @@ func TestWriter_Status(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, strconv.Itoa(http.StatusInternalServerError), req.Header.Get("X-Status-Code-MW-Set"))
}

// testNew is a copy of New() with a small change to the timeoutHandler() function.
// ref: https://github.com/gin-contrib/timeout/issues/31
func testNew(duration time.Duration) gin.HandlerFunc {
return New(
WithTimeout(duration),
WithHandler(func(c *gin.Context) { c.Next() }),
WithResponse(timeoutHandler()),
)
}

// timeoutHandler returns a handler that returns a 504 Gateway Timeout error.
func timeoutHandler() gin.HandlerFunc {
gatewayTimeoutErr := struct {
Error string `json:"error"`
}{
Error: "Timed out.",
}

return func(c *gin.Context) {
log.Printf("request timed out: [method=%s,path=%s]",
c.Request.Method, c.Request.URL.Path)
c.JSON(http.StatusGatewayTimeout, gatewayTimeoutErr)
}
}

// TestHTTPStatusCode tests the HTTP status code of the response.
func TestHTTPStatusCode(t *testing.T) {
gin.SetMode(gin.ReleaseMode)

type testCase struct {
Name string
Method string
Path string
ExpStatusCode int
Handler gin.HandlerFunc
}

var (
cases = []testCase{
{
Name: "Plain text (200)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusOK,
Handler: func(ctx *gin.Context) {
ctx.String(http.StatusOK, "I'm text!")
},
},
{
Name: "Plain text (201)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusCreated,
Handler: func(ctx *gin.Context) {
ctx.String(http.StatusCreated, "I'm created!")
},
},
{
Name: "Plain text (204)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusNoContent,
Handler: func(ctx *gin.Context) {
ctx.String(http.StatusNoContent, "")
},
},
{
Name: "Plain text (400)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusBadRequest,
Handler: func(ctx *gin.Context) {
ctx.String(http.StatusBadRequest, "")
},
},
{
Name: "JSON (200)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusOK,
Handler: func(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"field": "value"})
},
},
{
Name: "JSON (201)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusCreated,
Handler: func(ctx *gin.Context) {
ctx.JSON(http.StatusCreated, gin.H{"field": "value"})
},
},
{
Name: "JSON (204)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusNoContent,
Handler: func(ctx *gin.Context) {
ctx.JSON(http.StatusNoContent, nil)
},
},
{
Name: "JSON (400)",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusBadRequest,
Handler: func(ctx *gin.Context) {
ctx.JSON(http.StatusBadRequest, nil)
},
},
{
Name: "No reply",
Method: http.MethodGet,
Path: "/me",
ExpStatusCode: http.StatusOK,
Handler: func(ctx *gin.Context) {},
},
}

initCase = func(c testCase) (*http.Request, *httptest.ResponseRecorder) {
return httptest.NewRequest(c.Method, c.Path, nil), httptest.NewRecorder()
}
)

for i := range cases {
t.Run(cases[i].Name, func(tt *testing.T) {
tt.Logf("Test case [%s]", cases[i].Name)

router := gin.Default()

router.Use(testNew(1 * time.Second))
router.GET("/*root", cases[i].Handler)

req, resp := initCase(cases[i])
router.ServeHTTP(resp, req)

if resp.Code != cases[i].ExpStatusCode {
tt.Errorf("response is different from expected:\nexp: >>>%d<<<\ngot: >>>%d<<<",
cases[i].ExpStatusCode, resp.Code)
}
})
}
}