Skip to content

Commit

Permalink
Backport fixes from pgx v5
Browse files Browse the repository at this point in the history
Check for overflow on uint16 sizes in pgproto3

Do not allow protocol messages larger than ~1GB

The PostgreSQL server will reject messages greater than ~1 GB anyway.
However, worse than that is that a message that is larger than 4 GB
could wrap the 32-bit integer message size and be interpreted by the
server as multiple messages. This could allow a malicious client to
inject arbitrary protocol messages.

GHSA-mrww-27vc-gghv
  • Loading branch information
jackc committed Mar 4, 2024
1 parent 0c0f7b0 commit 945c212
Show file tree
Hide file tree
Showing 59 changed files with 359 additions and 359 deletions.
7 changes: 3 additions & 4 deletions authentication_cleartext_password.go
Expand Up @@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
8 changes: 4 additions & 4 deletions authentication_gss.go
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"

"github.com/jackc/pgio"
)

Expand All @@ -26,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
return nil
}

func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
return finishMessage(dst, sp)
}

func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
Expand Down
8 changes: 4 additions & 4 deletions authentication_gss_continue.go
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"

"github.com/jackc/pgio"
)

Expand All @@ -30,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
return nil
}

func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
return finishMessage(dst, sp)
}

func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
Expand Down
7 changes: 3 additions & 4 deletions authentication_md5_password.go
Expand Up @@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...)
return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
7 changes: 3 additions & 4 deletions authentication_ok.go
Expand Up @@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
10 changes: 3 additions & 7 deletions authentication_sasl.go
Expand Up @@ -46,10 +46,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASL)

for _, s := range src.AuthMechanisms {
Expand All @@ -58,9 +56,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
}
dst = append(dst, 0)

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
12 changes: 3 additions & 9 deletions authentication_sasl_continue.go
Expand Up @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)

dst = append(dst, src.Data...)

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
12 changes: 3 additions & 9 deletions authentication_sasl_final.go
Expand Up @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)

dst = append(dst, src.Data...)

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Unmarshaler.
Expand Down
15 changes: 10 additions & 5 deletions backend.go
Expand Up @@ -49,7 +49,12 @@ func NewBackend(cr ChunkReader, w io.Writer) *Backend {

// Send sends a message to the frontend.
func (b *Backend) Send(msg BackendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
buf, err := msg.Encode(nil)
if err != nil {
return err
}

_, err = b.w.Write(buf)
return err
}

Expand Down Expand Up @@ -184,11 +189,11 @@ func (b *Backend) Receive() (FrontendMessage, error) {
// contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage:
//
// Byte1('p')
// Byte1('p')
//
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
//
// Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend.
Expand Down
7 changes: 3 additions & 4 deletions backend_key_data.go
Expand Up @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) []byte {
dst = append(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'K')
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
4 changes: 2 additions & 2 deletions backend_test.go
Expand Up @@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
"username": "tester",
},
}
dst := []byte{}
dst = want.Encode(dst)
dst, err := want.Encode([]byte{})
require.NoError(t, err)

server := &interruptReader{}
server.push(dst)
Expand Down
21 changes: 14 additions & 7 deletions bind.go
Expand Up @@ -5,7 +5,9 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"

"github.com/jackc/pgio"
)
Expand Down Expand Up @@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'B')

dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)

if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}

if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
Expand All @@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, p...)
}

if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
4 changes: 2 additions & 2 deletions bind_complete.go
Expand Up @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) []byte {
return append(dst, '2', 0, 0, 0, 4)
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '2', 0, 0, 0, 4), nil
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
20 changes: 20 additions & 0 deletions bind_test.go
@@ -0,0 +1,20 @@
package pgproto3_test

import (
"testing"

"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/require"
)

func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
t.Parallel()

// Maximum allowed size.
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
require.NoError(t, err)

// 1 byte too big
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
require.Error(t, err)
}
4 changes: 2 additions & 2 deletions cancel_request.go
Expand Up @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte {
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
return dst, nil
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
14 changes: 3 additions & 11 deletions close.go
Expand Up @@ -4,8 +4,6 @@ import (
"bytes"
"encoding/json"
"errors"

"github.com/jackc/pgio"
)

type Close struct {
Expand Down Expand Up @@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)

func (src *Close) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
4 changes: 2 additions & 2 deletions close_complete.go
Expand Up @@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) []byte {
return append(dst, '3', 0, 0, 0, 4)
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '3', 0, 0, 0, 4), nil
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down
14 changes: 3 additions & 11 deletions command_complete.go
Expand Up @@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"

"github.com/jackc/pgio"
)

type CommandComplete struct {
Expand All @@ -28,17 +26,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
}

// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)

func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)

pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))

return dst
return finishMessage(dst, sp)
}

// MarshalJSON implements encoding/json.Marshaler.
Expand Down

0 comments on commit 945c212

Please sign in to comment.