From d58008b0ca5f08d2061a6f662a8b50c01aaade47 Mon Sep 17 00:00:00 2001 From: marco Date: Mon, 10 Jun 2024 09:40:24 +0200 Subject: [PATCH 1/2] do not merge - expect unit tests to fail --- cmd/crowdsec-cli/alerts.go | 2 +- cmd/crowdsec-cli/bouncers.go | 39 +++++----- cmd/crowdsec-cli/machines.go | 49 ++++++------- cmd/crowdsec-cli/papi.go | 2 +- cmd/crowdsec-cli/support.go | 12 ++-- cmd/crowdsec/api.go | 4 +- cmd/crowdsec/metrics.go | 4 +- pkg/apiserver/apic.go | 27 ++++--- pkg/apiserver/apic_metrics.go | 16 ++--- pkg/apiserver/apiserver.go | 6 +- pkg/apiserver/controllers/v1/alerts.go | 10 +-- pkg/apiserver/controllers/v1/decisions.go | 27 +++---- pkg/apiserver/controllers/v1/heartbeat.go | 2 +- pkg/apiserver/controllers/v1/machines.go | 2 +- pkg/apiserver/middlewares/v1/api_key.go | 12 ++-- pkg/apiserver/middlewares/v1/jwt.go | 14 ++-- pkg/apiserver/papi.go | 14 ++-- pkg/apiserver/papi_cmd.go | 13 ++-- pkg/database/alerts.go | 88 +++++++++++------------ pkg/database/bouncers.go | 41 +++++------ pkg/database/config.go | 11 +-- pkg/database/database.go | 2 - pkg/database/decisions.go | 77 ++++++++++---------- pkg/database/flush.go | 43 +++++------ pkg/database/lock.go | 23 +++--- pkg/database/machines.go | 59 +++++++-------- pkg/exprhelpers/helpers.go | 13 ++-- 27 files changed, 321 insertions(+), 291 deletions(-) diff --git a/cmd/crowdsec-cli/alerts.go b/cmd/crowdsec-cli/alerts.go index 7c9c5f23032..95ecdad02ea 100644 --- a/cmd/crowdsec-cli/alerts.go +++ b/cmd/crowdsec-cli/alerts.go @@ -575,7 +575,7 @@ func (cli *cliAlerts) NewFlushCmd() *cobra.Command { return err } log.Info("Flushing alerts. !! This may take a long time !!") - err = db.FlushAlerts(maxAge, maxItems) + err = db.FlushAlerts(cmd.Context(), maxAge, maxItems) if err != nil { return fmt.Errorf("unable to flush alerts: %w", err) } diff --git a/cmd/crowdsec-cli/bouncers.go b/cmd/crowdsec-cli/bouncers.go index f8628538378..88341088d82 100644 --- a/cmd/crowdsec-cli/bouncers.go +++ b/cmd/crowdsec-cli/bouncers.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/csv" "encoding/json" "errors" @@ -83,10 +84,10 @@ Note: This command requires database direct access, so is intended to be run on return cmd } -func (cli *cliBouncers) list() error { +func (cli *cliBouncers) list(ctx context.Context) error { out := color.Output - bouncers, err := cli.db.ListBouncers() + bouncers, err := cli.db.ListBouncers(ctx) if err != nil { return fmt.Errorf("unable to list bouncers: %w", err) } @@ -134,15 +135,15 @@ func (cli *cliBouncers) newListCmd() *cobra.Command { Example: `cscli bouncers list`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.list() + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.list(cmd.Context()) }, } return cmd } -func (cli *cliBouncers) add(bouncerName string, key string) error { +func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) error { var err error keyLength := 32 @@ -154,7 +155,7 @@ func (cli *cliBouncers) add(bouncerName string, key string) error { } } - _, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) + _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) if err != nil { return fmt.Errorf("unable to create bouncer: %w", err) } @@ -188,8 +189,8 @@ func (cli *cliBouncers) newAddCmd() *cobra.Command { cscli bouncers add MyBouncerName --key `, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args[0], key) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args[0], key) }, } @@ -201,8 +202,8 @@ cscli bouncers add MyBouncerName --key `, return cmd } -func (cli *cliBouncers) deleteValid(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - bouncers, err := cli.db.ListBouncers() +func (cli *cliBouncers) deleteValid(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + bouncers, err := cli.db.ListBouncers(cmd.Context()) if err != nil { cobra.CompError("unable to list bouncers " + err.Error()) } @@ -218,9 +219,9 @@ func (cli *cliBouncers) deleteValid(_ *cobra.Command, args []string, toComplete return ret, cobra.ShellCompDirectiveNoFileComp } -func (cli *cliBouncers) delete(bouncers []string) error { +func (cli *cliBouncers) delete(ctx context.Context, bouncers []string) error { for _, bouncerID := range bouncers { - err := cli.db.DeleteBouncer(bouncerID) + err := cli.db.DeleteBouncer(ctx, bouncerID) if err != nil { return fmt.Errorf("unable to delete bouncer '%s': %w", bouncerID, err) } @@ -239,15 +240,15 @@ func (cli *cliBouncers) newDeleteCmd() *cobra.Command { Aliases: []string{"remove"}, DisableAutoGenTag: true, ValidArgsFunction: cli.deleteValid, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args) }, } return cmd } -func (cli *cliBouncers) prune(duration time.Duration, force bool) error { +func (cli *cliBouncers) prune(ctx context.Context, duration time.Duration, force bool) error { if duration < 2*time.Minute { if yes, err := askYesNo( "The duration you provided is less than 2 minutes. " + @@ -259,7 +260,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error { } } - bouncers, err := cli.db.QueryBouncersLastPulltimeLT(time.Now().UTC().Add(-duration)) + bouncers, err := cli.db.QueryBouncersLastPulltimeLT(ctx, time.Now().UTC().Add(-duration)) if err != nil { return fmt.Errorf("unable to query bouncers: %w", err) } @@ -282,7 +283,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error { } } - deleted, err := cli.db.BulkDeleteBouncers(bouncers) + deleted, err := cli.db.BulkDeleteBouncers(ctx, bouncers) if err != nil { return fmt.Errorf("unable to prune bouncers: %w", err) } @@ -307,8 +308,8 @@ func (cli *cliBouncers) newPruneCmd() *cobra.Command { DisableAutoGenTag: true, Example: `cscli bouncers prune -d 45m cscli bouncers prune -d 45m --force`, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, force) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, force) }, } diff --git a/cmd/crowdsec-cli/machines.go b/cmd/crowdsec-cli/machines.go index 7beaa5c7fdd..3a3823700bd 100644 --- a/cmd/crowdsec-cli/machines.go +++ b/cmd/crowdsec-cli/machines.go @@ -1,6 +1,7 @@ package main import ( + "context" saferand "crypto/rand" "encoding/csv" "encoding/json" @@ -151,10 +152,10 @@ Note: This command requires database direct access, so is intended to be run on return cmd } -func (cli *cliMachines) list() error { +func (cli *cliMachines) list(ctx context.Context) error { out := color.Output - machines, err := cli.db.ListMachines() + machines, err := cli.db.ListMachines(ctx) if err != nil { return fmt.Errorf("unable to list machines: %w", err) } @@ -206,8 +207,8 @@ func (cli *cliMachines) newListCmd() *cobra.Command { Example: `cscli machines list`, Args: cobra.NoArgs, DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.list() + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.list(cmd.Context()) }, } @@ -233,8 +234,8 @@ func (cli *cliMachines) newAddCmd() *cobra.Command { cscli machines add MyTestMachine --auto cscli machines add MyTestMachine --password MyPassword cscli machines add -f- --auto > /tmp/mycreds.yaml`, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args, string(password), dumpFile, apiURL, interactive, autoAdd, force) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force) }, } @@ -249,7 +250,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`, return cmd } -func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { +func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { var ( err error machineID string @@ -308,7 +309,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri password := strfmt.Password(machinePassword) - _, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType) + _, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType) if err != nil { return fmt.Errorf("unable to create machine: %w", err) } @@ -349,8 +350,8 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri return nil } -func (cli *cliMachines) deleteValid(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - machines, err := cli.db.ListMachines() +func (cli *cliMachines) deleteValid(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + machines, err := cli.db.ListMachines(cmd.Context()) if err != nil { cobra.CompError("unable to list machines " + err.Error()) } @@ -366,9 +367,9 @@ func (cli *cliMachines) deleteValid(_ *cobra.Command, args []string, toComplete return ret, cobra.ShellCompDirectiveNoFileComp } -func (cli *cliMachines) delete(machines []string) error { +func (cli *cliMachines) delete(ctx context.Context, machines []string) error { for _, machineID := range machines { - if err := cli.db.DeleteWatcher(machineID); err != nil { + if err := cli.db.DeleteWatcher(ctx, machineID); err != nil { log.Errorf("unable to delete machine '%s': %s", machineID, err) return nil } @@ -388,15 +389,15 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command { Aliases: []string{"remove"}, DisableAutoGenTag: true, ValidArgsFunction: cli.deleteValid, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args) }, } return cmd } -func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force bool) error { +func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error { if duration < 2*time.Minute && !notValidOnly { if yes, err := askYesNo( "The duration you provided is less than 2 minutes. " + @@ -409,12 +410,12 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b } machines := []*ent.Machine{} - if pending, err := cli.db.QueryPendingMachine(); err == nil { + if pending, err := cli.db.QueryPendingMachine(ctx); err == nil { machines = append(machines, pending...) } if !notValidOnly { - if pending, err := cli.db.QueryLastValidatedHeartbeatLT(time.Now().UTC().Add(-duration)); err == nil { + if pending, err := cli.db.QueryLastValidatedHeartbeatLT(ctx, time.Now().UTC().Add(-duration)); err == nil { machines = append(machines, pending...) } } @@ -437,7 +438,7 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b } } - deleted, err := cli.db.BulkDeleteWatchers(machines) + deleted, err := cli.db.BulkDeleteWatchers(ctx, machines) if err != nil { return fmt.Errorf("unable to prune machines: %w", err) } @@ -465,8 +466,8 @@ cscli machines prune --duration 1h cscli machines prune --not-validated-only --force`, Args: cobra.NoArgs, DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, notValidOnly, force) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, notValidOnly, force) }, } @@ -478,8 +479,8 @@ cscli machines prune --not-validated-only --force`, return cmd } -func (cli *cliMachines) validate(machineID string) error { - if err := cli.db.ValidateMachine(machineID); err != nil { +func (cli *cliMachines) validate(ctx context.Context, machineID string) error { + if err := cli.db.ValidateMachine(ctx, machineID); err != nil { return fmt.Errorf("unable to validate machine '%s': %w", machineID, err) } @@ -496,8 +497,8 @@ func (cli *cliMachines) newValidateCmd() *cobra.Command { Example: `cscli machines validate "machine_name"`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.validate(args[0]) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(cmd.Context(), args[0]) }, } diff --git a/cmd/crowdsec-cli/papi.go b/cmd/crowdsec-cli/papi.go index a2fa0a90871..691ebaeeb3e 100644 --- a/cmd/crowdsec-cli/papi.go +++ b/cmd/crowdsec-cli/papi.go @@ -78,7 +78,7 @@ func (cli *cliPapi) NewStatusCmd() *cobra.Command { return fmt.Errorf("unable to get PAPI permissions: %w", err) } var lastTimestampStr *string - lastTimestampStr, err = db.GetConfigItem(apiserver.PapiPullKey) + lastTimestampStr, err = db.GetConfigItem(cmd.Context(), apiserver.PapiPullKey) if err != nil { lastTimestampStr = ptr.Of("never") } diff --git a/cmd/crowdsec-cli/support.go b/cmd/crowdsec-cli/support.go index 3b0f53cd6e1..7455bca233b 100644 --- a/cmd/crowdsec-cli/support.go +++ b/cmd/crowdsec-cli/support.go @@ -184,7 +184,7 @@ func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub, itemType str return nil } -func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting bouncers") if db == nil { @@ -193,7 +193,7 @@ func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { out := new(bytes.Buffer) - bouncers, err := db.ListBouncers() + bouncers, err := db.ListBouncers(ctx) if err != nil { return fmt.Errorf("unable to list bouncers: %w", err) } @@ -207,7 +207,7 @@ func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { return nil } -func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting agents") if db == nil { @@ -216,7 +216,7 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { out := new(bytes.Buffer) - machines, err := db.ListMachines() + machines, err := db.ListMachines(ctx) if err != nil { return fmt.Errorf("unable to list machines: %w", err) } @@ -518,11 +518,11 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error { } } - if err = cli.dumpBouncers(zipWriter, db); err != nil { + if err = cli.dumpBouncers(ctx, zipWriter, db); err != nil { log.Warnf("could not collect bouncers information: %s", err) } - if err = cli.dumpAgents(zipWriter, db); err != nil { + if err = cli.dumpAgents(ctx, zipWriter, db); err != nil { log.Warnf("could not collect agents information: %s", err) } diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index c57b8d87cff..f13d3583c07 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "runtime" @@ -63,7 +64,8 @@ func serveAPIServer(apiServer *apiserver.APIServer) { go func() { defer trace.CatchPanic("crowdsec/runAPIServer") log.Debugf("serving API after %s ms", time.Since(crowdsecT0)) - if err := apiServer.Run(apiReady); err != nil { + ctx := context.TODO() + if err := apiServer.Run(ctx, apiReady); err != nil { log.Fatal(err) } }() diff --git a/cmd/crowdsec/metrics.go b/cmd/crowdsec/metrics.go index d3c6e172091..db422471cc1 100644 --- a/cmd/crowdsec/metrics.go +++ b/cmd/crowdsec/metrics.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "net/http" @@ -118,7 +119,8 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha return } - decisions, err := dbClient.QueryDecisionCountByScenario() + ctx := context.TODO() + decisions, err := dbClient.QueryDecisionCountByScenario(ctx) if err != nil { log.Errorf("Error querying decisions for metrics: %v", err) next.ServeHTTP(w, r) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 68dc94367e2..d7d5295e6eb 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -82,7 +82,8 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { func (a *apic) FetchScenariosListFromDB() ([]string, error) { scenarios := make([]string, 0) - machines, err := a.dbClient.ListMachines() + ctx := context.TODO() + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) } @@ -406,7 +407,8 @@ func (a *apic) CAPIPullIsOld() (bool, error) { alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert - count, err := alerts.Count(a.dbClient.CTX) + ctx := context.TODO() + count, err := alerts.Count(ctx) if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } @@ -432,7 +434,8 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet filter["scopes"] = []string{*decision.Scope} } - dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter) + ctx := context.TODO() + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return 0, fmt.Errorf("expiring decisions error: %w", err) } @@ -464,7 +467,8 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi filter["scopes"] = []string{*scope} } - dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter) + ctx := context.TODO() + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return 0, fmt.Errorf("expiring decisions error: %w", err) } @@ -634,7 +638,8 @@ func (a *apic) PullTop(forcePull bool) error { log.Debug("Acquiring lock for pullCAPI") - err = a.dbClient.AcquirePullCAPILock() + ctx := context.TODO() + err = a.dbClient.AcquirePullCAPILock(ctx) if a.dbClient.IsLocked(err) { log.Info("PullCAPI is already running, skipping") return nil @@ -644,7 +649,8 @@ func (a *apic) PullTop(forcePull bool) error { defer func() { log.Debug("Releasing lock for pullCAPI") - if err := a.dbClient.ReleasePullCAPILock(); err != nil { + ctx := context.TODO() + if err := a.dbClient.ReleasePullCAPILock(ctx); err != nil { log.Errorf("while releasing lock: %v", err) } }() @@ -768,7 +774,8 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") } - alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert) + ctx := context.TODO() + alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(ctx, alert) if err != nil { return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) } @@ -844,7 +851,8 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap ) if !forcePull { - lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) + ctx := context.TODO() + lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName) if err != nil { return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } @@ -865,7 +873,8 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } - err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) + ctx := context.TODO() + err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) if err != nil { return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 128ce5a9639..dc24c9b3d45 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -14,8 +14,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -func (a *apic) GetMetrics() (*models.Metrics, error) { - machines, err := a.dbClient.ListMachines() +func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -31,7 +31,7 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { } } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, err } @@ -54,8 +54,8 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { }, nil } -func (a *apic) fetchMachineIDs() ([]string, error) { - machines, err := a.dbClient.ListMachines() +func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -75,7 +75,7 @@ func (a *apic) fetchMachineIDs() ([]string, error) { // Metrics are sent at start, then at the randomized metricsIntervalFirst, // then at regular metricsInterval. If a change is detected in the list // of machines, the next metrics are sent immediately. -func (a *apic) SendMetrics(stop chan (bool)) { +func (a *apic) SendMetrics(ctx context.Context, stop chan (bool)) { defer trace.CatchPanic("lapi/metricsToAPIC") // verify the list of machines every interval @@ -99,7 +99,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { machineIDs := []string{} reloadMachineIDs := func() { - ids, err := a.fetchMachineIDs() + ids, err := a.fetchMachineIDs(ctx) if err != nil { log.Debugf("unable to get machines (%s), will retry", err) @@ -135,7 +135,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { case <-metTicker.C: metTicker.Stop() - metrics, err := a.GetMetrics() + metrics, err := a.GetMetrics(ctx) if err != nil { log.Errorf("unable to get metrics (%s)", err) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index c6074801d7e..aa699377b82 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -170,7 +170,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { } if config.DbConfig.Flush != nil { - flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush) + flushScheduler, err = dbClient.StartFlushScheduler(ctx, config.DbConfig.Flush) if err != nil { return nil, err } @@ -300,7 +300,7 @@ func (s *APIServer) Router() (*gin.Engine, error) { return s.router, nil } -func (s *APIServer) Run(apiReady chan bool) error { +func (s *APIServer) Run(ctx context.Context, apiReady chan bool) error { defer trace.CatchPanic("lapi/runServer") tlsCfg, err := s.TLS.GetTLSConfig() @@ -364,7 +364,7 @@ func (s *APIServer) Run(apiReady chan bool) error { } s.apic.metricsTomb.Go(func() error { - s.apic.SendMetrics(make(chan bool)) + s.apic.SendMetrics(ctx, make(chan bool)) return nil }) } diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 7483e8dcdf9..71ff2f38eac 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -254,7 +254,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { c.DBClient.CanFlush = false } - alerts, err := c.DBClient.CreateAlert(machineID, input) + alerts, err := c.DBClient.CreateAlert(gctx.Request.Context(), machineID, input) c.DBClient.CanFlush = true if err != nil { @@ -276,7 +276,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { // FindAlerts: returns alerts from the database based on the specified filter func (c *Controller) FindAlerts(gctx *gin.Context) { - result, err := c.DBClient.QueryAlertWithFilter(gctx.Request.URL.Query()) + result, err := c.DBClient.QueryAlertWithFilter(gctx.Request.Context(), gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return @@ -302,7 +302,7 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) { return } - result, err := c.DBClient.GetAlertByID(alertID) + result, err := c.DBClient.GetAlertByID(gctx.Request.Context(), alertID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -336,7 +336,7 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { return } - err = c.DBClient.DeleteAlertByID(decisionID) + err = c.DBClient.DeleteAlertByID(gctx.Request.Context(), decisionID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -355,7 +355,7 @@ func (c *Controller) DeleteAlerts(gctx *gin.Context) { return } - nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.Context(), gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 543c832095a..e4eca3df9d3 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -1,6 +1,7 @@ package v1 import ( + "context" "encoding/json" "fmt" "net/http" @@ -50,7 +51,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { return } - data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query()) + data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.Context(), gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) @@ -73,7 +74,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { } if time.Now().UTC().Sub(bouncerInfo.LastPull) >= time.Minute { - if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(gctx.Request.Context(), time.Now().UTC(), bouncerInfo.ID); err != nil { log.Errorf("failed to update bouncer last pull: %v", err) } } @@ -91,7 +92,7 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { return } - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(decisionID) + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(gctx.Request.Context(), decisionID) if err != nil { c.HandleDBErrors(gctx, err) @@ -113,7 +114,7 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { } func (c *Controller) DeleteDecisions(gctx *gin.Context) { - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(gctx.Request.URL.Query()) + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(gctx.Request.Context(), gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) @@ -134,7 +135,7 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { gctx.JSON(http.StatusOK, deleteDecisionResp) } -func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(map[string][]string) ([]*ent.Decision, error)) error { +func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error { // respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 //FIXME : make it configurable needComma := false @@ -148,7 +149,7 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(filters) + data, err := dbFunc(gctx.Request.Context(), filters) if err != nil { return err } @@ -186,7 +187,7 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun return nil } -func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull time.Time, dbFunc func(time.Time, map[string][]string) ([]*ent.Decision, error)) error { +func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull time.Time, dbFunc func(context.Context, time.Time, map[string][]string) ([]*ent.Decision, error)) error { //respBuffer := bytes.NewBuffer([]byte{}) limit := 30000 //FIXME : make it configurable needComma := false @@ -200,7 +201,7 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(lastPull, filters) + data, err := dbFunc(gctx.Request.Context(), lastPull, filters) if err != nil { return err } @@ -310,7 +311,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if val, ok := gctx.Request.URL.Query()["startup"]; ok { if val[0] == "true" { - data, err = c.DBClient.QueryAllDecisionsWithFilters(filters) + data, err = c.DBClient.QueryAllDecisionsWithFilters(gctx.Request.Context(), filters) if err != nil { log.Errorf("failed querying decisions: %v", err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -321,7 +322,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en ret["new"] = FormatDecisions(data) // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters) + data, err = c.DBClient.QueryExpiredDecisionsWithFilters(gctx.Request.Context(), filters) if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -338,7 +339,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en } // getting new decisions - data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters) + data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(gctx.Request.Context(), bouncerInfo.LastPull, filters) if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -349,7 +350,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en ret["new"] = FormatDecisions(data) // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(bouncerInfo.LastPull.Add((-2 * time.Second)), filters) // do we want to give exactly lastPull time ? + data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(gctx.Request.Context(), bouncerInfo.LastPull.Add((-2 * time.Second)), filters) // do we want to give exactly lastPull time ? if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -396,7 +397,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { if err == nil { //Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions - if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(gctx.Request.Context(), streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) } } diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index e1231eaa9ec..79a68dec6d6 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -9,7 +9,7 @@ import ( func (c *Controller) HeartBeat(gctx *gin.Context) { machineID, _ := getMachineIDFromContext(gctx) - if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { + if err := c.DBClient.UpdateMachineLastHeartBeat(gctx.Request.Context(), machineID); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 84a6ef2583c..09e9c992ef2 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -23,7 +23,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) { return } - if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType); err != nil { + if _, err := c.DBClient.CreateMachine(gctx.Request.Context(), input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 4561b8f7789..bfd03b41106 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -80,7 +80,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { }) bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - bouncer, err := a.DbClient.SelectBouncerByName(bouncerName) + bouncer, err := a.DbClient.SelectBouncerByName(c.Request.Context(), bouncerName) // This is likely not the proper way, but isNotFound does not seem to work if err != nil && strings.Contains(err.Error(), "bouncer not found") { @@ -94,7 +94,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + bouncer, err = a.DbClient.CreateBouncer(c.Request.Context(), bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) return nil @@ -121,7 +121,7 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { hashStr := HashSHA512(val[0]) - bouncer, err := a.DbClient.SelectBouncer(hashStr) + bouncer, err := a.DbClient.SelectBouncer(c.Request.Context(), hashStr) if err != nil { logger.Errorf("while fetching bouncer info: %s", err) return nil @@ -163,7 +163,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { }) if bouncer.IPAddress == "" { - if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerIP(c.Request.Context(), clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() @@ -176,7 +176,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress) - if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerIP(c.Request.Context(), clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() @@ -192,7 +192,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { - if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerTypeAndVersion(c.Request.Context(), useragent[0], useragent[1], bouncer.ID); err != nil { logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 735c5f058cb..2c067853b9d 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -85,7 +85,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(c.Request.Context()) if ent.IsNotFound(err) { // Machine was not found, let's create it log.Infof("machine %s not found, create it", ret.machineID) @@ -102,7 +102,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { password := strfmt.Password(pwd) - ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) + ret.clientMachine, err = j.DbClient.CreateMachine(c.Request.Context(), &ret.machineID, &password, "", true, true, types.TlsAuthType) if err != nil { return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) } @@ -154,7 +154,7 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(c.Request.Context()) if err != nil { log.Infof("Error machine login for %s : %+v ", ret.machineID, err) return nil, err @@ -209,7 +209,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { } } - err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineScenarios(c.Request.Context(), scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -219,7 +219,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { clientIP := c.ClientIP() if auth.clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(c.Request.Context(), clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -229,7 +229,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" { log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress) - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(c.Request.Context(), clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication @@ -242,7 +242,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil { + if err := j.DbClient.UpdateMachineVersion(c.Request.Context(), useragent[1], auth.clientMachine.ID); err != nil { log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) log.Errorf("bad user agent from : %s", clientIP) diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 0d0fd0ecd42..60a9416b0bc 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -26,7 +26,7 @@ var SyncInterval = time.Second * 10 const PapiPullKey = "papi:last_pull" -var operationMap = map[string]func(*Message, *Papi, bool) error{ +var operationMap = map[string]func(context.Context, *Message, *Papi, bool) error{ "decision": DecisionCmd, "alert": AlertCmd, "management": ManagementCmd, @@ -148,7 +148,8 @@ func (p *Papi) handleEvent(event longpollclient.Event, sync bool) error { logger.Debugf("Calling operation '%s'", message.Header.OperationType) - err := operationFunc(message, p, sync) + ctx := context.TODO() + err := operationFunc(ctx, message, p, sync) if err != nil { return fmt.Errorf("'%s %s failed: %w", message.Header.OperationType, message.Header.OperationCmd, err) } @@ -236,7 +237,8 @@ func (p *Papi) Pull() error { lastTimestamp := time.Time{} - lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) + ctx := context.TODO() + lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } @@ -248,7 +250,8 @@ func (p *Papi) Pull() error { return fmt.Errorf("failed to marshal last timestamp: %w", err) } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + ctx := context.TODO() + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime)) @@ -277,7 +280,8 @@ func (p *Papi) Pull() error { continue } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + ctx := context.TODO() + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index a1137161698..30dafdcc726 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "time" @@ -41,7 +42,7 @@ type listUnsubscribe struct { Name string `json:"name"` } -func DecisionCmd(message *Message, p *Papi, sync bool) error { +func DecisionCmd(ctx context.Context, message *Message, p *Papi, sync bool) error { switch message.Header.OperationCmd { case "delete": data, err := json.Marshal(message.Data) @@ -64,7 +65,7 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { filter := make(map[string][]string) filter["uuid"] = UUIDs - _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err) } @@ -93,7 +94,7 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { return nil } -func AlertCmd(message *Message, p *Papi, sync bool) error { +func AlertCmd(ctx context.Context, message *Message, p *Papi, sync bool) error { switch message.Header.OperationCmd { case "add": data, err := json.Marshal(message.Data) @@ -152,7 +153,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } // use a different method: alert and/or decision might already be partially present in the database - _, err = p.DBClient.CreateOrUpdateAlert("", alert) + _, err = p.DBClient.CreateOrUpdateAlert(ctx, "", alert) if err != nil { log.Errorf("Failed to create alerts in DB: %s", err) } else { @@ -166,7 +167,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { return nil } -func ManagementCmd(message *Message, p *Papi, sync bool) error { +func ManagementCmd(ctx context.Context, message *Message, p *Papi, sync bool) error { if sync { p.Logger.Infof("Ignoring management command from PAPI in sync mode") return nil @@ -194,7 +195,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { filter["origin"] = []string{types.ListOrigin} filter["scenario"] = []string{unsubscribeMsg.Name} - _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err) } diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 3563adba68c..c54773ba5e7 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -35,12 +35,12 @@ const ( // CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it // if alert already exists, it checks it associated decisions already exists // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them -func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) { +func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, alertItem *models.Alert) (string, error) { if alertItem.UUID == "" { return "", errors.New("alert UUID is empty") } - alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(c.CTX) + alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(ctx) if err != nil && !ent.IsNotFound(err) { return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err) @@ -48,7 +48,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) // alert wasn't found, insert it (expected hotpath) if ent.IsNotFound(err) || len(alerts) == 0 { - alertIDs, err := c.CreateAlert(machineID, []*models.Alert{alertItem}) + alertIDs, err := c.CreateAlert(ctx, machineID, []*models.Alert{alertItem}) if err != nil { return "", fmt.Errorf("unable to create alert: %w", err) } @@ -165,7 +165,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(c.CTX) + decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { return "", fmt.Errorf("creating alert decisions: %w", err) } @@ -178,7 +178,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(c.CTX) + err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(ctx) if err != nil { return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) } @@ -191,7 +191,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) // it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions: // 1st pull, you get decisions [1,2,3]. it inserts [1,2,3] // 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin -func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) { +func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models.Alert) (int, int, int, error) { if alertItem == nil { return 0, 0, 0, errors.New("nil alert") } @@ -243,7 +243,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in SetScenarioVersion(*alertItem.ScenarioVersion). SetScenarioHash(*alertItem.ScenarioHash) - alertRef, err := alertB.Save(c.CTX) + alertRef, err := alertB.Save(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err) } @@ -252,7 +252,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, 0, 0, nil } - txClient, err := c.Ent.Tx(c.CTX) + txClient, err := c.Ent.Tx(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) } @@ -346,7 +346,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in decision.OriginEQ(DecOrigin), decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), decision.ValueIn(deleteChunk...), - )).Exec(c.CTX) + )).Exec(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -362,7 +362,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(c.CTX) + insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -390,7 +390,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, inserted, deleted, nil } -func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { +func (c *Client) createDecisionChunk(ctx context.Context, simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { decisionCreate := []*ent.DecisionCreate{} for _, decisionItem := range decisions { @@ -435,7 +435,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return nil, nil } - ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(c.CTX) + ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(ctx) if err != nil { return nil, err } @@ -443,7 +443,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return ret, nil } -func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { +func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { alertBuilders := []*ent.AlertCreate{} alertDecisions := [][]*ent.Decision{} @@ -539,7 +539,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } - events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX) + events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert events: %s", err) } @@ -568,7 +568,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ SetValue(value) } - metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX) + metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(ctx) if err != nil { c.Log.Warningf("error creating alert meta: %s", err) } @@ -578,7 +578,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - decisionRet, err := c.createDecisionChunk(*alertItem.Simulated, stopAtTime, decisionChunk) + decisionRet, err := c.createDecisionChunk(ctx, *alertItem.Simulated, stopAtTime, decisionChunk) if err != nil { return nil, fmt.Errorf("creating alert decisions: %w", err) } @@ -635,7 +635,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ return nil, nil } - alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(c.CTX) + alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) } @@ -652,7 +652,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ for retry < maxLockRetries { // so much for the happy path... but sqlite3 errors work differently - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) + _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(ctx) if err == nil { break } @@ -681,14 +681,14 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ return ret, nil } -func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) { +func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList []*models.Alert) ([]string, error) { var ( owner *ent.Machine err error ) if machineID != "" { - owner, err = c.QueryMachineByID(machineID) + owner, err = c.QueryMachineByID(ctx, machineID) if err != nil { if !errors.Is(err, UserNotExists) { return nil, fmt.Errorf("machine '%s': %w", machineID, err) @@ -706,7 +706,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str alertIDs := []string{} for _, alertChunk := range alertChunks { - ids, err := c.createAlertChunk(machineID, owner, alertChunk) + ids, err := c.createAlertChunk(ctx, machineID, owner, alertChunk) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } @@ -715,7 +715,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str } if owner != nil { - err = owner.Update().SetLastPush(time.Now().UTC()).Exec(c.CTX) + err = owner.Update().SetLastPush(time.Now().UTC()).Exec(ctx) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } @@ -970,11 +970,11 @@ func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string return counts, nil } -func (c *Client) TotalAlerts() (int, error) { - return c.Ent.Alert.Query().Count(c.CTX) +func (c *Client) TotalAlerts(ctx context.Context) (int, error) { + return c.Ent.Alert.Query().Count(ctx) } -func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) { +func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Alert, error) { sort := "DESC" // we sort by desc by default if val, ok := filter["sort"]; ok { @@ -1021,7 +1021,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, WithOwner() if limit == 0 { - limit, err = alerts.Count(c.CTX) + limit, err = alerts.Count(ctx) if err != nil { return nil, fmt.Errorf("unable to count nb alerts: %w", err) } @@ -1033,7 +1033,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID)) } - result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX) + result, err := alerts.Limit(paginationSize).Offset(offset).All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) } @@ -1062,35 +1062,35 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, return ret, nil } -func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { +func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Alert) (int, error) { idList := make([]int, 0) for _, alert := range alertItems { idList = append(idList, alert.ID) } _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events") } _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta") } _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions") } deleted, err := c.Ent.Alert.Delete(). - Where(alert.IDIn(idList...)).Exec(c.CTX) + Where(alert.IDIn(idList...)).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch") @@ -1101,10 +1101,10 @@ func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { return deleted, nil } -func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { +func (c *Client) DeleteAlertGraph(ctx context.Context, alertItem *ent.Alert) error { // delete the associated events _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID) @@ -1112,7 +1112,7 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated meta _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID) @@ -1120,14 +1120,14 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated decisions _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID) } // delete the alert - err = c.Ent.Alert.DeleteOne(alertItem).Exec(c.CTX) + err = c.Ent.Alert.DeleteOne(alertItem).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID) @@ -1136,26 +1136,26 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { return nil } -func (c *Client) DeleteAlertByID(id int) error { - alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(c.CTX) +func (c *Client) DeleteAlertByID(ctx context.Context, id int) error { + alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(ctx) if err != nil { return err } - return c.DeleteAlertGraph(alertItem) + return c.DeleteAlertGraph(ctx, alertItem) } -func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) { +func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) { preds, err := AlertPredicatesFromFilter(filter) if err != nil { return 0, err } - return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX) + return c.Ent.Alert.Delete().Where(preds...).Exec(ctx) } -func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { - alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX) +func (c *Client) GetAlertByID(ctx context.Context, alertID int) (*ent.Alert, error) { + alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(ctx) if err != nil { /*record not found, 404*/ if ent.IsNotFound(err) { diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index 2cc6b9dcb47..ed330a775f6 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "time" @@ -10,8 +11,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" ) -func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX) +func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx) if err != nil { return nil, err } @@ -19,8 +20,8 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) +func (c *Client) SelectBouncerByName(ctx context.Context, bouncerName string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(ctx) if err != nil { return nil, err } @@ -28,8 +29,8 @@ func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().All(c.CTX) +func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err) } @@ -37,14 +38,14 @@ func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { return result, nil } -func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { +func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { if ent.IsConstraintError(err) { return nil, fmt.Errorf("bouncer %s already exists", name) @@ -56,11 +57,11 @@ func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authTy return bouncer, nil } -func (c *Client) DeleteBouncer(name string) error { +func (c *Client) DeleteBouncer(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Bouncer. Delete(). Where(bouncer.NameEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } @@ -72,13 +73,13 @@ func (c *Client) DeleteBouncer(name string) error { return nil } -func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { +func (c *Client) BulkDeleteBouncers(ctx context.Context, bouncers []*ent.Bouncer) (int, error) { ids := make([]int, len(bouncers)) for i, b := range bouncers { ids[i] = b.ID } - nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(c.CTX) + nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, fmt.Errorf("unable to delete bouncers: %w", err) } @@ -86,10 +87,10 @@ func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { return nbDeleted, nil } -func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error { +func (c *Client) UpdateBouncerLastPull(ctx context.Context, lastPull time.Time, id int) error { _, err := c.Ent.Bouncer.UpdateOneID(id). SetLastPull(lastPull). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine last pull in database: %w", err) } @@ -97,8 +98,8 @@ func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error { return nil } -func (c *Client) UpdateBouncerIP(ipAddr string, id int) error { - _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(c.CTX) +func (c *Client) UpdateBouncerIP(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(ctx) if err != nil { return fmt.Errorf("unable to update bouncer ip address in database: %w", err) } @@ -106,8 +107,8 @@ func (c *Client) UpdateBouncerIP(ipAddr string, id int) error { return nil } -func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id int) error { - _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(c.CTX) +func (c *Client) UpdateBouncerTypeAndVersion(ctx context.Context, bType string, version string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(ctx) if err != nil { return fmt.Errorf("unable to update bouncer type and version in database: %w", err) } @@ -115,6 +116,6 @@ func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id in return nil } -func (c *Client) QueryBouncersLastPulltimeLT(t time.Time) ([]*ent.Bouncer, error) { - return c.Ent.Bouncer.Query().Where(bouncer.LastPullLT(t)).All(c.CTX) +func (c *Client) QueryBouncersLastPulltimeLT(ctx context.Context, t time.Time) ([]*ent.Bouncer, error) { + return c.Ent.Bouncer.Query().Where(bouncer.LastPullLT(t)).All(ctx) } diff --git a/pkg/database/config.go b/pkg/database/config.go index 8c3578ad596..9a9e23b84f5 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -1,14 +1,15 @@ package database import ( + "context" "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) -func (c *Client) GetConfigItem(key string) (*string, error) { - result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX) +func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) { + result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx) if err != nil && ent.IsNotFound(err) { return nil, nil } @@ -19,11 +20,11 @@ func (c *Client) GetConfigItem(key string) (*string, error) { return &result.Value, nil } -func (c *Client) SetConfigItem(key string, value string) error { +func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error { - nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX) + nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx) if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create - err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX) + err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx) if err != nil { return errors.Wrapf(QueryFail, "insert config item: %s", err) } diff --git a/pkg/database/database.go b/pkg/database/database.go index 6f392c46d21..ad532bee9d4 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -23,7 +23,6 @@ import ( type Client struct { Ent *ent.Client - CTX context.Context Log *log.Logger CanFlush bool Type string @@ -109,7 +108,6 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro return &Client{ Ent: client, - CTX: ctx, Log: clog, CanFlush: true, Type: config.Type, diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index 294515d603e..8a21001eaa3 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strconv" "strings" @@ -120,7 +121,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] return query, nil } -func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -137,7 +138,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") @@ -146,7 +147,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e return data, nil } -func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) @@ -164,7 +165,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters") } - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions") @@ -173,7 +174,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( return data, nil } -func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) { +func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*DecisionsByScenario, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -186,7 +187,7 @@ func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) var r []*DecisionsByScenario - err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(c.CTX, &r) + err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(ctx, &r) if err != nil { c.Log.Warningf("QueryDecisionCountByScenario : %s", err) return nil, errors.Wrap(QueryFail, "count all decisions with filters") @@ -195,7 +196,7 @@ func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) return r, nil } -func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) { var data []*ent.Decision var err error @@ -217,7 +218,7 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec decision.FieldValue, decision.FieldScope, decision.FieldOrigin, - ).Scan(c.CTX, &data) + ).Scan(ctx, &data) if err != nil { c.Log.Warningf("QueryDecisionWithFilter : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "query decision failed") @@ -254,7 +255,7 @@ func longestDecisionForScopeTypeValue(s *sql.Selector) { ) } -func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), decision.UntilGT(since), @@ -272,7 +273,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") @@ -281,7 +282,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters return data, nil } -func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.CreatedAtGT(since), decision.UntilGT(time.Now().UTC()), @@ -300,7 +301,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[ query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String()) @@ -309,20 +310,20 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[ return data, nil } -func (c *Client) DeleteDecisionById(decisionID int) ([]*ent.Decision, error) { - toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) +func (c *Client) DeleteDecisionById(ctx context.Context, decisionID int) ([]*ent.Decision, error) { + toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx) if err != nil { c.Log.Warningf("DeleteDecisionById : %s", err) return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID) } - count, err := c.DeleteDecisions(toDelete) + count, err := c.DeleteDecisions(ctx, toDelete) c.Log.Debugf("deleted %d decisions", count) return toDelete, err } -func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -425,13 +426,13 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - toDelete, err := decisions.All(c.CTX) + toDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("DeleteDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") } - count, err := c.DeleteDecisions(toDelete) + count, err := c.DeleteDecisions(ctx, toDelete) if err != nil { c.Log.Warningf("While deleting decisions : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") @@ -441,7 +442,7 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, } // ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items -func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -550,13 +551,13 @@ func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - DecisionsToDelete, err := decisions.All(c.CTX) + DecisionsToDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("ExpireDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter") } - count, err := c.ExpireDecisions(DecisionsToDelete) + count, err := c.ExpireDecisions(ctx, DecisionsToDelete) if err != nil { return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err) } @@ -575,13 +576,13 @@ func decisionIDs(decisions []*ent.Decision) []int { // ExpireDecisions sets the expiration of a list of decisions to now() // It returns the number of impacted decisions for the CAPI/PAPI -func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { +func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { if len(decisions) <= decisionDeleteBulkSize { ids := decisionIDs(decisions) rows, err := c.Ent.Decision.Update().Where( decision.IDIn(ids...), - ).SetUntil(time.Now().UTC()).Save(c.CTX) + ).SetUntil(time.Now().UTC()).Save(ctx) if err != nil { return 0, fmt.Errorf("expire decisions with provided filter: %w", err) } @@ -594,7 +595,7 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { total := 0 for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { - rows, err := c.ExpireDecisions(chunk) + rows, err := c.ExpireDecisions(ctx, chunk) if err != nil { return total, err } @@ -607,13 +608,13 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { // DeleteDecisions removes a list of decisions from the database // It returns the number of impacted decisions for the CAPI/PAPI -func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { +func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { if len(decisions) < decisionDeleteBulkSize { ids := decisionIDs(decisions) rows, err := c.Ent.Decision.Delete().Where( decision.IDIn(ids...), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err) } @@ -626,7 +627,7 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { tot := 0 for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { - rows, err := c.DeleteDecisions(chunk) + rows, err := c.DeleteDecisions(ctx, chunk) if err != nil { return tot, err } @@ -638,8 +639,8 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { } // ExpireDecision set the expiration of a decision to now() -func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error) { - toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) +func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) { + toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx) // XXX: do we want 500 or 404 here? if err != nil || len(toUpdate) == 0 { @@ -651,12 +652,12 @@ func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error return 0, nil, ItemNotFound } - count, err := c.ExpireDecisions(toUpdate) + count, err := c.ExpireDecisions(ctx, toUpdate) return count, toUpdate, err } -func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int @@ -674,7 +675,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -682,7 +683,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return count, nil } -func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int @@ -702,7 +703,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) decisions = decisions.Where(decision.UntilGT(time.Now().UTC())) - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, fmt.Errorf("fail to count decisions: %w", err) } @@ -710,7 +711,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) return count, nil } -func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.Duration, error) { +func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -732,7 +733,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D decisions = decisions.Order(ent.Desc(decision.FieldUntil)) - decision, err := decisions.First(c.CTX) + decision, err := decisions.First(ctx) if err != nil && !ent.IsNotFound(err) { return 0, fmt.Errorf("fail to get decision: %w", err) } @@ -744,7 +745,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D return decision.Until.Sub(time.Now().UTC()), nil } -func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { +func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) { ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) @@ -760,7 +761,7 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err := decisions.Count(c.CTX) + count, err := decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } diff --git a/pkg/database/flush.go b/pkg/database/flush.go index 56e42715b2c..5063da1265f 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -1,6 +1,7 @@ package database import ( + "context" "errors" "fmt" "time" @@ -17,7 +18,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { +func (c *Client) StartFlushScheduler(ctx context.Context, config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { maxItems := 0 maxAge := "" @@ -36,7 +37,7 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched // Init & Start cronjob every minute for alerts scheduler := gocron.NewScheduler(time.UTC) - job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) + job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, ctx, maxAge, maxItems) if err != nil { return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) } @@ -91,7 +92,7 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched } } - baJob, err := scheduler.Every(1).Minute().Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) + baJob, err := scheduler.Every(1).Minute().Do(c.FlushAgentsAndBouncers, ctx, config.AgentsGC, config.BouncersGC) if err != nil { return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) } @@ -102,10 +103,10 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched return scheduler, nil } -func (c *Client) FlushOrphans() { +func (c *Client) FlushOrphans(ctx context.Context) { /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ - eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(c.CTX) + eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan events: %s", err) return @@ -116,7 +117,7 @@ func (c *Client) FlushOrphans() { } eventsCount, err = c.Ent.Decision.Delete().Where( - decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(c.CTX) + decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan decisions: %s", err) @@ -128,7 +129,7 @@ func (c *Client) FlushOrphans() { } } -func (c *Client) flushBouncers(authType string, duration *time.Duration) { +func (c *Client) flushBouncers(ctx context.Context, authType string, duration *time.Duration) { if duration == nil { return } @@ -137,7 +138,7 @@ func (c *Client) flushBouncers(authType string, duration *time.Duration) { bouncer.LastPullLTE(time.Now().UTC().Add(-*duration)), ).Where( bouncer.AuthTypeEQ(authType), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while auto-deleting expired bouncers (%s): %s", authType, err) @@ -149,7 +150,7 @@ func (c *Client) flushBouncers(authType string, duration *time.Duration) { } } -func (c *Client) flushAgents(authType string, duration *time.Duration) { +func (c *Client) flushAgents(ctx context.Context, authType string, duration *time.Duration) { if duration == nil { return } @@ -158,7 +159,7 @@ func (c *Client) flushAgents(authType string, duration *time.Duration) { machine.LastHeartbeatLTE(time.Now().UTC().Add(-*duration)), machine.Not(machine.HasAlerts()), machine.AuthTypeEQ(authType), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while auto-deleting expired machines (%s): %s", authType, err) @@ -170,23 +171,23 @@ func (c *Client) flushAgents(authType string, duration *time.Duration) { } } -func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { +func (c *Client) FlushAgentsAndBouncers(ctx context.Context, agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { log.Debug("starting FlushAgentsAndBouncers") if agentsCfg != nil { - c.flushAgents(types.TlsAuthType, agentsCfg.CertDuration) - c.flushAgents(types.PasswordAuthType, agentsCfg.LoginPasswordDuration) + c.flushAgents(ctx, types.TlsAuthType, agentsCfg.CertDuration) + c.flushAgents(ctx, types.PasswordAuthType, agentsCfg.LoginPasswordDuration) } if bouncersCfg != nil { - c.flushBouncers(types.TlsAuthType, bouncersCfg.CertDuration) - c.flushBouncers(types.ApiKeyAuthType, bouncersCfg.ApiDuration) + c.flushBouncers(ctx, types.TlsAuthType, bouncersCfg.CertDuration) + c.flushBouncers(ctx, types.ApiKeyAuthType, bouncersCfg.ApiDuration) } return nil } -func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { +func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) error { var ( deletedByAge int deletedByNbItem int @@ -200,10 +201,10 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { } c.Log.Debug("Flushing orphan alerts") - c.FlushOrphans() + c.FlushOrphans(ctx) c.Log.Debug("Done flushing orphan alerts") - totalAlerts, err = c.TotalAlerts() + totalAlerts, err = c.TotalAlerts(ctx) if err != nil { c.Log.Warningf("FlushAlerts (max items count): %s", err) return fmt.Errorf("unable to get alerts count: %w", err) @@ -216,7 +217,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { "created_before": {MaxAge}, } - nbDeleted, err := c.DeleteAlertWithFilter(filter) + nbDeleted, err := c.DeleteAlertWithFilter(ctx, filter) if err != nil { c.Log.Warningf("FlushAlerts (max age): %s", err) return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) @@ -232,7 +233,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { // This gives us the oldest alert that we want to keep // We then delete all the alerts with an id lower than this one // We can do this because the id is auto-increment, and the database won't reuse the same id twice - lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ + lastAlert, err := c.QueryAlertWithFilter(ctx, map[string][]string{ "sort": {"DESC"}, "limit": {"1"}, // we do not care about fetching the edges, we just want the id @@ -252,7 +253,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { if maxid > 0 { // This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted - deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(c.CTX) + deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(ctx) if err != nil { c.Log.Errorf("FlushAlerts: Could not delete alerts: %s", err) diff --git a/pkg/database/lock.go b/pkg/database/lock.go index d25b71870f0..e928c10da43 100644 --- a/pkg/database/lock.go +++ b/pkg/database/lock.go @@ -1,6 +1,7 @@ package database import ( + "context" "time" "github.com/pkg/errors" @@ -16,12 +17,12 @@ const ( CapiPullLockName = "pullCAPI" ) -func (c *Client) AcquireLock(name string) error { +func (c *Client) AcquireLock(ctx context.Context, name string) error { log.Debugf("acquiring lock %s", name) _, err := c.Ent.Lock.Create(). SetName(name). SetCreatedAt(types.UtcNow()). - Save(c.CTX) + Save(ctx) if ent.IsConstraintError(err) { return err } @@ -31,21 +32,21 @@ func (c *Client) AcquireLock(name string) error { return nil } -func (c *Client) ReleaseLock(name string) error { +func (c *Client) ReleaseLock(ctx context.Context, name string) error { log.Debugf("releasing lock %s", name) - _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(c.CTX) + _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) } return nil } -func (c *Client) ReleaseLockWithTimeout(name string, timeout int) error { +func (c *Client) ReleaseLockWithTimeout(ctx context.Context, name string, timeout int) error { log.Debugf("releasing lock %s with timeout of %d minutes", name, timeout) _, err := c.Ent.Lock.Delete().Where( lock.NameEQ(name), lock.CreatedAtLT(time.Now().UTC().Add(-time.Duration(timeout)*time.Minute)), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) @@ -57,21 +58,21 @@ func (c *Client) IsLocked(err error) bool { return ent.IsConstraintError(err) } -func (c *Client) AcquirePullCAPILock() error { +func (c *Client) AcquirePullCAPILock(ctx context.Context) error { /*delete orphan "old" lock if present*/ - err := c.ReleaseLockWithTimeout(CapiPullLockName, CAPIPullLockTimeout) + err := c.ReleaseLockWithTimeout(ctx, CapiPullLockName, CAPIPullLockTimeout) if err != nil { log.Errorf("unable to release pullCAPI lock: %s", err) } - return c.AcquireLock(CapiPullLockName) + return c.AcquireLock(ctx, CapiPullLockName) } -func (c *Client) ReleasePullCAPILock() error { +func (c *Client) ReleasePullCAPILock(ctx context.Context) error { log.Debugf("deleting lock %s", CapiPullLockName) _, err := c.Ent.Lock.Delete().Where( lock.NameEQ(CapiPullLockName), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 7a64c1d4d6e..7118e08fae3 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "time" @@ -16,7 +17,7 @@ import ( const CapiMachineID = types.CAPIOrigin const CapiListsMachineID = types.ListOrigin -func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { +func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { c.Log.Warningf("CreateMachine: %s", err) @@ -26,18 +27,18 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA machineExist, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(*machineID)). - Select(machine.FieldMachineId).Strings(c.CTX) + Select(machine.FieldMachineId).Strings(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } if len(machineExist) > 0 { if force { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) } - machine, err := c.QueryMachineByID(*machineID) + machine, err := c.QueryMachineByID(ctx, *machineID) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } @@ -53,7 +54,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA SetIpAddress(ipAddress). SetIsValidated(isValidated). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) @@ -63,11 +64,11 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA return machine, nil } -func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { +func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.Machine, error) { machine, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(machineID)). - Only(c.CTX) + Only(ctx) if err != nil { c.Log.Warningf("QueryMachineByID : %s", err) return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID) @@ -75,16 +76,16 @@ func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { return machine, nil } -func (c *Client) ListMachines() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().All(c.CTX) +func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing machines: %s", err) } return machines, nil } -func (c *Client) ValidateMachine(machineID string) error { - rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(c.CTX) +func (c *Client) ValidateMachine(ctx context.Context, machineID string) error { + rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "validating machine: %s", err) } @@ -94,11 +95,11 @@ func (c *Client) ValidateMachine(machineID string) error { return nil } -func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { +func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error) { var machines []*ent.Machine var err error - machines, err = c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) + machines, err = c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(ctx) if err != nil { c.Log.Warningf("QueryPendingMachine : %s", err) return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err) @@ -106,11 +107,11 @@ func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { return machines, nil } -func (c *Client) DeleteWatcher(name string) error { +func (c *Client) DeleteWatcher(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Machine. Delete(). Where(machine.MachineIdEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } @@ -122,59 +123,59 @@ func (c *Client) DeleteWatcher(name string) error { return nil } -func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) { +func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine) (int, error) { ids := make([]int, len(machines)) for i, b := range machines { ids[i] = b.ID } - nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(c.CTX) + nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, err } return nbDeleted, nil } -func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX) +func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error { + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err) } return nil } -func (c *Client) UpdateMachineScenarios(scenarios string, ID int) error { +func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, ID int) error { _, err := c.Ent.Machine.UpdateOneID(ID). SetUpdatedAt(time.Now().UTC()). SetScenarios(scenarios). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine in database: %s", err) } return nil } -func (c *Client) UpdateMachineIP(ipAddr string, ID int) error { +func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, ID int) error { _, err := c.Ent.Machine.UpdateOneID(ID). SetIpAddress(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine IP in database: %s", err) } return nil } -func (c *Client) UpdateMachineVersion(ipAddr string, ID int) error { +func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, ID int) error { _, err := c.Ent.Machine.UpdateOneID(ID). SetVersion(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine version in database: %s", err) } return nil } -func (c *Client) IsMachineRegistered(machineID string) (bool, error) { - exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX) +func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) { + exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx) if err != nil { return false, err } @@ -189,6 +190,6 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) { } -func (c *Client) QueryLastValidatedHeartbeatLT(t time.Time) ([]*ent.Machine, error) { - return c.Ent.Machine.Query().Where(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)).All(c.CTX) +func (c *Client) QueryLastValidatedHeartbeatLT(ctx context.Context, t time.Time) ([]*ent.Machine, error) { + return c.Ent.Machine.Query().Where(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)).All(ctx) } diff --git a/pkg/exprhelpers/helpers.go b/pkg/exprhelpers/helpers.go index 5c041aa2886..a0d00d584dd 100644 --- a/pkg/exprhelpers/helpers.go +++ b/pkg/exprhelpers/helpers.go @@ -2,6 +2,7 @@ package exprhelpers import ( "bufio" + "context" "encoding/base64" "fmt" "math" @@ -591,7 +592,8 @@ func GetDecisionsCount(params ...any) (any, error) { return 0, nil } - count, err := dbClient.CountDecisionsByValue(value) + ctx := context.TODO() + count, err := dbClient.CountDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -613,7 +615,8 @@ func GetDecisionsSinceCount(params ...any) (any, error) { return 0, nil } sinceTime := time.Now().UTC().Add(-sinceDuration) - count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime) + ctx := context.TODO() + count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -627,7 +630,8 @@ func GetActiveDecisionsCount(params ...any) (any, error) { log.Error("No database config to call GetActiveDecisionsCount()") return 0, nil } - count, err := dbClient.CountActiveDecisionsByValue(value) + ctx := context.TODO() + count, err := dbClient.CountActiveDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get active decisions count from value '%s'", value) return 0, err @@ -641,7 +645,8 @@ func GetActiveDecisionsTimeLeft(params ...any) (any, error) { log.Error("No database config to call GetActiveDecisionsTimeLeft()") return 0, nil } - timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(value) + ctx := context.TODO() + timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value) if err != nil { log.Errorf("Failed to get active decisions time left from value '%s'", value) return 0, err From a1a6749e44920f828a911fb7d9fd8b2bad8dbbb2 Mon Sep 17 00:00:00 2001 From: marco Date: Mon, 10 Jun 2024 10:02:20 +0200 Subject: [PATCH 2/2] add context to unit tests --- pkg/apiserver/apic_metrics_test.go | 2 +- pkg/apiserver/apic_test.go | 6 +++--- pkg/apiserver/apiserver_test.go | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/apiserver/apic_metrics_test.go b/pkg/apiserver/apic_metrics_test.go index f3b9b352316..07266f09834 100644 --- a/pkg/apiserver/apic_metrics_test.go +++ b/pkg/apiserver/apic_metrics_test.go @@ -86,7 +86,7 @@ func TestAPICSendMetrics(t *testing.T) { stop := make(chan bool) httpmock.ZeroCallCounters() - go api.SendMetrics(stop) + go api.SendMetrics(context.TODO(), stop) time.Sleep(tc.duration) stop <- true diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 10f4cf9444b..0959077ab5b 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -379,7 +379,7 @@ func TestAPICGetMetrics(t *testing.T) { ExecX(context.Background()) } - foundMetrics, err := apiClient.GetMetrics() + foundMetrics, err := apiClient.GetMetrics(context.TODO()) require.NoError(t, err) assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) @@ -920,7 +920,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { require.NoError(t, err) blocklistConfigItemName := "blocklist:blocklist1:last_pull" - lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err := api.dbClient.GetConfigItem(context.TODO(), blocklistConfigItemName) require.NoError(t, err) assert.NotEqual(t, "", *lastPullTimestamp) @@ -932,7 +932,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { err = api.PullTop(false) require.NoError(t, err) - secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + secondLastPullTimestamp, err := api.dbClient.GetConfigItem(context.TODO(), blocklistConfigItemName) require.NoError(t, err) assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 20c48337833..0aff4008556 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -167,7 +167,7 @@ func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCf dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - err = dbClient.ValidateMachine(machineID) + err = dbClient.ValidateMachine(context.TODO(), machineID) require.NoError(t, err) } @@ -177,7 +177,7 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - machines, err := dbClient.ListMachines() + machines, err := dbClient.ListMachines(context.TODO()) require.NoError(t, err) for _, machine := range machines { @@ -273,7 +273,7 @@ func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string { apiKey, err := middlewares.GenerateAPIKey(keyLength) require.NoError(t, err) - _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) + _, err = dbClient.CreateBouncer(context.TODO(), "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) require.NoError(t, err) return apiKey