From f4fcef6a61aa23e7fe28dae312cf086753dcbbaa Mon Sep 17 00:00:00 2001 From: Jeroen Rinzema Date: Tue, 19 Dec 2023 21:36:23 +0100 Subject: [PATCH] feat: include support for multiple statements --- README.md | 8 +++--- command.go | 56 +++++++++++++++++++++++++++++--------- command_test.go | 10 +++---- error_test.go | 6 ++--- examples/error/main.go | 2 +- examples/session/main.go | 8 +++--- examples/simple/main.go | 9 +++---- examples/tls/main.go | 8 +++--- options.go | 52 ++++++++++++++++++++++++++--------- wire_test.go | 58 +++++++++++++++++++--------------------- 10 files changed, 134 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 85b36f6..c6083a4 100644 --- a/README.md +++ b/README.md @@ -24,13 +24,11 @@ func main() { wire.ListenAndServe("127.0.0.1:5432", handler) } -func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) { - statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error { +func handler(ctx context.Context, query string) (wire.PreparedStatements, error) { + return wire.Prepared(wire.NewStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error { fmt.Println(query) return writer.Complete("OK") - }) - - return statement, nil + })), nil } ``` diff --git a/command.go b/command.go index f8854fb..625f81f 100644 --- a/command.go +++ b/command.go @@ -31,6 +31,20 @@ func NewErrUnkownStatement(name string) error { return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.InvalidPreparedStatementDefinition), psqlerr.LevelFatal) } +// NewErrUndefinedStatement is returned whenever no statement has been defined +// within the incoming query. +func NewErrUndefinedStatement() error { + err := errors.New("no statement has been defined") + return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Syntax), psqlerr.LevelError) +} + +// NewErrMultipleCommandsStatements is returned whenever multiple statements have been +// given within a single query during the extended query protocol. +func NewErrMultipleCommandsStatements() error { + err := errors.New("cannot insert multiple commands into a prepared statement") + return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Syntax), psqlerr.LevelError) +} + // consumeCommands consumes incoming commands send over the Postgres wire connection. // Commands consumed from the connection are returned through a go channel. // Responses for the given message type are written back to the client. @@ -237,24 +251,26 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, return readyForQuery(writer, types.ServerIdle) } - statement, err := srv.parse(ctx, query) + statements, err := srv.parse(ctx, query) if err != nil { return ErrorCode(writer, err) } - if err != nil { - return ErrorCode(writer, err) + if len(statements) == 0 { + return ErrorCode(writer, NewErrUndefinedStatement()) } - // NOTE: we have to define the column definitions before executing a simple query - err = statement.columns.Define(ctx, writer, nil) - if err != nil { - return ErrorCode(writer, err) - } + // NOTE: it is possible to send multiple statements in one simple query. + for index := range statements { + err = statements[index].columns.Define(ctx, writer, nil) + if err != nil { + return ErrorCode(writer, err) + } - err = statement.fn(ctx, NewDataWriter(ctx, statement.columns, nil, writer), nil) - if err != nil { - return ErrorCode(writer, err) + err = statements[index].fn(ctx, NewDataWriter(ctx, statements[index].columns, nil, writer), nil) + if err != nil { + return ErrorCode(writer, err) + } } return readyForQuery(writer, types.ServerIdle) @@ -294,7 +310,7 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write // `reader.GetUint32()` } - statement, err := srv.parse(ctx, query) + statement, err := singleStatement(srv.parse(ctx, query)) if err != nil { return ErrorCode(writer, err) } @@ -543,3 +559,19 @@ func (srv *Server) handleConnTerminate(ctx context.Context) error { return srv.TerminateConn(ctx) } + +func singleStatement(stmts PreparedStatements, err error) (*PreparedStatement, error) { + if err != nil { + return nil, err + } + + if len(stmts) > 1 { + return nil, NewErrMultipleCommandsStatements() + } + + if len(stmts) == 0 { + return nil, NewErrUndefinedStatement() + } + + return stmts[0], nil +} diff --git a/command_test.go b/command_test.go index 2568041..2f0157f 100644 --- a/command_test.go +++ b/command_test.go @@ -66,8 +66,8 @@ func TestBindMessageParameters(t *testing.T) { }, } - handler := func(ctx context.Context, query string) (*PreparedStatement, error) { - statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error { t.Log("serving query") if len(parameters) != 2 { @@ -79,11 +79,9 @@ func TestBindMessageParameters(t *testing.T) { writer.Row([]any{first, second}) //nolint:errcheck return writer.Complete("SELECT 1") - }) + } - statement.WithParameters(ParseParameters(query)) - statement.WithColumns(columns) - return statement, nil + return Prepared(NewStatement(handle, WithColumns(columns), WithParameters(ParseParameters(query)))), nil } server, err := NewServer(handler, Logger(slogt.New(t))) diff --git a/error_test.go b/error_test.go index acbb669..d7a0d2d 100644 --- a/error_test.go +++ b/error_test.go @@ -15,12 +15,12 @@ import ( ) func TestErrorCode(t *testing.T) { - handler := func(ctx context.Context, query string) (*PreparedStatement, error) { - statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + stmt := NewStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { return psqlerr.WithSeverity(psqlerr.WithCode(errors.New("unimplemented feature"), codes.FeatureNotSupported), psqlerr.LevelFatal) }) - return statement, nil + return Prepared(stmt), nil } server, err := NewServer(handler, Logger(slogt.New(t))) diff --git a/examples/error/main.go b/examples/error/main.go index 0b6e1fe..3d5588d 100644 --- a/examples/error/main.go +++ b/examples/error/main.go @@ -15,7 +15,7 @@ func main() { wire.ListenAndServe("127.0.0.1:5432", handler) } -func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) { +func handler(ctx context.Context, query string) (wire.PreparedStatements, error) { log.Println("incoming SQL query:", query) err := errors.New("unimplemented feature") diff --git a/examples/session/main.go b/examples/session/main.go index 421fca2..6228dad 100644 --- a/examples/session/main.go +++ b/examples/session/main.go @@ -32,13 +32,13 @@ func session(ctx context.Context) (context.Context, error) { return context.WithValue(ctx, id, counter), nil } -func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) { +func handler(ctx context.Context, query string) (wire.PreparedStatements, error) { log.Println("incoming SQL query:", query) - statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error { + handle := func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error { session := ctx.Value(id).(int) return writer.Complete(fmt.Sprintf("OK, session: %d", session)) - }) + } - return statement, nil + return wire.Prepared(wire.NewStatement(handle)), nil } diff --git a/examples/simple/main.go b/examples/simple/main.go index d677b33..a67c7ad 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -34,15 +34,14 @@ var table = wire.Columns{ }, } -func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) { +func handler(ctx context.Context, query string) (wire.PreparedStatements, error) { log.Println("incoming SQL query:", query) - statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error { + handle := func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error { writer.Row([]any{"John", true, 29}) writer.Row([]any{"Marry", false, 21}) return writer.Complete("SELECT 2") - }) + } - statement.WithColumns(table) - return statement, nil + return wire.Prepared(wire.NewStatement(handle, wire.WithColumns(table))), nil } diff --git a/examples/tls/main.go b/examples/tls/main.go index 9a789a6..627de54 100644 --- a/examples/tls/main.go +++ b/examples/tls/main.go @@ -33,12 +33,12 @@ func run() error { return server.ListenAndServe("127.0.0.1:5432") } -func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) { +func handler(ctx context.Context, query string) (wire.PreparedStatements, error) { slog.Info("incoming SQL query", slog.String("query", query)) - statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error { + handle := func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error { return writer.Complete("OK") - }) + } - return statement, nil + return wire.Prepared(wire.NewStatement(handle)), nil } diff --git a/options.go b/options.go index d093286..29b6b9b 100644 --- a/options.go +++ b/options.go @@ -16,32 +16,60 @@ import ( // ParseFn parses the given query and returns a prepared statement which could // be used to execute at a later point in time. -type ParseFn func(ctx context.Context, query string) (*PreparedStatement, error) +type ParseFn func(ctx context.Context, query string) (PreparedStatements, error) // PreparedStatementFn represents a query of which a statement has been // prepared. The statement could be executed at any point in time with the given // arguments and data writer. type PreparedStatementFn func(ctx context.Context, writer DataWriter, parameters []Parameter) error -// NewPreparedStatement constructs a new prepared statement for the given function. -func NewPreparedStatement(fn PreparedStatementFn) *PreparedStatement { - return &PreparedStatement{ +// Prepared is a small wrapper function returning a list of prepared statements. +// More then one prepared statement could be returned within the simple query +// protocol. An error is returned when more then one prepared statement is +// returned in the extended query protocol. +// https://www.postgresql.org/docs/15/protocol-flow.html#PROTOCOL-FLOW-MULTI-STATEMENT +func Prepared(stmts ...*PreparedStatement) PreparedStatements { + return stmts +} + +// NewStatement constructs a new prepared statement for the given function. +func NewStatement(fn PreparedStatementFn, options ...PreparedOptionFn) *PreparedStatement { + stmt := &PreparedStatement{ fn: fn, } + + for _, option := range options { + option(stmt) + } + + return stmt } -type PreparedStatement struct { - fn PreparedStatementFn - parameters []oid.Oid - columns Columns +// PreparedOptionFn options pattern used to define options while preparing a new statement. +type PreparedOptionFn func(*PreparedStatement) + +// WithColumns sets the given columns as the columns which are returned by the +// prepared statement. +func WithColumns(columns Columns) PreparedOptionFn { + return func(stmt *PreparedStatement) { + stmt.columns = columns + } } -func (stmt *PreparedStatement) WithParameters(parameters []oid.Oid) { - stmt.parameters = parameters +// WithParameters sets the given parameters as the parameters which are expected +// by the prepared statement. +func WithParameters(parameters []oid.Oid) PreparedOptionFn { + return func(stmt *PreparedStatement) { + stmt.parameters = parameters + } } -func (stmt *PreparedStatement) WithColumns(columns Columns) { - stmt.columns = columns +type PreparedStatements []*PreparedStatement + +type PreparedStatement struct { + fn PreparedStatementFn + parameters []oid.Oid + columns Columns } // SessionHandler represents a wrapper function defining the state of a single diff --git a/wire_test.go b/wire_test.go index 7297b6d..c81f7e8 100644 --- a/wire_test.go +++ b/wire_test.go @@ -41,13 +41,13 @@ func TListenAndServe(t *testing.T, server *Server) *net.TCPAddr { func TestClientConnect(t *testing.T) { t.Parallel() - handler := func(ctx context.Context, query string) (*PreparedStatement, error) { - statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + statement := NewStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { t.Log("serving query") return writer.Complete("OK") }) - return statement, nil + return Prepared(statement), nil } server, err := NewServer(handler, Logger(slogt.New(t))) @@ -111,23 +111,22 @@ func TestClientConnect(t *testing.T) { func TestClientParameters(t *testing.T) { t.Parallel() - handler := func(ctx context.Context, query string) (*PreparedStatement, error) { - statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error { writer.Row([]any{"John Doe"}) //nolint:errcheck return writer.Complete("SELECT 1") - }) + } - statement.WithParameters(ParseParameters(query)) - statement.WithColumns(Columns{ + columns := Columns{ { Table: 0, Name: "full_name", Oid: oid.T_text, Width: 256, }, - }) + } - return statement, nil + return Prepared(NewStatement(handle, WithColumns(columns), WithParameters(ParseParameters(query)))), nil } server, err := NewServer(handler, Logger(slogt.New(t))) @@ -185,16 +184,15 @@ func TestClientParameters(t *testing.T) { func TestServerWritingResult(t *testing.T) { t.Parallel() - handler := func(ctx context.Context, query string) (*PreparedStatement, error) { - statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error { t.Log("serving query") writer.Row([]any{"John", true, 28}) //nolint:errcheck writer.Row([]any{"Marry", false, 21}) //nolint:errcheck return writer.Complete("SELECT 2") - }) + } - statement.WithParameters(ParseParameters(query)) - statement.WithColumns(Columns{ //nolint:errcheck + columns := Columns{ //nolint:errcheck { Table: 0, Name: "name", @@ -213,9 +211,9 @@ func TestServerWritingResult(t *testing.T) { Oid: oid.T_int4, Width: 1, }, - }) + } - return statement, nil + return Prepared(NewStatement(handle, WithColumns(columns))), nil } server, err := NewServer(handler, Logger(slogt.New(t))) @@ -338,24 +336,23 @@ func TestServerHandlingMultipleConnections(t *testing.T) { func TOpenMockServer(t *testing.T) *net.TCPAddr { t.Helper() - handler := func(ctx context.Context, query string) (*PreparedStatement, error) { - statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error { t.Log("serving query") writer.Row([]any{20}) //nolint:errcheck return writer.Complete("SELECT 1") - }) + } - statement.WithParameters(ParseParameters(query)) - statement.WithColumns(Columns{ + columns := Columns{ { Table: 0, Name: "age", Oid: oid.T_int4, Width: 1, }, - }) + } - return statement, nil + return Prepared(NewStatement(handle, WithColumns(columns), WithParameters(ParseParameters(query)))), nil } server, err := NewServer(handler, Logger(slogt.New(t))) @@ -373,25 +370,24 @@ func TestServerNULLValues(t *testing.T) { nil, } - handler := func(ctx context.Context, query string) (*PreparedStatement, error) { - statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error { + handler := func(ctx context.Context, query string) (PreparedStatements, error) { + handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error { t.Log("serving query") writer.Row([]any{"John"}) //nolint:errcheck writer.Row([]any{nil}) //nolint:errcheck return writer.Complete("SELECT 2") - }) + } - statement.WithParameters(ParseParameters(query)) - statement.WithColumns(Columns{ + columns := Columns{ { Table: 0, Name: "name", Oid: oid.T_text, Width: 256, }, - }) + } - return statement, nil + return Prepared(NewStatement(handle, WithColumns(columns))), nil } server, err := NewServer(handler, Logger(slogt.New(t)))