diff --git a/domain/repository/factory.go b/domain/repository/factory.go new file mode 100644 index 0000000..51ef3a9 --- /dev/null +++ b/domain/repository/factory.go @@ -0,0 +1,8 @@ +package repository + +type Factory interface { + Key() KeyRepository + Marketplace() MarketplaceRepository + AdminAudit() AdminAuditRepository + Reversal() ReversalRepository +} diff --git a/go.mod b/go.mod index 9e2dd21..4256af3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.0 toolchain go1.24.7 require ( + github.com/go-chi/chi/v5 v5.2.4 github.com/go-chi/render v1.0.3 github.com/go-viper/mapstructure/v2 v2.4.0 github.com/google/go-cmp v0.7.0 diff --git a/go.sum b/go.sum index 1131cf1..f103270 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-chi/chi/v5 v5.2.4 h1:WtFKPHwlywe8Srng8j2BhOD9312j9cGUxG1SP4V2cR4= +github.com/go-chi/chi/v5 v5.2.4/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= diff --git a/internal/testutil/factory.go b/internal/testutil/factory.go new file mode 100644 index 0000000..a2897a1 --- /dev/null +++ b/internal/testutil/factory.go @@ -0,0 +1,55 @@ +package testutil + +import ( + "testing" + + "reverse-watch/domain/repository" +) + +type factory struct { + key repository.KeyRepository + marketplace repository.MarketplaceRepository + adminAudit repository.AdminAuditRepository + reversal repository.ReversalRepository +} + +func NewTestFactory(t *testing.T) *factory { + t.Helper() + return &factory{} +} + +func (f *factory) Key() repository.KeyRepository { + return f.key +} + +func (f *factory) Marketplace() repository.MarketplaceRepository { + return f.marketplace +} + +func (f *factory) AdminAudit() repository.AdminAuditRepository { + return f.adminAudit +} + +func (f *factory) Reversal() repository.ReversalRepository { + return f.reversal +} + +func (f *factory) WithKey(key repository.KeyRepository) *factory { + f.key = key + return f +} + +func (f *factory) WithMarketplace(marketplace repository.MarketplaceRepository) *factory { + f.marketplace = marketplace + return f +} + +func (f *factory) WithAdminAudit(adminAudit repository.AdminAuditRepository) *factory { + f.adminAudit = adminAudit + return f +} + +func (f *factory) WithReversal(reversal repository.ReversalRepository) *factory { + f.reversal = reversal + return f +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..3a2e1dc --- /dev/null +++ b/main.go @@ -0,0 +1,64 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "reverse-watch/config" + "reverse-watch/domain/models" + "reverse-watch/logging" + "reverse-watch/server" +) + +func main() { + logging.Initialize() + cfg := config.Load() + models.InitSnowflakeGenerator(0, 0) + + logging.Log.Info("Starting Reverse Watch") + + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGTERM) + + srv, err := server.New(cfg) + if err != nil { + panic(err) + } + httpSrv := &http.Server{ + Addr: fmt.Sprintf("0.0.0.0:%s", cfg.HTTP.Port), + Handler: srv, + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + } + + go func() { + logging.Log.Infof("Starting HTTP Server on %v", httpSrv.Addr) + if err := httpSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + panic(err) + } + }() + + <-done + + logging.Log.Info("Shutting down server connections gracefully") + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := httpSrv.Shutdown(ctx); err != nil { + panic(err) + } + + // Close db connections + if err := srv.Close(); err != nil { + panic(err) + } +} diff --git a/middleware/factory.go b/middleware/factory.go new file mode 100644 index 0000000..a9731e9 --- /dev/null +++ b/middleware/factory.go @@ -0,0 +1,20 @@ +package middleware + +import ( + "context" + "net/http" + + "reverse-watch/domain/repository" +) + +const FactoryContextKey ContextKey = "factory" + +func FactoryMiddleware(factory repository.Factory) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), FactoryContextKey, factory) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} diff --git a/middleware/factory_test.go b/middleware/factory_test.go new file mode 100644 index 0000000..d7c5438 --- /dev/null +++ b/middleware/factory_test.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "reverse-watch/internal/testutil" +) + +func TestFactoryMiddleware(t *testing.T) { + var capturedRequest *http.Request + fn := func(w http.ResponseWriter, r *http.Request) { + capturedRequest = r + w.WriteHeader(http.StatusOK) + } + next := http.HandlerFunc(fn) + + factory := testutil.NewTestFactory(t) + factoryMiddleware := FactoryMiddleware(factory) + handler := factoryMiddleware(next) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + handler.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Errorf("got status code %d, wanted %d", w.Code, http.StatusOK) + } + + if capturedRequest == nil { + t.Fatalf("captured request is nil") + } + + capturedFactory := capturedRequest.Context().Value(FactoryContextKey) + if capturedFactory == nil { + t.Fatalf("captured factory is nil") + } +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 71e496d..d7255e2 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -97,10 +97,10 @@ func TestAuthMiddleware(t *testing.T) { fn := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } - dummyHandler := http.HandlerFunc(fn) + next := http.HandlerFunc(fn) middlewareFunc := AuthMiddleware(keyRepo) - handler := middlewareFunc(dummyHandler) + handler := middlewareFunc(next) w := httptest.NewRecorder() r, err := tc.setup() diff --git a/middleware/permissions_test.go b/middleware/permissions_test.go index 6e61a90..cdc7a81 100644 --- a/middleware/permissions_test.go +++ b/middleware/permissions_test.go @@ -85,10 +85,10 @@ func TestRequirePermissions(t *testing.T) { fn := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } - dummyHandler := http.HandlerFunc(fn) + next := http.HandlerFunc(fn) middlewareFunc := RequirePermissions(tc.permissions...) - handler := middlewareFunc(dummyHandler) + handler := middlewareFunc(next) w := httptest.NewRecorder() r := tc.setup() diff --git a/repository/factory/factory.go b/repository/factory/factory.go new file mode 100644 index 0000000..1ad0419 --- /dev/null +++ b/repository/factory/factory.go @@ -0,0 +1,31 @@ +package factory + +import "reverse-watch/domain/repository" + +type factory struct { + private repository.PrivateRepository + public repository.PublicRepository +} + +func NewFactory(private repository.PrivateRepository, public repository.PublicRepository) repository.Factory { + return &factory{ + private: private, + public: public, + } +} + +func (f *factory) Key() repository.KeyRepository { + return f.private.Key() +} + +func (f *factory) Marketplace() repository.MarketplaceRepository { + return f.private.Marketplace() +} + +func (f *factory) AdminAudit() repository.AdminAuditRepository { + return f.private.AdminAudit() +} + +func (f *factory) Reversal() repository.ReversalRepository { + return f.public.Reversal() +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..f1a0027 --- /dev/null +++ b/server/server.go @@ -0,0 +1,74 @@ +package server + +import ( + "errors" + "fmt" + "net/http" + + "reverse-watch/config" + "reverse-watch/domain/repository" + "reverse-watch/logging" + rwmiddleware "reverse-watch/middleware" + "reverse-watch/repository/factory" + "reverse-watch/repository/private" + "reverse-watch/repository/public" + "reverse-watch/secret" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" +) + +type Server struct { + r chi.Router + + privateRepo repository.PrivateRepository + publicRepo repository.PublicRepository +} + +func New(cfg config.Config) (*Server, error) { + keygen := secret.NewKeyGenerator(cfg.Environment) + privateRepo, err := private.NewPrivateRepository(cfg, keygen) + if err != nil { + logging.Log.Errorf("failed to create private repository: %v", err) + return nil, fmt.Errorf("failed to create private repository: %v", err) + } + publicRepo, err := public.NewPublicRepository(cfg) + if err != nil { + privateRepo.Close() + logging.Log.Errorf("failed to create public repository: %v", err) + return nil, fmt.Errorf("failed to create public repository: %v", err) + } + + r := chi.NewRouter() + + r.Use(middleware.Recoverer) + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + + f := factory.NewFactory(privateRepo, publicRepo) + r.Use(rwmiddleware.FactoryMiddleware(f)) + + // TODO(zach): Define routes + + return &Server{ + r: r, + privateRepo: privateRepo, + publicRepo: publicRepo, + }, nil +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.r.ServeHTTP(w, r) +} + +func (s *Server) Close() error { + var errs []error + if err := s.privateRepo.Close(); err != nil { + errs = append(errs, err) + } + if err := s.publicRepo.Close(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +}