Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
SQL injection can occur if an attacker can cause a single query or bind
message to exceed 4 GB in size. An integer overflow in the calculated
message size can cause the one large message to be sent as multiple
messages under the attacker's control.

Update to pgproto3/v2 v2.3.3 which checks for too large messages.
  • Loading branch information
jackc committed Mar 4, 2024
1 parent e82f7d1 commit c672dff
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 29 deletions.
12 changes: 10 additions & 2 deletions auth_scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
AuthMechanism: "SCRAM-SHA-256",
Data: sc.clientFirstMessage(),
}
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
buf, err := saslInitialResponse.Encode(nil)
if err != nil {
return err
}
_, err = c.conn.Write(buf)
if err != nil {
return err
}
Expand All @@ -60,7 +64,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
}
_, err = c.conn.Write(saslResponse.Encode(nil))
buf, err = saslResponse.Encode(nil)
if err != nil {
return err
}
_, err = c.conn.Write(buf)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/jackc/pgio v1.0.0
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65
github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgproto3/v2 v2.3.2
github.com/jackc/pgproto3/v2 v2.3.3
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a
github.com/stretchr/testify v1.8.1
golang.org/x/crypto v0.6.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y
github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0=
github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUOag=
github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
Expand Down
6 changes: 5 additions & 1 deletion krb5.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ func (c *PgConn) gssAuth() error {
gssResponse := &pgproto3.GSSResponse{
Data: nextData,
}
_, err = c.conn.Write(gssResponse.Encode(nil))
buf, err := gssResponse.Encode(nil)
if err != nil {
return err
}
_, err = c.conn.Write(buf)
if err != nil {
return err
}
Expand Down
169 changes: 148 additions & 21 deletions pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
startupMsg.Parameters["database"] = config.Database
}

if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
buf, err := startupMsg.Encode(pgConn.wbuf)
if err != nil {
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
}
if _, err := pgConn.conn.Write(buf); err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
}
Expand Down Expand Up @@ -419,7 +423,11 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {

func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
msg := &pgproto3.PasswordMessage{Password: password}
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))
buf, err := msg.Encode(pgConn.wbuf)
if err != nil {
return err
}
_, err = pgConn.conn.Write(buf)
return err
}

Expand Down Expand Up @@ -832,9 +840,19 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
}

buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
var err error
buf, err = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
if err != nil {
return nil, err
}
buf, err = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
if err != nil {
return nil, err
}
buf, err = (&pgproto3.Sync{}).Encode(buf)
if err != nil {
return nil, err
}

n, err := pgConn.conn.Write(buf)
if err != nil {
Expand Down Expand Up @@ -1006,7 +1024,14 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
}

buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
var err error
buf, err = (&pgproto3.Query{String: sql}).Encode(buf)
if err != nil {
return &MultiResultReader{
closed: true,
err: err,
}
}

n, err := pgConn.conn.Write(buf)
if err != nil {
Expand Down Expand Up @@ -1080,8 +1105,24 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
}

buf := pgConn.wbuf
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
var err error
buf, err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
if err != nil {
result.concludeCommand(nil, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
return result
}

buf, err = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
if err != nil {
result.concludeCommand(nil, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
return result
}

pgConn.execExtendedSuffix(buf, result)

Expand All @@ -1107,7 +1148,15 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
}

buf := pgConn.wbuf
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
var err error
buf, err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
if err != nil {
result.concludeCommand(nil, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
return result
}

pgConn.execExtendedSuffix(buf, result)

Expand Down Expand Up @@ -1150,9 +1199,31 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
}

func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
var err error
buf, err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
if err != nil {
result.concludeCommand(nil, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
return
}
buf, err = (&pgproto3.Execute{}).Encode(buf)
if err != nil {
result.concludeCommand(nil, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
return
}
buf, err = (&pgproto3.Sync{}).Encode(buf)
if err != nil {
result.concludeCommand(nil, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
return
}

n, err := pgConn.conn.Write(buf)
if err != nil {
Expand Down Expand Up @@ -1186,7 +1257,12 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm

// Send copy to command
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
var err error
buf, err = (&pgproto3.Query{String: sql}).Encode(buf)
if err != nil {
pgConn.unlock()
return nil, err
}

n, err := pgConn.conn.Write(buf)
if err != nil {
Expand Down Expand Up @@ -1246,7 +1322,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co

// Send copy to command
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
var err error
buf, err = (&pgproto3.Query{String: sql}).Encode(buf)
if err != nil {
pgConn.unlock()
return nil, err
}

n, err := pgConn.conn.Write(buf)
if err != nil {
Expand Down Expand Up @@ -1322,10 +1403,20 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
buf = buf[:0]
if copyErr == io.EOF || pgErr != nil {
copyDone := &pgproto3.CopyDone{}
buf = copyDone.Encode(buf)
var err error
buf, err = copyDone.Encode(buf)
if err != nil {
pgConn.asyncClose()
return nil, err
}
} else {
copyFail := &pgproto3.CopyFail{Message: copyErr.Error()}
buf = copyFail.Encode(buf)
var err error
buf, err = copyFail.Encode(buf)
if err != nil {
pgConn.asyncClose()
return nil, err
}
}
_, err = pgConn.conn.Write(buf)
if err != nil {
Expand Down Expand Up @@ -1632,24 +1723,54 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
type Batch struct {
buf []byte
err error
}

// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}

batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
}

// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}

batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
if batch.err != nil {
return
}

batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
if batch.err != nil {
return
}

batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}
}

// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
// transaction is already in progress or SQL contains transaction control statements.
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
if batch.err != nil {
return &MultiResultReader{
closed: true,
err: batch.err,
}
}

if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
Expand All @@ -1675,7 +1796,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
pgConn.contextWatcher.Watch(ctx)
}

batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
return multiResult
}

// A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is
// closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication
Expand Down
16 changes: 12 additions & 4 deletions pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,8 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) {
defer closeConn(t, pgConn)

queryMsg := pgproto3.Query{String: "select 42"}
buf := queryMsg.Encode(nil)
buf, err := queryMsg.Encode(nil)
require.NoError(t, err)

err = pgConn.SendBytes(ctx, buf)
require.NoError(t, err)
Expand Down Expand Up @@ -2315,9 +2316,9 @@ func TestSNISupport(t *testing.T) {
return
}

srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))

serverSNINameChan <- sniHost
}()
Expand Down Expand Up @@ -2385,3 +2386,10 @@ func TestCopyFrom(t *testing.T) {
_, err = pgConn.CopyFrom(context.Background(), r2, "COPY t FROM STDIN")
assert.NoError(t, err)
}

func mustEncode(buf []byte, err error) []byte {
if err != nil {
panic(err)
}
return buf
}

0 comments on commit c672dff

Please sign in to comment.