Skip to content

Commit

Permalink
Do not allow protocol messages larger than ~1GB
Browse files Browse the repository at this point in the history
The PostgreSQL server will reject messages greater than ~1 GB anyway.
However, worse than that is that a message that is larger than 4 GB
could wrap the 32-bit integer message size and be interpreted by the
server as multiple messages. This could allow a malicious client to
inject arbitrary protocol messages.

GHSA-mrww-27vc-gghv
  • Loading branch information
jackc committed Mar 4, 2024
1 parent c1b0a01 commit adbb38f
Show file tree
Hide file tree
Showing 61 changed files with 472 additions and 390 deletions.
46 changes: 41 additions & 5 deletions pgconn/pgconn.go
Expand Up @@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
type Batch struct {
buf []byte
err error
}

// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}

batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
}

// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}

batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
if batch.err != nil {
return
}

batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
if batch.err != nil {
return
}

batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}
}

// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
// multiple queries in a single round trip than using pipeline mode.
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
if batch.err != nil {
return &MultiResultReader{
closed: true,
err: batch.err,
}
}

if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
Expand All @@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
pgConn.contextWatcher.Watch(ctx)
}

batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
return multiResult
}

pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
Expand Down
13 changes: 10 additions & 3 deletions pgconn/pgconn_test.go
Expand Up @@ -3363,9 +3363,9 @@ func TestSNISupport(t *testing.T) {
return
}

srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))

serverSNINameChan <- sniHost
}()
Expand Down Expand Up @@ -3472,3 +3472,10 @@ func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
err = pipeline.Close()
require.Error(t, err)
}

func mustEncode(buf []byte, err error) []byte {
if err != nil {
panic(err)
}
return buf
}
7 changes: 3 additions & 4 deletions pgproto3/authentication_cleartext_password.go
Expand Up @@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
7 changes: 3 additions & 4 deletions pgproto3/authentication_gss.go
Expand Up @@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
return nil
}

func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
return finishMessage(dst, sp)
}

func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
Expand Down
7 changes: 3 additions & 4 deletions pgproto3/authentication_gss_continue.go
Expand Up @@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
return nil
}

func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
return finishMessage(dst, sp)
}

func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
Expand Down
7 changes: 3 additions & 4 deletions pgproto3/authentication_md5_password.go
Expand Up @@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...)
return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
7 changes: 3 additions & 4 deletions pgproto3/authentication_ok.go
Expand Up @@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
10 changes: 3 additions & 7 deletions pgproto3/authentication_sasl.go
Expand Up @@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASL)

for _, s := range src.AuthMechanisms {
Expand All @@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
}
dst = append(dst, 0)

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
12 changes: 3 additions & 9 deletions pgproto3/authentication_sasl_continue.go
Expand Up @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)

dst = append(dst, src.Data...)

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
12 changes: 3 additions & 9 deletions pgproto3/authentication_sasl_final.go
Expand Up @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)

dst = append(dst, src.Data...)

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Unmarshaler.
Expand Down
25 changes: 21 additions & 4 deletions pgproto3/backend.go
Expand Up @@ -16,7 +16,8 @@ type Backend struct {
// before it is actually transmitted (i.e. before Flush).
tracer *tracer

wbuf []byte
wbuf []byte
encodeError error

// Frontend message flyweights
bind Bind
Expand Down Expand Up @@ -55,18 +56,34 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
}

// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
// called.
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
// encountered will be returned from Flush.
func (b *Backend) Send(msg BackendMessage) {
if b.encodeError != nil {
return
}

prevLen := len(b.wbuf)
b.wbuf = msg.Encode(b.wbuf)
newBuf, err := msg.Encode(b.wbuf)
if err != nil {
b.encodeError = err
return
}
b.wbuf = newBuf

if b.tracer != nil {
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
}
}

// Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error {
if err := b.encodeError; err != nil {
b.encodeError = nil
b.wbuf = b.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}

n, err := b.w.Write(b.wbuf)

const maxLen = 1024
Expand Down
7 changes: 3 additions & 4 deletions pgproto3/backend_key_data.go
Expand Up @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) []byte {
dst = append(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'K')
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
4 changes: 2 additions & 2 deletions pgproto3/backend_test.go
Expand Up @@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
"username": "tester",
},
}
dst := []byte{}
dst = want.Encode(dst)
dst, err := want.Encode([]byte{})
require.NoError(t, err)

server := &interruptReader{}
server.push(dst)
Expand Down
10 changes: 3 additions & 7 deletions pgproto3/bind.go
Expand Up @@ -108,10 +108,8 @@ func (dst *Bind) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'B')

dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
Expand Down Expand Up @@ -139,9 +137,7 @@ func (src *Bind) Encode(dst []byte) []byte {
dst = pgio.AppendInt16(dst, fc)
}

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
4 changes: 2 additions & 2 deletions pgproto3/bind_complete.go
Expand Up @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) []byte {
return append(dst, '2', 0, 0, 0, 4)
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '2', 0, 0, 0, 4), nil
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
20 changes: 20 additions & 0 deletions pgproto3/bind_test.go
@@ -0,0 +1,20 @@
package pgproto3_test

import (
"testing"

"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/require"
)

func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
t.Parallel()

// Maximum allowed size.
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
require.NoError(t, err)

// 1 byte too big
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
require.Error(t, err)
}
4 changes: 2 additions & 2 deletions pgproto3/cancel_request.go
Expand Up @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte {
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
return dst, nil
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down

0 comments on commit adbb38f

Please sign in to comment.