diff --git a/driver.go b/driver.go index d7b30871..ae0bfd69 100644 --- a/driver.go +++ b/driver.go @@ -17,6 +17,7 @@ package dqlite import ( "context" "database/sql/driver" + "io" "net" "reflect" "time" @@ -300,7 +301,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam return nil, driverError(err) } - return &Rows{ctx: ctx, response: &c.response, client: c.client, rows: rows}, nil + return &Rows{ctx: ctx, request: &c.request, response: &c.response, client: c.client, rows: rows}, nil } // Exec is an optional interface that may be implemented by a Conn. @@ -461,7 +462,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv return nil, driverError(err) } - return &Rows{ctx: ctx, response: s.response, client: s.client, rows: rows}, nil + return &Rows{ctx: ctx, request: s.request, response: s.response, client: s.client, rows: rows}, nil } // Query executes a query that may return rows, such as a @@ -491,8 +492,10 @@ func (r *Result) RowsAffected() (int64, error) { type Rows struct { ctx context.Context client *client.Client + request *client.Message response *client.Message rows client.Rows + consumed bool } // Columns returns the names of the columns. The number of @@ -506,6 +509,21 @@ func (r *Rows) Columns() []string { // Close closes the rows iterator. func (r *Rows) Close() error { r.rows.Close() + + // If we consumed the whole result set, there's nothing to do as + // there's no pending response from the server. + if r.consumed { + return nil + } + + r.rows.Close() + + // Let's issue an interrupt request and wait until we get an empty + // response, signalling that the query was interrupted. + if err := r.client.Interrupt(r.ctx, r.request, r.response); err != nil { + return driverError(err) + } + return nil } @@ -516,7 +534,8 @@ func (r *Rows) Close() error { // Next should return io.EOF when there are no more rows. func (r *Rows) Next(dest []driver.Value) error { err := r.rows.Next(dest) - if err != nil && err == client.ErrRowsPart { + + if err == client.ErrRowsPart { r.rows.Close() if err := r.client.More(r.ctx, r.response); err != nil { return driverError(err) @@ -528,6 +547,11 @@ func (r *Rows) Next(dest []driver.Value) error { r.rows = rows return r.rows.Next(dest) } + + if err == io.EOF { + r.consumed = true + } + return err } diff --git a/integration_test.go b/integration_test.go index 5758df9d..df9140bc 100644 --- a/integration_test.go +++ b/integration_test.go @@ -448,6 +448,50 @@ func TestIntegration_EmptyTimestamp(t *testing.T) { require.NoError(t, db.Close()) } +func TestIntegration_QueryInterrupt(t *testing.T) { + db, _, cleanup := newDB(t) + defer cleanup() + + _, err := db.Exec("CREATE TABLE test (n INT)") + require.NoError(t, err) + + tx, err := db.Begin() + require.NoError(t, err) + + stmt, err := tx.Prepare("INSERT INTO test(n) VALUES(?)") + require.NoError(t, err) + + for i := 0; i < 512; i++ { + _, err = stmt.Exec(int64(i)) + require.NoError(t, err) + } + + require.NoError(t, stmt.Close()) + + require.NoError(t, tx.Commit()) + + tx, err = db.Begin() + require.NoError(t, err) + + // This query will yield a multi-response result set, which needs to be + // cancelled because Rows.Next() will be called only for one row. + row := tx.QueryRow("SELECT n FROM test") + + var n int64 + err = row.Scan(&n) + require.NoError(t, err) + + require.NoError(t, tx.Rollback()) + + tx, err = db.Begin() + require.NoError(t, err) + + _, err = tx.Exec("INSERT INTO test(n) VALUES(1)") + require.NoError(t, err) + + require.NoError(t, tx.Rollback()) +} + func newServers(t *testing.T, listeners []net.Listener) (*rafttest.Control, func()) { t.Helper() diff --git a/internal/bindings/server.go b/internal/bindings/server.go index 3cdeb2e2..ce8e34b3 100644 --- a/internal/bindings/server.go +++ b/internal/bindings/server.go @@ -37,6 +37,7 @@ const ( RequestFinalize = C.DQLITE_REQUEST_FINALIZE RequestExecSQL = C.DQLITE_REQUEST_EXEC_SQL RequestQuerySQL = C.DQLITE_REQUEST_QUERY_SQL + RequestInterrupt = C.DQLITE_REQUEST_INTERRUPT ) // Response types. diff --git a/internal/client/client.go b/internal/client/client.go index b9483bef..61f56ac4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -68,6 +68,46 @@ func (c *Client) More(ctx context.Context, response *Message) error { return c.recv(response) } +// Interrupt sends an interrupt request and awaits for the server's empty +// response. +func (c *Client) Interrupt(ctx context.Context, request *Message, response *Message) error { + // We need to take a lock since the dqlite server currently does not + // support concurrent requests. + c.mu.Lock() + defer c.mu.Unlock() + + // Honor the ctx deadline, if present, or use a default. + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(2 * time.Second) + } + c.conn.SetDeadline(deadline) + + defer request.Reset() + + EncodeInterrupt(request, 0) + + if err := c.send(request); err != nil { + return errors.Wrap(err, "failed to send interrupt request") + } + + for { + if err := c.recv(response); err != nil { + response.Reset() + return errors.Wrap(err, "failed to receive response") + } + + mtype, _ := response.getHeader() + response.Reset() + + if mtype == bindings.ResponseEmpty { + break + } + } + + return nil +} + // Close the client connection. func (c *Client) Close() error { c.log(bindings.LogInfo, "closing client") diff --git a/internal/client/request.go b/internal/client/request.go index 77e56deb..1cfe9a55 100644 --- a/internal/client/request.go +++ b/internal/client/request.go @@ -89,3 +89,10 @@ func EncodeQuerySQL(request *Message, db uint64, sql string, values NamedValues) request.putHeader(bindings.RequestQuerySQL) } + +// EncodeInterrupt encodes a Interrupt request. +func EncodeInterrupt(request *Message, db uint64) { + request.putUint64(db) + + request.putHeader(bindings.RequestInterrupt) +} diff --git a/internal/client/schema.go b/internal/client/schema.go index 8dac96b4..bc7b6d4f 100644 --- a/internal/client/schema.go +++ b/internal/client/schema.go @@ -12,6 +12,7 @@ package client //go:generate ./schema.sh --request Finalize db:uint32 stmt:uint32 //go:generate ./schema.sh --request ExecSQL db:uint64 sql:string values:NamedValues //go:generate ./schema.sh --request QuerySQL db:uint64 sql:string values:NamedValues +//go:generate ./schema.sh --request Interrupt db:uint64 //go:generate ./schema.sh --response init //go:generate ./schema.sh --response Failure code:uint64 message:string