Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] improve postgres bind params audit logging #34432

Merged
merged 1 commit into from Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 78 additions & 1 deletion lib/srv/db/audit_test.go
Expand Up @@ -75,6 +75,68 @@ func TestAuditPostgres(t *testing.T) {
requireEvent(t, testCtx, libevents.PostgresBindCode)
requireEvent(t, testCtx, libevents.PostgresExecuteCode)

bindTests := []struct {
desc string
sql string
params [][]byte
formatCodes []int16
wantParams []string
}{
{
desc: "zero format codes applies text format to all params",
sql: "select $1, $2",
params: [][]byte{[]byte("fish"), []byte("cat")},
wantParams: []string{"fish", "cat"},
},
{
desc: "one text format codes applies text format to all params",
sql: "select $1, $2",
params: [][]byte{[]byte("fish"), []byte("cat")},
formatCodes: []int16{0}, // text format.
wantParams: []string{"fish", "cat"},
},
{
desc: "one binary format codes applies binary format to all params",
sql: "select $1, $2",
params: [][]byte{[]byte("fish"), []byte("cat")},
formatCodes: []int16{1}, // binary format.
// event should encode binary as base64 strings.
wantParams: []string{"ZmlzaA==", "Y2F0"},
},
{
desc: "apply corresponding format code to each param",
sql: "select $1, $2, $3",
params: [][]byte{[]byte("fish"), []byte("cat"), []byte("dog")},
formatCodes: []int16{1, 0, 0}, // binary, text, text format.
wantParams: []string{"ZmlzaA==", "cat", "dog"},
},
{
desc: "more than one format codes but fewer than params is invalid bind",
sql: "select $1, $2, $3",
params: [][]byte{[]byte("fish"), []byte("cat"), []byte("dog")},
formatCodes: []int16{1, 0}, // binary, text.
wantParams: nil, // don't log params for invalid bind.
},
{
desc: "more format codes than params is invalid bind",
sql: "select $1, $2",
params: [][]byte{[]byte("fish"), []byte("cat")},
formatCodes: []int16{1, 0, 0}, // binary, text, text(missing)
wantParams: nil, // don't log params for invalid bind.
},
}
for _, test := range bindTests {
t.Run(test.desc, func(t *testing.T) {
resultUnnamed := psql.ExecParams(ctx, test.sql, test.params, nil, test.formatCodes, nil).Read()
require.NotNil(t, resultUnnamed)
require.NoError(t, resultUnnamed.Err)
requireEvent(t, testCtx, libevents.PostgresParseCode)
event := requireBindEvent(t, testCtx)
require.Equal(t, test.wantParams, event.Parameters)
requireEvent(t, testCtx, libevents.PostgresExecuteCode)
})
}

// Closing connection should trigger session end event.
err = psql.Close(ctx)
require.NoError(t, err)
Expand Down Expand Up @@ -280,23 +342,37 @@ func TestAuditClickHouseHTTP(t *testing.T) {
}

func assertDatabaseQueryFromAuditEvent(t *testing.T, event events.AuditEvent, wantQuery string) {
t.Helper()
query, ok := event.(*events.DatabaseSessionQuery)
require.True(t, ok)
require.Equal(t, wantQuery, query.DatabaseQuery)
}

func requireEvent(t *testing.T, testCtx *testContext, code string) {
func requireBindEvent(t *testing.T, testCtx *testContext) *events.PostgresBind {
t.Helper()
event := requireEvent(t, testCtx, libevents.PostgresBindCode)
bindEvent, ok := event.(*events.PostgresBind)
require.True(t, ok)
require.NotNil(t, bindEvent)
return bindEvent
}

func requireEvent(t *testing.T, testCtx *testContext, code string) events.AuditEvent {
t.Helper()
event := waitForAnyEvent(t, testCtx)
require.Equal(t, code, event.GetCode())
return event
}

func requireQueryEvent(t *testing.T, testCtx *testContext, code, query string) {
t.Helper()
event := waitForAnyEvent(t, testCtx)
require.Equal(t, code, event.GetCode())
require.Equal(t, query, event.(*events.DatabaseSessionQuery).DatabaseQuery)
}

func waitForAnyEvent(t *testing.T, testCtx *testContext) events.AuditEvent {
t.Helper()
select {
case event := <-testCtx.emitter.C():
return event
Expand All @@ -308,6 +384,7 @@ func waitForAnyEvent(t *testing.T, testCtx *testContext) events.AuditEvent {

// waitForEvent waits for particular event code ignoring other events.
func waitForEvent(t *testing.T, testCtx *testContext, code string) events.AuditEvent {
t.Helper()
for {
select {
case event := <-testCtx.emitter.C():
Expand Down
33 changes: 22 additions & 11 deletions lib/srv/db/postgres/engine.go
Expand Up @@ -19,6 +19,7 @@ package postgres
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -603,9 +604,12 @@ func formatParameters(parameters [][]byte, formatCodes []int16) (formatted []str
// by "parameter format codes" in the Bind message (0 - text, 1 - binary).
//
// Be a bit paranoid and make sure that number of format codes matches the
// number of parameters, or there are no format codes in which case all
// parameters will be text.
if len(formatCodes) != 0 && len(formatCodes) != len(parameters) {
// number of parameters, or there are zero or one format codes.
// zero format codes applies text format to all params.
// one format code applies the same format code to all params.
// https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BIND
// https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-FUNCTIONCALL
if len(formatCodes) > 1 && len(formatCodes) != len(parameters) {
logrus.Warnf("Postgres parameter format codes and parameters don't match: %#v %#v.",
parameters, formatCodes)
return formatted
Expand All @@ -614,23 +618,30 @@ func formatParameters(parameters [][]byte, formatCodes []int16) (formatted []str
// According to Bind message documentation, if there are no parameter
// format codes, it may mean that either there are no parameters, or
// that all parameters use default text format.
if len(formatCodes) == 0 {
formatted = append(formatted, string(p))
continue
var formatCode int16
switch len(formatCodes) {
case 0:
// use default 0 (text) format for all params.
case 1:
// apply the same format code to all params.
formatCode = formatCodes[0]
default:
// apply format code corresponding to this param.
formatCode = formatCodes[i]
}
switch formatCodes[i] {

switch formatCode {
case parameterFormatCodeText:
// Text parameters can just be converted to their string
// representation.
formatted = append(formatted, string(p))
case parameterFormatCodeBinary:
// For binary parameters, just put a placeholder to avoid
// spamming the audit log with unreadable info.
formatted = append(formatted, "<binary>")
// For binary parameters, encode the parameter as a base64 string.
formatted = append(formatted, base64.StdEncoding.EncodeToString(p))
default:
// Should never happen but...
logrus.Warnf("Unknown Postgres parameter format code: %#v.",
formatCodes[i])
formatCode)
formatted = append(formatted, "<unknown>")
}
}
Expand Down