From f04c58338b58927573d7e664c19b325220d848a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 10:02:38 -0500 Subject: [PATCH] Introduce pgproto3 package pgproto3 will wrap the message encoding and decoding for the PostgreSQL frontend/backend protocol version 3. --- .gitignore | 1 + .travis.yml | 2 +- conn.go | 312 ++++++++++++----------------- conn_pool_test.go | 2 +- copy_from.go | 30 +-- fastpath.go | 22 +- messages.go | 4 +- pgproto3/authentication.go | 54 +++++ pgproto3/backend_key_data.go | 47 +++++ pgproto3/big_endian.go | 37 ++++ pgproto3/bind_complete.go | 29 +++ pgproto3/close_complete.go | 29 +++ pgproto3/command_complete.go | 47 +++++ pgproto3/copy_both_response.go | 64 ++++++ pgproto3/copy_data.go | 41 ++++ pgproto3/copy_in_response.go | 64 ++++++ pgproto3/copy_out_response.go | 64 ++++++ pgproto3/data_row.go | 103 ++++++++++ pgproto3/empty_query_response.go | 29 +++ pgproto3/error_response.go | 197 ++++++++++++++++++ pgproto3/frontend.go | 70 +++++++ pgproto3/function_call_response.go | 73 +++++++ pgproto3/no_data.go | 29 +++ pgproto3/notice_response.go | 13 ++ pgproto3/notification_response.go | 65 ++++++ pgproto3/parameter_description.go | 60 ++++++ pgproto3/parameter_status.go | 62 ++++++ pgproto3/parse_complete.go | 29 +++ pgproto3/pgproto3.go | 88 ++++++++ pgproto3/query.go | 43 ++++ pgproto3/ready_for_query.go | 35 ++++ pgproto3/row_description.go | 101 ++++++++++ query.go | 30 ++- replication.go | 70 ++++--- 34 files changed, 1680 insertions(+), 266 deletions(-) create mode 100644 pgproto3/authentication.go create mode 100644 pgproto3/backend_key_data.go create mode 100644 pgproto3/big_endian.go create mode 100644 pgproto3/bind_complete.go create mode 100644 pgproto3/close_complete.go create mode 100644 pgproto3/command_complete.go create mode 100644 pgproto3/copy_both_response.go create mode 100644 pgproto3/copy_data.go create mode 100644 pgproto3/copy_in_response.go create mode 100644 pgproto3/copy_out_response.go create mode 100644 pgproto3/data_row.go create mode 100644 pgproto3/empty_query_response.go create mode 100644 pgproto3/error_response.go create mode 100644 pgproto3/frontend.go create mode 100644 pgproto3/function_call_response.go create mode 100644 pgproto3/no_data.go create mode 100644 pgproto3/notice_response.go create mode 100644 pgproto3/notification_response.go create mode 100644 pgproto3/parameter_description.go create mode 100644 pgproto3/parameter_status.go create mode 100644 pgproto3/parse_complete.go create mode 100644 pgproto3/pgproto3.go create mode 100644 pgproto3/query.go create mode 100644 pgproto3/ready_for_query.go create mode 100644 pgproto3/row_description.go diff --git a/.gitignore b/.gitignore index cb0cd901b..0ff008008 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ _testmain.go *.exe conn_config_test.go +.envrc diff --git a/.travis.yml b/.travis.yml index 0045cf5a8..edacab396 100644 --- a/.travis.yml +++ b/.travis.yml @@ -52,7 +52,7 @@ install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake - - go get -u github.com/jackc/pgmock/pgmsg + - go get -u github.com/jackc/pgmock/pgproto3 - go get -u github.com/lib/pq - go get -u github.com/hashicorp/go-version - go get -u github.com/satori/go.uuid diff --git a/conn.go b/conn.go index c2cb408f4..7487b8ada 100644 --- a/conn.go +++ b/conn.go @@ -20,7 +20,7 @@ import ( "sync/atomic" "time" - "github.com/jackc/pgx/chunkreader" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -88,8 +88,8 @@ type Conn struct { lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf - pid int32 // backend pid - secretKey int32 // key to use to send a cancel query message to the server + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server config ConnConfig // config used when establishing this connection txStatus byte @@ -98,7 +98,6 @@ type Conn struct { notifications []*Notification logger Logger logLevel int - mr msgReader fp *fastpath poolResetCount int preallocatedRows []Rows @@ -116,6 +115,8 @@ type Conn struct { closedChan chan error ConnInfo *pgtype.ConnInfo + + frontend *pgproto3.Frontend } // PreparedStatement is a description of a prepared statement @@ -133,7 +134,7 @@ type PrepareExOptions struct { // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system type Notification struct { - PID int32 // backend pid that sent the notification + PID uint32 // backend pid that sent the notification Channel string // channel from which notification was received Payload string } @@ -213,8 +214,6 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.logLevel = LogLevelDebug } c.logger = c.config.Logger - c.mr.log = c.log - c.mr.shouldLog = c.shouldLog if c.config.User == "" { user, err := user.Current() @@ -290,7 +289,10 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.mr.cr = chunkreader.NewChunkReader(c.conn) + c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn) + if err != nil { + return err + } msg := newStartupMessage() @@ -317,29 +319,27 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case backendKeyData: - c.rxBackendKeyData(r) - case authenticationX: - if err = c.rxAuthenticationX(r); err != nil { + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + c.rxBackendKeyData(msg) + case *pgproto3.Authentication: + if err = c.rxAuthenticationX(msg); err != nil { return err } - case readyForQuery: - c.rxReadyForQuery(r) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "Connection established") } // Replication connections can't execute the queries to // populate the c.PgTypes and c.pgsqlAfInet - if _, ok := msg.options["replication"]; ok { + if _, ok := config.RuntimeParams["replication"]; ok { return nil } @@ -352,7 +352,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return nil default: - if err = c.processContextFreeMsg(t, r); err != nil { + if err = c.processContextFreeMsg(msg); err != nil { return err } } @@ -393,7 +393,7 @@ where ( } // PID returns the backend PID for this connection. -func (c *Conn) PID() int32 { +func (c *Conn) PID() uint32 { return c.pid } @@ -744,22 +744,20 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared var softErr error for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return nil, err } - switch t { - case parameterDescription: - ps.ParameterOids = c.rxParameterDescription(r) + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + ps.ParameterOids = c.rxParameterDescription(msg) if len(ps.ParameterOids) > 65535 && softErr == nil { softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) } - case rowDescription: - ps.FieldDescriptions = c.rxRowDescription(r) + case *pgproto3.RowDescription: + ps.FieldDescriptions = c.rxRowDescription(msg) for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok { ps.FieldDescriptions[i].DataTypeName = dt.Name @@ -772,8 +770,8 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) } } - case readyForQuery: - c.rxReadyForQuery(r) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) if softErr == nil { c.preparedStatements[name] = ps @@ -781,7 +779,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared return ps, softErr default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { softErr = e } } @@ -830,18 +828,16 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { } for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case closeComplete: + switch msg.(type) { + case *pgproto3.CloseComplete: return nil default: - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return err } @@ -908,12 +904,12 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat } for { - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return nil, err } - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return nil, err } @@ -1030,62 +1026,48 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag // meaningful in a given context. These messages can occur due to a context // deadline interrupting message processing. For example, an interrupted query // may have left DataRow messages on the wire. -func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { - switch t { - case bindComplete: - case commandComplete: - case dataRow: - case emptyQueryResponse: - case errorResponse: - return c.rxErrorResponse(r) - case noData: - case noticeResponse: - case notificationResponse: - c.rxNotificationResponse(r) - case parameterDescription: - case parseComplete: - case readyForQuery: - c.rxReadyForQuery(r) - case rowDescription: - case 'S': - c.rxParameterStatus(r) - - default: - return fmt.Errorf("Received unknown message type: %c", t) +func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return c.rxErrorResponse(msg) + case *pgproto3.NotificationResponse: + c.rxNotificationResponse(msg) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) + case *pgproto3.ParameterStatus: + c.rxParameterStatus(msg) } return nil } -func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { +func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { if atomic.LoadInt32(&c.status) < connStatusIdle { - return 0, nil, ErrDeadConn + return nil, ErrDeadConn } - t, err = c.mr.rxMsg() + msg, err := c.frontend.Receive() if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { c.die(err) } + return nil, err } c.lastActivityTime = time.Now() - if c.shouldLog(LogLevelTrace) { - c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody)) - } + // fmt.Printf("rxMsg: %#v\n", msg) - return t, &c.mr, err + return msg, nil } -func (c *Conn) rxAuthenticationX(r *msgReader) (err error) { - switch r.readInt32() { - case 0: // AuthenticationOk - case 3: // AuthenticationCleartextPassword +func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { + switch msg.Type { + case pgproto3.AuthTypeOk: + case pgproto3.AuthTypeCleartextPassword: err = c.txPasswordMessage(c.config.Password) - case 5: // AuthenticationMD5Password - salt := r.readString(4) - digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt) + case pgproto3.AuthTypeMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:])) err = c.txPasswordMessage(digestedPassword) default: err = errors.New("Received unknown authentication message") @@ -1100,115 +1082,75 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *Conn) rxParameterStatus(r *msgReader) { - key := r.readCString() - value := r.readCString() - c.RuntimeParams[key] = value -} - -func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { - for { - switch r.readByte() { - case 'S': - err.Severity = r.readCString() - case 'C': - err.Code = r.readCString() - case 'M': - err.Message = r.readCString() - case 'D': - err.Detail = r.readCString() - case 'H': - err.Hint = r.readCString() - case 'P': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.Position = int32(n) - case 'p': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.InternalPosition = int32(n) - case 'q': - err.InternalQuery = r.readCString() - case 'W': - err.Where = r.readCString() - case 's': - err.SchemaName = r.readCString() - case 't': - err.TableName = r.readCString() - case 'c': - err.ColumnName = r.readCString() - case 'd': - err.DataTypeName = r.readCString() - case 'n': - err.ConstraintName = r.readCString() - case 'F': - err.File = r.readCString() - case 'L': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.Line = int32(n) - case 'R': - err.Routine = r.readCString() - - case 0: // End of error message - if err.Severity == "FATAL" { - c.die(err) - } - return - default: // Ignore other error fields - r.readCString() - } +func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) { + c.RuntimeParams[msg.Name] = msg.Value +} + +func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError { + err := PgError{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, + } + + if err.Severity == "FATAL" { + c.die(err) } + + return err } -func (c *Conn) rxBackendKeyData(r *msgReader) { - c.pid = r.readInt32() - c.secretKey = r.readInt32() +func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { + c.pid = msg.ProcessID + c.secretKey = msg.SecretKey } -func (c *Conn) rxReadyForQuery(r *msgReader) { +func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) { c.readyForQuery = true - c.txStatus = r.readByte() -} - -func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { - fieldCount := r.readInt16() - fields = make([]FieldDescription, fieldCount) - for i := int16(0); i < fieldCount; i++ { - f := &fields[i] - f.Name = r.readCString() - f.Table = pgtype.Oid(r.readUint32()) - f.AttributeNumber = r.readInt16() - f.DataType = pgtype.Oid(r.readUint32()) - f.DataTypeSize = r.readInt16() - f.Modifier = r.readInt32() - f.FormatCode = r.readInt16() - } - return + c.txStatus = msg.TxStatus } -func (c *Conn) rxParameterDescription(r *msgReader) (parameters []pgtype.Oid) { - // Internally, PostgreSQL supports greater than 64k parameters to a prepared - // statement. But the parameter description uses a 16-bit integer for the - // count of parameters. If there are more than 64K parameters, this count is - // wrong. So read the count, ignore it, and compute the proper value from - // the size of the message. - r.readInt16() - parameterCount := len(r.msgBody[r.rp:]) / 4 - - parameters = make([]pgtype.Oid, 0, parameterCount) +func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription { + fields := make([]FieldDescription, len(msg.Fields)) + for i := 0; i < len(fields); i++ { + fields[i].Name = msg.Fields[i].Name + fields[i].Table = pgtype.Oid(msg.Fields[i].TableOID) + fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber + fields[i].DataType = pgtype.Oid(msg.Fields[i].DataTypeOID) + fields[i].DataTypeSize = msg.Fields[i].DataTypeSize + fields[i].Modifier = msg.Fields[i].TypeModifier + fields[i].FormatCode = msg.Fields[i].Format + } + return fields +} - for i := 0; i < parameterCount; i++ { - parameters = append(parameters, pgtype.Oid(r.readUint32())) +func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.Oid { + parameters := make([]pgtype.Oid, len(msg.ParameterOIDs)) + for i := 0; i < len(parameters); i++ { + parameters[i] = pgtype.Oid(msg.ParameterOIDs[i]) } - return + return parameters } -func (c *Conn) rxNotificationResponse(r *msgReader) { +func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) { n := new(Notification) - n.PID = r.readInt32() - n.Channel = r.readCString() - n.Payload = r.readCString() + n.PID = msg.PID + n.Channel = msg.Channel + n.Payload = msg.Payload c.notifications = append(c.notifications, n) } @@ -1453,21 +1395,19 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, var softErr error for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() + msg, err := c.rxMsg() if err != nil { return commandTag, err } - switch t { - case readyForQuery: - c.rxReadyForQuery(r) + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) return commandTag, softErr - case commandComplete: - commandTag = CommandTag(r.readCString()) + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { softErr = e } } @@ -1545,19 +1485,19 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { func (c *Conn) ensureConnectionReadyForQuery() error { for !c.readyForQuery { - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case errorResponse: - pgErr := c.rxErrorResponse(r) + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr := c.rxErrorResponse(msg) if pgErr.Severity == "FATAL" { return pgErr } default: - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return err } diff --git a/conn_pool_test.go b/conn_pool_test.go index 825638b62..42f37eb1c 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -686,7 +686,7 @@ func TestConnPoolBeginRetry(t *testing.T) { } defer tx.Rollback() - var txPID int32 + var txPID uint32 err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID) if err != nil { t.Fatalf("tx.QueryRow Scan failed: %v", err) diff --git a/copy_from.go b/copy_from.go index 9fc76a7b4..7d8dead1d 100644 --- a/copy_from.go +++ b/copy_from.go @@ -3,6 +3,8 @@ package pgx import ( "bytes" "fmt" + + "github.com/jackc/pgx/pgproto3" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice @@ -54,25 +56,25 @@ type copyFrom struct { func (ct *copyFrom) readUntilReadyForQuery() { for { - t, r, err := ct.conn.rxMsg() + msg, err := ct.conn.rxMsg() if err != nil { ct.readerErrChan <- err close(ct.readerErrChan) return } - switch t { - case readyForQuery: - ct.conn.rxReadyForQuery(r) + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + ct.conn.rxReadyForQuery(msg) close(ct.readerErrChan) return - case commandComplete: - case errorResponse: - ct.readerErrChan <- ct.conn.rxErrorResponse(r) + case *pgproto3.CommandComplete: + case *pgproto3.ErrorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(msg) default: - err = ct.conn.processContextFreeMsg(t, r) + err = ct.conn.processContextFreeMsg(msg) if err != nil { - ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + ct.readerErrChan <- ct.conn.processContextFreeMsg(msg) } } } @@ -190,18 +192,16 @@ func (ct *copyFrom) run() (int, error) { func (c *Conn) readUntilCopyInResponse() error { for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case copyInResponse: + switch msg := msg.(type) { + case *pgproto3.CopyInResponse: return nil default: - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return err } diff --git a/fastpath.go b/fastpath.go index 0caba9d34..75681c9cf 100644 --- a/fastpath.go +++ b/fastpath.go @@ -3,6 +3,7 @@ package pgx import ( "encoding/binary" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -71,23 +72,20 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) { } for { - var t byte - var r *msgReader - t, r, err = f.cn.rxMsg() + msg, err := f.cn.rxMsg() if err != nil { return nil, err } - switch t { - case 'V': // FunctionCallResponse - data := r.readBytes(r.readInt32()) - res = make([]byte, len(data)) - copy(res, data) - case 'Z': // Ready for query - f.cn.rxReadyForQuery(r) + switch msg := msg.(type) { + case *pgproto3.FunctionCallResponse: + res = make([]byte, len(msg.Result)) + copy(res, msg.Result) + case *pgproto3.ReadyForQuery: + f.cn.rxReadyForQuery(msg) // done - return + return res, err default: - if err := f.cn.processContextFreeMsg(t, r); err != nil { + if err := f.cn.processContextFreeMsg(msg); err != nil { return nil, err } } diff --git a/messages.go b/messages.go index 68faf14ca..e229367a7 100644 --- a/messages.go +++ b/messages.go @@ -58,11 +58,11 @@ func (s *startupMessage) Bytes() (buf []byte) { type FieldDescription struct { Name string Table pgtype.Oid - AttributeNumber int16 + AttributeNumber uint16 DataType pgtype.Oid DataTypeSize int16 DataTypeName string - Modifier int32 + Modifier uint32 FormatCode int16 } diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go new file mode 100644 index 000000000..e265a2471 --- /dev/null +++ b/pgproto3/authentication.go @@ -0,0 +1,54 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 +) + +type Authentication struct { + Type uint32 + + // MD5Password fields + Salt [4]byte +} + +func (*Authentication) Backend() {} + +func (dst *Authentication) UnmarshalBinary(src []byte) error { + *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} + + switch dst.Type { + case AuthTypeOk: + case AuthTypeCleartextPassword: + case AuthTypeMD5Password: + copy(dst.Salt[:], src[4:8]) + default: + return fmt.Errorf("unknown authentication type: %d", dst.Type) + } + + return nil +} + +func (src *Authentication) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('R') + buf.Write(bigEndian.Uint32(0)) + buf.Write(bigEndian.Uint32(src.Type)) + + switch src.Type { + case AuthTypeMD5Password: + buf.Write(src.Salt[:]) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go new file mode 100644 index 000000000..5d8eb4969 --- /dev/null +++ b/pgproto3/backend_key_data.go @@ -0,0 +1,47 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type BackendKeyData struct { + ProcessID uint32 + SecretKey uint32 +} + +func (*BackendKeyData) Backend() {} + +func (dst *BackendKeyData) UnmarshalBinary(src []byte) error { + if len(src) != 8 { + return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} + } + + dst.ProcessID = binary.BigEndian.Uint32(src[:4]) + dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + + return nil +} + +func (src *BackendKeyData) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('K') + buf.Write(bigEndian.Uint32(12)) + buf.Write(bigEndian.Uint32(src.ProcessID)) + buf.Write(bigEndian.Uint32(src.SecretKey)) + return buf.Bytes(), nil +} + +func (src *BackendKeyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "BackendKeyData", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/pgproto3/big_endian.go b/pgproto3/big_endian.go new file mode 100644 index 000000000..f7bdb97eb --- /dev/null +++ b/pgproto3/big_endian.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/binary" +) + +type BigEndianBuf [8]byte + +func (b BigEndianBuf) Int16(n int16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, uint16(n)) + return buf +} + +func (b BigEndianBuf) Uint16(n uint16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, n) + return buf +} + +func (b BigEndianBuf) Int32(n int32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, uint32(n)) + return buf +} + +func (b BigEndianBuf) Uint32(n uint32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, n) + return buf +} + +func (b BigEndianBuf) Int64(n int64) []byte { + buf := b[0:8] + binary.BigEndian.PutUint64(buf, uint64(n)) + return buf +} diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go new file mode 100644 index 000000000..756a30e61 --- /dev/null +++ b/pgproto3/bind_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type BindComplete struct{} + +func (*BindComplete) Backend() {} + +func (dst *BindComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *BindComplete) MarshalBinary() ([]byte, error) { + return []byte{'2', 0, 0, 0, 4}, nil +} + +func (src *BindComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "BindComplete", + }) +} diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go new file mode 100644 index 000000000..fd6ff1809 --- /dev/null +++ b/pgproto3/close_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CloseComplete struct{} + +func (*CloseComplete) Backend() {} + +func (dst *CloseComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *CloseComplete) MarshalBinary() ([]byte, error) { + return []byte{'3', 0, 0, 0, 4}, nil +} + +func (src *CloseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CloseComplete", + }) +} diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go new file mode 100644 index 000000000..ac60153ed --- /dev/null +++ b/pgproto3/command_complete.go @@ -0,0 +1,47 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type CommandComplete struct { + CommandTag string +} + +func (*CommandComplete) Backend() {} + +func (dst *CommandComplete) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.CommandTag = string(b[:len(b)-1]) + + return nil +} + +func (src *CommandComplete) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('C') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1))) + + buf.WriteString(src.CommandTag) + buf.WriteByte(0) + + return buf.Bytes(), nil +} + +func (src *CommandComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + CommandTag string + }{ + Type: "CommandComplete", + CommandTag: src.CommandTag, + }) +} diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go new file mode 100644 index 000000000..2a4c58af6 --- /dev/null +++ b/pgproto3/copy_both_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyBothResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyBothResponse) Backend() {} + +func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyBothResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('W') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyBothResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go new file mode 100644 index 000000000..b9ea62729 --- /dev/null +++ b/pgproto3/copy_data.go @@ -0,0 +1,41 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" +) + +type CopyData struct { + Data []byte +} + +func (*CopyData) Backend() {} +func (*CopyData) Frontend() {} + +func (dst *CopyData) UnmarshalBinary(src []byte) error { + dst.Data = make([]byte, len(src)) + copy(dst.Data, src) + return nil +} + +func (src *CopyData) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('d') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data)))) + buf.Write(src.Data) + + return buf.Bytes(), nil +} + +func (src *CopyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "CopyData", + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go new file mode 100644 index 000000000..63868c7ab --- /dev/null +++ b/pgproto3/copy_in_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyInResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyInResponse) Backend() {} + +func (dst *CopyInResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyInResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('G') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyInResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyInResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go new file mode 100644 index 000000000..e46d9e8f0 --- /dev/null +++ b/pgproto3/copy_out_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyOutResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyOutResponse) Backend() {} + +func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyOutResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('H') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyOutResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go new file mode 100644 index 000000000..c95861b92 --- /dev/null +++ b/pgproto3/data_row.go @@ -0,0 +1,103 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type DataRow struct { + Values [][]byte +} + +func (*DataRow) Backend() {} + +func (dst *DataRow) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Values = make([][]byte, fieldCount) + + for i := 0; i < fieldCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4)))) + + // null + if msgSize == -1 { + continue + } + + value := make([]byte, msgSize) + _, err := buf.Read(value) + if err != nil { + return err + } + + dst.Values[i] = value + } + + return nil +} + +func (src *DataRow) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('D') + buf.Write(bigEndian.Uint32(0)) + + buf.Write(bigEndian.Uint16(uint16(len(src.Values)))) + + for _, v := range src.Values { + if v == nil { + buf.Write(bigEndian.Int32(-1)) + continue + } + + buf.Write(bigEndian.Int32(int32(len(v)))) + buf.Write(v) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *DataRow) MarshalJSON() ([]byte, error) { + formattedValues := make([]map[string]string, len(src.Values)) + for i, v := range src.Values { + if v == nil { + continue + } + + var hasNonPrintable bool + for _, b := range v { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)} + } else { + formattedValues[i] = map[string]string{"text": string(v)} + } + } + + return json.Marshal(struct { + Type string + Values []map[string]string + }{ + Type: "DataRow", + Values: formattedValues, + }) +} diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go new file mode 100644 index 000000000..de6e6272b --- /dev/null +++ b/pgproto3/empty_query_response.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type EmptyQueryResponse struct{} + +func (*EmptyQueryResponse) Backend() {} + +func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) { + return []byte{'I', 0, 0, 0, 4}, nil +} + +func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "EmptyQueryResponse", + }) +} diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go new file mode 100644 index 000000000..82e408d76 --- /dev/null +++ b/pgproto3/error_response.go @@ -0,0 +1,197 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "strconv" +) + +type ErrorResponse struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string +} + +func (*ErrorResponse) Backend() {} + +func (dst *ErrorResponse) UnmarshalBinary(src []byte) error { + *dst = ErrorResponse{} + + buf := bytes.NewBuffer(src) + + for { + k, err := buf.ReadByte() + if err != nil { + return err + } + if k == 0 { + break + } + + vb, err := buf.ReadBytes(0) + if err != nil { + return err + } + v := string(vb[:len(vb)-1]) + + switch k { + case 'S': + dst.Severity = v + case 'C': + dst.Code = v + case 'M': + dst.Message = v + case 'D': + dst.Detail = v + case 'H': + dst.Hint = v + case 'P': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Position = int32(n) + case 'p': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.InternalPosition = int32(n) + case 'q': + dst.InternalQuery = v + case 'W': + dst.Where = v + case 's': + dst.SchemaName = v + case 't': + dst.TableName = v + case 'c': + dst.ColumnName = v + case 'd': + dst.DataTypeName = v + case 'n': + dst.ConstraintName = v + case 'F': + dst.File = v + case 'L': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Line = int32(n) + case 'R': + dst.Routine = v + + default: + if dst.UnknownFields == nil { + dst.UnknownFields = make(map[byte]string) + } + dst.UnknownFields[k] = v + } + } + + return nil +} + +func (src *ErrorResponse) MarshalBinary() ([]byte, error) { + return src.marshalBinary('E') +} + +func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte(typeByte) + buf.Write(bigEndian.Uint32(0)) + + if src.Severity != "" { + buf.WriteString(src.Severity) + buf.WriteByte(0) + } + if src.Code != "" { + buf.WriteString(src.Code) + buf.WriteByte(0) + } + if src.Message != "" { + buf.WriteString(src.Message) + buf.WriteByte(0) + } + if src.Detail != "" { + buf.WriteString(src.Detail) + buf.WriteByte(0) + } + if src.Hint != "" { + buf.WriteString(src.Hint) + buf.WriteByte(0) + } + if src.Position != 0 { + buf.WriteString(strconv.Itoa(int(src.Position))) + buf.WriteByte(0) + } + if src.InternalPosition != 0 { + buf.WriteString(strconv.Itoa(int(src.InternalPosition))) + buf.WriteByte(0) + } + if src.InternalQuery != "" { + buf.WriteString(src.InternalQuery) + buf.WriteByte(0) + } + if src.Where != "" { + buf.WriteString(src.Where) + buf.WriteByte(0) + } + if src.SchemaName != "" { + buf.WriteString(src.SchemaName) + buf.WriteByte(0) + } + if src.TableName != "" { + buf.WriteString(src.TableName) + buf.WriteByte(0) + } + if src.ColumnName != "" { + buf.WriteString(src.ColumnName) + buf.WriteByte(0) + } + if src.DataTypeName != "" { + buf.WriteString(src.DataTypeName) + buf.WriteByte(0) + } + if src.ConstraintName != "" { + buf.WriteString(src.ConstraintName) + buf.WriteByte(0) + } + if src.File != "" { + buf.WriteString(src.File) + buf.WriteByte(0) + } + if src.Line != 0 { + buf.WriteString(strconv.Itoa(int(src.Line))) + buf.WriteByte(0) + } + if src.Routine != "" { + buf.WriteString(src.Routine) + buf.WriteByte(0) + } + + for k, v := range src.UnknownFields { + buf.WriteByte(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go new file mode 100644 index 000000000..c1dec461e --- /dev/null +++ b/pgproto3/frontend.go @@ -0,0 +1,70 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/chunkreader" +) + +type Frontend struct { + cr *chunkreader.ChunkReader + w io.Writer +} + +func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { + cr := chunkreader.NewChunkReader(r) + return &Frontend{cr: cr, w: w}, nil +} + +func (b *Frontend) Send(msg FrontendMessage) error { + return errors.New("not implemented") +} + +func (b *Frontend) Receive() (BackendMessage, error) { + backendMessages := map[byte]BackendMessage{ + '1': &ParseComplete{}, + '2': &BindComplete{}, + '3': &CloseComplete{}, + 'A': &NotificationResponse{}, + 'C': &CommandComplete{}, + 'd': &CopyData{}, + 'D': &DataRow{}, + 'E': &ErrorResponse{}, + 'G': &CopyInResponse{}, + 'H': &CopyOutResponse{}, + 'I': &EmptyQueryResponse{}, + 'K': &BackendKeyData{}, + 'n': &NoData{}, + 'N': &NoticeResponse{}, + 'R': &Authentication{}, + 'S': &ParameterStatus{}, + 't': &ParameterDescription{}, + 'T': &RowDescription{}, + 'V': &FunctionCallResponse{}, + 'W': &CopyBothResponse{}, + 'Z': &ReadyForQuery{}, + } + + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + if msg, ok := backendMessages[msgType]; ok { + err = msg.UnmarshalBinary(msgBody) + return msg, err + } + + return nil, fmt.Errorf("unknown message type: %c", msgType) +} diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go new file mode 100644 index 000000000..5c692b36c --- /dev/null +++ b/pgproto3/function_call_response.go @@ -0,0 +1,73 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type FunctionCallResponse struct { + Result []byte +} + +func (*FunctionCallResponse) Backend() {} + +func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + resultSize := int(binary.BigEndian.Uint32(buf.Next(4))) + if buf.Len() != resultSize { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + + dst.Result = make([]byte, resultSize) + copy(dst.Result, buf.Bytes()) + + return nil +} + +func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('V') + buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result)))) + + if src.Result == nil { + buf.Write(bigEndian.Int32(-1)) + } else { + buf.Write(bigEndian.Int32(int32(len(src.Result)))) + buf.Write(src.Result) + } + + return buf.Bytes(), nil +} + +func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { + var formattedValue map[string]string + var hasNonPrintable bool + for _, b := range src.Result { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)} + } else { + formattedValue = map[string]string{"text": string(src.Result)} + } + + return json.Marshal(struct { + Type string + Result map[string]string + }{ + Type: "FunctionCallResponse", + Result: formattedValue, + }) +} diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go new file mode 100644 index 000000000..47ebf28e0 --- /dev/null +++ b/pgproto3/no_data.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type NoData struct{} + +func (*NoData) Backend() {} + +func (dst *NoData) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *NoData) MarshalBinary() ([]byte, error) { + return []byte{'n', 0, 0, 0, 4}, nil +} + +func (src *NoData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "NoData", + }) +} diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go new file mode 100644 index 000000000..767c9a67f --- /dev/null +++ b/pgproto3/notice_response.go @@ -0,0 +1,13 @@ +package pgproto3 + +type NoticeResponse ErrorResponse + +func (*NoticeResponse) Backend() {} + +func (dst *NoticeResponse) UnmarshalBinary(src []byte) error { + return (*ErrorResponse)(dst).UnmarshalBinary(src) +} + +func (src *NoticeResponse) MarshalBinary() ([]byte, error) { + return (*ErrorResponse)(src).marshalBinary('N') +} diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go new file mode 100644 index 000000000..4ae8bab33 --- /dev/null +++ b/pgproto3/notification_response.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type NotificationResponse struct { + PID uint32 + Channel string + Payload string +} + +func (*NotificationResponse) Backend() {} + +func (dst *NotificationResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + pid := binary.BigEndian.Uint32(buf.Next(4)) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + channel := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + payload := string(b[:len(b)-1]) + + *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} + return nil +} + +func (src *NotificationResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('A') + buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload)))) + + buf.WriteString(src.Channel) + buf.WriteByte(0) + buf.WriteString(src.Payload) + buf.WriteByte(0) + + return buf.Bytes(), nil +} + +func (src *NotificationResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + PID uint32 + Channel string + Payload string + }{ + Type: "NotificationResponse", + PID: src.PID, + Channel: src.Channel, + Payload: src.Payload, + }) +} diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go new file mode 100644 index 000000000..40d92c50f --- /dev/null +++ b/pgproto3/parameter_description.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type ParameterDescription struct { + ParameterOIDs []uint32 +} + +func (*ParameterDescription) Backend() {} + +func (dst *ParameterDescription) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "ParameterDescription"} + } + + // Reported parameter count will be incorrect when number of args is greater than uint16 + buf.Next(2) + // Instead infer parameter count by remaining size of message + parameterCount := buf.Len() / 4 + + *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} + + for i := 0; i < parameterCount; i++ { + dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) + } + + return nil +} + +func (src *ParameterDescription) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('t') + buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) + + for _, oid := range src.ParameterOIDs { + buf.Write(bigEndian.Uint32(oid)) + } + + return buf.Bytes(), nil +} + +func (src *ParameterDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ParameterOIDs []uint32 + }{ + Type: "ParameterDescription", + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go new file mode 100644 index 000000000..b8ce7f8d1 --- /dev/null +++ b/pgproto3/parameter_status.go @@ -0,0 +1,62 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type ParameterStatus struct { + Name string + Value string +} + +func (*ParameterStatus) Backend() {} + +func (dst *ParameterStatus) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + name := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + value := string(b[:len(b)-1]) + + *dst = ParameterStatus{Name: name, Value: value} + return nil +} + +func (src *ParameterStatus) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('S') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Name) + buf.WriteByte(0) + buf.WriteString(src.Value) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Value string + }{ + Type: "ParameterStatus", + Name: ps.Name, + Value: ps.Value, + }) +} diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go new file mode 100644 index 000000000..24951e3de --- /dev/null +++ b/pgproto3/parse_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ParseComplete struct{} + +func (*ParseComplete) Backend() {} + +func (dst *ParseComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *ParseComplete) MarshalBinary() ([]byte, error) { + return []byte{'1', 0, 0, 0, 4}, nil +} + +func (src *ParseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "ParseComplete", + }) +} diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go new file mode 100644 index 000000000..a9221239d --- /dev/null +++ b/pgproto3/pgproto3.go @@ -0,0 +1,88 @@ +package pgproto3 + +import "fmt" + +type Message interface { + UnmarshalBinary(data []byte) error + MarshalBinary() (data []byte, err error) +} + +type FrontendMessage interface { + Message + Frontend() // no-op method to distinguish frontend from backend methods +} + +type BackendMessage interface { + Message + Backend() // no-op method to distinguish frontend from backend methods +} + +// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) { +// switch typeByte { +// case '1': +// return ParseParseComplete(body) +// case '2': +// return ParseBindComplete(body) +// case 'C': +// return ParseCommandComplete(body) +// case 'D': +// return ParseDataRow(body) +// case 'E': +// return ParseErrorResponse(body) +// case 'K': +// return ParseBackendKeyData(body) +// case 'R': +// return ParseAuthentication(body) +// case 'S': +// return ParseParameterStatus(body) +// case 'T': +// return ParseRowDescription(body) +// case 't': +// return ParseParameterDescription(body) +// case 'Z': +// return ParseReadyForQuery(body) +// default: +// return ParseUnknownMessage(typeByte, body) +// } +// } + +// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) { +// switch typeByte { +// case 'B': +// return ParseBind(body) +// case 'D': +// return ParseDescribe(body) +// case 'E': +// return ParseExecute(body) +// case 'P': +// return ParseParse(body) +// case 'p': +// return ParsePasswordMessage(body) +// case 'Q': +// return ParseQuery(body) +// case 'S': +// return ParseSync(body) +// case 'X': +// return ParseTerminate(body) +// default: +// return ParseUnknownMessage(typeByte, body) +// } +// } + +type invalidMessageLenErr struct { + messageType string + expectedLen int + actualLen int +} + +func (e *invalidMessageLenErr) Error() string { + return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) +} + +type invalidMessageFormatErr struct { + messageType string +} + +func (e *invalidMessageFormatErr) Error() string { + return fmt.Sprintf("%s body is invalid", e.messageType) +} diff --git a/pgproto3/query.go b/pgproto3/query.go new file mode 100644 index 000000000..a3fc32eb5 --- /dev/null +++ b/pgproto3/query.go @@ -0,0 +1,43 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type Query struct { + String string +} + +func (*Query) Frontend() {} + +func (dst *Query) UnmarshalBinary(src []byte) error { + i := bytes.IndexByte(src, 0) + if i != len(src)-1 { + return &invalidMessageFormatErr{messageType: "Query"} + } + + dst.String = string(src[:i]) + + return nil +} + +func (src *Query) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('Q') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1))) + buf.WriteString(src.String) + buf.WriteByte(0) + return buf.Bytes(), nil +} + +func (src *Query) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + String string + }{ + Type: "Query", + String: src.String, + }) +} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go new file mode 100644 index 000000000..09005d000 --- /dev/null +++ b/pgproto3/ready_for_query.go @@ -0,0 +1,35 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ReadyForQuery struct { + TxStatus byte +} + +func (*ReadyForQuery) Backend() {} + +func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error { + if len(src) != 1 { + return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} + } + + dst.TxStatus = src[0] + + return nil +} + +func (src *ReadyForQuery) MarshalBinary() ([]byte, error) { + return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil +} + +func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + TxStatus string + }{ + Type: "ReadyForQuery", + TxStatus: string(src.TxStatus), + }) +} diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go new file mode 100644 index 000000000..294a6aa9a --- /dev/null +++ b/pgproto3/row_description.go @@ -0,0 +1,101 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +const ( + TextFormat = 0 + BinaryFormat = 1 +) + +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier uint32 + Format int16 +} + +type RowDescription struct { + Fields []FieldDescription +} + +func (*RowDescription) Backend() {} + +func (dst *RowDescription) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + *dst = RowDescription{Fields: make([]FieldDescription, fieldCount)} + + for i := 0; i < fieldCount; i++ { + var fd FieldDescription + bName, err := buf.ReadBytes(0) + if err != nil { + return err + } + fd.Name = string(bName[:len(bName)-1]) + + // Since buf.Next() doesn't return an error if we hit the end of the buffer + // check Len ahead of time + if buf.Len() < 18 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + + fd.TableOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) + fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) + fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4)) + fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Fields[i] = fd + } + + return nil +} + +func (src *RowDescription) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('T') + buf.Write(bigEndian.Uint32(0)) + + buf.Write(bigEndian.Uint16(uint16(len(src.Fields)))) + + for _, fd := range src.Fields { + buf.WriteString(fd.Name) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint32(fd.TableOID)) + buf.Write(bigEndian.Uint16(fd.TableAttributeNumber)) + buf.Write(bigEndian.Uint32(fd.DataTypeOID)) + buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize))) + buf.Write(bigEndian.Uint32(fd.TypeModifier)) + buf.Write(bigEndian.Uint16(uint16(fd.Format))) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *RowDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Fields []FieldDescription + }{ + Type: "RowDescription", + Fields: src.Fields, + }) +} diff --git a/query.go b/query.go index f7d8ed19e..04a870436 100644 --- a/query.go +++ b/query.go @@ -8,6 +8,7 @@ import ( "time" "github.com/jackc/pgx/internal/sanitize" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -41,7 +42,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) { // calling Next() until it returns false, or when a fatal error occurs. type Rows struct { conn *Conn - mr *msgReader + values [][]byte fields []FieldDescription rowCount int columnIdx int @@ -115,15 +116,15 @@ func (rows *Rows) Next() bool { rows.columnIdx = 0 for { - t, r, err := rows.conn.rxMsg() + msg, err := rows.conn.rxMsg() if err != nil { rows.Fatal(err) return false } - switch t { - case rowDescription: - rows.fields = rows.conn.rxRowDescription(r) + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rows.fields = rows.conn.rxRowDescription(msg) for i := range rows.fields { if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok { rows.fields[i].DataTypeName = dt.Name @@ -133,21 +134,20 @@ func (rows *Rows) Next() bool { return false } } - case dataRow: - fieldCount := r.readInt16() - if int(fieldCount) != len(rows.fields) { - rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount))) + case *pgproto3.DataRow: + if len(msg.Values) != len(rows.fields) { + rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values)))) return false } - rows.mr = r + rows.values = msg.Values return true - case commandComplete: + case *pgproto3.CommandComplete: rows.Close() return false default: - err = rows.conn.processContextFreeMsg(t, r) + err = rows.conn.processContextFreeMsg(msg) if err != nil { rows.Fatal(err) return false @@ -170,13 +170,9 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { return nil, nil, false } + buf := rows.values[rows.columnIdx] fd := &rows.fields[rows.columnIdx] rows.columnIdx++ - size := rows.mr.readInt32() - var buf []byte - if size >= 0 { - buf = rows.mr.readBytes(size) - } return buf, fd, true } diff --git a/replication.go b/replication.go index a251172df..ea768961b 100644 --- a/replication.go +++ b/replication.go @@ -2,9 +2,12 @@ package pgx import ( "context" + "encoding/binary" "errors" "fmt" "time" + + "github.com/jackc/pgx/pgproto3" ) const ( @@ -203,59 +206,64 @@ func (rc *ReplicationConn) CauseOfDeath() error { } func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { - var t byte - var reader *msgReader - t, reader, err = rc.c.rxMsg() + msg, err := rc.c.rxMsg() if err != nil { return } - switch t { - case noticeResponse: - pgError := rc.c.rxErrorResponse(reader) + switch msg := msg.(type) { + case *pgproto3.NoticeResponse: + pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) if rc.c.shouldLog(LogLevelInfo) { rc.c.log(LogLevelInfo, pgError.Error()) } - case errorResponse: - err = rc.c.rxErrorResponse(reader) + case *pgproto3.ErrorResponse: + err = rc.c.rxErrorResponse(msg) if rc.c.shouldLog(LogLevelError) { rc.c.log(LogLevelError, err.Error()) } return - case copyBothResponse: + case *pgproto3.CopyBothResponse: // This is the tail end of the replication process start, // and can be safely ignored return - case copyData: - var msgType byte - msgType = reader.readByte() + case *pgproto3.CopyData: + msgType := msg.Data[0] + rp := 1 + switch msgType { case walData: - walStart := reader.readInt64() - serverWalEnd := reader.readInt64() - serverTime := reader.readInt64() - walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp)) - walMessage := WalMessage{WalStart: uint64(walStart), - ServerWalEnd: uint64(serverWalEnd), - ServerTime: uint64(serverTime), + walStart := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + walData := msg.Data[rp:] + walMessage := WalMessage{WalStart: walStart, + ServerWalEnd: serverWalEnd, + ServerTime: serverTime, WalData: walData, } return &ReplicationMessage{WalMessage: &walMessage}, nil case senderKeepalive: - serverWalEnd := reader.readInt64() - serverTime := reader.readInt64() - replyNow := reader.readByte() - h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow} + serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + replyNow := msg.Data[rp] + rp += 1 + h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow} return &ReplicationMessage{ServerHeartbeat: h}, nil default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected data playload message type %v", t) + rc.c.log(LogLevelError, "Unexpected data playload message type %v", msgType) } } default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected replication message type %v", t) + rc.c.log(LogLevelError, "Unexpected replication message type %T", msg) } } return @@ -325,21 +333,19 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows.Fatal(err) } - var t byte - var r *msgReader - t, r, err = rc.c.rxMsg() + msg, err := rc.c.rxMsg() if err != nil { return nil, err } - switch t { - case rowDescription: - rows.fields = rc.c.rxRowDescription(r) + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rows.fields = rc.c.rxRowDescription(msg) // We don't have c.PgTypes here because we're a replication // connection. This means the field descriptions will have // only Oids. Not much we can do about this. default: - if e := rc.c.processContextFreeMsg(t, r); e != nil { + if e := rc.c.processContextFreeMsg(msg); e != nil { rows.Fatal(e) return rows, e }