diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/ShuffleUtils.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/ShuffleUtils.java index bf58172ef1..df4281a94e 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/ShuffleUtils.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/ShuffleUtils.java @@ -125,7 +125,7 @@ public static void shuffleToMemory(byte[] shuffleData, LOG.debug("Read " + shuffleData.length + " bytes from input for " + identifier); } - } catch (InternalError | IOException e) { + } catch (InternalError | Exception e) { // Close the streams LOG.info("Failed to read data to memory for " + identifier + ". len=" + compressedLength + ", decomp=" + decompressedLength + ". ExceptionMessage=" + e.getMessage()); @@ -135,9 +135,12 @@ public static void shuffleToMemory(byte[] shuffleData, // on decompression failures. Catching and re-throwing as IOException // to allow fetch failure logic to be processed. throw new IOException(e); + } else if (e instanceof IOException) { + throw e; + } else { + // Re-throw as an IOException + throw new IOException(e); } - // Re-throw - throw e; } } diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/TestShuffleUtils.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/TestShuffleUtils.java index 1d2d4280b5..f61c7e5f66 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/TestShuffleUtils.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/common/shuffle/TestShuffleUtils.java @@ -41,6 +41,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.net.SocketTimeoutException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.BitSet; @@ -300,6 +301,58 @@ public void testInternalErrorTranslation() throws Exception { } } + @Test + public void testExceptionTranslation() throws Exception { + String codecErrorMsg = "codec failure"; + CompressionInputStream mockCodecStream = mock(CompressionInputStream.class); + when(mockCodecStream.read(any(byte[].class), anyInt(), anyInt())) + .thenThrow(new IllegalArgumentException(codecErrorMsg)); + Decompressor mockDecoder = mock(Decompressor.class); + CompressionCodec mockCodec = mock(CompressionCodec.class); + when(mockCodec.createDecompressor()).thenReturn(mockDecoder); + when(mockCodec.createInputStream(any(InputStream.class), any(Decompressor.class))) + .thenReturn(mockCodecStream); + byte[] header = new byte[] { (byte) 'T', (byte) 'I', (byte) 'F', (byte) 1}; + try { + ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header), + 1024, 128, mockCodec, false, 0, mock(Logger.class), null); + Assert.fail("shuffle was supposed to throw!"); + } catch (IOException e) { + Assert.assertTrue(e.getCause() instanceof IllegalArgumentException); + Assert.assertTrue(e.getMessage().contains(codecErrorMsg)); + } + CompressionInputStream mockCodecStream1 = mock(CompressionInputStream.class); + when(mockCodecStream1.read(any(byte[].class), anyInt(), anyInt())) + .thenThrow(new SocketTimeoutException(codecErrorMsg)); + CompressionCodec mockCodec1 = mock(CompressionCodec.class); + when(mockCodec1.createDecompressor()).thenReturn(mockDecoder); + when(mockCodec1.createInputStream(any(InputStream.class), any(Decompressor.class))) + .thenReturn(mockCodecStream1); + try { + ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header), + 1024, 128, mockCodec1, false, 0, mock(Logger.class), null); + Assert.fail("shuffle was supposed to throw!"); + } catch (IOException e) { + Assert.assertTrue(e instanceof SocketTimeoutException); + Assert.assertTrue(e.getMessage().contains(codecErrorMsg)); + } + CompressionInputStream mockCodecStream2 = mock(CompressionInputStream.class); + when(mockCodecStream2.read(any(byte[].class), anyInt(), anyInt())) + .thenThrow(new InternalError(codecErrorMsg)); + CompressionCodec mockCodec2 = mock(CompressionCodec.class); + when(mockCodec2.createDecompressor()).thenReturn(mockDecoder); + when(mockCodec2.createInputStream(any(InputStream.class), any(Decompressor.class))) + .thenReturn(mockCodecStream2); + try { + ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header), + 1024, 128, mockCodec2, false, 0, mock(Logger.class), null); + Assert.fail("shuffle was supposed to throw!"); + } catch (IOException e) { + Assert.assertTrue(e.getCause() instanceof InternalError); + Assert.assertTrue(e.getMessage().contains(codecErrorMsg)); + } + } + @Test public void testShuffleToDiskChecksum() throws Exception { // verify sending a stream of zeroes without checksum validation