Skip to content

Commit

Permalink
feat: include support for multiple statements
Browse files Browse the repository at this point in the history
  • Loading branch information
jeroenrinzema committed Dec 19, 2023
1 parent 949b563 commit f4fcef6
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 83 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ func main() {
wire.ListenAndServe("127.0.0.1:5432", handler)
}

func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
func handler(ctx context.Context, query string) (wire.PreparedStatements, error) {
return wire.Prepared(wire.NewStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
fmt.Println(query)
return writer.Complete("OK")
})

return statement, nil
})), nil
}
```

Expand Down
56 changes: 44 additions & 12 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ func NewErrUnkownStatement(name string) error {
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.InvalidPreparedStatementDefinition), psqlerr.LevelFatal)
}

// NewErrUndefinedStatement is returned whenever no statement has been defined
// within the incoming query.
func NewErrUndefinedStatement() error {
err := errors.New("no statement has been defined")
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Syntax), psqlerr.LevelError)
}

// NewErrMultipleCommandsStatements is returned whenever multiple statements have been
// given within a single query during the extended query protocol.
func NewErrMultipleCommandsStatements() error {
err := errors.New("cannot insert multiple commands into a prepared statement")
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Syntax), psqlerr.LevelError)
}

// consumeCommands consumes incoming commands send over the Postgres wire connection.
// Commands consumed from the connection are returned through a go channel.
// Responses for the given message type are written back to the client.
Expand Down Expand Up @@ -237,24 +251,26 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
return readyForQuery(writer, types.ServerIdle)
}

statement, err := srv.parse(ctx, query)
statements, err := srv.parse(ctx, query)
if err != nil {
return ErrorCode(writer, err)
}

if err != nil {
return ErrorCode(writer, err)
if len(statements) == 0 {
return ErrorCode(writer, NewErrUndefinedStatement())
}

// NOTE: we have to define the column definitions before executing a simple query
err = statement.columns.Define(ctx, writer, nil)
if err != nil {
return ErrorCode(writer, err)
}
// NOTE: it is possible to send multiple statements in one simple query.
for index := range statements {
err = statements[index].columns.Define(ctx, writer, nil)
if err != nil {
return ErrorCode(writer, err)
}

err = statement.fn(ctx, NewDataWriter(ctx, statement.columns, nil, writer), nil)
if err != nil {
return ErrorCode(writer, err)
err = statements[index].fn(ctx, NewDataWriter(ctx, statements[index].columns, nil, writer), nil)
if err != nil {
return ErrorCode(writer, err)
}
}

return readyForQuery(writer, types.ServerIdle)
Expand Down Expand Up @@ -294,7 +310,7 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write
// `reader.GetUint32()`
}

statement, err := srv.parse(ctx, query)
statement, err := singleStatement(srv.parse(ctx, query))
if err != nil {
return ErrorCode(writer, err)
}
Expand Down Expand Up @@ -543,3 +559,19 @@ func (srv *Server) handleConnTerminate(ctx context.Context) error {

return srv.TerminateConn(ctx)
}

func singleStatement(stmts PreparedStatements, err error) (*PreparedStatement, error) {
if err != nil {
return nil, err
}

if len(stmts) > 1 {
return nil, NewErrMultipleCommandsStatements()
}

if len(stmts) == 0 {
return nil, NewErrUndefinedStatement()
}

return stmts[0], nil
}
10 changes: 4 additions & 6 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ func TestBindMessageParameters(t *testing.T) {
},
}

handler := func(ctx context.Context, query string) (*PreparedStatement, error) {
statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
handler := func(ctx context.Context, query string) (PreparedStatements, error) {
handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
t.Log("serving query")

if len(parameters) != 2 {
Expand All @@ -79,11 +79,9 @@ func TestBindMessageParameters(t *testing.T) {

writer.Row([]any{first, second}) //nolint:errcheck
return writer.Complete("SELECT 1")
})
}

statement.WithParameters(ParseParameters(query))
statement.WithColumns(columns)
return statement, nil
return Prepared(NewStatement(handle, WithColumns(columns), WithParameters(ParseParameters(query)))), nil
}

server, err := NewServer(handler, Logger(slogt.New(t)))
Expand Down
6 changes: 3 additions & 3 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import (
)

func TestErrorCode(t *testing.T) {
handler := func(ctx context.Context, query string) (*PreparedStatement, error) {
statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
handler := func(ctx context.Context, query string) (PreparedStatements, error) {
stmt := NewStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
return psqlerr.WithSeverity(psqlerr.WithCode(errors.New("unimplemented feature"), codes.FeatureNotSupported), psqlerr.LevelFatal)
})

return statement, nil
return Prepared(stmt), nil
}

server, err := NewServer(handler, Logger(slogt.New(t)))
Expand Down
2 changes: 1 addition & 1 deletion examples/error/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func main() {
wire.ListenAndServe("127.0.0.1:5432", handler)
}

func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
func handler(ctx context.Context, query string) (wire.PreparedStatements, error) {
log.Println("incoming SQL query:", query)

err := errors.New("unimplemented feature")
Expand Down
8 changes: 4 additions & 4 deletions examples/session/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ func session(ctx context.Context) (context.Context, error) {
return context.WithValue(ctx, id, counter), nil
}

func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
func handler(ctx context.Context, query string) (wire.PreparedStatements, error) {
log.Println("incoming SQL query:", query)

statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
handle := func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
session := ctx.Value(id).(int)
return writer.Complete(fmt.Sprintf("OK, session: %d", session))
})
}

return statement, nil
return wire.Prepared(wire.NewStatement(handle)), nil
}
9 changes: 4 additions & 5 deletions examples/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ var table = wire.Columns{
},
}

func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
func handler(ctx context.Context, query string) (wire.PreparedStatements, error) {
log.Println("incoming SQL query:", query)

statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
handle := func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
writer.Row([]any{"John", true, 29})
writer.Row([]any{"Marry", false, 21})
return writer.Complete("SELECT 2")
})
}

statement.WithColumns(table)
return statement, nil
return wire.Prepared(wire.NewStatement(handle, wire.WithColumns(table))), nil
}
8 changes: 4 additions & 4 deletions examples/tls/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ func run() error {
return server.ListenAndServe("127.0.0.1:5432")
}

func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
func handler(ctx context.Context, query string) (wire.PreparedStatements, error) {
slog.Info("incoming SQL query", slog.String("query", query))

statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
handle := func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
return writer.Complete("OK")
})
}

return statement, nil
return wire.Prepared(wire.NewStatement(handle)), nil
}
52 changes: 40 additions & 12 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,60 @@ import (

// ParseFn parses the given query and returns a prepared statement which could
// be used to execute at a later point in time.
type ParseFn func(ctx context.Context, query string) (*PreparedStatement, error)
type ParseFn func(ctx context.Context, query string) (PreparedStatements, error)

// PreparedStatementFn represents a query of which a statement has been
// prepared. The statement could be executed at any point in time with the given
// arguments and data writer.
type PreparedStatementFn func(ctx context.Context, writer DataWriter, parameters []Parameter) error

// NewPreparedStatement constructs a new prepared statement for the given function.
func NewPreparedStatement(fn PreparedStatementFn) *PreparedStatement {
return &PreparedStatement{
// Prepared is a small wrapper function returning a list of prepared statements.
// More then one prepared statement could be returned within the simple query
// protocol. An error is returned when more then one prepared statement is
// returned in the extended query protocol.
// https://www.postgresql.org/docs/15/protocol-flow.html#PROTOCOL-FLOW-MULTI-STATEMENT
func Prepared(stmts ...*PreparedStatement) PreparedStatements {
return stmts
}

// NewStatement constructs a new prepared statement for the given function.
func NewStatement(fn PreparedStatementFn, options ...PreparedOptionFn) *PreparedStatement {
stmt := &PreparedStatement{
fn: fn,
}

for _, option := range options {
option(stmt)
}

return stmt
}

type PreparedStatement struct {
fn PreparedStatementFn
parameters []oid.Oid
columns Columns
// PreparedOptionFn options pattern used to define options while preparing a new statement.
type PreparedOptionFn func(*PreparedStatement)

// WithColumns sets the given columns as the columns which are returned by the
// prepared statement.
func WithColumns(columns Columns) PreparedOptionFn {
return func(stmt *PreparedStatement) {
stmt.columns = columns
}
}

func (stmt *PreparedStatement) WithParameters(parameters []oid.Oid) {
stmt.parameters = parameters
// WithParameters sets the given parameters as the parameters which are expected
// by the prepared statement.
func WithParameters(parameters []oid.Oid) PreparedOptionFn {
return func(stmt *PreparedStatement) {
stmt.parameters = parameters
}
}

func (stmt *PreparedStatement) WithColumns(columns Columns) {
stmt.columns = columns
type PreparedStatements []*PreparedStatement

type PreparedStatement struct {
fn PreparedStatementFn
parameters []oid.Oid
columns Columns
}

// SessionHandler represents a wrapper function defining the state of a single
Expand Down
Loading

0 comments on commit f4fcef6

Please sign in to comment.