Skip to content

Commit

Permalink
Various panics fixes (#68)
Browse files Browse the repository at this point in the history
* Refactoring and panic fixes

- refactored message parsing
- introduced shared message utils
- fixed both panics

* Introduced custom Message, Request and Response

* Fixed response body processing

* Moved SPOA's method into corresponding file

* Fixed linting

* Using default values from common Message

* Removed message struct

* Added miss if rule engine is off

* Added message tests

* Different exceptions on message parsing

* Arguments lazy loading

* chore: makes lint happy.

---------

Co-authored-by: José Carlos Chávez <jcchavezs@gmail.com>
  • Loading branch information
zc-devs and jcchavezs committed Jun 30, 2023
1 parent f24802c commit dd5eb86
Show file tree
Hide file tree
Showing 4 changed files with 421 additions and 250 deletions.
119 changes: 119 additions & 0 deletions internal/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright The OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0

package internal

import (
"fmt"
"net"

spoe "github.com/criteo/haproxy-spoe-go"
)

type message struct {
msg *spoe.Message
args map[string]interface{}
}

func NewMessage(msg *spoe.Message) (*message, error) {
message := message{
msg: msg,
args: make(map[string]interface{}, msg.Args.Count()),
}
return &message, nil
}

func (m *message) findArg(name string) (interface{}, error) {
argVal, exist := m.args[name]
if exist {
return argVal, nil
}

ai := m.msg.Args
for ai.Next() {
m.args[ai.Arg.Name] = ai.Arg.Value
if ai.Arg.Name == name {
return ai.Arg.Value, nil
}
}

return nil, &ArgNotFoundError{name}
}

func (m *message) getStringArg(name string) (string, error) {
argVal, err := m.findArg(name)
if err != nil {
return "", err
}
if argVal == nil {
return "", nil
}
val, ok := argVal.(string)
if !ok {
return "", &typeMismatchError{name, "string", argVal}
}
return val, nil
}

func (m *message) getIntArg(name string) (int, error) {
argVal, err := m.findArg(name)
if err != nil {
return 0, err
}
if argVal == nil {
return 0, nil
}
val, ok := argVal.(int)
if !ok {
return 0, &typeMismatchError{name, "int", argVal}
}
return val, nil
}

func (m *message) getByteArrayArg(name string) ([]byte, error) {
argVal, err := m.findArg(name)
if err != nil {
return nil, err
}
if argVal == nil {
return nil, nil
}
val, ok := argVal.([]byte)
if !ok {
return nil, &typeMismatchError{name, "[]byte", argVal}
}
return val, nil
}

func (m *message) getIpArg(name string) (net.IP, error) {
argVal, err := m.findArg(name)
if err != nil {
return nil, err
}
if argVal == nil {
return nil, nil
}
val, ok := argVal.(net.IP)
if !ok {
return nil, &typeMismatchError{name, "net.IP", argVal}
}
return val, nil
}

type ArgNotFoundError struct {
argName string
}

func (e *ArgNotFoundError) Error() string {
return fmt.Sprintf("Argument '%s' not found", e.argName)
}

type typeMismatchError struct {
key string
expectedType string
actualValue interface{}
}

func (e *typeMismatchError) Error() string {
return fmt.Sprintf("Invalid argument for %s, %s expected, got %T", e.key, e.expectedType, e.actualValue)
}
233 changes: 81 additions & 152 deletions internal/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,168 +6,97 @@ package internal
import (
"fmt"
"net"
"time"

"github.com/corazawaf/coraza/v3/types"
spoe "github.com/criteo/haproxy-spoe-go"
"go.uber.org/zap"
)

func (s *SPOA) processRequest(msg spoe.Message) ([]spoe.Action, error) {
var (
ok bool
method = ""
path = "/"
query = ""
version = "1.1"
srcIP net.IP
srcPort = 0
dstIP net.IP
dstPort = 0
tx types.Transaction
)
var app *application

defer func() {
if tx == nil || app == nil {
return
}
if tx.IsInterrupted() {
tx.ProcessLogging()
if err := tx.Close(); err != nil {
app.logger.Error("failed to close transaction", zap.String("transaction_id", tx.ID()), zap.String("error", err.Error()))
}
} else {
if app.cfg.NoResponseCheck {
return
}
err := app.cache.SetWithExpire(tx.ID(), tx, time.Millisecond*time.Duration(app.cfg.TransactionTTLMilliseconds))
if err != nil {
app.logger.Error(fmt.Sprintf("failed to cache transaction: %s", err.Error()))
}
}
}()

for msg.Args.Next() {
arg := msg.Args.Arg
if arg.Name != "app" && app == nil {
return nil, fmt.Errorf("app is not set")
}

switch arg.Name {
case "app":
var ok bool
app, ok = s.applications[arg.Value.(string)]
if !ok {
if len(s.defaultApplication) > 0 {
app, ok = s.applications[s.defaultApplication]
if !ok {
return nil, fmt.Errorf("default application not found: %s", s.defaultApplication)
}
app.logger.Debug("application not found, using default", zap.Any("application", arg.Value), zap.String("default", s.defaultApplication))
} else {
return nil, fmt.Errorf("application not found: %v", arg.Value)
}
}
case "id":
id, ok := arg.Value.(string)
if !ok {
return nil, fmt.Errorf("invalid argument for http request id, string expected, got %v", arg.Value)
}
tx = app.waf.NewTransactionWithID(id)
case "src-ip":
srcIP, ok = arg.Value.(net.IP)
if !ok {
return nil, fmt.Errorf("invalid argument for src ip, net.IP expected, got %v", arg.Value)
}
case "src-port":
srcPort, ok = arg.Value.(int)
if !ok {
return nil, fmt.Errorf("invalid argument for src port, integer expected, got %v", arg.Value)
}
case "dst-ip":
dstIP, ok = arg.Value.(net.IP)
if !ok {
return nil, fmt.Errorf("invalid argument for dst ip, net.IP expected, got %v", arg.Value)
}
case "dst-port":
dstPort, ok = arg.Value.(int)
if !ok {
return nil, fmt.Errorf("invalid argument for dst port, integer expected, got %v", arg.Value)
}
case "method":
method, ok = arg.Value.(string)
if !ok {
return nil, fmt.Errorf("invalid argument for http request method, string expected, got %v", arg.Value)
}
case "path":
path, ok = arg.Value.(string)
if !ok {
app.logger.Error(fmt.Sprintf("invalid argument for http request path, string expected, got %v", arg.Value))
path = "/"
}
case "query":
query, ok = arg.Value.(string)
if !ok && arg.Value != nil {
app.logger.Error(fmt.Sprintf("invalid argument for http request query, string expected, got %v", arg.Value))
query = ""
}
case "version":
version, ok = arg.Value.(string)
if !ok {
app.logger.Error(fmt.Sprintf("invalid argument for http request version, string expected, got %v", arg.Value))
version = "1.1"
}
case "headers":
value, ok := arg.Value.(string)
if !ok {
app.logger.Error(fmt.Sprintf("invalid argument for http request headers, string expected, got %v", arg.Value))
value = ""
}

headers, err := s.readHeaders(value)
if err != nil {
return nil, err
}

for key, values := range headers {
for _, v := range values {
tx.AddRequestHeader(key, v)
}
}
case "body":
body, ok := arg.Value.([]byte)
if !ok {
return nil, fmt.Errorf("invalid argument for http request body, []byte expected, got %v", arg.Value)
}

it, _, err := tx.WriteRequestBody(body)
if err != nil {
return nil, err
}
if it != nil {
return s.processInterruption(it, hit), nil
}
default:
app.logger.Error("invalid message on the http frontend request", zap.String("name", arg.Name), zap.Any("value", arg.Value))
}
}
type request struct {
msg *message
app string
id string
srcIp net.IP
srcPort int
dstIp net.IP
dstPort int
method string
path string
query string
version string
headers string
body []byte
}

//app.logger.Debug(fmt.Sprintf("ProcessConnection: %s:%d -> %s:%d", srcIP.String(), srcPort, dstIP.String(), dstPort))
tx.ProcessConnection(srcIP.String(), srcPort, dstIP.String(), dstPort)
func NewRequest(spoeMsg *spoe.Message) (*request, error) {
msg, err := NewMessage(spoeMsg)
if err != nil {
return nil, err
}

//app.logger.Debug(fmt.Sprintf("ProcessURI: %s %s?%s %s", method, path, query, "HTTP/"+version))
tx.ProcessURI(path+"?"+query, method, "HTTP/"+version)
request := request{}
request.msg = msg

if it := tx.ProcessRequestHeaders(); it != nil {
return s.processInterruption(it, hit), nil
request.app, err = msg.getStringArg("app")
if err != nil {
return nil, err
}
it, err := tx.ProcessRequestBody()

request.id, err = request.msg.getStringArg("id")
if err != nil {
return nil, err
}
if it != nil {
return s.processInterruption(it, hit), nil

return &request, nil
}

func (req *request) init() error {
var err error

req.srcIp, err = req.msg.getIpArg("src-ip")
if err != nil {
return err
}

req.srcPort, err = req.msg.getIntArg("src-port")
if err != nil {
return err
}

req.dstIp, err = req.msg.getIpArg("dst-ip")
if err != nil {
return err
}
return s.message(miss), nil

req.dstPort, err = req.msg.getIntArg("dst-port")
if err != nil {
return err
}

req.method, err = req.msg.getStringArg("method")
if err != nil {
return err
}

req.path, err = req.msg.getStringArg("path")
if err != nil {
fmt.Println(err.Error())
}

req.query, err = req.msg.getStringArg("query")
if err != nil {
fmt.Println(err.Error())
}

req.version, err = req.msg.getStringArg("version")
if err != nil {
fmt.Println(err.Error())
}

req.headers, err = req.msg.getStringArg("headers")
if err != nil {
fmt.Println(err.Error())
}

req.body, _ = req.msg.getByteArrayArg("body")

return nil
}

0 comments on commit dd5eb86

Please sign in to comment.