From ed5e207cdc7ff9244238d821d57c9244a623ccc8 Mon Sep 17 00:00:00 2001 From: Jannis Mattheis Date: Fri, 19 Jan 2018 22:08:44 +0100 Subject: [PATCH 1/2] Add models --- model/message.go | 9 +++++++++ model/token.go | 12 ++++++++++++ model/user.go | 9 +++++++++ 3 files changed, 30 insertions(+) create mode 100644 model/message.go create mode 100644 model/token.go create mode 100644 model/user.go diff --git a/model/message.go b/model/message.go new file mode 100644 index 00000000..44069e59 --- /dev/null +++ b/model/message.go @@ -0,0 +1,9 @@ +package model + +type Message struct { + ID uint `gorm:"primary_key" gorm:"AUTO_INCREMENT;primary_key;index"` + TokenID string + Message string + Title string + Priority int +} diff --git a/model/token.go b/model/token.go new file mode 100644 index 00000000..e295041d --- /dev/null +++ b/model/token.go @@ -0,0 +1,12 @@ +package model + +type Token struct { + Name string + DefaultTitle string + Description string + Icon string + WriteOnly bool + UserID uint `gorm:"index"` + Id string `gorm:"primary_key;unique_index"` + Messages []Message +} diff --git a/model/user.go b/model/user.go new file mode 100644 index 00000000..e995eb13 --- /dev/null +++ b/model/user.go @@ -0,0 +1,9 @@ +package model + +type User struct { + ID uint `gorm:"primary_key;unique_index" gorm:"AUTO_INCREMENT"` + Name string + Pass []byte + Admin bool + Tokens []Token +} From 223322f31552da11f91d330a50889a4655b9f440 Mon Sep 17 00:00:00 2001 From: Jannis Mattheis Date: Fri, 19 Jan 2018 22:19:27 +0100 Subject: [PATCH 2/2] Add authentication middleware --- auth/authentication.go | 101 +++++++++++++++++++++++ auth/authentication_test.go | 155 ++++++++++++++++++++++++++++++++++++ auth/password.go | 17 ++++ auth/password_test.go | 21 +++++ 4 files changed, 294 insertions(+) create mode 100644 auth/authentication.go create mode 100644 auth/authentication_test.go create mode 100644 auth/password.go create mode 100644 auth/password_test.go diff --git a/auth/authentication.go b/auth/authentication.go new file mode 100644 index 00000000..f377901e --- /dev/null +++ b/auth/authentication.go @@ -0,0 +1,101 @@ +package auth + +import ( + "errors" + "github.com/gin-gonic/gin" + "github.com/jmattheis/memo/model" + "strings" +) + +const ( + headerName = "Authorization" + headerSchema = "ApiKey " + typeAdmin = 0 + typeAll = 1 + typeWriteOnly = 2 +) + +type Database interface { + GetTokenById(id string) *model.Token + GetUserByName(name string) *model.User + GetUserById(id uint) *model.User +} + +type Auth struct { + DB Database +} + +func (a *Auth) RequireAdmin() gin.HandlerFunc { + return a.requireToken(typeAdmin) +} + +func (a *Auth) RequireAll() gin.HandlerFunc { + return a.requireToken(typeAll) +} + +func (a *Auth) RequireWrite() gin.HandlerFunc { + return a.requireToken(typeWriteOnly) +} + +func (a *Auth) tokenFromQueryOrHeader(ctx *gin.Context) *model.Token { + if token := a.tokenFromQuery(ctx); token != nil { + return token + } else if token := a.tokenFromHeader(ctx); token != nil { + return token + } + return nil +} + +func (a *Auth) tokenFromQuery(ctx *gin.Context) *model.Token { + if token := ctx.Request.URL.Query().Get("token"); token != "" { + return a.DB.GetTokenById(token) + } + return nil +} + +func (a *Auth) tokenFromHeader(ctx *gin.Context) *model.Token { + if header := ctx.Request.Header.Get(headerName); header != "" && strings.HasPrefix(header, headerSchema) { + return a.DB.GetTokenById(strings.TrimPrefix(header, headerSchema)) + } + return nil +} + +func (a *Auth) userFromBasicAuth(ctx *gin.Context) *model.User { + if name, pass, ok := ctx.Request.BasicAuth(); ok { + if user := a.DB.GetUserByName(name); user != nil && ComparePassword(user.Pass, []byte(pass)) { + return user + } + } + return nil +} + +func (a *Auth) isAuthenticated(checkType int, token *model.Token, user *model.User) bool { + if token == nil && user == nil { + return false + } + + switch checkType { + case typeWriteOnly: + return true + case typeAll: + return user != nil || (token != nil && !token.WriteOnly) + default: + if user == nil { + user = a.DB.GetUserById(token.UserID) + } + return user != nil && user.Admin + } +} + +func (a *Auth) requireToken(checkType int) gin.HandlerFunc { + return func(ctx *gin.Context) { + token := a.tokenFromQueryOrHeader(ctx) + user := a.userFromBasicAuth(ctx) + + if a.isAuthenticated(checkType, token, user) { + ctx.Next() + } else { + ctx.AbortWithError(401, errors.New("could not authenticate")) + } + } +} diff --git a/auth/authentication_test.go b/auth/authentication_test.go new file mode 100644 index 00000000..b7b7b1b7 --- /dev/null +++ b/auth/authentication_test.go @@ -0,0 +1,155 @@ +package auth + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/jmattheis/memo/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "net/http/httptest" + "testing" +) + +func TestSuite(t *testing.T) { + suite.Run(t, new(AuthenticationSuite)) +} + +type AuthenticationSuite struct { + suite.Suite + auth *Auth +} + +func (s *AuthenticationSuite) SetupSuite() { + gin.SetMode(gin.TestMode) + s.auth = &Auth{&DBMock{}} +} + +func (s *AuthenticationSuite) TestQueryToken() { + s.assertQueryRequest("token", "ergerogerg", s.auth.RequireWrite, 401) + s.assertQueryRequest("token", "ergerogerg", s.auth.RequireAll, 401) + s.assertQueryRequest("token", "ergerogerg", s.auth.RequireAdmin, 401) + + s.assertQueryRequest("tokenx", "all", s.auth.RequireWrite, 401) + s.assertQueryRequest("tokenx", "all", s.auth.RequireAll, 401) + s.assertQueryRequest("tokenx", "all", s.auth.RequireAdmin, 401) + + s.assertQueryRequest("token", "writeonly", s.auth.RequireWrite, 200) + s.assertQueryRequest("token", "writeonly", s.auth.RequireAll, 401) + s.assertQueryRequest("token", "writeonly", s.auth.RequireAdmin, 401) + + s.assertQueryRequest("token", "all", s.auth.RequireWrite, 200) + s.assertQueryRequest("token", "all", s.auth.RequireAll, 200) + s.assertQueryRequest("token", "all", s.auth.RequireAdmin, 401) + + s.assertQueryRequest("token", "admin", s.auth.RequireWrite, 200) + s.assertQueryRequest("token", "admin", s.auth.RequireAll, 200) + s.assertQueryRequest("token", "admin", s.auth.RequireAdmin, 200) +} + +func (s *AuthenticationSuite) assertQueryRequest(key, value string, f fMiddleware, code int) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest("GET", fmt.Sprintf("/?%s=%s", key, value), nil) + f()(ctx) + assert.Equal(s.T(), code, recorder.Code) +} + +func (s *AuthenticationSuite) TestNothingProvided() { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest("GET", "/", nil) + s.auth.RequireWrite()(ctx) + assert.Equal(s.T(), 401, recorder.Code) +} + +func (s *AuthenticationSuite) TestHeaderApiKeyToken() { + s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "ergerogerg", s.auth.RequireAdmin, 401) + + s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "ApiKey ergerogerg", s.auth.RequireAdmin, 401) + + s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorizationx", "ApiKey all", s.auth.RequireAdmin, 401) + + s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "ApiKey writeonly", s.auth.RequireAdmin, 401) + + s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireAll, 200) + s.assertHeaderRequest("Authorization", "ApiKey all", s.auth.RequireAdmin, 401) + + s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireAll, 200) + s.assertHeaderRequest("Authorization", "ApiKey admin", s.auth.RequireAdmin, 200) +} + +func (s *AuthenticationSuite) TestBasicAuth() { + s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "Basic ergerogerg", s.auth.RequireAdmin, 401) + + // user existing:pw + s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireAll, 200) + s.assertHeaderRequest("Authorization", "Basic ZXhpc3Rpbmc6cHc=", s.auth.RequireAdmin, 401) + + // user admin:pw + s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireWrite, 200) + s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireAll, 200) + s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHc=", s.auth.RequireAdmin, 200) + + // user admin:pwx + s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireWrite, 401) + s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireAll, 401) + s.assertHeaderRequest("Authorization", "Basic YWRtaW46cHd4", s.auth.RequireAdmin, 401) +} + +func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddleware, code int) { + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest("GET", "/", nil) + ctx.Request.Header.Set(key, value) + f()(ctx) + assert.Equal(s.T(), code, recorder.Code) +} + +type fMiddleware func() gin.HandlerFunc +type DBMock struct{} + +func (d *DBMock) GetTokenById(id string) *model.Token { + if id == "writeonly" { + return &model.Token{Id: "valid", WriteOnly: true, UserID: 1} + } + if id == "all" { + return &model.Token{Id: "valid", WriteOnly: false, UserID: 1} + } + if id == "admin" { + return &model.Token{Id: "valid", WriteOnly: false, UserID: 2} + } + return nil +} + +func (d *DBMock) GetUserByName(name string) *model.User { + if name == "existing" { + return &model.User{Name: "existing", Pass: CreatePassword("pw")} + } + if name == "admin" { + return &model.User{Name: "admin", Pass: CreatePassword("pw"), Admin: true} + } + return nil +} +func (d *DBMock) GetUserById(id uint) *model.User { + if id == 1 { + return &model.User{Name: "existing", Pass: CreatePassword("pw"), Admin: false} + } + + if id == 2 { + return &model.User{Name: "existing", Pass: CreatePassword("pw"), Admin: true} + } + return nil +} diff --git a/auth/password.go b/auth/password.go new file mode 100644 index 00000000..9d475065 --- /dev/null +++ b/auth/password.go @@ -0,0 +1,17 @@ +package auth + +import "golang.org/x/crypto/bcrypt" + +var strength = 13 + +func CreatePassword(pw string) []byte { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(pw), strength) + if err != nil { + panic(err) + } + return hashedPassword +} + +func ComparePassword(hashedPassword, password []byte) bool { + return bcrypt.CompareHashAndPassword(hashedPassword, password) == nil +} diff --git a/auth/password_test.go b/auth/password_test.go new file mode 100644 index 00000000..cd7025c8 --- /dev/null +++ b/auth/password_test.go @@ -0,0 +1,21 @@ +package auth + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestPasswordSuccess(t *testing.T) { + password := CreatePassword("secret") + assert.Equal(t, true, ComparePassword(password, []byte("secret"))) +} + +func TestPasswordFailure(t *testing.T) { + password := CreatePassword("secret") + assert.Equal(t, false, ComparePassword(password, []byte("secretx"))) +} + +func TestBCryptFailure(t *testing.T) { + strength = 12312 // invalid value + assert.Panics(t, func() { CreatePassword("secret") }) +}