Skip to content

Commit

Permalink
Add Ping, PanicRecovery, and AppInfo middlewares
Browse files Browse the repository at this point in the history
  • Loading branch information
nilBorodulia committed Jan 22, 2024
1 parent 307fb8e commit 2373339
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 29 deletions.
58 changes: 36 additions & 22 deletions middleware.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package rest

import (
"log"
"net/http"
"strings"
"encoding/json"
"log"
"net/http"
"runtime/debug"
"strings"
)

// Ping middleware response with pong to /ping. Stops chain if ping request detected
Expand All @@ -18,28 +18,42 @@ func Ping(next http.Handler) http.Handler {
return
}
if r.Method == "POST" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/ping") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
resp := make(map[string]string)
resp["message"] = "PONG"
jsonResp, _ := json.Marshal(resp)
_, _ = w.Write(jsonResp)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
resp := make(map[string]string)
resp["message"] = "PONG"
jsonResp, _ := json.Marshal(resp)
_, _ = w.Write(jsonResp)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}

func PanicRecovery(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if err := recover(); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
log.Println("An error occurred:", err)
log.Println(string(debug.Stack()))
}
}()
next.ServeHTTP(w, req)
})
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if err := recover(); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
log.Println("An error occurred:", err)
log.Println(string(debug.Stack()))
}
}()
next.ServeHTTP(w, req)
})
}

func AppInfo(app, author, version string) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Author", author)
w.Header().Set("App-Name", app)
w.Header().Set("App-Version", version)

h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
return f
}
36 changes: 29 additions & 7 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"bytes"
"encoding/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestMiddleware_Ping(t *testing.T) {
Expand Down Expand Up @@ -45,16 +45,16 @@ func TestMiddleware_PingPost(t *testing.T) {
ts := httptest.NewServer(Ping(handler))
defer ts.Close()

var jsonData = []byte("")
var jsonData = []byte("")

resp, err := http.Post(ts.URL + "/ping", "application/json", bytes.NewBuffer(jsonData))
resp, err := http.Post(ts.URL+"/ping", "application/json", bytes.NewBuffer(jsonData))
require.Nil(t, err)
assert.Equal(t, 200, resp.StatusCode)
defer resp.Body.Close()

var res map[string]interface{}

json.NewDecoder(resp.Body).Decode(&res)
json.NewDecoder(resp.Body).Decode(&res)

assert.NoError(t, err)
assert.Equal(t, "PONG", res["message"])
Expand All @@ -68,10 +68,32 @@ func TestMiddleware_PanicRecovery(t *testing.T) {
ts := httptest.NewServer(PanicRecovery(handler))
defer ts.Close()

var jsonData = []byte("")
var jsonData = []byte("")

resp, err := http.Post(ts.URL + "/error", "application/json", bytes.NewBuffer(jsonData))
resp, err := http.Post(ts.URL+"/error", "application/json", bytes.NewBuffer(jsonData))
require.Nil(t, err)
assert.Equal(t, 500, resp.StatusCode)
defer resp.Body.Close()
}

func TestMiddleware_AppInfo(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("bla bla"))
require.NoError(t, err)
})
ts := httptest.NewServer(AppInfo("app-name", "Nil", "12345")(handler))
defer ts.Close()

resp, err := http.Get(ts.URL + "/bla")
require.Nil(t, err)
assert.Equal(t, 200, resp.StatusCode)
defer resp.Body.Close()

b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)

assert.Equal(t, "bla bla", string(b))
assert.Equal(t, "app-name", resp.Header.Get("App-Name"))
assert.Equal(t, "12345", resp.Header.Get("App-Version"))
assert.Equal(t, "Nil", resp.Header.Get("Author"))
}

0 comments on commit 2373339

Please sign in to comment.