diff --git a/cmd/cache-offloader.go b/cmd/cache-offloader.go index c49f723..b20bd2a 100644 --- a/cmd/cache-offloader.go +++ b/cmd/cache-offloader.go @@ -13,6 +13,7 @@ import ( "neurocode.io/cache-offloader/pkg/metrics" "neurocode.io/cache-offloader/pkg/probes" "neurocode.io/cache-offloader/pkg/storage" + "neurocode.io/cache-offloader/pkg/worker" ) func getInMemoryStorage(cfg config.Config) http.Cacher { @@ -61,8 +62,10 @@ func main() { cfg := config.New() setupLogging(cfg.ServerConfig.LogLevel) m := metrics.NewPrometheusCollector() + maxInFlightRevalidationRequests := 1000 opts := http.ServerOpts{ Config: cfg, + Worker: worker.NewUpdateQueue(maxInFlightRevalidationRequests), MetricsCollector: m, ReadinessChecker: probes.NewReadinessChecker(), } diff --git a/pkg/http/cache-mock_test.go b/pkg/http/cache-mock_test.go index 29252db..b4c3be9 100644 --- a/pkg/http/cache-mock_test.go +++ b/pkg/http/cache-mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ./stale-while-revalidate.go +// Source: ./cache.go // Package http is a generated GoMock package. package http @@ -12,6 +12,41 @@ import ( model "neurocode.io/cache-offloader/pkg/model" ) +// MockWorker is a mock of Worker interface. +type MockWorker struct { + ctrl *gomock.Controller + recorder *MockWorkerMockRecorder +} + +// MockWorkerMockRecorder is the mock recorder for MockWorker. +type MockWorkerMockRecorder struct { + mock *MockWorker +} + +// NewMockWorker creates a new mock instance. +func NewMockWorker(ctrl *gomock.Controller) *MockWorker { + mock := &MockWorker{ctrl: ctrl} + mock.recorder = &MockWorkerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWorker) EXPECT() *MockWorkerMockRecorder { + return m.recorder +} + +// Start mocks base method. +func (m *MockWorker) Start(arg0 string, arg1 func()) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Start", arg0, arg1) +} + +// Start indicates an expected call of Start. +func (mr *MockWorkerMockRecorder) Start(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockWorker)(nil).Start), arg0, arg1) +} + // MockCacher is a mock of Cacher interface. type MockCacher struct { ctrl *gomock.Controller diff --git a/pkg/http/cache.go b/pkg/http/cache.go index 3dd5422..5c1b12d 100644 --- a/pkg/http/cache.go +++ b/pkg/http/cache.go @@ -21,21 +21,27 @@ import ( ) //go:generate mockgen -source=./cache.go -destination=./cache-mock_test.go -package=http -type Cacher interface { - LookUp(context.Context, string) (*model.Response, error) - Store(context.Context, string, *model.Response) error -} +type ( + Worker interface { + Start(string, func()) + } + Cacher interface { + LookUp(context.Context, string) (*model.Response, error) + Store(context.Context, string, *model.Response) error + } -type MetricsCollector interface { - CacheHit(method string, statusCode int) - CacheMiss(method string, statusCode int) -} + MetricsCollector interface { + CacheHit(method string, statusCode int) + CacheMiss(method string, statusCode int) + } -type handler struct { - cacher Cacher - metricsCollector MetricsCollector - cfg config.CacheConfig -} + handler struct { + cacher Cacher + worker Worker + metricsCollector MetricsCollector + cfg config.CacheConfig + } +) func handleGzipServeErr(err error) { if err != nil { @@ -94,47 +100,50 @@ func errHandler(res http.ResponseWriter, req *http.Request, err error) { http.Error(res, "service unavailable", http.StatusBadGateway) } -func newCacheHandler(c Cacher, m MetricsCollector, cfg config.CacheConfig) handler { +func newCacheHandler(c Cacher, m MetricsCollector, w Worker, cfg config.CacheConfig) handler { return handler{ cacher: c, + worker: w, metricsCollector: m, cfg: cfg, } } -func (h handler) asyncCacheRevalidate(hashKey string, res http.ResponseWriter, req *http.Request) { - ctx := context.Background() - newReq := req.WithContext(ctx) - - netTransport := &http.Transport{ - MaxIdleConnsPerHost: 1000, - DisableKeepAlives: false, - IdleConnTimeout: time.Hour * 1, - Dial: (&net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 10 * time.Second, - } - client := &http.Client{ - Timeout: time.Second * 10, - Transport: netTransport, - } - - newReq.URL.Host = h.cfg.DownstreamHost.Host - newReq.URL.Scheme = h.cfg.DownstreamHost.Scheme - newReq.RequestURI = "" - resp, err := client.Do(newReq) - if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("Errored when sending request to the server") +func (h handler) asyncCacheRevalidate(hashKey string, req *http.Request) func() { + return func() { + ctx := context.Background() + newReq := req.WithContext(ctx) + + netTransport := &http.Transport{ + MaxIdleConnsPerHost: 1000, + DisableKeepAlives: false, + IdleConnTimeout: time.Hour * 1, + Dial: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + } + client := &http.Client{ + Timeout: time.Second * 10, + Transport: netTransport, + } - return - } - err = h.cacheResponse(ctx, hashKey)(resp) + newReq.URL.Host = h.cfg.DownstreamHost.Host + newReq.URL.Scheme = h.cfg.DownstreamHost.Scheme + newReq.RequestURI = "" + resp, err := client.Do(newReq) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("Errored when sending request to the server") - if err != nil { - log.Print("Error occurred caching response") + return + } + err = h.cacheResponse(ctx, hashKey)(resp) + + if err != nil { + log.Print("Error occurred caching response") + } } } @@ -178,7 +187,7 @@ func (h handler) ServeHTTP(res http.ResponseWriter, req *http.Request) { h.metricsCollector.CacheHit(req.Method, result.Status) if result.IsStale() { - go h.asyncCacheRevalidate(hashKey, res, req) + go h.worker.Start(hashKey, h.asyncCacheRevalidate(hashKey, req)) } serveResponseFromMemory(res, result) } diff --git a/pkg/http/cache_test.go b/pkg/http/cache_test.go index 5f40b89..11d2cab 100644 --- a/pkg/http/cache_test.go +++ b/pkg/http/cache_test.go @@ -65,7 +65,7 @@ func mustURL(t *testing.T, downstreamURL string) *url.URL { func TestCacheHandler(t *testing.T) { ctrl := gomock.NewController(t) - // defer ctrl.Finish() + defer ctrl.Finish() proxied := http.StatusUseProxy endpoint := "/status/200?q=1" @@ -286,13 +286,18 @@ func TestCacheHandler(t *testing.T) { cfg: config.CacheConfig{ DownstreamHost: mustURL(t, downstreamServer.URL), }, + worker: func() Worker { + mock := NewMockWorker(ctrl) + mock.EXPECT().Start(gomock.Any(), gomock.Any()) + + return mock + }(), cacher: func() Cacher { mock := NewMockCacher(ctrl) mock.EXPECT().LookUp(gomock.Any(), gomock.Any()).Return(&model.Response{ Status: http.StatusOK, Body: []byte("hello"), }, nil) - // mock.EXPECT().Store(gomock.Any(), gomock.Any(), gomock.Any()) return mock }(), diff --git a/pkg/http/server.go b/pkg/http/server.go index 2f502df..3e21c54 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -16,13 +16,14 @@ import ( type ServerOpts struct { Config config.Config Cacher Cacher + Worker Worker MetricsCollector MetricsCollector ReadinessChecker ReadinessChecker } func RunServer(opts ServerOpts) { mux := h.NewServeMux() - mux.Handle("/", newCacheHandler(opts.Cacher, opts.MetricsCollector, opts.Config.CacheConfig)) + mux.Handle("/", newCacheHandler(opts.Cacher, opts.MetricsCollector, opts.Worker, opts.Config.CacheConfig)) mux.Handle("/metrics/prometheus", metricsHandler()) mux.HandleFunc("/probes/liveness", livenessHandler) diff --git a/pkg/metrics/prometheus_test.go b/pkg/metrics/prometheus_test.go index 69295e0..0131290 100644 --- a/pkg/metrics/prometheus_test.go +++ b/pkg/metrics/prometheus_test.go @@ -11,6 +11,8 @@ func TestPrometheusMetrics(t *testing.T) { t.Run("should return a prometheus registry", func(t *testing.T) { collector := NewPrometheusCollector() assert.NotNil(t, collector) + + prometheus.Unregister(collector.httpMetrics) }) t.Run("should use NA for invalid HTTP method", func(t *testing.T) { @@ -21,5 +23,6 @@ func TestPrometheusMetrics(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, metric) + prometheus.Unregister(collector.httpMetrics) }) } diff --git a/pkg/storage/debouncer.go b/pkg/storage/debouncer.go deleted file mode 100644 index 868afc8..0000000 --- a/pkg/storage/debouncer.go +++ /dev/null @@ -1,81 +0,0 @@ -package storage - -import ( - "container/list" - "net/http" - "sync" -) - -type Debouncer struct { - mtx sync.RWMutex - requests *list.List - cache map[string]*list.Element - capacityMB float64 - sizeMB float64 -} - -type RequestNode struct { - value *http.Request -} - -func NewDebouncer(maxSizeMB float64) *Debouncer { - if maxSizeMB <= 0 { - maxSizeMB = 50.0 - } - - return &Debouncer{ - capacityMB: maxSizeMB, - sizeMB: 0.0, - requests: list.New(), - cache: make(map[string]*list.Element), - } -} - -func (debouncer *Debouncer) Store(key string, value *http.Request) { - debouncer.mtx.Lock() - defer debouncer.mtx.Unlock() - - bodySizeMB := debouncer.getSize(*value) - - if bodySizeMB < 0 || (bodySizeMB+debouncer.sizeMB > debouncer.capacityMB) { - return - } - - if _, found := debouncer.cache[key]; !found { - element := debouncer.requests.PushFront(&RequestNode{value: value}) - debouncer.cache[key] = element - } -} - -func (debouncer *Debouncer) GetNext() *http.Request { - debouncer.mtx.RLock() - defer debouncer.mtx.RUnlock() - - return debouncer.requests.Back().Value.(*RequestNode).value -} - -func (debouncer *Debouncer) Erase(key string) bool { - debouncer.mtx.RLock() - defer debouncer.mtx.RUnlock() - - if val, found := debouncer.cache[key]; found { - delete(debouncer.cache, key) - debouncer.requests.Remove(val) - - return true - } - - return false -} - -func (debouncer *Debouncer) getSize(value http.Request) float64 { - sizeBytes := value.ContentLength - - if sizeBytes < 0 { - return -1.0 - } - - sizeMB := float64(sizeBytes) / (1024 * 1024) - - return sizeMB -} diff --git a/pkg/worker/cache-updater.go b/pkg/worker/cache-updater.go new file mode 100644 index 0000000..db64b03 --- /dev/null +++ b/pkg/worker/cache-updater.go @@ -0,0 +1,48 @@ +package worker + +import ( + "sync" + + "github.com/rs/zerolog/log" +) + +type UpdateQueue struct { + mtx sync.RWMutex + queue map[string]bool + size int +} + +func NewUpdateQueue(size int) *UpdateQueue { + if size <= 0 { + size = 1000 + } + + return &UpdateQueue{ + queue: make(map[string]bool, size), + size: size, + } +} + +func (debouncer *UpdateQueue) Start(key string, work func()) { + if len(debouncer.queue) >= debouncer.size { + log.Warn().Msg("UpdateQueue is full, dropping request") + + return + } + + debouncer.mtx.Lock() + + if _, ok := debouncer.queue[key]; ok { + debouncer.mtx.Unlock() + + return + } + debouncer.queue[key] = true + debouncer.mtx.Unlock() + + work() + + debouncer.mtx.Lock() + delete(debouncer.queue, key) + debouncer.mtx.Unlock() +} diff --git a/pkg/worker/cache-updater_test.go b/pkg/worker/cache-updater_test.go new file mode 100644 index 0000000..0c99434 --- /dev/null +++ b/pkg/worker/cache-updater_test.go @@ -0,0 +1,21 @@ +package worker + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCacheUpdater(t *testing.T) { + t.Run("UpdateQueue shouldnt panic on negative numbers", func(t *testing.T) { + q := NewUpdateQueue(-1) + + assert.NotNil(t, q) + }) + t.Run("should do the work in a function", func(t *testing.T) { + q := NewUpdateQueue(1) + q.Start("test2", func() { + t.Log("test work") + }) + }) +}