Skip to content

Commit

Permalink
Improve / simplify goroutine management (#31)
Browse files Browse the repository at this point in the history
* Sync goroutines with errgroup

* Lint
  • Loading branch information
mmetc committed Mar 17, 2023
1 parent 10187ca commit 457d587
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 1,867 deletions.
6 changes: 3 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
)

var (
blocklistMirrorLogFilePath string = "crowdsec-blocklist-mirror.log"
blocklistMirrorAccessLogFilePath string = "crowdsec-blocklist-mirror_access.log"
blocklistMirrorLogFilePath = "crowdsec-blocklist-mirror.log"
blocklistMirrorAccessLogFilePath = "crowdsec-blocklist-mirror_access.log"
)

type CrowdsecConfig struct {
Expand Down Expand Up @@ -126,7 +126,7 @@ func (cfg *Config) ValidateAndSetDefaults() error {
}

if !strings.HasSuffix(cfg.CrowdsecConfig.LapiURL, "/") {
cfg.CrowdsecConfig.LapiURL = cfg.CrowdsecConfig.LapiURL + "/"
cfg.CrowdsecConfig.LapiURL += "/"
}

if cfg.CrowdsecConfig.UpdateFrequency == "" {
Expand Down
6 changes: 3 additions & 3 deletions config/crowdsec-blocklist-mirror.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ crowdsec_config:
insecure_skip_verify: false

blocklists:
- format: plain_text # Supported formats are either of "plain_text", "mikrotik"
- format: plain_text # Supported formats are either "plain_text" or "mikrotik"
endpoint: /security/blocklist
authentication:
type: none # Supported types are either of "none", "ip_based", "basic"
type: none # Supported types are either "none", "ip_based" or "basic"
user:
password:
trusted_ips: # IP ranges, or IPs which don't require auth to access this blocklist
trusted_ips: # IP ranges, or IPs that don't require auth to access this blocklist
- 127.0.0.1
- ::1

Expand Down
2 changes: 1 addition & 1 deletion formatters.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models"
)

var FormattersByName map[string]func(w http.ResponseWriter, r *http.Request) = map[string]func(w http.ResponseWriter, r *http.Request){
var FormattersByName = map[string]func(w http.ResponseWriter, r *http.Request){
"plain_text": PlainTextFormatter,
"mikrotik": MikroTikFormatter,
}
Expand Down
7 changes: 4 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ go 1.20

require (
github.com/crowdsecurity/crowdsec v1.4.6
github.com/crowdsecurity/go-cs-bouncer v0.0.2
github.com/crowdsecurity/go-cs-bouncer v0.0.3
github.com/felixge/httpsnoop v1.0.3
github.com/prometheus/client_golang v1.14.0
github.com/sirupsen/logrus v1.9.0
golang.org/x/sync v0.1.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v2 v2.4.0
)
Expand Down Expand Up @@ -43,8 +44,8 @@ require (
github.com/prometheus/common v0.40.0 // indirect
github.com/prometheus/procfs v0.9.0 // indirect
go.mongodb.org/mongo-driver v1.11.2 // indirect
golang.org/x/net v0.4.0 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/net v0.8.0 // indirect
golang.org/x/sys v0.6.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
1,824 changes: 10 additions & 1,814 deletions go.sum

Large diffs are not rendered by default.

117 changes: 76 additions & 41 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package main

import (
"bytes"
"context"
"errors"
"flag"
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/models"
Expand All @@ -16,17 +19,27 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)

var globalDecisionRegistry DecisionRegistry = DecisionRegistry{
var globalDecisionRegistry = DecisionRegistry{
ActiveDecisionsByValue: make(map[string]*models.Decision),
}

func runServer(config Config) {
func listenAndServe(server *http.Server, config Config) error {
if config.TLS.CertFile != "" && config.TLS.KeyFile != "" {
log.Infof("Starting server with TLS at %s", config.ListenURI)
return server.ListenAndServeTLS(config.TLS.CertFile, config.TLS.KeyFile)
}
log.Infof("Starting server at %s", config.ListenURI)
return server.ListenAndServe()
}

func runServer(ctx context.Context, g *errgroup.Group, config Config) error {
for _, blockListCFG := range config.Blocklists {
f, err := getHandlerForBlockList(blockListCFG)
if err != nil {
log.Fatal(err)
return err
}
http.HandleFunc(blockListCFG.Endpoint, f)
log.Infof("serving blocklist in format %s at endpoint %s", blockListCFG.Format, blockListCFG.Endpoint)
Expand All @@ -38,33 +51,31 @@ func runServer(config Config) {
http.Handle(config.Metrics.Endpoint, promhttp.Handler())
}

if config.TLS.CertFile != "" && config.TLS.KeyFile != "" {
log.Infof("Starting server with TLS at %s", config.ListenURI)
if config.EnableAccessLogs {
log.Fatal(http.ListenAndServeTLS(
config.ListenURI,
config.TLS.CertFile,
config.TLS.KeyFile,
CombinedLoggingHandler(config.getLoggerForFile(blocklistMirrorAccessLogFilePath), http.DefaultServeMux)))
} else {
log.Fatal(http.ListenAndServeTLS(
config.ListenURI,
config.TLS.CertFile,
config.TLS.KeyFile,
nil))
}
} else {
log.Infof("Starting server at %s", config.ListenURI)
if config.EnableAccessLogs {
log.Fatal(http.ListenAndServe(
config.ListenURI,
CombinedLoggingHandler(config.getLoggerForFile(blocklistMirrorAccessLogFilePath), http.DefaultServeMux)))
} else {
log.Fatal(http.ListenAndServe(
config.ListenURI,
nil))
}
var logHandler http.Handler
if config.EnableAccessLogs {
logHandler = CombinedLoggingHandler(config.getLoggerForFile(blocklistMirrorAccessLogFilePath), http.DefaultServeMux)
}

server := &http.Server{
Addr: config.ListenURI,
Handler: logHandler,
}

g.Go(func() error {
err := listenAndServe(server, config)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})

<-ctx.Done()

serverCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.Shutdown(serverCtx) //nolint: contextcheck

return nil
}

func main() {
Expand Down Expand Up @@ -122,20 +133,44 @@ func main() {
log.Fatal(err)
}

go func() {
decisionStreamer.Run()
log.Fatal("can't access LAPI")
}()
go runServer(config)
g, ctx := errgroup.WithContext(context.Background())

g.Go(func() error {
decisionStreamer.Run(ctx)
return fmt.Errorf("stream api init failed")
})

for decisions := range decisionStreamer.Stream {
if len(decisions.New) > 0 {
log.Infof("received %d new decisions", len(decisions.New))
g.Go(func() error {
err := runServer(ctx, g, config)
if err != nil {
return fmt.Errorf("blocklist server failed: %w", err)
}
if len(decisions.Deleted) > 0 {
log.Infof("received %d expired decisions", len(decisions.Deleted))
return nil
})

g.Go(func() error {
for {
select {
case <-ctx.Done():
log.Info("terminating bouncer process")
return nil
case decisions := <-decisionStreamer.Stream:
if decisions == nil {
continue
}
if len(decisions.New) > 0 {
log.Infof("received %d new decisions", len(decisions.New))
globalDecisionRegistry.AddDecisions(decisions.New)
}
if len(decisions.Deleted) > 0 {
log.Infof("received %d expired decisions", len(decisions.Deleted))
globalDecisionRegistry.DeleteDecisions(decisions.Deleted)
}
}
}
globalDecisionRegistry.AddDecisions(decisions.New)
globalDecisionRegistry.DeleteDecisions(decisions.Deleted)
})

if err := g.Wait(); err != nil {
log.Fatalf("process return with error: %s", err)
}
}
4 changes: 2 additions & 2 deletions test/test_blocklist_mirror.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def test_no_lapi(bouncer, bm_cfg_factory):
with bouncer(bm_binary, cfg) as bm:
bm.wait_for_lines_fnmatch([
"*connection refused*",
# "*terminating bouncer process*",
"*can't access LAPI*",
"*terminating bouncer process*",
"*stream api init failed*",
])
bm.proc.wait(timeout=0.2)
assert not bm.proc.is_running()

0 comments on commit 457d587

Please sign in to comment.