Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Read cancellation #121

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

package serial

import "time"
import (
"context"
"time"
)

//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go

Expand All @@ -22,6 +25,13 @@ type Port interface {
// the serial port or an error occurs.
Read(p []byte) (n int, err error)

// Stores data received from the serial port into the provided byte array
// buffer. The function returns the number of bytes read.
//
// The Read function blocks until (at least) one byte is received from
// the serial port, an error occurs, or ctx is canceled.
ReadContext(ctx context.Context, p []byte) (n int, err error)

// Send the content of the data byte array to the serial port.
// Returns the number of bytes written.
Write(p []byte) (n int, err error)
Expand Down Expand Up @@ -142,6 +152,8 @@ const (
PortClosed
// FunctionNotImplemented the requested function is not implemented
FunctionNotImplemented
// ReadCanceled the read was canceled
ReadCanceled
)

// EncodedErrorString returns a string explaining the error code
Expand Down Expand Up @@ -171,6 +183,8 @@ func (e PortError) EncodedErrorString() string {
return "Port has been closed"
case FunctionNotImplemented:
return "Function not implemented"
case ReadCanceled:
return "Read was canceled"
default:
return "Other error"
}
Expand Down
107 changes: 89 additions & 18 deletions serial_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,51 @@ package serial

import (
"context"
"errors"
"io"
"os"
"os/exec"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func startSocatAndWaitForPort(t *testing.T, ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, "socat", "-D", "STDIO", "pty,link=/tmp/faketty")
r, err := cmd.StderrPipe()
require.NoError(t, err)
const ttyPath = "/tmp/faketty"

type ttyProc struct {
t *testing.T
cmd *exec.Cmd
}

func (tp *ttyProc) Close() error {
err := tp.cmd.Process.Signal(os.Interrupt)
require.NoError(tp.t, err)
return tp.cmd.Wait()
}

func (tp *ttyProc) waitForPort() {
for {
_, err := os.Stat(ttyPath)
if err == nil {
return
}
if !errors.Is(err, os.ErrNotExist) {
require.NoError(tp.t, err)
}
time.Sleep(time.Millisecond)
}
}

func startSocatAndWaitForPort(t *testing.T, ctx context.Context) io.Closer {
cmd := exec.CommandContext(ctx, "socat", "STDIO", "pty,link="+ttyPath)
require.NoError(t, cmd.Start())
// Let our fake serial port node appear.
// socat will write to stderr before starting transfer phase;
// we don't really care what, just that it did, because then it's ready.
buf := make([]byte, 1024)
_, err = r.Read(buf)
require.NoError(t, err)
return cmd
socat := &ttyProc{
t: t,
cmd: cmd,
}
socat.waitForPort()
return socat
}

func TestSerialReadAndCloseConcurrency(t *testing.T) {
Expand All @@ -39,26 +65,71 @@ func TestSerialReadAndCloseConcurrency(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cmd := startSocatAndWaitForPort(t, ctx)
go cmd.Wait()
socat := startSocatAndWaitForPort(t, ctx)
defer socat.Close()

port, err := Open("/tmp/faketty", &Mode{})
port, err := Open(ttyPath, &Mode{})
require.NoError(t, err)
defer port.Close()

buf := make([]byte, 100)
go port.Read(buf)
// let port.Read to start
time.Sleep(time.Millisecond * 1)
port.Close()
}

func TestDoubleCloseIsNoop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cmd := startSocatAndWaitForPort(t, ctx)
go cmd.Wait()
socat := startSocatAndWaitForPort(t, ctx)
defer socat.Close()

port, err := Open("/tmp/faketty", &Mode{})
port, err := Open(ttyPath, &Mode{})
require.NoError(t, err)
require.NoError(t, port.Close())
require.NoError(t, port.Close())
}

func TestCancelStopsRead(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
socat := startSocatAndWaitForPort(t, ctx)
defer socat.Close()

port, err := Open(ttyPath, &Mode{})
require.NoError(t, err)
defer port.Close()

readCtx, readCancel := context.WithCancel(context.Background())
done := make(chan struct{})
var readErr error
go func() {
buf := make([]byte, 100)
_, readErr = port.ReadContext(readCtx, buf)
close(done)
}()

time.Sleep(time.Millisecond)
select {
case <-done:
require.NoError(t, readErr)
require.Fail(t, "expected reading to be in-progress")
default:
}

readCancel()

time.Sleep(time.Millisecond)
select {
case <-done:
default:
require.Fail(t, "expected reading to be finished")

}

var portErr *PortError
if !errors.As(readErr, &portErr) {
require.Fail(t, "expected read error to be a port error")
}
require.Equal(t, portErr.Code(), ReadCanceled)
}
28 changes: 26 additions & 2 deletions serial_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
// license that can be found in the LICENSE file.
//

//go:build linux || darwin || freebsd || openbsd
// +build linux darwin freebsd openbsd

package serial

import (
"context"
"io/ioutil"
"regexp"
"strings"
Expand Down Expand Up @@ -57,18 +59,36 @@ func (port *unixPort) Close() error {
}

func (port *unixPort) Read(p []byte) (int, error) {
return port.ReadContext(context.Background(), p)
}

func (port *unixPort) ReadContext(ctx context.Context, p []byte) (int, error) {
port.closeLock.RLock()
defer port.closeLock.RUnlock()
if atomic.LoadUint32(&port.opened) != 1 {
return 0, &PortError{code: PortClosed}
}

cancelSignal := &unixutils.Pipe{}
if err := cancelSignal.Open(); err != nil {
port.Close()
return 0, &PortError{code: PortClosed, causedBy: err}
}
defer cancelSignal.Close()

ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
cancelSignal.Write([]byte{0})
}()

var deadline time.Time
if port.readTimeout != NoTimeout {
deadline = time.Now().Add(port.readTimeout)
}

fds := unixutils.NewFDSet(port.handle, port.closeSignal.ReadFD())
fds := unixutils.NewFDSet(port.handle, port.closeSignal.ReadFD(), cancelSignal.ReadFD())
for {
timeout := time.Duration(-1)
if port.readTimeout != NoTimeout {
Expand All @@ -84,6 +104,9 @@ func (port *unixPort) Read(p []byte) (int, error) {
if res.IsReadable(port.closeSignal.ReadFD()) {
return 0, &PortError{code: PortClosed}
}
if res.IsReadable(cancelSignal.ReadFD()) {
return 0, &PortError{code: ReadCanceled, causedBy: ctx.Err()}
}
if !res.IsReadable(port.handle) {
// Timeout happened
return 0, nil
Expand Down Expand Up @@ -247,7 +270,8 @@ func nativeOpen(portName string, mode *Mode) (*unixPort, error) {

port.acquireExclusiveAccess()

// This pipe is used as a signal to cancel blocking Read
// This pipe is used as a signal to cancel blocking Read when the port is
// closed
pipe := &unixutils.Pipe{}
if err := pipe.Open(); err != nil {
port.Close()
Expand Down
20 changes: 18 additions & 2 deletions serial_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package serial
*/

import (
"context"
"sync"
"syscall"
"time"
Expand Down Expand Up @@ -76,13 +77,24 @@ func (port *windowsPort) Close() error {
}

func (port *windowsPort) Read(p []byte) (int, error) {
return port.ReadContext(context.Background(), p)
}

func (port *windowsPort) ReadContext(ctx context.Context, p []byte) (int, error) {
var readed uint32
ev, err := createOverlappedEvent()
if err != nil {
return 0, err
}
defer syscall.CloseHandle(ev.HEvent)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
syscall.CancelIoEx(port.handle, ev)
}()

cycles := int64(0)
for {
err := syscall.ReadFile(port.handle, p, &readed, ev)
Expand All @@ -93,8 +105,12 @@ func (port *windowsPort) Read(p []byte) (int, error) {
case nil:
// operation completed successfully
case syscall.ERROR_OPERATION_ABORTED:
// port may have been closed
return int(readed), &PortError{code: PortClosed, causedBy: err}
if port.handle == 0 {
// port may have been closed
return int(readed), &PortError{code: PortClosed, causedBy: err}
}
// read was canceled
return int(readed), &PortError{code: ReadCanceled, causedBy: err}
default:
// error happened
return int(readed), err
Expand Down