Skip to content

Commit

Permalink
Merge pull request #68 from charleskorn/master
Browse files Browse the repository at this point in the history
Fix #60 (calling close on a UnixSocket with a pending read waits for the read to time out on Linux)
  • Loading branch information
headius committed Jan 9, 2020
2 parents b8c5ac2 + 3320077 commit bab2435
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;

import jnr.constants.platform.Errno;
import jnr.constants.platform.Shutdown;
import jnr.enxio.channels.Native;
import jnr.enxio.channels.NativeException;
Expand Down Expand Up @@ -54,6 +55,11 @@ public final int getFD() {

@Override
protected void implCloseSelectableChannel() throws IOException {
if (this.isConnected()) {
this.shutdownInput();
this.shutdownOutput();
}

Native.close(common.getFD());
}

Expand Down Expand Up @@ -85,7 +91,7 @@ public long write(ByteBuffer[] srcs, int offset,
@Override
public SocketChannel shutdownInput() throws IOException {
int n = Native.shutdown(common.getFD(), SHUT_RD);
if (n < 0) {
if (n < 0 && Native.getLastError() != Errno.ENOTCONN) {
throw new NativeException(Native.getLastErrorString(), Native.getLastError());
}
return this;
Expand All @@ -94,7 +100,7 @@ public SocketChannel shutdownInput() throws IOException {
@Override
public SocketChannel shutdownOutput() throws IOException {
int n = Native.shutdown(common.getFD(), SHUT_WR);
if (n < 0) {
if (n < 0 && Native.getLastError() != Errno.ENOTCONN) {
throw new NativeException(Native.getLastErrorString(), Native.getLastError());
}
return this;
Expand Down
84 changes: 84 additions & 0 deletions src/test/java/jnr/unixsocket/UnixSocketChannelTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package jnr.unixsocket;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.CountDownLatch;

import org.junit.Test;
import org.junit.Assume;

Expand Down Expand Up @@ -48,4 +53,83 @@ public void testAbstractNamespace() throws Exception {
assertEquals("local socket path", ABSTRACT, ch.getLocalSocketAddress().path());
}

@Test
public void testInterruptRead() throws Exception {
Path socketPath = getTemporarySocketFileName();
startServer(socketPath);

int readTimeoutInMilliseconds = 5000;

UnixSocket socket = createClient(socketPath, readTimeoutInMilliseconds);
CountDownLatch readStartLatch = new CountDownLatch(1);
ReadFromSocketRunnable runnable = new ReadFromSocketRunnable(readStartLatch, socket);

Thread readThread = new Thread(runnable);

readThread.setDaemon(true);

long startTime = System.nanoTime();
readThread.start();
readStartLatch.await();
Thread.sleep(100); // Wait for the thread to call read()
socket.close();
readThread.join();
long stopTime = System.nanoTime();

long duration = stopTime - startTime;
long durationInMilliseconds = duration / 1_000_000;

assertTrue("read() was not interrupted by close() before read() timed out", durationInMilliseconds < readTimeoutInMilliseconds);
assertEquals("read() threw an exception", null, runnable.getThrownOnThread());
}

private Path getTemporarySocketFileName() throws IOException {
Path socketPath = Files.createTempFile("jnr-unixsocket-tests", ".sock");
Files.delete(socketPath);
socketPath.toFile().deleteOnExit();

return socketPath;
}

private void startServer(Path socketPath) throws IOException {
UnixServerSocketChannel serverChannel = UnixServerSocketChannel.open();
serverChannel.configureBlocking(false);
serverChannel.socket().bind(new UnixSocketAddress(socketPath.toFile()));
}

private UnixSocket createClient(Path socketPath, int readTimeoutInMilliseconds) throws IOException {
UnixSocketChannel clientChannel = UnixSocketChannel.open(new UnixSocketAddress(socketPath.toFile()));
UnixSocket socket = new UnixSocket(clientChannel);
socket.setSoTimeout(readTimeoutInMilliseconds);

return socket;
}

private class ReadFromSocketRunnable implements Runnable {
private CountDownLatch readStartLatch;
private UnixSocket socket;
private IOException thrownOnThread;

private ReadFromSocketRunnable(CountDownLatch readStartLatch, UnixSocket socket) {
this.readStartLatch = readStartLatch;
this.socket = socket;
}

@Override
public void run() {
try {
readStartLatch.countDown();
socket.getInputStream().read();
} catch (IOException e) {
// EBADF (bad file descriptor) is thrown when read() is interrupted
if (!e.getMessage().equals("Bad file descriptor")) {
thrownOnThread = e;
}
}
}

private IOException getThrownOnThread() {
return thrownOnThread;
}
}
}

0 comments on commit bab2435

Please sign in to comment.