diff --git a/.gitignore b/.gitignore index 69dcb52..61ae2e1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ # Added by goreleaser init: dist/ +cache/ diff --git a/.golangci.yml b/.golangci.yml index d6392d5..017f601 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -54,7 +54,7 @@ linters: - bodyclose # checks whether HTTP response body is closed successfully - canonicalheader # checks whether net/http.Header uses canonical header - copyloopvar # detects places where loop variables are copied (Go 1.22+) - - cyclop # checks function and package cyclomatic complexity + # - cyclop # checks function and package cyclomatic complexity - depguard # checks if package imports are in a list of acceptable packages - dupl # tool for code clone detection - durationcheck # checks for two durations multiplied together diff --git a/cmd/sfptcd/main.go b/cmd/sfptcd/main.go index 9c44ee7..a75e0b9 100644 --- a/cmd/sfptcd/main.go +++ b/cmd/sfptcd/main.go @@ -2,21 +2,43 @@ package main import ( "context" + "log/slog" + "net/http" + "os" + "time" "github.com/alecthomas/kong" + "github.com/block/sfptc/internal/config" "github.com/block/sfptc/internal/logging" ) var cli struct { - logging.Config `prefix:"log-"` + Config *os.File `hcl:"-" help:"Configuration file path." placeholder:"PATH" required:""` + Bind string `hcl:"bind" default:"127.0.0.1:8080" help:"Bind address for the server."` + LoggingConfig logging.Config `embed:"" prefix:"log-"` } func main() { - kong.Parse(&cli) + kctx := kong.Parse(&cli) ctx := context.Background() - logger, ctx := logging.Configure(ctx, cli.Config) + logger, ctx := logging.Configure(ctx, cli.LoggingConfig) - logger.InfoContext(ctx, "Starting sfptcd") + mux := http.NewServeMux() + + err := config.Load(ctx, cli.Config, mux) + kctx.FatalIfErrorf(err) + + logger.InfoContext(ctx, "Starting sfptcd", slog.String("bind", cli.Bind)) + + server := &http.Server{ + Addr: cli.Bind, + Handler: mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + ReadHeaderTimeout: 10 * time.Second, + } + err = server.ListenAndServe() + kctx.FatalIfErrorf(err) } diff --git a/go.mod b/go.mod index d41febc..4635eb4 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/block/sfptc go 1.25.5 require ( - github.com/alecthomas/hcl/v2 v2.3.0 + github.com/alecthomas/hcl/v2 v2.3.1 github.com/alecthomas/kong v1.13.0 github.com/lmittmann/tint v1.1.2 ) @@ -15,7 +15,7 @@ require ( require ( github.com/alecthomas/assert/v2 v2.11.0 - github.com/alecthomas/errors v0.8.3 + github.com/alecthomas/errors v0.9.1 github.com/alecthomas/participle/v2 v2.1.4 // indirect github.com/alecthomas/repr v0.5.2 // indirect github.com/pkg/xattr v0.4.12 diff --git a/go.sum b/go.sum index 9b8f497..c545120 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,9 @@ github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= -github.com/alecthomas/errors v0.8.3 h1:IPyQj2fU3GGsl6C/r4OPmYgqgNSDLWJLE/ln2fLjwas= -github.com/alecthomas/errors v0.8.3/go.mod h1:l8mjMEHMGUdIWPMNtvDyRYPVS1fQFXHFXc/iVCCLGkI= -github.com/alecthomas/hcl/v2 v2.3.0 h1:voBoBfb69MBRFkJ5NyMN/cSFfevVZKJIoxwfuJ1j2gU= -github.com/alecthomas/hcl/v2 v2.3.0/go.mod h1:4UUp66q8ony5j8tm2bANErujUpZ3GgHBLgaKxTUQlQI= +github.com/alecthomas/errors v0.9.1 h1:JNXtU30rtMNARCkW41OTZ4yL6Lyocq20xIJgIw2raqI= +github.com/alecthomas/errors v0.9.1/go.mod h1:l8mjMEHMGUdIWPMNtvDyRYPVS1fQFXHFXc/iVCCLGkI= +github.com/alecthomas/hcl/v2 v2.3.1 h1:Nkj0svGJawz920nQyWUhD2PYmD47p7BB9vc2e3kft1o= +github.com/alecthomas/hcl/v2 v2.3.1/go.mod h1:4UUp66q8ony5j8tm2bANErujUpZ3GgHBLgaKxTUQlQI= github.com/alecthomas/kong v1.13.0 h1:5e/7XC3ugvhP1DQBmTS+WuHtCbcv44hsohMgcvVxSrA= github.com/alecthomas/kong v1.13.0/go.mod h1:wrlbXem1CWqUV5Vbmss5ISYhsVPkBb1Yo7YKJghju2I= github.com/alecthomas/participle/v2 v2.1.4 h1:W/H79S8Sat/krZ3el6sQMvMaahJ+XcM9WSI2naI7w2U= diff --git a/internal/cache/api.go b/internal/cache/api.go index 170e016..eb9dad3 100644 --- a/internal/cache/api.go +++ b/internal/cache/api.go @@ -12,20 +12,33 @@ import ( "github.com/alecthomas/hcl/v2" ) -var registry = map[string]func(config *hcl.Block) (Cache, error){} +// ErrNotFound is returned when a cache backend is not found. +var ErrNotFound = errors.New("cache backend not found") + +var registry = map[string]func(ctx context.Context, config *hcl.Block) (Cache, error){} // Factory is a function that creates a new cache instance from the given hcl-tagged configuration struct. type Factory[Config any, C Cache] func(ctx context.Context, config Config) (C, error) // Register a cache factory function. func Register[Config any, C Cache](id string, factory Factory[Config, C]) { - registry[id] = func(config *hcl.Block) (Cache, error) { + registry[id] = func(ctx context.Context, config *hcl.Block) (Cache, error) { var cfg Config if err := hcl.UnmarshalBlock(config, &cfg); err != nil { return nil, errors.WithStack(err) } - return factory(context.Background(), cfg) + return factory(ctx, cfg) + } +} + +// Create a new cache instance from the given name and configuration. +// +// Will return "ErrNotFound" if the cache backend is not found. +func Create(ctx context.Context, name string, config *hcl.Block) (Cache, error) { + if factory, ok := registry[name]; ok { + return errors.WithStack2(factory(ctx, config)) } + return nil, errors.Errorf("%s: %w", name, ErrNotFound) } // Key represents a unique identifier for a cached object. @@ -59,6 +72,8 @@ func (k *Key) MarshalText() ([]byte, error) { // A Cache knows how to retrieve, create and delete objects from a cache. type Cache interface { + // String describes the Cache implementation. + String() string // Open an existing file in the cache. // // Expired files SHOULD not be returned. diff --git a/internal/cache/disk.go b/internal/cache/disk.go index 21df964..45cf3e5 100644 --- a/internal/cache/disk.go +++ b/internal/cache/disk.go @@ -112,6 +112,8 @@ func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) { return disk, nil } +func (d *Disk) String() string { return "disk:" + d.config.Root } + func (d *Disk) Close() error { d.stop() return nil diff --git a/internal/cache/memory.go b/internal/cache/memory.go index 3ed9844..dcb60c5 100644 --- a/internal/cache/memory.go +++ b/internal/cache/memory.go @@ -3,6 +3,7 @@ package cache import ( "bytes" "context" + "fmt" "io" "os" "sync" @@ -12,10 +13,10 @@ import ( ) func init() { - Register("memory", NewMemoryCache) + Register("memory", NewMemory) } -type MemoryCacheConfig struct { +type MemoryConfig struct { LimitMB int `hcl:"limit-mb,optional" help:"Maximum size of the disk cache in megabytes (defaults to 1GB)." default:"1024"` MaxTTL time.Duration `hcl:"max-ttl,optional" help:"Maximum time-to-live for entries in the disk cache (defaults to 1 hour)." default:"1h"` } @@ -25,21 +26,23 @@ type memoryEntry struct { expiresAt time.Time } -type memoryCache struct { - config MemoryCacheConfig +type Memory struct { + config MemoryConfig mu sync.RWMutex entries map[Key]*memoryEntry currentSize int64 } -func NewMemoryCache(_ context.Context, config MemoryCacheConfig) (Cache, error) { - return &memoryCache{ +func NewMemory(_ context.Context, config MemoryConfig) (*Memory, error) { + return &Memory{ config: config, entries: make(map[Key]*memoryEntry), }, nil } -func (m *memoryCache) Open(_ context.Context, key Key) (io.ReadCloser, error) { +func (m *Memory) String() string { return fmt.Sprintf("memory:%dMB", m.config.LimitMB) } + +func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, error) { m.mu.RLock() defer m.mu.RUnlock() @@ -55,7 +58,7 @@ func (m *memoryCache) Open(_ context.Context, key Key) (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(entry.data)), nil } -func (m *memoryCache) Create(_ context.Context, key Key, ttl time.Duration) (io.WriteCloser, error) { +func (m *Memory) Create(_ context.Context, key Key, ttl time.Duration) (io.WriteCloser, error) { if ttl == 0 { ttl = m.config.MaxTTL } @@ -70,7 +73,7 @@ func (m *memoryCache) Create(_ context.Context, key Key, ttl time.Duration) (io. return writer, nil } -func (m *memoryCache) Delete(_ context.Context, key Key) error { +func (m *Memory) Delete(_ context.Context, key Key) error { m.mu.Lock() defer m.mu.Unlock() @@ -83,7 +86,7 @@ func (m *memoryCache) Delete(_ context.Context, key Key) error { return nil } -func (m *memoryCache) Close() error { +func (m *Memory) Close() error { m.mu.Lock() defer m.mu.Unlock() @@ -92,7 +95,7 @@ func (m *memoryCache) Close() error { } type memoryWriter struct { - cache *memoryCache + cache *Memory key Key buf *bytes.Buffer expiresAt time.Time @@ -142,7 +145,7 @@ func (w *memoryWriter) Close() error { return nil } -func (m *memoryCache) evictOldest(neededSpace int64) { +func (m *Memory) evictOldest(neededSpace int64) { type entryInfo struct { key Key size int64 diff --git a/internal/cache/memory_test.go b/internal/cache/memory_test.go index 730466a..953e62c 100644 --- a/internal/cache/memory_test.go +++ b/internal/cache/memory_test.go @@ -13,7 +13,7 @@ import ( func TestMemoryCache(t *testing.T) { cachetest.Suite(t, func(t *testing.T) cache.Cache { ctx := t.Context() - c, err := cache.NewMemoryCache(ctx, cache.MemoryCacheConfig{MaxTTL: 100 * time.Millisecond}) + c, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: 100 * time.Millisecond}) assert.NoError(t, err) return c }) diff --git a/internal/cache/remote/client.go b/internal/cache/remote/client.go index 139cd8e..220effe 100644 --- a/internal/cache/remote/client.go +++ b/internal/cache/remote/client.go @@ -29,6 +29,8 @@ func NewClient(baseURL string) *Client { } } +func (c *Client) String() string { return "remote:" + c.baseURL } + // Open retrieves an object from the remote cache. func (c *Client) Open(ctx context.Context, key cache.Key) (io.ReadCloser, error) { url := fmt.Sprintf("%s/%s", c.baseURL, key.String()) diff --git a/internal/cache/remote/client_test.go b/internal/cache/remote/client_test.go index 8a05924..871c4b4 100644 --- a/internal/cache/remote/client_test.go +++ b/internal/cache/remote/client_test.go @@ -12,19 +12,21 @@ import ( "github.com/block/sfptc/internal/cache/cachetest" "github.com/block/sfptc/internal/cache/remote" "github.com/block/sfptc/internal/logging" + "github.com/block/sfptc/internal/strategy" ) func TestRemoteClient(t *testing.T) { cachetest.Suite(t, func(t *testing.T) cache.Cache { ctx := t.Context() _, ctx = logging.Configure(ctx, logging.Config{Level: slog.LevelError}) - memCache, err := cache.NewMemoryCache(ctx, cache.MemoryCacheConfig{ + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{ MaxTTL: 100 * time.Millisecond, }) assert.NoError(t, err) t.Cleanup(func() { memCache.Close() }) - server := remote.NewServer(ctx, memCache) + server, err := strategy.NewDefault(ctx, strategy.DefaultConfig{}, memCache) + assert.NoError(t, err) ts := httptest.NewServer(server) t.Cleanup(ts.Close) diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..00b1392 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,74 @@ +// Package config loads HCL configuration and uses that to construct the cache backend, and proxy strategies. +package config + +import ( + "context" + "io" + "net/http" + "strings" + + "github.com/alecthomas/errors" + "github.com/alecthomas/hcl/v2" + + "github.com/block/sfptc/internal/cache" + "github.com/block/sfptc/internal/logging" + "github.com/block/sfptc/internal/strategy" +) + +// Load HCL configuration and uses that to construct the cache backend, and proxy strategies. +func Load(ctx context.Context, r io.Reader, mux *http.ServeMux) error { + logger := logging.FromContext(ctx) + ast, err := hcl.Parse(r) + if err != nil { + return errors.WithStack(err) + } + + strategyCandidates := []*hcl.Block{ + // Always enable the default strategy + {Name: "default", Labels: []string{"/api/v1/"}}, + } + + // First pass, instantiate caches + var caches []cache.Cache + for _, node := range ast.Entries { + switch node := node.(type) { + case *hcl.Block: + c, err := cache.Create(ctx, node.Name, node) + if errors.Is(err, cache.ErrNotFound) { + strategyCandidates = append(strategyCandidates, node) + continue + } else if err != nil { + return errors.Errorf("%s: %w", node.Pos, err) + } + caches = append(caches, c) + + case *hcl.Attribute: + return errors.Errorf("%s: attributes are not allowed", node.Pos) + } + } + if len(caches) != 1 { + return errors.Errorf("%s: expected exactly one cache backend, got %d", ast.Pos, len(caches)) + } + + cache := caches[0] + + logger.DebugContext(ctx, "Cache backend", "cache", cache) + + // Second pass, instantiate strategies and bind them to the mux. + for _, block := range strategyCandidates { + if len(block.Labels) != 1 { + return errors.Errorf("%s: block must have exactly one label defining the server mount point", block.Pos) + } + pattern := block.Labels[0] + block.Labels = nil + s, err := strategy.Create(ctx, block.Name, block, cache) + if err != nil { + return errors.Errorf("%s: %w", block.Pos, err) + } + + logger.DebugContext(ctx, "Adding strategy", "strategy", s, "pattern", pattern) + + mux.Handle(pattern, http.StripPrefix(strings.TrimSuffix(pattern, "/"), s)) + } + return nil +} diff --git a/internal/server/server.go b/internal/server/server.go deleted file mode 100644 index 43da9f6..0000000 --- a/internal/server/server.go +++ /dev/null @@ -1,40 +0,0 @@ -// Package server implements the HTTP server for the caching proxy. -package server - -import ( - "context" - "log/slog" - "net/http" - - "github.com/block/sfptc/internal/cache" - "github.com/block/sfptc/internal/cache/remote" - "github.com/block/sfptc/internal/logging" -) - -type Option func(*Server) - -type Server struct { - logger *slog.Logger - cache cache.Cache - mux *http.ServeMux - server *remote.Server -} - -var _ http.Handler = (*Server)(nil) - -func New(ctx context.Context, cache cache.Cache, options ...Option) *Server { - s := &Server{ - logger: logging.FromContext(ctx), - cache: cache, - mux: http.NewServeMux(), - server: remote.NewServer(ctx, cache), - } - for _, option := range options { - option(s) - } - return s -} - -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.mux.ServeHTTP(w, r) -} diff --git a/internal/strategy/api.go b/internal/strategy/api.go index 04c2893..870e341 100644 --- a/internal/strategy/api.go +++ b/internal/strategy/api.go @@ -7,23 +7,39 @@ import ( "github.com/alecthomas/errors" "github.com/alecthomas/hcl/v2" + + "github.com/block/sfptc/internal/cache" ) -var registry = map[string]func(config *hcl.Block) (Strategy, error){} +// ErrNotFound is returned when a strategy is not found. +var ErrNotFound = errors.New("strategy not found") + +var registry = map[string]func(ctx context.Context, config *hcl.Block, cache cache.Cache) (Strategy, error){} -type Factory[Config any] func(ctx context.Context, config Config) (Strategy, error) +type Factory[Config any, S Strategy] func(ctx context.Context, config Config, cache cache.Cache) (S, error) -// Register a new caching strategy. -func Register[Config any](id string, factory Factory[Config]) { - registry[id] = func(config *hcl.Block) (Strategy, error) { +// Register a new proxy strategy. +func Register[Config any, S Strategy](id string, factory Factory[Config, S]) { + registry[id] = func(ctx context.Context, config *hcl.Block, cache cache.Cache) (Strategy, error) { var cfg Config - if err := hcl.UnmarshalBlock(config, &cfg); err != nil { + if err := hcl.UnmarshalBlock(config, &cfg, hcl.AllowExtra(false)); err != nil { return nil, errors.WithStack(err) } - return factory(context.Background(), cfg) + return factory(ctx, cfg, cache) + } +} + +// Create a new proxy strategy. +// +// Will return "ErrNotFound" if the strategy is not found. +func Create(ctx context.Context, name string, config *hcl.Block, cache cache.Cache) (Strategy, error) { + if factory, ok := registry[name]; ok { + return errors.WithStack2(factory(ctx, config, cache)) } + return nil, errors.Errorf("%s: %w", name, ErrNotFound) } type Strategy interface { - Register(mux *http.ServeMux) + String() string + http.Handler } diff --git a/internal/cache/remote/server.go b/internal/strategy/default.go similarity index 50% rename from internal/cache/remote/server.go rename to internal/strategy/default.go index b02f9ce..e1db9d0 100644 --- a/internal/cache/remote/server.go +++ b/internal/strategy/default.go @@ -1,5 +1,4 @@ -// Package remote provides the server and client for a remote [cache.Cache] implementation. -package remote +package strategy import ( "context" @@ -14,17 +13,25 @@ import ( "github.com/block/sfptc/internal/logging" ) -// Server side implementation of the cache protocol. -type Server struct { +func init() { + Register("default", NewDefault) +} + +type DefaultConfig struct{} + +var _ Strategy = (*Default)(nil) + +// The Default strategy represents v1 of the proxy API. +type Default struct { cache cache.Cache logger *slog.Logger mux *http.ServeMux } -var _ http.Handler = (*Server)(nil) +var _ http.Handler = (*Default)(nil) -func NewServer(ctx context.Context, cache cache.Cache) *Server { - s := &Server{ +func NewDefault(ctx context.Context, _ DefaultConfig, cache cache.Cache) (*Default, error) { + s := &Default{ logger: logging.FromContext(ctx), cache: cache, mux: http.NewServeMux(), @@ -32,43 +39,45 @@ func NewServer(ctx context.Context, cache cache.Cache) *Server { s.mux.Handle("GET /{key}", http.HandlerFunc(s.getObject)) s.mux.Handle("POST /{key}", http.HandlerFunc(s.putObject)) s.mux.Handle("DELETE /{key}", http.HandlerFunc(s.deleteObject)) - return s + return s, nil } -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.mux.ServeHTTP(w, r) +func (d *Default) String() string { return "default" } + +func (d *Default) ServeHTTP(w http.ResponseWriter, r *http.Request) { + d.mux.ServeHTTP(w, r) } -func (s *Server) getObject(w http.ResponseWriter, r *http.Request) { +func (d *Default) getObject(w http.ResponseWriter, r *http.Request) { key, err := cache.ParseKey(r.PathValue("key")) if err != nil { - s.httpError(w, http.StatusBadRequest, err, "Invalid key") + d.httpError(w, http.StatusBadRequest, err, "Invalid key") return } - cr, err := s.cache.Open(r.Context(), key) + cr, err := d.cache.Open(r.Context(), key) if err != nil { if errors.Is(err, os.ErrNotExist) { - s.httpError(w, http.StatusNotFound, err, "Cache object not found", slog.String("key", key.String())) + d.httpError(w, http.StatusNotFound, err, "Cache object not found", slog.String("key", key.String())) return } - s.httpError(w, http.StatusInternalServerError, err, "Failed to open cache object", slog.String("key", key.String())) + d.httpError(w, http.StatusInternalServerError, err, "Failed to open cache object", slog.String("key", key.String())) return } _, err = io.Copy(w, cr) if err != nil { - s.logger.Error("Failed to copy cache object to response", slog.String("error", err.Error()), slog.String("key", key.String())) + d.logger.Error("Failed to copy cache object to response", slog.String("error", err.Error()), slog.String("key", key.String())) } if cerr := cr.Close(); cerr != nil { - s.logger.Error("Failed to close cache reader", slog.String("error", cerr.Error()), slog.String("key", key.String())) + d.logger.Error("Failed to close cache reader", slog.String("error", cerr.Error()), slog.String("key", key.String())) } } -func (s *Server) putObject(w http.ResponseWriter, r *http.Request) { +func (d *Default) putObject(w http.ResponseWriter, r *http.Request) { key, err := cache.ParseKey(r.PathValue("key")) if err != nil { - s.httpError(w, http.StatusBadRequest, err, "Invalid key") + d.httpError(w, http.StatusBadRequest, err, "Invalid key") return } @@ -77,48 +86,48 @@ func (s *Server) putObject(w http.ResponseWriter, r *http.Request) { if ttlh != "" { ttl, err = time.ParseDuration(ttlh) if err != nil { - s.httpError(w, http.StatusBadRequest, err, "Invalid Time-To-Live header format, must be in Go duration format eg. 1h") + d.httpError(w, http.StatusBadRequest, err, "Invalid Time-To-Live header format, must be in Go duration format eg. 1h") return } } - cw, err := s.cache.Create(r.Context(), key, ttl) + cw, err := d.cache.Create(r.Context(), key, ttl) if err != nil { - s.httpError(w, http.StatusInternalServerError, err, "Failed to create cache writer", slog.String("key", key.String())) + d.httpError(w, http.StatusInternalServerError, err, "Failed to create cache writer", slog.String("key", key.String())) return } if _, err := io.Copy(cw, r.Body); err != nil { - s.httpError(w, http.StatusInternalServerError, err, "Failed to copy request body to cache writer") + d.httpError(w, http.StatusInternalServerError, err, "Failed to copy request body to cache writer") return } if err := cw.Close(); err != nil { - s.httpError(w, http.StatusInternalServerError, err, "Failed to close cache writer") + d.httpError(w, http.StatusInternalServerError, err, "Failed to close cache writer") return } } -func (s *Server) deleteObject(w http.ResponseWriter, r *http.Request) { +func (d *Default) deleteObject(w http.ResponseWriter, r *http.Request) { key, err := cache.ParseKey(r.PathValue("key")) if err != nil { - s.httpError(w, http.StatusBadRequest, err, "Invalid key") + d.httpError(w, http.StatusBadRequest, err, "Invalid key") return } - err = s.cache.Delete(r.Context(), key) + err = d.cache.Delete(r.Context(), key) if err != nil { if errors.Is(err, os.ErrNotExist) { - s.httpError(w, http.StatusNotFound, err, "Cache object not found", slog.String("key", key.String())) + d.httpError(w, http.StatusNotFound, err, "Cache object not found", slog.String("key", key.String())) return } - s.httpError(w, http.StatusInternalServerError, err, "Failed to delete cache object", slog.String("key", key.String())) + d.httpError(w, http.StatusInternalServerError, err, "Failed to delete cache object", slog.String("key", key.String())) return } } -func (s *Server) httpError(w http.ResponseWriter, code int, err error, message string, args ...any) { +func (d *Default) httpError(w http.ResponseWriter, code int, err error, message string, args ...any) { args = append(args, slog.String("error", err.Error())) - s.logger.Error(message, args...) + d.logger.Error(message, args...) http.Error(w, message, code) } diff --git a/internal/strategy/host.go b/internal/strategy/host.go new file mode 100644 index 0000000..5188de9 --- /dev/null +++ b/internal/strategy/host.go @@ -0,0 +1,125 @@ +package strategy + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "os" + + "github.com/alecthomas/errors" + + "github.com/block/sfptc/internal/cache" + "github.com/block/sfptc/internal/logging" +) + +func init() { + Register("host", NewHost) +} + +// HostConfig represents the configuration for the Host strategy. +// +// In HCL it looks something like this: +// +// host "/github/" { +// target = "https://github.com/" +// } +// +// In this example, the strategy will be mounted under "/github". +type HostConfig struct { + Target string `hcl:"target" help:"The target URL to proxy requests to."` +} + +// The Host [Strategy] forwards all GET requests to the specified host, caching the response payloads. +type Host struct { + target *url.URL + cache cache.Cache + client *http.Client + logger *slog.Logger +} + +var _ Strategy = (*Host)(nil) + +func NewHost(ctx context.Context, config HostConfig, cache cache.Cache) (*Host, error) { + u, err := url.Parse(config.Target) + if err != nil { + return nil, fmt.Errorf("invalid target URL: %w", err) + } + return &Host{ + target: u, + cache: cache, + client: &http.Client{}, + logger: logging.FromContext(ctx), + }, nil +} + +func (d *Host) String() string { return "host:" + d.target.Host + d.target.Path } + +func (d *Host) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + targetURL := *d.target + targetURL.Path = r.URL.Path + targetURL.RawQuery = r.URL.RawQuery + fullURL := targetURL.String() + + key := cache.NewKey(fullURL) + + cr, err := d.cache.Open(r.Context(), key) + if err == nil { + defer cr.Close() + if _, err := io.Copy(w, cr); err != nil { + d.logger.Error("Failed to copy cached response", slog.String("error", err.Error()), slog.String("url", fullURL)) + } + return + } + + if !errors.Is(err, os.ErrNotExist) { + d.logger.Error("Failed to open cache", slog.String("error", err.Error()), slog.String("url", fullURL)) + } + + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, fullURL, nil) + if err != nil { + d.httpError(w, http.StatusInternalServerError, err, "Failed to create request", slog.String("url", fullURL)) + return + } + + resp, err := d.client.Do(req) + if err != nil { + d.httpError(w, http.StatusBadGateway, err, "Failed to fetch from target", slog.String("url", fullURL)) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + d.logger.Error("Failed to copy error response", slog.String("error", err.Error()), slog.String("url", fullURL)) + } + return + } + + cw, err := d.cache.Create(r.Context(), key, 0) + if err != nil { + d.httpError(w, http.StatusInternalServerError, err, "Failed to create cache entry", slog.String("url", fullURL)) + return + } + + mw := io.MultiWriter(w, cw) + _, copyErr := io.Copy(mw, resp.Body) + closeErr := cw.Close() + if err := errors.Join(copyErr, closeErr); err != nil { + d.logger.Error("Failed to write to cache", slog.String("error", err.Error()), slog.String("url", fullURL)) + } +} + +func (d *Host) httpError(w http.ResponseWriter, code int, err error, message string, args ...any) { + args = append(args, slog.String("error", err.Error())) + d.logger.Error(message, args...) + http.Error(w, message, code) +} diff --git a/internal/strategy/host_test.go b/internal/strategy/host_test.go new file mode 100644 index 0000000..be7abe7 --- /dev/null +++ b/internal/strategy/host_test.go @@ -0,0 +1,99 @@ +package strategy_test + +import ( + "context" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/sfptc/internal/cache" + "github.com/block/sfptc/internal/logging" + "github.com/block/sfptc/internal/strategy" +) + +func TestHostCaching(t *testing.T) { + callCount := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("response")) + })) + defer backend.Close() + + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + host, err := strategy.NewHost(ctx, strategy.HostConfig{Target: backend.URL}, memCache) + assert.NoError(t, err) + + req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + w1 := httptest.NewRecorder() + host.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "response", w1.Body.String()) + assert.Equal(t, 1, callCount) + + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + w2 := httptest.NewRecorder() + host.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "response", w2.Body.String()) + assert.Equal(t, 1, callCount, "second request should be served from cache") +} + +func TestHostNonOKStatus(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("not found")) + })) + defer backend.Close() + + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + host, err := strategy.NewHost(ctx, strategy.HostConfig{Target: backend.URL}, memCache) + assert.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/missing", nil) + w := httptest.NewRecorder() + host.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code) + assert.Equal(t, "not found", w.Body.String()) + + key := cache.NewKey(backend.URL + "/missing") + _, err = memCache.Open(context.Background(), key) + assert.Error(t, err, "non-OK responses should not be cached") +} + +func TestHostInvalidTargetURL(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + _, err = strategy.NewHost(ctx, strategy.HostConfig{Target: "://invalid"}, memCache) + assert.Error(t, err) +} + +func TestHostString(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + host, err := strategy.NewHost(ctx, strategy.HostConfig{Target: "https://example.com/prefix"}, memCache) + assert.NoError(t, err) + + assert.Equal(t, "host:example.com/prefix", host.String()) +} diff --git a/sfptc.hcl b/sfptc.hcl new file mode 100644 index 0000000..8549a0a --- /dev/null +++ b/sfptc.hcl @@ -0,0 +1,14 @@ +# strategy git {} +# strategy docker {} +# strategy hermit {} +# strategy artifactory { +# mitm = ["artifactory.global.square"] +# } + +host "/github/" { + target = "https://github.com/" +} + +disk { + root = "./cache" +}