diff --git a/src/app/webapi/component/auth/component.go b/src/app/webapi/component/auth/component.go index 432d52b..35075cc 100644 --- a/src/app/webapi/component/auth/component.go +++ b/src/app/webapi/component/auth/component.go @@ -18,5 +18,5 @@ type Endpoint struct { // Routes will set up the endpoints. func (p *Endpoint) Routes(router component.IRouter) { - router.Get("/v1/auth", component.H(p.Index)) + router.Get("/v1/auth", p.Index) } diff --git a/src/app/webapi/component/auth/index_test.go b/src/app/webapi/component/auth/index_test.go index e208161..d8d124f 100644 --- a/src/app/webapi/component/auth/index_test.go +++ b/src/app/webapi/component/auth/index_test.go @@ -50,5 +50,5 @@ func TestIndexError(t *testing.T) { mux.ServeHTTP(w, r) assert.Equal(t, http.StatusInternalServerError, w.Code) - assert.Equal(t, `{"status":"Internal Server Error","message":"generate error"}`+"\n", w.Body.String()) + assert.Equal(t, `generate error`+"\n", w.Body.String()) } diff --git a/src/app/webapi/component/core_mock.go b/src/app/webapi/component/core_mock.go index 1621443..a7f205d 100644 --- a/src/app/webapi/component/core_mock.go +++ b/src/app/webapi/component/core_mock.go @@ -1,40 +1,16 @@ package component import ( - "log" - "app/webapi/internal/bind" "app/webapi/internal/response" "app/webapi/internal/testutil" - "app/webapi/pkg/database" "app/webapi/pkg/query" ) -// TestDatabase returns a test database. -func TestDatabase(dbSpecificDB bool) *database.DBW { - dbc := new(database.Connection) - dbc.Hostname = "127.0.0.1" - dbc.Port = 3306 - dbc.Username = "root" - dbc.Password = "" - dbc.Database = "webapitest" - dbc.Parameter = "parseTime=true&allowNativePasswords=true" - - connection, err := dbc.Connect(dbSpecificDB) - if err != nil { - log.Println("DB Error:", err) - } - - dbw := database.New(connection) - - return dbw -} - // NewCoreMock returns all mocked dependencies. func NewCoreMock() (Core, *CoreMock) { ml := new(testutil.MockLogger) - //md := new(testutil.MockDatabase) - md := TestDatabase(true) + md := testutil.ConnectDatabase(true) mq := query.New(md) mt := new(testutil.MockToken) resp := response.New() diff --git a/src/app/webapi/component/handler.go b/src/app/webapi/component/handler.go deleted file mode 100644 index 5fd7ee6..0000000 --- a/src/app/webapi/component/handler.go +++ /dev/null @@ -1,41 +0,0 @@ -package component - -import ( - "encoding/json" - "net/http" - - "app/webapi/internal/response" -) - -// H is used to wrapper all endpoint functions so they work with generic -// routers. -type H func(http.ResponseWriter, *http.Request) (int, error) - -// ServeHTTP handles all the errors from the HTTP handlers. -func (fn H) ServeHTTP(w http.ResponseWriter, r *http.Request) { - status, err := fn(w, r) - // Handle only errors. - if status >= 400 { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - - resp := new(response.OKResponse) - resp.Body.Status = http.StatusText(status) - if err != nil { - resp.Body.Message = err.Error() - } - - err := json.NewEncoder(w).Encode(resp.Body) - if err != nil { - w.Write([]byte(`{"status":"Internal Server Error","message":"problem encoding JSON"}`)) - return - } - } - - // Only output 500 errors. - /*if status >= 500 { - if err != nil { - log.Println(err) - } - }*/ -} diff --git a/src/app/webapi/component/interface.go b/src/app/webapi/component/interface.go index 6604c14..d11c86c 100644 --- a/src/app/webapi/component/interface.go +++ b/src/app/webapi/component/interface.go @@ -1,6 +1,7 @@ package component import ( + "app/webapi/pkg/router" "database/sql" "net/http" "time" @@ -10,20 +11,9 @@ import ( // IDatabase provides data query capabilities. type IDatabase interface { - Select(dest interface{}, query string, args ...interface{}) error - Get(dest interface{}, query string, args ...interface{}) error Exec(query string, args ...interface{}) (sql.Result, error) - QueryRowScan(dest interface{}, query string, args ...interface{}) error - - ExistsString(err error, s string) (bool, string, error) - Error(err error) error - AffectedRows(result sql.Result) int - - /*PaginatedResults(results interface{}, fn func() (results interface{}, total int, err error)) (total int, err error) - RecordExistsInt(fn func() (exists bool, ID int64, err error)) (exists bool, ID int64, err error) - RecordExistsString(fn func() (exists bool, ID string, err error)) (exists bool, ID string, err error) - AddRecordInt(fn func() (ID int64, err error)) (ID int64, err error) - AddRecordString(fn func() (ID string, err error)) (ID string, err error)*/ + Get(dest interface{}, query string, args ...interface{}) error + Select(dest interface{}, query string, args ...interface{}) error } // IQuery provides default queries. @@ -36,40 +26,6 @@ type IQuery interface { DeleteAll(dest query.IRecord) (affected int, err error) } -// IRecord provides table information. -type IRecord interface { - Table() string - PrimaryKey() string -} - -/* -// IQuery provides data query capabilities. -type IQuery interface { - Select(dest interface{}, query string, args ...interface{}) error - Get(dest interface{}, query string, args ...interface{}) error - Exec(query string, args ...interface{}) (sql.Result, error) - - ExistsString(err error, s string) (bool, string, error) - Error(err error) error - AffectedRows(result sql.Result) int -}*/ - -// IDatabase provides data query capabilities. -/*type IDatabase interface { - //LastInsertID(r sql.Result, err error) (int64, error) - //MySQLTimestamp(t time.Time) string - //GoTimestamp(s string) (time.Time, error) - - //ExistsID(err error, ID int64) (bool, int64, error) - - PaginatedResults(results interface{}, fn func() (results interface{}, total int, err error)) (total int, err error) - RecordExistsInt(fn func() (exists bool, ID int64, err error)) (exists bool, ID int64, err error) - RecordExistsString(fn func() (exists bool, ID string, err error)) (exists bool, ID string, err error) - AddRecordInt(fn func() (ID int64, err error)) (ID int64, err error) - AddRecordString(fn func() (ID string, err error)) (ID string, err error) - //ExecQuery(fn func() (err error)) (err error) -}*/ - // ILogger provides logging capabilities. type ILogger interface { //ControllerError(r *http.Request, err error, a ...interface{}) @@ -79,13 +35,13 @@ type ILogger interface { // IRouter provides routing capabilities. type IRouter interface { - Delete(path string, fn http.Handler) - Get(path string, fn http.Handler) - Head(path string, fn http.Handler) - Options(path string, fn http.Handler) - Patch(path string, fn http.Handler) - Post(path string, fn http.Handler) - Put(path string, fn http.Handler) + Delete(path string, fn router.Handler) + Get(path string, fn router.Handler) + Head(path string, fn router.Handler) + Options(path string, fn router.Handler) + Patch(path string, fn router.Handler) + Post(path string, fn router.Handler) + Put(path string, fn router.Handler) } // IBind provides bind and validation for requests. diff --git a/src/app/webapi/component/root/component.go b/src/app/webapi/component/root/component.go index 9796995..d52e337 100644 --- a/src/app/webapi/component/root/component.go +++ b/src/app/webapi/component/root/component.go @@ -18,5 +18,5 @@ type Endpoint struct { // Routes will set up the endpoints. func (p *Endpoint) Routes(router component.IRouter) { - router.Get("/v1", component.H(p.Index)) + router.Get("/v1", p.Index) } diff --git a/src/app/webapi/component/user/component.go b/src/app/webapi/component/user/component.go index 53c3326..33bd9f5 100644 --- a/src/app/webapi/component/user/component.go +++ b/src/app/webapi/component/user/component.go @@ -18,10 +18,10 @@ type Endpoint struct { // Routes will set up the endpoints. func (p *Endpoint) Routes(router component.IRouter) { - router.Post("/v1/user", component.H(p.Create)) - router.Get("/v1/user/:user_id", component.H(p.Show)) - router.Get("/v1/user", component.H(p.Index)) - router.Put("/v1/user/:user_id", component.H(p.Update)) - router.Delete("/v1/user/:user_id", component.H(p.Destroy)) - router.Delete("/v1/user", component.H(p.DestroyAll)) + router.Post("/v1/user", p.Create) + router.Get("/v1/user/:user_id", p.Show) + router.Get("/v1/user", p.Index) + router.Put("/v1/user/:user_id", p.Update) + router.Delete("/v1/user/:user_id", p.Destroy) + router.Delete("/v1/user", p.DestroyAll) } diff --git a/src/app/webapi/component/user/create_test.go b/src/app/webapi/component/user/create_test.go index 9892489..66fd872 100644 --- a/src/app/webapi/component/user/create_test.go +++ b/src/app/webapi/component/user/create_test.go @@ -9,7 +9,7 @@ import ( "app/webapi/component" "app/webapi/component/user" - "app/webapi/internal/testdb" + "app/webapi/internal/testutil" "app/webapi/pkg/router" "app/webapi/store" @@ -17,7 +17,7 @@ import ( ) func TestCreate(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() @@ -39,7 +39,7 @@ func TestCreate(t *testing.T) { } func TestCreateUserAlreadyExists(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() @@ -61,11 +61,11 @@ func TestCreateUserAlreadyExists(t *testing.T) { mux.ServeHTTP(w, r) assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), `{"status":"Bad Request","message":"user already exists"}`) + assert.Contains(t, w.Body.String(), `user already exists`) } func TestCreateBadEmail(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() diff --git a/src/app/webapi/component/user/destroy_all_test.go b/src/app/webapi/component/user/destroy_all_test.go index e32cdad..f7fc527 100644 --- a/src/app/webapi/component/user/destroy_all_test.go +++ b/src/app/webapi/component/user/destroy_all_test.go @@ -7,7 +7,7 @@ import ( "app/webapi/component" "app/webapi/component/user" - "app/webapi/internal/testdb" + "app/webapi/internal/testutil" "app/webapi/pkg/router" "app/webapi/store" @@ -15,7 +15,7 @@ import ( ) func TestDestroyAll(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() @@ -35,7 +35,7 @@ func TestDestroyAll(t *testing.T) { } func TestDestroyAllNoUsers(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() @@ -47,5 +47,5 @@ func TestDestroyAllNoUsers(t *testing.T) { mux.ServeHTTP(w, r) assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), `{"status":"Bad Request","message":"no users to delete"}`) + assert.Contains(t, w.Body.String(), `no users to delete`) } diff --git a/src/app/webapi/component/user/destroy_test.go b/src/app/webapi/component/user/destroy_test.go index e52f7fa..9e4f325 100644 --- a/src/app/webapi/component/user/destroy_test.go +++ b/src/app/webapi/component/user/destroy_test.go @@ -9,7 +9,7 @@ import ( "app/webapi/component" "app/webapi/component/user" - "app/webapi/internal/testdb" + "app/webapi/internal/testutil" "app/webapi/pkg/router" "app/webapi/store" @@ -17,7 +17,7 @@ import ( ) func TestDestroy(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() diff --git a/src/app/webapi/component/user/index_test.go b/src/app/webapi/component/user/index_test.go index 4e823e8..71be531 100644 --- a/src/app/webapi/component/user/index_test.go +++ b/src/app/webapi/component/user/index_test.go @@ -7,7 +7,7 @@ import ( "app/webapi/component" "app/webapi/component/user" - "app/webapi/internal/testdb" + "app/webapi/internal/testutil" "app/webapi/pkg/router" "app/webapi/store" @@ -15,7 +15,7 @@ import ( ) func TestIndexEmpty(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() @@ -30,7 +30,7 @@ func TestIndexEmpty(t *testing.T) { } func TestIndexOne(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() diff --git a/src/app/webapi/component/user/request_test.go b/src/app/webapi/component/user/request_test.go index 93e528e..867edd2 100644 --- a/src/app/webapi/component/user/request_test.go +++ b/src/app/webapi/component/user/request_test.go @@ -8,7 +8,7 @@ import ( "app/webapi/component" "app/webapi/component/user" - "app/webapi/internal/testdb" + "app/webapi/internal/testutil" "app/webapi/pkg/router" "github.com/stretchr/testify/assert" @@ -21,7 +21,7 @@ func TestRequestValidation(t *testing.T) { } { arr := strings.Split(v, " ") - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() diff --git a/src/app/webapi/component/user/show_test.go b/src/app/webapi/component/user/show_test.go index 6bbeda0..4f4d70e 100644 --- a/src/app/webapi/component/user/show_test.go +++ b/src/app/webapi/component/user/show_test.go @@ -7,7 +7,7 @@ import ( "app/webapi/component" "app/webapi/component/user" - "app/webapi/internal/testdb" + "app/webapi/internal/testutil" "app/webapi/pkg/router" "app/webapi/store" @@ -15,7 +15,7 @@ import ( ) func TestShowOne(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() @@ -34,7 +34,7 @@ func TestShowOne(t *testing.T) { } func TestShowNotFound(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() @@ -45,5 +45,5 @@ func TestShowNotFound(t *testing.T) { mux.ServeHTTP(w, r) assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), `{"status":"Bad Request","message":"item not found"}`) + assert.Contains(t, w.Body.String(), `item not found`) } diff --git a/src/app/webapi/component/user/update_test.go b/src/app/webapi/component/user/update_test.go index dc3b2a5..5295b95 100644 --- a/src/app/webapi/component/user/update_test.go +++ b/src/app/webapi/component/user/update_test.go @@ -9,7 +9,7 @@ import ( "app/webapi/component" "app/webapi/component/user" - "app/webapi/internal/testdb" + "app/webapi/internal/testutil" "app/webapi/pkg/router" "app/webapi/store" @@ -17,7 +17,7 @@ import ( ) func TestUpdateUserAllFields(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() @@ -51,7 +51,7 @@ func TestUpdateUserAllFields(t *testing.T) { } func TestUpdateMissingFields(t *testing.T) { - testdb.SetupTest(t) + testutil.LoadDatabase(t) core, _ := component.NewCoreMock() mux := router.New() diff --git a/src/app/webapi/internal/bind/bind_test.go b/src/app/webapi/internal/bind/bind_test.go index 452f1c2..bb864e8 100644 --- a/src/app/webapi/internal/bind/bind_test.go +++ b/src/app/webapi/internal/bind/bind_test.go @@ -18,8 +18,8 @@ func TestSuccess(t *testing.T) { mux := router.New() - mux.Post("/user/:user_id", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Post("/user/:user_id", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { called = true // swagger:parameters UserCreate @@ -43,6 +43,7 @@ func TestSuccess(t *testing.T) { assert.Equal(t, "10", req.UserID) assert.Equal(t, "john", req.FirstName) assert.Equal(t, "smith", req.LastName) + return http.StatusOK, nil })) form := url.Values{} @@ -62,8 +63,8 @@ func TestMissingPointer(t *testing.T) { mux := router.New() - mux.Post("/user/:user_id", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Post("/user/:user_id", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { called = true // swagger:parameters UserCreate @@ -86,6 +87,7 @@ func TestMissingPointer(t *testing.T) { assert.Equal(t, "", req.UserID) assert.Equal(t, "", req.FirstName) assert.Equal(t, "", req.LastName) + return http.StatusOK, nil })) form := url.Values{} diff --git a/src/app/webapi/internal/response/response.go b/src/app/webapi/internal/response/response.go index fe8b932..4279e00 100644 --- a/src/app/webapi/internal/response/response.go +++ b/src/app/webapi/internal/response/response.go @@ -100,9 +100,9 @@ func (o *Output) Results(w http.ResponseWriter, body interface{}, data interface return http.StatusOK, nil } -// OKResponse returns 200. -// swagger:response OKResponse -type OKResponse struct { +// GenericResponse returns any status code. +// swagger:response GenericResponse +type GenericResponse struct { // in: body Body struct { // Status contains the string of the HTTP status. @@ -128,26 +128,32 @@ type CreatedResponse struct { } } +// OKResponse returns 200. +// swagger:response OKResponse +type OKResponse struct { + GenericResponse +} + // BadRequestResponse returns 400. // swagger:response BadRequestResponse type BadRequestResponse struct { - OKResponse + GenericResponse } // UnauthorizedResponse returns 401. // swagger:response UnauthorizedResponse type UnauthorizedResponse struct { - OKResponse + GenericResponse } // NotFoundResponse returns 404. // swagger:response NotFoundResponse type NotFoundResponse struct { - OKResponse + GenericResponse } // InternalServerErrorResponse returns 500. // swagger:response InternalServerErrorResponse type InternalServerErrorResponse struct { - OKResponse + GenericResponse } diff --git a/src/app/webapi/internal/testdb/testdb.go b/src/app/webapi/internal/testdb/testdb.go deleted file mode 100644 index c9e2937..0000000 --- a/src/app/webapi/internal/testdb/testdb.go +++ /dev/null @@ -1,40 +0,0 @@ -package testdb - -import ( - "app/webapi/component" - "io/ioutil" - "log" - "os" - "strings" - "testing" -) - -// SetupTest will set up the DB for the testss. -func SetupTest(t *testing.T) { - db := component.TestDatabase(false) - db.Exec(`DROP DATABASE IF EXISTS webapitest`) - db.Exec(`CREATE DATABASE webapitest DEFAULT CHARSET = utf8 COLLATE = utf8_unicode_ci`) - - db = component.TestDatabase(true) - b, err := ioutil.ReadFile("../../../../../migration/tables-only.sql") - if err != nil { - log.Println(err) - os.Exit(1) - } - - // Split each statement. - stmts := strings.Split(string(b), ";") - for i, s := range stmts { - if i == len(stmts)-1 { - break - } - _, err = db.Exec(s) - if err != nil { - log.Println(err) - } - } - - //exit := m.Run() - //db.Exec(`DROP DATABASE IF EXISTS webapitest`) - //os.Exit(exit) -} diff --git a/src/app/webapi/internal/testutil/database.go b/src/app/webapi/internal/testutil/database.go new file mode 100644 index 0000000..00709fa --- /dev/null +++ b/src/app/webapi/internal/testutil/database.go @@ -0,0 +1,57 @@ +package testutil + +import ( + "io/ioutil" + "log" + "os" + "strings" + "testing" + + "app/webapi/pkg/database" +) + +// ConnectDatabase returns a test database connection. +func ConnectDatabase(dbSpecificDB bool) *database.DBW { + dbc := new(database.Connection) + dbc.Hostname = "127.0.0.1" + dbc.Port = 3306 + dbc.Username = "root" + dbc.Password = "" + dbc.Database = "webapitest" + dbc.Parameter = "parseTime=true&allowNativePasswords=true" + + connection, err := dbc.Connect(dbSpecificDB) + if err != nil { + log.Println("DB Error:", err) + } + + dbw := database.New(connection) + + return dbw +} + +// LoadDatabase will set up the DB for the tests. +func LoadDatabase(t *testing.T) { + db := ConnectDatabase(false) + db.Exec(`DROP DATABASE IF EXISTS webapitest`) + db.Exec(`CREATE DATABASE webapitest DEFAULT CHARSET = utf8 COLLATE = utf8_unicode_ci`) + + db = ConnectDatabase(true) + b, err := ioutil.ReadFile("../../../../../migration/tables-only.sql") + if err != nil { + log.Println(err) + os.Exit(1) + } + + // Split each statement. + stmts := strings.Split(string(b), ";") + for i, s := range stmts { + if i == len(stmts)-1 { + break + } + _, err = db.Exec(s) + if err != nil { + log.Println(err) + } + } +} diff --git a/src/app/webapi/internal/testutil/mock.go b/src/app/webapi/internal/testutil/mock.go deleted file mode 100644 index b3ce5fe..0000000 --- a/src/app/webapi/internal/testutil/mock.go +++ /dev/null @@ -1,289 +0,0 @@ -package testutil - -import ( - "database/sql" - "errors" - "net/http" - "reflect" - "time" -) - -// MockLogger . -type MockLogger struct{} - -// MockSQLResponse . -type MockSQLResponse struct{} - -// LastInsertId . -func (ms *MockSQLResponse) LastInsertId() (int64, error) { - return 0, nil -} - -// RowsAffected . -func (ms *MockSQLResponse) RowsAffected() (int64, error) { - return 0, nil -} - -// MockDatabase . -type MockDatabase struct{} - -// Select . -func (d *MockDatabase) Select(dest interface{}, query string, args ...interface{}) error { - return nil -} - -// Get . -func (d *MockDatabase) Get(dest interface{}, query string, args ...interface{}) error { - return nil -} - -// Exec . -func (d *MockDatabase) Exec(query string, args ...interface{}) (sql.Result, error) { - ms := new(MockSQLResponse) - return ms, nil -} - -// QueryRowScan . -//FIXME: This just returns nil. -func (d *MockDatabase) QueryRowScan(dest interface{}, query string, args ...interface{}) error { - return nil -} - -// ExistsString . -func (d *MockDatabase) ExistsString(err error, s string) (bool, string, error) { - return false, "", nil -} - -// Error . -func (d *MockDatabase) Error(err error) error { - return nil -} - -// AffectedRows . -func (d *MockDatabase) AffectedRows(result sql.Result) int { - return 0 -} - -// ***************************************************************************** - -type recordExistsIntFunc func() (exists bool, ID int64, err error) - -var ( - recordExistsInt recordExistsIntFunc - - // RecordExistsIntNot returns false, 0, nil. - RecordExistsIntNot = func() (exists bool, ID int64, err error) { - return false, 0, nil - } -) - -// SetRecordExistsInt will set the function. -func (d *MockDatabase) SetRecordExistsInt(fn recordExistsIntFunc) { - recordExistsInt = fn -} - -// RecordExistsInt returns the ID if a record exists. -func (d *MockDatabase) RecordExistsInt(fn func() (exists bool, ID int64, err error)) ( - exists bool, ID int64, err error) { - // Use the default. - fnInternal := recordExistsInt - if fnInternal == nil { - fnInternal = fn - } - - return fnInternal() -} - -// ***************************************************************************** - -type recordExistsStringFunc func() (exists bool, ID string, err error) - -var ( - recordExistsString recordExistsStringFunc - - // RecordExistsStringNot returns false, "", nil. - RecordExistsStringNot = func() (exists bool, ID string, err error) { - return false, "", nil - } -) - -// SetRecordExistsString will set the function. -func (d *MockDatabase) SetRecordExistsString(fn recordExistsStringFunc) { - recordExistsString = fn -} - -// RecordExistsString returns the ID if a record exists. -func (d *MockDatabase) RecordExistsString(fn func() (exists bool, ID string, err error)) ( - exists bool, ID string, err error) { - // Use the default. - fnInternal := recordExistsString - if fnInternal == nil { - fnInternal = fn - } - - return fnInternal() -} - -// ***************************************************************************** - -type addRecordIntFunc func() (ID int64, err error) - -var ( - addRecordInt addRecordIntFunc - - // AddRecordIntDefault returns 0, nil. - AddRecordIntDefault = func() (ID int64, err error) { - return 0, nil - } -) - -// SetAddRecordInt will set the function. -func (d *MockDatabase) SetAddRecordInt(fn addRecordIntFunc) { - addRecordInt = fn -} - -// AddRecordInt returns the ID if the record is created. -func (d *MockDatabase) AddRecordInt(fn func() (ID int64, err error)) (ID int64, err error) { - // Use the default. - fnInternal := addRecordInt - if fnInternal == nil { - fnInternal = fn - } - return fnInternal() -} - -// ***************************************************************************** - -type addRecordStringFunc func() (ID string, err error) - -var ( - addRecordString addRecordStringFunc - - // AddRecordStringDefault returns "", nil. - AddRecordStringDefault = func() (ID string, err error) { - return "", nil - } -) - -// SetAddRecordString will set the function. -func (d *MockDatabase) SetAddRecordString(fn addRecordStringFunc) { - addRecordString = fn -} - -// AddRecordString returns the ID if the record is created. -func (d *MockDatabase) AddRecordString(fn func() (ID string, err error)) (ID string, err error) { - // Use the default. - fnInternal := addRecordString - if fnInternal == nil { - fnInternal = fn - } - return fnInternal() -} - -// ***************************************************************************** - -// PaginatedResults returns the paginated results of a query. -func (d *MockDatabase) PaginatedResults(i interface{}, fn func() ( - interface{}, int, error)) (int, error) { - v := reflect.ValueOf(i) - if v.Kind() != reflect.Ptr { - return 0, errors.New("must pass a pointer, not a value") - } - - // Use the default. - fnInternal := paginatedResults - if fnInternal == nil { - fnInternal = fn - } - - results, d2, d3 := fnInternal() - if results == nil { - return d2, d3 - } - - arrPtr := reflect.ValueOf(i) - value := arrPtr.Elem() - itemPtr := reflect.ValueOf(results) - value.Set(itemPtr) - - return d2, d3 -} - -type paginatedResultsFunc func() (interface{}, int, error) - -var paginatedResults paginatedResultsFunc - -// PaginatedResultsEmpty returns nil, 0, nil. -var PaginatedResultsEmpty = func() (interface{}, int, error) { - return nil, 0, nil -} - -// SetPaginatedResults will set the paginated results function. -func (d *MockDatabase) SetPaginatedResults(fn paginatedResultsFunc) { - paginatedResults = fn -} - -// ***************************************************************************** - -/* -// MockBind . -type MockBind struct{} - -// FormUnmarshal . -func (mb *MockBind) FormUnmarshal(i interface{}, r *http.Request) (err error) { - return nil -} - -// Validate . -func (mb *MockBind) Validate(s interface{}) error { - return nil -}*/ - -// MockResponse . -type MockResponse struct{} - -// Created . -func (mr *MockResponse) Created(w http.ResponseWriter, recordID string) (int, error) { - return 0, nil -} - -// Results . -func (mr *MockResponse) Results(w http.ResponseWriter, body interface{}, data interface{}) (int, error) { - return 0, nil -} - -// OK . -func (mr *MockResponse) OK(w http.ResponseWriter, message string) (int, error) { - return 0, nil -} - -// MockToken . -type MockToken struct{} - -// ***************************************************************************** - -type generateFunc func(userID string, duration time.Duration) (string, error) - -var ( - generate generateFunc - - // GenerateDefault returns "", nil. - GenerateDefault = func(userID string, duration time.Duration) (string, error) { - return "", nil - } -) - -// SetGenerate will set the function. -func (mt *MockToken) SetGenerate(fn generateFunc) { - generate = fn -} - -// Generate . -func (mt *MockToken) Generate(userID string, duration time.Duration) (string, error) { - // Use the default. - fnInternal := generate - if fnInternal == nil { - fnInternal = GenerateDefault - } - return fnInternal(userID, duration) -} diff --git a/src/app/webapi/internal/testutil/mock_logger.go b/src/app/webapi/internal/testutil/mock_logger.go new file mode 100644 index 0000000..05e7c11 --- /dev/null +++ b/src/app/webapi/internal/testutil/mock_logger.go @@ -0,0 +1,4 @@ +package testutil + +// MockLogger is a mocked logger. +type MockLogger struct{} diff --git a/src/app/webapi/internal/testutil/mock_token.go b/src/app/webapi/internal/testutil/mock_token.go new file mode 100644 index 0000000..9cdf3ab --- /dev/null +++ b/src/app/webapi/internal/testutil/mock_token.go @@ -0,0 +1,32 @@ +package testutil + +import "time" + +// MockToken is a mocked webtoken. +type MockToken struct{} + +type generateFunc func(userID string, duration time.Duration) (string, error) + +var ( + generate generateFunc + + // GenerateDefault returns "", nil. + GenerateDefault = func(userID string, duration time.Duration) (string, error) { + return "", nil + } +) + +// SetGenerate will set the function. +func (mt *MockToken) SetGenerate(fn generateFunc) { + generate = fn +} + +// Generate . +func (mt *MockToken) Generate(userID string, duration time.Duration) (string, error) { + // Use the default. + fnInternal := generate + if fnInternal == nil { + fnInternal = GenerateDefault + } + return fnInternal(userID, duration) +} diff --git a/src/app/webapi/middleware/jwt/jwt.go b/src/app/webapi/middleware/jwt/jwt.go index e89f3c2..b02e058 100644 --- a/src/app/webapi/middleware/jwt/jwt.go +++ b/src/app/webapi/middleware/jwt/jwt.go @@ -7,7 +7,7 @@ import ( "strings" "app/webapi/internal/response" - "app/webapi/internal/webtoken" + "app/webapi/pkg/webtoken" ) // Config contains the dependencies for the handler. diff --git a/src/app/webapi/pkg/database/database.go b/src/app/webapi/pkg/database/database.go index 3289739..f353c87 100644 --- a/src/app/webapi/pkg/database/database.go +++ b/src/app/webapi/pkg/database/database.go @@ -2,9 +2,6 @@ package database import ( "database/sql" - "errors" - "reflect" - "time" "github.com/jmoiron/sqlx" ) @@ -27,11 +24,6 @@ func (d *DBW) Select(dest interface{}, query string, args ...interface{}) error return d.db.Select(dest, query, args...) } -// QueryRowScan returns a single result. -func (d *DBW) QueryRowScan(dest interface{}, query string, args ...interface{}) error { - return d.db.QueryRow(query, args...).Scan(dest) -} - // Get using this DB. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. @@ -45,68 +37,12 @@ func (d *DBW) Exec(query string, args ...interface{}) (sql.Result, error) { return d.db.Exec(query, args...) } -// LastInsertID returns the last inserted ID. -func (d *DBW) LastInsertID(r sql.Result, err error) (int64, error) { - if err != nil { - return 0, err - } - - return r.LastInsertId() -} - -// MySQLTimestamp returns a MySQL timestamp. -func (d *DBW) MySQLTimestamp(t time.Time) string { - return t.Format("2006-01-02 15:04:05") -} - -// GoTimestamp returns a Go timestamp. -func (d *DBW) GoTimestamp(s string) (time.Time, error) { - return time.Parse("2006-01-02 15:04:05", s) -} - -// ExistsID returns the proper ID and other values based on the query results. -func (d *DBW) ExistsID(err error, ID int64) (bool, int64, error) { - if err == nil { - return true, ID, nil - } else if err == sql.ErrNoRows { - return false, 0, nil - } - return false, 0, err -} - -// ExistsString returns the proper string and other values based on the query results. -func (d *DBW) ExistsString(err error, s string) (bool, string, error) { - if err == nil { - return true, s, nil - } else if err == sql.ErrNoRows { - return false, "", nil - } - return false, "", err -} - -// Error will return nil if the error is sql.ErrNoRows. -func (d *DBW) Error(err error) error { - if err == sql.ErrNoRows { - return nil - } - return err -} - -// AffectedRows returns the number of rows affected by the query. -func (d *DBW) AffectedRows(result sql.Result) int { - if result == nil { - return 0 - } - - // If successful, get the number of affected rows. - count, err := result.RowsAffected() - if err != nil { - return 0 - } - - return int(count) +// QueryRowScan returns a single result. +func (d *DBW) QueryRowScan(dest interface{}, query string, args ...interface{}) error { + return d.db.QueryRow(query, args...).Scan(dest) } +/* // PaginatedResults returns the paginated results of a query. func (d *DBW) PaginatedResults(i interface{}, fn func() (interface{}, int, error)) (int, error) { @@ -126,31 +62,4 @@ func (d *DBW) PaginatedResults(i interface{}, fn func() (interface{}, int, value.Set(itemPtr) return d2, d3 -} - -// RecordExistsInt returns the ID if a record exists. -func (d *DBW) RecordExistsInt(fn func() (exists bool, ID int64, err error)) ( - exists bool, ID int64, err error) { - return fn() -} - -// RecordExistsString returns the ID if a record exists. -func (d *DBW) RecordExistsString(fn func() (exists bool, ID string, err error)) ( - exists bool, ID string, err error) { - return fn() -} - -// AddRecordInt returns the ID if the record is created. -func (d *DBW) AddRecordInt(fn func() (ID int64, err error)) (ID int64, err error) { - return fn() -} - -// AddRecordString returns the ID if the record is created. -func (d *DBW) AddRecordString(fn func() (ID string, err error)) (ID string, err error) { - return fn() -} - -// ExecQuery returns an error if the query failed. -func (d *DBW) ExecQuery(fn func() (err error)) (err error) { - return fn() -} +}*/ diff --git a/src/app/webapi/pkg/query/helper.go b/src/app/webapi/pkg/query/helper.go new file mode 100644 index 0000000..85b0b4c --- /dev/null +++ b/src/app/webapi/pkg/query/helper.go @@ -0,0 +1,46 @@ +package query + +import "database/sql" + +// recordExists returns if the record exists or not. +func recordExists(err error) (bool, error) { + if err == nil { + return true, nil + } else if err == sql.ErrNoRows { + return false, nil + } + return false, err +} + +// recordExistsString returns the proper string is the record exists. +func recordExistsString(err error, s string) (bool, string, error) { + if err == nil { + return true, s, nil + } else if err == sql.ErrNoRows { + return false, "", nil + } + return false, "", err +} + +// suppressNoRowsError will return nil if the error is sql.ErrNoRows. +func suppressNoRowsError(err error) error { + if err == sql.ErrNoRows { + return nil + } + return err +} + +// affectedRows returns the number of rows affected by the query. +func affectedRows(result sql.Result) int { + if result == nil { + return 0 + } + + // If successful, get the number of affected rows. + count, err := result.RowsAffected() + if err != nil { + return 0 + } + + return int(count) +} diff --git a/src/app/webapi/pkg/query/interface.go b/src/app/webapi/pkg/query/interface.go new file mode 100644 index 0000000..fd928d5 --- /dev/null +++ b/src/app/webapi/pkg/query/interface.go @@ -0,0 +1,17 @@ +package query + +import "database/sql" + +// IDatabase provides data query capabilities. +type IDatabase interface { + Get(dest interface{}, query string, args ...interface{}) error + Exec(query string, args ...interface{}) (sql.Result, error) + Select(dest interface{}, query string, args ...interface{}) error + QueryRowScan(dest interface{}, query string, args ...interface{}) error +} + +// IRecord provides table information. +type IRecord interface { + Table() string + PrimaryKey() string +} diff --git a/src/app/webapi/pkg/query/query.go b/src/app/webapi/pkg/query/query.go index 1636f38..3d992a9 100644 --- a/src/app/webapi/pkg/query/query.go +++ b/src/app/webapi/pkg/query/query.go @@ -1,7 +1,6 @@ package query import ( - "database/sql" "fmt" ) @@ -17,24 +16,9 @@ type Q struct { db IDatabase } -// IDatabase provides data query capabilities. -type IDatabase interface { - Get(dest interface{}, query string, args ...interface{}) error - Exec(query string, args ...interface{}) (sql.Result, error) - Select(dest interface{}, query string, args ...interface{}) error - QueryRowScan(dest interface{}, query string, args ...interface{}) error - - Error(err error) error - AffectedRows(result sql.Result) int - - ExistsString(err error, s string) (bool, string, error) -} - -// IRecord provides table information. -type IRecord interface { - Table() string - PrimaryKey() string -} +// ***************************************************************************** +// Find +// ***************************************************************************** // FindOneByID will find a record by string ID. func (q *Q) FindOneByID(dest IRecord, ID string) (exists bool, err error) { @@ -43,7 +27,7 @@ func (q *Q) FindOneByID(dest IRecord, ID string) (exists bool, err error) { WHERE %s = ? LIMIT 1`, dest.Table(), dest.PrimaryKey()), ID) - return (err != sql.ErrNoRows), q.db.Error(err) + return recordExists(err) } // FindAll returns all users. @@ -57,7 +41,7 @@ func (q *Q) FindAll(dest IRecord) (total int, err error) { `, dest.PrimaryKey(), dest.Table())) if err != nil { - return total, q.db.Error(err) + return total, suppressNoRowsError(err) } err = q.db.Select(dest, fmt.Sprintf(`SELECT * FROM %s`, dest.Table())) @@ -76,7 +60,7 @@ func (q *Q) DeleteOneByID(dest IRecord, ID string) (affected int, err error) { return 0, err } - return q.db.AffectedRows(result), err + return affectedRows(result), err } // DeleteAll removes all records. @@ -86,7 +70,7 @@ func (q *Q) DeleteAll(dest IRecord) (affected int, err error) { return 0, err } - return q.db.AffectedRows(result), err + return affectedRows(result), err } // ***************************************************************************** @@ -112,28 +96,5 @@ func (q *Q) ExistsByField(db IRecord, field string, value string) (found bool, I LIMIT 1`, db.PrimaryKey(), db.Table(), field), value) - //TODO: Add this to a function so it can be reused. - if err == nil { - return true, ID, nil - } else if err == sql.ErrNoRows { - return false, "", nil - } - return false, "", err - - /*err = q.db.Get(db, fmt.Sprintf(` - SELECT %s FROM %s - WHERE %s = ? - LIMIT 1`, db.PrimaryKey(), db.Table(), field), - value) - return recordExists(err)*/ -} - -// recordExists returns if the record exists or not. -func recordExists(err error) (bool, error) { - if err == nil { - return true, nil - } else if err == sql.ErrNoRows { - return false, nil - } - return false, err + return recordExistsString(err, ID) } diff --git a/src/app/webapi/pkg/router/handler.go b/src/app/webapi/pkg/router/handler.go new file mode 100644 index 0000000..5df8443 --- /dev/null +++ b/src/app/webapi/pkg/router/handler.go @@ -0,0 +1,28 @@ +package router + +import ( + "net/http" +) + +// Handler is used to wrapper all endpoint functions so they work with generic +// routers. +type Handler func(http.ResponseWriter, *http.Request) (int, error) + +// ServeHTTP is a settable function that receives the status and error from +// the function call. +var ServeHTTP = func(w http.ResponseWriter, r *http.Request, status int, + err error) { + if status >= 400 { + if err != nil { + http.Error(w, err.Error(), status) + } else { + http.Error(w, "", status) + } + } +} + +// ServeHTTP handles all the errors from the HTTP handlers. +func (fn Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + status, err := fn(w, r) + ServeHTTP(w, r, status, err) +} diff --git a/src/app/webapi/pkg/router/helper.go b/src/app/webapi/pkg/router/helper.go index 7eb7f3f..904f5ae 100644 --- a/src/app/webapi/pkg/router/helper.go +++ b/src/app/webapi/pkg/router/helper.go @@ -1,40 +1,36 @@ package router -import ( - "net/http" -) - // Delete is a shortcut for router.Handle("DELETE", path, handle) -func (m *Mux) Delete(path string, fn http.Handler) { +func (m *Mux) Delete(path string, fn Handler) { m.router.Handle("DELETE", path, fn) } // Get is a shortcut for router.Handle("GET", path, handle) -func (m *Mux) Get(path string, fn http.Handler) { +func (m *Mux) Get(path string, fn Handler) { m.router.Handle("GET", path, fn) } // Head is a shortcut for router.Handle("HEAD", path, handle) -func (m *Mux) Head(path string, fn http.Handler) { +func (m *Mux) Head(path string, fn Handler) { m.router.Handle("HEAD", path, fn) } // Options is a shortcut for router.Handle("OPTIONS", path, handle) -func (m *Mux) Options(path string, fn http.Handler) { +func (m *Mux) Options(path string, fn Handler) { m.router.Handle("OPTIONS", path, fn) } // Patch is a shortcut for router.Handle("PATCH", path, handle) -func (m *Mux) Patch(path string, fn http.Handler) { +func (m *Mux) Patch(path string, fn Handler) { m.router.Handle("PATCH", path, fn) } // Post is a shortcut for router.Handle("POST", path, handle) -func (m *Mux) Post(path string, fn http.Handler) { +func (m *Mux) Post(path string, fn Handler) { m.router.Handle("POST", path, fn) } // Put is a shortcut for router.Handle("PUT", path, handle) -func (m *Mux) Put(path string, fn http.Handler) { +func (m *Mux) Put(path string, fn Handler) { m.router.Handle("PUT", path, fn) } diff --git a/src/app/webapi/pkg/router/router_test.go b/src/app/webapi/pkg/router/router_test.go index f38dc94..c806a2d 100644 --- a/src/app/webapi/pkg/router/router_test.go +++ b/src/app/webapi/pkg/router/router_test.go @@ -17,9 +17,10 @@ import ( func TestParams(t *testing.T) { mux := router.New() - mux.Get("/user/:name", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Get("/user/:name", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { assert.Equal(t, "john", router.Params(r, "name")) + return http.StatusOK, nil })) r := httptest.NewRequest("GET", "/user/john", nil) @@ -30,9 +31,10 @@ func TestParams(t *testing.T) { func TestInstance(t *testing.T) { mux := router.New() - mux.Get("/user/:name", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Get("/user/:name", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { assert.Equal(t, "john", router.Params(r, "name")) + return http.StatusOK, nil })) r := httptest.NewRequest("GET", "/user/john", nil) @@ -47,10 +49,11 @@ func TestPostForm(t *testing.T) { form := url.Values{} form.Add("username", "jsmith") - mux.Post("/user", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Post("/user", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { r.ParseForm() assert.Equal(t, "jsmith", r.FormValue("username")) + return http.StatusOK, nil })) r := httptest.NewRequest("POST", "/user", strings.NewReader(form.Encode())) @@ -67,13 +70,13 @@ func TestPostJSON(t *testing.T) { }) assert.Nil(t, err) - mux.Post("/user", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Post("/user", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { b, err := ioutil.ReadAll(r.Body) assert.Nil(t, err) r.Body.Close() assert.Equal(t, `{"username":"jsmith"}`, string(b)) - + return http.StatusOK, nil })) r := httptest.NewRequest("POST", "/user", bytes.NewBuffer(j)) @@ -87,9 +90,10 @@ func TestGet(t *testing.T) { called := false - mux.Get("/user", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Get("/user", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { called = true + return http.StatusOK, nil })) r := httptest.NewRequest("GET", "/user", nil) @@ -104,9 +108,10 @@ func TestDelete(t *testing.T) { called := false - mux.Delete("/user", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Delete("/user", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { called = true + return http.StatusOK, nil })) r := httptest.NewRequest("DELETE", "/user", nil) @@ -121,9 +126,10 @@ func TestHead(t *testing.T) { called := false - mux.Head("/user", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Head("/user", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { called = true + return http.StatusOK, nil })) r := httptest.NewRequest("HEAD", "/user", nil) @@ -138,9 +144,10 @@ func TestOptions(t *testing.T) { called := false - mux.Options("/user", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Options("/user", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { called = true + return http.StatusOK, nil })) r := httptest.NewRequest("OPTIONS", "/user", nil) @@ -155,9 +162,10 @@ func TestPatch(t *testing.T) { called := false - mux.Patch("/user", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Patch("/user", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { called = true + return http.StatusOK, nil })) r := httptest.NewRequest("PATCH", "/user", nil) @@ -172,9 +180,10 @@ func TestPut(t *testing.T) { called := false - mux.Put("/user", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + mux.Put("/user", router.Handler( + func(w http.ResponseWriter, r *http.Request) (status int, err error) { called = true + return http.StatusOK, nil })) r := httptest.NewRequest("PUT", "/user", nil) diff --git a/src/app/webapi/internal/webtoken/clock.go b/src/app/webapi/pkg/webtoken/clock.go similarity index 100% rename from src/app/webapi/internal/webtoken/clock.go rename to src/app/webapi/pkg/webtoken/clock.go diff --git a/src/app/webapi/internal/webtoken/testdata/config.json b/src/app/webapi/pkg/webtoken/testdata/config.json similarity index 100% rename from src/app/webapi/internal/webtoken/testdata/config.json rename to src/app/webapi/pkg/webtoken/testdata/config.json diff --git a/src/app/webapi/internal/webtoken/webtoken.go b/src/app/webapi/pkg/webtoken/webtoken.go similarity index 92% rename from src/app/webapi/internal/webtoken/webtoken.go rename to src/app/webapi/pkg/webtoken/webtoken.go index dc70a41..a71ac4b 100644 --- a/src/app/webapi/internal/webtoken/webtoken.go +++ b/src/app/webapi/pkg/webtoken/webtoken.go @@ -1,13 +1,13 @@ package webtoken import ( + "crypto/rand" "encoding/base64" "errors" + "fmt" "strconv" "time" - "app/webapi/pkg/securegen" - jwt "github.com/dgrijalva/jwt-go" ) @@ -69,6 +69,17 @@ func (c *Configuration) SetClock(clock IClock) { c.clock = clock } +// randomID generates a UUID for use as an ID. +func randomID() (string, error) { + b := make([]byte, 16) + _, err := rand.Read(b) + if err != nil { + return "", err + } + + return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), nil +} + // Generate will generate a JWT. func (c *Configuration) Generate(userID string, duration time.Duration) (string, error) { // Ensure a secret is present. @@ -80,7 +91,7 @@ func (c *Configuration) Generate(userID string, duration time.Duration) (string, now := c.clock.Now() // Generate a unique ID. - unique, err := securegen.UUID() + unique, err := randomID() if err != nil { return "", err } diff --git a/src/app/webapi/internal/webtoken/webtoken_test.go b/src/app/webapi/pkg/webtoken/webtoken_test.go similarity index 99% rename from src/app/webapi/internal/webtoken/webtoken_test.go rename to src/app/webapi/pkg/webtoken/webtoken_test.go index 52cdecc..33dbdca 100644 --- a/src/app/webapi/internal/webtoken/webtoken_test.go +++ b/src/app/webapi/pkg/webtoken/webtoken_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "app/webapi/internal/webtoken" + "app/webapi/pkg/webtoken" "github.com/stretchr/testify/assert" ) diff --git a/src/app/webapi/webapi.go b/src/app/webapi/webapi.go index 2ec589d..6af3e59 100644 --- a/src/app/webapi/webapi.go +++ b/src/app/webapi/webapi.go @@ -37,12 +37,12 @@ import ( "app/webapi/component/user" "app/webapi/internal/bind" "app/webapi/internal/response" - "app/webapi/internal/webtoken" "app/webapi/pkg/database" "app/webapi/pkg/logger" "app/webapi/pkg/query" "app/webapi/pkg/router" "app/webapi/pkg/server" + "app/webapi/pkg/webtoken" ) // ***************************************************************************** @@ -85,11 +85,39 @@ func Routes(config *AppConfig, appLogger logger.ILog) *router.Mux { user.New(core).Routes(r) // Set up the 404 page. - r.Instance().NotFound = component.H( + r.Instance().NotFound = router.Handler( func(w http.ResponseWriter, r *http.Request) (int, error) { return http.StatusNotFound, nil }) + // Set the handling of all responses. + router.ServeHTTP = func(w http.ResponseWriter, r *http.Request, status int, err error) { + // Handle only errors. + if status >= 400 { + resp := new(response.GenericResponse) + resp.Body.Status = http.StatusText(status) + if err != nil { + resp.Body.Message = err.Error() + } + + // Write the content. + w.WriteHeader(status) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(resp.Body) + if err != nil { + w.Write([]byte(`{"status":"Internal Server Error","message":"problem encoding JSON"}`)) + return + } + } + + // Display server errors. + if status >= 500 { + if err != nil { + l.Printf("%v", err) + } + } + } + return r }