diff --git a/src/jni/bindings_vector.cpp b/src/jni/bindings_vector.cpp index d401efc66..56876d68e 100644 --- a/src/jni/bindings_vector.cpp +++ b/src/jni/bindings_vector.cpp @@ -107,16 +107,24 @@ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1da */ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1vector_1get_1validity(JNIEnv *env, jclass, jobject vector, - jlong array_size) { + jlong vector_size_elems) { duckdb_vector vec = vector_buf_to_vector(env, vector); if (env->ExceptionCheck()) { return nullptr; } + idx_t vector_size = jlong_to_idx(env, vector_size_elems); + if (env->ExceptionCheck()) { + return nullptr; + } + uint64_t *mask = duckdb_vector_get_validity(vec); - idx_t vec_len = duckdb_vector_size(); - idx_t mask_len = vec_len * sizeof(uint64_t) * array_size / 64; + idx_t vector_size_rounded = vector_size; + if (vector_size % 64 != 0) { + vector_size_rounded += 64 - (vector_size % 64); + } + idx_t mask_len = vector_size_rounded * sizeof(uint64_t) / 64; return make_data_buf(env, mask, mask_len); } diff --git a/src/main/java/org/duckdb/DuckDBAppender.java b/src/main/java/org/duckdb/DuckDBAppender.java index 063429781..2ca7819df 100644 --- a/src/main/java/org/duckdb/DuckDBAppender.java +++ b/src/main/java/org/duckdb/DuckDBAppender.java @@ -74,14 +74,14 @@ public class DuckDBAppender implements AutoCloseable { private static final LocalDateTime EPOCH_DATE_TIME = LocalDateTime.ofEpochSecond(0, 0, UTC); + private static final long MAX_TOP_LEVEL_ROWS = duckdb_vector_size(); + private final DuckDBConnection conn; private final String catalog; private final String schema; private final String table; - private final long maxRows; - private ByteBuffer appenderRef; private final Lock appenderRefLock = new ReentrantLock(); @@ -101,8 +101,6 @@ public class DuckDBAppender implements AutoCloseable { this.schema = schema; this.table = table; - this.maxRows = duckdb_vector_size(); - ByteBuffer appenderRef = null; ByteBuffer[] colTypes = null; ByteBuffer chunkRef = null; @@ -163,7 +161,7 @@ public DuckDBAppender endRow() throws SQLException { rowIdx++; Column prev = prevColumn; this.prevColumn = null; - if (rowIdx >= maxRows) { + if (rowIdx >= MAX_TOP_LEVEL_ROWS) { try { flush(); } catch (SQLException e) { @@ -2325,8 +2323,10 @@ private Column(Column parent, int idx, ByteBuffer colTypeRef, ByteBuffer vector, this.arraySize = duckdb_array_type_array_size(parent.colTypeRef); } + long maxElems = maxElementsCount(); if (colType.widthBytes > 0 || colType == DUCKDB_TYPE_DECIMAL) { - this.data = duckdb_vector_get_data(vectorRef, vectorSize()); + long vectorSizeBytes = maxElems * widthBytes(); + this.data = duckdb_vector_get_data(vectorRef, vectorSizeBytes); if (null == this.data) { throw new SQLException("cannot initialize data chunk vector data"); } @@ -2335,7 +2335,7 @@ private Column(Column parent, int idx, ByteBuffer colTypeRef, ByteBuffer vector, } duckdb_vector_ensure_validity_writable(vectorRef); - this.validity = duckdb_vector_get_validity(vectorRef, arraySize * parentArraySize()); + this.validity = duckdb_vector_get_validity(vectorRef, maxElems); if (null == this.validity) { throw new SQLException("cannot initialize data chunk vector validity"); } @@ -2353,15 +2353,18 @@ void reset(long listSize) throws SQLException { } void reset() throws SQLException { + long maxElems = maxElementsCount(); + if (null != this.data) { - this.data = duckdb_vector_get_data(vectorRef, vectorSize()); + long vectorSizeBytes = maxElems * widthBytes(); + this.data = duckdb_vector_get_data(vectorRef, vectorSizeBytes); if (null == this.data) { throw new SQLException("cannot reset data chunk vector data"); } } duckdb_vector_ensure_validity_writable(vectorRef); - this.validity = duckdb_vector_get_validity(vectorRef, arraySize * parentArraySize()); + this.validity = duckdb_vector_get_validity(vectorRef, maxElems); if (null == this.validity) { throw new SQLException("cannot reset data chunk vector validity"); } @@ -2432,12 +2435,17 @@ long parentArraySize() { return parent.arraySize; } - long vectorSize() { - if (null != parent && (parent.colType == DUCKDB_TYPE_LIST || parent.colType == DUCKDB_TYPE_MAP)) { - return listSize * widthBytes(); - } else { - return duckdb_vector_size() * widthBytes() * arraySize * parentArraySize(); + long maxElementsCount() { + Column ancestor = this; + while (null != ancestor) { + if (null != ancestor.parent && + (ancestor.parent.colType == DUCKDB_TYPE_LIST || ancestor.parent.colType == DUCKDB_TYPE_MAP)) { + break; + } + ancestor = ancestor.parent; } + long maxEntries = null != ancestor ? ancestor.listSize : DuckDBAppender.MAX_TOP_LEVEL_ROWS; + return maxEntries * arraySize * parentArraySize(); } } } diff --git a/src/main/java/org/duckdb/DuckDBBindings.java b/src/main/java/org/duckdb/DuckDBBindings.java index a8174d5a9..df0674c41 100644 --- a/src/main/java/org/duckdb/DuckDBBindings.java +++ b/src/main/java/org/duckdb/DuckDBBindings.java @@ -53,7 +53,7 @@ public class DuckDBBindings { static native ByteBuffer duckdb_vector_get_data(ByteBuffer vector, long size_bytes); - static native ByteBuffer duckdb_vector_get_validity(ByteBuffer vector, long array_size); + static native ByteBuffer duckdb_vector_get_validity(ByteBuffer vector, long vector_size_elems); static native void duckdb_vector_ensure_validity_writable(ByteBuffer vector); diff --git a/src/test/java/org/duckdb/TestAppenderCollection.java b/src/test/java/org/duckdb/TestAppenderCollection.java index 58f13459b..a6e38aace 100644 --- a/src/test/java/org/duckdb/TestAppenderCollection.java +++ b/src/test/java/org/duckdb/TestAppenderCollection.java @@ -1410,93 +1410,55 @@ public static void test_appender_list_basic_nested_list() throws Exception { } } - private static void assertMapsEqual(Object obj1, Map map2) throws Exception { - Map map1 = (Map) obj1; - assertEquals(map1.size(), map2.size()); - List> list2 = new ArrayList<>(map2.entrySet()); - int i = 0; - for (Map.Entry en : map1.entrySet()) { - assertEquals(en.getKey(), list2.get(i).getKey()); - assertEquals(en.getValue(), list2.get(i).getValue()); - i += 1; - } - } + public static void test_appender_list_bigint() throws Exception { + int count = 1 << 12; // auto flush twice + int tail = 7; // flushed on close + int listLen = (1 << 6) + 7; // increase this for stress tests - public static void test_appender_map_basic() throws Exception { - Map map1 = createMap(41, "foo", 42, "bar"); - Map map2 = createMap(41, "foo", 42, null, 43, "baz"); try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { - stmt.execute("CREATE TABLE tab1(col1 INTEGER, col2 MAP(INTEGER, VARCHAR))"); + stmt.execute("CREATE TABLE tab1(col1 INTEGER, col2 BIGINT[])"); try (DuckDBAppender appender = conn.createAppender("tab1")) { - appender.beginRow() - .append(41) - .append(map1) - .endRow() - .beginRow() - .append(42) - .append(map2) - .endRow() - .flush(); + for (int i = 0; i < count + tail; i++) { + List list = new ArrayList<>(); + for (long j = 0; j < Math.min(i, listLen); j++) { + if (0 == (i + j) % 13) { + list.add(null); + } else { + list.add(i + j); + } + } + appender.beginRow().append(i).append(list).endRow(); + } } - try (ResultSet rs = stmt.executeQuery("SELECT col2 FROM tab1 ORDER BY col1")) { - assertTrue(rs.next()); - assertMapsEqual(rs.getObject(1), map1); + try (ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { assertTrue(rs.next()); - assertMapsEqual(rs.getObject(1), map2); + assertEquals(rs.getInt(1), count + tail); assertFalse(rs.next()); } - } - } - - public static void test_appender_list_basic_map() throws Exception { - Map map1 = createMap(41, "foo1", 42, "bar1", 43, "baz1"); - Map map2 = createMap(44, null, 45, "bar2"); - Map map3 = new LinkedHashMap<>(); - Map map4 = createMap(46, "foo3"); - try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); - Statement stmt = conn.createStatement()) { - stmt.execute("CREATE TABLE tab1(col1 INT, col2 MAP(INTEGER, VARCHAR)[])"); - try (DuckDBAppender appender = conn.createAppender("tab1")) { - appender.beginRow() - .append(42) - .append(asList(map1, map2, map3)) - .endRow() - .beginRow() - .append(43) - .append((List) null) - .endRow() - .beginRow() - .append(44) - .append(asList(null, map4)) - .endRow() - .flush(); - } - try (ResultSet rs = stmt.executeQuery("SELECT unnest(col2) from tab1 WHERE col1 = 42")) { - assertTrue(rs.next()); - assertMapsEqual(rs.getObject(1), map1); + try (ResultSet rs = stmt.executeQuery( + "SELECT count(*) FROM (SELECT unnest(col2) FROM tab1 WHERE col1 = " + (listLen - 7) + ")")) { assertTrue(rs.next()); - assertMapsEqual(rs.getObject(1), map2); - assertTrue(rs.next()); - assertMapsEqual(rs.getObject(1), map3); - assertFalse(rs.next()); - } - try (ResultSet rs = stmt.executeQuery("SELECT col2 from tab1 WHERE col1 = 43")) { - assertTrue(rs.next()); - assertNull(rs.getObject(1)); - assertTrue(rs.wasNull()); + assertEquals(rs.getInt(1), listLen - 7); assertFalse(rs.next()); } - try (ResultSet rs = stmt.executeQuery("SELECT unnest(col2) from tab1 WHERE col1 = 44")) { - assertTrue(rs.next()); - assertNull(rs.getObject(1)); - assertTrue(rs.wasNull()); - assertTrue(rs.next()); - assertMapsEqual(rs.getObject(1), map4); - assertFalse(rs.next()); + + try (ResultSet rs = stmt.executeQuery("SELECT col1, unnest(col2) FROM tab1 ORDER BY col1")) { + for (int i = 0; i < count + tail; i++) { + for (long j = 0; j < Math.min(i, listLen); j++) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), i); + if (0 == (i + j) % 13) { + assertNull(rs.getObject(2)); + assertTrue(rs.wasNull()); + } else { + assertEquals(rs.getLong(2), i + j); + } + } + } } } } diff --git a/src/test/java/org/duckdb/TestAppenderComposite.java b/src/test/java/org/duckdb/TestAppenderComposite.java index 5fa0aedb6..5930b7b4c 100644 --- a/src/test/java/org/duckdb/TestAppenderComposite.java +++ b/src/test/java/org/duckdb/TestAppenderComposite.java @@ -670,4 +670,216 @@ public static void test_appender_list_basic_union() throws Exception { } } } + + private static void assertMapsEqual(Object obj1, Map map2) throws Exception { + Map map1 = (Map) obj1; + assertEquals(map1.size(), map2.size()); + List> list2 = new ArrayList<>(map2.entrySet()); + int i = 0; + for (Map.Entry en : map1.entrySet()) { + assertEquals(en.getKey(), list2.get(i).getKey()); + assertEquals(en.getValue(), list2.get(i).getValue()); + i += 1; + } + } + + public static void test_appender_map_basic() throws Exception { + Map map1 = createMap(41, "foo", 42, "bar"); + Map map2 = createMap(41, "foo", 42, null, 43, "baz"); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1(col1 INTEGER, col2 MAP(INTEGER, VARCHAR))"); + + try (DuckDBAppender appender = conn.createAppender("tab1")) { + appender.beginRow() + .append(41) + .append(map1) + .endRow() + .beginRow() + .append(42) + .append(map2) + .endRow() + .flush(); + } + + try (ResultSet rs = stmt.executeQuery("SELECT col2 FROM tab1 ORDER BY col1")) { + assertTrue(rs.next()); + assertMapsEqual(rs.getObject(1), map1); + assertTrue(rs.next()); + assertMapsEqual(rs.getObject(1), map2); + assertFalse(rs.next()); + } + } + } + + public static void test_appender_map_string_long() throws Exception { + int count = 1 << 12; // auto flush twice + int tail = 7; // flushed on close + int mapSize = (1 << 6) + 7; // increase this for stress tests + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1(col1 INTEGER, col2 MAP(VARCHAR, BIGINT))"); + + try (DuckDBAppender appender = conn.createAppender("tab1")) { + for (int i = 0; i < count + tail; i++) { + Map map = new LinkedHashMap<>(); + for (long j = 0; j < Math.min(i, mapSize); j++) { + String key = "foo_" + i + "_" + j; + if (0 == (i + j) % 13) { + map.put(key, null); + } else { + map.put(key, i + j); + } + } + appender.beginRow().append(i).append(map).endRow(); + } + } + + try (ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), count + tail); + assertFalse(rs.next()); + } + + try (ResultSet rs = stmt.executeQuery( + "SELECT count(*) FROM (SELECT unnest(map_keys(col2)) FROM tab1 WHERE col1 = " + (mapSize - 7) + + ")")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), mapSize - 7); + assertFalse(rs.next()); + } + + try (ResultSet rs = stmt.executeQuery( + "SELECT col1, unnest(map_keys(col2)), unnest(map_values(col2)) FROM tab1 ORDER BY col1")) { + for (int i = 0; i < count + tail; i++) { + for (long j = 0; j < Math.min(i, mapSize); j++) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), i); + assertEquals(rs.getString(2), "foo_" + i + "_" + j); + if (0 == (i + j) % 13) { + assertNull(rs.getObject(3)); + assertTrue(rs.wasNull()); + } else { + assertEquals(rs.getLong(3), i + j); + } + } + } + } + } + } + + public static void test_appender_map_string_struct() throws Exception { + int count = 1 << 12; // auto flush twice + int tail = 7; // flushed on close + int mapSize = (1 << 5) + 7; // increase this for stress tests + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1(col1 INTEGER, col2 MAP(VARCHAR, STRUCT(s1 INTEGER, s2 BIGINT)))"); + + try (DuckDBAppender appender = conn.createAppender("tab1")) { + for (int i = 0; i < count + tail; i++) { + Map map = new LinkedHashMap<>(); + for (long j = 0; j < Math.min(i, mapSize); j++) { + String key = "foo_" + i + "_" + j; + if (0 == (i + j) % 13) { + map.put(key, null); + } else if (0 == (i + j) % 17) { + map.put(key, asList(null, j)); + } else { + map.put(key, asList(i, j)); + } + } + appender.beginRow().append(i).append(map).endRow(); + } + } + + try (ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tab1")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), count + tail); + assertFalse(rs.next()); + } + + try (ResultSet rs = stmt.executeQuery( + "SELECT count(*) FROM (SELECT unnest(map_keys(col2)) FROM tab1 WHERE col1 = " + (mapSize - 7) + + ")")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), mapSize - 7); + assertFalse(rs.next()); + } + + try (ResultSet rs = stmt.executeQuery( + "SELECT col1, unnest(map_keys(col2)), unnest(map_values(col2)) FROM tab1 ORDER BY col1")) { + for (int i = 0; i < count + tail; i++) { + for (long j = 0; j < Math.min(i, mapSize); j++) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), i); + assertEquals(rs.getString(2), "foo_" + i + "_" + j); + if (0 == (i + j) % 13) { + assertNull(rs.getObject(3)); + assertTrue(rs.wasNull()); + } else { + DuckDBStruct struct = (DuckDBStruct) rs.getObject(3); + Map map = struct.getMap(); + if (0 == (i + j) % 17) { + assertNull(map.get("s1")); + } else { + assertEquals(map.get("s1"), i); + } + assertEquals(map.get("s2"), j); + } + } + } + } + } + } + + public static void test_appender_list_basic_map() throws Exception { + Map map1 = createMap(41, "foo1", 42, "bar1", 43, "baz1"); + Map map2 = createMap(44, null, 45, "bar2"); + Map map3 = new LinkedHashMap<>(); + Map map4 = createMap(46, "foo3"); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE tab1(col1 INT, col2 MAP(INTEGER, VARCHAR)[])"); + try (DuckDBAppender appender = conn.createAppender("tab1")) { + appender.beginRow() + .append(42) + .append(asList(map1, map2, map3)) + .endRow() + .beginRow() + .append(43) + .append((List) null) + .endRow() + .beginRow() + .append(44) + .append(asList(null, map4)) + .endRow() + .flush(); + } + + try (ResultSet rs = stmt.executeQuery("SELECT unnest(col2) from tab1 WHERE col1 = 42")) { + assertTrue(rs.next()); + assertMapsEqual(rs.getObject(1), map1); + assertTrue(rs.next()); + assertMapsEqual(rs.getObject(1), map2); + assertTrue(rs.next()); + assertMapsEqual(rs.getObject(1), map3); + assertFalse(rs.next()); + } + try (ResultSet rs = stmt.executeQuery("SELECT col2 from tab1 WHERE col1 = 43")) { + assertTrue(rs.next()); + assertNull(rs.getObject(1)); + assertTrue(rs.wasNull()); + assertFalse(rs.next()); + } + try (ResultSet rs = stmt.executeQuery("SELECT unnest(col2) from tab1 WHERE col1 = 44")) { + assertTrue(rs.next()); + assertNull(rs.getObject(1)); + assertTrue(rs.wasNull()); + assertTrue(rs.next()); + assertMapsEqual(rs.getObject(1), map4); + assertFalse(rs.next()); + } + } + } } diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index 51f1f1d34..cc6ef7338 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -112,11 +112,11 @@ public static void test_bindings_vector_validity() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); ByteBuffer vec = duckdb_create_vector(lt); - ByteBuffer emptyValidity = duckdb_vector_get_validity(vec, 1); + ByteBuffer emptyValidity = duckdb_vector_get_validity(vec, duckdb_vector_size()); assertNull(emptyValidity); duckdb_vector_ensure_validity_writable(vec); - ByteBuffer validity = duckdb_vector_get_validity(vec, 1); + ByteBuffer validity = duckdb_vector_get_validity(vec, duckdb_vector_size()); assertNotNull(validity); assertEquals(validity.capacity(), (int) duckdb_vector_size() / 8); @@ -198,7 +198,7 @@ public static void test_bindings_validity() throws Exception { ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); ByteBuffer vec = duckdb_create_vector(lt); duckdb_vector_ensure_validity_writable(vec); - ByteBuffer validity = duckdb_vector_get_validity(vec, 1); + ByteBuffer validity = duckdb_vector_get_validity(vec, duckdb_vector_size()); long row = 7; assertTrue(duckdb_validity_row_is_valid(validity, row)); @@ -232,9 +232,9 @@ public static void test_bindings_data_chunk() throws Exception { checkVectorInsertString(vec); duckdb_vector_ensure_validity_writable(vec); - assertNotNull(duckdb_vector_get_validity(vec, 1)); + assertNotNull(duckdb_vector_get_validity(vec, duckdb_vector_size())); duckdb_data_chunk_reset(chunk); - assertNull(duckdb_vector_get_validity(vec, 1)); + assertNull(duckdb_vector_get_validity(vec, duckdb_vector_size())); duckdb_destroy_data_chunk(chunk); duckdb_destroy_logical_type(varcharType);