Skip to content

Commit

Permalink
Refactor JWT authentication and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nilBorodulia committed Apr 2, 2024
1 parent 73387ef commit 85c1045
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 185 deletions.
90 changes: 45 additions & 45 deletions auth_header_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,51 @@ package rest

import (
"fmt"
"net/http"
"github.com/golang-jwt/jwt"
"github.com/golang-jwt/jwt"
"net/http"
)

func AuthenticationJwt(headerName, secret string, userCondition func(claims map[string]interface{}) error) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if r.Header[headerName] == nil {
http.Error(w, "Can not find token in header", http.StatusForbidden)
return
}

token, _ := jwt.Parse(r.Header[headerName][0], func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("[ERROR] There was an error in parsing")
}

return []byte(secret), nil
})

if token == nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}

if !token.Valid {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}

claims, ok := token.Claims.(jwt.MapClaims)

if !ok {
w.Write([]byte("couldn't parse claims"));
w.WriteHeader(http.StatusUnauthorized)
return
}

if err := userCondition(claims); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}

next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func AuthenticationJwt(headerName, secret string, userCondition func(claims map[string]interface{}) error) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if r.Header[headerName] == nil {
http.Error(w, "Can not find token in header", http.StatusForbidden)
return
}

token, _ := jwt.Parse(r.Header[headerName][0], func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("[ERROR] There was an error in parsing")
}

return []byte(secret), nil
})

if token == nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}

if !token.Valid {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}

claims, ok := token.Claims.(jwt.MapClaims)

if !ok {
w.Write([]byte("couldn't parse claims"))
w.WriteHeader(http.StatusUnauthorized)
return
}

if err := userCondition(claims); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}

next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
196 changes: 98 additions & 98 deletions auth_header_jwt_test.go
Original file line number Diff line number Diff line change
@@ -1,113 +1,113 @@
package rest

import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"net/http"
"net/http/httptest"
"testing"
"io"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"fmt"
)

func TestHeaderJwtTokenAuth(t *testing.T) {
jwtToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.WKQfGgHiRhXdkdz6Qy90gMQhYf3uK-GMeyAQBEs1EbQ"
jwtFail := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.1F5StBaWKNe53iB2919Agg3nMcCdwINDWlT0sNBaMbE"
jwtToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.WKQfGgHiRhXdkdz6Qy90gMQhYf3uK-GMeyAQBEs1EbQ"
jwtFail := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.1F5StBaWKNe53iB2919Agg3nMcCdwINDWlT0sNBaMbE"

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("blabla blabla"))
require.NoError(t, err)
})
headerName := "Api-Token"
ts := httptest.NewServer(AuthenticationJwt(
headerName,
"1234567890",
func(claims map[string]interface{}) error {return nil})(handler))
defer ts.Close()
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, jwtToken)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
defer resp.Body.Close()
}
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, "invalid")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Invalid token\n", string(b))
}
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, jwtFail)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(b))
}
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Can not find token in header\n", string(b))
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("blabla blabla"))
require.NoError(t, err)
})
headerName := "Api-Token"
ts := httptest.NewServer(AuthenticationJwt(
headerName,
"1234567890",
func(claims map[string]interface{}) error { return nil })(handler))
defer ts.Close()
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, jwtToken)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
defer resp.Body.Close()
}
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, "invalid")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Invalid token\n", string(b))
}
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, jwtFail)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(b))
}
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Can not find token in header\n", string(b))
}
}

func TestHeaderJwtTokenAuthCheckClaim(t *testing.T) {
jwtToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMTIzIn0.tsuWzcw0zCYzHoq0Kflun7cxVJWKMdQwWczNhU2h2IQ"
jwtFail := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJub3RfdXNlciI6IjEyMyJ9.wGj6rh93KK83eaehCoxwmMCyEvyEQXadeJykayMkEd8"
jwtToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMTIzIn0.tsuWzcw0zCYzHoq0Kflun7cxVJWKMdQwWczNhU2h2IQ"
jwtFail := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJub3RfdXNlciI6IjEyMyJ9.wGj6rh93KK83eaehCoxwmMCyEvyEQXadeJykayMkEd8"

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("blabla blabla"))
require.NoError(t, err)
})
headerName := "Api-Token"
ts := httptest.NewServer(AuthenticationJwt(
headerName,
"1234567890",
func(claims map[string]interface{}) error {
if claims["user_id"] == nil {
return fmt.Errorf("user_id not found")
}
return nil
})(handler))
defer ts.Close()
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, jwtToken)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
defer resp.Body.Close()
}
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, jwtFail)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "user_id not found\n", string(b))
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("blabla blabla"))
require.NoError(t, err)
})
headerName := "Api-Token"
ts := httptest.NewServer(AuthenticationJwt(
headerName,
"1234567890",
func(claims map[string]interface{}) error {
if claims["user_id"] == nil {
return fmt.Errorf("user_id not found")
}
return nil
})(handler))
defer ts.Close()
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, jwtToken)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
defer resp.Body.Close()
}
{
req, err := http.NewRequest("GET", ts.URL+"/ping", nil)
require.NoError(t, err)
req.Header.Set(headerName, jwtFail)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "user_id not found\n", string(b))
}
}
23 changes: 12 additions & 11 deletions auth_header_token.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
package rest

import (
"net/http"
"net/http"
)

func Authentication(headerName, token string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
apiToken := r.Header.Get(headerName)
if apiToken != token {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
}
return http.HandlerFunc(fn)
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
apiToken := r.Header.Get(headerName)
if apiToken != token {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
Loading

0 comments on commit 85c1045

Please sign in to comment.