Skip to content

Commit

Permalink
Support auditing chunked SQL Server packets (#29228) (#30245)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielcorado committed Aug 10, 2023
1 parent 5484124 commit a54e268
Show file tree
Hide file tree
Showing 7 changed files with 459 additions and 76 deletions.
52 changes: 45 additions & 7 deletions lib/srv/db/sqlserver/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package sqlserver

import (
"bytes"
"context"
"io"
"net"
Expand Down Expand Up @@ -134,6 +135,11 @@ func (e *Engine) receiveFromClient(clientConn, serverConn io.ReadWriteCloser, cl
e.Log.Debug("Stop receiving from client.")
close(clientErrCh)
}()
// initialPacketHeader and chunkData are used to accumulate chunked packets
// to build a single packet with full contents for auditing.
var initialPacketHeader protocol.PacketHeader
var chunkData bytes.Buffer

for {
p, err := protocol.ReadPacket(clientConn)
if err != nil {
Expand All @@ -146,13 +152,24 @@ func (e *Engine) receiveFromClient(clientConn, serverConn io.ReadWriteCloser, cl
return
}

sqlPacket, err := protocol.ToSQLPacket(p)
switch {
case err != nil:
e.Log.WithError(err).Errorf("Failed to parse SQLServer packet.")
e.emitMalformedPacket(e.Context, sessionCtx, p)
default:
e.auditPacket(e.Context, sessionCtx, sqlPacket)
// Audit events are going to be emitted only on final messages, this way
// the packet parsing can be complete and provide the query/RPC
// contents.
if protocol.IsFinalPacket(p) {
sqlPacket, err := e.toSQLPacket(initialPacketHeader, p, &chunkData)
switch {
case err != nil:
e.Log.WithError(err).Errorf("Failed to parse SQLServer packet.")
e.emitMalformedPacket(e.Context, sessionCtx, p)
default:
e.auditPacket(e.Context, sessionCtx, sqlPacket)
}
} else {
if chunkData.Len() == 0 {
initialPacketHeader = p.Header()
}

chunkData.Write(p.Data())
}

_, err = serverConn.Write(p.Bytes())
Expand All @@ -164,6 +181,27 @@ func (e *Engine) receiveFromClient(clientConn, serverConn io.ReadWriteCloser, cl
}
}

// toSQLPacket Parses a regular (self-contained) or chunked packet into an SQL
// packet (used for auditing).
func (e *Engine) toSQLPacket(header protocol.PacketHeader, packet *protocol.BasicPacket, chunks *bytes.Buffer) (protocol.Packet, error) {
if chunks.Len() > 0 {
defer chunks.Reset()
chunks.Write(packet.Data())
// We're safe to "read" chunk using `Bytes()` function because the
// packet processing copies the packet contents.
packetData := chunks.Bytes()

var err error
// The final chucked packet header must be the first packet header.
packet, err = protocol.NewBasicPacket(header, packetData)
if err != nil {
return nil, trace.Wrap(err)
}
}

return protocol.ToSQLPacket(packet)
}

// receiveFromServer relays protocol messages received from MySQL database
// to MySQL client.
func (e *Engine) receiveFromServer(serverConn, clientConn io.ReadWriteCloser, serverErrCh chan<- error) {
Expand Down
139 changes: 126 additions & 13 deletions lib/srv/db/sqlserver/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ func TestHandleConnectionAuditEvents(t *testing.T) {
}

tests := []struct {
name string
packet []byte
checks []check
name string
packets [][]byte
checks []check
}{
{
name: "rpc request procedure",
packet: fixtures.RPCClientRequest,
name: "rpc request procedure",
packets: [][]byte{fixtures.GenerateCustomRPCCallPacket("foo3")},
checks: []check{
hasNoErr(),
hasAuditEventCode(libevents.DatabaseSessionStartCode),
Expand All @@ -90,8 +90,8 @@ func TestHandleConnectionAuditEvents(t *testing.T) {
},
},
{
name: "rpc request param",
packet: fixtures.RPCClientRequestParam,
name: "rpc request param",
packets: [][]byte{fixtures.GenerateExecuteSQLRPCPacket("select @@version")},
checks: []check{
hasNoErr(),
hasAuditEventCode(libevents.DatabaseSessionStartCode),
Expand All @@ -110,8 +110,8 @@ func TestHandleConnectionAuditEvents(t *testing.T) {
},
},
{
name: "sql batch",
packet: fixtures.SQLBatch,
name: "sql batch",
packets: [][]byte{fixtures.GenerateBatchQueryPacket("\nselect 'foo' as 'bar'\n ")},
checks: []check{
hasNoErr(),
hasAuditEventCode(libevents.DatabaseSessionStartCode),
Expand All @@ -132,15 +132,114 @@ func TestHandleConnectionAuditEvents(t *testing.T) {
},
},
{
name: "malformed packet",
packet: fixtures.MalformedPacketTest,
name: "malformed packet",
packets: [][]byte{fixtures.MalformedPacketTest},
checks: []check{
hasNoErr(),
hasAuditEventCode(libevents.DatabaseSessionStartCode),
hasAuditEventCode(libevents.DatabaseSessionEndCode),
hasAuditEventCode(libevents.DatabaseSessionMalformedPacketCode),
},
},
{
name: "sql batch chunked packets",
packets: fixtures.GenerateBatchQueryChunkedPacket(5, "select 'foo' as 'bar'"),
checks: []check{
hasNoErr(),
hasAuditEventCode(libevents.DatabaseSessionStartCode),
hasAuditEventCode(libevents.DatabaseSessionEndCode),
hasAuditEvent(1, &events.DatabaseSessionQuery{
DatabaseMetadata: events.DatabaseMetadata{
DatabaseUser: "sa",
},
Metadata: events.Metadata{
Type: libevents.DatabaseSessionQueryEvent,
Code: libevents.DatabaseSessionQueryCode,
},
DatabaseQuery: "select 'foo' as 'bar'",
Status: events.Status{
Success: true,
},
}),
},
},
{
name: "rpc request param chunked",
packets: fixtures.GenerateExecuteSQLRPCChunkedPacket(5, "select @@version"),
checks: []check{
hasNoErr(),
hasAuditEventCode(libevents.DatabaseSessionStartCode),
hasAuditEventCode(libevents.DatabaseSessionEndCode),
hasAuditEvent(1, &events.SQLServerRPCRequest{
DatabaseMetadata: events.DatabaseMetadata{
DatabaseUser: "sa",
},
Metadata: events.Metadata{
Type: libevents.DatabaseSessionSQLServerRPCRequestEvent,
Code: libevents.SQLServerRPCRequestCode,
},
Parameters: []string{"select @@version"},
Procname: "Sp_ExecuteSql",
}),
},
},
{
name: "intercalated chunked messages",
packets: intercalateChunkedPacketMessages(
fixtures.GenerateExecuteSQLRPCChunkedPacket(5, "select @@version"),
fixtures.GenerateExecuteSQLRPCPacket("select 1"),
2,
),
checks: []check{
hasNoErr(),
hasAuditEventCode(libevents.DatabaseSessionStartCode),
hasAuditEventCode(libevents.DatabaseSessionEndCode),
hasAuditEvent(1, &events.SQLServerRPCRequest{
DatabaseMetadata: events.DatabaseMetadata{
DatabaseUser: "sa",
},
Metadata: events.Metadata{
Type: libevents.DatabaseSessionSQLServerRPCRequestEvent,
Code: libevents.SQLServerRPCRequestCode,
},
Parameters: []string{"select @@version"},
Procname: "Sp_ExecuteSql",
}),
hasAuditEvent(2, &events.SQLServerRPCRequest{
DatabaseMetadata: events.DatabaseMetadata{
DatabaseUser: "sa",
},
Metadata: events.Metadata{
Type: libevents.DatabaseSessionSQLServerRPCRequestEvent,
Code: libevents.SQLServerRPCRequestCode,
},
Parameters: []string{"select 1"},
Procname: "Sp_ExecuteSql",
}),
hasAuditEvent(3, &events.SQLServerRPCRequest{
DatabaseMetadata: events.DatabaseMetadata{
DatabaseUser: "sa",
},
Metadata: events.Metadata{
Type: libevents.DatabaseSessionSQLServerRPCRequestEvent,
Code: libevents.SQLServerRPCRequestCode,
},
Parameters: []string{"select @@version"},
Procname: "Sp_ExecuteSql",
}),
hasAuditEvent(4, &events.SQLServerRPCRequest{
DatabaseMetadata: events.DatabaseMetadata{
DatabaseUser: "sa",
},
Metadata: events.Metadata{
Type: libevents.DatabaseSessionSQLServerRPCRequestEvent,
Code: libevents.SQLServerRPCRequestCode,
},
Parameters: []string{"select 1"},
Procname: "Sp_ExecuteSql",
}),
},
},
}

for _, tc := range tests {
Expand All @@ -149,8 +248,11 @@ func TestHandleConnectionAuditEvents(t *testing.T) {
_, err := b.Write(fixtures.Login7)
require.NoError(t, err)

_, err = b.Write(tc.packet)
require.NoError(t, err)
for _, packet := range tc.packets {
_, err = b.Write(packet)
require.NoError(t, err)
}

emitterMock := &mockEmitter{}
audit, err := common.NewAudit(common.AuditConfig{Emitter: emitterMock})
require.NoError(t, err)
Expand Down Expand Up @@ -183,6 +285,17 @@ func TestHandleConnectionAuditEvents(t *testing.T) {
}
}

// intercalateChunkedPacketMessages intercalates a chunked packet with a regular packet a specified number of times.
func intercalateChunkedPacketMessages(chunkedPacket [][]byte, regularPacket []byte, repeat int) [][]byte {
var result [][]byte
for i := 0; i < repeat; i++ {
result = append(result, chunkedPacket...)
result = append(result, regularPacket)
}

return result
}

type mockConn struct {
net.Conn
buff bytes.Buffer
Expand Down
5 changes: 3 additions & 2 deletions lib/srv/db/sqlserver/protocol/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ const (
// packetHeaderSize is the size of the protocol packet header.
packetHeaderSize = 8

// packetStatusLast indicates that the packet is the last in the request.
packetStatusLast uint8 = 0x01
// PacketStatusLast indicates that the packet is the last in the request.
// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/ce398f9a-7d47-4ede-8f36-9dd6fc21ca43
PacketStatusLast uint8 = 0x01

preLoginOptionVersion = 0x00
preLoginOptionEncryption = 0x01
Expand Down

0 comments on commit a54e268

Please sign in to comment.