diff --git a/internal/config/config.go b/internal/config/config.go index aff86240..464f87ca 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,7 @@ type Config struct { Service ServiceType Port int Hostname string + APIKey string Redis *redis.RedisConfig OpenTelemetry *otel.OpenTelemetryConfig @@ -80,7 +81,8 @@ func Parse(flags Flags) (*Config, error) { config := &Config{ Hostname: hostname, Service: service, - Port: mustInt(viper, "PORT"), + Port: getPort(viper), + APIKey: viper.GetString("API_KEY"), Redis: &redis.RedisConfig{ Host: viper.GetString("REDIS_HOST"), Port: mustInt(viper, "REDIS_PORT"), @@ -100,3 +102,14 @@ func mustInt(viper *v.Viper, configName string) int { } return i } + +func getPort(viper *v.Viper) int { + port := mustInt(viper, "PORT") + if viper.GetString("API_PORT") != "" { + apiPort, err := strconv.Atoi(viper.GetString("API_PORT")) + if err == nil { + port = apiPort + } + } + return port +} diff --git a/internal/destination/destination_test/handlers_test.go b/internal/destination/destination_test/handlers_test.go new file mode 100644 index 00000000..1e7b81f1 --- /dev/null +++ b/internal/destination/destination_test/handlers_test.go @@ -0,0 +1,214 @@ +package destination_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/hookdeck/EventKit/internal/destination" + "github.com/hookdeck/EventKit/internal/util/testutil" + "github.com/stretchr/testify/assert" +) + +func setupRouter(destinationHandlers *destination.DestinationHandlers) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.Default() + r.GET("/destinations", destinationHandlers.List) + r.POST("/destinations", destinationHandlers.Create) + r.GET("/destinations/:destinationID", destinationHandlers.Retrieve) + r.PATCH("/destinations/:destinationID", destinationHandlers.Update) + r.DELETE("/destinations/:destinationID", destinationHandlers.Delete) + return r +} + +func TestDestinationListHandler(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + t.Run("should return 501", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/destinations", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotImplemented, w.Code) + }) +} + +func TestDestinationCreateHandler(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + t.Run("should create", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + + exampleDestination := destination.CreateDestinationRequest{ + Name: "Test Destination", + } + destinationJSON, _ := json.Marshal(exampleDestination) + req, _ := http.NewRequest("POST", "/destinations", strings.NewReader(string(destinationJSON))) + router.ServeHTTP(w, req) + + var destinationResponse map[string]any + json.Unmarshal(w.Body.Bytes(), &destinationResponse) + + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, exampleDestination.Name, destinationResponse["name"]) + assert.NotEqual(t, "", destinationResponse["id"]) + }) +} + +func TestDestinationRetrieveHandler(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + t.Run("should return 404 when there's no destination", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/destinations/invalid_id", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should retrieve when there's a destination", func(t *testing.T) { + t.Parallel() + + // Setup test destination + exampleDestination := destination.Destination{ + ID: uuid.New().String(), + Name: "Test Destination", + } + redisClient.Set(context.Background(), "destination:"+exampleDestination.ID, exampleDestination.Name, 0) + + // Test HTTP request + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/destinations/"+exampleDestination.ID, nil) + router.ServeHTTP(w, req) + + var destinationResponse map[string]any + json.Unmarshal(w.Body.Bytes(), &destinationResponse) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, exampleDestination.ID, destinationResponse["id"]) + assert.Equal(t, exampleDestination.Name, destinationResponse["name"]) + + // Clean up + redisClient.Del(context.Background(), "destination:"+exampleDestination.ID) + }) +} + +func TestDestinationUpdateHandler(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + initialDestination := destination.Destination{ + Name: "Test Destination", + } + + updateDestinationRequest := destination.UpdateDestinationRequest{ + Name: "Updated Destination", + } + updateDestinationJSON, _ := json.Marshal(updateDestinationRequest) + + t.Run("should validate", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PATCH", "/destinations/invalid_id", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("should return 404 when there's no destination", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PATCH", "/destinations/invalid_id", strings.NewReader(string(updateDestinationJSON))) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should update destination", func(t *testing.T) { + t.Parallel() + + // Setup initial destination + newDestination := initialDestination + newDestination.ID = uuid.New().String() + redisClient.Set(context.Background(), "destination:"+newDestination.ID, newDestination.Name, 0) + + // Test HTTP request + w := httptest.NewRecorder() + req, _ := http.NewRequest("PATCH", "/destinations/"+newDestination.ID, strings.NewReader(string(updateDestinationJSON))) + router.ServeHTTP(w, req) + + var destinationResponse map[string]any + json.Unmarshal(w.Body.Bytes(), &destinationResponse) + + assert.Equal(t, http.StatusAccepted, w.Code) + assert.Equal(t, newDestination.ID, destinationResponse["id"]) + assert.Equal(t, updateDestinationRequest.Name, destinationResponse["name"]) + + // Clean up + redisClient.Del(context.Background(), "destination:"+newDestination.ID) + }) +} + +func TestDestinationDeleteHandler(t *testing.T) { + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + t.Run("should return 404 when there's no destination", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/destinations/invalid_id", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should delete destination", func(t *testing.T) { + t.Parallel() + + // Setup initial destination + newDestination := destination.Destination{ + ID: uuid.New().String(), + Name: "Test Destination", + } + redisClient.Set(context.Background(), "destination:"+newDestination.ID, newDestination.Name, 0) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/destinations/"+newDestination.ID, nil) + router.ServeHTTP(w, req) + + var destinationResponse map[string]any + json.Unmarshal(w.Body.Bytes(), &destinationResponse) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, newDestination.ID, destinationResponse["id"]) + assert.Equal(t, newDestination.Name, destinationResponse["name"]) + }) +} diff --git a/internal/destination/destination_test/model_test.go b/internal/destination/destination_test/model_test.go new file mode 100644 index 00000000..929755ab --- /dev/null +++ b/internal/destination/destination_test/model_test.go @@ -0,0 +1,67 @@ +package destination_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/hookdeck/EventKit/internal/destination" + "github.com/hookdeck/EventKit/internal/util/testutil" + "github.com/stretchr/testify/assert" +) + +func TestDestinationModel(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + model := destination.NewDestinationModel(redisClient) + + input := destination.Destination{ + ID: uuid.New().String(), + Name: "Test Destination", + } + + t.Run("gets empty", func(t *testing.T) { + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, actual, "model.Get() should return nil when there's no value") + assert.Nil(t, err, "model.Get() should not return an error when there's no value") + }) + + t.Run("sets", func(t *testing.T) { + err := model.Set(context.Background(), input) + assert.Nil(t, err, "model.Set() should not return an error") + + value, err := redisClient.Get(context.Background(), "destination:"+input.ID).Result() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, input.Name, value, "model.Set() should set destination name %s", input.Name) + }) + + t.Run("gets", func(t *testing.T) { + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, err, "model.Get() should not return an error") + assert.Equal(t, input, *actual, "model.Get() should return %s", input) + }) + + t.Run("overrides", func(t *testing.T) { + input.Name = "Test Destination 2" + + err := model.Set(context.Background(), input) + assert.Nil(t, err, "model.Set() should not return an error", input) + + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, err, "model.Get() should not return an error") + assert.Equal(t, input, *actual, "model.Get() should return %s", input) + }) + + t.Run("clears", func(t *testing.T) { + deleted, err := model.Clear(context.Background(), input.ID) + assert.Nil(t, err, "model.Clear() should not return an error") + assert.Equal(t, *deleted, input, "model.Clear() should return deleted value", input) + + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, actual, "model.Clear() should properly remove value") + assert.Nil(t, err, "model.Clear() should properly remove value") + }) +} diff --git a/internal/services/api/auth_middleware.go b/internal/services/api/auth_middleware.go new file mode 100644 index 00000000..c0af656c --- /dev/null +++ b/internal/services/api/auth_middleware.go @@ -0,0 +1,41 @@ +package api + +import ( + "errors" + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +func apiKeyAuthMiddleware(apiKey string) gin.HandlerFunc { + if apiKey == "" { + return func(c *gin.Context) { + c.Next() + } + } + + return func(c *gin.Context) { + authorizationToken, err := extractBearerToken(c.GetHeader("Authorization")) + if err != nil { + // TODO: Consider sending a more detailed error message. + // Currently we don't have clear specs on how to send back error message. + c.AbortWithStatus(http.StatusUnauthorized) + return + } + if authorizationToken != apiKey { + // TODO: Consider sending a more detailed error message. + // Currently we don't have clear specs on how to send back error message. + c.AbortWithStatus(http.StatusUnauthorized) + return + } + c.Next() + } +} + +func extractBearerToken(header string) (string, error) { + if !strings.HasPrefix(header, "Bearer ") { + return "", errors.New("invalid bearer token") + } + return strings.TrimPrefix(header, "Bearer "), nil +} diff --git a/internal/services/api/auth_middleware_test.go b/internal/services/api/auth_middleware_test.go new file mode 100644 index 00000000..165bf68d --- /dev/null +++ b/internal/services/api/auth_middleware_test.go @@ -0,0 +1,86 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func setupRouter(apiKey string) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.Default() + r.Use(apiKeyAuthMiddleware(apiKey)) + r.GET("/healthz", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + return r +} + +func TestPublicRouter(t *testing.T) { + t.Parallel() + + router := setupRouter("") + + t.Run("should accept requests without a token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("should accept requests with an invalid authorization token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + req.Header.Set("Authorization", "invalid key") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("should accept requests with a valid authorization token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + req.Header.Set("Authorization", "Bearer key") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) +} + +func TestPrivateAPIKeyRouter(t *testing.T) { + t.Parallel() + + const apiKey = "key" + + router := setupRouter(apiKey) + + t.Run("should reject requests without a token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("should reject requests with an invalid authorization token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + req.Header.Set("Authorization", "invalid key") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("should accept requests with a valid authorization token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + req.Header.Set("Authorization", "Bearer "+apiKey) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) +} diff --git a/internal/services/api/router.go b/internal/services/api/router.go index d9d79c45..c4b0d4d0 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -14,6 +14,7 @@ import ( func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Client) http.Handler { r := gin.Default() r.Use(otelgin.Middleware(cfg.Hostname)) + r.Use(apiKeyAuthMiddleware(cfg.APIKey)) r.GET("/healthz", func(c *gin.Context) { logger.Ctx(c.Request.Context()).Info("health check")