From dbb1fa7a4694c20c6ee1a0c48146da6eb05a0f23 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Tue, 27 Aug 2024 10:43:36 +0700 Subject: [PATCH 01/21] refactor: make config & redis dependency more explicit --- cmd/app/main.go | 54 +++++++++++++------------ internal/config/config.go | 50 +++++++++++++---------- internal/destination/handlers.go | 21 ++++++---- internal/destination/model.go | 24 +++++++---- internal/otel/exporter.go | 31 +++++++-------- internal/otel/otel.go | 11 +++--- internal/redis/redis.go | 55 ++++++++++++-------------- internal/services/api/api.go | 20 +++++++--- internal/services/api/router.go | 10 ++--- internal/services/delivery/delivery.go | 8 +++- internal/services/log/log.go | 20 +++++++--- 11 files changed, 174 insertions(+), 130 deletions(-) diff --git a/cmd/app/main.go b/cmd/app/main.go index 3790fc6d..fe9765f4 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -11,7 +11,6 @@ import ( "github.com/hookdeck/EventKit/internal/config" "github.com/hookdeck/EventKit/internal/otel" - "github.com/hookdeck/EventKit/internal/redis" "github.com/hookdeck/EventKit/internal/services/api" "github.com/hookdeck/EventKit/internal/services/delivery" "github.com/hookdeck/EventKit/internal/services/log" @@ -34,7 +33,8 @@ func main() { func run(mainContext context.Context) error { flags := config.ParseFlags() - if err := config.Parse(flags); err != nil { + c, err := config.Parse(flags) + if err != nil { return err } @@ -51,8 +51,8 @@ func run(mainContext context.Context) error { ctx, cancel := context.WithCancel(mainContext) // Set up OpenTelemetry. - if config.OpenTelemetry != nil { - otelShutdown, err := otel.SetupOTelSDK(ctx) + if c.OpenTelemetry != nil { + otelShutdown, err := otel.SetupOTelSDK(ctx, c) if err != nil { cancel() return err @@ -61,12 +61,6 @@ func run(mainContext context.Context) error { defer func() { err = errors.Join(err, otelShutdown(context.Background())) }() - - // QUESTION: what if a service doesn't need Redis? Is it unnecessary to initalize the client here? - if err := redis.InstrumentOpenTelemetry(); err != nil { - cancel() - return err - } } // Initialize waitgroup @@ -76,22 +70,30 @@ func run(mainContext context.Context) error { // Construct services based on config services := []Service{} - switch config.Service { - case config.ServiceTypeAPI: - services = append(services, api.NewService(ctx, wg, logger)) - case config.ServiceTypeLog: - services = append(services, log.NewService(ctx, wg, logger)) - case config.ServiceTypeDelivery: - services = append(services, delivery.NewService(ctx, wg, logger)) - case config.ServiceTypeSingular: - services = append(services, - api.NewService(ctx, wg, logger), - log.NewService(ctx, wg, logger), - delivery.NewService(ctx, wg, logger), - ) - default: - cancel() - return fmt.Errorf("unknown service: %s", flags.Service) + + if c.Service == config.ServiceTypeAPI || c.Service == config.ServiceTypeSingular { + service, err := api.NewService(ctx, wg, c, logger) + if err != nil { + cancel() + return err + } + services = append(services, service) + } + if c.Service == config.ServiceTypeDelivery || c.Service == config.ServiceTypeSingular { + service, err := delivery.NewService(ctx, wg, c, logger) + if err != nil { + cancel() + return err + } + services = append(services, service) + } + if c.Service == config.ServiceTypeLog || c.Service == config.ServiceTypeSingular { + service, err := log.NewService(ctx, wg, c, logger) + if err != nil { + cancel() + return err + } + services = append(services, service) } // Start services diff --git a/internal/config/config.go b/internal/config/config.go index 91e78db8..cd23571d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,14 +6,14 @@ import ( "strconv" "github.com/joho/godotenv" - "github.com/spf13/viper" + v "github.com/spf13/viper" ) const ( Namespace = "EventKit" ) -var ( +type Config struct { Service ServiceType Port int Hostname string @@ -24,7 +24,7 @@ var ( RedisDatabase int OpenTelemetry *OpenTelemetryConfig -) +} var defaultConfig = map[string]any{ "PORT": 3333, @@ -34,12 +34,15 @@ var defaultConfig = map[string]any{ "REDIS_DATABASE": 0, } -func Parse(flags Flags) error { +func Parse(flags Flags) (*Config, error) { var err error - Hostname, err = os.Hostname() + // Use a local Viper instance to avoid leaking configuration to global scope + viper := v.New() + + hostname, err := os.Hostname() if err != nil { - return err + return nil, err } // Load .env file to environment variables @@ -49,9 +52,9 @@ func Parse(flags Flags) error { } // Parse service type from flag - Service, err = ServiceTypeFromString(flags.Service) + service, err := ServiceTypeFromString(flags.Service) if err != nil { - return err + return nil, err } // Set default config values @@ -63,29 +66,36 @@ func Parse(flags Flags) error { if flags.Config != "" { viper.SetConfigFile(flags.Config) if err := viper.ReadInConfig(); err != nil { - return err + return nil, err } } // Bind environemnt variable to viper viper.AutomaticEnv() - // Initialize config values - Port = mustInt("PORT") - RedisHost = viper.GetString("REDIS_HOST") - RedisPort = mustInt("REDIS_PORT") - RedisPassword = viper.GetString("REDIS_PASSWORD") - RedisDatabase = mustInt("REDIS_DATABASE") - - OpenTelemetry, err = parseOpenTelemetryConfig() + openTelemetry, err := parseOpenTelemetryConfig() if err != nil { - return err + return nil, err + } + + viper.Get("PORT") + + // Initialize config values + config := &Config{ + Hostname: hostname, + Service: service, + Port: mustInt(viper, "PORT"), + RedisHost: viper.GetString("REDIS_HOST"), + RedisPort: mustInt(viper, "REDIS_PORT"), + RedisPassword: viper.GetString("REDIS_PASSWORD"), + RedisDatabase: mustInt(viper, "REDIS_DATABASE"), + OpenTelemetry: openTelemetry, } - return nil + return config, nil } -func mustInt(configName string) int { +func mustInt(viper *v.Viper, configName string) int { i, err := strconv.Atoi(viper.GetString(configName)) if err != nil { log.Fatalf("%s = '%s' is not int", configName, viper.GetString(configName)) diff --git a/internal/destination/handlers.go b/internal/destination/handlers.go index 21f711e0..58c8cdee 100644 --- a/internal/destination/handlers.go +++ b/internal/destination/handlers.go @@ -6,12 +6,17 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/hookdeck/EventKit/internal/redis" ) -type DestinationHandlers struct{} +type DestinationHandlers struct { + model *DestinationModel +} -func NewHandlers() *DestinationHandlers { - return &DestinationHandlers{} +func NewHandlers(redisClient *redis.Client) *DestinationHandlers { + return &DestinationHandlers{ + model: NewDestinationModel(redisClient), + } } func (h *DestinationHandlers) List(c *gin.Context) { @@ -29,7 +34,7 @@ func (h *DestinationHandlers) Create(c *gin.Context) { ID: id, Name: json.Name, } - if err := SetDestination(c.Request.Context(), destination); err != nil { + if err := h.model.Set(c.Request.Context(), destination); err != nil { log.Println(err) c.Status(http.StatusInternalServerError) return @@ -42,7 +47,7 @@ func (h *DestinationHandlers) Create(c *gin.Context) { func (h *DestinationHandlers) Retrieve(c *gin.Context) { destinationID := c.Param("destinationID") - destination, err := GetDestination(c.Request.Context(), destinationID) + destination, err := h.model.Get(c.Request.Context(), destinationID) if err != nil { log.Println(err) c.Status(http.StatusInternalServerError) @@ -68,7 +73,7 @@ func (h *DestinationHandlers) Update(c *gin.Context) { // Get destination. destinationID := c.Param("destinationID") - destination, err := GetDestination(c.Request.Context(), destinationID) + destination, err := h.model.Get(c.Request.Context(), destinationID) if err != nil { log.Println(err) c.Status(http.StatusInternalServerError) @@ -81,7 +86,7 @@ func (h *DestinationHandlers) Update(c *gin.Context) { // Update destination destination.Name = json.Name - if err := SetDestination(c.Request.Context(), *destination); err != nil { + if err := h.model.Set(c.Request.Context(), *destination); err != nil { log.Println(err) c.Status(http.StatusInternalServerError) return @@ -94,7 +99,7 @@ func (h *DestinationHandlers) Update(c *gin.Context) { func (h *DestinationHandlers) Delete(c *gin.Context) { destinationID := c.Param("destinationID") - destination, err := ClearDestination(c.Request.Context(), destinationID) + destination, err := h.model.Clear(c.Request.Context(), destinationID) if err != nil { log.Println(err) c.Status(http.StatusInternalServerError) diff --git a/internal/destination/model.go b/internal/destination/model.go index 2283e55d..14411cff 100644 --- a/internal/destination/model.go +++ b/internal/destination/model.go @@ -20,8 +20,18 @@ type UpdateDestinationRequest struct { Name string `json:"name" binding:"required"` } -func GetDestination(c context.Context, id string) (*Destination, error) { - destination, err := redis.Client().Get(c, redisDestinationID(id)).Result() +type DestinationModel struct { + redisClient *redis.Client +} + +func NewDestinationModel(redisClient *redis.Client) *DestinationModel { + return &DestinationModel{ + redisClient: redisClient, + } +} + +func (m *DestinationModel) Get(c context.Context, id string) (*Destination, error) { + destination, err := m.redisClient.Get(c, redisDestinationID(id)).Result() if err == redis.Nil { return nil, nil } else if err != nil { @@ -33,22 +43,22 @@ func GetDestination(c context.Context, id string) (*Destination, error) { }, nil } -func SetDestination(c context.Context, destination Destination) error { - if err := redis.Client().Set(c, redisDestinationID(destination.ID), destination.Name, 0).Err(); err != nil { +func (m *DestinationModel) Set(c context.Context, destination Destination) error { + if err := m.redisClient.Set(c, redisDestinationID(destination.ID), destination.Name, 0).Err(); err != nil { return err } return nil } -func ClearDestination(c context.Context, id string) (*Destination, error) { - destination, err := GetDestination(c, id) +func (m *DestinationModel) Clear(c context.Context, id string) (*Destination, error) { + destination, err := m.Get(c, id) if err != nil { return nil, err } if destination == nil { return nil, nil } - if err := redis.Client().Del(c, redisDestinationID(id)).Err(); err != nil { + if err := m.redisClient.Del(c, redisDestinationID(id)).Err(); err != nil { return nil, err } return destination, nil diff --git a/internal/otel/exporter.go b/internal/otel/exporter.go index ac1848e3..5e61b01a 100644 --- a/internal/otel/exporter.go +++ b/internal/otel/exporter.go @@ -11,28 +11,27 @@ import ( "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" - "go.opentelemetry.io/otel/sdk/log" "go.opentelemetry.io/otel/sdk/metric" "go.opentelemetry.io/otel/sdk/trace" ) -func newTraceProvider(ctx context.Context) (*trace.TracerProvider, error) { - if config.OpenTelemetry.Traces == nil { +func newTraceProvider(ctx context.Context, c *config.Config) (*trace.TracerProvider, error) { + if c.OpenTelemetry.Traces == nil { return nil, nil } var err error var traceExporter trace.SpanExporter - if config.OpenTelemetry.Traces.Protocol == config.OpenTelemetryProtocolGRPC { + if c.OpenTelemetry.Traces.Protocol == config.OpenTelemetryProtocolGRPC { traceExporter, err = otlptracegrpc.New(ctx, otlptracegrpc.WithInsecure(), // TODO: support TLS - otlptracegrpc.WithEndpoint(config.OpenTelemetry.Traces.Endpoint), + otlptracegrpc.WithEndpoint(c.OpenTelemetry.Traces.Endpoint), ) } else { traceExporter, err = otlptracehttp.New(ctx, otlptracehttp.WithInsecure(), // TODO: support TLS - otlptracehttp.WithEndpointURL(ensureHTTPEndpoint("traces", config.OpenTelemetry.Traces.Endpoint)), + otlptracehttp.WithEndpointURL(ensureHTTPEndpoint("traces", c.OpenTelemetry.Traces.Endpoint)), ) } // traceExporter, err = stdouttrace.New() @@ -51,22 +50,22 @@ func newTraceProvider(ctx context.Context) (*trace.TracerProvider, error) { return traceProvider, nil } -func newMeterProvider(ctx context.Context) (*metric.MeterProvider, error) { - if config.OpenTelemetry.Metrics == nil { +func newMeterProvider(ctx context.Context, c *config.Config) (*metric.MeterProvider, error) { + if c.OpenTelemetry.Metrics == nil { return nil, nil } var err error var metricExporter metric.Exporter - if config.OpenTelemetry.Metrics.Protocol == config.OpenTelemetryProtocolGRPC { + if c.OpenTelemetry.Metrics.Protocol == config.OpenTelemetryProtocolGRPC { metricExporter, err = otlpmetricgrpc.New(ctx, otlpmetricgrpc.WithInsecure(), // TODO: support TLS - otlpmetricgrpc.WithEndpoint(config.OpenTelemetry.Metrics.Endpoint), + otlpmetricgrpc.WithEndpoint(c.OpenTelemetry.Metrics.Endpoint), ) } else { metricExporter, err = otlpmetrichttp.New(ctx, otlpmetrichttp.WithInsecure(), // TODO: support TLS - otlpmetrichttp.WithEndpointURL(ensureHTTPEndpoint("metrics", config.OpenTelemetry.Metrics.Endpoint)), + otlpmetrichttp.WithEndpointURL(ensureHTTPEndpoint("metrics", c.OpenTelemetry.Metrics.Endpoint)), ) } @@ -83,22 +82,22 @@ func newMeterProvider(ctx context.Context) (*metric.MeterProvider, error) { return meterProvider, nil } -func newLoggerProvider(ctx context.Context) (*log.LoggerProvider, error) { - if config.OpenTelemetry.Logs == nil { +func newLoggerProvider(ctx context.Context, c *config.Config) (*log.LoggerProvider, error) { + if c.OpenTelemetry.Logs == nil { return nil, nil } var err error var logExporter log.Exporter - if config.OpenTelemetry.Logs.Protocol == config.OpenTelemetryProtocolGRPC { + if c.OpenTelemetry.Logs.Protocol == config.OpenTelemetryProtocolGRPC { logExporter, err = otlploggrpc.New(ctx, otlploggrpc.WithInsecure(), // TODO: support TLS - otlploggrpc.WithEndpoint(config.OpenTelemetry.Logs.Endpoint), + otlploggrpc.WithEndpoint(c.OpenTelemetry.Logs.Endpoint), ) } else { logExporter, err = otlploghttp.New(ctx, otlploghttp.WithInsecure(), // TODO: support TLS - otlploghttp.WithEndpointURL(ensureHTTPEndpoint("logs", config.OpenTelemetry.Logs.Endpoint)), + otlploghttp.WithEndpointURL(ensureHTTPEndpoint("logs", c.OpenTelemetry.Logs.Endpoint)), ) } diff --git a/internal/otel/otel.go b/internal/otel/otel.go index deea74e9..2d4aa805 100644 --- a/internal/otel/otel.go +++ b/internal/otel/otel.go @@ -5,6 +5,7 @@ import ( "errors" "time" + "github.com/hookdeck/EventKit/internal/config" "go.opentelemetry.io/contrib/instrumentation/host" "go.opentelemetry.io/contrib/instrumentation/runtime" "go.opentelemetry.io/otel" @@ -15,9 +16,9 @@ import ( // TODO: Consider supporting the official OTEL configuration format. // https://opentelemetry.io/docs/collector/configuration/ -// setupOTelSDK bootstraps the OpenTelemetry pipeline. +// SetupOTelSDK bootstraps the OpenTelemetry pipeline. // If it does not return an error, make sure to call shutdown for proper cleanup. -func SetupOTelSDK(ctx context.Context) (shutdown func(context.Context) error, err error) { +func SetupOTelSDK(ctx context.Context, c *config.Config) (shutdown func(context.Context) error, err error) { var shutdownFuncs []func(context.Context) error // shutdown calls cleanup functions registered via shutdownFuncs. @@ -42,7 +43,7 @@ func SetupOTelSDK(ctx context.Context) (shutdown func(context.Context) error, er otel.SetTextMapPropagator(prop) // Set up trace provider. - tracerProvider, err := newTraceProvider(ctx) + tracerProvider, err := newTraceProvider(ctx, c) if err != nil { handleErr(err) return @@ -53,7 +54,7 @@ func SetupOTelSDK(ctx context.Context) (shutdown func(context.Context) error, er } // Set up meter provider. - meterProvider, err := newMeterProvider(ctx) + meterProvider, err := newMeterProvider(ctx, c) if err != nil { handleErr(err) return @@ -75,7 +76,7 @@ func SetupOTelSDK(ctx context.Context) (shutdown func(context.Context) error, er } // Set up logger provider. - loggerProvider, err := newLoggerProvider(ctx) + loggerProvider, err := newLoggerProvider(ctx, c) if err != nil { handleErr(err) return diff --git a/internal/redis/redis.go b/internal/redis/redis.go index 72b71deb..45245ef5 100644 --- a/internal/redis/redis.go +++ b/internal/redis/redis.go @@ -1,7 +1,7 @@ package redis import ( - "errors" + "context" "fmt" "sync" @@ -10,44 +10,41 @@ import ( r "github.com/redis/go-redis/v9" ) -var ( - client *r.Client - once sync.Once -) - +// Reexport go-redis's Nil constant for DX purposes. const ( Nil = r.Nil ) -func InstrumentOpenTelemetry() error { - once.Do(initializeClient) +type Client = r.Client - if config.OpenTelemetry == nil { - return errors.New("OpenTelemetry config is nil") - } - if config.OpenTelemetry.Traces != nil { - if err := redisotel.InstrumentTracing(client); err != nil { - return err - } - } - if config.OpenTelemetry.Metrics != nil { - if err := redisotel.InstrumentMetrics(client); err != nil { - return err - } - } +var ( + once sync.Once + client *r.Client + initializationError error +) - return nil +func New(ctx context.Context, c *config.Config) (*r.Client, error) { + once.Do(func() { + initializeClient(ctx, c) + initializationError = instrumentOpenTelemetry() + }) + return client, initializationError } -func Client() *r.Client { - once.Do(initializeClient) - return client +func instrumentOpenTelemetry() error { + if err := redisotel.InstrumentTracing(client); err != nil { + return err + } + if err := redisotel.InstrumentMetrics(client); err != nil { + return err + } + return nil } -func initializeClient() { +func initializeClient(_ context.Context, c *config.Config) { client = r.NewClient(&r.Options{ - Addr: fmt.Sprintf("%s:%d", config.RedisHost, config.RedisPort), - Password: config.RedisPassword, - DB: config.RedisDatabase, + Addr: fmt.Sprintf("%s:%d", c.RedisHost, c.RedisPort), + Password: c.RedisPassword, + DB: c.RedisDatabase, }) } diff --git a/internal/services/api/api.go b/internal/services/api/api.go index c63f0f11..3c9f55b8 100644 --- a/internal/services/api/api.go +++ b/internal/services/api/api.go @@ -9,6 +9,7 @@ import ( "time" "github.com/hookdeck/EventKit/internal/config" + "github.com/hookdeck/EventKit/internal/redis" "github.com/uptrace/opentelemetry-go-extra/otelzap" "go.uber.org/zap" ) @@ -18,16 +19,23 @@ type APIService struct { logger *otelzap.Logger } -func NewService(ctx context.Context, wg *sync.WaitGroup, logger *otelzap.Logger) *APIService { +func NewService(ctx context.Context, wg *sync.WaitGroup, c *config.Config, logger *otelzap.Logger) (*APIService, error) { + wg.Add(1) + + redisClient, err := redis.New(ctx, c) + if err != nil { + return nil, err + } + + router := NewRouter(c, logger, redisClient) + service := &APIService{} service.logger = logger service.server = &http.Server{ - Addr: fmt.Sprintf(":%d", config.Port), - Handler: NewRouter(logger), + Addr: fmt.Sprintf(":%d", c.Port), + Handler: router, } - wg.Add(1) - go func() { defer wg.Done() <-ctx.Done() @@ -41,7 +49,7 @@ func NewService(ctx context.Context, wg *sync.WaitGroup, logger *otelzap.Logger) logger.Ctx(ctx).Info("http server shutted down") }() - return service + return service, nil } func (s *APIService) Run(ctx context.Context) error { diff --git a/internal/services/api/router.go b/internal/services/api/router.go index 8954c14c..c152be60 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -6,23 +6,21 @@ import ( "github.com/gin-gonic/gin" "github.com/hookdeck/EventKit/internal/config" "github.com/hookdeck/EventKit/internal/destination" + "github.com/hookdeck/EventKit/internal/redis" "github.com/uptrace/opentelemetry-go-extra/otelzap" "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" ) -func NewRouter(logger *otelzap.Logger) http.Handler { +func NewRouter(c *config.Config, logger *otelzap.Logger, redisClient *redis.Client) http.Handler { r := gin.Default() - - if config.OpenTelemetry != nil { - r.Use(otelgin.Middleware(config.Hostname)) - } + r.Use(otelgin.Middleware(c.Hostname)) r.GET("/healthz", func(c *gin.Context) { logger.Ctx(c.Request.Context()).Info("health check") c.Status(http.StatusOK) }) - destinationHandlers := destination.NewHandlers() + destinationHandlers := destination.NewHandlers(redisClient) r.GET("/destinations", destinationHandlers.List) r.POST("/destinations", destinationHandlers.Create) diff --git a/internal/services/delivery/delivery.go b/internal/services/delivery/delivery.go index 99f29c6e..6acea462 100644 --- a/internal/services/delivery/delivery.go +++ b/internal/services/delivery/delivery.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "github.com/hookdeck/EventKit/internal/config" "github.com/uptrace/opentelemetry-go-extra/otelzap" "go.uber.org/zap" ) @@ -12,16 +13,19 @@ type DeliveryService struct { logger *otelzap.Logger } -func NewService(ctx context.Context, wg *sync.WaitGroup, logger *otelzap.Logger) *DeliveryService { +func NewService(ctx context.Context, wg *sync.WaitGroup, c *config.Config, logger *otelzap.Logger) (*DeliveryService, error) { wg.Add(1) go func() { defer wg.Done() <-ctx.Done() logger.Ctx(ctx).Info("service shutdown", zap.String("service", "delivery")) }() - return &DeliveryService{ + + service := &DeliveryService{ logger: logger, } + + return service, nil } func (s *DeliveryService) Run(ctx context.Context) error { diff --git a/internal/services/log/log.go b/internal/services/log/log.go index dcb964a7..66bed998 100644 --- a/internal/services/log/log.go +++ b/internal/services/log/log.go @@ -6,25 +6,35 @@ import ( "sync" "time" + "github.com/hookdeck/EventKit/internal/config" "github.com/hookdeck/EventKit/internal/redis" "github.com/uptrace/opentelemetry-go-extra/otelzap" "go.uber.org/zap" ) type LogService struct { - logger *otelzap.Logger + logger *otelzap.Logger + redisClient *redis.Client } -func NewService(ctx context.Context, wg *sync.WaitGroup, logger *otelzap.Logger) *LogService { +func NewService(ctx context.Context, wg *sync.WaitGroup, c *config.Config, logger *otelzap.Logger) (*LogService, error) { wg.Add(1) go func() { defer wg.Done() <-ctx.Done() logger.Ctx(ctx).Info("service shutdown", zap.String("service", "log")) }() - return &LogService{ - logger: logger, + + redisClient, err := redis.New(ctx, c) + if err != nil { + return nil, err } + + service := &LogService{} + service.logger = logger + service.redisClient = redisClient + + return service, nil } func (s *LogService) Run(ctx context.Context) error { @@ -36,7 +46,7 @@ func (s *LogService) Run(ctx context.Context) error { } for range time.Tick(time.Second * 1) { - keys, err := redis.Client().Keys(ctx, "destination:*").Result() + keys, err := s.redisClient.Keys(ctx, "destination:*").Result() if err != nil { s.logger.Ctx(ctx).Error("error", zap.Error(err), From b22a54d46a86f9e0622a1ea27b4b1fb577145ddf Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Tue, 27 Aug 2024 12:03:33 +0700 Subject: [PATCH 02/21] refactor: move config type declaration to individual packages --- cmd/app/main.go | 19 +++---- internal/config/config.go | 31 +++++------ internal/config/otel.go | 72 +++++++------------------- internal/otel/config.go | 35 +++++++++++++ internal/otel/exporter.go | 32 ++++++------ internal/otel/otel.go | 12 +++-- internal/redis/config.go | 8 +++ internal/redis/redis.go | 13 +++-- internal/services/api/api.go | 8 +-- internal/services/api/router.go | 4 +- internal/services/delivery/delivery.go | 2 +- internal/services/log/log.go | 4 +- 12 files changed, 126 insertions(+), 114 deletions(-) create mode 100644 internal/otel/config.go create mode 100644 internal/redis/config.go diff --git a/cmd/app/main.go b/cmd/app/main.go index fe9765f4..cdca038e 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -33,7 +33,7 @@ func main() { func run(mainContext context.Context) error { flags := config.ParseFlags() - c, err := config.Parse(flags) + cfg, err := config.Parse(flags) if err != nil { return err } @@ -51,8 +51,9 @@ func run(mainContext context.Context) error { ctx, cancel := context.WithCancel(mainContext) // Set up OpenTelemetry. - if c.OpenTelemetry != nil { - otelShutdown, err := otel.SetupOTelSDK(ctx, c) + fmt.Println("cfg.OpenTelemetry", cfg.OpenTelemetry) + if cfg.OpenTelemetry != nil { + otelShutdown, err := otel.SetupOTelSDK(ctx, cfg.OpenTelemetry) if err != nil { cancel() return err @@ -71,24 +72,24 @@ func run(mainContext context.Context) error { // Construct services based on config services := []Service{} - if c.Service == config.ServiceTypeAPI || c.Service == config.ServiceTypeSingular { - service, err := api.NewService(ctx, wg, c, logger) + if cfg.Service == config.ServiceTypeAPI || cfg.Service == config.ServiceTypeSingular { + service, err := api.NewService(ctx, wg, cfg, logger) if err != nil { cancel() return err } services = append(services, service) } - if c.Service == config.ServiceTypeDelivery || c.Service == config.ServiceTypeSingular { - service, err := delivery.NewService(ctx, wg, c, logger) + if cfg.Service == config.ServiceTypeDelivery || cfg.Service == config.ServiceTypeSingular { + service, err := delivery.NewService(ctx, wg, cfg, logger) if err != nil { cancel() return err } services = append(services, service) } - if c.Service == config.ServiceTypeLog || c.Service == config.ServiceTypeSingular { - service, err := log.NewService(ctx, wg, c, logger) + if cfg.Service == config.ServiceTypeLog || cfg.Service == config.ServiceTypeSingular { + service, err := log.NewService(ctx, wg, cfg, logger) if err != nil { cancel() return err diff --git a/internal/config/config.go b/internal/config/config.go index cd23571d..e8e60a33 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,13 @@ package config import ( + "fmt" "log" "os" "strconv" + "github.com/hookdeck/EventKit/internal/otel" + "github.com/hookdeck/EventKit/internal/redis" "github.com/joho/godotenv" v "github.com/spf13/viper" ) @@ -18,12 +21,8 @@ type Config struct { Port int Hostname string - RedisHost string - RedisPort int - RedisPassword string - RedisDatabase int - - OpenTelemetry *OpenTelemetryConfig + Redis *redis.RedisConfig + OpenTelemetry *otel.OpenTelemetryConfig } var defaultConfig = map[string]any{ @@ -73,22 +72,24 @@ func Parse(flags Flags) (*Config, error) { // Bind environemnt variable to viper viper.AutomaticEnv() - openTelemetry, err := parseOpenTelemetryConfig() + openTelemetry, err := parseOpenTelemetryConfig(viper) if err != nil { return nil, err } - viper.Get("PORT") + fmt.Println(openTelemetry) // Initialize config values config := &Config{ - Hostname: hostname, - Service: service, - Port: mustInt(viper, "PORT"), - RedisHost: viper.GetString("REDIS_HOST"), - RedisPort: mustInt(viper, "REDIS_PORT"), - RedisPassword: viper.GetString("REDIS_PASSWORD"), - RedisDatabase: mustInt(viper, "REDIS_DATABASE"), + Hostname: hostname, + Service: service, + Port: mustInt(viper, "PORT"), + Redis: &redis.RedisConfig{ + Host: viper.GetString("REDIS_HOST"), + Port: mustInt(viper, "REDIS_PORT"), + Password: viper.GetString("REDIS_PASSWORD"), + Database: mustInt(viper, "REDIS_DATABASE"), + }, OpenTelemetry: openTelemetry, } diff --git a/internal/config/otel.go b/internal/config/otel.go index daa9117c..26e8c87a 100644 --- a/internal/config/otel.go +++ b/internal/config/otel.go @@ -1,35 +1,15 @@ package config import ( - "fmt" - - "github.com/spf13/viper" -) - -type OpenTelemetryProtocol string - -const ( - OpenTelemetryProtocolGRPC OpenTelemetryProtocol = "grpc" - OpenTelemetryProtocolHTTPProtobuf OpenTelemetryProtocol = "http/protobuf" - OpenTelemetryProtocolHTTPJSON OpenTelemetryProtocol = "http/json" + "github.com/hookdeck/EventKit/internal/otel" + v "github.com/spf13/viper" ) -type OpenTelemetryTypeConfig struct { - Endpoint string - Protocol OpenTelemetryProtocol -} - -type OpenTelemetryConfig struct { - Traces *OpenTelemetryTypeConfig - Metrics *OpenTelemetryTypeConfig - Logs *OpenTelemetryTypeConfig -} - // If the user has set OTEL_SERVICE_NAME, we assume they are managing their own OpenTelemetry configuration. // When parsing config, we assume if the user has set OTEL_EXPORTER_OTLP_ENDPOINT, they will use all 3 // Traces, Metrics, and Logs. // If the user doesn't want to use all 3, they will have to specify each one individually. -func parseOpenTelemetryConfig() (*OpenTelemetryConfig, error) { +func parseOpenTelemetryConfig(viper *v.Viper) (*otel.OpenTelemetryConfig, error) { if viper.GetString("OTEL_SERVICE_NAME") == "" { return nil, nil } @@ -40,93 +20,81 @@ func parseOpenTelemetryConfig() (*OpenTelemetryConfig, error) { defaultProtocol = "grpc" } - tracesConfig, err := parseTracesConfig(defaultEndpoint, defaultProtocol) + tracesConfig, err := parseTracesConfig(viper, defaultEndpoint, defaultProtocol) if err != nil { return nil, err } - metricsConfig, err := parseMetricsConfig(defaultEndpoint, defaultProtocol) + metricsConfig, err := parseMetricsConfig(viper, defaultEndpoint, defaultProtocol) if err != nil { return nil, err } - logsConfig, err := parseLogsConfig(defaultEndpoint, defaultProtocol) + logsConfig, err := parseLogsConfig(viper, defaultEndpoint, defaultProtocol) if err != nil { return nil, err } - return &OpenTelemetryConfig{ + return &otel.OpenTelemetryConfig{ Traces: tracesConfig, Metrics: metricsConfig, Logs: logsConfig, }, nil } -func parseTracesConfig(defaultEndpoint, defaultProtocol string) (*OpenTelemetryTypeConfig, error) { - endpoint := getTypeSpecificWithDefault("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", defaultEndpoint) +func parseTracesConfig(viper *v.Viper, defaultEndpoint, defaultProtocol string) (*otel.OpenTelemetryTypeConfig, error) { + endpoint := getTypeSpecificWithDefault(viper, "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", defaultEndpoint) if endpoint == "" { return nil, nil } - protocol, err := OpenTelemetryProtocolFromString(getTypeSpecificWithDefault("OTEL_EXPORTER_OTLP_TRACES_PROTOCOL", defaultProtocol)) + protocol, err := otel.OpenTelemetryProtocolFromString(getTypeSpecificWithDefault(viper, "OTEL_EXPORTER_OTLP_TRACES_PROTOCOL", defaultProtocol)) if err != nil { return nil, err } - return &OpenTelemetryTypeConfig{ + return &otel.OpenTelemetryTypeConfig{ Endpoint: endpoint, Protocol: protocol, }, nil } -func parseMetricsConfig(defaultEndpoint, defaultProtocol string) (*OpenTelemetryTypeConfig, error) { - endpoint := getTypeSpecificWithDefault("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", defaultEndpoint) +func parseMetricsConfig(viper *v.Viper, defaultEndpoint, defaultProtocol string) (*otel.OpenTelemetryTypeConfig, error) { + endpoint := getTypeSpecificWithDefault(viper, "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", defaultEndpoint) if endpoint == "" { return nil, nil } - protocol, err := OpenTelemetryProtocolFromString(getTypeSpecificWithDefault("OTEL_EXPORTER_OTLP_METRICS_PROTOCOL", defaultProtocol)) + protocol, err := otel.OpenTelemetryProtocolFromString(getTypeSpecificWithDefault(viper, "OTEL_EXPORTER_OTLP_METRICS_PROTOCOL", defaultProtocol)) if err != nil { return nil, err } - return &OpenTelemetryTypeConfig{ + return &otel.OpenTelemetryTypeConfig{ Endpoint: endpoint, Protocol: protocol, }, nil } -func parseLogsConfig(defaultEndpoint, defaultProtocol string) (*OpenTelemetryTypeConfig, error) { - endpoint := getTypeSpecificWithDefault("OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", defaultEndpoint) +func parseLogsConfig(viper *v.Viper, defaultEndpoint, defaultProtocol string) (*otel.OpenTelemetryTypeConfig, error) { + endpoint := getTypeSpecificWithDefault(viper, "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", defaultEndpoint) if endpoint == "" { return nil, nil } - protocol, err := OpenTelemetryProtocolFromString(getTypeSpecificWithDefault("OTEL_EXPORTER_OTLP_LOGS_PROTOCOL", defaultProtocol)) + protocol, err := otel.OpenTelemetryProtocolFromString(getTypeSpecificWithDefault(viper, "OTEL_EXPORTER_OTLP_LOGS_PROTOCOL", defaultProtocol)) if err != nil { return nil, err } - return &OpenTelemetryTypeConfig{ + return &otel.OpenTelemetryTypeConfig{ Endpoint: endpoint, Protocol: protocol, }, nil } -func getTypeSpecificWithDefault(otelTypeKey string, defaultValue string) string { +func getTypeSpecificWithDefault(viper *v.Viper, otelTypeKey string, defaultValue string) string { value := viper.GetString(otelTypeKey) if value == "" { return defaultValue } return value } - -func OpenTelemetryProtocolFromString(s string) (OpenTelemetryProtocol, error) { - switch s { - case "grpc": - return OpenTelemetryProtocolGRPC, nil - case "http/protobuf": - return OpenTelemetryProtocolHTTPProtobuf, nil - case "http/json": - return OpenTelemetryProtocolHTTPJSON, nil - } - return OpenTelemetryProtocol(""), fmt.Errorf("unknown OpenTelemetry protocol: %s", s) -} diff --git a/internal/otel/config.go b/internal/otel/config.go new file mode 100644 index 00000000..046090aa --- /dev/null +++ b/internal/otel/config.go @@ -0,0 +1,35 @@ +package otel + +import "fmt" + +type OpenTelemetryProtocol string + +type OpenTelemetryTypeConfig struct { + Endpoint string + Protocol OpenTelemetryProtocol +} + +type OpenTelemetryConfig struct { + Traces *OpenTelemetryTypeConfig + Metrics *OpenTelemetryTypeConfig + Logs *OpenTelemetryTypeConfig +} + +const ( + OpenTelemetryProtocolGRPC OpenTelemetryProtocol = "grpc" + OpenTelemetryProtocolHTTPProtobuf OpenTelemetryProtocol = "http/protobuf" + OpenTelemetryProtocolHTTPJSON OpenTelemetryProtocol = "http/json" +) + +func OpenTelemetryProtocolFromString(s string) (OpenTelemetryProtocol, error) { + switch s { + case "grpc": + return OpenTelemetryProtocolGRPC, nil + case "http/protobuf": + return OpenTelemetryProtocolHTTPProtobuf, nil + case "http/json": + return OpenTelemetryProtocolHTTPJSON, nil + default: + return OpenTelemetryProtocol(""), fmt.Errorf("unknown OpenTelemetry protocol: %s", s) + } +} diff --git a/internal/otel/exporter.go b/internal/otel/exporter.go index 5e61b01a..4eb0d0a5 100644 --- a/internal/otel/exporter.go +++ b/internal/otel/exporter.go @@ -4,7 +4,6 @@ import ( "context" "time" - "github.com/hookdeck/EventKit/internal/config" "go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc" "go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp" "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" @@ -16,25 +15,24 @@ import ( "go.opentelemetry.io/otel/sdk/trace" ) -func newTraceProvider(ctx context.Context, c *config.Config) (*trace.TracerProvider, error) { - if c.OpenTelemetry.Traces == nil { +func newTraceProvider(ctx context.Context, config *OpenTelemetryConfig) (*trace.TracerProvider, error) { + if config.Traces == nil { return nil, nil } var err error var traceExporter trace.SpanExporter - if c.OpenTelemetry.Traces.Protocol == config.OpenTelemetryProtocolGRPC { + if config.Traces.Protocol == OpenTelemetryProtocolGRPC { traceExporter, err = otlptracegrpc.New(ctx, otlptracegrpc.WithInsecure(), // TODO: support TLS - otlptracegrpc.WithEndpoint(c.OpenTelemetry.Traces.Endpoint), + otlptracegrpc.WithEndpoint(config.Traces.Endpoint), ) } else { traceExporter, err = otlptracehttp.New(ctx, otlptracehttp.WithInsecure(), // TODO: support TLS - otlptracehttp.WithEndpointURL(ensureHTTPEndpoint("traces", c.OpenTelemetry.Traces.Endpoint)), + otlptracehttp.WithEndpointURL(ensureHTTPEndpoint("traces", config.Traces.Endpoint)), ) } - // traceExporter, err = stdouttrace.New() if err != nil { return nil, err @@ -50,22 +48,22 @@ func newTraceProvider(ctx context.Context, c *config.Config) (*trace.TracerProvi return traceProvider, nil } -func newMeterProvider(ctx context.Context, c *config.Config) (*metric.MeterProvider, error) { - if c.OpenTelemetry.Metrics == nil { +func newMeterProvider(ctx context.Context, config *OpenTelemetryConfig) (*metric.MeterProvider, error) { + if config.Metrics == nil { return nil, nil } var err error var metricExporter metric.Exporter - if c.OpenTelemetry.Metrics.Protocol == config.OpenTelemetryProtocolGRPC { + if config.Metrics.Protocol == OpenTelemetryProtocolGRPC { metricExporter, err = otlpmetricgrpc.New(ctx, otlpmetricgrpc.WithInsecure(), // TODO: support TLS - otlpmetricgrpc.WithEndpoint(c.OpenTelemetry.Metrics.Endpoint), + otlpmetricgrpc.WithEndpoint(config.Metrics.Endpoint), ) } else { metricExporter, err = otlpmetrichttp.New(ctx, otlpmetrichttp.WithInsecure(), // TODO: support TLS - otlpmetrichttp.WithEndpointURL(ensureHTTPEndpoint("metrics", c.OpenTelemetry.Metrics.Endpoint)), + otlpmetrichttp.WithEndpointURL(ensureHTTPEndpoint("metrics", config.Metrics.Endpoint)), ) } @@ -82,22 +80,22 @@ func newMeterProvider(ctx context.Context, c *config.Config) (*metric.MeterProvi return meterProvider, nil } -func newLoggerProvider(ctx context.Context, c *config.Config) (*log.LoggerProvider, error) { - if c.OpenTelemetry.Logs == nil { +func newLoggerProvider(ctx context.Context, config *OpenTelemetryConfig) (*log.LoggerProvider, error) { + if config.Logs == nil { return nil, nil } var err error var logExporter log.Exporter - if c.OpenTelemetry.Logs.Protocol == config.OpenTelemetryProtocolGRPC { + if config.Logs.Protocol == OpenTelemetryProtocolGRPC { logExporter, err = otlploggrpc.New(ctx, otlploggrpc.WithInsecure(), // TODO: support TLS - otlploggrpc.WithEndpoint(c.OpenTelemetry.Logs.Endpoint), + otlploggrpc.WithEndpoint(config.Logs.Endpoint), ) } else { logExporter, err = otlploghttp.New(ctx, otlploghttp.WithInsecure(), // TODO: support TLS - otlploghttp.WithEndpointURL(ensureHTTPEndpoint("logs", c.OpenTelemetry.Logs.Endpoint)), + otlploghttp.WithEndpointURL(ensureHTTPEndpoint("logs", config.Logs.Endpoint)), ) } diff --git a/internal/otel/otel.go b/internal/otel/otel.go index 2d4aa805..0b1e241b 100644 --- a/internal/otel/otel.go +++ b/internal/otel/otel.go @@ -3,9 +3,9 @@ package otel import ( "context" "errors" + "fmt" "time" - "github.com/hookdeck/EventKit/internal/config" "go.opentelemetry.io/contrib/instrumentation/host" "go.opentelemetry.io/contrib/instrumentation/runtime" "go.opentelemetry.io/otel" @@ -18,7 +18,9 @@ import ( // SetupOTelSDK bootstraps the OpenTelemetry pipeline. // If it does not return an error, make sure to call shutdown for proper cleanup. -func SetupOTelSDK(ctx context.Context, c *config.Config) (shutdown func(context.Context) error, err error) { +func SetupOTelSDK(ctx context.Context, config *OpenTelemetryConfig) (shutdown func(context.Context) error, err error) { + fmt.Println("config", config) + var shutdownFuncs []func(context.Context) error // shutdown calls cleanup functions registered via shutdownFuncs. @@ -43,7 +45,7 @@ func SetupOTelSDK(ctx context.Context, c *config.Config) (shutdown func(context. otel.SetTextMapPropagator(prop) // Set up trace provider. - tracerProvider, err := newTraceProvider(ctx, c) + tracerProvider, err := newTraceProvider(ctx, config) if err != nil { handleErr(err) return @@ -54,7 +56,7 @@ func SetupOTelSDK(ctx context.Context, c *config.Config) (shutdown func(context. } // Set up meter provider. - meterProvider, err := newMeterProvider(ctx, c) + meterProvider, err := newMeterProvider(ctx, config) if err != nil { handleErr(err) return @@ -76,7 +78,7 @@ func SetupOTelSDK(ctx context.Context, c *config.Config) (shutdown func(context. } // Set up logger provider. - loggerProvider, err := newLoggerProvider(ctx, c) + loggerProvider, err := newLoggerProvider(ctx, config) if err != nil { handleErr(err) return diff --git a/internal/redis/config.go b/internal/redis/config.go new file mode 100644 index 00000000..85628a01 --- /dev/null +++ b/internal/redis/config.go @@ -0,0 +1,8 @@ +package redis + +type RedisConfig struct { + Host string + Port int + Password string + Database int +} diff --git a/internal/redis/redis.go b/internal/redis/redis.go index 45245ef5..7171a7c5 100644 --- a/internal/redis/redis.go +++ b/internal/redis/redis.go @@ -5,7 +5,6 @@ import ( "fmt" "sync" - "github.com/hookdeck/EventKit/internal/config" "github.com/redis/go-redis/extra/redisotel/v9" r "github.com/redis/go-redis/v9" ) @@ -23,9 +22,9 @@ var ( initializationError error ) -func New(ctx context.Context, c *config.Config) (*r.Client, error) { +func New(ctx context.Context, config *RedisConfig) (*r.Client, error) { once.Do(func() { - initializeClient(ctx, c) + initializeClient(ctx, config) initializationError = instrumentOpenTelemetry() }) return client, initializationError @@ -41,10 +40,10 @@ func instrumentOpenTelemetry() error { return nil } -func initializeClient(_ context.Context, c *config.Config) { +func initializeClient(_ context.Context, config *RedisConfig) { client = r.NewClient(&r.Options{ - Addr: fmt.Sprintf("%s:%d", c.RedisHost, c.RedisPort), - Password: c.RedisPassword, - DB: c.RedisDatabase, + Addr: fmt.Sprintf("%s:%d", config.Host, config.Port), + Password: config.Password, + DB: config.Database, }) } diff --git a/internal/services/api/api.go b/internal/services/api/api.go index 3c9f55b8..1f9e024e 100644 --- a/internal/services/api/api.go +++ b/internal/services/api/api.go @@ -19,20 +19,20 @@ type APIService struct { logger *otelzap.Logger } -func NewService(ctx context.Context, wg *sync.WaitGroup, c *config.Config, logger *otelzap.Logger) (*APIService, error) { +func NewService(ctx context.Context, wg *sync.WaitGroup, cfg *config.Config, logger *otelzap.Logger) (*APIService, error) { wg.Add(1) - redisClient, err := redis.New(ctx, c) + redisClient, err := redis.New(ctx, cfg.Redis) if err != nil { return nil, err } - router := NewRouter(c, logger, redisClient) + router := NewRouter(cfg, logger, redisClient) service := &APIService{} service.logger = logger service.server = &http.Server{ - Addr: fmt.Sprintf(":%d", c.Port), + Addr: fmt.Sprintf(":%d", cfg.Port), Handler: router, } diff --git a/internal/services/api/router.go b/internal/services/api/router.go index c152be60..d9d79c45 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -11,9 +11,9 @@ import ( "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" ) -func NewRouter(c *config.Config, logger *otelzap.Logger, redisClient *redis.Client) http.Handler { +func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Client) http.Handler { r := gin.Default() - r.Use(otelgin.Middleware(c.Hostname)) + r.Use(otelgin.Middleware(cfg.Hostname)) r.GET("/healthz", func(c *gin.Context) { logger.Ctx(c.Request.Context()).Info("health check") diff --git a/internal/services/delivery/delivery.go b/internal/services/delivery/delivery.go index 6acea462..6e46df3b 100644 --- a/internal/services/delivery/delivery.go +++ b/internal/services/delivery/delivery.go @@ -13,7 +13,7 @@ type DeliveryService struct { logger *otelzap.Logger } -func NewService(ctx context.Context, wg *sync.WaitGroup, c *config.Config, logger *otelzap.Logger) (*DeliveryService, error) { +func NewService(ctx context.Context, wg *sync.WaitGroup, cfg *config.Config, logger *otelzap.Logger) (*DeliveryService, error) { wg.Add(1) go func() { defer wg.Done() diff --git a/internal/services/log/log.go b/internal/services/log/log.go index 66bed998..17882629 100644 --- a/internal/services/log/log.go +++ b/internal/services/log/log.go @@ -17,7 +17,7 @@ type LogService struct { redisClient *redis.Client } -func NewService(ctx context.Context, wg *sync.WaitGroup, c *config.Config, logger *otelzap.Logger) (*LogService, error) { +func NewService(ctx context.Context, wg *sync.WaitGroup, cfg *config.Config, logger *otelzap.Logger) (*LogService, error) { wg.Add(1) go func() { defer wg.Done() @@ -25,7 +25,7 @@ func NewService(ctx context.Context, wg *sync.WaitGroup, c *config.Config, logge logger.Ctx(ctx).Info("service shutdown", zap.String("service", "log")) }() - redisClient, err := redis.New(ctx, c) + redisClient, err := redis.New(ctx, cfg.Redis) if err != nil { return nil, err } From 714f177b1e083757954e47291ebdc8cfb4f0b86a Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Tue, 27 Aug 2024 12:04:15 +0700 Subject: [PATCH 03/21] chore: remove debug log --- cmd/app/main.go | 1 - internal/otel/otel.go | 3 --- 2 files changed, 4 deletions(-) diff --git a/cmd/app/main.go b/cmd/app/main.go index cdca038e..fabe26ee 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -51,7 +51,6 @@ func run(mainContext context.Context) error { ctx, cancel := context.WithCancel(mainContext) // Set up OpenTelemetry. - fmt.Println("cfg.OpenTelemetry", cfg.OpenTelemetry) if cfg.OpenTelemetry != nil { otelShutdown, err := otel.SetupOTelSDK(ctx, cfg.OpenTelemetry) if err != nil { diff --git a/internal/otel/otel.go b/internal/otel/otel.go index 0b1e241b..3227661c 100644 --- a/internal/otel/otel.go +++ b/internal/otel/otel.go @@ -3,7 +3,6 @@ package otel import ( "context" "errors" - "fmt" "time" "go.opentelemetry.io/contrib/instrumentation/host" @@ -19,8 +18,6 @@ import ( // SetupOTelSDK bootstraps the OpenTelemetry pipeline. // If it does not return an error, make sure to call shutdown for proper cleanup. func SetupOTelSDK(ctx context.Context, config *OpenTelemetryConfig) (shutdown func(context.Context) error, err error) { - fmt.Println("config", config) - var shutdownFuncs []func(context.Context) error // shutdown calls cleanup functions registered via shutdownFuncs. From 7908bf0f2f3f42bdc011940a573d21b647f1346c Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Tue, 27 Aug 2024 12:58:14 +0700 Subject: [PATCH 04/21] test: Destination model --- go.mod | 3 + .../destination_test/model_test.go | 77 +++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 internal/destination/destination_test/model_test.go diff --git a/go.mod b/go.mod index 4e768647..005183c5 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/redis/go-redis/extra/redisotel/v9 v9.5.3 github.com/redis/go-redis/v9 v9.6.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 github.com/uptrace/opentelemetry-go-extra/otelzap v0.3.1 go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.53.0 go.opentelemetry.io/contrib/instrumentation/host v0.54.0 @@ -34,6 +35,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.4 // indirect @@ -57,6 +59,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/redis/go-redis/extra/rediscmd/v9 v9.5.3 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/internal/destination/destination_test/model_test.go b/internal/destination/destination_test/model_test.go new file mode 100644 index 00000000..8e7ccb92 --- /dev/null +++ b/internal/destination/destination_test/model_test.go @@ -0,0 +1,77 @@ +package destination_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/hookdeck/EventKit/internal/destination" + "github.com/hookdeck/EventKit/internal/redis" + "github.com/stretchr/testify/assert" +) + +func getRedisClient() (*redis.Client, error) { + return redis.New(context.Background(), &redis.RedisConfig{ + Host: "localhost", + Port: 6379, + Password: "password", + Database: 0, + }) +} + +func TestDestinationModel(t *testing.T) { + redisClient, err := getRedisClient() + if err != nil { + t.Fatal(err) + } + model := destination.NewDestinationModel(redisClient) + + input := destination.Destination{ + ID: uuid.New().String(), + Name: "Test Destination", + } + + t.Run("gets empty", func(t *testing.T) { + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, actual, "model.Get() should return nil when there's no value") + assert.Nil(t, err, "model.Get() should not return an error when there's no value") + }) + + t.Run("sets", func(t *testing.T) { + err := model.Set(context.Background(), input) + assert.Nil(t, err, "model.Set() should not return an error") + + value, err := redisClient.Get(context.Background(), "destination:"+input.ID).Result() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, input.Name, value, "model.Set() should set destination name %s", input.Name) + }) + + t.Run("gets", func(t *testing.T) { + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, err, "model.Get() should not return an error") + assert.Equal(t, input, *actual, "model.Get() should return %s", input) + }) + + t.Run("overrides", func(t *testing.T) { + input.Name = "Test Destination 2" + + err := model.Set(context.Background(), input) + assert.Nil(t, err, "model.Set() should not return an error", input) + + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, err, "model.Get() should not return an error") + assert.Equal(t, input, *actual, "model.Get() should return %s", input) + }) + + t.Run("clears", func(t *testing.T) { + deleted, err := model.Clear(context.Background(), input.ID) + assert.Nil(t, err, "model.Clear() should not return an error") + assert.Equal(t, *deleted, input, "model.Clear() should return deleted value", input) + + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, actual, "model.Clear() should properly remove value") + assert.Nil(t, err, "model.Clear() should properly remove value") + }) +} From 272d1a12182c00cd7b10f59ebf0eaa729627ea28 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Tue, 27 Aug 2024 14:38:04 +0700 Subject: [PATCH 05/21] test: Use miniredis for testing --- go.mod | 3 +++ go.sum | 6 ++++++ .../destination_test/model_test.go | 16 ++------------- internal/util/testutil/testutil.go | 20 +++++++++++++++++++ 4 files changed, 31 insertions(+), 14 deletions(-) create mode 100644 internal/util/testutil/testutil.go diff --git a/go.mod b/go.mod index 005183c5..ad7951bb 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/hookdeck/EventKit go 1.23.0 require ( + github.com/alicebob/miniredis/v2 v2.33.0 github.com/gin-gonic/gin v1.10.0 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 @@ -29,6 +30,7 @@ require ( ) require ( + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect github.com/bytedance/sonic v1.11.9 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -76,6 +78,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/uptrace/opentelemetry-go-extra/otelutil v0.3.1 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 // indirect go.opentelemetry.io/otel/metric v1.29.0 // indirect diff --git a/go.sum b/go.sum index de67fd17..1f715532 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA= +github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -145,6 +149,8 @@ github.com/uptrace/opentelemetry-go-extra/otelutil v0.3.1 h1:Suvl9fe12MM0oi8/rcG github.com/uptrace/opentelemetry-go-extra/otelutil v0.3.1/go.mod h1:aiX/F5+EYbY2ed2OQEYRXzMcNGvI9pip5gW2ZtBDers= github.com/uptrace/opentelemetry-go-extra/otelzap v0.3.1 h1:0iCp8hx3PFhGihubKHxyOCdIlIPxzUr0VsK+rvlMGdk= github.com/uptrace/opentelemetry-go-extra/otelzap v0.3.1/go.mod h1:FXrjpUJDqwqofvXWG3YNxQwhg2876tUpZASj8VvOMAM= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.53.0 h1:ktt8061VV/UU5pdPF6AcEFyuPxMizf/vU6eD1l+13LI= diff --git a/internal/destination/destination_test/model_test.go b/internal/destination/destination_test/model_test.go index 8e7ccb92..c923a707 100644 --- a/internal/destination/destination_test/model_test.go +++ b/internal/destination/destination_test/model_test.go @@ -6,24 +6,12 @@ import ( "github.com/google/uuid" "github.com/hookdeck/EventKit/internal/destination" - "github.com/hookdeck/EventKit/internal/redis" + "github.com/hookdeck/EventKit/internal/util/testutil" "github.com/stretchr/testify/assert" ) -func getRedisClient() (*redis.Client, error) { - return redis.New(context.Background(), &redis.RedisConfig{ - Host: "localhost", - Port: 6379, - Password: "password", - Database: 0, - }) -} - func TestDestinationModel(t *testing.T) { - redisClient, err := getRedisClient() - if err != nil { - t.Fatal(err) - } + redisClient := testutil.CreateTestRedisClient(t) model := destination.NewDestinationModel(redisClient) input := destination.Destination{ diff --git a/internal/util/testutil/testutil.go b/internal/util/testutil/testutil.go new file mode 100644 index 00000000..52a43bc5 --- /dev/null +++ b/internal/util/testutil/testutil.go @@ -0,0 +1,20 @@ +package testutil + +import ( + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +func CreateTestRedisClient(t *testing.T) *redis.Client { + mr := miniredis.RunT(t) + + t.Cleanup(func() { + mr.Close() + }) + + return redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) +} From c6bf701d2b67e2d2170dc700575d2c632d5c6b81 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Tue, 27 Aug 2024 16:43:37 +0700 Subject: [PATCH 06/21] test: Destination handlers --- .../destination_test/handlers_test.go | 214 ++++++++++++++++++ .../destination_test/model_test.go | 2 + 2 files changed, 216 insertions(+) create mode 100644 internal/destination/destination_test/handlers_test.go diff --git a/internal/destination/destination_test/handlers_test.go b/internal/destination/destination_test/handlers_test.go new file mode 100644 index 00000000..1e7b81f1 --- /dev/null +++ b/internal/destination/destination_test/handlers_test.go @@ -0,0 +1,214 @@ +package destination_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/hookdeck/EventKit/internal/destination" + "github.com/hookdeck/EventKit/internal/util/testutil" + "github.com/stretchr/testify/assert" +) + +func setupRouter(destinationHandlers *destination.DestinationHandlers) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.Default() + r.GET("/destinations", destinationHandlers.List) + r.POST("/destinations", destinationHandlers.Create) + r.GET("/destinations/:destinationID", destinationHandlers.Retrieve) + r.PATCH("/destinations/:destinationID", destinationHandlers.Update) + r.DELETE("/destinations/:destinationID", destinationHandlers.Delete) + return r +} + +func TestDestinationListHandler(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + t.Run("should return 501", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/destinations", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotImplemented, w.Code) + }) +} + +func TestDestinationCreateHandler(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + t.Run("should create", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + + exampleDestination := destination.CreateDestinationRequest{ + Name: "Test Destination", + } + destinationJSON, _ := json.Marshal(exampleDestination) + req, _ := http.NewRequest("POST", "/destinations", strings.NewReader(string(destinationJSON))) + router.ServeHTTP(w, req) + + var destinationResponse map[string]any + json.Unmarshal(w.Body.Bytes(), &destinationResponse) + + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, exampleDestination.Name, destinationResponse["name"]) + assert.NotEqual(t, "", destinationResponse["id"]) + }) +} + +func TestDestinationRetrieveHandler(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + t.Run("should return 404 when there's no destination", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/destinations/invalid_id", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should retrieve when there's a destination", func(t *testing.T) { + t.Parallel() + + // Setup test destination + exampleDestination := destination.Destination{ + ID: uuid.New().String(), + Name: "Test Destination", + } + redisClient.Set(context.Background(), "destination:"+exampleDestination.ID, exampleDestination.Name, 0) + + // Test HTTP request + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/destinations/"+exampleDestination.ID, nil) + router.ServeHTTP(w, req) + + var destinationResponse map[string]any + json.Unmarshal(w.Body.Bytes(), &destinationResponse) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, exampleDestination.ID, destinationResponse["id"]) + assert.Equal(t, exampleDestination.Name, destinationResponse["name"]) + + // Clean up + redisClient.Del(context.Background(), "destination:"+exampleDestination.ID) + }) +} + +func TestDestinationUpdateHandler(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + initialDestination := destination.Destination{ + Name: "Test Destination", + } + + updateDestinationRequest := destination.UpdateDestinationRequest{ + Name: "Updated Destination", + } + updateDestinationJSON, _ := json.Marshal(updateDestinationRequest) + + t.Run("should validate", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PATCH", "/destinations/invalid_id", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("should return 404 when there's no destination", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PATCH", "/destinations/invalid_id", strings.NewReader(string(updateDestinationJSON))) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should update destination", func(t *testing.T) { + t.Parallel() + + // Setup initial destination + newDestination := initialDestination + newDestination.ID = uuid.New().String() + redisClient.Set(context.Background(), "destination:"+newDestination.ID, newDestination.Name, 0) + + // Test HTTP request + w := httptest.NewRecorder() + req, _ := http.NewRequest("PATCH", "/destinations/"+newDestination.ID, strings.NewReader(string(updateDestinationJSON))) + router.ServeHTTP(w, req) + + var destinationResponse map[string]any + json.Unmarshal(w.Body.Bytes(), &destinationResponse) + + assert.Equal(t, http.StatusAccepted, w.Code) + assert.Equal(t, newDestination.ID, destinationResponse["id"]) + assert.Equal(t, updateDestinationRequest.Name, destinationResponse["name"]) + + // Clean up + redisClient.Del(context.Background(), "destination:"+newDestination.ID) + }) +} + +func TestDestinationDeleteHandler(t *testing.T) { + redisClient := testutil.CreateTestRedisClient(t) + handlers := destination.NewHandlers(redisClient) + router := setupRouter(handlers) + + t.Run("should return 404 when there's no destination", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/destinations/invalid_id", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should delete destination", func(t *testing.T) { + t.Parallel() + + // Setup initial destination + newDestination := destination.Destination{ + ID: uuid.New().String(), + Name: "Test Destination", + } + redisClient.Set(context.Background(), "destination:"+newDestination.ID, newDestination.Name, 0) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/destinations/"+newDestination.ID, nil) + router.ServeHTTP(w, req) + + var destinationResponse map[string]any + json.Unmarshal(w.Body.Bytes(), &destinationResponse) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, newDestination.ID, destinationResponse["id"]) + assert.Equal(t, newDestination.Name, destinationResponse["name"]) + }) +} diff --git a/internal/destination/destination_test/model_test.go b/internal/destination/destination_test/model_test.go index c923a707..929755ab 100644 --- a/internal/destination/destination_test/model_test.go +++ b/internal/destination/destination_test/model_test.go @@ -11,6 +11,8 @@ import ( ) func TestDestinationModel(t *testing.T) { + t.Parallel() + redisClient := testutil.CreateTestRedisClient(t) model := destination.NewDestinationModel(redisClient) From 640cab1151b2e8ca8ad18fd09b68431c978ba66a Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Tue, 27 Aug 2024 16:48:30 +0700 Subject: [PATCH 07/21] chore: Remove debug log --- internal/config/config.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index e8e60a33..aff86240 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,6 @@ package config import ( - "fmt" "log" "os" "strconv" @@ -77,8 +76,6 @@ func Parse(flags Flags) (*Config, error) { return nil, err } - fmt.Println(openTelemetry) - // Initialize config values config := &Config{ Hostname: hostname, From 6023a5939ef300ec9e7eacc9b15d47f9dc70c6c9 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 08:10:58 +0700 Subject: [PATCH 08/21] chore: Support API_PORT env --- internal/config/config.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/internal/config/config.go b/internal/config/config.go index aff86240..35ad9d63 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -80,7 +80,7 @@ func Parse(flags Flags) (*Config, error) { config := &Config{ Hostname: hostname, Service: service, - Port: mustInt(viper, "PORT"), + Port: getPort(viper), Redis: &redis.RedisConfig{ Host: viper.GetString("REDIS_HOST"), Port: mustInt(viper, "REDIS_PORT"), @@ -100,3 +100,14 @@ func mustInt(viper *v.Viper, configName string) int { } return i } + +func getPort(viper *v.Viper) int { + port := mustInt(viper, "PORT") + if viper.GetString("API_PORT") != "" { + apiPort, err := strconv.Atoi(viper.GetString("API_PORT")) + if err == nil { + port = apiPort + } + } + return port +} From 889b53fd66d04b11825d643ea7185b366edbd074 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 08:30:14 +0700 Subject: [PATCH 09/21] feat: Authentication middleware --- internal/config/config.go | 2 ++ internal/services/api/auth_middleware.go | 45 ++++++++++++++++++++++++ internal/services/api/router.go | 1 + 3 files changed, 48 insertions(+) create mode 100644 internal/services/api/auth_middleware.go diff --git a/internal/config/config.go b/internal/config/config.go index 35ad9d63..464f87ca 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,7 @@ type Config struct { Service ServiceType Port int Hostname string + APIKey string Redis *redis.RedisConfig OpenTelemetry *otel.OpenTelemetryConfig @@ -81,6 +82,7 @@ func Parse(flags Flags) (*Config, error) { Hostname: hostname, Service: service, Port: getPort(viper), + APIKey: viper.GetString("API_KEY"), Redis: &redis.RedisConfig{ Host: viper.GetString("REDIS_HOST"), Port: mustInt(viper, "REDIS_PORT"), diff --git a/internal/services/api/auth_middleware.go b/internal/services/api/auth_middleware.go new file mode 100644 index 00000000..15bcacb7 --- /dev/null +++ b/internal/services/api/auth_middleware.go @@ -0,0 +1,45 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +func authMiddleware(apiKey string) gin.HandlerFunc { + if apiKey == "" { + return func(c *gin.Context) { + c.Next() + } + } + + return func(c *gin.Context) { + authorizationToken, err := extractBearerToken(c.GetHeader("Authorization")) + if err != nil { + // TODO: Consider sending a more detailed error message. + // Currently we don't have clear specs on how to send back error message. + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + fmt.Println(apiKey) + if authorizationToken != apiKey { + // TODO: Consider sending a more detailed error message. + // Currently we don't have clear specs on how to send back error message. + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + c.Next() + } +} + +func extractBearerToken(header string) (string, error) { + if !strings.HasPrefix(header, "Bearer ") { + return "", errors.New("invalid bearer token") + } + return strings.TrimPrefix(header, "Bearer "), nil +} diff --git a/internal/services/api/router.go b/internal/services/api/router.go index d9d79c45..78b6debb 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -14,6 +14,7 @@ import ( func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Client) http.Handler { r := gin.Default() r.Use(otelgin.Middleware(cfg.Hostname)) + r.Use(authMiddleware(cfg.APIKey)) r.GET("/healthz", func(c *gin.Context) { logger.Ctx(c.Request.Context()).Info("health check") From ed38e7cbf86ce107c5cfd32025f4031734ce20fd Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 08:41:31 +0700 Subject: [PATCH 10/21] test: Authentication middleware --- internal/services/api/auth_middleware_test.go | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 internal/services/api/auth_middleware_test.go diff --git a/internal/services/api/auth_middleware_test.go b/internal/services/api/auth_middleware_test.go new file mode 100644 index 00000000..864c6938 --- /dev/null +++ b/internal/services/api/auth_middleware_test.go @@ -0,0 +1,86 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func setupRouter(apiKey string) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.Default() + r.Use(authMiddleware(apiKey)) + r.GET("/healthz", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + return r +} + +func TestPublicRouter(t *testing.T) { + t.Parallel() + + router := setupRouter("") + + t.Run("should accept requests without a token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("should accept requests with an invalid authorization token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + req.Header.Set("Authorization", "invalid key") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("should accept requests with a valid authorization token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + req.Header.Set("Authorization", "Bearer key") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) +} + +func TestPrivateRouter(t *testing.T) { + t.Parallel() + + const apiKey = "key" + + router := setupRouter(apiKey) + + t.Run("should reject requests without a token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("should reject requests with an invalid authorization token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + req.Header.Set("Authorization", "invalid key") + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("should accept requests with a valid authorization token", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + req.Header.Set("Authorization", "Bearer "+apiKey) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) +} From 950dcb829a575fce278f13004dd9229e35963b3a Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 08:42:34 +0700 Subject: [PATCH 11/21] chore: Remove debug log --- internal/services/api/auth_middleware.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/internal/services/api/auth_middleware.go b/internal/services/api/auth_middleware.go index 15bcacb7..6ddd2dcf 100644 --- a/internal/services/api/auth_middleware.go +++ b/internal/services/api/auth_middleware.go @@ -2,7 +2,6 @@ package api import ( "errors" - "fmt" "net/http" "strings" @@ -24,15 +23,12 @@ func authMiddleware(apiKey string) gin.HandlerFunc { c.AbortWithStatus(http.StatusUnauthorized) return } - - fmt.Println(apiKey) if authorizationToken != apiKey { // TODO: Consider sending a more detailed error message. // Currently we don't have clear specs on how to send back error message. c.AbortWithStatus(http.StatusUnauthorized) return } - c.Next() } } From 9aa9347c4dd05bab3337883771589680e8ea6f2d Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 08:56:55 +0700 Subject: [PATCH 12/21] chore: Rename middleware name to specify API key mechanism --- internal/services/api/auth_middleware.go | 2 +- internal/services/api/auth_middleware_test.go | 4 ++-- internal/services/api/router.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/services/api/auth_middleware.go b/internal/services/api/auth_middleware.go index 6ddd2dcf..c0af656c 100644 --- a/internal/services/api/auth_middleware.go +++ b/internal/services/api/auth_middleware.go @@ -8,7 +8,7 @@ import ( "github.com/gin-gonic/gin" ) -func authMiddleware(apiKey string) gin.HandlerFunc { +func apiKeyAuthMiddleware(apiKey string) gin.HandlerFunc { if apiKey == "" { return func(c *gin.Context) { c.Next() diff --git a/internal/services/api/auth_middleware_test.go b/internal/services/api/auth_middleware_test.go index 864c6938..165bf68d 100644 --- a/internal/services/api/auth_middleware_test.go +++ b/internal/services/api/auth_middleware_test.go @@ -12,7 +12,7 @@ import ( func setupRouter(apiKey string) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.Default() - r.Use(authMiddleware(apiKey)) + r.Use(apiKeyAuthMiddleware(apiKey)) r.GET("/healthz", func(c *gin.Context) { c.Status(http.StatusOK) }) @@ -51,7 +51,7 @@ func TestPublicRouter(t *testing.T) { }) } -func TestPrivateRouter(t *testing.T) { +func TestPrivateAPIKeyRouter(t *testing.T) { t.Parallel() const apiKey = "key" diff --git a/internal/services/api/router.go b/internal/services/api/router.go index 78b6debb..c4b0d4d0 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -14,7 +14,7 @@ import ( func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Client) http.Handler { r := gin.Default() r.Use(otelgin.Middleware(cfg.Hostname)) - r.Use(authMiddleware(cfg.APIKey)) + r.Use(apiKeyAuthMiddleware(cfg.APIKey)) r.GET("/healthz", func(c *gin.Context) { logger.Ctx(c.Request.Context()).Info("health check") From ab3ffbe2f061d6989e40e83da92e4692066a6082 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 12:13:04 +0700 Subject: [PATCH 13/21] feat: Tenant's CRUD --- internal/services/api/router.go | 8 ++ internal/tenant/handlers.go | 99 +++++++++++++++++ internal/tenant/handlers_test.go | 170 +++++++++++++++++++++++++++++ internal/tenant/model.go | 61 +++++++++++ internal/tenant/model_test.go | 68 ++++++++++++ internal/util/testutil/testutil.go | 11 ++ 6 files changed, 417 insertions(+) create mode 100644 internal/tenant/handlers.go create mode 100644 internal/tenant/handlers_test.go create mode 100644 internal/tenant/model.go create mode 100644 internal/tenant/model_test.go diff --git a/internal/services/api/router.go b/internal/services/api/router.go index c4b0d4d0..cc5c23d8 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -7,6 +7,7 @@ import ( "github.com/hookdeck/EventKit/internal/config" "github.com/hookdeck/EventKit/internal/destination" "github.com/hookdeck/EventKit/internal/redis" + "github.com/hookdeck/EventKit/internal/tenant" "github.com/uptrace/opentelemetry-go-extra/otelzap" "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" ) @@ -21,6 +22,13 @@ func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Cl c.Status(http.StatusOK) }) + tenantHandlers := tenant.NewHandlers(logger, redisClient) + + r.PUT("/:tenantID", tenantHandlers.Upsert) + r.GET("/:tenantID", tenantHandlers.Retrieve) + r.DELETE("/:tenantID", tenantHandlers.Delete) + r.GET("/:tenantID/portal", tenantHandlers.RetrievePortal) + destinationHandlers := destination.NewHandlers(redisClient) r.GET("/destinations", destinationHandlers.List) diff --git a/internal/tenant/handlers.go b/internal/tenant/handlers.go new file mode 100644 index 00000000..c8385dda --- /dev/null +++ b/internal/tenant/handlers.go @@ -0,0 +1,99 @@ +package tenant + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/hookdeck/EventKit/internal/redis" + "github.com/uptrace/opentelemetry-go-extra/otelzap" + "go.uber.org/zap" +) + +type TenantHandlers struct { + logger *otelzap.Logger + model *TenantModel +} + +func NewHandlers(logger *otelzap.Logger, redisClient *redis.Client) *TenantHandlers { + return &TenantHandlers{ + logger: logger, + model: NewTenantModel(redisClient), + } +} + +func (h *TenantHandlers) Upsert(c *gin.Context) { + logger := h.logger.Ctx(c.Request.Context()) + tenantID := c.Param("tenantID") + + // Check existing tenant. + tenant, err := h.model.Get(c.Request.Context(), tenantID) + if err != nil { + logger.Error("failed to get tenant", zap.Error(err)) + c.Status(http.StatusInternalServerError) + return + } + + // If tenant already exists, return. + if tenant != nil { + c.JSON(http.StatusOK, tenant) + return + } + + // Create new tenant. + tenant = &Tenant{ + ID: tenantID, + CreatedAt: time.Now().String(), + } + if err := h.model.Set(c.Request.Context(), *tenant); err != nil { + logger.Error("failed to set tenant", zap.Error(err)) + c.Status(http.StatusInternalServerError) + return + } + c.JSON(http.StatusCreated, tenant) +} + +func (h *TenantHandlers) Retrieve(c *gin.Context) { + logger := h.logger.Ctx(c.Request.Context()) + tenantID := c.Param("tenantID") + tenant, err := h.model.Get(c.Request.Context(), tenantID) + if err != nil { + logger.Error("failed to get tenant", zap.Error(err)) + c.Status(http.StatusInternalServerError) + return + } + if tenant == nil { + c.Status(http.StatusNotFound) + return + } + c.JSON(http.StatusOK, tenant) +} + +func (h *TenantHandlers) Delete(c *gin.Context) { + logger := h.logger.Ctx(c.Request.Context()) + tenantID := c.Param("tenantID") + tenant, err := h.model.Get(c.Request.Context(), tenantID) + if err != nil { + logger.Error("failed to get tenant", zap.Error(err)) + c.Status(http.StatusInternalServerError) + return + } + if tenant == nil { + c.Status(http.StatusNotFound) + return + } + tenant, err = h.model.Clear(c.Request.Context(), tenantID) + if err != nil { + logger.Error("failed to delete tenant", zap.Error(err)) + c.Status(http.StatusInternalServerError) + return + } + + // TODO: delete associated destinations + + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +func (h *TenantHandlers) RetrievePortal(c *gin.Context) { + c.Status(http.StatusNotImplemented) +} diff --git a/internal/tenant/handlers_test.go b/internal/tenant/handlers_test.go new file mode 100644 index 00000000..d719c8c6 --- /dev/null +++ b/internal/tenant/handlers_test.go @@ -0,0 +1,170 @@ +package tenant_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/hookdeck/EventKit/internal/tenant" + "github.com/hookdeck/EventKit/internal/util/testutil" + "github.com/stretchr/testify/assert" +) + +func setupRouter(tenantHandlers *tenant.TenantHandlers) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.Default() + r.PUT("/:tenantID", tenantHandlers.Upsert) + r.GET("/:tenantID", tenantHandlers.Retrieve) + r.DELETE("/:tenantID", tenantHandlers.Delete) + r.GET("/:tenantID/portal", tenantHandlers.RetrievePortal) + return r +} + +func TestDestinationUpsertHandler(t *testing.T) { + t.Parallel() + + logger := testutil.CreateTestLogger(t) + redisClient := testutil.CreateTestRedisClient(t) + model := tenant.NewTenantModel(redisClient) + handlers := tenant.NewHandlers(logger, redisClient) + router := setupRouter(handlers) + + t.Run("should create when there's no existing tenant", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + + id := uuid.New().String() + req, _ := http.NewRequest("PUT", "/"+id, nil) + router.ServeHTTP(w, req) + + var response map[string]any + json.Unmarshal(w.Body.Bytes(), &response) + + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, id, response["id"]) + assert.NotEqual(t, "", response["created_at"]) + }) + + t.Run("should return tenant when there's already one", func(t *testing.T) { + t.Parallel() + + // Setup + existingResource := tenant.Tenant{ + ID: uuid.New().String(), + CreatedAt: time.Now().String(), + } + model.Set(context.Background(), existingResource) + + // Request + w := httptest.NewRecorder() + req, _ := http.NewRequest("PUT", "/"+existingResource.ID, nil) + router.ServeHTTP(w, req) + var response map[string]any + json.Unmarshal(w.Body.Bytes(), &response) + + // Test + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, existingResource.ID, response["id"]) + assert.Equal(t, existingResource.CreatedAt, response["created_at"]) + + // Cleanup + model.Clear(context.Background(), existingResource.ID) + }) +} + +func TestTenantRetrieveHandler(t *testing.T) { + t.Parallel() + + logger := testutil.CreateTestLogger(t) + redisClient := testutil.CreateTestRedisClient(t) + model := tenant.NewTenantModel(redisClient) + handlers := tenant.NewHandlers(logger, redisClient) + router := setupRouter(handlers) + + t.Run("should return 404 when there's no tenant", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/invalid_id", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should retrieve tenant", func(t *testing.T) { + t.Parallel() + + // Setup + existingResource := tenant.Tenant{ + ID: uuid.New().String(), + CreatedAt: time.Now().String(), + } + model.Set(context.Background(), existingResource) + + // Request + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/"+existingResource.ID, nil) + router.ServeHTTP(w, req) + var response map[string]any + json.Unmarshal(w.Body.Bytes(), &response) + + // Assert + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, existingResource.ID, response["id"]) + assert.Equal(t, existingResource.CreatedAt, response["created_at"]) + + // Cleanup + model.Clear(context.Background(), existingResource.ID) + }) +} + +func TestTenantDeleteHandler(t *testing.T) { + t.Parallel() + + logger := testutil.CreateTestLogger(t) + redisClient := testutil.CreateTestRedisClient(t) + model := tenant.NewTenantModel(redisClient) + handlers := tenant.NewHandlers(logger, redisClient) + router := setupRouter(handlers) + + t.Run("should return 404 when there's no tenant", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/invalid_id", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should delete tenant", func(t *testing.T) { + t.Parallel() + + // Setup + existingResource := tenant.Tenant{ + ID: uuid.New().String(), + CreatedAt: time.Now().String(), + } + model.Set(context.Background(), existingResource) + + // Request + w := httptest.NewRecorder() + req, _ := http.NewRequest("DELETE", "/"+existingResource.ID, nil) + router.ServeHTTP(w, req) + var response map[string]any + json.Unmarshal(w.Body.Bytes(), &response) + + // Test + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, response["success"]) + + // Cleanup + model.Clear(context.Background(), existingResource.ID) + }) +} diff --git a/internal/tenant/model.go b/internal/tenant/model.go new file mode 100644 index 00000000..652842fb --- /dev/null +++ b/internal/tenant/model.go @@ -0,0 +1,61 @@ +package tenant + +import ( + "context" + "fmt" + + "github.com/hookdeck/EventKit/internal/redis" +) + +type Tenant struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` +} + +type TenantModel struct { + redisClient *redis.Client +} + +func NewTenantModel(redisClient *redis.Client) *TenantModel { + return &TenantModel{ + redisClient: redisClient, + } +} + +func (m *TenantModel) Get(c context.Context, id string) (*Tenant, error) { + destination, err := m.redisClient.Get(c, redisTenantID(id)).Result() + if err == redis.Nil { + return nil, nil + } else if err != nil { + return nil, err + } + return &Tenant{ + ID: id, + CreatedAt: destination, + }, nil +} + +func (m *TenantModel) Set(c context.Context, tenant Tenant) error { + if err := m.redisClient.Set(c, redisTenantID(tenant.ID), tenant.CreatedAt, 0).Err(); err != nil { + return err + } + return nil +} + +func (m *TenantModel) Clear(c context.Context, id string) (*Tenant, error) { + destination, err := m.Get(c, id) + if err != nil { + return nil, err + } + if destination == nil { + return nil, nil + } + if err := m.redisClient.Del(c, redisTenantID(id)).Err(); err != nil { + return nil, err + } + return destination, nil +} + +func redisTenantID(tenantID string) string { + return fmt.Sprintf("tenant:%s", tenantID) +} diff --git a/internal/tenant/model_test.go b/internal/tenant/model_test.go new file mode 100644 index 00000000..9fec405e --- /dev/null +++ b/internal/tenant/model_test.go @@ -0,0 +1,68 @@ +package tenant_test + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/hookdeck/EventKit/internal/tenant" + "github.com/hookdeck/EventKit/internal/util/testutil" + "github.com/stretchr/testify/assert" +) + +func TestTenantModel(t *testing.T) { + t.Parallel() + + redisClient := testutil.CreateTestRedisClient(t) + model := tenant.NewTenantModel(redisClient) + + input := tenant.Tenant{ + ID: uuid.New().String(), + CreatedAt: time.Now().String(), + } + + t.Run("gets empty", func(t *testing.T) { + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, actual, "model.Get() should return nil when there's no value") + assert.Nil(t, err, "model.Get() should not return an error when there's no value") + }) + + t.Run("sets", func(t *testing.T) { + err := model.Set(context.Background(), input) + assert.Nil(t, err, "model.Set() should not return an error") + + value, err := redisClient.Get(context.Background(), "tenant:"+input.ID).Result() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, input.CreatedAt, value, "model.Set() should set tenant created timestamp %s", input.CreatedAt) + }) + + t.Run("gets", func(t *testing.T) { + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, err, "model.Get() should not return an error") + assert.Equal(t, input, *actual, "model.Get() should return %s", input) + }) + + t.Run("overrides", func(t *testing.T) { + input.CreatedAt = time.Now().String() + + err := model.Set(context.Background(), input) + assert.Nil(t, err, "model.Set() should not return an error", input) + + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, err, "model.Get() should not return an error") + assert.Equal(t, input, *actual, "model.Get() should return %s", input) + }) + + t.Run("clears", func(t *testing.T) { + deleted, err := model.Clear(context.Background(), input.ID) + assert.Nil(t, err, "model.Clear() should not return an error") + assert.Equal(t, *deleted, input, "model.Clear() should return deleted value", input) + + actual, err := model.Get(context.Background(), input.ID) + assert.Nil(t, actual, "model.Clear() should properly remove value") + assert.Nil(t, err, "model.Clear() should properly remove value") + }) +} diff --git a/internal/util/testutil/testutil.go b/internal/util/testutil/testutil.go index 52a43bc5..c5f691c6 100644 --- a/internal/util/testutil/testutil.go +++ b/internal/util/testutil/testutil.go @@ -5,6 +5,9 @@ import ( "github.com/alicebob/miniredis/v2" "github.com/redis/go-redis/v9" + "github.com/uptrace/opentelemetry-go-extra/otelzap" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" ) func CreateTestRedisClient(t *testing.T) *redis.Client { @@ -18,3 +21,11 @@ func CreateTestRedisClient(t *testing.T) *redis.Client { Addr: mr.Addr(), }) } + +func CreateTestLogger(t *testing.T) *otelzap.Logger { + zapLogger := zaptest.NewLogger(t) + logger := otelzap.New(zapLogger, + otelzap.WithMinLevel(zap.InfoLevel), + ) + return logger +} From 5c5338f20c7f803f6a761eac83a352a94a025605 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 13:57:20 +0700 Subject: [PATCH 14/21] refactor: Use Redis hash for tenant --- internal/tenant/model.go | 40 +++++++++++++++++++++++------------ internal/tenant/model_test.go | 4 ++-- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/internal/tenant/model.go b/internal/tenant/model.go index 652842fb..47073a93 100644 --- a/internal/tenant/model.go +++ b/internal/tenant/model.go @@ -8,8 +8,8 @@ import ( ) type Tenant struct { - ID string `json:"id"` - CreatedAt string `json:"created_at"` + ID string `json:"id" redis:"id"` + CreatedAt string `json:"created_at" redis:"created_at"` } type TenantModel struct { @@ -23,37 +23,51 @@ func NewTenantModel(redisClient *redis.Client) *TenantModel { } func (m *TenantModel) Get(c context.Context, id string) (*Tenant, error) { - destination, err := m.redisClient.Get(c, redisTenantID(id)).Result() - if err == redis.Nil { + hash, err := m.redisClient.HGetAll(c, redisTenantID(id)).Result() + if err != nil { + return nil, err + } + if len(hash) == 0 { return nil, nil - } else if err != nil { + } + tenant := &Tenant{} + if err = tenant.parseRedisHash(hash); err != nil { return nil, err } - return &Tenant{ - ID: id, - CreatedAt: destination, - }, nil + return tenant, nil } func (m *TenantModel) Set(c context.Context, tenant Tenant) error { - if err := m.redisClient.Set(c, redisTenantID(tenant.ID), tenant.CreatedAt, 0).Err(); err != nil { + if err := m.redisClient.HSet(c, redisTenantID(tenant.ID), tenant).Err(); err != nil { return err } return nil } func (m *TenantModel) Clear(c context.Context, id string) (*Tenant, error) { - destination, err := m.Get(c, id) + tenant, err := m.Get(c, id) if err != nil { return nil, err } - if destination == nil { + if tenant == nil { return nil, nil } if err := m.redisClient.Del(c, redisTenantID(id)).Err(); err != nil { return nil, err } - return destination, nil + return tenant, nil +} + +func (t *Tenant) parseRedisHash(hash map[string]string) error { + if hash["id"] == "" { + return fmt.Errorf("missing id") + } + t.ID = hash["id"] + if hash["created_at"] == "" { + return fmt.Errorf("missing created_at") + } + t.CreatedAt = hash["created_at"] + return nil } func redisTenantID(tenantID string) string { diff --git a/internal/tenant/model_test.go b/internal/tenant/model_test.go index 9fec405e..a6893e96 100644 --- a/internal/tenant/model_test.go +++ b/internal/tenant/model_test.go @@ -32,11 +32,11 @@ func TestTenantModel(t *testing.T) { err := model.Set(context.Background(), input) assert.Nil(t, err, "model.Set() should not return an error") - value, err := redisClient.Get(context.Background(), "tenant:"+input.ID).Result() + hash, err := redisClient.HGetAll(context.Background(), "tenant:"+input.ID).Result() if err != nil { t.Fatal(err) } - assert.Equal(t, input.CreatedAt, value, "model.Set() should set tenant created timestamp %s", input.CreatedAt) + assert.Equal(t, input.CreatedAt, hash["created_at"], "model.Set() should set tenant created timestamp %s", input.CreatedAt) }) t.Run("gets", func(t *testing.T) { From 8292e4645ed84c2edb8bda75f4c7679622ad31ca Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 14:36:05 +0700 Subject: [PATCH 15/21] refactor: Use time.Time for tenant created timestamp --- go.mod | 1 + internal/tenant/handlers.go | 2 +- internal/tenant/handlers_test.go | 20 ++++++++++++++------ internal/tenant/model.go | 11 ++++++++--- internal/tenant/model_test.go | 19 ++++++++++++------- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index ad7951bb..37f827d0 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.0 require ( github.com/alicebob/miniredis/v2 v2.33.0 github.com/gin-gonic/gin v1.10.0 + github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/redis/go-redis/extra/redisotel/v9 v9.5.3 diff --git a/internal/tenant/handlers.go b/internal/tenant/handlers.go index c8385dda..1fe749b7 100644 --- a/internal/tenant/handlers.go +++ b/internal/tenant/handlers.go @@ -43,7 +43,7 @@ func (h *TenantHandlers) Upsert(c *gin.Context) { // Create new tenant. tenant = &Tenant{ ID: tenantID, - CreatedAt: time.Now().String(), + CreatedAt: time.Now(), } if err := h.model.Set(c.Request.Context(), *tenant); err != nil { logger.Error("failed to set tenant", zap.Error(err)) diff --git a/internal/tenant/handlers_test.go b/internal/tenant/handlers_test.go index d719c8c6..f72d074c 100644 --- a/internal/tenant/handlers_test.go +++ b/internal/tenant/handlers_test.go @@ -57,7 +57,7 @@ func TestDestinationUpsertHandler(t *testing.T) { // Setup existingResource := tenant.Tenant{ ID: uuid.New().String(), - CreatedAt: time.Now().String(), + CreatedAt: time.Now(), } model.Set(context.Background(), existingResource) @@ -71,7 +71,11 @@ func TestDestinationUpsertHandler(t *testing.T) { // Test assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, existingResource.ID, response["id"]) - assert.Equal(t, existingResource.CreatedAt, response["created_at"]) + createdAt, err := time.Parse(time.RFC3339Nano, response["created_at"].(string)) + if err != nil { + t.Fatal(err) + } + assert.True(t, existingResource.CreatedAt.Equal(createdAt)) // Cleanup model.Clear(context.Background(), existingResource.ID) @@ -103,7 +107,7 @@ func TestTenantRetrieveHandler(t *testing.T) { // Setup existingResource := tenant.Tenant{ ID: uuid.New().String(), - CreatedAt: time.Now().String(), + CreatedAt: time.Now(), } model.Set(context.Background(), existingResource) @@ -114,10 +118,14 @@ func TestTenantRetrieveHandler(t *testing.T) { var response map[string]any json.Unmarshal(w.Body.Bytes(), &response) - // Assert + // Test assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, existingResource.ID, response["id"]) - assert.Equal(t, existingResource.CreatedAt, response["created_at"]) + createdAt, err := time.Parse(time.RFC3339Nano, response["created_at"].(string)) + if err != nil { + t.Fatal(err) + } + assert.True(t, existingResource.CreatedAt.Equal(createdAt)) // Cleanup model.Clear(context.Background(), existingResource.ID) @@ -149,7 +157,7 @@ func TestTenantDeleteHandler(t *testing.T) { // Setup existingResource := tenant.Tenant{ ID: uuid.New().String(), - CreatedAt: time.Now().String(), + CreatedAt: time.Now(), } model.Set(context.Background(), existingResource) diff --git a/internal/tenant/model.go b/internal/tenant/model.go index 47073a93..1d5ebad6 100644 --- a/internal/tenant/model.go +++ b/internal/tenant/model.go @@ -3,13 +3,14 @@ package tenant import ( "context" "fmt" + "time" "github.com/hookdeck/EventKit/internal/redis" ) type Tenant struct { - ID string `json:"id" redis:"id"` - CreatedAt string `json:"created_at" redis:"created_at"` + ID string `json:"id" redis:"id"` + CreatedAt time.Time `json:"created_at" redis:"created_at"` } type TenantModel struct { @@ -66,7 +67,11 @@ func (t *Tenant) parseRedisHash(hash map[string]string) error { if hash["created_at"] == "" { return fmt.Errorf("missing created_at") } - t.CreatedAt = hash["created_at"] + createdAt, err := time.Parse(time.RFC3339Nano, hash["created_at"]) + if err != nil { + return err + } + t.CreatedAt = createdAt return nil } diff --git a/internal/tenant/model_test.go b/internal/tenant/model_test.go index a6893e96..cac50912 100644 --- a/internal/tenant/model_test.go +++ b/internal/tenant/model_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/hookdeck/EventKit/internal/tenant" "github.com/hookdeck/EventKit/internal/util/testutil" @@ -19,7 +20,7 @@ func TestTenantModel(t *testing.T) { input := tenant.Tenant{ ID: uuid.New().String(), - CreatedAt: time.Now().String(), + CreatedAt: time.Now(), } t.Run("gets empty", func(t *testing.T) { @@ -36,30 +37,34 @@ func TestTenantModel(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, input.CreatedAt, hash["created_at"], "model.Set() should set tenant created timestamp %s", input.CreatedAt) + createdAt, err := time.Parse(time.RFC3339Nano, hash["created_at"]) + if err != nil { + t.Fatal(err) + } + assert.True(t, input.CreatedAt.Equal(createdAt), "model.Set() should set tenant created timestamp %s", input.CreatedAt) }) t.Run("gets", func(t *testing.T) { actual, err := model.Get(context.Background(), input.ID) assert.Nil(t, err, "model.Get() should not return an error") - assert.Equal(t, input, *actual, "model.Get() should return %s", input) + assert.True(t, cmp.Equal(input, *actual), "model.Get() should return %s", input) }) t.Run("overrides", func(t *testing.T) { - input.CreatedAt = time.Now().String() + input.CreatedAt = time.Now() err := model.Set(context.Background(), input) - assert.Nil(t, err, "model.Set() should not return an error", input) + assert.Nil(t, err, "model.Set() should not return an error") actual, err := model.Get(context.Background(), input.ID) assert.Nil(t, err, "model.Get() should not return an error") - assert.Equal(t, input, *actual, "model.Get() should return %s", input) + assert.True(t, cmp.Equal(input, *actual), "model.Get() should return %s", input) }) t.Run("clears", func(t *testing.T) { deleted, err := model.Clear(context.Background(), input.ID) assert.Nil(t, err, "model.Clear() should not return an error") - assert.Equal(t, *deleted, input, "model.Clear() should return deleted value", input) + assert.True(t, cmp.Equal(*deleted, input), "model.Clear() should return deleted value") actual, err := model.Get(context.Background(), input.ID) assert.Nil(t, actual, "model.Clear() should properly remove value") From 1ff8340e9e90922e2472ba2183d291ab9bf6fbb2 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 20:52:36 +0700 Subject: [PATCH 16/21] feat: Generate tenant scope JWT token --- go.mod | 1 + go.sum | 2 + internal/config/config.go | 18 ++++--- internal/services/api/router.go | 2 +- internal/tenant/handlers.go | 34 +++++++++--- internal/tenant/handlers_test.go | 6 +-- internal/tenant/jwt.go | 41 ++++++++++++++ internal/tenant/jwt_test.go | 91 ++++++++++++++++++++++++++++++++ 8 files changed, 177 insertions(+), 18 deletions(-) create mode 100644 internal/tenant/jwt.go create mode 100644 internal/tenant/jwt_test.go diff --git a/go.mod b/go.mod index 37f827d0..521a5db7 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.0 require ( github.com/alicebob/miniredis/v2 v2.33.0 github.com/gin-gonic/gin v1.10.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 diff --git a/go.sum b/go.sum index 1f715532..aa35c72c 100644 --- a/go.sum +++ b/go.sum @@ -52,6 +52,8 @@ github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4 github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= diff --git a/internal/config/config.go b/internal/config/config.go index 464f87ca..11437805 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,10 +16,11 @@ const ( ) type Config struct { - Service ServiceType - Port int - Hostname string - APIKey string + Service ServiceType + Port int + Hostname string + APIKey string + JWTSecret string Redis *redis.RedisConfig OpenTelemetry *otel.OpenTelemetryConfig @@ -79,10 +80,11 @@ func Parse(flags Flags) (*Config, error) { // Initialize config values config := &Config{ - Hostname: hostname, - Service: service, - Port: getPort(viper), - APIKey: viper.GetString("API_KEY"), + Hostname: hostname, + Service: service, + Port: getPort(viper), + APIKey: viper.GetString("API_KEY"), + JWTSecret: viper.GetString("JWT_SECRET"), Redis: &redis.RedisConfig{ Host: viper.GetString("REDIS_HOST"), Port: mustInt(viper, "REDIS_PORT"), diff --git a/internal/services/api/router.go b/internal/services/api/router.go index cc5c23d8..c5f9291e 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -22,7 +22,7 @@ func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Cl c.Status(http.StatusOK) }) - tenantHandlers := tenant.NewHandlers(logger, redisClient) + tenantHandlers := tenant.NewHandlers(logger, redisClient, cfg.JWTSecret) r.PUT("/:tenantID", tenantHandlers.Upsert) r.GET("/:tenantID", tenantHandlers.Retrieve) diff --git a/internal/tenant/handlers.go b/internal/tenant/handlers.go index 1fe749b7..c26f6dde 100644 --- a/internal/tenant/handlers.go +++ b/internal/tenant/handlers.go @@ -11,14 +11,16 @@ import ( ) type TenantHandlers struct { - logger *otelzap.Logger - model *TenantModel + logger *otelzap.Logger + model *TenantModel + jwtSecret string } -func NewHandlers(logger *otelzap.Logger, redisClient *redis.Client) *TenantHandlers { +func NewHandlers(logger *otelzap.Logger, redisClient *redis.Client, jwtSecret string) *TenantHandlers { return &TenantHandlers{ - logger: logger, - model: NewTenantModel(redisClient), + logger: logger, + model: NewTenantModel(redisClient), + jwtSecret: jwtSecret, } } @@ -95,5 +97,25 @@ func (h *TenantHandlers) Delete(c *gin.Context) { } func (h *TenantHandlers) RetrievePortal(c *gin.Context) { - c.Status(http.StatusNotImplemented) + logger := h.logger.Ctx(c.Request.Context()) + tenantID := c.Param("tenantID") + tenant, err := h.model.Get(c.Request.Context(), tenantID) + if err != nil { + logger.Error("failed to get tenant", zap.Error(err)) + c.Status(http.StatusInternalServerError) + return + } + if tenant == nil { + c.Status(http.StatusNotFound) + return + } + jwtToken, err := JWT.New(h.jwtSecret, tenantID) + if err != nil { + logger.Error("failed to create jwt token", zap.Error(err)) + c.Status(http.StatusInternalServerError) + return + } + c.JSON(http.StatusOK, gin.H{ + "redirect_url": "https://example.com?token=" + jwtToken, + }) } diff --git a/internal/tenant/handlers_test.go b/internal/tenant/handlers_test.go index f72d074c..8825c192 100644 --- a/internal/tenant/handlers_test.go +++ b/internal/tenant/handlers_test.go @@ -31,7 +31,7 @@ func TestDestinationUpsertHandler(t *testing.T) { logger := testutil.CreateTestLogger(t) redisClient := testutil.CreateTestRedisClient(t) model := tenant.NewTenantModel(redisClient) - handlers := tenant.NewHandlers(logger, redisClient) + handlers := tenant.NewHandlers(logger, redisClient, "") router := setupRouter(handlers) t.Run("should create when there's no existing tenant", func(t *testing.T) { @@ -88,7 +88,7 @@ func TestTenantRetrieveHandler(t *testing.T) { logger := testutil.CreateTestLogger(t) redisClient := testutil.CreateTestRedisClient(t) model := tenant.NewTenantModel(redisClient) - handlers := tenant.NewHandlers(logger, redisClient) + handlers := tenant.NewHandlers(logger, redisClient, "") router := setupRouter(handlers) t.Run("should return 404 when there's no tenant", func(t *testing.T) { @@ -138,7 +138,7 @@ func TestTenantDeleteHandler(t *testing.T) { logger := testutil.CreateTestLogger(t) redisClient := testutil.CreateTestRedisClient(t) model := tenant.NewTenantModel(redisClient) - handlers := tenant.NewHandlers(logger, redisClient) + handlers := tenant.NewHandlers(logger, redisClient, "") router := setupRouter(handlers) t.Run("should return 404 when there's no tenant", func(t *testing.T) { diff --git a/internal/tenant/jwt.go b/internal/tenant/jwt.go new file mode 100644 index 00000000..8ae65719 --- /dev/null +++ b/internal/tenant/jwt.go @@ -0,0 +1,41 @@ +package tenant + +import ( + "time" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const issuer = "eventkit" + +var signingMethod = jwt.SigningMethodHS256 + +type jsonwebtoken struct{} + +var JWT = jsonwebtoken{} + +func (_ jsonwebtoken) New(jwtKey string, tenantID string) (string, error) { + now := time.Now() + token := jwt.NewWithClaims(signingMethod, jwt.MapClaims{ + "iss": issuer, + "sub": tenantID, + "iat": now.Unix(), + "exp": now.Add(time.Hour).Unix(), + }) + return token.SignedString([]byte(jwtKey)) +} + +func (_ jsonwebtoken) Verify(jwtKey string, tokenString string, tenantID string) (bool, error) { + token, err := jwt.Parse( + tokenString, + func(token *jwt.Token) (interface{}, error) { + return []byte(jwtKey), nil + }, + jwt.WithIssuer(issuer), + jwt.WithSubject(tenantID), + ) + if err != nil { + return false, err + } + return token.Valid, nil +} diff --git a/internal/tenant/jwt_test.go b/internal/tenant/jwt_test.go new file mode 100644 index 00000000..e21d5171 --- /dev/null +++ b/internal/tenant/jwt_test.go @@ -0,0 +1,91 @@ +package tenant_test + +import ( + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/hookdeck/EventKit/internal/tenant" + "github.com/stretchr/testify/assert" +) + +func TestJWT(t *testing.T) { + t.Parallel() + + const issuer = "eventkit" + const jwtKey = "supersecret" + const tenantID = "tenantID" + var signingMethod = jwt.SigningMethodHS256 + + t.Run("should generate a new jwt token", func(t *testing.T) { + t.Parallel() + token, err := tenant.JWT.New(jwtKey, tenantID) + assert.Nil(t, err) + assert.NotEqual(t, "", token) + }) + + t.Run("should verify a valid jwt token", func(t *testing.T) { + t.Parallel() + token, err := tenant.JWT.New(jwtKey, tenantID) + if err != nil { + t.Fatal(err) + } + valid, err := tenant.JWT.Verify(jwtKey, token, tenantID) + assert.Nil(t, err) + assert.True(t, valid) + }) + + t.Run("should reject a token from a different issuer", func(t *testing.T) { + t.Parallel() + now := time.Now() + jwtToken := jwt.NewWithClaims(signingMethod, jwt.MapClaims{ + "iss": "not-eventkit", + "sub": tenantID, + "iat": now.Unix(), + "exp": now.Add(time.Hour).Unix(), + }) + token, err := jwtToken.SignedString([]byte(jwtKey)) + if err != nil { + t.Fatal(err) + } + valid, err := tenant.JWT.Verify(jwtKey, token, tenantID) + assert.ErrorContains(t, err, "token has invalid claims: token has invalid issuer") + assert.NotEqual(t, true, valid) + }) + + t.Run("should reject a token for a different tenant", func(t *testing.T) { + t.Parallel() + now := time.Now() + jwtToken := jwt.NewWithClaims(signingMethod, jwt.MapClaims{ + "iss": issuer, + "sub": "different_tenantID", + "iat": now.Unix(), + "exp": now.Add(time.Hour).Unix(), + }) + token, err := jwtToken.SignedString([]byte(jwtKey)) + if err != nil { + t.Fatal(err) + } + valid, err := tenant.JWT.Verify(jwtKey, token, tenantID) + assert.ErrorContains(t, err, "token has invalid claims: token has invalid subject") + assert.NotEqual(t, true, valid) + }) + + t.Run("should reject an expired token", func(t *testing.T) { + t.Parallel() + now := time.Now() + jwtToken := jwt.NewWithClaims(signingMethod, jwt.MapClaims{ + "iss": issuer, + "sub": tenantID, + "iat": now.Add(-2 * time.Hour).Unix(), + "exp": now.Add(-time.Hour).Unix(), + }) + token, err := jwtToken.SignedString([]byte(jwtKey)) + if err != nil { + t.Fatal(err) + } + valid, err := tenant.JWT.Verify(jwtKey, token, tenantID) + assert.ErrorContains(t, err, "token has invalid claims: token is expired") + assert.NotEqual(t, true, valid) + }) +} From 11b65a8f8cfbb76a35934b0811ec5bfeb7b1112c Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 21:26:37 +0700 Subject: [PATCH 17/21] test: Add JWT test case for malformed token --- internal/tenant/jwt_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/internal/tenant/jwt_test.go b/internal/tenant/jwt_test.go index e21d5171..36e87edf 100644 --- a/internal/tenant/jwt_test.go +++ b/internal/tenant/jwt_test.go @@ -35,6 +35,13 @@ func TestJWT(t *testing.T) { assert.True(t, valid) }) + t.Run("should reject an invalid token", func(t *testing.T) { + t.Parallel() + valid, err := tenant.JWT.Verify(jwtKey, "invalid_token", tenantID) + assert.ErrorContains(t, err, "token is malformed") + assert.NotEqual(t, true, valid) + }) + t.Run("should reject a token from a different issuer", func(t *testing.T) { t.Parallel() now := time.Now() From 10a681c578786cadbf2b0f8c55deddc49b8ffce6 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 21:28:53 +0700 Subject: [PATCH 18/21] feat: Support tenant scope JWT auth --- internal/services/api/auth_middleware.go | 32 ++++++++++++++++++++++++ internal/services/api/router.go | 23 ++++++++++++----- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/internal/services/api/auth_middleware.go b/internal/services/api/auth_middleware.go index c0af656c..f44ce23f 100644 --- a/internal/services/api/auth_middleware.go +++ b/internal/services/api/auth_middleware.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/hookdeck/EventKit/internal/tenant" ) func apiKeyAuthMiddleware(apiKey string) gin.HandlerFunc { @@ -33,6 +34,37 @@ func apiKeyAuthMiddleware(apiKey string) gin.HandlerFunc { } } +func apiKeyOrTenantJWTAuthMiddleware(apiKey string, jwtKey string) gin.HandlerFunc { + return func(c *gin.Context) { + authorizationToken, err := extractBearerToken(c.GetHeader("Authorization")) + if err != nil { + // TODO: Consider sending a more detailed error message. + // Currently we don't have clear specs on how to send back error message. + c.AbortWithStatus(http.StatusUnauthorized) + return + } + if authorizationToken == apiKey { + c.Next() + return + } + tenantID := c.Param("tenantID") + valid, err := tenant.JWT.Verify(jwtKey, authorizationToken, tenantID) + if err != nil { + // TODO: Consider sending a more detailed error message. + // Currently we don't have clear specs on how to send back error message. + c.AbortWithStatus(http.StatusUnauthorized) + return + } + if !valid { + // TODO: Consider sending a more detailed error message. + // Currently we don't have clear specs on how to send back error message. + c.AbortWithStatus(http.StatusUnauthorized) + return + } + c.Next() + } +} + func extractBearerToken(header string) (string, error) { if !strings.HasPrefix(header, "Bearer ") { return "", errors.New("invalid bearer token") diff --git a/internal/services/api/router.go b/internal/services/api/router.go index c5f9291e..34233171 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -15,7 +15,6 @@ import ( func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Client) http.Handler { r := gin.Default() r.Use(otelgin.Middleware(cfg.Hostname)) - r.Use(apiKeyAuthMiddleware(cfg.APIKey)) r.GET("/healthz", func(c *gin.Context) { logger.Ctx(c.Request.Context()).Info("health check") @@ -23,13 +22,25 @@ func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Cl }) tenantHandlers := tenant.NewHandlers(logger, redisClient, cfg.JWTSecret) + destinationHandlers := destination.NewHandlers(redisClient) - r.PUT("/:tenantID", tenantHandlers.Upsert) - r.GET("/:tenantID", tenantHandlers.Retrieve) - r.DELETE("/:tenantID", tenantHandlers.Delete) - r.GET("/:tenantID/portal", tenantHandlers.RetrievePortal) + // Admin router is a router group with the API key auth mechanism. + adminRouter := r.Group("/", apiKeyAuthMiddleware(cfg.APIKey)) - destinationHandlers := destination.NewHandlers(redisClient) + adminRouter.PUT("/:tenantID", tenantHandlers.Upsert) + adminRouter.GET("/:tenantID/portal", tenantHandlers.RetrievePortal) + + // Tenant router is a router group that accepts either + // - a tenant's JWT token OR + // - the preconfigured API key + // + // If the EventKit service deployment isn't configured with an API key, then + // it's assumed that the service runs in a secure environment + // and the JWT check is NOT necessary either. + tenantRouter := r.Group("/", apiKeyOrTenantJWTAuthMiddleware(cfg.APIKey, cfg.JWTSecret)) + + tenantRouter.GET("/:tenantID", tenantHandlers.Retrieve) + tenantRouter.DELETE("/:tenantID", tenantHandlers.Delete) r.GET("/destinations", destinationHandlers.List) r.POST("/destinations", destinationHandlers.Create) From c30355363886e800801ca86a80c05aefee0b3289 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 22:07:51 +0700 Subject: [PATCH 19/21] chore: Rename param of jwt functions --- internal/tenant/jwt.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/tenant/jwt.go b/internal/tenant/jwt.go index 8ae65719..209e18b9 100644 --- a/internal/tenant/jwt.go +++ b/internal/tenant/jwt.go @@ -14,7 +14,7 @@ type jsonwebtoken struct{} var JWT = jsonwebtoken{} -func (_ jsonwebtoken) New(jwtKey string, tenantID string) (string, error) { +func (_ jsonwebtoken) New(jwtSecret string, tenantID string) (string, error) { now := time.Now() token := jwt.NewWithClaims(signingMethod, jwt.MapClaims{ "iss": issuer, @@ -22,14 +22,14 @@ func (_ jsonwebtoken) New(jwtKey string, tenantID string) (string, error) { "iat": now.Unix(), "exp": now.Add(time.Hour).Unix(), }) - return token.SignedString([]byte(jwtKey)) + return token.SignedString([]byte(jwtSecret)) } -func (_ jsonwebtoken) Verify(jwtKey string, tokenString string, tenantID string) (bool, error) { +func (_ jsonwebtoken) Verify(jwtSecret string, tokenString string, tenantID string) (bool, error) { token, err := jwt.Parse( tokenString, func(token *jwt.Token) (interface{}, error) { - return []byte(jwtKey), nil + return []byte(jwtSecret), nil }, jwt.WithIssuer(issuer), jwt.WithSubject(tenantID), From d6517bb8309ae868c1591cb4ec5717b57c2a25c7 Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 22:08:25 +0700 Subject: [PATCH 20/21] test: Router & different auth mechanism --- internal/services/api/api.go | 12 +- internal/services/api/auth_middleware.go | 3 + internal/services/api/router.go | 19 +- internal/services/api/router_test.go | 271 +++++++++++++++++++++++ 4 files changed, 296 insertions(+), 9 deletions(-) create mode 100644 internal/services/api/router_test.go diff --git a/internal/services/api/api.go b/internal/services/api/api.go index 1f9e024e..86219116 100644 --- a/internal/services/api/api.go +++ b/internal/services/api/api.go @@ -9,7 +9,9 @@ import ( "time" "github.com/hookdeck/EventKit/internal/config" + "github.com/hookdeck/EventKit/internal/destination" "github.com/hookdeck/EventKit/internal/redis" + "github.com/hookdeck/EventKit/internal/tenant" "github.com/uptrace/opentelemetry-go-extra/otelzap" "go.uber.org/zap" ) @@ -27,7 +29,15 @@ func NewService(ctx context.Context, wg *sync.WaitGroup, cfg *config.Config, log return nil, err } - router := NewRouter(cfg, logger, redisClient) + router := NewRouter( + RouterConfig{ + Hostname: cfg.Hostname, + APIKey: cfg.APIKey, + JWTSecret: cfg.JWTSecret, + }, + tenant.NewHandlers(logger, redisClient, cfg.JWTSecret), + destination.NewHandlers(redisClient), + ) service := &APIService{} service.logger = logger diff --git a/internal/services/api/auth_middleware.go b/internal/services/api/auth_middleware.go index f44ce23f..96082b07 100644 --- a/internal/services/api/auth_middleware.go +++ b/internal/services/api/auth_middleware.go @@ -66,6 +66,9 @@ func apiKeyOrTenantJWTAuthMiddleware(apiKey string, jwtKey string) gin.HandlerFu } func extractBearerToken(header string) (string, error) { + if header == "" { + return "", nil + } if !strings.HasPrefix(header, "Bearer ") { return "", errors.New("invalid bearer token") } diff --git a/internal/services/api/router.go b/internal/services/api/router.go index 34233171..4921505d 100644 --- a/internal/services/api/router.go +++ b/internal/services/api/router.go @@ -4,26 +4,29 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/hookdeck/EventKit/internal/config" "github.com/hookdeck/EventKit/internal/destination" - "github.com/hookdeck/EventKit/internal/redis" "github.com/hookdeck/EventKit/internal/tenant" - "github.com/uptrace/opentelemetry-go-extra/otelzap" "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" ) -func NewRouter(cfg *config.Config, logger *otelzap.Logger, redisClient *redis.Client) http.Handler { +type RouterConfig struct { + Hostname string + APIKey string + JWTSecret string +} + +func NewRouter( + cfg RouterConfig, + tenantHandlers *tenant.TenantHandlers, + destinationHandlers *destination.DestinationHandlers, +) http.Handler { r := gin.Default() r.Use(otelgin.Middleware(cfg.Hostname)) r.GET("/healthz", func(c *gin.Context) { - logger.Ctx(c.Request.Context()).Info("health check") c.Status(http.StatusOK) }) - tenantHandlers := tenant.NewHandlers(logger, redisClient, cfg.JWTSecret) - destinationHandlers := destination.NewHandlers(redisClient) - // Admin router is a router group with the API key auth mechanism. adminRouter := r.Group("/", apiKeyAuthMiddleware(cfg.APIKey)) diff --git a/internal/services/api/router_test.go b/internal/services/api/router_test.go new file mode 100644 index 00000000..6fec6547 --- /dev/null +++ b/internal/services/api/router_test.go @@ -0,0 +1,271 @@ +package api_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/hookdeck/EventKit/internal/destination" + "github.com/hookdeck/EventKit/internal/services/api" + "github.com/hookdeck/EventKit/internal/tenant" + "github.com/hookdeck/EventKit/internal/util/testutil" + "github.com/stretchr/testify/assert" +) + +func TestRouterWithAPIKey(t *testing.T) { + t.Parallel() + + apiKey := "api_key" + jwtSecret := "jwt_secret" + + logger := testutil.CreateTestLogger(t) + redisClient := testutil.CreateTestRedisClient(t) + + router := api.NewRouter( + api.RouterConfig{ + Hostname: "", + APIKey: apiKey, + JWTSecret: jwtSecret, + }, + tenant.NewHandlers(logger, redisClient, jwtSecret), + destination.NewHandlers(redisClient), + ) + + tenantID := "tenantID" + validToken, err := tenant.JWT.New(jwtSecret, tenantID) + if err != nil { + t.Fatal(err) + } + + t.Run("healthcheck should work", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("should block unauthenticated request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PUT", "/"+uuid.New().String(), nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("should block tenant-auth request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PUT", "/"+uuid.New().String(), nil) + req.Header.Set("Authorization", "Bearer "+validToken) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("should allow admin request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PUT", "/"+uuid.New().String(), nil) + req.Header.Set("Authorization", "Bearer "+apiKey) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusCreated, w.Code) + }) + + t.Run("should block unauthenticated request to tenant routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/tenantID", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("should allow admin request to tenant routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/tenantIDnotfound", nil) + req.Header.Set("Authorization", "Bearer "+apiKey) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should allow admin request to tenant routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/tenantIDnotfound", nil) + req.Header.Set("Authorization", "Bearer "+apiKey) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should allow tenant-auth request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/"+tenantID, nil) + req.Header.Set("Authorization", "Bearer "+validToken) + router.ServeHTTP(w, req) + + // A bit awkward that the tenant is not found, but the request is authenticated + // and the 404 response is handled by the handler which is what we're testing here (routing). + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should block invalid tenant-auth request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/"+tenantID, nil) + req.Header.Set("Authorization", "Bearer invalid") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) +} + +func TestRouterWithoutAPIKey(t *testing.T) { + t.Parallel() + + apiKey := "" + jwtSecret := "jwt_secret" + + logger := testutil.CreateTestLogger(t) + redisClient := testutil.CreateTestRedisClient(t) + + router := api.NewRouter( + api.RouterConfig{ + Hostname: "", + APIKey: apiKey, + JWTSecret: jwtSecret, + }, + tenant.NewHandlers(logger, redisClient, jwtSecret), + destination.NewHandlers(redisClient), + ) + + tenantID := "tenantID" + validToken, err := tenant.JWT.New(jwtSecret, tenantID) + if err != nil { + t.Fatal(err) + } + + t.Run("healthcheck should work", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/healthz", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("should allow unauthenticated request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PUT", "/"+uuid.New().String(), nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusCreated, w.Code) + }) + + t.Run("should allow tenant-auth request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PUT", "/"+uuid.New().String(), nil) + req.Header.Set("Authorization", "Bearer "+validToken) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusCreated, w.Code) + }) + + t.Run("should allow admin request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("PUT", "/"+uuid.New().String(), nil) + req.Header.Set("Authorization", "Bearer "+apiKey) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusCreated, w.Code) + }) + + t.Run("should allow unauthenticated request to tenant routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/tenantID", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should allow admin request to tenant routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/tenantIDnotfound", nil) + req.Header.Set("Authorization", "Bearer "+apiKey) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should allow admin request to tenant routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/tenantIDnotfound", nil) + req.Header.Set("Authorization", "Bearer "+apiKey) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should allow tenant-auth request to admin routes", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/"+tenantID, nil) + req.Header.Set("Authorization", "Bearer "+validToken) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + }) + + t.Run("should block request with invalid bearer authorization header", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/"+tenantID, nil) + req.Header.Set("Authorization", "NotBearer "+validToken) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("should block request with bearer authorization header with invalid token", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/"+tenantID, nil) + req.Header.Set("Authorization", "Bearer invalid") + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) +} From dbc433130017316e331d92820e49fe0d41c0301f Mon Sep 17 00:00:00 2001 From: Alex Luong Date: Wed, 28 Aug 2024 22:20:31 +0700 Subject: [PATCH 21/21] chore: Update .env.example --- .env.example | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.env.example b/.env.example index 9462409f..1966073b 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,7 @@ PORT=4000 +API_PORT=4000 +API_KEY=apikey +JWT_SECRET=jwtsecret # REDIS_HOST= REDIS_PORT=6379