Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions timeout.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package timeout

import (
"fmt"
"net/http"
"runtime/debug"
"time"

"github.com/gin-gonic/gin"
Expand All @@ -20,49 +23,77 @@ func New(opts ...Option) gin.HandlerFunc {
response: defaultResponse,
}

// Loop through each option
// Apply each option to the Timeout instance
for _, opt := range opts {
if opt == nil {
panic("timeout Option not be nil")
panic("timeout Option must not be nil")
}

// Call the option giving the instantiated
// Call the option to configure the Timeout instance
opt(t)
}

// If timeout is not set or is negative, return the original handler directly (no timeout logic).
if t.timeout <= 0 {
return t.handler
}

// Initialize the buffer pool for response writers.
bufPool = &BufferPool{}

return func(c *gin.Context) {
// Channel to signal handler completion.
finish := make(chan struct{}, 1)
panicChan := make(chan interface{}, 1)
// panicChan transmits both the panic value and the stack trace.
type panicInfo struct {
Value interface{}
Stack []byte
}
panicChan := make(chan panicInfo, 1)

// Swap the response writer with a buffered writer.
w := c.Writer
buffer := bufPool.Get()
tw := NewWriter(w, buffer)
c.Writer = tw
buffer.Reset()

// Run the handler in a separate goroutine to enforce timeout and catch panics.
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
// Capture both the panic value and the stack trace.
panicChan <- panicInfo{
Value: p,
Stack: debug.Stack(),
}
}
}()
t.handler(c)
finish <- struct{}{}
}()

select {
case p := <-panicChan:
case pi := <-panicChan:
// Handler panicked: free buffer, restore writer, and print stack trace if in debug mode.
tw.FreeBuffer()
c.Writer = w
panic(p)
// If in debug mode, write error and stack trace to response for easier debugging.
if gin.IsDebugging() {
// Add the panic error to Gin's error list and write 500 status and stack trace to response.
// Check the error return value of c.Error to satisfy errcheck linter.
_ = c.Error(fmt.Errorf("%v", pi.Value))
c.Writer.WriteHeader(http.StatusInternalServerError)
// Use fmt.Fprintf instead of Write([]byte(fmt.Sprintf(...))) to satisfy staticcheck.
_, _ = fmt.Fprintf(c.Writer, "panic caught: %v\n", pi.Value)
_, _ = c.Writer.Write([]byte("Panic stack trace:\n"))
_, _ = c.Writer.Write(pi.Stack)
}
// In non-debug mode, re-throw the panic to be handled by the upper middleware.
panic(pi)

case <-finish:
// Handler finished successfully: flush buffer to response.
tw.mu.Lock()
defer tw.mu.Unlock()
dst := tw.ResponseWriter.Header()
Expand All @@ -77,6 +108,7 @@ func New(opts ...Option) gin.HandlerFunc {
bufPool.Put(buffer)

case <-time.After(t.timeout):
// Handler timed out: abort context and write timeout response.
c.Abort()
tw.mu.Lock()
defer tw.mu.Unlock()
Expand Down
54 changes: 33 additions & 21 deletions timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,6 @@ func TestSuccess(t *testing.T) {
assert.Equal(t, "", w.Body.String())
}

func panicResponse(c *gin.Context) {
panic("test")
}

func TestPanic(t *testing.T) {
r := gin.New()
r.Use(gin.Recovery())
r.GET("/", New(
WithTimeout(1*time.Second),
WithHandler(panicResponse),
))

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil)
r.ServeHTTP(w, req)

assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "", w.Body.String())
}

func TestLargeResponse(t *testing.T) {
r := gin.New()
r.GET("/slow", New(
Expand Down Expand Up @@ -132,7 +112,10 @@ func TestLargeResponse(t *testing.T) {
wg.Wait()
}

// Test to ensure no further middleware is executed after timeout (covers c.Next() removal)
/*
Test to ensure no further middleware is executed after timeout (covers c.Next() removal)
This test verifies that after a timeout occurs, no subsequent middleware is executed.
*/
func TestNoNextAfterTimeout(t *testing.T) {
r := gin.New()
called := false
Expand All @@ -151,3 +134,32 @@ func TestNoNextAfterTimeout(t *testing.T) {
assert.Equal(t, http.StatusRequestTimeout, w.Code)
assert.False(t, called, "next middleware should not be called after timeout")
}

/*
TestTimeoutPanic: verifies the behavior when a panic occurs inside a handler wrapped by the timeout middleware.
This test ensures that a panic in the handler is caught by CustomRecovery and returns a 500 status code
with the panic message.
*/
func TestTimeoutPanic(t *testing.T) {
r := gin.New()
// Use CustomRecovery to catch panics and return a custom error message.
r.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
c.String(http.StatusInternalServerError, "panic caught: %v", recovered)
}))

// Register the timeout middleware; the handler will panic.
r.GET("/panic", New(
WithTimeout(100*time.Millisecond),
WithHandler(func(c *gin.Context) {
panic("timeout panic test")
}),
))

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(context.Background(), "GET", "/panic", nil)
r.ServeHTTP(w, req)

// Verify the response status code and body.
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, w.Body.String(), "panic caught: timeout panic test")
}
Loading