diff --git a/serial.go b/serial.go index 3e4f3b1..3ee6c33 100644 --- a/serial.go +++ b/serial.go @@ -6,7 +6,10 @@ package serial -import "time" +import ( + "context" + "time" +) //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go @@ -22,6 +25,13 @@ type Port interface { // the serial port or an error occurs. Read(p []byte) (n int, err error) + // Stores data received from the serial port into the provided byte array + // buffer. The function returns the number of bytes read. + // + // The Read function blocks until (at least) one byte is received from + // the serial port, an error occurs, or ctx is canceled. + ReadContext(ctx context.Context, p []byte) (n int, err error) + // Send the content of the data byte array to the serial port. // Returns the number of bytes written. Write(p []byte) (n int, err error) @@ -142,6 +152,8 @@ const ( PortClosed // FunctionNotImplemented the requested function is not implemented FunctionNotImplemented + // ReadCanceled the read was canceled + ReadCanceled ) // EncodedErrorString returns a string explaining the error code @@ -171,6 +183,8 @@ func (e PortError) EncodedErrorString() string { return "Port has been closed" case FunctionNotImplemented: return "Function not implemented" + case ReadCanceled: + return "Read was canceled" default: return "Other error" } diff --git a/serial_linux_test.go b/serial_linux_test.go index 2eda7f3..b327925 100644 --- a/serial_linux_test.go +++ b/serial_linux_test.go @@ -11,6 +11,9 @@ package serial import ( "context" + "errors" + "io" + "os" "os/exec" "testing" "time" @@ -18,18 +21,41 @@ import ( "github.com/stretchr/testify/require" ) -func startSocatAndWaitForPort(t *testing.T, ctx context.Context) *exec.Cmd { - cmd := exec.CommandContext(ctx, "socat", "-D", "STDIO", "pty,link=/tmp/faketty") - r, err := cmd.StderrPipe() - require.NoError(t, err) +const ttyPath = "/tmp/faketty" + +type ttyProc struct { + t *testing.T + cmd *exec.Cmd +} + +func (tp *ttyProc) Close() error { + err := tp.cmd.Process.Signal(os.Interrupt) + require.NoError(tp.t, err) + return tp.cmd.Wait() +} + +func (tp *ttyProc) waitForPort() { + for { + _, err := os.Stat(ttyPath) + if err == nil { + return + } + if !errors.Is(err, os.ErrNotExist) { + require.NoError(tp.t, err) + } + time.Sleep(time.Millisecond) + } +} + +func startSocatAndWaitForPort(t *testing.T, ctx context.Context) io.Closer { + cmd := exec.CommandContext(ctx, "socat", "STDIO", "pty,link="+ttyPath) require.NoError(t, cmd.Start()) - // Let our fake serial port node appear. - // socat will write to stderr before starting transfer phase; - // we don't really care what, just that it did, because then it's ready. - buf := make([]byte, 1024) - _, err = r.Read(buf) - require.NoError(t, err) - return cmd + socat := &ttyProc{ + t: t, + cmd: cmd, + } + socat.waitForPort() + return socat } func TestSerialReadAndCloseConcurrency(t *testing.T) { @@ -39,26 +65,71 @@ func TestSerialReadAndCloseConcurrency(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cmd := startSocatAndWaitForPort(t, ctx) - go cmd.Wait() + socat := startSocatAndWaitForPort(t, ctx) + defer socat.Close() - port, err := Open("/tmp/faketty", &Mode{}) + port, err := Open(ttyPath, &Mode{}) require.NoError(t, err) + defer port.Close() + buf := make([]byte, 100) go port.Read(buf) // let port.Read to start time.Sleep(time.Millisecond * 1) - port.Close() } func TestDoubleCloseIsNoop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cmd := startSocatAndWaitForPort(t, ctx) - go cmd.Wait() + socat := startSocatAndWaitForPort(t, ctx) + defer socat.Close() - port, err := Open("/tmp/faketty", &Mode{}) + port, err := Open(ttyPath, &Mode{}) require.NoError(t, err) require.NoError(t, port.Close()) require.NoError(t, port.Close()) } + +func TestCancelStopsRead(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + socat := startSocatAndWaitForPort(t, ctx) + defer socat.Close() + + port, err := Open(ttyPath, &Mode{}) + require.NoError(t, err) + defer port.Close() + + readCtx, readCancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + var readErr error + go func() { + buf := make([]byte, 100) + _, readErr = port.ReadContext(readCtx, buf) + close(done) + }() + + time.Sleep(time.Millisecond) + select { + case <-done: + require.NoError(t, readErr) + require.Fail(t, "expected reading to be in-progress") + default: + } + + readCancel() + + time.Sleep(time.Millisecond) + select { + case <-done: + default: + require.Fail(t, "expected reading to be finished") + + } + + var portErr *PortError + if !errors.As(readErr, &portErr) { + require.Fail(t, "expected read error to be a port error") + } + require.Equal(t, portErr.Code(), ReadCanceled) +} diff --git a/serial_unix.go b/serial_unix.go index 541660c..0df5efb 100644 --- a/serial_unix.go +++ b/serial_unix.go @@ -4,11 +4,13 @@ // license that can be found in the LICENSE file. // +//go:build linux || darwin || freebsd || openbsd // +build linux darwin freebsd openbsd package serial import ( + "context" "io/ioutil" "regexp" "strings" @@ -57,18 +59,36 @@ func (port *unixPort) Close() error { } func (port *unixPort) Read(p []byte) (int, error) { + return port.ReadContext(context.Background(), p) +} + +func (port *unixPort) ReadContext(ctx context.Context, p []byte) (int, error) { port.closeLock.RLock() defer port.closeLock.RUnlock() if atomic.LoadUint32(&port.opened) != 1 { return 0, &PortError{code: PortClosed} } + cancelSignal := &unixutils.Pipe{} + if err := cancelSignal.Open(); err != nil { + port.Close() + return 0, &PortError{code: PortClosed, causedBy: err} + } + defer cancelSignal.Close() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + <-ctx.Done() + cancelSignal.Write([]byte{0}) + }() + var deadline time.Time if port.readTimeout != NoTimeout { deadline = time.Now().Add(port.readTimeout) } - fds := unixutils.NewFDSet(port.handle, port.closeSignal.ReadFD()) + fds := unixutils.NewFDSet(port.handle, port.closeSignal.ReadFD(), cancelSignal.ReadFD()) for { timeout := time.Duration(-1) if port.readTimeout != NoTimeout { @@ -84,6 +104,9 @@ func (port *unixPort) Read(p []byte) (int, error) { if res.IsReadable(port.closeSignal.ReadFD()) { return 0, &PortError{code: PortClosed} } + if res.IsReadable(cancelSignal.ReadFD()) { + return 0, &PortError{code: ReadCanceled, causedBy: ctx.Err()} + } if !res.IsReadable(port.handle) { // Timeout happened return 0, nil @@ -247,7 +270,8 @@ func nativeOpen(portName string, mode *Mode) (*unixPort, error) { port.acquireExclusiveAccess() - // This pipe is used as a signal to cancel blocking Read + // This pipe is used as a signal to cancel blocking Read when the port is + // closed pipe := &unixutils.Pipe{} if err := pipe.Open(); err != nil { port.Close() diff --git a/serial_windows.go b/serial_windows.go index af9a620..0b04d8c 100644 --- a/serial_windows.go +++ b/serial_windows.go @@ -18,6 +18,7 @@ package serial */ import ( + "context" "sync" "syscall" "time" @@ -76,6 +77,10 @@ func (port *windowsPort) Close() error { } func (port *windowsPort) Read(p []byte) (int, error) { + return port.ReadContext(context.Background(), p) +} + +func (port *windowsPort) ReadContext(ctx context.Context, p []byte) (int, error) { var readed uint32 ev, err := createOverlappedEvent() if err != nil { @@ -83,6 +88,13 @@ func (port *windowsPort) Read(p []byte) (int, error) { } defer syscall.CloseHandle(ev.HEvent) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + <-ctx.Done() + syscall.CancelIoEx(port.handle, ev) + }() + cycles := int64(0) for { err := syscall.ReadFile(port.handle, p, &readed, ev) @@ -93,8 +105,12 @@ func (port *windowsPort) Read(p []byte) (int, error) { case nil: // operation completed successfully case syscall.ERROR_OPERATION_ABORTED: - // port may have been closed - return int(readed), &PortError{code: PortClosed, causedBy: err} + if port.handle == 0 { + // port may have been closed + return int(readed), &PortError{code: PortClosed, causedBy: err} + } + // read was canceled + return int(readed), &PortError{code: ReadCanceled, causedBy: err} default: // error happened return int(readed), err diff --git a/serial_windows_test.go b/serial_windows_test.go new file mode 100644 index 0000000..af9bf9c --- /dev/null +++ b/serial_windows_test.go @@ -0,0 +1,115 @@ +package serial + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func openTestPort(t *testing.T) (Port, error) { + ports, err := GetPortsList() + if err != nil || len(ports) == 0 { + t.SkipNow() + } + + mode := Mode{ + BaudRate: 115200, + DataBits: 8, + Parity: NoParity, + StopBits: OneStopBit, + } + return Open(ports[0], &mode) +} + +func TestOpenClose(t *testing.T) { + // prevent port from being busy in other tests + defer time.Sleep(time.Millisecond) + + port, err := openTestPort(t) + require.NoError(t, err) + port.Close() +} + +func TestOpenReadClosed(t *testing.T) { + // prevent port from being busy in other tests + defer time.Sleep(time.Millisecond) + + port, err := openTestPort(t) + require.NoError(t, err) + defer port.Close() + + done := make(chan struct{}) + var readErr error + go func() { + buf := make([]byte, 100) + _, readErr = port.ReadContext(context.Background(), buf) + close(done) + }() + + time.Sleep(time.Millisecond) + select { + case <-done: + require.NoError(t, readErr) + require.Fail(t, "expected reading to be in-progress") + default: + } + + port.Close() + + time.Sleep(time.Millisecond) + select { + case <-done: + default: + require.Fail(t, "expected reading to be done") + } + + var portErr *PortError + if !errors.As(readErr, &portErr) { + require.Fail(t, "expected read error to be a port error") + } + require.Equal(t, portErr.Code(), PortClosed) +} + +func TestOpenReadCanceled(t *testing.T) { + // prevent port from being busy in other tests + defer time.Sleep(time.Millisecond) + + port, err := openTestPort(t) + require.NoError(t, err) + defer port.Close() + + readCtx, readCancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + var readErr error + go func() { + buf := make([]byte, 100) + _, readErr = port.ReadContext(readCtx, buf) + close(done) + }() + + time.Sleep(time.Millisecond) + select { + case <-done: + require.NoError(t, readErr) + require.Fail(t, "expected reading to be in-progress") + default: + } + + readCancel() + + time.Sleep(time.Millisecond) + select { + case <-done: + default: + require.Fail(t, "expected reading to be done") + } + + var portErr *PortError + if !errors.As(readErr, &portErr) { + require.Fail(t, "expected read error to be a port error") + } + require.Equal(t, portErr.Code(), ReadCanceled) +}