From c672dff9d7a456abb81542ba5e61c372540e54b7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 4 Mar 2024 08:34:53 -0600 Subject: [PATCH] Fix CVE-2024-27304 SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer overflow in the calculated message size can cause the one large message to be sent as multiple messages under the attacker's control. Update to pgproto3/v2 v2.3.3 which checks for too large messages. --- auth_scram.go | 12 +++- go.mod | 2 +- go.sum | 2 + krb5.go | 6 +- pgconn.go | 169 +++++++++++++++++++++++++++++++++++++++++++------ pgconn_test.go | 16 +++-- 6 files changed, 178 insertions(+), 29 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index d8d7111..1545b7c 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -41,7 +41,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { AuthMechanism: "SCRAM-SHA-256", Data: sc.clientFirstMessage(), } - _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + buf, err := saslInitialResponse.Encode(nil) + if err != nil { + return err + } + _, err = c.conn.Write(buf) if err != nil { return err } @@ -60,7 +64,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { saslResponse := &pgproto3.SASLResponse{ Data: []byte(sc.clientFinalMessage()), } - _, err = c.conn.Write(saslResponse.Encode(nil)) + buf, err = saslResponse.Encode(nil) + if err != nil { + return err + } + _, err = c.conn.Write(buf) if err != nil { return err } diff --git a/go.mod b/go.mod index 98a95dc..e007fcc 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.3.2 + github.com/jackc/pgproto3/v2 v2.3.3 github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.6.0 diff --git a/go.sum b/go.sum index 272719c..fbf8fe7 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0= github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUOag= +github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= diff --git a/krb5.go b/krb5.go index 08427b8..1639b72 100644 --- a/krb5.go +++ b/krb5.go @@ -62,7 +62,11 @@ func (c *PgConn) gssAuth() error { gssResponse := &pgproto3.GSSResponse{ Data: nextData, } - _, err = c.conn.Write(gssResponse.Encode(nil)) + buf, err := gssResponse.Encode(nil) + if err != nil { + return err + } + _, err = c.conn.Write(buf) if err != nil { return err } diff --git a/pgconn.go b/pgconn.go index e531303..894baa2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -314,7 +314,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { + buf, err := startupMsg.Encode(pgConn.wbuf) + if err != nil { + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} + } + if _, err := pgConn.conn.Write(buf); err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } @@ -419,7 +423,11 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) + buf, err := msg.Encode(pgConn.wbuf) + if err != nil { + return err + } + _, err = pgConn.conn.Write(buf) return err } @@ -832,9 +840,19 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } buf := pgConn.wbuf - buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) - buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) + var err error + buf, err = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + if err != nil { + return nil, err + } + buf, err = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) + if err != nil { + return nil, err + } + buf, err = (&pgproto3.Sync{}).Encode(buf) + if err != nil { + return nil, err + } n, err := pgConn.conn.Write(buf) if err != nil { @@ -1006,7 +1024,14 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { } buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) + var err error + buf, err = (&pgproto3.Query{String: sql}).Encode(buf) + if err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } n, err := pgConn.conn.Write(buf) if err != nil { @@ -1080,8 +1105,24 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] } buf := pgConn.wbuf - buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) - buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + var err error + buf, err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + if err != nil { + result.concludeCommand(nil, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return result + } + + buf, err = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + if err != nil { + result.concludeCommand(nil, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return result + } pgConn.execExtendedSuffix(buf, result) @@ -1107,7 +1148,15 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } buf := pgConn.wbuf - buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + var err error + buf, err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + if err != nil { + result.concludeCommand(nil, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return result + } pgConn.execExtendedSuffix(buf, result) @@ -1150,9 +1199,31 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by } func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { - buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) - buf = (&pgproto3.Execute{}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) + var err error + buf, err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + if err != nil { + result.concludeCommand(nil, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return + } + buf, err = (&pgproto3.Execute{}).Encode(buf) + if err != nil { + result.concludeCommand(nil, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return + } + buf, err = (&pgproto3.Sync{}).Encode(buf) + if err != nil { + result.concludeCommand(nil, err) + pgConn.contextWatcher.Unwatch() + result.closed = true + pgConn.unlock() + return + } n, err := pgConn.conn.Write(buf) if err != nil { @@ -1186,7 +1257,12 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Send copy to command buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) + var err error + buf, err = (&pgproto3.Query{String: sql}).Encode(buf) + if err != nil { + pgConn.unlock() + return nil, err + } n, err := pgConn.conn.Write(buf) if err != nil { @@ -1246,7 +1322,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Send copy to command buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) + var err error + buf, err = (&pgproto3.Query{String: sql}).Encode(buf) + if err != nil { + pgConn.unlock() + return nil, err + } n, err := pgConn.conn.Write(buf) if err != nil { @@ -1322,10 +1403,20 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = buf[:0] if copyErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} - buf = copyDone.Encode(buf) + var err error + buf, err = copyDone.Encode(buf) + if err != nil { + pgConn.asyncClose() + return nil, err + } } else { copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} - buf = copyFail.Encode(buf) + var err error + buf, err = copyFail.Encode(buf) + if err != nil { + pgConn.asyncClose() + return nil, err + } } _, err = pgConn.conn.Write(buf) if err != nil { @@ -1632,24 +1723,54 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte + err error } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) - batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) - batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if batch.err != nil { + return &MultiResultReader{ + closed: true, + err: batch.err, + } + } + if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, @@ -1675,7 +1796,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.contextWatcher.Watch(ctx) } - batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) + if batch.err != nil { + multiResult.closed = true + multiResult.err = batch.err + pgConn.unlock() + return multiResult + } // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication diff --git a/pgconn_test.go b/pgconn_test.go index 3b5aa66..bd89771 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -2001,7 +2001,8 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { defer closeConn(t, pgConn) queryMsg := pgproto3.Query{String: "select 42"} - buf := queryMsg.Encode(nil) + buf, err := queryMsg.Encode(nil) + require.NoError(t, err) err = pgConn.SendBytes(ctx, buf) require.NoError(t, err) @@ -2315,9 +2316,9 @@ func TestSNISupport(t *testing.T) { return } - srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil)) - srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)) - srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))) serverSNINameChan <- sniHost }() @@ -2385,3 +2386,10 @@ func TestCopyFrom(t *testing.T) { _, err = pgConn.CopyFrom(context.Background(), r2, "COPY t FROM STDIN") assert.NoError(t, err) } + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +}