Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[CONJ-375] LOAD DATA LOCAL INFILE correction for file length > 16mb
  • Loading branch information
rusher committed Oct 25, 2016
1 parent 3180799 commit 01ad520
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 60 deletions.
161 changes: 105 additions & 56 deletions src/main/java/org/mariadb/jdbc/internal/stream/PacketOutputStream.java
Expand Up @@ -190,76 +190,125 @@ public void setCompressSeqNo(int compressSeqNo) {
*/
public void sendFile(InputStream is, int seq) throws IOException {
this.seqNo = seq;

if (!useCompression) {
buffer.clear();
//reserve the 4th first bytes for header
buffer.position(4);
this.checkPacketLength = false;
byte[] buffer = new byte[BUFFER_DEFAULT_SIZE];
int len;
while ((len = is.read(buffer)) > 0) {
write(buffer, 0, len);
}
finishPacketWithoutRelease(true);
releaseBuffer();
writeEmptyPacket(this.seqNo++);
} else {
buffer.clear();
buffer.position(4);
this.checkPacketLength = false;
//No compression
//According to protocol, buffer can be up to max_allowed_packet, but if max_allowed_packet size > a packet :
// - it may take a lot of memory client side
// - it will be faster to send packet directly
//so, reserve the 4th first bytes for packet header to permit writing buffer immediately buffer to socket

int bufLength = Math.min(maxAllowedPacket, MAX_PACKET_LENGTH) - 4;
byte[] buf = new byte[bufLength + 4];

//write file into buffer
byte[] readFileBuffer = new byte[BUFFER_DEFAULT_SIZE];
int len;
while ((len = is.read(readFileBuffer)) > 0) {
write(readFileBuffer, 0, len);
while ((len = is.read(buf, 4, bufLength)) > 0) {
buf[0] = (byte) ((len) & 0xff);
buf[1] = (byte) ((len) >>> 8);
buf[2] = (byte) ((len) >>> 16);
buf[3] = (byte) seqNo++;
outputStream.write(buf, 0, len + 4);

if (logger.isTraceEnabled() && logQuery) {
logger.trace("send packet local file packet seq:" + (seqNo - 1) + " length:" + (len));
}
}

if (buffer.position() > 4) {
checkPacketMaxSize(buffer.position());
//send empty packet when finish
buf[0] = ((byte) 0);
buf[1] = ((byte) 0);
buf[2] = ((byte) 0);
buf[3] = ((byte) seqNo);
outputStream.write(buf, 0, 4);

buffer.flip();
int limit = buffer.limit();
buffer.position(4);
int position = 0;
int expectedPacketSize = limit + HEADER_LENGTH * ((limit / maxPacketSize) + 1);
byte[] bufferBytes = new byte[expectedPacketSize];
} else {
//compression
//reserve 11 byte for header (7 bytes for compression header + 4 byte packet header)
int bufLength = Math.min(maxAllowedPacket - 11, MAX_PACKET_LENGTH - 11);
byte[] buf = new byte[bufLength + 11];
int len;

//write first packet
while (position < expectedPacketSize - 4) {
int length = buffer.remaining();
if (length > maxPacketSize) {
length = maxPacketSize;
while ((len = is.read(buf, 11, bufLength)) > 0) {
boolean compressedPacketSend = false;

if (len > MIN_COMPRESSION_SIZE) {
buf[7] = (byte) ((len) & 0xff);
buf[8] = (byte) (((len) >> 8) & 0xff);
buf[9] = (byte) (((len) >> 16) & 0xff);
buf[10] = (byte) this.seqNo++;
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DeflaterOutputStream deflater = new DeflaterOutputStream(baos);
deflater.write(buf, 7, len + 4);
deflater.finish();
deflater.close();

byte[] compressedBytes = baos.toByteArray();
baos.close();

if (compressedBytes.length < (int) (MIN_COMPRESSION_RATIO * len)) {

int compressedLength = compressedBytes.length;

buf[0] = (byte) ((compressedLength) & 0xff);
buf[1] = (byte) (((compressedLength) >> 8) & 0xff);
buf[2] = (byte) (((compressedLength) >> 16) & 0xff);
buf[3] = (byte) this.compressSeqNo++;
buf[4] = (byte) ((len + 4) & 0xff);
buf[5] = (byte) (((len + 4) >> 8) & 0xff);
buf[6] = (byte) (((len + 4) >> 16) & 0xff);
outputStream.write(buf, 0, 7);
outputStream.write(compressedBytes, 0, compressedLength);
compressedPacketSend = true;
if (logger.isTraceEnabled()) {
logger.trace("send compress packet seq:" + compressSeqNo + " length:" + compressedLength
+ " gzip data");
}
}
bufferBytes[position++] = (byte) (length & 0xff);
bufferBytes[position++] = (byte) (length >>> 8);
bufferBytes[position++] = (byte) (length >>> 16);
bufferBytes[position++] = (byte) seqNo++;

if (length > 0) {
buffer.get(bufferBytes, position, length);
position += length;
}

if (!compressedPacketSend) {
//uncompress packet : 7 bytes compress packet header + standard packet header

buf[0] = (byte) ((len + 4) & 0xff);
buf[1] = (byte) (((len + 4) >> 8) & 0xff);
buf[2] = (byte) (((len + 4) >> 16) & 0xff);
buf[3] = (byte) this.compressSeqNo++;
buf[4] = 0;
buf[5] = 0;
buf[6] = 0;
buf[7] = (byte) ((len) & 0xff);
buf[8] = (byte) (((len) >> 8) & 0xff);;
buf[9] = (byte) (((len) >> 16) & 0xff);
buf[10] = (byte) this.seqNo++;

outputStream.write(buf, 0, len + 11);

if (logger.isTraceEnabled()) {
logger.trace("send compress packet seq:" + compressSeqNo + " length:" + len
+ " data:" + Utils.hexdump(buf, maxQuerySizeToLog, 7, len));
}
}
//write second packet (empty packet)
bufferBytes[position++] = (byte) 0;
bufferBytes[position++] = (byte) 0;
bufferBytes[position++] = (byte) 0;
bufferBytes[position++] = (byte) seqNo++;

//send data
compressedAndSend(position, bufferBytes, true);
} else {
writeEmptyPacket(seqNo);

}

//save big buffer next query to avoid new allocation if next query size is similar
if (buffer.capacity() > BUFFER_DEFAULT_SIZE) {
buffer = firstBuffer;
//write empty packet
buf[0] = (byte) 4;
buf[1] = (byte) 0;
buf[2] = (byte) 0;
buf[3] = (byte) compressSeqNo++;
buf[4] = (byte) 0;
buf[5] = (byte) 0;
buf[6] = (byte) 0;
buf[7] = (byte) 0;
buf[8] = (byte) 0;
buf[9] = (byte) 0;
buf[10] = (byte) seqNo++;
outputStream.write(buf, 0, 11);

if (logger.isTraceEnabled()) {
logger.trace("send compress empty packet seq:" + compressSeqNo);
}

buffer.clear();
buffer.position(4);
}
}

Expand Down
77 changes: 73 additions & 4 deletions src/test/java/org/mariadb/jdbc/LocalInfileInputStreamTest.java
Expand Up @@ -4,10 +4,7 @@
import org.junit.BeforeClass;
import org.junit.Test;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.*;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -25,6 +22,8 @@ public static void initClass() throws SQLException {
createTable("LocalInfileInputStreamTest", "id int, test varchar(100)");
createTable("ttlocal", "id int, test varchar(100)");
createTable("ldinfile", "a varchar(10)");
createTable("`infile`", "`a` varchar(50) DEFAULT NULL, `b` varchar(50) DEFAULT NULL",
"ENGINE=InnoDB DEFAULT CHARSET=latin1");
}

@Test
Expand Down Expand Up @@ -166,4 +165,74 @@ private void validateRecord(ResultSet rs, int expectedId, String expectedTest) t
Assert.assertEquals(expectedId, id);
Assert.assertEquals(expectedTest, test);
}

private File createTmpData(int recordNumber) throws Exception {
File file = File.createTempFile("./infile" + recordNumber, ".tmp");

//write it
try (BufferedWriter writer = new BufferedWriter(new FileWriter(file))) {
// Every row is 8 bytes to make counting easier
for (long i = 0; i < recordNumber; i++) {
writer.write("\"a\",\"b\"");
writer.write("\n");
}
}

return file;
}

private void checkBigLocalInfile(int fileSize) throws Exception {
int recordNumber = fileSize / 8;

try (Statement statement = sharedConnection.createStatement()) {
statement.execute("truncate `infile`");
File file = createTmpData(recordNumber);

try (InputStream is = new BufferedInputStream(new FileInputStream(file))) {
MariaDbStatement stmt = statement.unwrap(MariaDbStatement.class);
stmt.setLocalInfileInputStream(is);
int insertNumber = stmt.executeUpdate("LOAD DATA LOCAL INFILE 'ignoredFileName' "
+ "INTO TABLE `infile` "
+ "COLUMNS TERMINATED BY ',' ENCLOSED BY '\\\"' ESCAPED BY '\\\\' "
+ "LINES TERMINATED BY '\\n' (`a`, `b`)");
Assert.assertEquals(insertNumber, recordNumber);
}

statement.setFetchSize(1000); //to avoid using too much memory for tests
try (ResultSet rs = statement.executeQuery("SELECT * FROM `infile`")) {
for (int i = 0; i < recordNumber; i++) {
Assert.assertTrue("record " + i + " doesn't exist",rs.next());
Assert.assertEquals("a", rs.getString(1));
Assert.assertEquals("b", rs.getString(2));
}
Assert.assertFalse(rs.next());
}

}
}

/**
* CONJ-375 : error with local infile with size > 16mb.
*
* @throws Exception if error occus
*/
@Test
public void testSmallBigLocalInfileInputStream() throws Exception {
checkBigLocalInfile(256);
}

@Test
public void test2xBigLocalInfileInputStream() throws Exception {
checkBigLocalInfile(16777216 * 2);
}

@Test
public void test2xMaxAllowedPacketLocalInfileInputStream() throws Exception {
ResultSet rs = sharedConnection.createStatement().executeQuery("select @@max_allowed_packet");
rs.next();
int maxAllowedPacket = rs.getInt(1);

checkBigLocalInfile(maxAllowedPacket * 2);
}

}

0 comments on commit 01ad520

Please sign in to comment.