diff --git a/driver/driver.go b/driver/driver.go index a0fafc32..4b46fdad 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -335,6 +335,14 @@ type Conn struct { tracing client.LogLevel } +func (c *Conn) ResetSession(ctx context.Context) error { + err := c.protocol.ConnCheck() + if err != nil { + return driver.ErrBadConn + } + return nil +} + // PrepareContext returns a prepared statement, bound to this connection. // context is for the preparation of the statement, it must not store the // context within the statement itself. @@ -353,7 +361,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e if c.tracing != client.LogNone { start = time.Now() } - err := c.protocol.Call(ctx, &c.request, &c.response); + err := c.protocol.Call(ctx, &c.request, &c.response) if c.tracing != client.LogNone { c.log(c.tracing, "%.3fs request prepared: %q", time.Since(start).Seconds(), query) } @@ -392,7 +400,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name if c.tracing != client.LogNone { start = time.Now() } - err := c.protocol.Call(ctx, &c.request, &c.response); + err := c.protocol.Call(ctx, &c.request, &c.response) if c.tracing != client.LogNone { c.log(c.tracing, "%.3fs request exec: %q", time.Since(start).Seconds(), query) } @@ -428,7 +436,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam if c.tracing != client.LogNone { start = time.Now() } - err := c.protocol.Call(ctx, &c.request, &c.response); + err := c.protocol.Call(ctx, &c.request, &c.response) if c.tracing != client.LogNone { c.log(c.tracing, "%.3fs request query: %q", time.Since(start).Seconds(), query) } @@ -588,7 +596,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive if s.tracing != client.LogNone { start = time.Now() } - err := s.protocol.Call(ctx, s.request, s.response); + err := s.protocol.Call(ctx, s.request, s.response) if s.tracing != client.LogNone { s.log(s.tracing, "%.3fs request prepared: %q", time.Since(start).Seconds(), s.sql) } @@ -627,7 +635,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv if s.tracing != client.LogNone { start = time.Now() } - err := s.protocol.Call(ctx, s.request, s.response); + err := s.protocol.Call(ctx, s.request, s.response) if s.tracing != client.LogNone { s.log(s.tracing, "%.3fs request prepared: %q", time.Since(start).Seconds(), s.sql) } diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 8a48ffcf..961acd1e 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -6,6 +6,7 @@ import ( "io" "net" "sync" + "syscall" "time" "github.com/pkg/errors" @@ -30,6 +31,44 @@ func newProtocol(version uint64, conn net.Conn) *Protocol { return protocol } +func connCheck(conn net.Conn) error { + var sysErr error + + sysConn, ok := conn.(syscall.Conn) + if !ok { + return nil + } + rawConn, err := sysConn.SyscallConn() + if err != nil { + return err + } + + err = rawConn.Read(func(fd uintptr) bool { + var buf [1]byte + n, err := syscall.Read(int(fd), buf[:]) + switch { + case n == 0 && err == nil: + sysErr = io.EOF + case n > 0: + sysErr = syscall.EBADFD // TODO assign a sensible error here. + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + sysErr = nil + default: + sysErr = err + } + return true + }) + if err != nil { + return err + } + + return sysErr +} + +func (p *Protocol) ConnCheck() error { + return connCheck(p.conn) +} + // Call invokes a dqlite RPC, sending a request message and receiving a // response message. func (p *Protocol) Call(ctx context.Context, request, response *Message) (err error) {