Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: race condition on StrID #1084

Merged
merged 2 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/http-server/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,45 @@ func TestHttpServer(t *testing.T) {
})
}
}

// TestHttpServerConcurrent is meant to be run with the "-race" flag.
// Multiple requests are sent concurrently to the server and race conditions are checked.
// It is especially useful to ensure that rules and their metadata are not edited in an unsafe way
// after parsing time.
func TestHttpServerConcurrent(t *testing.T) {
tests := []struct {
name string
path string
expStatus int
body []byte // if body is populated, POST request is sent
}{
{"negative", "/", 200, nil},
{"positive for query parameter 1", "/?id=0", 403, nil},
{"positive for request body", "/", 403, []byte("password")},
}
// Spin up the test server with default.conf configuration
testServer := setupTestServer(t)
defer testServer.Close()
// a t.Run wraps all the concurrent tests and permits to close the server only once test is done
// See https://github.com/golang/go/issues/17791
t.Run("concurrent test", func(t *testing.T) {
for _, tc := range tests {
tt := tc
for i := 0; i < 10; i++ {
// Each test case is added 10 times and then run concurrently
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var statusCode int
if tt.body == nil {
statusCode = doGetRequest(t, testServer.URL+tt.path)
} else {
statusCode = doPostRequest(t, testServer.URL+tt.path, tt.body)
}
if want, have := tt.expStatus, statusCode; want != have {
t.Errorf("Unexpected status code, want: %d, have: %d", want, have)
}
})
}
}
})
}
2 changes: 2 additions & 0 deletions internal/actions/ctl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ func TestCtl(t *testing.T) {
waf := corazawaf.NewWAF()
waf.Logger = logger
r := corazawaf.NewRule()
r.ID_ = 1
r.LogID_ = "1"
err := waf.Rules.Add(r)
if err != nil {
t.Fatalf("failed to add rule: %s", err.Error())
Expand Down
1 change: 1 addition & 0 deletions internal/actions/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (a *idFn) Init(r plugintypes.RuleMetadata, data string) error {

cr := r.(*corazawaf.Rule)
cr.ID_ = int(i)
cr.LogID_ = strconv.Itoa(i)
return nil
}

Expand Down
47 changes: 20 additions & 27 deletions internal/corazarules/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@
package corazarules

import (
"strconv"

"github.com/corazawaf/coraza/v3/types"
)

const noID = 0

// RuleMetadata is used to store rule metadata
// that can be used across packages
type RuleMetadata struct {
ID_ int
File_ string
Line_ int
Rev_ string
Severity_ types.RuleSeverity
Version_ string
Tags_ []string
Maturity_ int
Accuracy_ int
Operator_ string
Phase_ types.RulePhase
Raw_ string
SecMark_ string
cachedStrID_ string
ID_ int
// Stores the string representation of the rule ID for logging purposes.
// If the rule is part of a chain, the parent ID is used as log ID.
// This approach prevents repeated computations in performance-critical sections, enhancing efficiency.
// It is stored for performance reasons, avoiding to perfrom the computation multiple times in the hot path
LogID_ string
File_ string
Line_ int
Rev_ string
Severity_ types.RuleSeverity
Version_ string
Tags_ []string
Maturity_ int
Accuracy_ int
Operator_ string
Phase_ types.RulePhase
Raw_ string
SecMark_ string
// Contains the Id of the parent rule if you are inside
// a chain. Otherwise, it will be 0
ParentID_ int
Expand Down Expand Up @@ -85,13 +85,6 @@ func (r *RuleMetadata) SecMark() string {
return r.SecMark_
}

func (r *RuleMetadata) StrID() string {
if r.cachedStrID_ == "" {
rid := r.ID_
if rid == noID {
rid = r.ParentID_
}
r.cachedStrID_ = strconv.Itoa(rid)
}
return r.cachedStrID_
func (r *RuleMetadata) LogID() string {
return r.LogID_
}
2 changes: 1 addition & 1 deletion internal/corazawaf/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (r *Rule) doEvaluate(logger debuglog.Logger, phase types.RulePhase, tx *Tra
defer logger.Debug().Msg("Finished rule evaluation")

ruleCol := tx.variables.rule
ruleCol.SetIndex("id", 0, r.StrID())
ruleCol.SetIndex("id", 0, r.LogID())
if r.Msg != nil {
ruleCol.SetIndex("msg", 0, r.Msg.String())
}
Expand Down
12 changes: 12 additions & 0 deletions internal/corazawaf/rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func TestMatchEvaluate(t *testing.T) {
r.Msg, _ = macro.NewMacro("Message")
r.LogData, _ = macro.NewMacro("Data Message")
r.ID_ = 1
r.LogID_ = "1"
if err := r.AddVariable(variables.ArgsGet, "", false); err != nil {
t.Error(err)
}
Expand All @@ -44,6 +45,7 @@ func TestMatchEvaluate(t *testing.T) {
func TestNoMatchEvaluate(t *testing.T) {
r := NewRule()
r.ID_ = 1
r.LogID_ = "1"
if err := r.AddVariable(variables.ArgsGet, "", false); err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -89,6 +91,7 @@ func TestNoMatchEvaluateBecauseOfException(t *testing.T) {
r.Msg, _ = macro.NewMacro("Message")
r.LogData, _ = macro.NewMacro("Data Message")
r.ID_ = 1
r.LogID_ = "1"
if err := r.AddVariable(tc.variable, "", false); err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -129,6 +132,7 @@ func (*dummyFlowAction) Type() plugintypes.ActionType {
func TestFlowActionIfDetectionOnlyEngine(t *testing.T) {
r := NewRule()
r.ID_ = 1
r.LogID_ = "1"
r.operator = nil
action := &dummyFlowAction{}
_ = r.AddAction("dummyFlowAction", action)
Expand Down Expand Up @@ -174,13 +178,15 @@ func TestMatchVariableRunsActionTypeNondisruptive(t *testing.T) {
func TestDisruptiveActionFromChainNotEvaluated(t *testing.T) {
r := NewRule()
r.ID_ = 1
r.LogID_ = "1"
r.operator = nil
r.HasChain = true
action := &dummyNonDisruptiveAction{}
_ = r.AddAction("dummyNonDisruptiveAction", action)
chainedRule := NewRule()
chainedRule.ID_ = 0
chainedRule.ParentID_ = 1
chainedRule.LogID_ = "1"
chainedRule.operator = nil
chainedAction := &dummyDenyAction{}
_ = chainedRule.AddAction("dummyDenyAction", chainedAction)
Expand All @@ -201,6 +207,7 @@ func TestRuleDetailsTransferredToTransaction(t *testing.T) {
r := NewRule()
r.ID_ = 0
r.ParentID_ = 1
r.LogID_ = "1"
r.Capture = true
r.operator = nil
tx := NewWAF().NewTransaction()
Expand All @@ -226,6 +233,7 @@ func TestSecActionMessagePropagationInMatchData(t *testing.T) {
r.Msg, _ = macro.NewMacro("Message")
r.LogData, _ = macro.NewMacro("Data Message")
r.ID_ = 1
r.LogID_ = "1"
// SecAction uses nil operator
r.operator = nil
tx := NewWAF().NewTransaction()
Expand Down Expand Up @@ -545,13 +553,15 @@ func TestTransformArgNoCacheForTXVariable(t *testing.T) {
func TestCaptureNotPropagatedToInnerChainRule(t *testing.T) {
r := NewRule()
r.ID_ = 1
r.LogID_ = "1"
r.operator = nil
r.HasChain = true
r.Phase_ = 1
r.Capture = true
chainedRule := NewRule()
chainedRule.ID_ = 0
chainedRule.ParentID_ = 1
chainedRule.LogID_ = "1"
chainedRule.operator = nil
chainedRule.Capture = false
r.Chain = chainedRule
Expand All @@ -567,6 +577,7 @@ func TestCaptureNotPropagatedToInnerChainRule(t *testing.T) {
func TestExpandMacroAfterWholeRuleEvaluation(t *testing.T) {
r := NewRule()
r.ID_ = 1
r.LogID_ = "1"
r.operator = nil
r.HasChain = true
r.Phase_ = 1
Expand All @@ -577,6 +588,7 @@ func TestExpandMacroAfterWholeRuleEvaluation(t *testing.T) {
chainedRule := NewRule()
chainedRule.ID_ = 0
chainedRule.ParentID_ = 1
chainedRule.LogID_ = "1"
chainedRule.operator = nil

_ = r.AddVariable(variables.RequestURI, "", false)
Expand Down
1 change: 1 addition & 0 deletions internal/corazawaf/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ func TestLogCallback(t *testing.T) {
tx := waf.NewTransaction()
rule := NewRule()
rule.ID_ = 1
rule.LogID_ = "1"
rule.Phase_ = 1
rule.Log = true
_ = rule.AddAction("deny", testCase.action)
Expand Down
1 change: 1 addition & 0 deletions internal/seclang/directives.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func directiveSecMarker(options *DirectiveOptions) error {
rule.Raw_ = fmt.Sprintf("SecMarker %s", options.Opts)
rule.SecMark_ = options.Opts
rule.ID_ = 0
rule.LogID_ = "0"
rule.Phase_ = 0
rule.Line_ = options.Parser.LastLine
rule.File_ = options.Parser.ConfigFile
Expand Down
3 changes: 3 additions & 0 deletions internal/seclang/rule_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ func ParseRule(options RuleOptions) (*corazawaf.Rule, error) {

if parent := getLastRuleExpectingChain(options.WAF); parent != nil {
rule.ParentID_ = parent.ID_
// While the ID_ will be kept to 0 being a chain rule, the LogID_ is meant to be
// the printable ID that represents the chain rule, therefore the parent's ID is inherited.
rule.LogID_ = parent.LogID_
lastChain := parent
for lastChain.Chain != nil {
lastChain = lastChain.Chain
Expand Down
2 changes: 1 addition & 1 deletion magefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func Test() error {
return err
}

if err := sh.RunV("go", "test", "./examples/http-server"); err != nil {
if err := sh.RunV("go", "test", "./examples/http-server", "-race"); err != nil {
return err
}

Expand Down
Loading