From 87d082477762fa9b720f5772c5331bd629ed90ba Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Mon, 5 Jun 2023 14:01:07 -0500 Subject: [PATCH 01/31] GH-34749 : [Java] Make Zstd compression level configurable (#34873) ### Rationale for this change Closes: https://github.com/apache/arrow/issues/34749 ### What changes are included in this PR? Make compression level configurable for Zstd ### Are these changes tested? Yes ### Are there any user-facing changes? No * Closes: #34749 Lead-authored-by: david dali susanibar arce Co-authored-by: Gang Wu Signed-off-by: David Li --- .../CommonsCompressionFactory.java | 12 +++++++++++ .../compression/ZstdCompressionCodec.java | 13 +++++++++++- .../TestArrowReaderWriterWithCompression.java | 3 ++- .../compression/TestCompressionCodec.java | 8 +++++-- .../vector/compression/CompressionCodec.java | 5 +++++ .../compression/NoCompressionCodec.java | 5 +++++ .../arrow/vector/ipc/ArrowFileWriter.java | 9 +++++++- .../arrow/vector/ipc/ArrowStreamWriter.java | 21 ++++++++++++++++++- .../apache/arrow/vector/ipc/ArrowWriter.java | 14 ++++++++++--- 9 files changed, 81 insertions(+), 9 deletions(-) diff --git a/java/compression/src/main/java/org/apache/arrow/compression/CommonsCompressionFactory.java b/java/compression/src/main/java/org/apache/arrow/compression/CommonsCompressionFactory.java index 867e9f418b2b3..45d8c7d443252 100644 --- a/java/compression/src/main/java/org/apache/arrow/compression/CommonsCompressionFactory.java +++ b/java/compression/src/main/java/org/apache/arrow/compression/CommonsCompressionFactory.java @@ -40,4 +40,16 @@ public CompressionCodec createCodec(CompressionUtil.CodecType codecType) { throw new IllegalArgumentException("Compression type not supported: " + codecType); } } + + @Override + public CompressionCodec createCodec(CompressionUtil.CodecType codecType, int compressionLevel) { + switch (codecType) { + case LZ4_FRAME: + return new Lz4CompressionCodec(); + case ZSTD: + return new ZstdCompressionCodec(compressionLevel); + default: + throw new IllegalArgumentException("Compression type not supported: " + codecType); + } + } } diff --git a/java/compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java b/java/compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java index 38717843ef86d..c26519055b1fe 100644 --- a/java/compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java +++ b/java/compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java @@ -30,6 +30,17 @@ */ public class ZstdCompressionCodec extends AbstractCompressionCodec { + private int compressionLevel; + private static final int DEFAULT_COMPRESSION_LEVEL = 3; + + public ZstdCompressionCodec() { + this.compressionLevel = DEFAULT_COMPRESSION_LEVEL; + } + + public ZstdCompressionCodec(int compressionLevel) { + this.compressionLevel = compressionLevel; + } + @Override protected ArrowBuf doCompress(BufferAllocator allocator, ArrowBuf uncompressedBuffer) { long maxSize = Zstd.compressBound(uncompressedBuffer.writerIndex()); @@ -38,7 +49,7 @@ protected ArrowBuf doCompress(BufferAllocator allocator, ArrowBuf uncompressedBu long bytesWritten = Zstd.compressUnsafe( compressedBuffer.memoryAddress() + CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH, dstSize, /*src*/uncompressedBuffer.memoryAddress(), /*srcSize=*/uncompressedBuffer.writerIndex(), - /*level=*/3); + /*level=*/this.compressionLevel); if (Zstd.isError(bytesWritten)) { compressedBuffer.close(); throw new RuntimeException("Error compressing: " + Zstd.getErrorName(bytesWritten)); diff --git a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java index d8e51d82020e3..6104cb1a132e4 100644 --- a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java +++ b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Optional; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -57,7 +58,7 @@ public void testArrowFileZstdRoundTrip() throws Exception { ByteArrayOutputStream out = new ByteArrayOutputStream(); try (final ArrowFileWriter writer = new ArrowFileWriter(root, null, Channels.newChannel(out), new HashMap<>(), - IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, CompressionUtil.CodecType.ZSTD)) { + IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, CompressionUtil.CodecType.ZSTD, Optional.of(7))) { writer.start(); writer.writeBatch(); writer.end(); diff --git a/java/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java b/java/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java index a1d5000daac82..7db00cfde485d 100644 --- a/java/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java +++ b/java/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java @@ -32,6 +32,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.stream.Stream; @@ -93,6 +94,9 @@ static Collection codecs() { CompressionCodec zstdCodec = new ZstdCompressionCodec(); params.add(Arguments.arguments(len, zstdCodec)); + + CompressionCodec zstdCodecAndCompressionLevel = new ZstdCompressionCodec(7); + params.add(Arguments.arguments(len, zstdCodecAndCompressionLevel)); } return params; } @@ -235,7 +239,7 @@ void testReadWriteStream(CompressionUtil.CodecType codec) throws Exception { try (final ArrowStreamWriter writer = new ArrowStreamWriter( root, new DictionaryProvider.MapDictionaryProvider(), Channels.newChannel(compressedStream), - IpcOption.DEFAULT, factory, codec)) { + IpcOption.DEFAULT, factory, codec, Optional.of(7))) { writer.start(); writer.writeBatch(); writer.end(); @@ -262,7 +266,7 @@ void testReadWriteFile(CompressionUtil.CodecType codec) throws Exception { try (final ArrowFileWriter writer = new ArrowFileWriter( root, new DictionaryProvider.MapDictionaryProvider(), Channels.newChannel(compressedStream), - new HashMap<>(), IpcOption.DEFAULT, factory, codec)) { + new HashMap<>(), IpcOption.DEFAULT, factory, codec, Optional.of(7))) { writer.start(); writer.writeBatch(); writer.end(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compression/CompressionCodec.java b/java/vector/src/main/java/org/apache/arrow/vector/compression/CompressionCodec.java index a6dd8b51fe554..8d235f713a8fa 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compression/CompressionCodec.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compression/CompressionCodec.java @@ -58,5 +58,10 @@ interface Factory { * Creates the codec based on the codec type. */ CompressionCodec createCodec(CompressionUtil.CodecType codecType); + + /** + * Creates the codec based on the codec type and compression level. + */ + CompressionCodec createCodec(CompressionUtil.CodecType codecType, int compressionLevel); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compression/NoCompressionCodec.java b/java/vector/src/main/java/org/apache/arrow/vector/compression/NoCompressionCodec.java index 9b3feb7321687..f0e969508e0e0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compression/NoCompressionCodec.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compression/NoCompressionCodec.java @@ -72,5 +72,10 @@ public CompressionCodec createCodec(CompressionUtil.CodecType codecType) { throw new IllegalArgumentException("Unsupported codec type: " + codecType); } } + + @Override + public CompressionCodec createCodec(CompressionUtil.CodecType codecType, int compressionLevel) { + return createCodec(codecType); + } } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java index 4b41d0ab61c16..0b0931f7bb57c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.VectorSchemaRoot; @@ -74,7 +75,13 @@ public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, Writa public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, Map metaData, IpcOption option, CompressionCodec.Factory compressionFactory, CompressionUtil.CodecType codecType) { - super(root, provider, out, option, compressionFactory, codecType); + this(root, provider, out, metaData, option, compressionFactory, codecType, Optional.empty()); + } + + public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, + Map metaData, IpcOption option, CompressionCodec.Factory compressionFactory, + CompressionUtil.CodecType codecType, Optional compressionLevel) { + super(root, provider, out, option, compressionFactory, codecType, compressionLevel); this.metaData = metaData; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java index 60230d5a9b910..7200851620f74 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java @@ -21,6 +21,7 @@ import java.io.OutputStream; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; +import java.util.Optional; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.compression.CompressionCodec; @@ -81,7 +82,25 @@ public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, Wri public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, IpcOption option, CompressionCodec.Factory compressionFactory, CompressionUtil.CodecType codecType) { - super(root, provider, out, option, compressionFactory, codecType); + this(root, provider, out, option, compressionFactory, codecType, Optional.empty()); + } + + /** + * Construct an ArrowStreamWriter with compression enabled. + * + * @param root Existing VectorSchemaRoot with vectors to be written. + * @param provider DictionaryProvider for any vectors that are dictionary encoded. + * (Optional, can be null) + * @param option IPC write options + * @param compressionFactory Compression codec factory + * @param codecType Codec type + * @param compressionLevel Compression level + * @param out WritableByteChannel for writing. + */ + public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, + IpcOption option, CompressionCodec.Factory compressionFactory, + CompressionUtil.CodecType codecType, Optional compressionLevel) { + super(root, provider, out, option, compressionFactory, codecType, compressionLevel); } /** diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java index 13e313ab4d97c..2c524b81b7008 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import org.apache.arrow.util.AutoCloseables; @@ -72,7 +73,8 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab } protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, IpcOption option) { - this(root, provider, out, option, NoCompressionCodec.Factory.INSTANCE, CompressionUtil.CodecType.NO_COMPRESSION); + this(root, provider, out, option, NoCompressionCodec.Factory.INSTANCE, CompressionUtil.CodecType.NO_COMPRESSION, + Optional.empty()); } /** @@ -84,11 +86,17 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab * @param option IPC write options * @param compressionFactory Compression codec factory * @param codecType Compression codec + * @param compressionLevel Compression level */ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, IpcOption option, - CompressionCodec.Factory compressionFactory, CompressionUtil.CodecType codecType) { + CompressionCodec.Factory compressionFactory, CompressionUtil.CodecType codecType, + Optional compressionLevel) { this.unloader = new VectorUnloader( - root, /*includeNullCount*/ true, compressionFactory.createCodec(codecType), /*alignBuffers*/ true); + root, /*includeNullCount*/ true, + compressionLevel.isPresent() ? + compressionFactory.createCodec(codecType, compressionLevel.get()) : + compressionFactory.createCodec(codecType), + /*alignBuffers*/ true); this.out = new WriteChannel(out); this.option = option; From e2ae492b63245a8bc6330b10785777ab04659798 Mon Sep 17 00:00:00 2001 From: Bryce Mecum Date: Mon, 5 Jun 2023 19:58:15 -0800 Subject: [PATCH 02/31] MINOR: [Documentation][Python] Minor tweaks to docs around IPC messages (#35880) ### Rationale for this change @ jorisvandenbossche pointed out that the docstring for RecordBatch.serialize might not be clear enough and that users may be surprised to find out a Schema isn't included. This PR makes that clearer and improves up related documentation. ### Are these changes tested? Yes. I built the docs and verified the rendering was correct and I ran the changed example locally. I didn't --doctest-cython it because it looks like that hasn't been set up for this module. ### Are there any user-facing changes? No, these are just docs improvements. Lead-authored-by: Bryce Mecum Co-authored-by: Alenka Frim Signed-off-by: AlenkaF --- docs/source/format/Glossary.rst | 7 +++++-- docs/source/format/Other.rst | 10 ++++++---- python/pyarrow/table.pxi | 20 ++++++++++++++++++-- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/docs/source/format/Glossary.rst b/docs/source/format/Glossary.rst index ac18c1618bceb..65d6e0afa4557 100644 --- a/docs/source/format/Glossary.rst +++ b/docs/source/format/Glossary.rst @@ -151,8 +151,11 @@ Glossary IPC message message - The IPC representation of a particular in-memory structure, - like a record batch or schema. + The IPC representation of a particular in-memory structure, like a :term:`record + batch` or :term:`schema`. Will always be one of the members of ``MessageHeader`` + in the `Flatbuffers protocol file + `_. + IPC streaming format streaming format diff --git a/docs/source/format/Other.rst b/docs/source/format/Other.rst index 9504998d62234..cb5234e0c2bfe 100644 --- a/docs/source/format/Other.rst +++ b/docs/source/format/Other.rst @@ -18,10 +18,10 @@ Other Data Structures ===================== -Our Flatbuffers protocol files have metadata for some other data -structures defined to allow other kinds of applications to take -advantage of common interprocess communication machinery. These data -structures are not considered to be part of the columnar format. +Our `Flatbuffers protocol definition files`_ have metadata for some other data +structures defined to allow other kinds of applications to take advantage of +common interprocess communication machinery. These data structures are not +considered to be part of the columnar format. An Arrow columnar implementation is not required to implement these types. @@ -61,3 +61,5 @@ region) to be multiples of 64 bytes: :: The contents of the sparse tensor index depends on what kind of sparse format is used. + +.. _Flatbuffers protocol definition files: https://github.com/apache/arrow/tree/main/format diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 5f1ee00201589..2c3092d0932f3 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2274,7 +2274,12 @@ cdef class RecordBatch(_Tabular): def serialize(self, memory_pool=None): """ - Write RecordBatch to Buffer as encapsulated IPC message. + Write RecordBatch to Buffer as encapsulated IPC message, which does not + include a Schema. + + To reconstruct a RecordBatch from the encapsulated IPC message Buffer + returned by this function, a Schema must be passed separately. See + Examples. Parameters ---------- @@ -2292,8 +2297,19 @@ cdef class RecordBatch(_Tabular): >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], ... names=["n_legs", "animals"]) - >>> batch.serialize() + >>> buf = batch.serialize() + >>> buf + + Reconstruct RecordBatch from IPC message Buffer and original Schema + + >>> pa.ipc.read_record_batch(buf, batch.schema) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,2,4,4,5,100] + animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] """ cdef shared_ptr[CBuffer] buffer cdef CIpcWriteOptions options = CIpcWriteOptions.Defaults() From c2536253885eff81f258ccaa3a6834dce0bf192c Mon Sep 17 00:00:00 2001 From: rtpsw Date: Tue, 6 Jun 2023 11:16:28 +0300 Subject: [PATCH 03/31] GH-35868: [C++] Occasional TSAN failure on asof-join-node-test (#35904) ### Rationale for this change `AsofJoinNode` may run into a data race when invalidating the key hasher. The key hasher queried from one thread but invalidated from another. This might be simplified so that the key hasher would only be used from one thread, but this is out of scope for this PR. ### What changes are included in this PR? The invalidated member of the key hasher is made atomic. ### Are these changes tested? Yes, by existing testing. ### Are there any user-facing changes? No. * Closes: #35868 Authored-by: Yaron Gvili Signed-off-by: Antoine Pitrou --- cpp/src/arrow/acero/asof_join_node.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index b92339b951bcf..f8dee5aac8815 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -529,7 +529,7 @@ class KeyHasher { size_t index_; std::vector indices_; std::vector metadata_; - const RecordBatch* batch_; + std::atomic batch_; std::vector hashes_; LightContext ctx_; std::vector column_arrays_; From 68cbc6fe79203be597d4b274a62de6966bf9181c Mon Sep 17 00:00:00 2001 From: abandy Date: Tue, 6 Jun 2023 06:27:33 -0400 Subject: [PATCH 04/31] GH-35803: [Doc] Add columns to the Implementation Status tables for Swift (#35862) Added swift columns to implementation status tables * Closes: #35803 Authored-by: Alva Bandy Signed-off-by: Joris Van den Bossche --- docs/source/status.rst | 444 ++++++++++++++++++++--------------------- 1 file changed, 222 insertions(+), 222 deletions(-) diff --git a/docs/source/status.rst b/docs/source/status.rst index 21afc32efe16b..a73de815b7798 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -28,76 +28,76 @@ Arrow library. Data Types ========== -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia | -| (primitive) | | | | | | | | -+===================+=======+=======+=======+============+=======+=======+=======+ -| Null | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Boolean | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Int8/16/32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| UInt8/16/32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Float16 | | | ✓ | | ✓ (1)| ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Float32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Decimal128 | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Decimal256 | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Date32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Time32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Timestamp | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Duration | ✓ | ✓ | ✓ | | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Interval | ✓ | ✓ | ✓ | | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Fixed Size Binary | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Binary | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Large Binary | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Utf8 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Large Utf8 | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ - -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia | -| (nested) | | | | | | | | -+===================+=======+=======+=======+============+=======+=======+=======+ -| Fixed Size List | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| List | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Large List | ✓ | ✓ | ✓ | | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Struct | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Map | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Dense Union | ✓ | ✓ | ✓ | | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Sparse Union | ✓ | ✓ | ✓ | | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ - -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia | -| (special) | | | | | | | | -+===================+=======+=======+=======+============+=======+=======+=======+ -| Dictionary | ✓ | ✓ (2) | ✓ | ✓ (2) | ✓ (2) | ✓ (2) | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Extension | ✓ | ✓ | ✓ | | | ✓ | ✓ | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ -| Run-End Encoded | | | ✓ | | | | | -+-------------------+-------+-------+-------+------------+-------+-------+-------+ ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | +| (primitive) | | | | | | | | | ++===================+=======+=======+=======+============+=======+=======+=======+=======+ +| Null | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Boolean | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Int8/16/32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| UInt8/16/32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Float16 | | | ✓ | | ✓ (1)| ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Float32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Decimal128 | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Decimal256 | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Date32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Time32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Timestamp | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Duration | ✓ | ✓ | ✓ | | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Interval | ✓ | ✓ | ✓ | | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Fixed Size Binary | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Binary | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Large Binary | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Utf8 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Large Utf8 | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ + ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | +| (nested) | | | | | | | | | ++===================+=======+=======+=======+============+=======+=======+=======+=======+ +| Fixed Size List | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| List | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Large List | ✓ | ✓ | ✓ | | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Struct | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Map | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Dense Union | ✓ | ✓ | ✓ | | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Sparse Union | ✓ | ✓ | ✓ | | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ + ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | +| (special) | | | | | | | | | ++===================+=======+=======+=======+============+=======+=======+=======+=======+ +| Dictionary | ✓ | ✓ (2) | ✓ | ✓ (2) | ✓ (2) | ✓ (2) | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Extension | ✓ | ✓ | ✓ | | | ✓ | ✓ | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Run-End Encoded | | | ✓ | | | | | | ++-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ Notes: @@ -111,32 +111,32 @@ Notes: IPC Format ========== -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| IPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | -| | | | | | | | | -+=============================+=======+=======+=======+============+=======+=======+=======+ -| Arrow stream format | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Arrow file format | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Record batches | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Dictionaries | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Replacement dictionaries | ✓ | ✓ | ✓ | | | | ✓ | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Delta dictionaries | ✓ (1) | | ✓ (1) | | ✓ | | ✓ | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Tensors | ✓ | | | | | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Sparse tensors | ✓ | | | | | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Buffer compression | ✓ | ✓ (3) | ✓ | | ✓ (4) | ✓ | ✓ | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Endianness conversion | ✓ (2) | | ✓ (2) | | | | | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Custom schema metadata | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | -+-----------------------------+-------+-------+-------+------------+-------+-------+-------+ ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| IPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | +| | | | | | | | | | ++=============================+=======+=======+=======+============+=======+=======+=======+=======+ +| Arrow stream format | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Arrow file format | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Record batches | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Dictionaries | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Replacement dictionaries | ✓ | ✓ | ✓ | | | | ✓ | | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Delta dictionaries | ✓ (1) | | ✓ (1) | | ✓ | | ✓ | | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Tensors | ✓ | | | | | | | | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Sparse tensors | ✓ | | | | | | | | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Buffer compression | ✓ | ✓ (3) | ✓ | | ✓ (4) | ✓ | ✓ | | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Endianness conversion | ✓ (2) | | ✓ (2) | | | | | | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Custom schema metadata | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | | ++-----------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ Notes: @@ -156,57 +156,57 @@ Notes: Flight RPC ========== -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Flight RPC Transport | C++ | Java | Go | JavaScript | C# | Rust | Julia | -+============================================+=======+=======+=======+============+=======+=======+=======+ -| gRPC_ transport (grpc:, grpc+tcp:) | ✓ | ✓ | ✓ | | ✓ | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| gRPC domain socket transport (grpc+unix:) | ✓ | ✓ | ✓ | | ✓ | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| gRPC + TLS transport (grpc+tls:) | ✓ | ✓ | ✓ | | ✓ | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| UCX_ transport (ucx:) | ✓ | | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Flight RPC Transport | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | ++============================================+=======+=======+=======+============+=======+=======+=======+=======+ +| gRPC_ transport (grpc:, grpc+tcp:) | ✓ | ✓ | ✓ | | ✓ | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| gRPC domain socket transport (grpc+unix:) | ✓ | ✓ | ✓ | | ✓ | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| gRPC + TLS transport (grpc+tls:) | ✓ | ✓ | ✓ | | ✓ | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| UCX_ transport (ucx:) | ✓ | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ Supported features in the gRPC transport: -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | -+============================================+=======+=======+=======+============+=======+=======+=======+ -| All RPC methods | ✓ | ✓ | ✓ | | × (1) | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Authentication handlers | ✓ | ✓ | ✓ | | ✓ (2) | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Call timeouts | ✓ | ✓ | ✓ | | | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Call cancellation | ✓ | ✓ | ✓ | | | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Concurrent client calls (3) | ✓ | ✓ | ✓ | | ✓ | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Custom middleware | ✓ | ✓ | ✓ | | | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| RPC error codes | ✓ | ✓ | ✓ | | ✓ | ✓ | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | ++============================================+=======+=======+=======+============+=======+=======+=======+=======+ +| All RPC methods | ✓ | ✓ | ✓ | | × (1) | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Authentication handlers | ✓ | ✓ | ✓ | | ✓ (2) | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Call timeouts | ✓ | ✓ | ✓ | | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Call cancellation | ✓ | ✓ | ✓ | | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Concurrent client calls (3) | ✓ | ✓ | ✓ | | ✓ | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Custom middleware | ✓ | ✓ | ✓ | | | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| RPC error codes | ✓ | ✓ | ✓ | | ✓ | ✓ | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ Supported features in the UCX transport: -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | -+============================================+=======+=======+=======+============+=======+=======+=======+ -| All RPC methods | × (4) | | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Authentication handlers | | | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Call timeouts | | | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Call cancellation | | | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Concurrent client calls | ✓ (5) | | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Custom middleware | | | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| RPC error codes | ✓ | | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | ++============================================+=======+=======+=======+============+=======+=======+=======+=======+ +| All RPC methods | × (4) | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Authentication handlers | | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Call timeouts | | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Call cancellation | | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Concurrent client calls | ✓ (5) | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Custom middleware | | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| RPC error codes | ✓ | | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ Notes: @@ -234,55 +234,55 @@ The feature support refers to the client/server libraries only; databases which implement the Flight SQL protocol in turn will support/not support individual features. -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | -+============================================+=======+=======+=======+============+=======+=======+=======+ -| BeginSavepoint | ✓ | ✓ | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| BeginTransaction | ✓ | ✓ | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| CancelQuery | ✓ | ✓ | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| ClosePreparedStatement | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| CreatePreparedStatement | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| CreatePreparedSubstraitPlan | ✓ | ✓ | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| EndSavepoint | ✓ | ✓ | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| EndTransaction | ✓ | ✓ | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetCatalogs | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetCrossReference | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetDbSchemas | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetExportedKeys | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetImportedKeys | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetPrimaryKeys | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetSqlInfo | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetTables | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetTableTypes | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| GetXdbcTypeInfo | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| PreparedStatementQuery | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| PreparedStatementUpdate | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| StatementSubstraitPlan | ✓ | ✓ | | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| StatementQuery | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ -| StatementUpdate | ✓ | ✓ | ✓ | | | | | -+--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | ++============================================+=======+=======+=======+============+=======+=======+=======+=======+ +| BeginSavepoint | ✓ | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| BeginTransaction | ✓ | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| CancelQuery | ✓ | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| ClosePreparedStatement | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| CreatePreparedStatement | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| CreatePreparedSubstraitPlan | ✓ | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| EndSavepoint | ✓ | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| EndTransaction | ✓ | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetCatalogs | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetCrossReference | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetDbSchemas | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetExportedKeys | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetImportedKeys | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetPrimaryKeys | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetSqlInfo | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetTables | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetTableTypes | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| GetXdbcTypeInfo | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| PreparedStatementQuery | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| PreparedStatementUpdate | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| StatementSubstraitPlan | ✓ | ✓ | | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| StatementQuery | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| StatementUpdate | ✓ | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+-------+ .. seealso:: The :doc:`./format/FlightSql` specification. @@ -290,18 +290,18 @@ support/not support individual features. C Data Interface ================ -+-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+ -| Feature | C++ | Python | R | Rust | Go | Java | C/GLib | Ruby | Julia | C# | -| | | | | | | | | | | | -+=============================+=====+========+===+======+====+======+========+======+=======+=====+ -| Schema export | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | -+-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+ -| Array export | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | -+-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+ -| Schema import | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | -+-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+ -| Array import | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | -+-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+ ++-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+-------+ +| Feature | C++ | Python | R | Rust | Go | Java | C/GLib | Ruby | Julia | C# | Swift | +| | | | | | | | | | | | | ++=============================+=====+========+===+======+====+======+========+======+=======+=====+=======+ +| Schema export | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | | ++-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+-------+ +| Array export | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | | ++-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+-------+ +| Schema import | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | | ++-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+-------+ +| Array import | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | | ++-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+-------+ .. seealso:: The :ref:`C Data Interface ` specification. @@ -310,14 +310,14 @@ C Data Interface C Stream Interface ================== -+-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+ -| Feature | C++ | Python | R | Rust | Go | Java | C/GLib | Ruby | Julia | C# | -| | | | | | | | | | | | -+=============================+=====+========+===+======+====+======+========+======+=======+=====+ -| Stream export | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ✓ | -+-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+ -| Stream import | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ✓ | -+-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+ ++-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+-------+ +| Feature | C++ | Python | R | Rust | Go | Java | C/GLib | Ruby | Julia | C# | Swift | +| | | | | | | | | | | | | ++=============================+=====+========+===+======+====+======+========+======+=======+=====+=======+ +| Stream export | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ✓ | | ++-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+-------+ +| Stream import | ✓ | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | ✓ | | ++-----------------------------+-----+--------+---+------+----+------+--------+------+-------+-----+-------+ .. seealso:: The :ref:`C Stream Interface ` specification. @@ -326,18 +326,18 @@ C Stream Interface Third-Party Data Formats ======================== -+-----------------------------+---------+---------+-------+------------+-------+-------+-------+ -| Format | C++ | Java | Go | JavaScript | C# | Rust | Julia | -| | | | | | | | | -+=============================+=========+=========+=======+============+=======+=======+=======+ -| Avro | | R | | | | | | -+-----------------------------+---------+---------+-------+------------+-------+-------+-------+ -| CSV | R/W | R (2) | R/W | | | R/W | R/W | -+-----------------------------+---------+---------+-------+------------+-------+-------+-------+ -| ORC | R/W | R (1) | | | | | | -+-----------------------------+---------+---------+-------+------------+-------+-------+-------+ -| Parquet | R/W | R (2) | R/W | | | R/W | | -+-----------------------------+---------+---------+-------+------------+-------+-------+-------+ ++-----------------------------+---------+---------+-------+------------+-------+-------+-------+-------+ +| Format | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | +| | | | | | | | | | ++=============================+=========+=========+=======+============+=======+=======+=======+=======+ +| Avro | | R | | | | | | | ++-----------------------------+---------+---------+-------+------------+-------+-------+-------+-------+ +| CSV | R/W | R (2) | R/W | | | R/W | R/W | | ++-----------------------------+---------+---------+-------+------------+-------+-------+-------+-------+ +| ORC | R/W | R (1) | | | | | | | ++-----------------------------+---------+---------+-------+------------+-------+-------+-------+-------+ +| Parquet | R/W | R (2) | R/W | | | R/W | | | ++-----------------------------+---------+---------+-------+------------+-------+-------+-------+-------+ Notes: From 7f8ccb5ff17980a7ca9c1668f33689738c634144 Mon Sep 17 00:00:00 2001 From: Gang Wu Date: Tue, 6 Jun 2023 21:06:44 +0900 Subject: [PATCH 05/31] GH-34375: [C++][Parquet] Ignore page header stats when page index enabled (#35455) ### Rationale for this change Page-level statistics are probably not used in production, and after adding column indexes they are useless. parquet-mr already stopped writing them in https://issues.apache.org/jira/browse/PARQUET-1365. ### What changes are included in this PR? Once page index is enabled for one column, it does not write page stats to the header any more. ### Are these changes tested? Added a test to check page stats have been skipped. ### Are there any user-facing changes? Yes (behavior change when page index is enabled). * Closes: #34375 Authored-by: Gang Wu Signed-off-by: Antoine Pitrou --- .../parquet/arrow/arrow_reader_writer_test.cc | 21 +++++++++++++++++-- cpp/src/parquet/column_writer.cc | 12 +++++++++-- cpp/src/parquet/properties.h | 7 +++++-- docs/source/cpp/parquet.rst | 11 +++++----- 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc index ad33ca296a283..ba3702ede4de7 100644 --- a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc +++ b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc @@ -5252,17 +5252,34 @@ class ParquetPageIndexRoundTripTest : public ::testing::Test { auto row_group_index_reader = page_index_reader->RowGroup(rg); ASSERT_NE(row_group_index_reader, nullptr); + auto row_group_reader = reader->RowGroup(rg); + ASSERT_NE(row_group_reader, nullptr); + for (int col = 0; col < metadata->num_columns(); ++col) { auto column_index = row_group_index_reader->GetColumnIndex(col); column_indexes_.emplace_back(column_index.get()); + bool expect_no_page_index = + expect_columns_without_index.find(col) != expect_columns_without_index.cend(); + auto offset_index = row_group_index_reader->GetOffsetIndex(col); - if (expect_columns_without_index.find(col) != - expect_columns_without_index.cend()) { + if (expect_no_page_index) { ASSERT_EQ(offset_index, nullptr); } else { CheckOffsetIndex(offset_index.get(), expect_num_pages, &offset_lower_bound); } + + // Verify page stats are not written to page header if page index is enabled. + auto page_reader = row_group_reader->GetColumnPageReader(col); + ASSERT_NE(page_reader, nullptr); + std::shared_ptr page = nullptr; + while ((page = page_reader->NextPage()) != nullptr) { + if (page->type() == PageType::DATA_PAGE || + page->type() == PageType::DATA_PAGE_V2) { + ASSERT_EQ(std::static_pointer_cast(page)->statistics().is_set(), + expect_no_page_index); + } + } } } } diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc index f298f282fe76a..33e9f8f6658ae 100644 --- a/cpp/src/parquet/column_writer.cc +++ b/cpp/src/parquet/column_writer.cc @@ -460,7 +460,11 @@ class SerializedPageWriter : public PageWriter { ToThrift(page.definition_level_encoding())); data_page_header.__set_repetition_level_encoding( ToThrift(page.repetition_level_encoding())); - data_page_header.__set_statistics(ToThrift(page.statistics())); + + // Write page statistics only when page index is not enabled. + if (column_index_builder_ == nullptr) { + data_page_header.__set_statistics(ToThrift(page.statistics())); + } page_header.__set_type(format::PageType::DATA_PAGE); page_header.__set_data_page_header(data_page_header); @@ -479,7 +483,11 @@ class SerializedPageWriter : public PageWriter { page.repetition_levels_byte_length()); data_page_header.__set_is_compressed(page.is_compressed()); - data_page_header.__set_statistics(ToThrift(page.statistics())); + + // Write page statistics only when page index is not enabled. + if (column_index_builder_ == nullptr) { + data_page_header.__set_statistics(ToThrift(page.statistics())); + } page_header.__set_type(format::PageType::DATA_PAGE_V2); page_header.__set_data_page_header_v2(data_page_header); diff --git a/cpp/src/parquet/properties.h b/cpp/src/parquet/properties.h index 0a9864de6266a..d1012fccf0ca5 100644 --- a/cpp/src/parquet/properties.h +++ b/cpp/src/parquet/properties.h @@ -524,8 +524,11 @@ class PARQUET_EXPORT WriterProperties { /// Enable writing page index in general for all columns. Default disabled. /// - /// Page index contains statistics for data pages and can be used to skip pages - /// when scanning data in ordered and unordered columns. + /// Writing statistics to the page index disables the old method of writing + /// statistics to each data page header. + /// The page index makes filtering more efficient than the page header, as + /// it gathers all the statistics for a Parquet file in a single place, + /// avoiding scattered I/O. /// /// Please check the link below for more details: /// https://github.com/apache/parquet-format/blob/master/PageIndex.md diff --git a/docs/source/cpp/parquet.rst b/docs/source/cpp/parquet.rst index 0ea6063b2a276..95f2d8d98dc0a 100644 --- a/docs/source/cpp/parquet.rst +++ b/docs/source/cpp/parquet.rst @@ -304,6 +304,8 @@ Statistics are enabled by default for all columns. You can disable statistics fo all columns or specific columns using ``disable_statistics`` on the builder. There is a ``max_statistics_size`` which limits the maximum number of bytes that may be used for min and max values, useful for types like strings or binary blobs. +If a column has enabled page index using ``enable_write_page_index``, then it does +not write statistics to the page header because it is duplicated in the ColumnIndex. There are also Arrow-specific settings that can be configured with :class:`parquet::ArrowWriterProperties`: @@ -573,13 +575,13 @@ Miscellaneous +--------------------------+----------+----------+---------+ | Feature | Reading | Writing | Notes | +==========================+==========+==========+=========+ -| Column Index | ✓ | | \(1) | +| Column Index | ✓ | ✓ | \(1) | +--------------------------+----------+----------+---------+ -| Offset Index | ✓ | | \(1) | +| Offset Index | ✓ | ✓ | \(1) | +--------------------------+----------+----------+---------+ | Bloom Filter | ✓ | ✓ | \(2) | +--------------------------+----------+----------+---------+ -| CRC checksums | ✓ | ✓ | \(3) | +| CRC checksums | ✓ | ✓ | | +--------------------------+----------+----------+---------+ * \(1) Access to the Column and Offset Index structures is provided, but @@ -587,6 +589,3 @@ Miscellaneous * \(2) APIs are provided for creating, serializing and deserializing Bloom Filters, but they are not integrated into data read APIs. - -* \(3) For now, only the checksums of V1 Data Pages and Dictionary Pages - are computed. From 3d0172d40dfcf934308e6e1f4249a854004fe824 Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Tue, 6 Jun 2023 13:48:12 +0100 Subject: [PATCH 06/31] GH-35937: [R] Update NEWS for 12.0.1 (#35938) Lead-authored-by: Nic Crane Co-authored-by: Dewey Dunnington Signed-off-by: Nic Crane --- r/NEWS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/r/NEWS.md b/r/NEWS.md index bc130249c0824..e36b42e5389ee 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -19,6 +19,10 @@ # arrow 12.0.0.9000 +* Update the version of the date library vendored with Arrow C++ library + for compatibility with tzdb 0.4.0 (#35594, #35612). +* Update some tests for compatibility with waldo 0.5.1 (#35131, #35308). + # arrow 12.0.0 ## New features From 4ec231b23432e9af340ed1a10d914f9bee84aca7 Mon Sep 17 00:00:00 2001 From: Igor Izvekov Date: Tue, 6 Jun 2023 17:45:06 +0300 Subject: [PATCH 07/31] GH-35911: [Go] Fix method CastToBytes of decimal256Traits (#35912) ### Rationale for this change ### What changes are included in this PR? ### Are these changes tested? Yes ### Are there any user-facing changes? Yes * Closes: #35911 Authored-by: izveigor Signed-off-by: Matt Topol --- go/arrow/type_traits_decimal256.go | 2 +- go/arrow/type_traits_test.go | 45 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/go/arrow/type_traits_decimal256.go b/go/arrow/type_traits_decimal256.go index b213df32fbcf3..38d26ce1002a2 100644 --- a/go/arrow/type_traits_decimal256.go +++ b/go/arrow/type_traits_decimal256.go @@ -59,7 +59,7 @@ func (decimal256Traits) CastToBytes(b []decimal256.Num) []byte { h := (*reflect.SliceHeader)(unsafe.Pointer(&b)) var res []byte - s := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + s := (*reflect.SliceHeader)(unsafe.Pointer(&res)) s.Data = h.Data s.Len = h.Len * Decimal256SizeBytes s.Cap = h.Cap * Decimal256SizeBytes diff --git a/go/arrow/type_traits_test.go b/go/arrow/type_traits_test.go index 3b9571d3d3d49..a54d2d7c696c4 100644 --- a/go/arrow/type_traits_test.go +++ b/go/arrow/type_traits_test.go @@ -24,6 +24,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/decimal128" + "github.com/apache/arrow/go/v13/arrow/decimal256" "github.com/apache/arrow/go/v13/arrow/float16" ) @@ -133,6 +134,50 @@ func TestDecimal128Traits(t *testing.T) { } } +func TestDecimal256Traits(t *testing.T) { + const N = 10 + nbytes := arrow.Decimal256Traits.BytesRequired(N) + b1 := arrow.Decimal256Traits.CastToBytes([]decimal256.Num{ + decimal256.New(0, 0, 0, 10), + decimal256.New(1, 1, 1, 10), + decimal256.New(2, 2, 2, 10), + decimal256.New(3, 3, 3, 10), + decimal256.New(4, 4, 4, 10), + decimal256.New(5, 5, 5, 10), + decimal256.New(6, 6, 6, 10), + decimal256.New(7, 7, 7, 10), + decimal256.New(8, 8, 8, 10), + decimal256.New(9, 9, 9, 10), + }) + + b2 := make([]byte, nbytes) + for i := 0; i < N; i++ { + beg := i * arrow.Decimal256SizeBytes + end := (i + 1) * arrow.Decimal256SizeBytes + arrow.Decimal256Traits.PutValue(b2[beg:end], decimal256.New(uint64(i), uint64(i), uint64(i), 10)) + } + + if !reflect.DeepEqual(b1, b2) { + v1 := arrow.Decimal256Traits.CastFromBytes(b1) + v2 := arrow.Decimal256Traits.CastFromBytes(b2) + t.Fatalf("invalid values:\nb1=%v\nb2=%v\nv1=%v\nv2=%v\n", b1, b2, v1, v2) + } + + v1 := arrow.Decimal256Traits.CastFromBytes(b1) + for i, v := range v1 { + if got, want := v, decimal256.New(uint64(i), uint64(i), uint64(i), 10); got != want { + t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got, want) + } + } + + v2 := make([]decimal256.Num, N) + arrow.Decimal256Traits.Copy(v2, v1) + + if !reflect.DeepEqual(v1, v2) { + t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2) + } +} + func TestMonthIntervalTraits(t *testing.T) { const N = 10 b1 := arrow.MonthIntervalTraits.CastToBytes([]arrow.MonthInterval{ From c78ef0f1a4adc9211aef78345415faa00cefa0bf Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 6 Jun 2023 17:24:04 +0200 Subject: [PATCH 08/31] GH-35040: [Python] Skip test_cast_timestamp_to_string on Windows because it requires tz database (#35735) ### Rationale for this change Fix up of https://github.com/apache/arrow/pull/35395, skipping one of the tests added in that PR on Windows, because the test requires access to a tz database. Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- python/pyarrow/tests/test_compute.py | 44 +++++++++++++++------------- python/pyarrow/tests/test_scalars.py | 4 +++ python/pyarrow/tests/util.py | 9 ++++++ 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 875d0e613b6ca..f934edd3c3bcb 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -36,6 +36,8 @@ import pyarrow as pa import pyarrow.compute as pc from pyarrow.lib import ArrowNotImplementedError +from pyarrow.tests import util + all_array_types = [ ('bool', [True, False, False, True, True]), @@ -180,17 +182,19 @@ def test_option_class_equality(): pc.WeekOptions(week_starts_monday=True, count_from_zero=False, first_week_is_fully_in_year=False), ] - # TODO: We should test on windows once ARROW-13168 is resolved. - # Timezone database is not available on Windows yet - if sys.platform != 'win32': + # Timezone database might not be installed on Windows + if sys.platform != "win32" or util.windows_has_tzdata(): options.append(pc.AssumeTimezoneOptions("Europe/Ljubljana")) classes = {type(option) for option in options} for cls in exported_option_classes: - # Timezone database is not available on Windows yet - if cls not in classes and sys.platform != 'win32' and \ - cls != pc.AssumeTimezoneOptions: + # Timezone database might not be installed on Windows + if ( + cls not in classes + and (sys.platform != "win32" or util.windows_has_tzdata()) + and cls != pc.AssumeTimezoneOptions + ): try: options.append(cls()) except TypeError: @@ -1846,17 +1850,18 @@ def test_strptime(): assert got == pa.array([None, None, None], type=pa.timestamp('s')) -# TODO: We should test on windows once ARROW-13168 is resolved. @pytest.mark.pandas -@pytest.mark.skipif(sys.platform == 'win32', - reason="Timezone database is not available on Windows yet") +@pytest.mark.skipif(sys.platform == "win32" and not util.windows_has_tzdata(), + reason="Timezone database is not installed on Windows") def test_strftime(): times = ["2018-03-10 09:00", "2038-01-31 12:23", None] timezones = ["CET", "UTC", "Europe/Ljubljana"] - formats = ["%a", "%A", "%w", "%d", "%b", "%B", "%m", "%y", "%Y", "%H", - "%I", "%p", "%M", "%z", "%Z", "%j", "%U", "%W", "%c", "%x", - "%X", "%%", "%G", "%V", "%u"] + formats = ["%a", "%A", "%w", "%d", "%b", "%B", "%m", "%y", "%Y", "%H", "%I", + "%p", "%M", "%z", "%Z", "%j", "%U", "%W", "%%", "%G", "%V", "%u"] + if sys.platform != "win32": + # Locale-dependent formats don't match on Windows + formats.extend(["%c", "%x", "%X"]) for timezone in timezones: ts = pd.to_datetime(times).tz_localize(timezone) @@ -2029,18 +2034,16 @@ def test_extract_datetime_components(): _check_datetime_components(timestamps) # Test timezone aware timestamp array - if sys.platform == 'win32': - # TODO: We should test on windows once ARROW-13168 is resolved. - pytest.skip('Timezone database is not available on Windows yet') + if sys.platform == "win32" and not util.windows_has_tzdata(): + pytest.skip('Timezone database is not installed on Windows') else: for timezone in timezones: _check_datetime_components(timestamps, timezone) -# TODO: We should test on windows once ARROW-13168 is resolved. @pytest.mark.pandas -@pytest.mark.skipif(sys.platform == 'win32', - reason="Timezone database is not available on Windows yet") +@pytest.mark.skipif(sys.platform == "win32" and not util.windows_has_tzdata(), + reason="Timezone database is not installed on Windows") def test_assume_timezone(): ts_type = pa.timestamp("ns") timestamps = pd.to_datetime(["1970-01-01T00:00:59.123456789", @@ -2235,9 +2238,8 @@ def _check_temporal_rounding(ts, values, unit): np.testing.assert_array_equal(result, expected) -# TODO: We should test on windows once ARROW-13168 is resolved. -@pytest.mark.skipif(sys.platform == 'win32', - reason="Timezone database is not available on Windows yet") +@pytest.mark.skipif(sys.platform == "win32" and not util.windows_has_tzdata(), + reason="Timezone database is not installed on Windows") @pytest.mark.parametrize('unit', ("nanosecond", "microsecond", "millisecond", "second", "minute", "hour", "day")) @pytest.mark.pandas diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py index 1e6a3f29e0d7e..b7180e5250fdf 100644 --- a/python/pyarrow/tests/test_scalars.py +++ b/python/pyarrow/tests/test_scalars.py @@ -19,12 +19,14 @@ import decimal import pickle import pytest +import sys import weakref import numpy as np import pyarrow as pa import pyarrow.compute as pc +from pyarrow.tests import util @pytest.mark.parametrize(['value', 'ty', 'klass'], [ @@ -304,6 +306,8 @@ def test_cast(): pa.scalar('foo').cast('int32') +@pytest.mark.skipif(sys.platform == "win32" and not util.windows_has_tzdata(), + reason="Timezone database is not installed on Windows") def test_cast_timestamp_to_string(): # GH-35370 pytest.importorskip("pytz") diff --git a/python/pyarrow/tests/util.py b/python/pyarrow/tests/util.py index df7936371ee8f..0b69deb73ba28 100644 --- a/python/pyarrow/tests/util.py +++ b/python/pyarrow/tests/util.py @@ -448,3 +448,12 @@ def _configure_s3_limited_user(s3_server, policy): except FileNotFoundError: pytest.skip("Configuring limited s3 user failed") + + +def windows_has_tzdata(): + """ + This is the default location where tz.cpp will look for (until we make + this configurable at run-time) + """ + tzdata_path = os.path.expandvars(r"%USERPROFILE%\Downloads\tzdata") + return os.path.exists(tzdata_path) From 9fb8697dcb442f63317c7d6046393fb74842e0ae Mon Sep 17 00:00:00 2001 From: Igor Izvekov Date: Tue, 6 Jun 2023 18:36:24 +0300 Subject: [PATCH 09/31] MINOR: [Go] Use `decimal256` in arrdata (#35913) ### Rationale for this change ### What changes are included in this PR? ### Are these changes tested? ### Are there any user-facing changes? No Authored-by: izveigor Signed-off-by: Matt Topol --- go/arrow/internal/arrdata/arrdata.go | 56 ++++++++++++++ go/arrow/internal/arrjson/arrjson_test.go | 92 +++++++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/go/arrow/internal/arrdata/arrdata.go b/go/arrow/internal/arrdata/arrdata.go index f488724e940ca..3f023e40d67dd 100644 --- a/go/arrow/internal/arrdata/arrdata.go +++ b/go/arrow/internal/arrdata/arrdata.go @@ -24,6 +24,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/decimal128" + "github.com/apache/arrow/go/v13/arrow/decimal256" "github.com/apache/arrow/go/v13/arrow/float16" "github.com/apache/arrow/go/v13/arrow/ipc" "github.com/apache/arrow/go/v13/arrow/memory" @@ -47,6 +48,7 @@ func init() { Records["intervals"] = makeIntervalsRecords() Records["durations"] = makeDurationsRecords() Records["decimal128"] = makeDecimal128sRecords() + Records["decimal256"] = makeDecimal256sRecords() Records["maps"] = makeMapsRecords() Records["extension"] = makeExtensionRecords() Records["union"] = makeUnionRecords() @@ -688,6 +690,7 @@ func makeDurationsRecords() []arrow.Record { var ( decimal128Type = &arrow.Decimal128Type{Precision: 10, Scale: 1} + decimal256Type = &arrow.Decimal256Type{Precision: 72, Scale: 2} ) func makeDecimal128sRecords() []arrow.Record { @@ -735,6 +738,51 @@ func makeDecimal128sRecords() []arrow.Record { return recs } +func makeDecimal256sRecords() []arrow.Record { + mem := memory.NewGoAllocator() + schema := arrow.NewSchema( + []arrow.Field{ + {Name: "dec256s", Type: decimal256Type, Nullable: true}, + }, nil, + ) + + dec256s := func(vs []uint64) []decimal256.Num { + o := make([]decimal256.Num, len(vs)) + for i, v := range vs { + o[i] = decimal256.New(v, v, v, v) + } + return o + } + + mask := []bool{true, false, false, true, true} + chunks := [][]arrow.Array{ + { + arrayOf(mem, dec256s([]uint64{21, 22, 23, 24, 25}), mask), + }, + { + arrayOf(mem, dec256s([]uint64{31, 32, 33, 34, 35}), mask), + }, + { + arrayOf(mem, dec256s([]uint64{41, 42, 43, 44, 45}), mask), + }, + } + + defer func() { + for _, chunk := range chunks { + for _, col := range chunk { + col.Release() + } + } + }() + + recs := make([]arrow.Record, len(chunks)) + for i, chunk := range chunks { + recs[i] = array.NewRecord(schema, chunk, -1) + } + + return recs +} + func makeMapsRecords() []arrow.Record { mem := memory.NewGoAllocator() dtype := arrow.MapOf(arrow.PrimitiveTypes.Int32, arrow.BinaryTypes.String) @@ -1167,6 +1215,14 @@ func arrayOf(mem memory.Allocator, a interface{}, valids []bool) arrow.Array { aa := bldr.NewDecimal128Array() return aa + case []decimal256.Num: + bldr := array.NewDecimal256Builder(mem, decimal256Type) + defer bldr.Release() + + bldr.AppendValues(a, valids) + aa := bldr.NewDecimal256Array() + return aa + case []string: bldr := array.NewStringBuilder(mem) defer bldr.Release() diff --git a/go/arrow/internal/arrjson/arrjson_test.go b/go/arrow/internal/arrjson/arrjson_test.go index bba361982fdb9..e969e553a1394 100644 --- a/go/arrow/internal/arrjson/arrjson_test.go +++ b/go/arrow/internal/arrjson/arrjson_test.go @@ -41,6 +41,7 @@ func TestReadWrite(t *testing.T) { wantJSONs["intervals"] = makeIntervalsWantJSONs() wantJSONs["durations"] = makeDurationsWantJSONs() wantJSONs["decimal128"] = makeDecimal128sWantJSONs() + wantJSONs["decimal256"] = makeDecimal256sWantJSONs() wantJSONs["maps"] = makeMapsWantJSONs() wantJSONs["extension"] = makeExtensionsWantJSONs() wantJSONs["dictionary"] = makeDictionaryWantJSONs() @@ -3430,6 +3431,97 @@ func makeDecimal128sWantJSONs() string { }` } +func makeDecimal256sWantJSONs() string { + return `{ + "schema": { + "fields": [ + { + "name": "dec256s", + "type": { + "name": "decimal", + "scale": 2, + "precision": 72, + "bitWidth": 256 + }, + "nullable": true, + "children": [] + } + ] + }, + "batches": [ + { + "count": 5, + "columns": [ + { + "name": "dec256s", + "count": 5, + "VALIDITY": [ + 1, + 0, + 0, + 1, + 1 + ], + "DATA": [ + "131819136443120296047697507592700702471267712715359757795349", + "138096238178506976811873579382829307350851889511329270071318", + "144373339913893657576049651172957912230436066307298782347287", + "150650441649280338340225722963086517110020243103268294623256", + "156927543384667019104401794753215121989604419899237806899225" + ] + } + ] + }, + { + "count": 5, + "columns": [ + { + "name": "dec256s", + "count": 5, + "VALIDITY": [ + 1, + 0, + 0, + 1, + 1 + ], + "DATA": [ + "194590153796987103689458225493986751267109480675054880555039", + "200867255532373784453634297284115356146693657471024392831008", + "207144357267760465217810369074243961026277834266993905106977", + "213421459003147145981986440864372565905862011062963417382946", + "219698560738533826746162512654501170785446187858932929658915" + ] + } + ] + }, + { + "count": 5, + "columns": [ + { + "name": "dec256s", + "count": 5, + "VALIDITY": [ + 1, + 0, + 0, + 1, + 1 + ], + "DATA": [ + "257361171150853911331218943395272800062951248634750003314729", + "263638272886240592095395015185401404942535425430719515590698", + "269915374621627272859571086975530009822119602226689027866667", + "276192476357013953623747158765658614701703779022658540142636", + "282469578092400634387923230555787219581287955818628052418605" + ] + } + ] + } + ] +}` +} + func makeMapsWantJSONs() string { return `{ "schema": { From 105b9df0d504f25901cb89f49d365ec46b6de87b Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 6 Jun 2023 12:31:15 -0400 Subject: [PATCH 10/31] GH-34971: [Format] Add non-CPU version of C Data Interface (#34972) ### Rationale for this change In order to support non-cpu devices and memory usage, we can add new `ArrowDeviceArray` and `ArrowDeviceArrayStream` structs to the C Data Interface in order to allow for handling these types of memory. ### What changes are included in this PR? Definitions for a new `ArrowDeviceArray`, `ArrowDeviceArrayStream` and `ArrowDeviceType` enums. * Closes: #34971 Lead-authored-by: Matt Topol Co-authored-by: David Li Co-authored-by: Antoine Pitrou Co-authored-by: John Zedlewski <904524+JohnZed@users.noreply.github.com> Co-authored-by: Gang Wu Signed-off-by: Matt Topol --- cpp/src/arrow/c/abi.h | 122 ++++ docs/source/format/CDataInterface.rst | 4 + docs/source/format/CDeviceDataInterface.rst | 693 ++++++++++++++++++++ docs/source/index.rst | 1 + 4 files changed, 820 insertions(+) create mode 100644 docs/source/format/CDeviceDataInterface.rst diff --git a/cpp/src/arrow/c/abi.h b/cpp/src/arrow/c/abi.h index d58417e6fbcf2..6abe866b5f6f6 100644 --- a/cpp/src/arrow/c/abi.h +++ b/cpp/src/arrow/c/abi.h @@ -15,10 +15,27 @@ // specific language governing permissions and limitations // under the License. +/// \file abi.h Arrow C Data Interface +/// +/// The Arrow C Data interface defines a very small, stable set +/// of C definitions which can be easily copied into any project's +/// source code and vendored to be used for columnar data interchange +/// in the Arrow format. For non-C/C++ languages and runtimes, +/// it should be almost as easy to translate the C definitions into +/// the corresponding C FFI declarations. +/// +/// Applications and libraries can therefore work with Arrow memory +/// without necessarily using the Arrow libraries or reinventing +/// the wheel. Developers can choose between tight integration +/// with the Arrow software project or minimal integration with +/// the Arrow format only. + #pragma once #include +// Spec and documentation: https://arrow.apache.org/docs/format/CDataInterface.html + #ifdef __cplusplus extern "C" { #endif @@ -65,6 +82,61 @@ struct ArrowArray { #endif // ARROW_C_DATA_INTERFACE +#ifndef ARROW_C_DEVICE_DATA_INTERFACE +#define ARROW_C_DEVICE_DATA_INTERFACE + +// Spec and Documentation: https://arrow.apache.org/docs/format/CDeviceDataInterface.html + +// DeviceType for the allocated memory +typedef int32_t ArrowDeviceType; + +// CPU device, same as using ArrowArray directly +#define ARROW_DEVICE_CPU 1 +// CUDA GPU Device +#define ARROW_DEVICE_CUDA 2 +// Pinned CUDA CPU memory by cudaMallocHost +#define ARROW_DEVICE_CUDA_HOST 3 +// OpenCL Device +#define ARROW_DEVICE_OPENCL 4 +// Vulkan buffer for next-gen graphics +#define ARROW_DEVICE_VULKAN 7 +// Metal for Apple GPU +#define ARROW_DEVICE_METAL 8 +// Verilog simulator buffer +#define ARROW_DEVICE_VPI 9 +// ROCm GPUs for AMD GPUs +#define ARROW_DEVICE_ROCM 10 +// Pinned ROCm CPU memory allocated by hipMallocHost +#define ARROW_DEVICE_ROCM_HOST 11 +// Reserved for extension +#define ARROW_DEVICE_EXT_DEV 12 +// CUDA managed/unified memory allocated by cudaMallocManaged +#define ARROW_DEVICE_CUDA_MANAGED 13 +// unified shared memory allocated on a oneAPI non-partitioned device. +#define ARROW_DEVICE_ONEAPI 14 +// GPU support for next-gen WebGPU standard +#define ARROW_DEVICE_WEBGPU 15 +// Qualcomm Hexagon DSP +#define ARROW_DEVICE_HEXAGON 16 + +struct ArrowDeviceArray { + // the Allocated Array + // + // the buffers in the array (along with the buffers of any + // children) are what is allocated on the device. + struct ArrowArray array; + // The device id to identify a specific device + int64_t device_id; + // The type of device which can access this memory. + ArrowDeviceType device_type; + // An event-like object to synchronize on if needed. + void* sync_event; + // Reserved bytes for future expansion. + int64_t reserved[3]; +}; + +#endif // ARROW_C_DEVICE_DATA_INTERFACE + #ifndef ARROW_C_STREAM_INTERFACE #define ARROW_C_STREAM_INTERFACE @@ -106,6 +178,56 @@ struct ArrowArrayStream { #endif // ARROW_C_STREAM_INTERFACE +#ifndef ARROW_C_DEVICE_STREAM_INTERFACE +#define ARROW_C_DEVICE_STREAM_INTERFACE + +// Equivalent to ArrowArrayStream, but for ArrowDeviceArrays. +// +// This stream is intended to provide a stream of data on a single +// device, if a producer wants data to be produced on multiple devices +// then multiple streams should be provided. One per device. +struct ArrowDeviceArrayStream { + // The device that this stream produces data on. + ArrowDeviceType device_type; + + // Callback to get the stream schema + // (will be the same for all arrays in the stream). + // + // Return value 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowSchema must be released independently from the stream. + // The schema should be accessible via CPU memory. + int (*get_schema)(struct ArrowDeviceArrayStream* self, struct ArrowSchema* out); + + // Callback to get the next array + // (if no error and the array is released, the stream has ended) + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowDeviceArray must be released independently from the stream. + int (*get_next)(struct ArrowDeviceArrayStream* self, struct ArrowDeviceArray* out); + + // Callback to get optional detailed error information. + // This must only be called if the last stream operation failed + // with a non-0 return code. + // + // Return value: pointer to a null-terminated character array describing + // the last error, or NULL if no description is available. + // + // The returned pointer is only valid until the next operation on this stream + // (including release). + const char* (*get_last_error)(struct ArrowDeviceArrayStream* self); + + // Release callback: release the stream's own resources. + // Note that arrays returned by `get_next` must be individually released. + void (*release)(struct ArrowDeviceArrayStream* self); + + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DEVICE_STREAM_INTERFACE + #ifdef __cplusplus } #endif diff --git a/docs/source/format/CDataInterface.rst b/docs/source/format/CDataInterface.rst index fde872ff9711a..5d3747bd3aaa2 100644 --- a/docs/source/format/CDataInterface.rst +++ b/docs/source/format/CDataInterface.rst @@ -246,6 +246,7 @@ Examples has format string ``+us:4,5``; its two children have names ``ints`` and ``floats``, and format strings ``i`` and ``f`` respectively. +.. _c-data-interface-struct-defs: Structure definitions ===================== @@ -531,6 +532,7 @@ parameterized extension types). The ``ArrowArray`` structure exported from an extension array simply points to the storage data of the extension array. +.. _c-data-interface-semantics: Semantics ========= @@ -703,6 +705,8 @@ C producer examples Exporting a simple ``int32`` array ---------------------------------- +.. _c-data-interface-export-int32-schema: + Export a non-nullable ``int32`` type with empty metadata. In this case, all ``ArrowSchema`` members point to statically-allocated data, so the release callback is trivial. diff --git a/docs/source/format/CDeviceDataInterface.rst b/docs/source/format/CDeviceDataInterface.rst new file mode 100644 index 0000000000000..ed3f3fb16d496 --- /dev/null +++ b/docs/source/format/CDeviceDataInterface.rst @@ -0,0 +1,693 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. highlight:: c + +.. _c-device-data-interface: + +================================= +The Arrow C Device data interface +================================= + +.. warning:: The Arrow C Device Data Interface should be considered experimental + +Rationale +========= + +The current :ref:`C Data Interface `, and most +implementations of it, make the assumption that all data buffers provided +are CPU buffers. Since Apache Arrow is designed to be a universal in-memory +format for representing tabular ("columnar") data, there will be the desire +to leverage this data on non-CPU hardware such as GPUs. One example of such +a case is the `RAPIDS cuDF library`_ which uses the Arrow memory format with +CUDA for NVIDIA GPUs. Since copying data from host to device and back is +expensive, the ideal would be to be able to leave the data on the device +for as long as possible, even when passing it between runtimes and +libraries. + +The Arrow C Device data interface builds on the existing C data interface +by adding a very small, stable set of C definitions to it. These definitions +are equivalents to the ``ArrowArray`` and ``ArrowArrayStream`` structures +from the C Data Interface which add members to allow specifying the device +type and pass necessary information to synchronize with the producer. +For non-C/C++ languages and runtimes, translating the C definitions to +corresponding C FFI declarations should be just as simple as with the +current C data interface. + +Applications and libraries can then use Arrow schemas and Arrow formatted +memory on non-CPU devices to exchange data just as easily as they do +now with CPU data. This will enable leaving data on those devices longer +and avoiding costly copies back and forth between the host and device +just to leverage new libraries and runtimes. + +Goals +----- + +* Expose an ABI-stable interface built on the existing C data interface. +* Make it easy for third-party projects to implement support with little + initial investment. +* Allow zero-copy sharing of Arrow formatted device memory between + independant runtimes and components running in the same process. +* Avoid the need for one-to-one adaptation layers such as the + `CUDA Array Interface`_ for Python processes to pass CUDA data. +* Enable integration without explicit dependencies (either at compile-time + or runtime) on the Arrow software project itself. + +The intent is for the Arrow C Device data interface to expand the reach +of the current C data interface, allowing it to also become the standard +low-level building block for columnar processing on devices like GPUs or +FPGAs. + +Structure definitions +===================== + +Because this is built on the C data interface, the C Device data interface +uses the ``ArrowSchema`` and ``ArrowArray`` structures as defined in the +:ref:`C data interface spec `. It then adds the +following free-standing definitions. Like the rest of the Arrow project, +they are available under the Apache License 2.0. + +.. code-block:: c + + #ifndef ARROW_C_DEVICE_DATA_INTERFACE + #define ARROW_C_DEVICE_DATA_INTERFACE + + // Device type for the allocated memory + typedef int32_t ArrowDeviceType; + + // CPU device, same as using ArrowArray directly + #define ARROW_DEVICE_CPU 1 + // CUDA GPU Device + #define ARROW_DEVICE_CUDA 2 + // Pinned CUDA CPU memory by cudaMallocHost + #define ARROW_DEVICE_CUDA_HOST 3 + // OpenCL Device + #define ARROW_DEVICE_OPENCL 4 + // Vulkan buffer for next-gen graphics + #define ARROW_DEVICE_VULKAN 7 + // Metal for Apple GPU + #define ARROW_DEVICE_METAL 8 + // Verilog simulator buffer + #define ARROW_DEVICE_VPI 9 + // ROCm GPUs for AMD GPUs + #define ARROW_DEVICE_ROCM 10 + // Pinned ROCm CPU memory allocated by hipMallocHost + #define ARROW_DEVICE_ROCM_HOST 11 + // Reserved for extension + // + // used to quickly test extension devices, semantics + // can differ based on implementation + #define ARROW_DEVICE_EXT_DEV 12 + // CUDA managed/unified memory allocated by cudaMallocManaged + #define ARROW_DEVICE_CUDA_MANAGED 13 + // Unified shared memory allocated on a oneAPI + // non-partitioned device. + // + // A call to the oneAPI runtime is required to determine the + // device type, the USM allocation type and the sycl context + // that it is bound to. + #define ARROW_DEVICE_ONEAPI 14 + // GPU support for next-gen WebGPU standard + #define ARROW_DEVICE_WEBGPU 15 + // Qualcomm Hexagon DSP + #define ARROW_DEVICE_HEXAGON 16 + + struct ArrowDeviceArray { + struct ArrowArray array; + int64_t device_id; + ArrowDeviceType device_type; + void* sync_event; + + // reserved bytes for future expansion + int64_t reserved[3]; + }; + + #endif // ARROW_C_DEVICE_DATA_INTERFACE + +.. note:: + The canonical guard ``ARROW_C_DEVICE_DATA_INTERFACE`` is meant to avoid + duplicate definitions if two projects copy the definitions in their own + headers, and a third-party project includes from these two projects. It + is therefore important that this guard is kept exactly as-is when these + definitions are copied. + +ArrowDeviceType +--------------- + +The ``ArrowDeviceType`` typedef is used to indicate what type of device the +provided memory buffers were allocated on. This, in conjunction with the +``device_id``, should be sufficient to reference the correct data buffers. + +We then use macros to define values for different device types. The provided +macro values are compatible with the widely used `dlpack`_ ``DLDeviceType`` +definition values, using the same value for each as the equivalent +``kDL`` enum from ``dlpack.h``. The list will be kept in sync with those +equivalent enum values over time to ensure compatibility, rather than +potentially diverging. To avoid the Arrow project having to be in the +position of vetting new hardware devices, new additions should first be +added to dlpack before we add a corresponding macro here. + +To ensure predictability with the ABI, we use macros instead of an ``enum`` +so the storage type is not compiler dependent. + +.. c:macro:: ARROW_DEVICE_CPU + + CPU Device, equivalent to just using ``ArrowArray`` directly instead of + using ``ArrowDeviceArray``. + +.. c:macro:: ARROW_DEVICE_CUDA + + A `CUDA`_ GPU Device. This could represent data allocated either with the + runtime library (``cudaMalloc``) or the device driver (``cuMemAlloc``). + +.. c:macro:: ARROW_DEVICE_CUDA_HOST + + CPU memory that was pinned and page-locked by CUDA by using + ``cudaMallocHost`` or ``cuMemAllocHost``. + +.. c:macro:: ARROW_DEVICE_OPENCL + + Data allocated on the device by using the `OpenCL (Open Computing Language)`_ + framework. + +.. c:macro:: ARROW_DEVICE_VULKAN + + Data allocated by the `Vulkan`_ framework and libraries. + +.. c:macro:: ARROW_DEVICE_METAL + + Data on Apple GPU devices using the `Metal`_ framework and libraries. + +.. c:macro:: ARROW_DEVICE_VPI + + Indicates usage of a Verilog simulator buffer. + +.. c:macro:: ARROW_DEVICE_ROCM + + An AMD device using the `ROCm`_ stack. + +.. c:macro:: ARROW_DEVICE_ROCM_HOST + + CPU memory that was pinned and page-locked by ROCm by using ``hipMallocHost``. + +.. c:macro:: ARROW_DEVICE_EXT_DEV + + This value is an escape-hatch for devices to extend which aren't + currently represented otherwise. Producers would need to provide + additional information/context specific to the device if using + this device type. This is used to quickly test extension devices + and semantics can differ based on the implementation. + +.. c:macro:: ARROW_DEVICE_CUDA_MANAGED + + CUDA managed/unified memory which is allocated by ``cudaMallocManaged``. + +.. c:macro:: ARROW_DEVICE_ONEAPI + + Unified shared memory allocated on an Intel `oneAPI`_ non-partitioned + device. A call to the ``oneAPI`` runtime is required to determine + the specific device type, the USM allocation type and the sycl context + that it is bound to. + +.. c:macro:: ARROW_DEVICE_WEBGPU + + GPU support for next-gen WebGPU standards + +.. c:macro:: ARROW_DEVICE_HEXAGON + + Data allocated on a Qualcomm Hexagon DSP device. + +The ArrowDeviceArray structure +------------------------------ + +The ``ArrowDeviceArray`` structure embeds the C data ``ArrowArray`` structure +and adds additional information necessary for consumers to use the data. It +has the following fields: + +.. c:member:: struct ArrowArray ArrowDeviceArray.array + + *Mandatory.* The allocated array data. The values in the ``void**`` buffers (along + with the buffers of any children) are what is allocated on the device. + The buffer values should be device pointers. The rest of the structure + should be accessible to the CPU. + + The ``private_data`` and ``release`` callback of this structure should + contain any necessary information and structures related to freeing + the array according to the device it is allocated on, rather than + having a separate release callback and ``private_data`` pointer here. + +.. c:member:: int64_t ArrowDeviceArray.device_id + + *Mandatory.* The device id to identify a specific device if multiple devices of this + type are on the system. The semantics of the id will be hardware dependent, + but we use an ``int64_t`` to future-proof the id as devices change over time. + +.. c:member:: ArrowDeviceType ArrowDeviceArray.device_type + + *Mandatory.* The type of the device which can access the buffers in the array. + +.. c:member:: void* ArrowDeviceArray.sync_event + + *Optional.* An event-like object to synchronize on if needed. + + Many devices, like GPUs, are primarily asynchronous with respect to + CPU processing. As such, in order to safely access device memory, it is often + necessary to have an object to synchronize processing with. Since + different devices will use different types to specify this, we use a + ``void*`` which can be coerced into a pointer to whatever the device + appropriate type is. + + If synchronization is not needed, this can be null. If this is non-null + then it MUST be used to call the appropriate sync method for the device + (e.g. ``cudaStreamWaitEvent`` or ``hipStreamWaitEvent``) before attempting + to access the memory in the buffers. + + If an event is provided, then the producer MUST ensure that the exported + data is available on the device before the event is triggered. The + consumer SHOULD wait on the event before trying to access the exported + data. + +.. seealso:: + The :ref:`synchronization event types <_c-device-data-interface-event-types>` + section below. + +.. c:member:: int64_t ArrowDeviceArray.reserved[3] + + As non-CPU development expands, there may be a need to expand this + structure. In order to do so without potentially breaking ABI changes, + we reserve 24 bytes at the end of the object. These bytes MUST be zero'd + out after initialization by the producer in order to ensure safe + evolution of the ABI in the future. + +.. _c-device-data-interface-event-types: + +Synchronization event types +--------------------------- + +The table below lists the expected event types for each device type. +If no event type is supported ("N/A"), then the ``sync_event`` member +should always be null. + +Remember that the event *CAN* be null if synchronization is not needed +to access the data. + ++---------------------------+--------------------+---------+ +| Device Type | Actual Event Type | Notes | ++===========================+====================+=========+ +| ARROW_DEVICE_CPU | N/A | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_CUDA | ``cudaEvent_t*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_CUDA_HOST | ``cudaEvent_t*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_OPENCL | ``cl_event*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_VULKAN | ``VkEvent*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_METAL | ``MTLEvent*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_VPI | N/A | (1) | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_ROCM | ``hipEvent_t*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_ROCM_HOST | ``hipEvent_t*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_EXT_DEV | | (2) | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_CUDA_MANAGED | ``cudaEvent_t*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_ONEAPI | ``sycl::event*`` | | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_WEBGPU | N/A | (1) | ++---------------------------+--------------------+---------+ +| ARROW_DEVICE_HEXAGON | N/A | (1) | ++---------------------------+--------------------+---------+ + +Notes: + +* \(1) Currently unknown if framework has an event type to support. +* \(2) Extension Device has producer defined semantics and thus if + synchronization is needed for an extension device, the producer + should document the type. + + +Semantics +========= + +Memory management +----------------- + +First and foremost: Out of everything in this interface, it is *only* the +data buffers themselves which reside in device memory (i.e. the ``buffers`` +member of the ``ArrowArray`` struct). Everything else should be in CPU +memory. + +The ``ArrowDeviceArray`` structure contains an ``ArrowArray`` object which +itself has :ref:`specific semantics ` for releasing +memory. The term *"base structure"* below refers to the ``ArrowDeviceArray`` +object that is passed directly between the producer and consumer -- not any +child structure thereof. + +It is intended for the base structure to be stack- or heap-allocated by the +*consumer*. In this case, the producer API should take a pointer to the +consumer-allocated structure. + +However, any data pointed to by the struct MUST be allocated and maintained +by the producer. This includes the ``sync_event`` member if it is not null, +along with any pointers in the ``ArrowArray`` object as usual. Data lifetime +is managed through the ``release`` callback of the ``ArrowArray`` member. + +For an ``ArrowDeviceArray``, the semantics of a released structure and the +callback semantics are identical to those for +:ref:`ArrowArray itself `. Any producer specific context +information necessary for releasing the device data buffers, in addition to +any allocated event, should be stored in the ``private_data`` member of +the ``ArrowArray`` and managed by the ``release`` callback. + +Moving an array +''''''''''''''' + +The consumer can *move* the ``ArrowDeviceArray`` structure by bitwise copying +or shallow member-wise copying. Then it MUST mark the source structure released +by setting the ``release`` member of the embedded ``ArrowArray`` structure to +``NULL``, but *without* calling that release callback. This ensures that only +one live copy of the struct is active at any given time and that lifetime is +correctly communicated to the producer. + +As usual, the release callback will be called on the destination structure +when it is not needed anymore. + +Record batches +-------------- +As with the C data interface itself, a record batch can be trivially considered +as an equivalent struct array. In this case the metadata of the top-level +``ArrowSchema`` can be used for schema-level metadata of the record batch. + +Mutability +---------- + +Both the producer and the consumer SHOULD consider the exported data (that +is, the data reachable on the device through the ``buffers`` member of +the embedded ``ArrowArray``) to be immutable, as either party could otherwise +see inconsistent data while the other is mutating it. + +Synchronization +--------------- + +If the ``sync_event`` member is non-NULL, the consumer should not attempt +to access or read the data until they have synchronized on that event. If +the ``sync_event`` member is NULL, then it MUST be safe to access the data +without any synchronization necessary on the part of the consumer. + +C producer example +==================== + +Exporting a simple ``int32`` device array +----------------------------------------- + +Export a non-nullable ``int32`` type with empty metadata. An example of this +can be seen in the :ref:`C data interface docs directly `. + +To export the data itself, we transfer ownership to the consumer through +the release callback. This example will use CUDA, but the equivalent calls +could be used for any device: + +.. code-block:: c + + static void release_int32_device_array(struct ArrowArray* array) { + assert(array->n_buffers == 2); + // destroy the event + cudaEvent_t* ev_ptr = (cudaEvent_t*)(array->private_data); + cudaError_t status = cudaEventDestroy(*ev_ptr); + assert(status == cudaSuccess); + free(ev_ptr); + + // free the buffers and the buffers array + status = cudaFree(array->buffers[1]); + assert(status == cudaSuccess); + free(array->buffers); + + // mark released + array->release = NULL; + } + + void export_int32_device_array(void* cudaAllocdPtr, + cudaStream_t stream, + int64_t length, + struct ArrowDeviceArray* array) { + // get device id + int device; + cudaError_t status; + status = cudaGetDevice(&device); + assert(status == cudaSuccess); + + cudaEvent_t* ev_ptr = (cudaEvent_t*)malloc(sizeof(cudaEvent_t)); + assert(ev_ptr != NULL); + status = cudaEventCreate(ev_ptr); + assert(status == cudaSuccess); + + // record event on the stream, assuming that the passed in + // stream is where the work to produce the data will be processing. + status = cudaEventRecord(*ev_ptr, stream); + assert(status == cudaSuccess); + + memset(array, 0, sizeof(struct ArrowDeviceArray)); + // initialize fields + *array = (struct ArrowDeviceArray) { + .array = (struct ArrowArray) { + .length = length, + .null_count = 0, + .offset = 0, + .n_buffers = 2, + .n_children = 0, + .children = NULL, + .dictionary = NULL, + // bookkeeping + .release = &release_int32_device_array, + // store the event pointer as private data in the array + // so that we can access it in the release callback. + .private_data = (void*)(ev_ptr), + }, + .device_id = (int64_t)(device), + .device_type = ARROW_DEVICE_CUDA, + // pass the event pointer to the consumer + .sync_event = (void*)(ev_ptr), + }; + + // allocate list of buffers + array->array.buffers = (const void**)malloc(sizeof(void*) * array->array.n_buffers); + assert(array->array.buffers != NULL); + array->array.buffers[0] = NULL; + array->array.buffers[1] = cudaAllocdPtr; + } + + // calling the release callback should be done using the array member + // of the device array. + static void release_device_array_helper(struct ArrowDeviceArray* arr) { + arr->array.release(&arr->array); + } + +======================= +Device Stream Interface +======================= + +Like the :ref:`C stream interface `, the C Device data +interface also specifies a higher-level structure for easing communication +of streaming data within a single process. + +Semantics +========= + +An Arrow C device stream exposes a streaming source of data chunks, each with +the same schema. Chunks are obtained by calling a blocking pull-style iteration +function. It is expected that all chunks should be providing data on the same +device type (but not necessarily the same device id). If it is necessary +to provide a stream of data on multiple device types, a producer should +provide a separate stream object for each device type. + +Structure definition +==================== + +The C device stream interface is defined by a single ``struct`` definition: + +.. code-block:: c + + #ifndef ARROW_C_DEVICE_STREAM_INTERFACE + #define ARROW_C_DEVICE_STREAM_INTERFACE + + struct ArrowDeviceArrayStream { + // device type that all arrays will be accessible from + ArrowDeviceType device_type; + // callbacks + int (*get_schema)(struct ArrowDeviceArrayStream*, struct ArrowSchema*); + int (*get_next)(struct ArrowDeviceArrayStream*, struct ArrowDeviceArray*); + const char* (*get_last_error)(struct ArrowDeviceArrayStream*); + + // release callback + void (*release)(struct ArrowDeviceArrayStream*); + + // opaque producer-specific data + void* private_data; + }; + + #endif // ARROW_C_DEVICE_STREAM_INTERFACE + +.. note:: + The canonical guard ``ARROW_C_DEVICE_STREAM_INTERFACE`` is meant to avoid + duplicate definitions if two projects copy the C device stream interface + definitions into their own headers, and a third-party project includes + from these two projects. It is therefore important that this guard is + kept exactly as-is when these definitions are copied. + +The ArrowDeviceArrayStream structure +------------------------------------ + +The ``ArrowDeviceArrayStream`` provides a device type that can access the +resulting data along with the required callbacks to interact with a +streaming source of Arrow arrays. It has the following fields: + +.. c:member:: ArrowDeviceType device_type + + *Mandatory.* The device type that this stream produces data on. All + ``ArrowDeviceArray``s that are produced by this stream should have the + same device type as is set here. This is a convenience for the consumer + to not have to check every array that is retrieved and instead allows + higher-level coding constructs for streams. + +.. c:member:: int (*ArrowDeviceArrayStream.get_schema)(struct ArrowDeviceArrayStream*, struct ArrowSchema* out) + + *Mandatory.* This callback allows the consumer to query the schema of + the chunks of data in the stream. The schema is the same for all data + chunks. + + This callback must NOT be called on a released ``ArrowDeviceArrayStream``. + + *Return value:* 0 on success, a non-zero + :ref:`error code ` otherwise. + +.. c:member:: int (*ArrowDeviceArrayStream.get_next)(struct ArrowDeviceArrayStream*, struct ArrowDeviceArray* out) + + *Mandatory.* This callback allows the consumer to get the next chunk of + data in the stream. + + This callback must NOT be called on a released ``ArrowDeviceArrayStream``. + + The next chunk of data MUST be accessible from a device type matching the + :c:member:`ArrowDeviceArrayStream.device_type`. + + *Return value:* 0 on success, a non-zero + :ref:`error code ` otherwise. + + On success, the consumer must check whether the ``ArrowDeviceArray``'s + embedded ``ArrowArray`` is marked :ref:`released `. + If the embedded ``ArrowDeviceArray.array`` is released, then the end of the + stream has been reached. Otherwise, the ``ArrowDeviceArray`` contains a + valid data chunk. + +.. c:member:: const char* (*ArrowDeviceArrayStream.get_last_error)(struct ArrowDeviceArrayStream*) + + *Mandatory.* This callback allows the consumer to get a textual description + of the last error. + + This callback must ONLY be called if the last operation on the + ``ArrowDeviceArrayStream`` returned an error. It must NOT be called on a + released ``ArrowDeviceArrayStream``. + + *Return value:* a pointer to a NULL-terminated character string + (UTF8-encoded). NULL can also be returned if no detailed description is + available. + + The returned pointer is only guaranteed to be valid until the next call + of one of the stream's callbacks. The character string it points to should + be copied to consumer-managed storage if it is intended to survive longer. + +.. c:member:: void (*ArrowDeviceArrayStream.release)(struct ArrowDeviceArrayStream*) + + *Mandatory.* A pointer to a producer-provided release callback. + +.. c:member:: void* ArrowDeviceArrayStream.private_data + + *Optional.* An opaque pointer to producer-provided private data. + + Consumers MUST NOT process this member. Lifetime of this member is + handled by the producer, and especially by the release callback. + +Result lifetimes +---------------- + +The data returned by the ``get_schema`` and ``get_next`` callbacks must be +released independantly. Their lifetimes are not tied to that of +``ArrowDeviceArrayStream``. + +Stream lifetime +--------------- + +Lifetime of the C stream is managed using a release callback with similar +usage as in :ref:`C data interface `. + +Thread safety +------------- + +The stream source is not assumed to be thread-safe. Consumers wanting to +call ``get_next`` from several threads should ensure those calls are +serialized. + +Interoperability with other interchange formats +=============================================== + +Other interchange APIs, such as the `CUDA Array Interface`_, include +members to pass the shape and the data types of the data buffers being +exported. This information is necessary to interpret the raw bytes in the +device data buffers that are being shared. Rather than store the +shape / types of the data alongside the ``ArrowDeviceArray``, users +should utilize the existing ``ArrowSchema`` structure to pass any data +type and shape information. + +Updating this specification +=========================== + +.. note:: + Since this specification is still considered experimental, there is the + (still very low) possibility it might change slightly. The reason for + tagging this as "experimental" is because we don't know what we don't know. + Work and research was done to ensure a generic ABI compatible with many + different frameworks, but it is always possible something was missed. + Once this is supported in an official Arrow release and usage is observed + to confirm there aren't any modifications necessary, the "experimental" + tag will be removed and the ABI frozen. + +Once this specification is supported in an official Arrow release, the C ABI +is frozen. This means that the ``ArrowDeviceArray`` structure definition +should not change in any way -- including adding new members. + +Backwards-compatible changes are allowed, for example new macro values for +:c:typedef:`ArrowDeviceType` or converting the reserved 24 bytes into a +different type/member without changing the size of the structure. + +Any incompatible changes should be part of a new specification, for example +``ArrowDeviceArrayV2``. + + +.. _RAPIDS cuDF library: https://docs.rapids.ai/api/cudf/stable/ +.. _CUDA Array Interface: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html +.. _dlpack: https://dmlc.github.io/dlpack/latest/c_api.html#c-api +.. _CUDA: https://developer.nvidia.com/cuda-toolkit +.. _OpenCL (Open Computing Language): https://www.khronos.org/opencl/ +.. _Vulkan: https://www.vulkan.org/ +.. _Metal: https://developer.apple.com/metal/ +.. _ROCm: https://www.amd.com/en/graphics/servers-solutions-rocm +.. _oneAPI: https://www.intel.com/content/www/us/en/developer/tools/oneapi/overview.html \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 1813c201d3f99..bcae36a116564 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -80,6 +80,7 @@ target environment.** format/Integration format/CDataInterface format/CStreamInterface + format/CDeviceDataInterface format/ADBC format/Other format/Changing From 745fa94f68f200d2a703b864d4b29ec8ee98a64d Mon Sep 17 00:00:00 2001 From: Erez Rokah Date: Tue, 6 Jun 2023 18:39:01 +0200 Subject: [PATCH 11/31] GH-35948: [Go] Only cast `int8` and `unit8` to `float64` when JSON marshaling arrays (#35950) ### Rationale for this change See [issue](https://github.com/apache/arrow/issues/35948). Based on https://github.com/apache/arrow/blob/9fb8697dcb442f63317c7d6046393fb74842e0ae/go/arrow/array/numeric.gen.go.tmpl#L120 it seems we should only do the casting for `unit8` and `int8` types. ### What changes are included in this PR? This PR fixes the numeric template so it only casts `int8` and `unit8` to `float64` when JSON marshaling arrays to avoid losing data for other types. ### Are these changes tested? Yes ### Are there any user-facing changes? The user facing change is that if a user was JSON marshaling arrays that has `uint64` or `int64` values they can now get different (more accurate) values than before * Closes: #35948 Authored-by: erezrokah Signed-off-by: Matt Topol --- go/arrow/array/numeric.gen.go | 16 ++++----- go/arrow/array/numeric.gen.go.tmpl | 2 +- go/arrow/array/numeric_test.go | 58 ++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/go/arrow/array/numeric.gen.go b/go/arrow/array/numeric.gen.go index aa14ca0ea2844..4ceb63d2c4926 100644 --- a/go/arrow/array/numeric.gen.go +++ b/go/arrow/array/numeric.gen.go @@ -101,7 +101,7 @@ func (a *Int64) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + vals[i] = a.values[i] } else { vals[i] = nil } @@ -196,7 +196,7 @@ func (a *Uint64) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + vals[i] = a.values[i] } else { vals[i] = nil } @@ -291,7 +291,7 @@ func (a *Float64) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + vals[i] = a.values[i] } else { vals[i] = nil } @@ -386,7 +386,7 @@ func (a *Int32) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + vals[i] = a.values[i] } else { vals[i] = nil } @@ -481,7 +481,7 @@ func (a *Uint32) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + vals[i] = a.values[i] } else { vals[i] = nil } @@ -576,7 +576,7 @@ func (a *Float32) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + vals[i] = a.values[i] } else { vals[i] = nil } @@ -671,7 +671,7 @@ func (a *Int16) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + vals[i] = a.values[i] } else { vals[i] = nil } @@ -766,7 +766,7 @@ func (a *Uint16) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + vals[i] = a.values[i] } else { vals[i] = nil } diff --git a/go/arrow/array/numeric.gen.go.tmpl b/go/arrow/array/numeric.gen.go.tmpl index b141276d756ac..44b353069fd05 100644 --- a/go/arrow/array/numeric.gen.go.tmpl +++ b/go/arrow/array/numeric.gen.go.tmpl @@ -134,7 +134,7 @@ func (a *{{.Name}}) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { if a.IsValid(i) { - vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data + {{ if (eq .Size "1") }}vals[i] = float64(a.values[i]) // prevent uint8 from being seen as binary data{{ else }}vals[i] = a.values[i]{{ end }} } else { vals[i] = nil } diff --git a/go/arrow/array/numeric_test.go b/go/arrow/array/numeric_test.go index 73513374a6566..962b4eb6a598f 100644 --- a/go/arrow/array/numeric_test.go +++ b/go/arrow/array/numeric_test.go @@ -632,3 +632,61 @@ func TestDate64SliceDataWithNull(t *testing.T) { t.Fatalf("got=%v, want=%v", got, want) } } + +func TestInt64MarshalJSON(t *testing.T) { + pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer pool.AssertSize(t, 0) + + var ( + vs = []int64{-5474557666971701248} + ) + + b := array.NewInt64Builder(pool) + defer b.Release() + + for _, v := range vs { + b.Append(v) + } + + arr := b.NewArray().(*array.Int64) + defer arr.Release() + + jsonBytes, err := json.Marshal(arr) + if err != nil { + t.Fatal(err) + } + got := string(jsonBytes) + want := `[-5474557666971701248]` + if got != want { + t.Fatalf("got=%s, want=%s", got, want) + } +} + +func TestUInt64MarshalJSON(t *testing.T) { + pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer pool.AssertSize(t, 0) + + var ( + vs = []uint64{14697929703826477056} + ) + + b := array.NewUint64Builder(pool) + defer b.Release() + + for _, v := range vs { + b.Append(v) + } + + arr := b.NewArray().(*array.Uint64) + defer arr.Release() + + jsonBytes, err := json.Marshal(arr) + if err != nil { + t.Fatal(err) + } + got := string(jsonBytes) + want := `[14697929703826477056]` + if got != want { + t.Fatalf("got=%s, want=%s", got, want) + } +} From e7a9b29b50c27e7239bfdcaff683b41069c36910 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 6 Jun 2023 12:40:31 -0400 Subject: [PATCH 12/31] GH-35932: [Java] Make JDBC test less brittle (#35940) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Rationale for this change A JDBC test is brittle because it compares string representations. ### What changes are included in this PR? Compare values directly. ### Are these changes tested? N/A ### Are there any user-facing changes? No. * Closes: #35932 Authored-by: David Li Signed-off-by: Raúl Cumplido --- dev/release/verify-release-candidate.sh | 5 +- .../jdbc/JdbcToArrowCommentMetadataTest.java | 80 +++++++++++-- .../h2/expectedSchemaWithComments.json | 51 -------- ...expectedSchemaWithCommentsAndJdbcMeta.json | 112 ------------------ .../arrow/flight/TestBasicOperation.java | 3 + .../ArrowFlightPreparedStatementTest.java | 2 + 6 files changed, 79 insertions(+), 174 deletions(-) delete mode 100644 java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithComments.json delete mode 100644 java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithCommentsAndJdbcMeta.json diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 638f48aaecb99..4183520a06e6b 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -563,7 +563,10 @@ test_package_java() { show_header "Build and test Java libraries" # Build and test Java (Requires newer Maven -- I used 3.3.9) - maybe_setup_conda maven || exit 1 + # Pin OpenJDK 17 since OpenJDK 20 is incompatible with our versions + # of things like Mockito, and we also can't update Mockito due to + # not supporting Java 8 anymore + maybe_setup_conda maven openjdk=17.0.3 || exit 1 pushd java mvn test diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowCommentMetadataTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowCommentMetadataTest.java index 8d3e59955488e..dc52210d6c7ab 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowCommentMetadataTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowCommentMetadataTest.java @@ -29,6 +29,7 @@ import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -37,18 +38,18 @@ import java.util.Set; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.arrow.vector.util.ObjectMapperFactory; import org.junit.After; import org.junit.Before; import org.junit.Test; -import com.fasterxml.jackson.databind.ObjectWriter; - public class JdbcToArrowCommentMetadataTest { private static final String COMMENT = "comment"; //use this metadata key for interoperability with Spark StructType - private final ObjectWriter schemaSerializer = ObjectMapperFactory.newObjectMapper().writerWithDefaultPrettyPrinter(); private Connection conn = null; /** @@ -73,26 +74,85 @@ public void tearDown() throws SQLException { } } + private static Field field(String name, boolean nullable, ArrowType type, Map metadata) { + return new Field(name, new FieldType(nullable, type, null, metadata), Collections.emptyList()); + } + + private static Map metadata(String... entries) { + if (entries.length % 2 != 0) { + throw new IllegalArgumentException("Map must have equal number of keys and values"); + } + + final Map result = new HashMap<>(); + for (int i = 0; i < entries.length; i += 2) { + result.put(entries[i], entries[i + 1]); + } + return result; + } + @Test public void schemaComment() throws Exception { boolean includeMetadata = false; - String schemaJson = schemaSerializer.writeValueAsString(getSchemaWithCommentFromQuery(includeMetadata)); - String expectedSchema = getExpectedSchema("/h2/expectedSchemaWithComments.json"); - assertThat(schemaJson).isEqualTo(expectedSchema); + Schema schema = getSchemaWithCommentFromQuery(includeMetadata); + Schema expectedSchema = new Schema(Arrays.asList( + field("ID", false, Types.MinorType.BIGINT.getType(), + metadata("comment", "Record identifier")), + field("NAME", true, Types.MinorType.VARCHAR.getType(), + metadata("comment", "Name of record")), + field("COLUMN1", true, Types.MinorType.BIT.getType(), + metadata()), + field("COLUMNN", true, Types.MinorType.INT.getType(), + metadata("comment", "Informative description of columnN")) + ), metadata("comment", "This is super special table with valuable data")); + assertThat(schema).isEqualTo(expectedSchema); } @Test public void schemaCommentWithDatabaseMetadata() throws Exception { boolean includeMetadata = true; - String schemaJson = schemaSerializer.writeValueAsString(getSchemaWithCommentFromQuery(includeMetadata)); - String expectedSchema = getExpectedSchema("/h2/expectedSchemaWithCommentsAndJdbcMeta.json"); + Schema schema = getSchemaWithCommentFromQuery(includeMetadata); + Schema expectedSchema = new Schema(Arrays.asList( + field("ID", false, Types.MinorType.BIGINT.getType(), + metadata( + "SQL_CATALOG_NAME", "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", + "SQL_SCHEMA_NAME", "PUBLIC", + "SQL_TABLE_NAME", "TABLE1", + "SQL_COLUMN_NAME", "ID", + "SQL_TYPE", "BIGINT", + "comment", "Record identifier" + )), + field("NAME", true, Types.MinorType.VARCHAR.getType(), + metadata( + "SQL_CATALOG_NAME", "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", + "SQL_SCHEMA_NAME", "PUBLIC", + "SQL_TABLE_NAME", "TABLE1", + "SQL_COLUMN_NAME", "NAME", + "SQL_TYPE", "VARCHAR", + "comment", "Name of record")), + field("COLUMN1", true, Types.MinorType.BIT.getType(), + metadata( + "SQL_CATALOG_NAME", "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", + "SQL_SCHEMA_NAME", "PUBLIC", + "SQL_TABLE_NAME", "TABLE1", + "SQL_COLUMN_NAME", "COLUMN1", + "SQL_TYPE", "BOOLEAN")), + field("COLUMNN", true, Types.MinorType.INT.getType(), + metadata( + "SQL_CATALOG_NAME", "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", + "SQL_SCHEMA_NAME", "PUBLIC", + "SQL_TABLE_NAME", "TABLE1", + "SQL_COLUMN_NAME", "COLUMNN", + "SQL_TYPE", "INTEGER", + "comment", "Informative description of columnN")) + ), metadata("comment", "This is super special table with valuable data")); + assertThat(schema).isEqualTo(expectedSchema); /* corresponding Apache Spark DDL after conversion: ID BIGINT NOT NULL COMMENT 'Record identifier', NAME STRING COMMENT 'Name of record', COLUMN1 BOOLEAN, COLUMNN INT COMMENT 'Informative description of columnN' */ - assertThat(schemaJson).isEqualTo(expectedSchema); + assertThat(schema).isEqualTo(expectedSchema); } private Schema getSchemaWithCommentFromQuery(boolean includeMetadata) throws SQLException { diff --git a/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithComments.json b/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithComments.json deleted file mode 100644 index cfdd00fdff4e5..0000000000000 --- a/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithComments.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "fields" : [ { - "name" : "ID", - "nullable" : false, - "type" : { - "name" : "int", - "bitWidth" : 64, - "isSigned" : true - }, - "children" : [ ], - "metadata" : [ { - "value" : "Record identifier", - "key" : "comment" - } ] - }, { - "name" : "NAME", - "nullable" : true, - "type" : { - "name" : "utf8" - }, - "children" : [ ], - "metadata" : [ { - "value" : "Name of record", - "key" : "comment" - } ] - }, { - "name" : "COLUMN1", - "nullable" : true, - "type" : { - "name" : "bool" - }, - "children" : [ ] - }, { - "name" : "COLUMNN", - "nullable" : true, - "type" : { - "name" : "int", - "bitWidth" : 32, - "isSigned" : true - }, - "children" : [ ], - "metadata" : [ { - "value" : "Informative description of columnN", - "key" : "comment" - } ] - } ], - "metadata" : [ { - "value" : "This is super special table with valuable data", - "key" : "comment" - } ] -} \ No newline at end of file diff --git a/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithCommentsAndJdbcMeta.json b/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithCommentsAndJdbcMeta.json deleted file mode 100644 index 9b25d635d4bc2..0000000000000 --- a/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithCommentsAndJdbcMeta.json +++ /dev/null @@ -1,112 +0,0 @@ -{ - "fields" : [ { - "name" : "ID", - "nullable" : false, - "type" : { - "name" : "int", - "bitWidth" : 64, - "isSigned" : true - }, - "children" : [ ], - "metadata" : [ { - "value" : "PUBLIC", - "key" : "SQL_SCHEMA_NAME" - }, { - "value" : "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", - "key" : "SQL_CATALOG_NAME" - }, { - "value" : "ID", - "key" : "SQL_COLUMN_NAME" - }, { - "value" : "BIGINT", - "key" : "SQL_TYPE" - }, { - "value" : "Record identifier", - "key" : "comment" - }, { - "value" : "TABLE1", - "key" : "SQL_TABLE_NAME" - } ] - }, { - "name" : "NAME", - "nullable" : true, - "type" : { - "name" : "utf8" - }, - "children" : [ ], - "metadata" : [ { - "value" : "PUBLIC", - "key" : "SQL_SCHEMA_NAME" - }, { - "value" : "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", - "key" : "SQL_CATALOG_NAME" - }, { - "value" : "NAME", - "key" : "SQL_COLUMN_NAME" - }, { - "value" : "VARCHAR", - "key" : "SQL_TYPE" - }, { - "value" : "Name of record", - "key" : "comment" - }, { - "value" : "TABLE1", - "key" : "SQL_TABLE_NAME" - } ] - }, { - "name" : "COLUMN1", - "nullable" : true, - "type" : { - "name" : "bool" - }, - "children" : [ ], - "metadata" : [ { - "value" : "PUBLIC", - "key" : "SQL_SCHEMA_NAME" - }, { - "value" : "TABLE1", - "key" : "SQL_TABLE_NAME" - }, { - "value" : "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", - "key" : "SQL_CATALOG_NAME" - }, { - "value" : "COLUMN1", - "key" : "SQL_COLUMN_NAME" - }, { - "value" : "BOOLEAN", - "key" : "SQL_TYPE" - } ] - }, { - "name" : "COLUMNN", - "nullable" : true, - "type" : { - "name" : "int", - "bitWidth" : 32, - "isSigned" : true - }, - "children" : [ ], - "metadata" : [ { - "value" : "PUBLIC", - "key" : "SQL_SCHEMA_NAME" - }, { - "value" : "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", - "key" : "SQL_CATALOG_NAME" - }, { - "value" : "COLUMNN", - "key" : "SQL_COLUMN_NAME" - }, { - "value" : "INTEGER", - "key" : "SQL_TYPE" - }, { - "value" : "Informative description of columnN", - "key" : "comment" - }, { - "value" : "TABLE1", - "key" : "SQL_TABLE_NAME" - } ] - } ], - "metadata" : [ { - "value" : "This is super special table with valuable data", - "key" : "comment" - } ] -} \ No newline at end of file diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index 40337b2de5a59..260ea4a0e3fed 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -56,6 +56,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; import com.google.common.base.Charsets; import com.google.protobuf.ByteString; @@ -285,6 +287,7 @@ public void getStream() throws Exception { /** Ensure the client is configured to accept large messages. */ @Test + @DisabledOnOs(value = {OS.WINDOWS}, disabledReason = "https://github.com/apache/arrow/issues/33237: flaky test") public void getStreamLargeBatch() throws Exception { test(c -> { try (final FlightStream stream = c.getStream(new Ticket(Producer.TICKET_LARGE_BATCH))) { diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java index 42fb31e811be3..df2577e955881 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -32,6 +32,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; @@ -73,6 +74,7 @@ public void testSimpleQueryNoParameterBinding() throws SQLException { } @Test + @Ignore("https://github.com/apache/arrow/issues/34741: flaky test") public void testPreparedStatementExecutionOnce() throws SQLException { final PreparedStatement statement = connection.prepareStatement(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD); // Expect that there is one entry in the map -- {prepared statement action type, invocation count}. From fe6c067e690ab28392d19388810725ce9f7ad646 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 6 Jun 2023 13:36:43 -0400 Subject: [PATCH 13/31] GH-35952: [R] Ensure that schema metadata can actually be set as a named character vector (#35954) This wasn't necessarily a regression (reprex fails in 12.0.0 as well), although the comments suggest that assigning a named character vector will work when assigning schema metadata and this feature appears to be used by at least one of our dependencies (sfarrow). Given that the sfarrow check passes on 12.0.0, there is possibly also a place in our code that returns a named character vector rather than a list. I've confirmed that this fix solves the reverse dependency failure building the sfarrow example vignette. ``` r library(arrow, warn.conflicts = FALSE) schema <- schema(x = int32()) schema$metadata <- c("name" = "value") schema$metadata #> $name #> [1] "value" ``` Created on 2023-06-06 with [reprex v2.0.2](https://reprex.tidyverse.org) * Closes: #35952 Authored-by: Dewey Dunnington Signed-off-by: Nic Crane --- r/R/schema.R | 4 ++++ r/tests/testthat/test-schema.R | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/r/R/schema.R b/r/R/schema.R index 9ff38487c4b6f..dc0b4ba81fbbc 100644 --- a/r/R/schema.R +++ b/r/R/schema.R @@ -229,9 +229,13 @@ prepare_key_value_metadata <- function(metadata) { call. = FALSE ) } + + metadata <- as.list(metadata) + if (!is_empty(metadata) && is.list(metadata[["r"]])) { metadata[["r"]] <- .serialize_arrow_r_metadata(metadata[["r"]]) } + map_chr(metadata, as.character) } diff --git a/r/tests/testthat/test-schema.R b/r/tests/testthat/test-schema.R index 24776e6d0c199..a6a0555eaca0e 100644 --- a/r/tests/testthat/test-schema.R +++ b/r/tests/testthat/test-schema.R @@ -140,6 +140,22 @@ test_that("Schema modification", { expect_error(schm[[c(2, 4)]] <- int32(), "length(i) not equal to 1", fixed = TRUE) }) +test_that("Metadata can be reassigned as a whole", { + schm <- schema(b = double(), c = string(), d = int8()) + + # Check named character vector + schm$metadata <- c("foo" = "bar") + expect_identical(schm$metadata, list(foo = "bar")) + + # Check list() + schm$metadata <- list("foo" = "bar") + expect_identical(schm$metadata, list(foo = "bar")) + + # Check NULL for removal + schm$metadata <- NULL + expect_identical(schm$metadata, set_names(list(), character())) +}) + test_that("Metadata is preserved when modifying Schema", { schm <- schema(b = double(), c = string(), d = int8()) schm$metadata$foo <- "bar" From 77d8bc59fbbd44b285d4f1abf02f215ec33d8ad8 Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Tue, 6 Jun 2023 13:44:27 -0400 Subject: [PATCH 14/31] MINOR: clarify when to use minor (#35293) This makes the instructions more clear imo. Authored-by: Dominik Moritz Signed-off-by: AlenkaF --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a1c473a24bfa1..5ad7ca248625a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -50,7 +50,7 @@ makes it easier for us to process the backlog of submitted Pull Requests. Any functionality change should have a GitHub issue opened. For minor changes that affect documentation, you do not need to open up a GitHub issue. Instead you can -prefix the title of your PR with "MINOR: " if meets the following guidelines: +prefix the title of your PR with "MINOR: " if meets one of the following: * Grammar, usage and spelling fixes that affect no more than 2 files * Documentation updates affecting no more than 2 files and not more From daacbcc4c5f0e435b2158896584a4385dbf38986 Mon Sep 17 00:00:00 2001 From: Alex Shcherbakov Date: Tue, 6 Jun 2023 21:17:20 +0300 Subject: [PATCH 15/31] GH-35909: [Go] Deprecate `arrow.MapType.ValueField` & `arrow.MapType.ValueType` methods (#35899) ### Rationale for this change Follow-up for #35885 ### What changes are included in this PR? * Added `ElemField() Field` to `arrow.ListLikeType` interface * Added `ElemField() Field` to `arrow.MapType` implementation * Added deprecation notice to `arrow.MapType.ValueField` & `arrow.MapType.ValueType` * Fixed a bug in `go/arrow/array/map.go` (`NewMapBuilderWithType` used `ValueType` instead of `ItemType`) ### Are these changes tested? Compile-time assertion for corresponding types. ### Are there any user-facing changes? * Added `ElemField() Field` to `arrow.ListLikeType` interface * Added `ElemField() Field` to `arrow.MapType` implementation * Added deprecation notice to `arrow.MapType.ValueField` & `arrow.MapType.ValueType` * Closes: #35909 Authored-by: candiduslynx Signed-off-by: Matt Topol --- go/arrow/array/map.go | 6 +++--- go/arrow/cdata/cdata.go | 2 +- go/arrow/datatype_nested.go | 24 ++++++++++++++++-------- go/arrow/datatype_nested_test.go | 12 ++++++------ go/arrow/internal/arrdata/arrdata.go | 8 ++++---- go/arrow/internal/arrjson/arrjson.go | 4 ++-- go/arrow/ipc/file_reader.go | 2 +- go/arrow/ipc/metadata.go | 2 +- go/arrow/scalar/nested.go | 15 +-------------- go/arrow/scalar/parse.go | 2 +- go/parquet/pqarrow/encode_arrow.go | 22 +++++++--------------- go/parquet/pqarrow/schema.go | 14 +++----------- 12 files changed, 46 insertions(+), 67 deletions(-) diff --git a/go/arrow/array/map.go b/go/arrow/array/map.go index 815465c1e5ec9..1e1ba59f0d442 100644 --- a/go/arrow/array/map.go +++ b/go/arrow/array/map.go @@ -150,7 +150,7 @@ type MapBuilder struct { func NewMapBuilder(mem memory.Allocator, keytype, itemtype arrow.DataType, keysSorted bool) *MapBuilder { etype := arrow.MapOf(keytype, itemtype) etype.KeysSorted = keysSorted - listBldr := NewListBuilder(mem, etype.ValueType()) + listBldr := NewListBuilder(mem, etype.Elem()) keyBldr := listBldr.ValueBuilder().(*StructBuilder).FieldBuilder(0) keyBldr.Retain() itemBldr := listBldr.ValueBuilder().(*StructBuilder).FieldBuilder(1) @@ -167,7 +167,7 @@ func NewMapBuilder(mem memory.Allocator, keytype, itemtype arrow.DataType, keysS } func NewMapBuilderWithType(mem memory.Allocator, dt *arrow.MapType) *MapBuilder { - listBldr := NewListBuilder(mem, dt.ValueType()) + listBldr := NewListBuilder(mem, dt.Elem()) keyBldr := listBldr.ValueBuilder().(*StructBuilder).FieldBuilder(0) keyBldr.Retain() itemBldr := listBldr.ValueBuilder().(*StructBuilder).FieldBuilder(1) @@ -178,7 +178,7 @@ func NewMapBuilderWithType(mem memory.Allocator, dt *arrow.MapType) *MapBuilder itemBuilder: itemBldr, etype: dt, keytype: dt.KeyType(), - itemtype: dt.ValueType(), + itemtype: dt.ItemType(), keysSorted: dt.KeysSorted, } } diff --git a/go/arrow/cdata/cdata.go b/go/arrow/cdata/cdata.go index 638d39b02e258..e20a12bbdf61c 100644 --- a/go/arrow/cdata/cdata.go +++ b/go/arrow/cdata/cdata.go @@ -359,7 +359,7 @@ func (imp *cimporter) doImportChildren() error { imp.children[i].importChild(imp, c) } case arrow.MAP: // only one child to import, it's a struct array - imp.children[0].dt = imp.dt.(*arrow.MapType).ValueType() + imp.children[0].dt = imp.dt.(*arrow.MapType).Elem() if err := imp.children[0].importChild(imp, children[0]); err != nil { return err } diff --git a/go/arrow/datatype_nested.go b/go/arrow/datatype_nested.go index 314172a21b47f..b2abfafddabb8 100644 --- a/go/arrow/datatype_nested.go +++ b/go/arrow/datatype_nested.go @@ -37,6 +37,7 @@ type ( ListLikeType interface { DataType Elem() DataType + ElemField() Field } ) @@ -391,15 +392,22 @@ func (t *MapType) String() string { return o.String() } -func (t *MapType) KeyField() Field { return t.value.Elem().(*StructType).Field(0) } -func (t *MapType) KeyType() DataType { return t.KeyField().Type } -func (t *MapType) ItemField() Field { return t.value.Elem().(*StructType).Field(1) } -func (t *MapType) ItemType() DataType { return t.ItemField().Type } -func (t *MapType) ValueType() *StructType { return t.value.Elem().(*StructType) } -func (t *MapType) ValueField() Field { return Field{Name: "entries", Type: t.ValueType()} } +func (t *MapType) KeyField() Field { return t.value.Elem().(*StructType).Field(0) } +func (t *MapType) KeyType() DataType { return t.KeyField().Type } +func (t *MapType) ItemField() Field { return t.value.Elem().(*StructType).Field(1) } +func (t *MapType) ItemType() DataType { return t.ItemField().Type } + +// Deprecated: use MapType.Elem().(*StructType) instead +func (t *MapType) ValueType() *StructType { return t.Elem().(*StructType) } + +// Deprecated: use MapType.ElemField() instead +func (t *MapType) ValueField() Field { return t.ElemField() } // Elem returns the MapType's element type (if treating MapType as ListLikeType) -func (t *MapType) Elem() DataType { return t.ValueType() } +func (t *MapType) Elem() DataType { return t.value.Elem() } + +// ElemField returns the MapType's element field (if treating MapType as ListLikeType) +func (t *MapType) ElemField() Field { return Field{Name: "entries", Type: t.Elem()} } func (t *MapType) SetItemNullable(nullable bool) { t.value.Elem().(*StructType).fields[1].Nullable = nullable @@ -419,7 +427,7 @@ func (t *MapType) Fingerprint() string { return fingerprint + "{" + keyFingerprint + itemFingerprint + "}" } -func (t *MapType) Fields() []Field { return []Field{t.ValueField()} } +func (t *MapType) Fields() []Field { return []Field{t.ElemField()} } func (t *MapType) Layout() DataTypeLayout { return t.value.Layout() diff --git a/go/arrow/datatype_nested_test.go b/go/arrow/datatype_nested_test.go index 15d7b2e0e37a7..17d530ba85b90 100644 --- a/go/arrow/datatype_nested_test.go +++ b/go/arrow/datatype_nested_test.go @@ -417,7 +417,7 @@ func TestMapOf(t *testing.T) { t.Fatalf("invalid item type. got=%q, want=%q", got, want) } - if got, want := got.ValueType(), StructOf(got.KeyField(), got.ItemField()); !TypeEqual(got, want) { + if got, want := got.Elem(), StructOf(got.KeyField(), got.ItemField()); !TypeEqual(got, want) { t.Fatalf("invalid value type. got=%q, want=%q", got, want) } @@ -477,7 +477,7 @@ func TestMapOfWithMetadata(t *testing.T) { t.Fatalf("invalid item type. got=%q, want=%q", got, want) } - if got, want := got.ValueType(), StructOf(got.KeyField(), got.ItemField()); !TypeEqual(got, want) { + if got, want := got.Elem(), StructOf(got.KeyField(), got.ItemField()); !TypeEqual(got, want) { t.Fatalf("invalid value type. got=%q, want=%q", got, want) } @@ -485,11 +485,11 @@ func TestMapOfWithMetadata(t *testing.T) { t.Fatalf("invalid String() result. got=%q, want=%q", got, want) } - if !reflect.DeepEqual(got.ValueType().fields[0].Metadata, tc.keyMetadata) { - t.Fatalf("invalid key metadata. got=%v, want=%v", got.ValueType().fields[0].Metadata, tc.keyMetadata) + if !reflect.DeepEqual(got.Elem().(*StructType).fields[0].Metadata, tc.keyMetadata) { + t.Fatalf("invalid key metadata. got=%v, want=%v", got.Elem().(*StructType).fields[0].Metadata, tc.keyMetadata) } - if !reflect.DeepEqual(got.ValueType().fields[1].Metadata, tc.itemMetadata) { - t.Fatalf("invalid item metadata. got=%v, want=%v", got.ValueType().fields[1].Metadata, tc.itemMetadata) + if !reflect.DeepEqual(got.Elem().(*StructType).fields[1].Metadata, tc.itemMetadata) { + t.Fatalf("invalid item metadata. got=%v, want=%v", got.Elem().(*StructType).fields[1].Metadata, tc.itemMetadata) } }) } diff --git a/go/arrow/internal/arrdata/arrdata.go b/go/arrow/internal/arrdata/arrdata.go index 3f023e40d67dd..311645fd3ccba 100644 --- a/go/arrow/internal/arrdata/arrdata.go +++ b/go/arrow/internal/arrdata/arrdata.go @@ -793,7 +793,7 @@ func makeMapsRecords() []arrow.Record { chunks := [][]arrow.Array{ { mapOf(mem, dtype.KeysSorted, []arrow.Array{ - structOf(mem, dtype.ValueType(), [][]arrow.Array{ + structOf(mem, dtype.Elem().(*arrow.StructType), [][]arrow.Array{ { arrayOf(mem, []int32{-1, -2, -3, -4, -5}, nil), arrayOf(mem, []string{"111", "222", "333", "444", "555"}, mask[:5]), @@ -815,7 +815,7 @@ func makeMapsRecords() []arrow.Record { arrayOf(mem, []string{"4111", "4222", "4333", "4444", "4555"}, mask[:5]), }, }, nil), - structOf(mem, dtype.ValueType(), [][]arrow.Array{ + structOf(mem, dtype.Elem().(*arrow.StructType), [][]arrow.Array{ { arrayOf(mem, []int32{1, 2, 3, 4, 5}, nil), arrayOf(mem, []string{"-111", "-222", "-333", "-444", "-555"}, mask[:5]), @@ -841,7 +841,7 @@ func makeMapsRecords() []arrow.Record { }, { mapOf(mem, dtype.KeysSorted, []arrow.Array{ - structOf(mem, dtype.ValueType(), [][]arrow.Array{ + structOf(mem, dtype.Elem().(*arrow.StructType), [][]arrow.Array{ { arrayOf(mem, []int32{1, 2, 3, 4, 5}, nil), arrayOf(mem, []string{"-111", "-222", "-333", "-444", "-555"}, mask[:5]), @@ -863,7 +863,7 @@ func makeMapsRecords() []arrow.Record { arrayOf(mem, []string{"-4111", "-4222", "-4333", "-4444", "-4555"}, mask[:5]), }, }, nil), - structOf(mem, dtype.ValueType(), [][]arrow.Array{ + structOf(mem, dtype.Elem().(*arrow.StructType), [][]arrow.Array{ { arrayOf(mem, []int32{-1, -2, -3, -4, -5}, nil), arrayOf(mem, []string{"111", "222", "333", "444", "555"}, mask[:5]), diff --git a/go/arrow/internal/arrjson/arrjson.go b/go/arrow/internal/arrjson/arrjson.go index 93e3039e98d48..f25691f0cf681 100644 --- a/go/arrow/internal/arrjson/arrjson.go +++ b/go/arrow/internal/arrjson/arrjson.go @@ -1098,7 +1098,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr case *arrow.MapType: valids := validsFromJSON(arr.Valids) - elems := arrayFromJSON(mem, dt.ValueType(), arr.Children[0]) + elems := arrayFromJSON(mem, dt.Elem(), arr.Children[0]) defer elems.Release() bitmap := validsToBitmap(valids, mem) @@ -1429,7 +1429,7 @@ func arrayToJSON(field arrow.Field, arr arrow.Array) Array { Valids: validsToJSON(arr), Offset: arr.Offsets(), Children: []Array{ - arrayToJSON(arrow.Field{Name: "entries", Type: arr.DataType().(*arrow.MapType).ValueType()}, arr.ListValues()), + arrayToJSON(arrow.Field{Name: "entries", Type: arr.DataType().(*arrow.MapType).Elem()}, arr.ListValues()), }, } return o diff --git a/go/arrow/ipc/file_reader.go b/go/arrow/ipc/file_reader.go index a561352960a66..d12f57324daf7 100644 --- a/go/arrow/ipc/file_reader.go +++ b/go/arrow/ipc/file_reader.go @@ -589,7 +589,7 @@ func (ctx *arrayLoaderContext) loadMap(dt *arrow.MapType) arrow.ArrayData { buffers = append(buffers, ctx.buffer()) defer releaseBuffers(buffers) - sub := ctx.loadChild(dt.ValueType()) + sub := ctx.loadChild(dt.Elem()) defer sub.Release() return array.NewData(dt, int(field.Length()), buffers, []arrow.ArrayData{sub}, int(field.NullCount()), 0) diff --git a/go/arrow/ipc/metadata.go b/go/arrow/ipc/metadata.go index a489c94ceb0c7..eaa04beeb0d7d 100644 --- a/go/arrow/ipc/metadata.go +++ b/go/arrow/ipc/metadata.go @@ -420,7 +420,7 @@ func (fv *fieldVisitor) visit(field arrow.Field) { case *arrow.MapType: fv.dtype = flatbuf.TypeMap - fv.kids = append(fv.kids, fieldToFB(fv.b, fv.pos.Child(0), dt.ValueField(), fv.memo)) + fv.kids = append(fv.kids, fieldToFB(fv.b, fv.pos.Child(0), dt.ElemField(), fv.memo)) flatbuf.MapStart(fv.b) flatbuf.MapAddKeysSorted(fv.b, dt.KeysSorted) fv.offset = flatbuf.MapEnd(fv.b) diff --git a/go/arrow/scalar/nested.go b/go/arrow/scalar/nested.go index d777727e11f44..38fe3298e5cb4 100644 --- a/go/arrow/scalar/nested.go +++ b/go/arrow/scalar/nested.go @@ -69,20 +69,7 @@ func (l *List) Validate() (err error) { return } - var ( - valueType arrow.DataType - ) - - switch dt := l.Type.(type) { - case *arrow.ListType: - valueType = dt.Elem() - case *arrow.LargeListType: - valueType = dt.Elem() - case *arrow.FixedSizeListType: - valueType = dt.Elem() - case *arrow.MapType: - valueType = dt.ValueType() - } + valueType := l.Type.(arrow.ListLikeType).Elem() listType := l.Type if !arrow.TypeEqual(l.Value.DataType(), valueType) { diff --git a/go/arrow/scalar/parse.go b/go/arrow/scalar/parse.go index 1d50c5f9b972e..2051c54b397d0 100644 --- a/go/arrow/scalar/parse.go +++ b/go/arrow/scalar/parse.go @@ -512,7 +512,7 @@ func MakeScalarParam(val interface{}, dt arrow.DataType) (Scalar, error) { } return NewFixedSizeListScalarWithType(v, dt), nil case arrow.MAP: - if !arrow.TypeEqual(dt.(*arrow.MapType).ValueType(), v.DataType()) { + if !arrow.TypeEqual(dt.(*arrow.MapType).Elem(), v.DataType()) { return nil, fmt.Errorf("inconsistent type for map scalar type") } return NewMapScalar(v), nil diff --git a/go/parquet/pqarrow/encode_arrow.go b/go/parquet/pqarrow/encode_arrow.go index dea412c8b9e9c..a91cee3a7307a 100644 --- a/go/parquet/pqarrow/encode_arrow.go +++ b/go/parquet/pqarrow/encode_arrow.go @@ -36,25 +36,17 @@ import ( // get the count of the number of leaf arrays for the type func calcLeafCount(dt arrow.DataType) int { - switch dt.ID() { - case arrow.EXTENSION: - return calcLeafCount(dt.(arrow.ExtensionType).StorageType()) - case arrow.SPARSE_UNION, arrow.DENSE_UNION: - panic("arrow type not implemented") - case arrow.DICTIONARY: - return calcLeafCount(dt.(*arrow.DictionaryType).ValueType) - case arrow.LIST: - return calcLeafCount(dt.(*arrow.ListType).Elem()) - case arrow.FIXED_SIZE_LIST: - return calcLeafCount(dt.(*arrow.FixedSizeListType).Elem()) - case arrow.MAP: - return calcLeafCount(dt.(*arrow.MapType).ValueType()) - case arrow.STRUCT: + switch dt := dt.(type) { + case arrow.ExtensionType: + return calcLeafCount(dt.StorageType()) + case arrow.NestedType: nleaves := 0 - for _, f := range dt.(*arrow.StructType).Fields() { + for _, f := range dt.Fields() { nleaves += calcLeafCount(f.Type) } return nleaves + case *arrow.DictionaryType: + return calcLeafCount(dt.ValueType) default: return 1 } diff --git a/go/parquet/pqarrow/schema.go b/go/parquet/pqarrow/schema.go index 73eece19c79f1..2e83d3926a48f 100644 --- a/go/parquet/pqarrow/schema.go +++ b/go/parquet/pqarrow/schema.go @@ -1002,7 +1002,7 @@ func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField) (mo } inferred.Field.Type = factory(modifiedChildren) } - case arrow.FIXED_SIZE_LIST, arrow.LIST, arrow.MAP: + case arrow.FIXED_SIZE_LIST, arrow.LIST, arrow.LARGE_LIST, arrow.MAP: // arrow.ListLike if nchildren != 1 { return } @@ -1012,17 +1012,9 @@ func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField) (mo } modified = origin.Type.ID() != inferred.Field.Type.ID() - var childModified bool - switch typ := origin.Type.(type) { - case *arrow.FixedSizeListType: - childModified, err = applyOriginalMetadata(arrow.Field{Type: typ.Elem()}, &inferred.Children[0]) - case *arrow.ListType: - childModified, err = applyOriginalMetadata(arrow.Field{Type: typ.Elem()}, &inferred.Children[0]) - case *arrow.MapType: - childModified, err = applyOriginalMetadata(arrow.Field{Type: typ.ValueType()}, &inferred.Children[0]) - } + childModified, err := applyOriginalMetadata(arrow.Field{Type: origin.Type.(arrow.ListLikeType).Elem()}, &inferred.Children[0]) if err != nil { - return + return modified, err } modified = modified || childModified if modified { From c62ce6b179f67ff57fa1571b01f356739727c0c9 Mon Sep 17 00:00:00 2001 From: Bryce Mecum Date: Tue, 6 Jun 2023 17:31:26 -0800 Subject: [PATCH 16/31] GH-35601: [R][Documentation] Add missing docs to fileysystem.R (#35895) ### What changes are included in this PR? Just new docstrings to cover missing ones. ### Are these changes tested? I built the package locally and previewed the new help pages. ### Are there any user-facing changes? No * Closes: #35601 Authored-by: Bryce Mecum Signed-off-by: Dewey Dunnington --- r/R/filesystem.R | 11 +++++++++++ r/man/FileSystem.Rd | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/r/R/filesystem.R b/r/R/filesystem.R index ee5e5a62e248c..f028e57425ba1 100644 --- a/r/R/filesystem.R +++ b/r/R/filesystem.R @@ -148,6 +148,8 @@ FileSelector$create <- function(base_dir, allow_not_found = FALSE, recursive = F #' such as "localhost:9000". This is useful for connecting to file systems #' that emulate S3. #' - `scheme`: S3 connection transport (default "https") +#' - `proxy_options`: optional string, URI of a proxy to use when connecting +#' to S3 #' - `background_writes`: logical, whether `OutputStream` writes will be issued #' in the background, without blocking (default `TRUE`) #' - `allow_bucket_creation`: logical, if TRUE, the filesystem will create @@ -182,6 +184,13 @@ FileSelector$create <- function(base_dir, allow_not_found = FALSE, recursive = F #' #' @section Methods: #' +#' - `path(x)`: Create a `SubTreeFileSystem` from the current `FileSystem` +#' rooted at the specified path `x`. +#' - `cd(x)`: Create a `SubTreeFileSystem` from the current `FileSystem` +#' rooted at the specified path `x`. +#' - `ls(path, ...)`: List files or objects at the given path or from the root +#' of the `FileSystem` if `path` is not provided. Additional arguments passed +#' to `FileSelector$create`, see [FileSelector][FileSelector]. #' - `$GetFileInfo(x)`: `x` may be a [FileSelector][FileSelector] or a character #' vector of paths. Returns a list of [FileInfo][FileInfo] #' - `$CreateDir(path, recursive = TRUE)`: Create a directory and subdirectories. @@ -217,6 +226,8 @@ FileSelector$create <- function(base_dir, allow_not_found = FALSE, recursive = F #' - `$base_fs`: for `SubTreeFileSystem`, the `FileSystem` it contains #' - `$base_path`: for `SubTreeFileSystem`, the path in `$base_fs` which is considered #' root in this `SubTreeFileSystem`. +#' - `$options`: for `GcsFileSystem`, the options used to create the +#' `GcsFileSystem` instance as a `list` #' #' @section Notes: #' diff --git a/r/man/FileSystem.Rd b/r/man/FileSystem.Rd index 38c694af995c6..3e15a020275c9 100644 --- a/r/man/FileSystem.Rd +++ b/r/man/FileSystem.Rd @@ -49,6 +49,8 @@ to "us-east-1" if no other alternatives are found. such as "localhost:9000". This is useful for connecting to file systems that emulate S3. \item \code{scheme}: S3 connection transport (default "https") +\item \code{proxy_options}: optional string, URI of a proxy to use when connecting +to S3 \item \code{background_writes}: logical, whether \code{OutputStream} writes will be issued in the background, without blocking (default \code{TRUE}) \item \code{allow_bucket_creation}: logical, if TRUE, the filesystem will create @@ -87,6 +89,13 @@ the filesystem encounters errors. Default is 15 seconds. \section{Methods}{ \itemize{ +\item \code{path(x)}: Create a \code{SubTreeFileSystem} from the current \code{FileSystem} +rooted at the specified path \code{x}. +\item \code{cd(x)}: Create a \code{SubTreeFileSystem} from the current \code{FileSystem} +rooted at the specified path \code{x}. +\item \code{ls(path, ...)}: List files or objects at the given path or from the root +of the \code{FileSystem} if \code{path} is not provided. Additional arguments passed +to \code{FileSelector$create}, see \link{FileSelector}. \item \verb{$GetFileInfo(x)}: \code{x} may be a \link{FileSelector} or a character vector of paths. Returns a list of \link{FileInfo} \item \verb{$CreateDir(path, recursive = TRUE)}: Create a directory and subdirectories. @@ -125,6 +134,8 @@ containing a \code{S3FileSystem} \item \verb{$base_fs}: for \code{SubTreeFileSystem}, the \code{FileSystem} it contains \item \verb{$base_path}: for \code{SubTreeFileSystem}, the path in \verb{$base_fs} which is considered root in this \code{SubTreeFileSystem}. +\item \verb{$options}: for \code{GcsFileSystem}, the options used to create the +\code{GcsFileSystem} instance as a \code{list} } } From a0d28deefd660adc6b14c4ac5fe4ed4175d4fedb Mon Sep 17 00:00:00 2001 From: eitsupi <50911393+eitsupi@users.noreply.github.com> Date: Wed, 7 Jun 2023 15:40:43 +0900 Subject: [PATCH 17/31] GH-33987: [R] Support new dplyr .by/by argument (#35667) ### Rationale for this change Implement the `.by` argument for `mutate`, `summarise`, `filter` and `slice_*` family. ### What changes are included in this PR? The `.by` argument that matches `dplyr` has been added to some functions. Most of the internal functions, such as `compute_by`, are copied from the existing `dplyr` backends, `dbplyr` and `dtplyr`. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * Closes: #33987 Authored-by: SHIMA Tatsuya Signed-off-by: Nic Crane --- r/DESCRIPTION | 1 + r/NAMESPACE | 1 + r/R/arrow-package.R | 2 +- r/R/dplyr-by.R | 77 +++++++++++++++++++++++++ r/R/dplyr-filter.R | 35 +++++++---- r/R/dplyr-mutate.R | 39 ++++++++----- r/R/dplyr-slice.R | 37 ++++++------ r/R/dplyr-summarize.R | 35 ++++++----- r/tests/testthat/test-dplyr-filter.R | 53 +++++++++++++++++ r/tests/testthat/test-dplyr-mutate.R | 72 +++++++++++++++++++++++ r/tests/testthat/test-dplyr-slice.R | 22 +++++++ r/tests/testthat/test-dplyr-summarize.R | 23 ++++++++ 12 files changed, 340 insertions(+), 57 deletions(-) create mode 100644 r/R/dplyr-by.R diff --git a/r/DESCRIPTION b/r/DESCRIPTION index a65e03a466a53..867829c7f80c6 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -95,6 +95,7 @@ Collate: 'dictionary.R' 'dplyr-across.R' 'dplyr-arrange.R' + 'dplyr-by.R' 'dplyr-collect.R' 'dplyr-count.R' 'dplyr-datetime-helpers.R' diff --git a/r/NAMESPACE b/r/NAMESPACE index ba791cb415728..06edcb31a8a29 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -436,6 +436,7 @@ importFrom(rlang,call2) importFrom(rlang,call_args) importFrom(rlang,caller_env) importFrom(rlang,check_dots_empty) +importFrom(rlang,check_dots_empty0) importFrom(rlang,dots_list) importFrom(rlang,dots_n) importFrom(rlang,enexpr) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 76c420e21fac6..6105a062a879e 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -26,7 +26,7 @@ #' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure #' @importFrom rlang is_list call2 is_empty as_function as_label arg_match is_symbol is_call call_args #' @importFrom rlang quo_set_env quo_get_env is_formula quo_is_call f_rhs parse_expr f_env new_quosure -#' @importFrom rlang new_quosures expr_text caller_env check_dots_empty dots_list is_string inform +#' @importFrom rlang new_quosures expr_text caller_env check_dots_empty check_dots_empty0 dots_list is_string inform #' @importFrom tidyselect vars_pull eval_select eval_rename #' @importFrom glue glue #' @useDynLib arrow, .registration = TRUE diff --git a/r/R/dplyr-by.R b/r/R/dplyr-by.R new file mode 100644 index 0000000000000..ac80cd2ea00c2 --- /dev/null +++ b/r/R/dplyr-by.R @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +compute_by <- function(by, data, ..., by_arg = "by", data_arg = "data", error_call = caller_env()) { + check_dots_empty0(...) + + by <- enquo(by) + check_by(by, data, by_arg = by_arg, data_arg = data_arg, error_call = error_call) + + if (is_grouped_adq(data)) { + names <- data$group_by_vars + from_by <- FALSE + } else { + names <- eval_select_by(by, data, error_call = error_call) + from_by <- TRUE + } + + new_by(from_by = from_by, names = names) +} + +is_grouped_adq <- function(data) { + !is_empty(data$group_by_vars) +} + +check_by <- function(by, + data, + ..., + by_arg = "by", + data_arg = "data", + error_call = caller_env()) { + check_dots_empty0(...) + + if (quo_is_null(by)) { + return(invisible(NULL)) + } + + if (is_grouped_adq(data)) { + message <- paste0( + "Can't supply `", by_arg, "` when `", + data_arg, "` is grouped data." + ) + abort(message) + } + + invisible(NULL) +} + +eval_select_by <- function(by, + data, + error_call = caller_env()) { + sim_df <- as.data.frame(implicit_schema(data)) + out <- eval_select( + expr = by, + data = sim_df, + allow_rename = FALSE, + error_call = error_call + ) + names(out) +} + +new_by <- function(from_by, names) { + structure(list(from_by = from_by, names = names), class = "arrow_dplyr_query_by") +} diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R index a864f1e9ce9ca..c14c67e70168c 100644 --- a/r/R/dplyr-filter.R +++ b/r/R/dplyr-filter.R @@ -18,25 +18,31 @@ # The following S3 methods are registered on load if dplyr is present -filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { - .data <- as_adq(.data) +filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) { # TODO something with the .preserve argument - filts <- expand_across(.data, quos(...)) + out <- as_adq(.data) + + by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") + + if (by$from_by) { + out$group_by_vars <- by$names + } + + filts <- expand_across(out, quos(...)) if (length(filts) == 0) { # Nothing to do - return(.data) + return(as_adq(.data)) } - .data <- as_adq(.data) # tidy-eval the filter expressions inside an Arrow data_mask - filters <- lapply(filts, arrow_eval, arrow_mask(.data)) + filters <- lapply(filts, arrow_eval, arrow_mask(out)) bad_filters <- map_lgl(filters, ~ inherits(., "try-error")) if (any(bad_filters)) { # This is similar to abandon_ship() except that the filter eval is # vectorized, and we apply filters that _did_ work before abandoning ship # with the rest expr_labs <- map_chr(filts[bad_filters], format_expr) - if (query_on_dataset(.data)) { + if (query_on_dataset(out)) { # Abort. We don't want to auto-collect if this is a Dataset because that # could blow up, too big. stop( @@ -61,12 +67,21 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { call. = FALSE ) # Set any valid filters first, then collect and then apply the invalid ones in R - .data <- set_filters(.data, filters[!bad_filters]) - return(dplyr::filter(dplyr::collect(.data), !!!filts[bad_filters])) + out <- dplyr::collect(set_filters(out, filters[!bad_filters])) + if (by$from_by) { + out <- dplyr::ungroup(out) + } + return(dplyr::filter(out, !!!filts[bad_filters], .by = {{ .by }})) } } - set_filters(.data, filters) + out <- set_filters(out, filters) + + if (by$from_by) { + out$group_by_vars <- character() + } + + out } filter.Dataset <- filter.ArrowTabular <- filter.RecordBatchReader <- filter.arrow_dplyr_query diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R index 638a4566ab1fe..287532dee08a9 100644 --- a/r/R/dplyr-mutate.R +++ b/r/R/dplyr-mutate.R @@ -20,14 +20,20 @@ mutate.arrow_dplyr_query <- function(.data, ..., + .by = NULL, .keep = c("all", "used", "unused", "none"), .before = NULL, .after = NULL) { call <- match.call() - .data <- as_adq(.data) - grv <- .data$group_by_vars + out <- as_adq(.data) - expression_list <- expand_across(.data, quos(...), exclude_cols = grv) + by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") + + if (by$from_by) { + out$group_by_vars <- by$names + } + grv <- out$group_by_vars + expression_list <- expand_across(out, quos(...), exclude_cols = grv) exprs <- ensure_named_exprs(expression_list) .keep <- match.arg(.keep) @@ -36,7 +42,7 @@ mutate.arrow_dplyr_query <- function(.data, if (.keep %in% c("all", "unused") && length(exprs) == 0) { # Nothing to do - return(.data) + return(out) } # Restrict the cases we support for now @@ -49,7 +55,7 @@ mutate.arrow_dplyr_query <- function(.data, return(abandon_ship(call, .data, "window functions not currently supported in Arrow")) } - mask <- arrow_mask(.data) + mask <- arrow_mask(out) results <- list() for (i in seq_along(exprs)) { # Iterate over the indices and not the names because names may be repeated @@ -75,23 +81,23 @@ mutate.arrow_dplyr_query <- function(.data, mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] } - old_vars <- names(.data$selected_columns) + old_vars <- names(out$selected_columns) # Note that this is names(exprs) not names(results): # if results$new_var is NULL, that means we are supposed to remove it new_vars <- names(exprs) - # Assign the new columns into the .data$selected_columns + # Assign the new columns into the out$selected_columns for (new_var in new_vars) { - .data$selected_columns[[new_var]] <- results[[new_var]] + out$selected_columns[[new_var]] <- results[[new_var]] } # Deduplicate new_vars and remove NULL columns from new_vars - new_vars <- intersect(union(new_vars, grv), names(.data$selected_columns)) + new_vars <- intersect(union(new_vars, grv), names(out$selected_columns)) # Respect .before and .after if (!quo_is_null(.before) || !quo_is_null(.after)) { new <- setdiff(new_vars, old_vars) - .data <- dplyr::relocate(.data, all_of(new), .before = !!.before, .after = !!.after) + out <- dplyr::relocate(out, all_of(new), .before = !!.before, .after = !!.after) } # Respect .keep @@ -99,19 +105,24 @@ mutate.arrow_dplyr_query <- function(.data, ## for consistency with dplyr, this appends new columns after existing columns ## by specifying the order new_cols_last <- c(intersect(old_vars, new_vars), setdiff(new_vars, old_vars)) - .data$selected_columns <- .data$selected_columns[new_cols_last] + out$selected_columns <- out$selected_columns[new_cols_last] } else if (.keep != "all") { # "used" or "unused" used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE) if (.keep == "used") { - .data$selected_columns[setdiff(old_vars, used_vars)] <- NULL + out$selected_columns[setdiff(old_vars, used_vars)] <- NULL } else { # "unused" - .data$selected_columns[intersect(old_vars, used_vars)] <- NULL + out$selected_columns[intersect(old_vars, used_vars)] <- NULL } } + + if (by$from_by) { + out$group_by_vars <- character() + } + # Even if "none", we still keep group vars - ensure_group_vars(.data) + ensure_group_vars(out) } mutate.Dataset <- mutate.ArrowTabular <- mutate.RecordBatchReader <- mutate.arrow_dplyr_query diff --git a/r/R/dplyr-slice.R b/r/R/dplyr-slice.R index ba7ec5fc44aa7..bcb6547f7c8e9 100644 --- a/r/R/dplyr-slice.R +++ b/r/R/dplyr-slice.R @@ -18,10 +18,8 @@ # The following S3 methods are registered on load if dplyr is present -slice_head.arrow_dplyr_query <- function(.data, ..., n, prop) { - if (length(dplyr::group_vars(.data)) > 0) { - arrow_not_supported("Slicing grouped data") - } +slice_head.arrow_dplyr_query <- function(.data, ..., n, prop, by = NULL) { + check_not_grouped(.data, {{ by }}) check_dots_empty() if (missing(n)) { @@ -32,10 +30,8 @@ slice_head.arrow_dplyr_query <- function(.data, ..., n, prop) { } slice_head.Dataset <- slice_head.ArrowTabular <- slice_head.RecordBatchReader <- slice_head.arrow_dplyr_query -slice_tail.arrow_dplyr_query <- function(.data, ..., n, prop) { - if (length(dplyr::group_vars(.data)) > 0) { - arrow_not_supported("Slicing grouped data") - } +slice_tail.arrow_dplyr_query <- function(.data, ..., n, prop, by = NULL) { + check_not_grouped(.data, {{ by }}) check_dots_empty() if (missing(n)) { @@ -46,10 +42,8 @@ slice_tail.arrow_dplyr_query <- function(.data, ..., n, prop) { } slice_tail.Dataset <- slice_tail.ArrowTabular <- slice_tail.RecordBatchReader <- slice_tail.arrow_dplyr_query -slice_min.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, with_ties = TRUE) { - if (length(dplyr::group_vars(.data)) > 0) { - arrow_not_supported("Slicing grouped data") - } +slice_min.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, by = NULL, with_ties = TRUE) { + check_not_grouped(.data, {{ by }}) if (with_ties) { arrow_not_supported("with_ties = TRUE") } @@ -63,10 +57,8 @@ slice_min.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, with_ties } slice_min.Dataset <- slice_min.ArrowTabular <- slice_min.RecordBatchReader <- slice_min.arrow_dplyr_query -slice_max.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, with_ties = TRUE) { - if (length(dplyr::group_vars(.data)) > 0) { - arrow_not_supported("Slicing grouped data") - } +slice_max.arrow_dplyr_query <- function(.data, order_by, ..., n, prop, by = NULL, with_ties = TRUE) { + check_not_grouped(.data, {{ by }}) if (with_ties) { arrow_not_supported("with_ties = TRUE") } @@ -91,11 +83,10 @@ slice_sample.arrow_dplyr_query <- function(.data, ..., n, prop, + by = NULL, weight_by = NULL, replace = FALSE) { - if (length(dplyr::group_vars(.data)) > 0) { - arrow_not_supported("Slicing grouped data") - } + check_not_grouped(.data, {{ by }}) if (replace) { arrow_not_supported("Sampling with replacement") } @@ -168,3 +159,11 @@ n_to_prop <- function(.data, n) { } n / nrows } + + +check_not_grouped <- function(.data, by) { + by <- enquo(by) + if (length(dplyr::group_vars(.data)) > 0 || !quo_is_null(by)) { + arrow_not_supported("Slicing grouped data") + } +} diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index c02b6ee52210c..4eb24ab8b86eb 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -169,14 +169,23 @@ agg_funcs[["::"]] <- function(lhs, rhs) { # The following S3 methods are registered on load if dplyr is present -summarise.arrow_dplyr_query <- function(.data, ..., .groups = NULL) { +summarise.arrow_dplyr_query <- function(.data, ..., .by = NULL, .groups = NULL) { call <- match.call() - .data <- as_adq(.data) - exprs <- expand_across(.data, quos(...), exclude_cols = .data$group_by_vars) + out <- as_adq(.data) + + by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") + + if (by$from_by) { + out$group_by_vars <- by$names + .groups <- "drop" + } + + exprs <- expand_across(out, quos(...), exclude_cols = out$group_by_vars) + # Only retain the columns we need to do our aggregations vars_to_keep <- unique(c( unlist(lapply(exprs, all.vars)), # vars referenced in summarise - dplyr::group_vars(.data) # vars needed for grouping + dplyr::group_vars(out) # vars needed for grouping )) # If exprs rely on the results of previous exprs # (total = sum(x), mean = total / n()) @@ -185,15 +194,15 @@ summarise.arrow_dplyr_query <- function(.data, ..., .groups = NULL) { # Note that this select() isn't useful for the Arrow summarize implementation # because it will effectively project to keep what it needs anyway, # but the data.frame fallback version does benefit from select here - .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data))) + out <- dplyr::select(out, intersect(vars_to_keep, names(out))) # Try stuff, if successful return() - out <- try(do_arrow_summarize(.data, !!!exprs, .groups = .groups), silent = TRUE) + out <- try(do_arrow_summarize(out, !!!exprs, .groups = .groups), silent = TRUE) if (inherits(out, "try-error")) { - return(abandon_ship(call, .data, format(out))) - } else { - return(out) + out <- abandon_ship(call, .data, format(out)) } + + out } summarise.Dataset <- summarise.ArrowTabular <- summarise.RecordBatchReader <- summarise.arrow_dplyr_query @@ -315,7 +324,7 @@ summarize_projection <- function(.data) { c( unlist(unname(imap( .data$aggregations, - ~set_names( + ~ set_names( .x$data, aggregate_target_names(.x$data, .y) ) @@ -345,7 +354,7 @@ aggregate_types <- function(.data, hash, schema = NULL) { if (hash) dummy_groups <- Scalar$create(1L, uint32()) map( .data$aggregations, - ~if (hash) { + ~ if (hash) { Expression$create( paste0("hash_", .$fun), # hash aggregate kernels must be passed an additional argument @@ -367,11 +376,11 @@ aggregate_types <- function(.data, hash, schema = NULL) { # This function returns a named list of the data types of the group columns # returned by an aggregation group_types <- function(.data, schema = NULL) { - map(.data$selected_columns[.data$group_by_vars], ~.$type(schema)) + map(.data$selected_columns[.data$group_by_vars], ~ .$type(schema)) } format_aggregation <- function(x) { - paste0(x$fun, "(", paste(map(x$data, ~.$ToString()), collapse = ","), ")") + paste0(x$fun, "(", paste(map(x$data, ~ .$ToString()), collapse = ","), ")") } # This function handles each summarize expression and turns it into the diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index 8b144f47852f0..724b93c96609f 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -425,3 +425,56 @@ test_that("filter() with across()", { tbl ) }) + +test_that(".by argument", { + compare_dplyr_binding( + .input %>% + filter(is.na(lgl), .by = chr) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + filter(is.na(lgl), .by = starts_with("chr")) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + filter(.by = chr) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + filter(.by = c(int, chr)) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + filter(.by = c("int", "chr")) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) + # filter should pulling not grouped data into R when using the .by argument + compare_dplyr_binding( + .input %>% + filter(int > 2, pnorm(dbl) > .99, .by = chr) %>% + collect(), + tbl, + warning = "Expression pnorm\\(dbl\\) > 0.99 not supported in Arrow; pulling data into R" + ) + expect_error( + tbl %>% + arrow_table() %>% + group_by(chr) %>% + filter(is.na(lgl), .by = chr), + "Can't supply `\\.by` when `\\.data` is grouped data" + ) +}) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 79554059cb38b..0889fffedd508 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -441,6 +441,78 @@ test_that("Can mutate after group_by as long as there are no aggregations", { ) }) +test_that("Can mutate with .by argument as long as there are no aggregations", { + compare_dplyr_binding( + .input %>% + select(int, chr) %>% + mutate(int = int + 6L, .by = chr) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + select(int, chr) %>% + mutate(int = int + 6L, .by = starts_with("chr")) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + select(int, chr) %>% + mutate(new_col = int + 6L, .by = c(chr, int)) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + select(int, chr) %>% + mutate(new_col = int + 6L, .by = c("chr", "int")) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + select(mean = int, chr) %>% + # rename `int` to `mean` and use `mean` in `mutate()` to test that + # `all_funs()` does not incorrectly identify it as an aggregate function + mutate(mean = mean + 6L, .by = chr) %>% + collect(), + tbl + ) + expect_warning( + tbl %>% + Table$create() %>% + select(int, chr) %>% + mutate(avg_int = mean(int), .by = chr) %>% + collect(), + "window functions not currently supported in Arrow; pulling data into R", + fixed = TRUE + ) + expect_warning( + tbl %>% + Table$create() %>% + select(mean = int, chr) %>% + # rename `int` to `mean` and use `mean(mean)` in `mutate()` to test that + # `all_funs()` detects `mean()` despite the collision with a column name + mutate(avg_int = mean(mean), .by = chr) %>% + collect(), + "window functions not currently supported in Arrow; pulling data into R", + fixed = TRUE + ) +}) + +test_that("Can't supply .by after group_by", { + expect_error( + tbl %>% + arrow_table() %>% + select(int, chr) %>% + group_by(chr) %>% + mutate(int = int + 6L, .by = chr) %>% + collect(), + "Can't supply `\\.by` when `\\.data` is grouped data" + ) +}) + test_that("handle bad expressions", { # TODO: search for functions other than mean() (see above test) # that need to be forced to fail because they error ambiguously diff --git a/r/tests/testthat/test-dplyr-slice.R b/r/tests/testthat/test-dplyr-slice.R index 6d0711589c22c..3b103d2e3cd8a 100644 --- a/r/tests/testthat/test-dplyr-slice.R +++ b/r/tests/testthat/test-dplyr-slice.R @@ -158,6 +158,28 @@ test_that("slice_* not supported with groups", { slice_sample(grouped, n = 5), "Slicing grouped data not supported in Arrow" ) + + # with the by argument + expect_error( + slice_head(arrow_table(tbl), n = 5, by = lgl), + "Slicing grouped data not supported in Arrow" + ) + expect_error( + slice_tail(arrow_table(tbl), n = 5, by = lgl), + "Slicing grouped data not supported in Arrow" + ) + expect_error( + slice_min(arrow_table(tbl), int, n = 5, by = lgl), + "Slicing grouped data not supported in Arrow" + ) + expect_error( + slice_max(arrow_table(tbl), int, n = 5, by = lgl), + "Slicing grouped data not supported in Arrow" + ) + expect_error( + slice_sample(arrow_table(tbl), n = 5, by = lgl), + "Slicing grouped data not supported in Arrow" + ) }) test_that("input validation", { diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 09f50986d7510..e2fb9841e72ad 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -1194,3 +1194,26 @@ test_that("across() does not select grouping variables within summarise()", { "Column `int` doesn't exist" ) }) + +test_that(".by argument", { + compare_dplyr_binding( + .input %>% + summarize(total = sum(int, na.rm = TRUE), .by = some_grouping) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + summarize(total = sum(int, na.rm = TRUE), .by = starts_with("dbl")) %>% + arrange(dbl) %>% + collect(), + tbl + ) + expect_error( + tbl %>% + arrow_table() %>% + group_by(some_grouping) %>% + summarize(total = sum(int, na.rm = TRUE), .by = starts_with("dbl")), + "Can't supply `\\.by` when `\\.data` is grouped data" + ) +}) From dd267572a65272dd01c30a8df15c3b43cf6ba007 Mon Sep 17 00:00:00 2001 From: David Greiss Date: Wed, 7 Jun 2023 04:21:51 -0400 Subject: [PATCH 18/31] GH-35709: [R][Documentation] Document passing data to duckdb for windowed aggregates (#35882) ### Rationale for this change #35702 documents how to use joins for computing windowed aggregates. This documents an alternative solution by passing data to duckdb. This use case was also mentioned on the [duckdb blog](https://duckdb.org/2021/12/03/duck-arrow.html). ### What changes are included in this PR? Changes to vignette. * Closes: #35709 Authored-by: David Greiss Signed-off-by: Nic Crane --- r/vignettes/data_wrangling.Rmd | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/r/vignettes/data_wrangling.Rmd b/r/vignettes/data_wrangling.Rmd index bad1d4bd58f35..e3d5b306f3e71 100644 --- a/r/vignettes/data_wrangling.Rmd +++ b/r/vignettes/data_wrangling.Rmd @@ -1,7 +1,7 @@ --- title: "Data analysis with dplyr syntax" description: > - Learn how to use the dplyr backend supplied by arrow + Learn how to use the dplyr backend supplied by arrow output: rmarkdown::html_vignette --- @@ -61,7 +61,7 @@ sw %>% collect() ``` -Note, however, that window functions such as `ntile()` are not yet supported. +Note, however, that window functions such as `ntile()` are not yet supported. ## Two-table dplyr verbs @@ -109,7 +109,7 @@ register_scalar_function( ) ``` -In this expression, the `name` argument specifies the name by which it will be recognized in the context of the arrow/dplyr pipeline and `fun` is the function itself. The `in_type` and `out_type` arguments are used to specify the expected data type for the input and output, and `auto_convert` specifies whether arrow should automatically convert any R inputs to their Arrow equivalents. +In this expression, the `name` argument specifies the name by which it will be recognized in the context of the arrow/dplyr pipeline and `fun` is the function itself. The `in_type` and `out_type` arguments are used to specify the expected data type for the input and output, and `auto_convert` specifies whether arrow should automatically convert any R inputs to their Arrow equivalents. Once registered, the following works: @@ -119,7 +119,7 @@ sw %>% collect() ``` -To learn more, see `help("register_scalar_function", package = "arrow")`. +To learn more, see `help("register_scalar_function", package = "arrow")`. ## Handling unsupported expressions @@ -127,7 +127,7 @@ For dplyr queries on Table objects, which are held in memory and should usually be representable as data frames, if the arrow package detects an unimplemented function within a dplyr verb, it automatically calls `collect()` to return the data as an R data frame before processing -that dplyr verb. As an example, neither `lm()` nor `residuals()` are +that dplyr verb. As an example, neither `lm()` nor `residuals()` are implemented, so if we write code that computes the residuals for a linear regression model, this automatic collection takes place: @@ -139,9 +139,9 @@ sw %>% For queries on `Dataset` objects -- which can be larger than memory -- arrow is more conservative and always raises an -error if it detects an unsupported expression. To illustrate this +error if it detects an unsupported expression. To illustrate this behavior, we can write the `starwars` data to disk and then open -it as a Dataset. When we use the same pipeline on the Dataset, +it as a Dataset. When we use the same pipeline on the Dataset, we obtain an error: ```{r, error=TRUE} @@ -165,7 +165,7 @@ sw2 %>% transmute(name, height, mass, res = residuals(lm(mass ~ height))) ``` -Because window functions are not supported, computing an aggregation like `mean()` on a grouped table or within a rowwise opertation like `filter()` is not supported: +Because window functions are not supported, computing an aggregation like `mean()` on a grouped table or within a rowwise opertation like `filter()` is not supported: ```{r} sw %>% @@ -175,7 +175,7 @@ sw %>% filter(height < mean(height, na.rm = TRUE)) ``` -This operation can be accomplished in arrow by computing the aggregation separately, for example within a join operation: +This operation is sometimes referred to as a windowed aggregate and can be accomplished in Arrow by computing the aggregation separately, for example within a join operation: ```{r} sw %>% @@ -191,9 +191,22 @@ sw %>% collect() ``` +Alternatively, [DuckDB](https:\www.duckdb.org) supports Arrow natively, so you can pass the `Table` object to DuckDB without paying a performance penalty using the helper function `to_duckdb()` and pass the object back to Arrow with `to_arrow()`: + +```{r} +sw %>% + select(1:4) %>% + filter(!is.na(hair_color)) %>% + to_duckdb() %>% + group_by(hair_color) %>% + filter(height < mean(height, na.rm = TRUE)) %>% + to_arrow() %>% + # perform other arrow operations... + collect() +``` ## Further reading - To learn more about multi-file datasets, see the [dataset article](./dataset.html). - To learn more about user-registered functions, see `help("register_scalar_function", package = "arrow")`. -- To learn more about writing dplyr bindings as an arrow developer, see the [article on writing bindings](./developers/writing_bindings.html). +- To learn more about writing dplyr bindings as an arrow developer, see the [article on writing bindings](./developers/writing_bindings.html). From 1d758162044f806cd10d2e7a0953ea31f48a8594 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Wed, 7 Jun 2023 14:02:08 +0200 Subject: [PATCH 19/31] GH-33980: [Docs][Python] Document DataFrame Interchange Protocol implementation and usage (#35835) _edit: just added something_ * Closes: #33980 Lead-authored-by: AlenkaF Co-authored-by: Alenka Frim Co-authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- docs/source/conf.py | 1 + docs/source/python/api/tables.rst | 8 ++ docs/source/python/index.rst | 1 + docs/source/python/interchange_protocol.rst | 119 +++++++++++++++++++ python/pyarrow/interchange/from_dataframe.py | 25 ++++ 5 files changed, 154 insertions(+) create mode 100644 docs/source/python/interchange_protocol.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index 19b0c353bdc71..8a05641525ead 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -79,6 +79,7 @@ # Show members for classes in .. autosummary autodoc_default_options = { 'members': None, + 'special-members': '__dataframe__', 'undoc-members': None, 'show-inheritance': None, 'inherited-members': None diff --git a/docs/source/python/api/tables.rst b/docs/source/python/api/tables.rst index eadf40cb759b3..ae9f5de127dfd 100644 --- a/docs/source/python/api/tables.rst +++ b/docs/source/python/api/tables.rst @@ -46,6 +46,14 @@ Classes TableGroupBy RecordBatchReader +Dataframe Interchange Protocol +------------------------------ + +.. autosummary:: + :toctree: ../generated/ + + interchange.from_dataframe + .. _api.tensor: Tensors diff --git a/docs/source/python/index.rst b/docs/source/python/index.rst index 77cfaef4a408a..b80cbc7de594e 100644 --- a/docs/source/python/index.rst +++ b/docs/source/python/index.rst @@ -47,6 +47,7 @@ files into Arrow structures. filesystems_deprecated numpy pandas + interchange_protocol timestamps orc csv diff --git a/docs/source/python/interchange_protocol.rst b/docs/source/python/interchange_protocol.rst new file mode 100644 index 0000000000000..7784d78619e6e --- /dev/null +++ b/docs/source/python/interchange_protocol.rst @@ -0,0 +1,119 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Dataframe Interchange Protocol +============================== + +The interchange protocol is implemented for ``pa.Table`` and +``pa.RecordBatch`` and is used to interchange data between +PyArrow and other dataframe libraries that also have the +protocol implemented. The data structures that are supported +in the protocol are primitive data types plus the dictionary +data type. The protocol also has missing data support and +it supports chunking, meaning accessing the +data in “batches” of rows. + + +The Python dataframe interchange protocol is designed by the +`Consortium for Python Data API Standards `_ +in order to enable data interchange between dataframe +libraries in the Python ecosystem. See more about the +standard in the +`protocol documentation `_. + +From pyarrow to other libraries: ``__dataframe__()`` method +----------------------------------------------------------- + +The ``__dataframe__()`` method creates a new exchange object that +the consumer library can take and construct an object of it's own. + +.. code-block:: + + >>> import pyarrow as pa + >>> table = pa.table({"n_atendees": [100, 10, 1]}) + >>> table.__dataframe__() + + +This is meant to be used by the consumer library when calling +the ``from_dataframe()`` function and is not meant to be used manually +by the user. + +From other libraries to pyarrow: ``from_dataframe()`` +----------------------------------------------------- + +With the ``from_dataframe()`` function, we can construct a :class:`pyarrow.Table` +from any dataframe object that implements the +``__dataframe__()`` method via the dataframe interchange +protocol. + +We can for example take a pandas dataframe and construct a +pyarrow table with the use of the interchange protocol: + +.. code-block:: + + >>> import pyarrow + >>> from pyarrow.interchange import from_dataframe + + >>> import pandas as pd + >>> df = pd.DataFrame({ + ... "n_atendees": [100, 10, 1], + ... "country": ["Italy", "Spain", "Slovenia"], + ... }) + >>> df + n_atendees country + 0 100 Italy + 1 10 Spain + 2 1 Slovenia + >>> from_dataframe(df) + pyarrow.Table + n_atendees: int64 + country: large_string + ---- + n_atendees: [[100,10,1]] + country: [["Italy","Spain","Slovenia"]] + +We can do the same with a polars dataframe: + +.. code-block:: + + >>> import polars as pl + >>> from datetime import datetime + >>> arr = [datetime(2023, 5, 20, 10, 0), + ... datetime(2023, 5, 20, 11, 0), + ... datetime(2023, 5, 20, 13, 30)] + >>> df = pl.DataFrame({ + ... 'Talk': ['About Polars','Intro into PyArrow','Coding in Rust'], + ... 'Time': arr, + ... }) + >>> df + shape: (3, 2) + ┌────────────────────┬─────────────────────┐ + │ Talk ┆ Time │ + │ --- ┆ --- │ + │ str ┆ datetime[μs] │ + ╞════════════════════╪═════════════════════╡ + │ About Polars ┆ 2023-05-20 10:00:00 │ + │ Intro into PyArrow ┆ 2023-05-20 11:00:00 │ + │ Coding in Rust ┆ 2023-05-20 13:30:00 │ + └────────────────────┴─────────────────────┘ + >>> from_dataframe(df) + pyarrow.Table + Talk: large_string + Time: timestamp[us] + ---- + Talk: [["About Polars","Intro into PyArrow","Coding in Rust"]] + Time: [[2023-05-20 10:00:00.000000,2023-05-20 11:00:00.000000,2023-05-20 13:30:00.000000]] diff --git a/python/pyarrow/interchange/from_dataframe.py b/python/pyarrow/interchange/from_dataframe.py index 801d0dd452a97..1d41aa8d7eef3 100644 --- a/python/pyarrow/interchange/from_dataframe.py +++ b/python/pyarrow/interchange/from_dataframe.py @@ -74,6 +74,31 @@ def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table: Returns ------- pa.Table + + Examples + -------- + >>> import pyarrow + >>> from pyarrow.interchange import from_dataframe + + Convert a pandas dataframe to a pyarrow table: + + >>> import pandas as pd + >>> df = pd.DataFrame({ + ... "n_atendees": [100, 10, 1], + ... "country": ["Italy", "Spain", "Slovenia"], + ... }) + >>> df + n_atendees country + 0 100 Italy + 1 10 Spain + 2 1 Slovenia + >>> from_dataframe(df) + pyarrow.Table + n_atendees: int64 + country: large_string + ---- + n_atendees: [[100,10,1]] + country: [["Italy","Spain","Slovenia"]] """ if isinstance(df, pa.Table): return df From 96302584c849dcf6091eb2fe2078351d030f0dc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cumplido?= Date: Wed, 7 Jun 2023 16:40:04 +0200 Subject: [PATCH 20/31] GH-35946: [CI][Packaging] Free up more disk space for Linux packages (#35947) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Rationale for this change Fix some Linux packaging jobs that fail due to out of space ### What changes are included in this PR? Cleaning up some more cache ### Are these changes tested? With crossbow jobs ### Are there any user-facing changes? No * Closes: #35946 Authored-by: Raúl Cumplido Signed-off-by: Raúl Cumplido --- dev/tasks/linux-packages/github.linux.yml | 24 +++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dev/tasks/linux-packages/github.linux.yml b/dev/tasks/linux-packages/github.linux.yml index d055c439cd17c..3a00849a4d1a0 100644 --- a/dev/tasks/linux-packages/github.linux.yml +++ b/dev/tasks/linux-packages/github.linux.yml @@ -38,13 +38,37 @@ jobs: env.ARCHITECTURE == 'amd64' run: | df -h + du -hsc /opt/* /usr/local/* du -hsc /opt/hostedtoolcache/* + du -hs /usr/local/bin + # ~1GB (From 1.2GB to 214MB) + sudo rm -rf /usr/local/bin/aliyun \ + /usr/local/bin/azcopy \ + /usr/local/bin/bicep \ + /usr/local/bin/cmake-gui \ + /usr/local/bin/cpack \ + /usr/local/bin/helm \ + /usr/local/bin/hub \ + /usr/local/bin/kubectl \ + /usr/local/bin/minikube \ + /usr/local/bin/node \ + /usr/local/bin/packer \ + /usr/local/bin/pulumi* \ + /usr/local/bin/stack \ + /usr/local/bin/terraform + du -hs /usr/local/bin + du -hs /usr/local/share + # 1.3GB + sudo rm -rf /usr/local/share/powershell + du -hs /usr/local/share # 5.3GB sudo rm -rf /opt/hostedtoolcache/CodeQL || : # 1.4GB sudo rm -rf /opt/hostedtoolcache/go || : # 489MB sudo rm -rf /opt/hostedtoolcache/PyPy || : + # 1.2GB + sudo rm -rf /opt/hostedtoolcache/Python || : # 376MB sudo rm -rf /opt/hostedtoolcache/node || : df -h From cd7f431d8cd2808f8b49847fce0df0cc4270229b Mon Sep 17 00:00:00 2001 From: Igor Izvekov Date: Wed, 7 Jun 2023 18:20:43 +0300 Subject: [PATCH 21/31] GH-35965: [Go] Fix `Decimal256DictionaryBuilder` (#35966) ### Rationale for this change `Decimal256DictionaryBuilder` cannot append `decimal256` value and prefers `decimal128` instead. ``` cannot use v (variable of type decimal256.Num) as decimal128.Num value in argument to builder.Append ``` ### What changes are included in this PR? ### Are these changes tested? Yes ### Are there any user-facing changes? Yes * Closes: #35965 Authored-by: izveigor Signed-off-by: Matt Topol --- go/arrow/array/dictionary.go | 14 ++++++++++---- go/arrow/array/dictionary_test.go | 31 ++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/go/arrow/array/dictionary.go b/go/arrow/array/dictionary.go index 1cbaac5a2cab7..48f37e91d1012 100644 --- a/go/arrow/array/dictionary.go +++ b/go/arrow/array/dictionary.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/bitutil" "github.com/apache/arrow/go/v13/arrow/decimal128" + "github.com/apache/arrow/go/v13/arrow/decimal256" "github.com/apache/arrow/go/v13/arrow/float16" "github.com/apache/arrow/go/v13/arrow/internal/debug" "github.com/apache/arrow/go/v13/arrow/memory" @@ -45,12 +46,12 @@ import ( // // For example, the array: // -// ["foo", "bar", "foo", "bar", "foo", "bar"] +// ["foo", "bar", "foo", "bar", "foo", "bar"] // // with dictionary ["bar", "foo"], would have the representation of: // -// indices: [1, 0, 1, 0, 1, 0] -// dictionary: ["bar", "foo"] +// indices: [1, 0, 1, 0, 1, 0] +// dictionary: ["bar", "foo"] // // The indices in principle may be any integer type. type Dictionary struct { @@ -883,6 +884,11 @@ func getvalFn(arr arrow.Array) func(i int) interface{} { val := typedarr.Value(i) return (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:] } + case *Decimal256: + return func(i int) interface{} { + val := typedarr.Value(i) + return (*(*[arrow.Decimal256SizeBytes]byte)(unsafe.Pointer(&val)))[:] + } case *DayTimeInterval: return func(i int) interface{} { val := typedarr.Value(i) @@ -1373,7 +1379,7 @@ type Decimal256DictionaryBuilder struct { dictionaryBuilder } -func (b *Decimal256DictionaryBuilder) Append(v decimal128.Num) error { +func (b *Decimal256DictionaryBuilder) Append(v decimal256.Num) error { return b.appendValue((*(*[arrow.Decimal256SizeBytes]byte)(unsafe.Pointer(&v)))[:]) } func (b *Decimal256DictionaryBuilder) InsertDictValues(arr *Decimal256) (err error) { diff --git a/go/arrow/array/dictionary_test.go b/go/arrow/array/dictionary_test.go index 23b21e5aba7a5..230cc706a848d 100644 --- a/go/arrow/array/dictionary_test.go +++ b/go/arrow/array/dictionary_test.go @@ -27,6 +27,7 @@ import ( "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/bitutil" "github.com/apache/arrow/go/v13/arrow/decimal128" + "github.com/apache/arrow/go/v13/arrow/decimal256" "github.com/apache/arrow/go/v13/arrow/memory" "github.com/apache/arrow/go/v13/internal/types" "github.com/stretchr/testify/assert" @@ -878,7 +879,7 @@ func TestFixedSizeBinaryDictionaryStringRoundTrip(t *testing.T) { assert.True(t, array.Equal(arr, arr1)) } -func TestDecimalDictionaryBuilderBasic(t *testing.T) { +func TestDecimal128DictionaryBuilderBasic(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) @@ -906,6 +907,34 @@ func TestDecimalDictionaryBuilderBasic(t *testing.T) { assert.True(t, array.ArrayApproxEqual(expected, result)) } +func TestDecimal256DictionaryBuilderBasic(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + test := []decimal256.Num{decimal256.FromI64(12), decimal256.FromI64(12), decimal256.FromI64(11), decimal256.FromI64(12)} + dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: &arrow.Decimal256Type{Precision: 2, Scale: 0}} + bldr := array.NewDictionaryBuilder(mem, dictType) + defer bldr.Release() + + builder := bldr.(*array.Decimal256DictionaryBuilder) + for _, v := range test { + assert.NoError(t, builder.Append(v)) + } + + result := bldr.NewDictionaryArray() + defer result.Release() + + indices, _, _ := array.FromJSON(mem, dictType.IndexType, strings.NewReader("[0, 0, 1, 0]")) + defer indices.Release() + dict, _, _ := array.FromJSON(mem, dictType.ValueType, strings.NewReader("[12, 11]")) + defer dict.Release() + + expected := array.NewDictionaryArray(dictType, indices, dict) + defer expected.Release() + + assert.True(t, array.ArrayApproxEqual(expected, result)) +} + func TestNullDictionaryBuilderBasic(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) From acf3cbac6bdddb8e2b60e1095f591ee84b17f00c Mon Sep 17 00:00:00 2001 From: Igor Izvekov Date: Wed, 7 Jun 2023 18:22:07 +0300 Subject: [PATCH 22/31] MINOR: [Go] add `TestDecimal256JSON` (#35968) ### Rationale for this change Like `TestDecimal128JSON` ### What changes are included in this PR? ### Are these changes tested? ### Are there any user-facing changes? Authored-by: izveigor Signed-off-by: Matt Topol --- go/arrow/array/util_test.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/go/arrow/array/util_test.go b/go/arrow/array/util_test.go index 8b8ddfb464db9..8822c13bf5525 100644 --- a/go/arrow/array/util_test.go +++ b/go/arrow/array/util_test.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/decimal128" + "github.com/apache/arrow/go/v13/arrow/decimal256" "github.com/apache/arrow/go/v13/arrow/internal/arrdata" "github.com/apache/arrow/go/v13/arrow/memory" "github.com/goccy/go-json" @@ -417,16 +418,36 @@ func TestDecimal128JSON(t *testing.T) { assert.JSONEq(t, `["123.4567", null, "-78.9"]`, string(data)) } +func TestDecimal256JSON(t *testing.T) { + dt := &arrow.Decimal256Type{Precision: 10, Scale: 4} + bldr := array.NewDecimal256Builder(memory.DefaultAllocator, dt) + defer bldr.Release() + + bldr.AppendValues([]decimal256.Num{decimal256.FromU64(1234567), {}, decimal256.FromI64(-789000)}, []bool{true, false, true}) + expected := bldr.NewArray() + defer expected.Release() + + arr, _, err := array.FromJSON(memory.DefaultAllocator, dt, strings.NewReader(`["123.4567", null, "-78.9000"]`)) + assert.NoError(t, err) + defer arr.Release() + + assert.Truef(t, array.ArrayEqual(expected, arr), "expected: %s\ngot: %s\n", expected, arr) + + data, err := json.Marshal(arr) + assert.NoError(t, err) + assert.JSONEq(t, `["123.4567", null, "-78.9"]`, string(data)) +} + func TestArrRecordsJSONRoundTrip(t *testing.T) { for k, v := range arrdata.Records { - if k == "decimal128" || k == "fixed_width_types" { + if k == "decimal128" || k == "decimal256" || k == "fixed_width_types" { // test these separately since the sample data in the arrdata // records doesn't lend itself to exactness when going to/from // json. The fixed_width_types one uses negative values for // time32 and time64 which correctly get interpreted into times, // but re-encoding them in json produces the normalized positive // values instead of re-creating negative ones. - // the decimal128 values don't get parsed *exactly* due to fun + // the decimal128/decimal256 values don't get parsed *exactly* due to fun // float weirdness due to their size, so smaller tests will work fine. continue } From 9be7074f85d6057a92bc008766e8f30d58b30cf7 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 7 Jun 2023 12:32:59 -0400 Subject: [PATCH 23/31] GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute (#35654) ### Rationale for this change Providing the ability to execute more complex expressions than single operations by leveraging Substrait's expression objects and deprecating the existing separate Expression interfaces in Go Arrow compute. This provides a quick integration with Substrait Expressions and ExtendedExpressions to start building more integrations. ### What changes are included in this PR? This PR provides: * an extension registry for Go arrow to provide mappings between Arrow and substrait for functions and for types along with other custom mappings if necessary. * Facilities to convert between Arrow data types and Substrait types * Functions to evaluate Substrait expression objects with Arrow data as the input * Functions to evaluate Substrait field references against Arrow data and arrow schemas ### Are these changes tested? Yes, unit tests are included. ### Are there any user-facing changes? Existing `compute.Expression` and its friends are being marked as deprecated. * Closes: #35652 Authored-by: Matt Topol Signed-off-by: Matt Topol --- .gitignore | 2 + go/arrow/compute/arithmetic.go | 6 +- go/arrow/compute/exec.go | 6 +- go/arrow/compute/executor.go | 24 +- go/arrow/compute/expression.go | 11 +- go/arrow/compute/exprs/builders.go | 445 +++++++++++ go/arrow/compute/exprs/builders_test.go | 92 +++ go/arrow/compute/exprs/exec.go | 620 +++++++++++++++ go/arrow/compute/exprs/exec_internal_test.go | 114 +++ go/arrow/compute/exprs/exec_test.go | 461 ++++++++++++ go/arrow/compute/exprs/extension_types.go | 149 ++++ go/arrow/compute/exprs/field_refs.go | 254 +++++++ go/arrow/compute/exprs/types.go | 745 +++++++++++++++++++ go/arrow/compute/internal/kernels/cast.go | 1 + go/arrow/errors.go | 1 + go/go.mod | 39 +- go/go.sum | 159 ++-- go/internal/hashing/hash_funcs.go | 4 +- go/internal/hashing/hash_string.go | 2 +- go/internal/hashing/hash_string_go1.19.go | 2 +- go/internal/hashing/hashing_test.go | 4 +- go/internal/hashing/xxh3_memo_table.go | 6 +- 22 files changed, 3016 insertions(+), 131 deletions(-) create mode 100644 go/arrow/compute/exprs/builders.go create mode 100644 go/arrow/compute/exprs/builders_test.go create mode 100644 go/arrow/compute/exprs/exec.go create mode 100644 go/arrow/compute/exprs/exec_internal_test.go create mode 100644 go/arrow/compute/exprs/exec_test.go create mode 100644 go/arrow/compute/exprs/extension_types.go create mode 100644 go/arrow/compute/exprs/field_refs.go create mode 100644 go/arrow/compute/exprs/types.go diff --git a/.gitignore b/.gitignore index be4b19a89397c..5d51d258bd8b7 100644 --- a/.gitignore +++ b/.gitignore @@ -95,6 +95,8 @@ swift/Arrow/.build # Go dependencies go/vendor +# go debug binaries +__debug_bin # direnv .envrc diff --git a/go/arrow/compute/arithmetic.go b/go/arrow/compute/arithmetic.go index c53b4dd10fca4..2df547e5b4fa4 100644 --- a/go/arrow/compute/arithmetic.go +++ b/go/arrow/compute/arithmetic.go @@ -627,6 +627,8 @@ func RegisterScalarArithmetic(reg FunctionRegistry) { }{ {"sub_unchecked", kernels.OpSub, decPromoteAdd, subUncheckedDoc}, {"sub", kernels.OpSubChecked, decPromoteAdd, subDoc}, + {"subtract_unchecked", kernels.OpSub, decPromoteAdd, subUncheckedDoc}, + {"subtract", kernels.OpSubChecked, decPromoteAdd, subDoc}, } for _, o := range ops { @@ -1088,8 +1090,8 @@ func Negate(ctx context.Context, opts ArithmeticOptions, input Datum) (Datum, er // input. For x in the input: // // if x > 0: 1 -// if x < 0: -1 -// if x == 0: 0 +// if x < 0: -1 +// if x == 0: 0 func Sign(ctx context.Context, input Datum) (Datum, error) { return CallFunction(ctx, "sign", nil, input) } diff --git a/go/arrow/compute/exec.go b/go/arrow/compute/exec.go index 1f6c030f641eb..6dbef8cdfbbd9 100644 --- a/go/arrow/compute/exec.go +++ b/go/arrow/compute/exec.go @@ -77,20 +77,20 @@ func execInternal(ctx context.Context, fn Function, opts FunctionOptions, passed var ( k exec.Kernel - executor kernelExecutor + executor KernelExecutor ) switch fn.Kind() { case FuncScalar: executor = scalarExecPool.Get().(*scalarExecutor) defer func() { - executor.clear() + executor.Clear() scalarExecPool.Put(executor.(*scalarExecutor)) }() case FuncVector: executor = vectorExecPool.Get().(*vectorExecutor) defer func() { - executor.clear() + executor.Clear() vectorExecPool.Put(executor.(*vectorExecutor)) }() default: diff --git a/go/arrow/compute/executor.go b/go/arrow/compute/executor.go index c341e99cb42d9..d3f1a1fd41d4c 100644 --- a/go/arrow/compute/executor.go +++ b/go/arrow/compute/executor.go @@ -88,11 +88,11 @@ var ( // then be modified to set into a context. // // The default exec context uses the following values: -// - ChunkSize = DefaultMaxChunkSize (MaxInt64) -// - PreallocContiguous = true -// - Registry = GetFunctionRegistry() -// - ExecChannelSize = 10 -// - NumParallel = runtime.NumCPU() +// - ChunkSize = DefaultMaxChunkSize (MaxInt64) +// - PreallocContiguous = true +// - Registry = GetFunctionRegistry() +// - ExecChannelSize = 10 +// - NumParallel = runtime.NumCPU() func DefaultExecCtx() ExecCtx { return defaultExecCtx } func init() { @@ -131,7 +131,7 @@ type ExecBatch struct { Values []Datum // Guarantee is a predicate Expression guaranteed to evaluate to true for // all rows in this batch. - Guarantee Expression + // Guarantee Expression // Len is the semantic length of this ExecBatch. When the values are // all scalars, the length should be set to 1 for non-aggregate kernels. // Otherwise the length is taken from the array values. Aggregate kernels @@ -384,9 +384,9 @@ func inferBatchLength(values []Datum) (length int64, allSame bool) { return } -// kernelExecutor is the interface for all executors to initialize and +// KernelExecutor is the interface for all executors to initialize and // call kernel execution functions on batches. -type kernelExecutor interface { +type KernelExecutor interface { // Init must be called *after* the kernel's init method and any // KernelState must be set into the KernelCtx *before* calling // this Init method. This is to faciliate the case where @@ -407,8 +407,8 @@ type kernelExecutor interface { // CheckResultType checks the actual result type against the resolved // output type. If the types don't match an error is returned CheckResultType(out Datum) error - - clear() + // Clear resets the state in the executor so that it can be reused. + Clear() } // the base implementation for executing non-aggregate kernels. @@ -422,7 +422,7 @@ type nonAggExecImpl struct { preallocValidity bool } -func (e *nonAggExecImpl) clear() { +func (e *nonAggExecImpl) Clear() { e.ctx, e.kernel, e.outType = nil, nil, nil if e.dataPrealloc != nil { e.dataPrealloc = e.dataPrealloc[:0] @@ -479,6 +479,8 @@ func (e *nonAggExecImpl) CheckResultType(out Datum) error { type spanIterator func() (exec.ExecSpan, int64, bool) +func NewScalarExecutor() KernelExecutor { return &scalarExecutor{} } + type scalarExecutor struct { nonAggExecImpl diff --git a/go/arrow/compute/expression.go b/go/arrow/compute/expression.go index b1a9946b0bf3b..b01c3b67133ad 100644 --- a/go/arrow/compute/expression.go +++ b/go/arrow/compute/expression.go @@ -42,9 +42,12 @@ var hashSeed = maphash.MakeSeed() // Expression is an interface for mapping one datum to another. An expression // is one of: +// // A literal Datum -// A reference to a single (potentially nested) field of an input Datum +// A reference to a single (potentially nested) field of an input Datum // A call to a compute function, with arguments specified by other Expressions +// +// Deprecated: use substrait-go expressions instead. type Expression interface { fmt.Stringer // IsBound returns true if this expression has been bound to a particular @@ -95,6 +98,8 @@ func printDatum(datum Datum) string { // Literal is an expression denoting a literal Datum which could be any value // as a scalar, an array, or so on. +// +// Deprecated: use substrait-go expressions Literal instead. type Literal struct { Literal Datum } @@ -144,6 +149,8 @@ func (l *Literal) Release() { // Parameter represents a field reference and needs to be bound in order to determine // its type and shape. +// +// Deprecated: use substrait-go field references instead. type Parameter struct { ref *FieldRef @@ -265,6 +272,8 @@ func optionsToString(fn FunctionOptions) string { // Call is a function call with specific arguments which are themselves other // expressions. A call can also have options that are specific to the function // in question. It must be bound to determine the shape and type. +// +// Deprecated: use substrait-go expression functions instead. type Call struct { funcName string args []Expression diff --git a/go/arrow/compute/exprs/builders.go b/go/arrow/compute/exprs/builders.go new file mode 100644 index 0000000000000..28e0b06fc7763 --- /dev/null +++ b/go/arrow/compute/exprs/builders.go @@ -0,0 +1,445 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//go:build go1.18 + +package exprs + +import ( + "fmt" + "strconv" + "strings" + "unicode" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/compute" + "github.com/substrait-io/substrait-go/expr" + "github.com/substrait-io/substrait-go/extensions" + "github.com/substrait-io/substrait-go/types" +) + +// NewDefaultExtensionSet constructs an empty extension set using the default +// Arrow Extension registry and the default collection of substrait extensions +// from the Substrait-go repo. +func NewDefaultExtensionSet() ExtensionIDSet { + return NewExtensionSetDefault(expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection)) +} + +// NewScalarCall constructs a substrait ScalarFunction expression with the provided +// options and arguments. +// +// The function name (fn) is looked up in the internal Arrow DefaultExtensionIDRegistry +// to ensure it exists and to convert from the Arrow function name to the substrait +// function name. It is then looked up using the DefaultCollection from the +// substrait extensions module to find the declaration. If it cannot be found, +// we try constructing the compound signature name by getting the types of the +// arguments which were passed and appending them to the function name appropriately. +// +// An error is returned if the function cannot be resolved. +func NewScalarCall(reg ExtensionIDSet, fn string, opts []*types.FunctionOption, args ...types.FuncArg) (*expr.ScalarFunction, error) { + conv, ok := reg.GetArrowRegistry().GetArrowToSubstrait(fn) + if !ok { + return nil, arrow.ErrNotFound + } + + id, convOpts, err := conv(fn) + if err != nil { + return nil, err + } + + opts = append(opts, convOpts...) + return expr.NewScalarFunc(reg.GetSubstraitRegistry(), id, opts, args...) +} + +// NewFieldRefFromDotPath constructs a substrait reference segment from +// a dot path and the base schema. +// +// dot_path = '.' name +// +// | '[' digit+ ']' +// | dot_path+ +// +// # Examples +// +// Assume root schema of {alpha: i32, beta: struct>, delta: map} +// +// ".alpha" => StructFieldRef(0) +// "[2]" => StructFieldRef(2) +// ".beta[0]" => StructFieldRef(1, StructFieldRef(0)) +// "[1].gamma[3]" => StructFieldRef(1, StructFieldRef(0, ListElementRef(3))) +// ".delta.foobar" => StructFieldRef(2, MapKeyRef("foobar")) +// +// Note: when parsing a name, a '\' preceding any other character +// will be dropped from the resulting name. Therefore if a name must +// contain the characters '.', '\', '[', or ']' then they must be escaped +// with a preceding '\'. +func NewFieldRefFromDotPath(dotpath string, rootSchema *arrow.Schema) (expr.ReferenceSegment, error) { + if len(dotpath) == 0 { + return nil, fmt.Errorf("%w dotpath was empty", arrow.ErrInvalid) + } + + parseName := func() string { + var name string + for { + idx := strings.IndexAny(dotpath, `\[.`) + if idx == -1 { + name += dotpath + dotpath = "" + break + } + + if dotpath[idx] != '\\' { + // subscript for a new field ref + name += dotpath[:idx] + dotpath = dotpath[idx:] + break + } + + if len(dotpath) == idx+1 { + // dotpath ends with a backslash; consume it all + name += dotpath + dotpath = "" + break + } + + // append all characters before backslash, then the character which follows it + name += dotpath[:idx] + string(dotpath[idx+1]) + dotpath = dotpath[idx+2:] + } + return name + } + + var curType arrow.DataType = arrow.StructOf(rootSchema.Fields()...) + children := make([]expr.ReferenceSegment, 0) + + for len(dotpath) > 0 { + subscript := dotpath[0] + dotpath = dotpath[1:] + switch subscript { + case '.': + // next element is a name + n := parseName() + switch ct := curType.(type) { + case *arrow.StructType: + idx, found := ct.FieldIdx(n) + if !found { + return nil, fmt.Errorf("%w: dot path '%s' referenced invalid field", arrow.ErrInvalid, dotpath) + } + children = append(children, &expr.StructFieldRef{Field: int32(idx)}) + curType = ct.Field(idx).Type + case *arrow.MapType: + curType = ct.KeyType() + switch ct.KeyType().ID() { + case arrow.BINARY, arrow.LARGE_BINARY: + children = append(children, &expr.MapKeyRef{MapKey: expr.NewByteSliceLiteral([]byte(n), false)}) + case arrow.STRING, arrow.LARGE_STRING: + children = append(children, &expr.MapKeyRef{MapKey: expr.NewPrimitiveLiteral(n, false)}) + default: + return nil, fmt.Errorf("%w: MapKeyRef to non-binary/string map not supported", arrow.ErrNotImplemented) + } + default: + return nil, fmt.Errorf("%w: dot path names must refer to struct fields or map keys", arrow.ErrInvalid) + } + case '[': + subend := strings.IndexFunc(dotpath, func(r rune) bool { return !unicode.IsDigit(r) }) + if subend == -1 || dotpath[subend] != ']' { + return nil, fmt.Errorf("%w: dot path '%s' contained an unterminated index", arrow.ErrInvalid, dotpath) + } + idx, _ := strconv.Atoi(dotpath[:subend]) + switch ct := curType.(type) { + case *arrow.StructType: + if idx > len(ct.Fields()) { + return nil, fmt.Errorf("%w: field out of bounds in dotpath", arrow.ErrIndex) + } + curType = ct.Field(idx).Type + children = append(children, &expr.StructFieldRef{Field: int32(idx)}) + case *arrow.MapType: + curType = ct.KeyType() + var keyLiteral expr.Literal + // TODO: implement user defined types and variations + switch ct.KeyType().ID() { + case arrow.INT8: + keyLiteral = expr.NewPrimitiveLiteral(int8(idx), false) + case arrow.INT16: + keyLiteral = expr.NewPrimitiveLiteral(int16(idx), false) + case arrow.INT32: + keyLiteral = expr.NewPrimitiveLiteral(int32(idx), false) + case arrow.INT64: + keyLiteral = expr.NewPrimitiveLiteral(int64(idx), false) + case arrow.FLOAT32: + keyLiteral = expr.NewPrimitiveLiteral(float32(idx), false) + case arrow.FLOAT64: + keyLiteral = expr.NewPrimitiveLiteral(float64(idx), false) + default: + return nil, fmt.Errorf("%w: dotpath ref to map key type %s", arrow.ErrNotImplemented, ct.KeyType()) + } + children = append(children, &expr.MapKeyRef{MapKey: keyLiteral}) + case *arrow.ListType: + curType = ct.Elem() + children = append(children, &expr.ListElementRef{Offset: int32(idx)}) + case *arrow.LargeListType: + curType = ct.Elem() + children = append(children, &expr.ListElementRef{Offset: int32(idx)}) + case *arrow.FixedSizeListType: + curType = ct.Elem() + children = append(children, &expr.ListElementRef{Offset: int32(idx)}) + default: + return nil, fmt.Errorf("%w: %s type not supported for dotpath ref", arrow.ErrInvalid, ct) + } + dotpath = dotpath[subend+1:] + default: + return nil, fmt.Errorf("%w: dot path must begin with '[' or '.' got '%s'", + arrow.ErrInvalid, dotpath) + } + } + + out := children[0] + if len(children) > 1 { + cur := out + for _, c := range children[1:] { + switch r := cur.(type) { + case *expr.StructFieldRef: + r.Child = c + case *expr.MapKeyRef: + r.Child = c + case *expr.ListElementRef: + r.Child = c + } + cur = c + } + } + + return out, nil +} + +// RefFromFieldPath constructs a substrait field reference segment +// from a compute.FieldPath which should be a slice of integers +// indicating nested field paths to travel. This will return a +// series of StructFieldRef's whose child is the next element in +// the field path. +func RefFromFieldPath(field compute.FieldPath) expr.ReferenceSegment { + if len(field) == 0 { + return nil + } + + seg := expr.NewStructFieldRef(int32(field[0])) + parent := seg + for _, ref := range field[1:] { + next := expr.NewStructFieldRef(int32(ref)) + parent.Child = next + parent = next + } + + return seg +} + +// NewFieldRef constructs a properly typed substrait field reference segment, +// from a given arrow field reference, schema and extension set (for resolving +// substrait types). +func NewFieldRef(ref compute.FieldRef, schema *arrow.Schema, ext ExtensionIDSet) (*expr.FieldReference, error) { + path, err := ref.FindOne(schema) + if err != nil { + return nil, err + } + + st, err := ToSubstraitType(arrow.StructOf(schema.Fields()...), false, ext) + if err != nil { + return nil, err + } + + return expr.NewRootFieldRef(RefFromFieldPath(path), st.(*types.StructType)) +} + +// Builder wraps the substrait-go expression Builder and FuncArgBuilder +// interfaces for a simple interface that can be passed around to build +// substrait expressions from Arrow data. +type Builder interface { + expr.Builder + expr.FuncArgBuilder +} + +// ExprBuilder is the parent for building substrait expressions +// via Arrow types and functions. +// +// The expectation is that it should be utilized like so: +// +// bldr := NewExprBuilder(extSet) +// bldr.SetInputSchema(arrowschema) +// call, err := bldr.CallScalar("equal", nil, +// bldr.FieldRef("i32"), +// bldr.Literal(expr.NewPrimitiveLiteral( +// int32(0), false))) +// ex, err := call.BuildExpr() +// ... +// result, err := exprs.ExecuteScalarExpression(ctx, arrowschema, +// ex, input) +type ExprBuilder struct { + b expr.ExprBuilder + extSet ExtensionIDSet + inputSchema *arrow.Schema +} + +// NewExprBuilder constructs a new Expression Builder that will use the +// provided extension set and registry. +func NewExprBuilder(extSet ExtensionIDSet) ExprBuilder { + return ExprBuilder{ + b: expr.ExprBuilder{Reg: extSet.GetSubstraitRegistry()}, + extSet: extSet, + } +} + +// SetInputSchema sets the current Arrow schema that will be utilized +// for performing field reference and field type resolutions. +func (e *ExprBuilder) SetInputSchema(s *arrow.Schema) error { + st, err := ToSubstraitType(arrow.StructOf(s.Fields()...), false, e.extSet) + if err != nil { + return err + } + + e.inputSchema = s + e.b.BaseSchema = st.(*types.StructType) + return nil +} + +// MustCallScalar is like CallScalar, but will panic on error rather than +// return it. +func (e *ExprBuilder) MustCallScalar(fn string, opts []*types.FunctionOption, args ...expr.FuncArgBuilder) Builder { + b, err := e.CallScalar(fn, opts, args...) + if err != nil { + panic(err) + } + return b +} + +// CallScalar constructs a builder for a scalar function call. The function +// name is expected to be valid in the Arrow function registry which will +// map it properly to a substrait expression by resolving the types of +// the arguments. Examples are: "greater", "multiply", "equal", etc. +// +// Can return arrow.ErrNotFound if there is no function mapping found. +// Or will forward any error encountered when converting from an Arrow +// function to a substrait one. +func (e *ExprBuilder) CallScalar(fn string, opts []*types.FunctionOption, args ...expr.FuncArgBuilder) (Builder, error) { + conv, ok := e.extSet.GetArrowRegistry().GetArrowToSubstrait(fn) + if !ok { + return nil, arrow.ErrNotFound + } + + id, convOpts, err := conv(fn) + if err != nil { + return nil, err + } + + opts = append(opts, convOpts...) + return e.b.ScalarFunc(id, opts...).Args(args...), nil +} + +// FieldPath uses a field path to construct a Field Reference +// expression. +func (e *ExprBuilder) FieldPath(path compute.FieldPath) Builder { + segments := make([]expr.ReferenceSegment, len(path)) + for i, p := range path { + segments[i] = expr.NewStructFieldRef(int32(p)) + } + + return e.b.RootRef(expr.FlattenRefSegments(segments...)) +} + +// FieldIndex is shorthand for creating a single field reference +// to the struct field index provided. +func (e *ExprBuilder) FieldIndex(i int) Builder { + return e.b.RootRef(expr.NewStructFieldRef(int32(i))) +} + +// FieldRef constructs a field reference expression to the field with +// the given name from the input. It will be resolved to a field +// index when calling BuildExpr. +func (e *ExprBuilder) FieldRef(field string) Builder { + return &refBuilder{eb: e, fieldRef: compute.FieldRefName(field)} +} + +// FieldRefList accepts a list of either integers or strings to +// construct a field reference expression from. This will panic +// if any of elems are not a string or int. +// +// Field names will be resolved to their indexes when BuildExpr is called +// by using the provided Arrow schema. +func (e *ExprBuilder) FieldRefList(elems ...any) Builder { + return &refBuilder{eb: e, fieldRef: compute.FieldRefList(elems...)} +} + +// Literal wraps a substrait literal to be used as an argument to +// building other expressions. +func (e *ExprBuilder) Literal(l expr.Literal) Builder { + return e.b.Literal(l) +} + +// WrapLiteral is a convenience for accepting functions like NewLiteral +// which can potentially return an error. If an error is encountered, +// it will be surfaced when BuildExpr is called. +func (e *ExprBuilder) WrapLiteral(l expr.Literal, err error) Builder { + return e.b.Wrap(l, err) +} + +// Must is a convenience wrapper for any method that returns a Builder +// and error, panic'ing if it received an error or otherwise returning +// the Builder. +func (*ExprBuilder) Must(b Builder, err error) Builder { + if err != nil { + panic(err) + } + return b +} + +// Cast returns a Cast expression with the FailBehavior of ThrowException, +// erroring for invalid casts. +func (e *ExprBuilder) Cast(from Builder, to arrow.DataType) (Builder, error) { + t, err := ToSubstraitType(to, true, e.extSet) + if err != nil { + return nil, err + } + + return e.b.Cast(from, t).FailBehavior(types.BehaviorThrowException), nil +} + +type refBuilder struct { + eb *ExprBuilder + + fieldRef compute.FieldRef +} + +func (r *refBuilder) BuildFuncArg() (types.FuncArg, error) { + return r.BuildExpr() +} + +func (r *refBuilder) BuildExpr() (expr.Expression, error) { + if r.eb.inputSchema == nil { + return nil, fmt.Errorf("%w: no input schema specified for ref", arrow.ErrInvalid) + } + + path, err := r.fieldRef.FindOne(r.eb.inputSchema) + if err != nil { + return nil, err + } + + segments := make([]expr.ReferenceSegment, len(path)) + for i, p := range path { + segments[i] = expr.NewStructFieldRef(int32(p)) + } + + return r.eb.b.RootRef(expr.FlattenRefSegments(segments...)).Build() +} diff --git a/go/arrow/compute/exprs/builders_test.go b/go/arrow/compute/exprs/builders_test.go new file mode 100644 index 0000000000000..7b6f3ecac7c4a --- /dev/null +++ b/go/arrow/compute/exprs/builders_test.go @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 + +package exprs_test + +import ( + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/compute/exprs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/expr" +) + +func TestNewScalarFunc(t *testing.T) { + reg := exprs.NewDefaultExtensionSet() + + fn, err := exprs.NewScalarCall(reg, "add", nil, + expr.NewPrimitiveLiteral(int32(1), false), + expr.NewPrimitiveLiteral(int32(10), false)) + require.NoError(t, err) + + assert.Equal(t, "add(i32(1), i32(10), {overflow: [ERROR]}) => i32", fn.String()) + assert.Equal(t, "add:i32_i32", fn.Name()) +} + +func TestFieldRefDotPath(t *testing.T) { + f0 := arrow.Field{Name: "alpha", Type: arrow.PrimitiveTypes.Int32} + f1_0 := arrow.Field{Name: "be.ta", Type: arrow.PrimitiveTypes.Int32} + f1 := arrow.Field{Name: "beta", Type: arrow.StructOf(f1_0)} + f2_0 := arrow.Field{Name: "alpha", Type: arrow.PrimitiveTypes.Int32} + f2_1_0 := arrow.Field{Name: "[alpha]", Type: arrow.MapOf(arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int32)} + f2_1_1 := arrow.Field{Name: "beta", Type: arrow.ListOf(arrow.PrimitiveTypes.Int32)} + f2_1 := arrow.Field{Name: "gamma", Type: arrow.StructOf(f2_1_0, f2_1_1)} + f2 := arrow.Field{Name: "gamma", Type: arrow.StructOf(f2_0, f2_1)} + s := arrow.NewSchema([]arrow.Field{f0, f1, f2}, nil) + + tests := []struct { + dotpath string + shouldErr bool + expected expr.ReferenceSegment + }{ + {".alpha", false, &expr.StructFieldRef{Field: 0}}, + {"[2]", false, &expr.StructFieldRef{Field: 2}}, + {".beta[0]", false, &expr.StructFieldRef{Field: 1, Child: &expr.StructFieldRef{Field: 0}}}, + {"[2].gamma[1][5]", false, &expr.StructFieldRef{Field: 2, + Child: &expr.StructFieldRef{Field: 1, + Child: &expr.StructFieldRef{Field: 1, + Child: &expr.ListElementRef{Offset: 5}}}}}, + {"[2].gamma[0].foobar", false, &expr.StructFieldRef{Field: 2, + Child: &expr.StructFieldRef{Field: 1, + Child: &expr.StructFieldRef{Field: 0, + Child: &expr.MapKeyRef{MapKey: expr.NewPrimitiveLiteral("foobar", false)}}}}}, + {`[1].be\.ta`, false, &expr.StructFieldRef{Field: 1, Child: &expr.StructFieldRef{Field: 0}}}, + {`[2].gamma.\[alpha\]`, false, &expr.StructFieldRef{Field: 2, + Child: &expr.StructFieldRef{Field: 1, + Child: &expr.StructFieldRef{Field: 0}}}}, + {`[5]`, true, nil}, // bad struct index + {``, true, nil}, // empty + {`delta`, true, nil}, // not found + {`[1234`, true, nil}, // bad syntax + {`[1stuf]`, true, nil}, // bad syntax + } + + for _, tt := range tests { + t.Run(tt.dotpath, func(t *testing.T) { + ref, err := exprs.NewFieldRefFromDotPath(tt.dotpath, s) + if tt.shouldErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Truef(t, tt.expected.Equals(ref), "expected: %s\ngot: %s", tt.expected, ref) + } + }) + } +} diff --git a/go/arrow/compute/exprs/exec.go b/go/arrow/compute/exprs/exec.go new file mode 100644 index 0000000000000..74e6435b8dafb --- /dev/null +++ b/go/arrow/compute/exprs/exec.go @@ -0,0 +1,620 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 + +package exprs + +import ( + "context" + "fmt" + "unsafe" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/compute" + "github.com/apache/arrow/go/v13/arrow/compute/internal/exec" + "github.com/apache/arrow/go/v13/arrow/decimal128" + "github.com/apache/arrow/go/v13/arrow/endian" + "github.com/apache/arrow/go/v13/arrow/internal/debug" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/apache/arrow/go/v13/arrow/scalar" + "github.com/substrait-io/substrait-go/expr" + "github.com/substrait-io/substrait-go/extensions" + "github.com/substrait-io/substrait-go/types" +) + +func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) { + // cleanup if we get an error + defer func() { + if err != nil { + for _, v := range out.Values { + if v != nil { + v.Release() + } + } + } + }() + + if partial.Kind() == compute.KindRecord { + partialBatch := partial.(*compute.RecordDatum).Value + batchSchema := partialBatch.Schema() + + out.Values = make([]compute.Datum, len(schema.Fields())) + out.Len = partialBatch.NumRows() + + for i, field := range schema.Fields() { + idxes := batchSchema.FieldIndices(field.Name) + switch len(idxes) { + case 0: + out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type)) + case 1: + col := partialBatch.Column(idxes[0]) + if !arrow.TypeEqual(col.DataType(), field.Type) { + // referenced field was present but didn't have expected type + // we'll cast this case for now + col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type)) + if err != nil { + return compute.ExecBatch{}, err + } + defer col.Release() + } + out.Values[i] = compute.NewDatum(col) + default: + err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match", + arrow.ErrInvalid, field.Name) + return compute.ExecBatch{}, err + } + } + return + } + + part, ok := partial.(compute.ArrayLikeDatum) + if !ok { + return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial) + } + + // wasteful but useful for testing + if part.Type().ID() == arrow.STRUCT { + switch part := part.(type) { + case *compute.ArrayDatum: + arr := part.MakeArray().(*array.Struct) + defer arr.Release() + + batch := array.RecordFromStructArray(arr, nil) + defer batch.Release() + return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch)) + case *compute.ScalarDatum: + out.Len = 1 + out.Values = make([]compute.Datum, len(schema.Fields())) + + s := part.Value.(*scalar.Struct) + dt := s.Type.(*arrow.StructType) + + for i, field := range schema.Fields() { + idx, found := dt.FieldIdx(field.Name) + if !found { + out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type)) + continue + } + + val := s.Value[idx] + if !arrow.TypeEqual(val.DataType(), field.Type) { + // referenced field was present but didn't have the expected + // type. for now we'll cast this + val, err = val.CastTo(field.Type) + if err != nil { + return compute.ExecBatch{}, err + } + } + out.Values[i] = compute.NewDatum(val) + } + return + } + } + + return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial) +} + +// ToArrowSchema takes a substrait NamedStruct and an extension set (for +// type resolution mapping) and creates the equivalent Arrow Schema. +func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) { + fields := make([]arrow.Field, len(base.Names)) + for i, typ := range base.Struct.Types { + dt, nullable, err := FromSubstraitType(typ, ext) + if err != nil { + return nil, err + } + fields[i] = arrow.Field{ + Name: base.Names[i], + Type: dt, + Nullable: nullable, + } + } + + return arrow.NewSchema(fields, nil), nil +} + +type ( + regCtxKey struct{} + extCtxKey struct{} +) + +func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context { + return context.WithValue(ctx, regCtxKey{}, reg) +} + +func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry { + v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry) + if !ok { + v = DefaultExtensionIDRegistry + } + return v +} + +func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context { + return context.WithValue(ctx, extCtxKey{}, ext) +} + +func GetExtensionIDSet(ctx context.Context) ExtensionIDSet { + v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet) + if !ok { + return NewExtensionSet( + expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection), + GetExtensionRegistry(ctx)) + } + return v +} + +func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) { + switch v := lit.(type) { + case *expr.PrimitiveLiteral[bool]: + return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil + case *expr.PrimitiveLiteral[int8]: + return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil + case *expr.PrimitiveLiteral[int16]: + return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil + case *expr.PrimitiveLiteral[int32]: + return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil + case *expr.PrimitiveLiteral[int64]: + return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil + case *expr.PrimitiveLiteral[float32]: + return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil + case *expr.PrimitiveLiteral[float64]: + return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil + case *expr.PrimitiveLiteral[string]: + return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil + case *expr.PrimitiveLiteral[types.Timestamp]: + return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil + case *expr.PrimitiveLiteral[types.TimestampTz]: + return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), + &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil + case *expr.PrimitiveLiteral[types.Date]: + return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil + case *expr.PrimitiveLiteral[types.Time]: + return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil + case *expr.PrimitiveLiteral[types.FixedChar]: + length := int(v.Type.(*types.FixedCharType).Length) + return compute.NewDatum(scalar.NewExtensionScalar( + scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)), + &arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil + case *expr.ByteSliceLiteral[[]byte]: + return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil + case *expr.ByteSliceLiteral[types.UUID]: + return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar( + memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil + case *expr.ByteSliceLiteral[types.FixedBinary]: + return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value), + &arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil + case *expr.NullLiteral: + dt, _, err := FromSubstraitType(v.Type, ext) + if err != nil { + return nil, err + } + return compute.NewDatum(scalar.MakeNullScalar(dt)), nil + case *expr.ListLiteral: + var elemType arrow.DataType + + values := make([]scalar.Scalar, len(v.Value)) + for i, val := range v.Value { + d, err := literalToDatum(mem, val, ext) + if err != nil { + return nil, err + } + defer d.Release() + values[i] = d.(*compute.ScalarDatum).Value + if elemType != nil { + if !arrow.TypeEqual(values[i].DataType(), elemType) { + return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values", + arrow.ErrInvalid, v) + } + } else { + elemType = values[i].DataType() + } + } + + bldr := array.NewBuilder(memory.DefaultAllocator, elemType) + defer bldr.Release() + if err := scalar.AppendSlice(bldr, values); err != nil { + return nil, err + } + arr := bldr.NewArray() + defer arr.Release() + return compute.NewDatum(scalar.NewListScalar(arr)), nil + case *expr.MapLiteral: + dt, _, err := FromSubstraitType(v.Type, ext) + if err != nil { + return nil, err + } + + mapType, ok := dt.(*arrow.MapType) + if !ok { + return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid) + } + + keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value)) + for i, kv := range v.Value { + k, err := literalToDatum(mem, kv.Key, ext) + if err != nil { + return nil, err + } + defer k.Release() + scalarKey := k.(*compute.ScalarDatum).Value + + v, err := literalToDatum(mem, kv.Value, ext) + if err != nil { + return nil, err + } + defer v.Release() + scalarValue := v.(*compute.ScalarDatum).Value + + if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) { + return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s", + arrow.ErrInvalid, mapType, scalarKey.DataType()) + } + if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) { + return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s", + arrow.ErrInvalid, mapType, scalarValue.DataType()) + } + + keys[i], values[i] = scalarKey, scalarValue + } + + keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType()) + defer keyBldr.Release() + defer valBldr.Release() + + if err := scalar.AppendSlice(keyBldr, keys); err != nil { + return nil, err + } + if err := scalar.AppendSlice(valBldr, values); err != nil { + return nil, err + } + + keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray() + defer keyArr.Release() + defer valArr.Release() + + kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"}) + if err != nil { + return nil, err + } + defer kvArr.Release() + + return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil + case *expr.StructLiteral: + fields := make([]scalar.Scalar, len(v.Value)) + names := make([]string, len(v.Value)) + + for i, l := range v.Value { + lit, err := literalToDatum(mem, l, ext) + if err != nil { + return nil, err + } + fields[i] = lit.(*compute.ScalarDatum).Value + } + + s, err := scalar.NewStructScalarWithNames(fields, names) + return compute.NewDatum(s), err + case *expr.ProtoLiteral: + switch v := v.Value.(type) { + case *types.Decimal: + if len(v.Value) != arrow.Decimal128SizeBytes { + return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)", + arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes) + } + + var val decimal128.Num + data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:] + copy(data, v.Value) + if endian.IsBigEndian { + // reverse the bytes + for i := len(data)/2 - 1; i >= 0; i-- { + opp := len(data) - 1 - i + data[i], data[opp] = data[opp], data[i] + } + } + + return compute.NewDatum(scalar.NewDecimal128Scalar(val, + &arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil + case *types.UserDefinedLiteral: // not yet implemented + case *types.IntervalYearToMonth: + bldr := array.NewInt32Builder(memory.DefaultAllocator) + defer bldr.Release() + typ := intervalYear() + bldr.Append(v.Years) + bldr.Append(v.Months) + arr := bldr.NewArray() + defer arr.Release() + return &compute.ScalarDatum{Value: scalar.NewExtensionScalar( + scalar.NewFixedSizeListScalar(arr), typ)}, nil + case *types.IntervalDayToSecond: + bldr := array.NewInt32Builder(memory.DefaultAllocator) + defer bldr.Release() + typ := intervalDay() + bldr.Append(v.Days) + bldr.Append(v.Seconds) + arr := bldr.NewArray() + defer arr.Release() + return &compute.ScalarDatum{Value: scalar.NewExtensionScalar( + scalar.NewFixedSizeListScalar(arr), typ)}, nil + case *types.VarChar: + return compute.NewDatum(scalar.NewExtensionScalar( + scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil + } + } + + return nil, arrow.ErrNotImplemented +} + +// ExecuteScalarExpression executes the given substrait expression using the provided datum as input. +// It will first create an exec batch using the input schema and the datum. +// The datum may have missing or incorrectly ordered columns while the input schema +// should describe the expected input schema for the expression. Missing fields will +// be replaced with null scalars and incorrectly ordered columns will be re-ordered +// according to the schema. +// +// You can provide an allocator to use through the context via compute.WithAllocator. +// +// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet. +func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) { + if expression == nil { + return nil, arrow.ErrInvalid + } + + batch, err := makeExecBatch(ctx, inputSchema, partialInput) + if err != nil { + return nil, err + } + defer func() { + for _, v := range batch.Values { + v.Release() + } + }() + + return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx)) +} + +// ExecuteScalarSubstrait uses the provided Substrait extended expression to +// determine the expected input schema (replacing missing fields in the partial +// input datum with null scalars and re-ordering columns if necessary) and +// ExtensionIDSet to use. You can provide the extension registry to use +// through the context via WithExtensionRegistry, otherwise the default +// Arrow registry will be used. You can provide a memory.Allocator to use +// the same way via compute.WithAllocator. +func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) { + if expression == nil { + return nil, arrow.ErrInvalid + } + + var toExecute expr.Expression + + switch len(expression.ReferredExpr) { + case 0: + return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid) + case 1: + if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil { + return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented) + } + default: + return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented) + } + + reg := GetExtensionRegistry(ctx) + set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg) + sc, err := ToArrowSchema(expression.BaseSchema, set) + if err != nil { + return nil, err + } + + return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput) +} + +func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) { + if e.Root != expr.RootReference { + return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented) + } + + ref, ok := e.Reference.(expr.ReferenceSegment) + if !ok { + return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented) + } + + expectedType, _, err := FromSubstraitType(e.GetType(), ext) + if err != nil { + return nil, err + } + + var param compute.Datum + if sref, ok := ref.(*expr.StructFieldRef); ok { + if sref.Field < 0 || sref.Field >= int32(len(input.Values)) { + return nil, arrow.ErrInvalid + } + param = input.Values[sref.Field] + ref = ref.GetChild() + } + + out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext) + if err == compute.ErrEmpty { + out = compute.NewDatum(param) + } else if err != nil { + return nil, err + } + if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) { + return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s", + arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType) + } + + return out, nil +} + +func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) { + if !exp.IsScalar() { + return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions", + arrow.ErrInvalid) + } + + switch e := exp.(type) { + case expr.Literal: + return literalToDatum(compute.GetAllocator(ctx), e, ext) + case *expr.FieldReference: + return execFieldRef(ctx, e, input, ext) + case *expr.Cast: + if e.Input == nil { + return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid) + } + + arg, err := executeScalarBatch(ctx, input, e.Input, ext) + if err != nil { + return nil, err + } + defer arg.Release() + + dt, _, err := FromSubstraitType(e.Type, ext) + if err != nil { + return nil, fmt.Errorf("%w: could not determine type for cast", err) + } + + var opts *compute.CastOptions + switch e.FailureBehavior { + case types.BehaviorThrowException: + opts = compute.UnsafeCastOptions(dt) + case types.BehaviorUnspecified: + return nil, fmt.Errorf("%w: cast behavior unspecified", arrow.ErrInvalid) + case types.BehaviorReturnNil: + return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented) + } + return compute.CastDatum(ctx, arg, opts) + case *expr.ScalarFunction: + var ( + err error + allScalar = true + args = make([]compute.Datum, e.NArgs()) + argTypes = make([]arrow.DataType, e.NArgs()) + ) + for i := 0; i < e.NArgs(); i++ { + switch v := e.Arg(i).(type) { + case types.Enum: + args[i] = compute.NewDatum(scalar.NewStringScalar(string(v))) + case expr.Expression: + args[i], err = executeScalarBatch(ctx, input, v, ext) + if err != nil { + return nil, err + } + defer args[i].Release() + + if args[i].Kind() != compute.KindScalar { + allScalar = false + } + default: + return nil, arrow.ErrNotImplemented + } + + argTypes[i] = args[i].(compute.ArrayLikeDatum).Type() + } + + _, conv, ok := ext.DecodeFunction(e.FuncRef()) + if !ok { + return nil, arrow.ErrNotImplemented + } + + fname, opts, err := conv(e) + if err != nil { + return nil, err + } + + ectx := compute.GetExecCtx(ctx) + fn, ok := ectx.Registry.GetFunction(fname) + if !ok { + return nil, arrow.ErrInvalid + } + + if fn.Kind() != compute.FuncScalar { + return nil, arrow.ErrInvalid + } + + k, err := fn.DispatchBest(argTypes...) + if err != nil { + return nil, err + } + + kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k} + init := k.GetInitFn() + kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: argTypes, Options: opts} + if init != nil { + kctx.State, err = init(kctx, kinitArgs) + if err != nil { + return nil, err + } + } + + executor := compute.NewScalarExecutor() + if err := executor.Init(kctx, kinitArgs); err != nil { + return nil, err + } + + batch := compute.ExecBatch{Values: args} + if allScalar { + batch.Len = 1 + } else { + batch.Len = input.Len + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := make(chan compute.Datum, ectx.ExecChannelSize) + go func() { + defer close(ch) + if err = executor.Execute(ctx, &batch, ch); err != nil { + cancel() + } + }() + + result := executor.WrapResults(ctx, ch, false) + if err == nil { + debug.Assert(executor.CheckResultType(result) == nil, "invalid result type") + } + + if ctx.Err() == context.Canceled && result != nil { + result.Release() + } + + return result, nil + } + + return nil, arrow.ErrNotImplemented +} diff --git a/go/arrow/compute/exprs/exec_internal_test.go b/go/arrow/compute/exprs/exec_internal_test.go new file mode 100644 index 0000000000000..a41a21ef59d11 --- /dev/null +++ b/go/arrow/compute/exprs/exec_internal_test.go @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 + +package exprs + +import ( + "context" + "strings" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/compute" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + boringArrowSchema = arrow.NewSchema([]arrow.Field{ + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + {Name: "i8", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + {Name: "i32", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "i32_req", Type: arrow.PrimitiveTypes.Int32}, + {Name: "u32", Type: arrow.PrimitiveTypes.Uint32, Nullable: true}, + {Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "f32", Type: arrow.PrimitiveTypes.Float32, Nullable: true}, + {Name: "f32_req", Type: arrow.PrimitiveTypes.Float32}, + {Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "date32", Type: arrow.FixedWidthTypes.Date32, Nullable: true}, + {Name: "str", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "bin", Type: arrow.BinaryTypes.Binary, Nullable: true}, + }, nil) +) + +func TestMakeExecBatch(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + const numRows = 3 + var ( + ctx = compute.WithAllocator(context.Background(), mem) + i32, _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Int32, strings.NewReader(`[1, 2, 3]`)) + f32, _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Float32, strings.NewReader(`[1.5, 2.25, 3.125]`)) + empty, _, _ = array.RecordFromJSON(mem, boringArrowSchema, strings.NewReader(`[]`)) + ) + defer i32.Release() + defer f32.Release() + + getField := func(n string) arrow.Field { + f, _ := boringArrowSchema.FieldsByName(n) + return f[0] + } + + tests := []struct { + name string + batch arrow.Record + }{ + {"empty", empty}, + {"subset", array.NewRecord(arrow.NewSchema([]arrow.Field{getField("i32"), getField("f32")}, nil), + []arrow.Array{i32, f32}, numRows)}, + {"flipped subset", array.NewRecord(arrow.NewSchema([]arrow.Field{getField("f32"), getField("i32")}, nil), + []arrow.Array{f32, i32}, numRows)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer tt.batch.Release() + batch, err := makeExecBatch(ctx, boringArrowSchema, compute.NewDatumWithoutOwning(tt.batch)) + require.NoError(t, err) + require.Equal(t, tt.batch.NumRows(), batch.Len) + + defer func() { + for _, v := range batch.Values { + v.Release() + } + }() + + for i, field := range boringArrowSchema.Fields() { + typ := batch.Values[i].(compute.ArrayLikeDatum).Type() + assert.Truef(t, arrow.TypeEqual(typ, field.Type), + "expected: %s\ngot: %s", field.Type, typ) + + idxes := tt.batch.Schema().FieldIndices(field.Name) + if batch.Values[i].Kind() == compute.KindScalar { + assert.False(t, batch.Values[i].(*compute.ScalarDatum).Value.IsValid(), + "null placeholder should be injected") + assert.Len(t, idxes, 0, "should only happen when column isn't found") + } else { + col := tt.batch.Column(idxes[0]) + val := batch.Values[i].(*compute.ArrayDatum).MakeArray() + defer val.Release() + + assert.Truef(t, array.Equal(col, val), "expected: %s\ngot: %s", col, val) + } + } + }) + } +} diff --git a/go/arrow/compute/exprs/exec_test.go b/go/arrow/compute/exprs/exec_test.go new file mode 100644 index 0000000000000..9cadcb37ad21a --- /dev/null +++ b/go/arrow/compute/exprs/exec_test.go @@ -0,0 +1,461 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 + +package exprs_test + +import ( + "context" + "strings" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/compute" + "github.com/apache/arrow/go/v13/arrow/compute/exprs" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/apache/arrow/go/v13/arrow/scalar" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/expr" + "github.com/substrait-io/substrait-go/types" +) + +var ( + extSet = exprs.NewDefaultExtensionSet() + _, u32TypeRef, _ = extSet.EncodeTypeVariation(arrow.PrimitiveTypes.Uint32) + + boringSchema = types.NamedStruct{ + Names: []string{ + "bool", "i8", "i32", "i32_req", + "u32", "i64", "f32", "f32_req", + "f64", "date32", "str", "bin"}, + Struct: types.StructType{ + Nullability: types.NullabilityRequired, + Types: []types.Type{ + &types.BooleanType{}, + &types.Int8Type{}, + &types.Int32Type{}, + &types.Int32Type{Nullability: types.NullabilityRequired}, + &types.Int32Type{ + TypeVariationRef: u32TypeRef, + }, + &types.Int64Type{}, + &types.Float32Type{}, + &types.Float32Type{Nullability: types.NullabilityRequired}, + &types.Float64Type{}, + &types.DateType{}, + &types.StringType{}, + &types.BinaryType{}, + }, + }, + } + + boringArrowSchema = arrow.NewSchema([]arrow.Field{ + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + {Name: "i8", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + {Name: "i32", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "u32", Type: arrow.PrimitiveTypes.Uint32, Nullable: true}, + {Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "f32", Type: arrow.PrimitiveTypes.Float32, Nullable: true}, + {Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "date32", Type: arrow.FixedWidthTypes.Date32, Nullable: true}, + {Name: "str", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "bin", Type: arrow.BinaryTypes.Binary, Nullable: true}, + }, nil) +) + +func TestToArrowSchema(t *testing.T) { + expectedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + {Name: "i8", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + {Name: "i32", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "i32_req", Type: arrow.PrimitiveTypes.Int32}, + {Name: "u32", Type: arrow.PrimitiveTypes.Uint32, Nullable: true}, + {Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "f32", Type: arrow.PrimitiveTypes.Float32, Nullable: true}, + {Name: "f32_req", Type: arrow.PrimitiveTypes.Float32}, + {Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "date32", Type: arrow.FixedWidthTypes.Date32, Nullable: true}, + {Name: "str", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "bin", Type: arrow.BinaryTypes.Binary, Nullable: true}, + }, nil) + + sc, err := exprs.ToArrowSchema(boringSchema, extSet) + assert.NoError(t, err) + + assert.Truef(t, expectedSchema.Equal(sc), "expected: %s\ngot: %s", expectedSchema, sc) +} + +func assertEqual(t *testing.T, expected, actual any) bool { + switch e := expected.(type) { + case compute.Datum: + return assert.Truef(t, e.Equals(compute.NewDatumWithoutOwning(actual)), + "expected: %s\ngot: %s", e, actual) + case arrow.Array: + switch a := actual.(type) { + case compute.Datum: + if a.Kind() == compute.KindArray { + actual := a.(*compute.ArrayDatum).MakeArray() + defer actual.Release() + return assert.Truef(t, array.Equal(e, actual), "expected: %s\ngot: %s", + e, actual) + } + case arrow.Array: + return assert.Truef(t, array.Equal(e, a), "expected: %s\ngot: %s", + e, actual) + } + t.Errorf("expected arrow Array, got %s", actual) + return false + } + panic("unimplemented comparison") +} + +func TestComparisons(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + var ( + ctx = compute.WithAllocator(context.Background(), mem) + zero = scalar.MakeScalar(int32(0)) + one = scalar.MakeScalar(int32(1)) + two = scalar.MakeScalar(int32(2)) + + str = scalar.MakeScalar("hello") + bin = scalar.MakeScalar([]byte("hello")) + ) + + getArgType := func(dt arrow.DataType) types.Type { + switch dt.ID() { + case arrow.INT32: + return &types.Int32Type{} + case arrow.STRING: + return &types.StringType{} + case arrow.BINARY: + return &types.BinaryType{} + } + panic("wtf") + } + + expect := func(t *testing.T, fn string, arg1, arg2 scalar.Scalar, res bool) { + baseStruct := types.NamedStruct{ + Names: []string{"arg1", "arg2"}, + Struct: types.StructType{ + Types: []types.Type{getArgType(arg1.DataType()), getArgType(arg2.DataType())}, + }, + } + + ex, err := exprs.NewScalarCall(extSet, fn, nil, + expr.MustExpr(expr.NewRootFieldRef(expr.NewStructFieldRef(0), &baseStruct.Struct)), + expr.MustExpr(expr.NewRootFieldRef(expr.NewStructFieldRef(1), &baseStruct.Struct))) + require.NoError(t, err) + + expression := &expr.Extended{ + Extensions: extSet.GetSubstraitRegistry().Set, + ReferredExpr: []expr.ExpressionReference{ + expr.NewExpressionReference([]string{"out"}, ex), + }, + BaseSchema: baseStruct, + } + + input, _ := scalar.NewStructScalarWithNames([]scalar.Scalar{arg1, arg2}, []string{"arg1", "arg2"}) + out, err := exprs.ExecuteScalarSubstrait(ctx, expression, compute.NewDatum(input)) + require.NoError(t, err) + require.Equal(t, compute.KindScalar, out.Kind()) + + result := out.(*compute.ScalarDatum).Value + assert.Equal(t, res, result.(*scalar.Boolean).Value) + } + + expect(t, "equal", one, one, true) + expect(t, "equal", one, two, false) + expect(t, "less", one, two, true) + expect(t, "less", one, zero, false) + expect(t, "greater", one, zero, true) + expect(t, "greater", one, two, false) + + expect(t, "equal", str, bin, true) + expect(t, "equal", bin, str, true) +} + +func TestExecuteFieldRef(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + fromJSON := func(ty arrow.DataType, json string) arrow.Array { + arr, _, err := array.FromJSON(mem, ty, strings.NewReader(json)) + require.NoError(t, err) + return arr + } + + scalarFromJSON := func(ty arrow.DataType, json string) scalar.Scalar { + arr, _, err := array.FromJSON(mem, ty, strings.NewReader(json)) + require.NoError(t, err) + defer arr.Release() + s, err := scalar.GetScalar(arr, 0) + require.NoError(t, err) + return s + } + + tests := []struct { + testName string + ref compute.FieldRef + input compute.Datum + expected compute.Datum + }{ + {"basic ref", compute.FieldRefName("a"), compute.NewDatumWithoutOwning(fromJSON( + arrow.StructOf(arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + `[ + {"a": 6.125}, + {"a": 0.0}, + {"a": -1} + ]`)), compute.NewDatumWithoutOwning(fromJSON( + arrow.PrimitiveTypes.Float64, `[6.125, 0.0, -1]`))}, + {"ref one field", compute.FieldRefName("a"), compute.NewDatumWithoutOwning(fromJSON( + arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + `[ + {"a": 6.125, "b": 7.5}, + {"a": 0.0, "b": 2.125}, + {"a": -1, "b": 4.0} + ]`)), compute.NewDatumWithoutOwning(fromJSON( + arrow.PrimitiveTypes.Float64, `[6.125, 0.0, -1]`))}, + {"second field", compute.FieldRefName("b"), compute.NewDatumWithoutOwning(fromJSON( + arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + `[ + {"a": 6.125, "b": 7.5}, + {"a": 0.0, "b": 2.125}, + {"a": -1, "b": 4.0} + ]`)), compute.NewDatumWithoutOwning(fromJSON( + arrow.PrimitiveTypes.Float64, `[7.5, 2.125, 4.0]`))}, + {"nested field by path", compute.FieldRefPath(compute.FieldPath{0, 0}), compute.NewDatumWithoutOwning(fromJSON( + arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.StructOf( + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + Nullable: true}), + `[ + {"a": {"b": 6.125}}, + {"a": {"b": 0.0}}, + {"a": {"b": -1}} + ]`)), compute.NewDatumWithoutOwning(fromJSON( + arrow.PrimitiveTypes.Float64, `[6.125, 0.0, -1]`))}, + {"nested field by name", compute.FieldRefList("a", "b"), compute.NewDatumWithoutOwning(fromJSON( + arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.StructOf( + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + Nullable: true}), + `[ + {"a": {"b": 6.125}}, + {"a": {"b": 0.0}}, + {"a": {"b": -1}} + ]`)), compute.NewDatumWithoutOwning(fromJSON( + arrow.PrimitiveTypes.Float64, `[6.125, 0.0, -1]`))}, + {"nested field with nulls", compute.FieldRefList("a", "b"), compute.NewDatumWithoutOwning(fromJSON( + arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.StructOf( + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + Nullable: true}), + `[ + {"a": {"b": 6.125}}, + {"a": null}, + {"a": {"b": null}} + ]`)), compute.NewDatumWithoutOwning(fromJSON( + arrow.PrimitiveTypes.Float64, `[6.125, null, null]`))}, + {"nested scalar", compute.FieldRefList("a", "b"), compute.NewDatumWithoutOwning( + scalarFromJSON(arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.StructOf( + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + Nullable: true}), `[{"a": {"b": 64.0}}]`)), + compute.NewDatum(scalar.NewFloat64Scalar(64.0))}, + {"nested scalar with null", compute.FieldRefList("a", "b"), compute.NewDatumWithoutOwning( + scalarFromJSON(arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.StructOf( + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + Nullable: true}), `[{"a": {"b": null}}]`)), + compute.NewDatum(scalar.MakeNullScalar(arrow.PrimitiveTypes.Float64))}, + {"nested scalar null", compute.FieldRefList("a", "b"), compute.NewDatumWithoutOwning( + scalarFromJSON(arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.StructOf( + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}), + Nullable: true}), `[{"a": null}]`)), + compute.NewDatum(scalar.MakeNullScalar(arrow.PrimitiveTypes.Float64))}, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + scoped := memory.NewCheckedAllocatorScope(mem) + defer scoped.CheckSize(t) + + ctx := exprs.WithExtensionIDSet(compute.WithAllocator(context.Background(), mem), extSet) + dt := tt.input.(compute.ArrayLikeDatum).Type().(arrow.NestedType) + schema := arrow.NewSchema(dt.Fields(), nil) + ref, err := exprs.NewFieldRef(tt.ref, schema, extSet) + require.NoError(t, err) + assert.NotNil(t, ref) + + actual, err := exprs.ExecuteScalarExpression(ctx, schema, ref, tt.input) + require.NoError(t, err) + defer actual.Release() + + assert.Truef(t, tt.expected.Equals(actual), "expected: %s\ngot: %s", tt.expected, actual) + }) + } +} + +func TestExecuteScalarFuncCall(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + fromJSON := func(ty arrow.DataType, json string) arrow.Array { + arr, _, err := array.FromJSON(mem, ty, strings.NewReader(json)) + require.NoError(t, err) + return arr + } + + basicSchema := arrow.NewSchema([]arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + }, nil) + + nestedSchema := arrow.NewSchema([]arrow.Field{ + {Name: "a", Type: arrow.StructOf(basicSchema.Fields()...), Nullable: false}, + }, nil) + + bldr := exprs.NewExprBuilder(extSet) + + tests := []struct { + name string + ex exprs.Builder + sc *arrow.Schema + input compute.Datum + expected compute.Datum + }{ + {"add", bldr.MustCallScalar("add", nil, bldr.FieldRef("a"), + bldr.Literal(expr.NewPrimitiveLiteral(float64(3.5), false))), + basicSchema, + compute.NewDatumWithoutOwning(fromJSON(arrow.StructOf(basicSchema.Fields()...), + `[ + {"a": 6.125, "b": 3.375}, + {"a": 0.0, "b": 1}, + {"a": -1, "b": 4.75} + ]`)), compute.NewDatumWithoutOwning(fromJSON(arrow.PrimitiveTypes.Float64, + `[9.625, 3.5, 2.5]`))}, + {"add sub", bldr.MustCallScalar("add", nil, bldr.FieldRef("a"), + bldr.MustCallScalar("subtract", nil, + bldr.WrapLiteral(expr.NewLiteral(float64(3.5), false)), + bldr.FieldRef("b"))), + basicSchema, + compute.NewDatumWithoutOwning(fromJSON(arrow.StructOf(basicSchema.Fields()...), + `[ + {"a": 6.125, "b": 3.375}, + {"a": 0.0, "b": 1}, + {"a": -1, "b": 4.75} + ]`)), compute.NewDatumWithoutOwning(fromJSON(arrow.PrimitiveTypes.Float64, + `[6.25, 2.5, -2.25]`))}, + {"add nested", bldr.MustCallScalar("add", nil, + bldr.FieldRefList("a", "a"), bldr.FieldRefList("a", "b")), nestedSchema, + compute.NewDatumWithoutOwning(fromJSON(arrow.StructOf(nestedSchema.Fields()...), + `[ + {"a": {"a": 6.125, "b": 3.375}}, + {"a": {"a": 0.0, "b": 1}}, + {"a": {"a": -1, "b": 4.75}} + ]`)), compute.NewDatumWithoutOwning(fromJSON(arrow.PrimitiveTypes.Float64, + `[9.5, 1, 3.75]`))}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scoped := memory.NewCheckedAllocatorScope(mem) + defer scoped.CheckSize(t) + + bldr.SetInputSchema(tt.sc) + ex, err := tt.ex.BuildExpr() + require.NoError(t, err) + + ctx := exprs.WithExtensionIDSet(compute.WithAllocator(context.Background(), mem), extSet) + dt := tt.input.(compute.ArrayLikeDatum).Type().(arrow.NestedType) + schema := arrow.NewSchema(dt.Fields(), nil) + + actual, err := exprs.ExecuteScalarExpression(ctx, schema, ex, tt.input) + require.NoError(t, err) + defer actual.Release() + + assert.Truef(t, tt.expected.Equals(actual), "expected: %s\ngot: %s", tt.expected, actual) + }) + } +} + +func TestGenerateMask(t *testing.T) { + sc, err := boringArrowSchema.AddField(0, arrow.Field{ + Name: "in", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}) + require.NoError(t, err) + + bldr := exprs.NewExprBuilder(extSet) + require.NoError(t, bldr.SetInputSchema(sc)) + + tests := []struct { + name string + json string + filter exprs.Builder + }{ + {"simple", `[ + {"i32": 0, "f32": -0.1, "in": true}, + {"i32": 0, "f32": 0.3, "in": true}, + {"i32": 1, "f32": 0.2, "in": false}, + {"i32": 2, "f32": -0.1, "in": false}, + {"i32": 0, "f32": 0.1, "in": true}, + {"i32": 0, "f32": null, "in": true}, + {"i32": 0, "f32": 1.0, "in": true} + ]`, bldr.MustCallScalar("equal", nil, + bldr.FieldRef("i32"), bldr.Literal(expr.NewPrimitiveLiteral(int32(0), false)))}, + {"complex", `[ + {"f64": 0.3, "f32": 0.1, "in": true}, + {"f64": -0.1, "f32": 0.3, "in": false}, + {"f64": 0.1, "f32": 0.2, "in": true}, + {"f64": 0.0, "f32": -0.1, "in": false}, + {"f64": 1.0, "f32": 0.1, "in": true}, + {"f64": -2.0, "f32": null, "in": null}, + {"f64": 3.0, "f32": 1.0, "in": true} + ]`, bldr.MustCallScalar("greater", nil, + bldr.MustCallScalar("multiply", nil, + bldr.Must(bldr.Cast(bldr.FieldRef("f32"), arrow.PrimitiveTypes.Float64)), + bldr.FieldRef("f64")), + bldr.Literal(expr.NewPrimitiveLiteral(float64(0), false)))}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + ctx := exprs.WithExtensionIDSet(compute.WithAllocator(context.Background(), mem), extSet) + + rec, _, err := array.RecordFromJSON(mem, sc, strings.NewReader(tt.json)) + require.NoError(t, err) + defer rec.Release() + + input := compute.NewDatumWithoutOwning(rec) + expectedMask := rec.Column(0) + + mask, err := exprs.ExecuteScalarExpression(ctx, sc, + expr.MustExpr(tt.filter.BuildExpr()), input) + require.NoError(t, err) + defer mask.Release() + + assertEqual(t, expectedMask, mask) + }) + } +} diff --git a/go/arrow/compute/exprs/extension_types.go b/go/arrow/compute/exprs/extension_types.go new file mode 100644 index 0000000000000..dd753727f59c8 --- /dev/null +++ b/go/arrow/compute/exprs/extension_types.go @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 + +package exprs + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" +) + +type simpleExtensionTypeFactory[P comparable] struct { + arrow.ExtensionBase + + params P + name string + getStorage func(P) arrow.DataType +} + +func (ef *simpleExtensionTypeFactory[P]) String() string { return "extension<" + ef.Serialize() + ">" } +func (ef *simpleExtensionTypeFactory[P]) ExtensionName() string { return ef.name } +func (ef *simpleExtensionTypeFactory[P]) Serialize() string { + s, _ := json.Marshal(ef.params) + return ef.name + string(s) +} +func (ef *simpleExtensionTypeFactory[P]) Deserialize(storage arrow.DataType, data string) (arrow.ExtensionType, error) { + if !strings.HasPrefix(data, ef.name) { + return nil, fmt.Errorf("%w: invalid deserialization of extension type %s", arrow.ErrInvalid, ef.name) + } + + data = strings.TrimPrefix(data, ef.name) + if err := json.Unmarshal([]byte(data), &ef.params); err != nil { + return nil, fmt.Errorf("%w: failed parsing parameters for extension type", err) + } + + if !arrow.TypeEqual(storage, ef.getStorage(ef.params)) { + return nil, fmt.Errorf("%w: invalid storage type for %s: %s (expected: %s)", + arrow.ErrInvalid, ef.name, storage, ef.getStorage(ef.params)) + } + + return &simpleExtensionTypeFactory[P]{ + name: ef.name, + params: ef.params, + getStorage: ef.getStorage, + ExtensionBase: arrow.ExtensionBase{ + Storage: storage, + }, + }, nil +} +func (ef *simpleExtensionTypeFactory[P]) ExtensionEquals(other arrow.ExtensionType) bool { + if ef.name != other.ExtensionName() { + return false + } + + rhs := other.(*simpleExtensionTypeFactory[P]) + return ef.params == rhs.params +} +func (ef *simpleExtensionTypeFactory[P]) ArrayType() reflect.Type { + return reflect.TypeOf(array.ExtensionArrayBase{}) +} + +func (ef *simpleExtensionTypeFactory[P]) CreateType(params P) arrow.DataType { + storage := ef.getStorage(params) + + return &simpleExtensionTypeFactory[P]{ + name: ef.name, + params: params, + getStorage: ef.getStorage, + ExtensionBase: arrow.ExtensionBase{ + Storage: storage, + }, + } +} + +type uuidExtParams struct{} + +var uuidType = simpleExtensionTypeFactory[uuidExtParams]{ + name: "uuid", getStorage: func(uuidExtParams) arrow.DataType { + return &arrow.FixedSizeBinaryType{ByteWidth: 16} + }} + +type fixedCharExtensionParams struct { + Length int32 `json:"length"` +} + +var fixedCharType = simpleExtensionTypeFactory[fixedCharExtensionParams]{ + name: "fixed_char", getStorage: func(p fixedCharExtensionParams) arrow.DataType { + return &arrow.FixedSizeBinaryType{ByteWidth: int(p.Length)} + }, +} + +type varCharExtensionParams struct { + Length int32 `json:"length"` +} + +var varCharType = simpleExtensionTypeFactory[varCharExtensionParams]{ + name: "varchar", getStorage: func(varCharExtensionParams) arrow.DataType { + return arrow.BinaryTypes.String + }, +} + +type intervalYearExtensionParams struct{} + +var intervalYearType = simpleExtensionTypeFactory[intervalYearExtensionParams]{ + name: "interval_year", getStorage: func(intervalYearExtensionParams) arrow.DataType { + return arrow.FixedSizeListOf(2, arrow.PrimitiveTypes.Int32) + }, +} + +type intervalDayExtensionParams struct{} + +var intervalDayType = simpleExtensionTypeFactory[intervalDayExtensionParams]{ + name: "interval_day", getStorage: func(intervalDayExtensionParams) arrow.DataType { + return arrow.FixedSizeListOf(2, arrow.PrimitiveTypes.Int32) + }, +} + +func uuid() arrow.DataType { return uuidType.CreateType(uuidExtParams{}) } +func fixedChar(length int32) arrow.DataType { + return fixedCharType.CreateType(fixedCharExtensionParams{Length: length}) +} +func varChar(length int32) arrow.DataType { + return varCharType.CreateType(varCharExtensionParams{Length: length}) +} +func intervalYear() arrow.DataType { + return intervalYearType.CreateType(intervalYearExtensionParams{}) +} +func intervalDay() arrow.DataType { + return intervalDayType.CreateType(intervalDayExtensionParams{}) +} diff --git a/go/arrow/compute/exprs/field_refs.go b/go/arrow/compute/exprs/field_refs.go new file mode 100644 index 0000000000000..3a08d519fbebe --- /dev/null +++ b/go/arrow/compute/exprs/field_refs.go @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 + +package exprs + +import ( + "fmt" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/compute" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/apache/arrow/go/v13/arrow/scalar" + "github.com/substrait-io/substrait-go/expr" +) + +func getFields(typ arrow.DataType) []arrow.Field { + if nested, ok := typ.(arrow.NestedType); ok { + return nested.Fields() + } + return nil +} + +// GetRefField evaluates the substrait field reference to retrieve the +// referenced field or return an error. +func GetRefField(ref expr.ReferenceSegment, fields []arrow.Field) (*arrow.Field, error) { + if ref == nil { + return nil, compute.ErrEmpty + } + + var ( + out *arrow.Field + ) + + for ref != nil { + if len(fields) == 0 { + return nil, fmt.Errorf("%w: %s", compute.ErrNoChildren, out.Type) + } + + switch f := ref.(type) { + case *expr.StructFieldRef: + if f.Field < 0 || f.Field >= int32(len(fields)) { + return nil, fmt.Errorf("%w: indices=%s", compute.ErrIndexRange, f) + } + + out = &fields[f.Field] + fields = getFields(out.Type) + default: + return nil, arrow.ErrNotImplemented + } + + ref = ref.GetChild() + } + + return out, nil +} + +// GetRefSchema evaluates the provided substrait field reference against +// the schema to retrieve the referenced (potentially nested) field. +func GetRefSchema(ref expr.ReferenceSegment, schema *arrow.Schema) (*arrow.Field, error) { + return GetRefField(ref, schema.Fields()) +} + +// GetScalar returns the evaluated referenced scalar value from the provided +// scalar which must be appropriate to the type of reference. +// +// A StructFieldRef can only reference against a Struct-type scalar, a +// ListElementRef can only reference against a List or LargeList scalar, +// and a MapKeyRef will only reference against a Map scalar. An error is +// returned if following the reference children ends up with an invalid +// nested reference object. +func GetScalar(ref expr.ReferenceSegment, s scalar.Scalar, mem memory.Allocator, ext ExtensionIDSet) (scalar.Scalar, error) { + if ref == nil { + return nil, compute.ErrEmpty + } + + var out scalar.Scalar + for ref != nil { + switch f := ref.(type) { + case *expr.StructFieldRef: + if s.DataType().ID() != arrow.STRUCT { + return nil, fmt.Errorf("%w: attempting to reference field from non-struct scalar %s", + arrow.ErrInvalid, s) + } + + st := s.(*scalar.Struct) + if f.Field < 0 || f.Field >= int32(len(st.Value)) { + return nil, fmt.Errorf("%w: indices=%s", compute.ErrIndexRange, ref) + } + + out = st.Value[f.Field] + case *expr.ListElementRef: + switch v := s.(type) { + case *scalar.List: + sc, err := scalar.GetScalar(v.Value, int(f.Offset)) + if err != nil { + return nil, err + } + out = sc + case *scalar.LargeList: + sc, err := scalar.GetScalar(v.Value, int(f.Offset)) + if err != nil { + return nil, err + } + out = sc + default: + return nil, fmt.Errorf("%w: cannot get ListElementRef from non-list scalar %s", + arrow.ErrInvalid, v) + } + case *expr.MapKeyRef: + v, ok := s.(*scalar.Map) + if !ok { + return nil, arrow.ErrInvalid + } + + dt, _, err := FromSubstraitType(f.MapKey.GetType(), ext) + if err != nil { + return nil, err + } + + if !arrow.TypeEqual(dt, v.Type.(*arrow.MapType).KeyType()) { + return nil, arrow.ErrInvalid + } + + keyvalDatum, err := literalToDatum(mem, f.MapKey, ext) + if err != nil { + return nil, err + } + + var ( + keyval = keyvalDatum.(*compute.ScalarDatum) + m = v.Value.(*array.Struct) + keys = m.Field(0) + valueScalar scalar.Scalar + ) + for i := 0; i < v.Value.Len(); i++ { + kv, err := scalar.GetScalar(keys, i) + if err != nil { + return nil, err + } + if scalar.Equals(kv, keyval.Value) { + valueScalar, err = scalar.GetScalar(m.Field(1), i) + if err != nil { + return nil, err + } + break + } + } + + if valueScalar == nil { + return nil, arrow.ErrNotFound + } + + out = valueScalar + } + s = out + ref = ref.GetChild() + } + + return out, nil +} + +// GetReferencedValue retrieves the referenced (potentially nested) value from +// the provided datum which may be a scalar, array, or record batch. +func GetReferencedValue(mem memory.Allocator, ref expr.ReferenceSegment, value compute.Datum, ext ExtensionIDSet) (compute.Datum, error) { + if ref == nil { + return nil, compute.ErrEmpty + } + + for ref != nil { + // process the rest of the refs for the scalars + // since arrays can go down to a scalar, but you + // won't get an array from a scalar via ref + if v, ok := value.(*compute.ScalarDatum); ok { + out, err := GetScalar(ref, v.Value, mem, ext) + if err != nil { + return nil, err + } + + return &compute.ScalarDatum{Value: out}, nil + } + + switch r := ref.(type) { + case *expr.MapKeyRef: + return nil, arrow.ErrNotImplemented + case *expr.StructFieldRef: + switch v := value.(type) { + case *compute.ArrayDatum: + if v.Type().ID() != arrow.STRUCT { + return nil, fmt.Errorf("%w: struct field ref for non struct type %s", + arrow.ErrInvalid, v.Type()) + } + + if r.Field < 0 || r.Field >= int32(len(v.Value.Children())) { + return nil, fmt.Errorf("%w: indices=%s", compute.ErrIndexRange, ref) + } + + value = &compute.ArrayDatum{Value: v.Value.Children()[r.Field]} + case *compute.RecordDatum: + if r.Field < 0 || r.Field >= int32(v.Value.NumCols()) { + return nil, fmt.Errorf("%w: indices=%s", compute.ErrIndexRange, ref) + } + + value = &compute.ArrayDatum{Value: v.Value.Column(int(r.Field)).Data()} + default: + return nil, arrow.ErrNotImplemented + } + case *expr.ListElementRef: + switch v := value.(type) { + case *compute.ArrayDatum: + switch v.Type().ID() { + case arrow.LIST, arrow.LARGE_LIST, arrow.FIXED_SIZE_LIST: + arr := v.MakeArray() + defer arr.Release() + + sc, err := scalar.GetScalar(arr, int(r.Offset)) + if err != nil { + return nil, err + } + if s, ok := sc.(scalar.Releasable); ok { + defer s.Release() + } + + value = &compute.ScalarDatum{Value: sc} + default: + return nil, fmt.Errorf("%w: cannot reference list element in non-list array type %s", + arrow.ErrInvalid, v.Type()) + } + + default: + return nil, arrow.ErrNotImplemented + } + } + + ref = ref.GetChild() + } + + return value, nil +} diff --git a/go/arrow/compute/exprs/types.go b/go/arrow/compute/exprs/types.go new file mode 100644 index 0000000000000..f169f18d4bca7 --- /dev/null +++ b/go/arrow/compute/exprs/types.go @@ -0,0 +1,745 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 + +package exprs + +import ( + "fmt" + "hash/maphash" + "strconv" + "strings" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/compute" + "github.com/substrait-io/substrait-go/expr" + "github.com/substrait-io/substrait-go/extensions" + "github.com/substrait-io/substrait-go/types" +) + +const ( + // URI for official Arrow Substrait Extension Types + ArrowExtTypesUri = "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml" + SubstraitDefaultURIPrefix = extensions.SubstraitDefaultURIPrefix + // URI for official Substrait Arithemetic funcs extensions + SubstraitArithmeticFuncsURI = SubstraitDefaultURIPrefix + "functions_arithmetic.yaml" + // URI for official Substrait Comparison funcs extensions + SubstraitComparisonFuncsURI = SubstraitDefaultURIPrefix + "functions_comparison.yaml" + + TimestampTzTimezone = "UTC" +) + +var hashSeed maphash.Seed + +// the default extension registry that will contain the Arrow extension +// type variations and types. +var DefaultExtensionIDRegistry = NewExtensionIDRegistry() + +func init() { + hashSeed = maphash.MakeSeed() + + types := []struct { + dt arrow.DataType + name string + }{ + {arrow.PrimitiveTypes.Uint8, "u8"}, + {arrow.PrimitiveTypes.Uint16, "u16"}, + {arrow.PrimitiveTypes.Uint32, "u32"}, + {arrow.PrimitiveTypes.Uint64, "u64"}, + {arrow.FixedWidthTypes.Float16, "fp16"}, + {arrow.Null, "null"}, + {arrow.FixedWidthTypes.MonthInterval, "interval_month"}, + {arrow.FixedWidthTypes.DayTimeInterval, "interval_day_milli"}, + {arrow.FixedWidthTypes.MonthDayNanoInterval, "interval_month_day_nano"}, + } + + for _, t := range types { + err := DefaultExtensionIDRegistry.RegisterType(extensions.ID{ + URI: ArrowExtTypesUri, Name: t.name}, t.dt) + if err != nil { + panic(err) + } + } + + for _, fn := range []string{"add", "subtract", "multiply", "divide", "power", "sqrt", "abs"} { + err := DefaultExtensionIDRegistry.AddSubstraitScalarToArrow( + extensions.ID{URI: SubstraitArithmeticFuncsURI, Name: fn}, + decodeOptionlessOverflowableArithmetic(fn)) + if err != nil { + panic(err) + } + } + + for _, fn := range []string{"add", "subtract", "multiply", "divide"} { + err := DefaultExtensionIDRegistry.AddArrowToSubstrait(fn, + encodeOptionlessOverflowableArithmetic(extensions.ID{ + URI: SubstraitArithmeticFuncsURI, Name: fn})) + if err != nil { + panic(err) + } + } + + for _, fn := range []string{"equal", "not_equal", "lt", "lte", "gt", "gte"} { + err := DefaultExtensionIDRegistry.AddSubstraitScalarToArrow( + extensions.ID{URI: SubstraitComparisonFuncsURI, Name: fn}, + simpleMapSubstraitToArrowFunc) + if err != nil { + panic(err) + } + } + + for _, fn := range []string{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"} { + err := DefaultExtensionIDRegistry.AddArrowToSubstrait(fn, + simpleMapArrowToSubstraitFunc(SubstraitComparisonFuncsURI)) + if err != nil { + panic(err) + } + } +} + +type overflowBehavior string + +const ( + overflowSILENT = "SILENT" + overflowSATURATE = "SATURATE" + overflowERROR = "ERROR" +) + +type enumParser[typ ~string] struct { + values map[typ]struct{} +} + +func (e *enumParser[typ]) parse(v string) (typ, error) { + out := typ(v) + if _, ok := e.values[out]; ok { + return out, nil + } + return "", arrow.ErrNotFound +} + +var overflowParser = enumParser[overflowBehavior]{ + values: map[overflowBehavior]struct{}{ + overflowSILENT: {}, + overflowSATURATE: {}, + overflowERROR: {}, + }, +} + +func parseOption[typ ~string](sf *expr.ScalarFunction, optionName string, parser *enumParser[typ], implemented []typ, def typ) (typ, error) { + opts := sf.GetOption(optionName) + if len(opts) == 0 { + return def, nil + } + + for _, o := range opts { + p, err := parser.parse(o) + if err != nil { + return def, arrow.ErrInvalid + } + for _, i := range implemented { + if i == p { + return p, nil + } + } + } + + return def, arrow.ErrNotImplemented +} + +type substraitToArrow = func(*expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) +type arrowToSubstrait = func(fname string) (extensions.ID, []*types.FunctionOption, error) + +var substraitToArrowFuncMap = map[string]string{ + "lt": "less", + "gt": "greater", + "lte": "less_equal", + "gte": "greater_equal", +} + +var arrowToSubstraitFuncMap = map[string]string{ + "less": "lt", + "greater": "gt", + "less_equal": "lte", + "greater_equal": "gte", +} + +func simpleMapSubstraitToArrowFunc(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) { + fname, _, _ = strings.Cut(sf.Name(), ":") + f, ok := substraitToArrowFuncMap[fname] + if ok { + fname = f + } + return +} + +func simpleMapArrowToSubstraitFunc(uri string) arrowToSubstrait { + return func(fname string) (extensions.ID, []*types.FunctionOption, error) { + f, ok := arrowToSubstraitFuncMap[fname] + if ok { + fname = f + } + return extensions.ID{URI: uri, Name: fname}, nil, nil + } +} + +func decodeOptionlessOverflowableArithmetic(n string) substraitToArrow { + return func(sf *expr.ScalarFunction) (fname string, opts compute.FunctionOptions, err error) { + overflow, err := parseOption(sf, "overflow", &overflowParser, []overflowBehavior{overflowSILENT, overflowERROR}, overflowSILENT) + if err != nil { + return n, nil, err + } + + switch overflow { + case overflowSILENT: + return n + "_unchecked", nil, nil + case overflowERROR: + return n, nil, nil + default: + return n, nil, arrow.ErrNotImplemented + } + } +} + +func encodeOptionlessOverflowableArithmetic(id extensions.ID) arrowToSubstrait { + return func(fname string) (extensions.ID, []*types.FunctionOption, error) { + fn, _, ok := strings.Cut(fname, ":") + if ok { + id.Name = fname + fname = fn + } + + opts := make([]*types.FunctionOption, 0, 1) + if strings.HasSuffix(fname, "_unchecked") { + opts = append(opts, &types.FunctionOption{ + Name: "overflow", Preference: []string{"SILENT"}}) + } else { + opts = append(opts, &types.FunctionOption{ + Name: "overflow", Preference: []string{"ERROR"}}) + } + + return id, opts, nil + } +} + +// NewExtensionSetDefault is a convenince function to create a new extension +// set using the Default arrow extension ID registry. +// +// See NewExtensionSet for more info. +func NewExtensionSetDefault(set expr.ExtensionRegistry) ExtensionIDSet { + return &extensionSet{ExtensionRegistry: set, reg: DefaultExtensionIDRegistry} +} + +// NewExtensionSet creates a new extension set given a substrait extension registry, +// and an Arrow <--> Substrait registry for mapping substrait extensions to +// their Arrow equivalents. This extension set can then be used to manage a +// particular set of extensions in use by an expression or plan, so when +// serializing you only need to serialize the extensions that have been +// inserted into the extension set. +func NewExtensionSet(set expr.ExtensionRegistry, reg *ExtensionIDRegistry) ExtensionIDSet { + return &extensionSet{ExtensionRegistry: set, reg: reg} +} + +type extensionSet struct { + expr.ExtensionRegistry + reg *ExtensionIDRegistry +} + +func (e *extensionSet) GetArrowRegistry() *ExtensionIDRegistry { return e.reg } +func (e *extensionSet) GetSubstraitRegistry() expr.ExtensionRegistry { return e.ExtensionRegistry } + +func (e *extensionSet) DecodeTypeArrow(anchor uint32) (extensions.ID, arrow.DataType, bool) { + id, ok := e.Set.DecodeType(anchor) + if !ok { + if id, ok = e.Set.DecodeTypeVariation(anchor); !ok { + return id, nil, false + } + } + + dt, ok := e.reg.GetTypeByID(id) + return id, dt, ok +} + +func (e *extensionSet) DecodeFunction(ref uint32) (extensions.ID, substraitToArrow, bool) { + id, ok := e.Set.DecodeFunc(ref) + if !ok { + return id, nil, false + } + + conv, ok := e.reg.GetSubstraitScalarToArrow(id) + if !ok { + id.Name, _, ok = strings.Cut(id.Name, ":") + if ok { + conv, ok = e.reg.GetSubstraitScalarToArrow(id) + } + } + return id, conv, ok +} + +func (e *extensionSet) EncodeTypeVariation(dt arrow.DataType) (extensions.ID, uint32, bool) { + id, ok := e.reg.GetIDByType(dt) + if !ok { + return extensions.ID{}, 0, false + } + + return id, e.Set.GetTypeVariationAnchor(id), true +} + +func (e *extensionSet) EncodeType(dt arrow.DataType) (extensions.ID, uint32, bool) { + id, ok := e.reg.GetIDByType(dt) + if !ok { + return extensions.ID{}, 0, false + } + + return id, e.Set.GetTypeAnchor(id), true +} + +func (e *extensionSet) EncodeFunction(id extensions.ID) uint32 { + return e.Set.GetFuncAnchor(id) +} + +// ExtensionIDRegistry manages a set of mappings between Arrow types +// and functions and their substrait equivalents. +type ExtensionIDRegistry struct { + typeList []arrow.DataType + ids []extensions.ID + + substraitToIdx map[extensions.ID]int + arrowToIdx map[uint64]int + + substraitToArrowFn map[extensions.ID]substraitToArrow + arrowToSubstrait map[string]arrowToSubstrait +} + +// NewExtensionIDRegistry initializes a new registry for use. +func NewExtensionIDRegistry() *ExtensionIDRegistry { + return &ExtensionIDRegistry{ + typeList: make([]arrow.DataType, 0), + ids: make([]extensions.ID, 0), + substraitToIdx: make(map[extensions.ID]int), + arrowToIdx: make(map[uint64]int), + substraitToArrowFn: make(map[extensions.ID]substraitToArrow), + arrowToSubstrait: make(map[string]arrowToSubstrait), + } +} + +// RegisterType creates a mapping between the given extension ID and the +// provided Arrow data type. If this extension ID or arrow type are already +// registered, an arrow.ErrInvalid error will be returned. +func (e *ExtensionIDRegistry) RegisterType(id extensions.ID, dt arrow.DataType) error { + if _, ok := e.substraitToIdx[id]; ok { + return fmt.Errorf("%w: type id already registered", arrow.ErrInvalid) + } + + dthash := arrow.HashType(hashSeed, dt) + if _, ok := e.arrowToIdx[dthash]; ok { + return fmt.Errorf("%w: type already registered", arrow.ErrInvalid) + } + + idx := len(e.ids) + e.typeList = append(e.typeList, dt) + e.ids = append(e.ids, id) + e.substraitToIdx[id] = idx + e.arrowToIdx[dthash] = idx + return nil +} + +// AddSubstraitScalarToArrow creates a mapping between a given extension ID +// and a function which should return the corresponding Arrow compute function +// name along with any relevant FunctionOptions based on the ScalarFunction +// instance passed to it. +// +// Any relevant options should be parsed from the ScalarFunction's options +// and used to ensure the correct arrow compute function is used and necessary +// options are passed. +func (e *ExtensionIDRegistry) AddSubstraitScalarToArrow(id extensions.ID, toArrow substraitToArrow) error { + if _, ok := e.substraitToArrowFn[id]; ok { + return fmt.Errorf("%w: extension id already registered as function", arrow.ErrInvalid) + } + + e.substraitToArrowFn[id] = toArrow + return nil +} + +// AddArrowToSubstrait creates a mapping between the provided arrow compute function +// and a function which should provide the correct substrait ExtensionID and function +// options from that name. +func (e *ExtensionIDRegistry) AddArrowToSubstrait(name string, fn arrowToSubstrait) error { + if _, ok := e.arrowToSubstrait[name]; ok { + return fmt.Errorf("%w: function name '%s' already registered for conversion to substrait", arrow.ErrInvalid, name) + } + + e.arrowToSubstrait[name] = fn + return nil +} + +// GetTypeByID returns the mapped arrow data type from the provided substrait +// extension id. If no mapping exists for this substrait extension id, +// the second return value will be false. +func (e *ExtensionIDRegistry) GetTypeByID(id extensions.ID) (arrow.DataType, bool) { + idx, ok := e.substraitToIdx[id] + if !ok { + return nil, false + } + + return e.typeList[idx], true +} + +// GetIDByType is the inverse of GetTypeByID, returning the mapped substrait +// extension ID corresponding to the provided arrow data type. The second +// return is false if there is no mapping found. +func (e *ExtensionIDRegistry) GetIDByType(typ arrow.DataType) (extensions.ID, bool) { + dthash := arrow.HashType(hashSeed, typ) + idx, ok := e.arrowToIdx[dthash] + if !ok { + return extensions.ID{}, false + } + + return e.ids[idx], true +} + +// GetSubstraitScalarToArrow returns the mapped conversion function for a +// given substrait extension ID to convert a substrait ScalarFunction to +// the corresponding Arrow compute function call. False is returned as +// the second value if there is no mapping available. +func (e *ExtensionIDRegistry) GetSubstraitScalarToArrow(id extensions.ID) (substraitToArrow, bool) { + conv, ok := e.substraitToArrowFn[id] + if !ok { + return nil, ok + } + + return conv, true +} + +// GetArrowToSubstrait returns the mapped function to convert an arrow compute +// function to the corresponding Substrait ScalarFunction extension ID and options. +// False is returned as the second value if there is no mapping found. +func (e *ExtensionIDRegistry) GetArrowToSubstrait(name string) (conv arrowToSubstrait, ok bool) { + conv, ok = e.arrowToSubstrait[name] + if !ok { + fn, _, found := strings.Cut(name, ":") + if found { + conv, ok = e.arrowToSubstrait[fn] + } + } + return +} + +// ExtensionIDSet is an interface for managing the mapping between arrow +// and substrait types and function extensions. +type ExtensionIDSet interface { + GetArrowRegistry() *ExtensionIDRegistry + GetSubstraitRegistry() expr.ExtensionRegistry + + DecodeTypeArrow(anchor uint32) (extensions.ID, arrow.DataType, bool) + DecodeFunction(ref uint32) (extensions.ID, substraitToArrow, bool) + + EncodeType(dt arrow.DataType) (extensions.ID, uint32, bool) + EncodeTypeVariation(dt arrow.DataType) (extensions.ID, uint32, bool) +} + +// IsNullable is a convenience method to return whether or not +// a substrait type has Nullability set to NullabilityRequired or not. +func IsNullable(t types.Type) bool { + return t.GetNullability() != types.NullabilityRequired +} + +// FieldsFromSubstrait produces a list of arrow fields from a list of +// substrait types (such as the fields of a StructType) using nextName +// to determine the names for the fields. +func FieldsFromSubstrait(typeList []types.Type, nextName func() string, ext ExtensionIDSet) (out []arrow.Field, err error) { + out = make([]arrow.Field, len(typeList)) + for i, t := range typeList { + out[i].Name = nextName() + out[i].Nullable = IsNullable(t) + + if st, ok := t.(*types.StructType); ok { + fields, err := FieldsFromSubstrait(st.Types, nextName, ext) + if err != nil { + return nil, err + } + out[i].Type = arrow.StructOf(fields...) + } else { + out[i].Type, _, err = FromSubstraitType(t, ext) + if err != nil { + return nil, err + } + } + } + return +} + +// ToSubstraitType converts an arrow data type to a Substrait Type. Since +// arrow types don't have a nullable flag (it is in the arrow.Field) but +// Substrait types do, the nullability must be passed in here. +func ToSubstraitType(dt arrow.DataType, nullable bool, ext ExtensionIDSet) (types.Type, error) { + var nullability types.Nullability + if nullable { + nullability = types.NullabilityNullable + } else { + nullability = types.NullabilityRequired + } + + switch dt.ID() { + case arrow.BOOL: + return &types.BooleanType{Nullability: nullability}, nil + case arrow.INT8: + return &types.Int8Type{Nullability: nullability}, nil + case arrow.INT16: + return &types.Int16Type{Nullability: nullability}, nil + case arrow.INT32: + return &types.Int32Type{Nullability: nullability}, nil + case arrow.INT64: + return &types.Int64Type{Nullability: nullability}, nil + case arrow.UINT8: + _, anchor, ok := ext.EncodeTypeVariation(dt) + if !ok { + return nil, arrow.ErrNotFound + } + return &types.Int8Type{ + Nullability: nullability, + TypeVariationRef: anchor, + }, nil + case arrow.UINT16: + _, anchor, ok := ext.EncodeTypeVariation(dt) + if !ok { + return nil, arrow.ErrNotFound + } + return &types.Int16Type{ + Nullability: nullability, + TypeVariationRef: anchor, + }, nil + case arrow.UINT32: + _, anchor, ok := ext.EncodeTypeVariation(dt) + if !ok { + return nil, arrow.ErrNotFound + } + return &types.Int32Type{ + Nullability: nullability, + TypeVariationRef: anchor, + }, nil + case arrow.UINT64: + _, anchor, ok := ext.EncodeTypeVariation(dt) + if !ok { + return nil, arrow.ErrNotFound + } + return &types.Int64Type{ + Nullability: nullability, + TypeVariationRef: anchor, + }, nil + case arrow.FLOAT16: + _, anchor, ok := ext.EncodeTypeVariation(dt) + if !ok { + return nil, arrow.ErrNotFound + } + return &types.Int16Type{ + Nullability: nullability, + TypeVariationRef: anchor, + }, nil + case arrow.FLOAT32: + return &types.Float32Type{Nullability: nullability}, nil + case arrow.FLOAT64: + return &types.Float64Type{Nullability: nullability}, nil + case arrow.STRING: + return &types.StringType{Nullability: nullability}, nil + case arrow.BINARY: + return &types.BinaryType{Nullability: nullability}, nil + case arrow.DATE32: + return &types.DateType{Nullability: nullability}, nil + case arrow.EXTENSION: + dt := dt.(arrow.ExtensionType) + switch dt.ExtensionName() { + case "uuid": + return &types.UUIDType{Nullability: nullability}, nil + case "fixed_char": + return &types.FixedCharType{ + Nullability: nullability, + Length: int32(dt.StorageType().(*arrow.FixedSizeBinaryType).ByteWidth), + }, nil + case "varchar": + return &types.VarCharType{Nullability: nullability, Length: -1}, nil + case "interval_year": + return &types.IntervalYearType{Nullability: nullability}, nil + case "interval_day": + return &types.IntervalDayType{Nullability: nullability}, nil + default: + _, anchor, ok := ext.EncodeType(dt) + if !ok { + return nil, arrow.ErrNotFound + } + return &types.UserDefinedType{ + Nullability: nullability, + TypeReference: anchor, + }, nil + } + case arrow.FIXED_SIZE_BINARY: + return &types.FixedBinaryType{Nullability: nullability, + Length: int32(dt.(*arrow.FixedSizeBinaryType).ByteWidth)}, nil + case arrow.DECIMAL128, arrow.DECIMAL256: + dt := dt.(arrow.DecimalType) + return &types.DecimalType{Nullability: nullability, + Precision: dt.GetPrecision(), Scale: dt.GetScale()}, nil + case arrow.STRUCT: + dt := dt.(*arrow.StructType) + fields := make([]types.Type, len(dt.Fields())) + var err error + for i, f := range dt.Fields() { + fields[i], err = ToSubstraitType(f.Type, f.Nullable, ext) + if err != nil { + return nil, err + } + } + + return &types.StructType{ + Nullability: nullability, + Types: fields, + }, nil + case arrow.LIST, arrow.FIXED_SIZE_LIST, arrow.LARGE_LIST: + dt := dt.(arrow.NestedType) + elemType, err := ToSubstraitType(dt.Fields()[0].Type, dt.Fields()[0].Nullable, ext) + if err != nil { + return nil, err + } + return &types.ListType{ + Nullability: nullability, + Type: elemType, + }, nil + case arrow.MAP: + dt := dt.(*arrow.MapType) + keyType, err := ToSubstraitType(dt.KeyType(), false, ext) + if err != nil { + return nil, err + } + valueType, err := ToSubstraitType(dt.ValueType(), dt.ValueField().Nullable, ext) + if err != nil { + return nil, err + } + + return &types.MapType{ + Nullability: nullability, + Key: keyType, + Value: valueType, + }, nil + } + + return nil, arrow.ErrNotImplemented +} + +// FromSubstraitType returns the appropriate Arrow data type for the given +// substrait type, using the extension set if necessary. +// Since Substrait types contain their nullability also, the nullability +// returned along with the data type. +func FromSubstraitType(t types.Type, ext ExtensionIDSet) (arrow.DataType, bool, error) { + nullable := IsNullable(t) + + if t.GetTypeVariationReference() > 0 { + _, dt, ok := ext.DecodeTypeArrow(t.GetTypeVariationReference()) + if ok { + return dt, nullable, nil + } + } + + switch t := t.(type) { + case *types.BooleanType: + return arrow.FixedWidthTypes.Boolean, nullable, nil + case *types.Int8Type: + return arrow.PrimitiveTypes.Int8, nullable, nil + case *types.Int16Type: + return arrow.PrimitiveTypes.Int16, nullable, nil + case *types.Int32Type: + return arrow.PrimitiveTypes.Int32, nullable, nil + case *types.Int64Type: + return arrow.PrimitiveTypes.Int64, nullable, nil + case *types.Float32Type: + return arrow.PrimitiveTypes.Float32, nullable, nil + case *types.Float64Type: + return arrow.PrimitiveTypes.Float64, nullable, nil + case *types.StringType: + return arrow.BinaryTypes.String, nullable, nil + case *types.BinaryType: + return arrow.BinaryTypes.Binary, nullable, nil + case *types.TimestampType: + return &arrow.TimestampType{Unit: arrow.Microsecond}, nullable, nil + case *types.TimestampTzType: + return &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone}, + nullable, nil + case *types.DateType: + return arrow.FixedWidthTypes.Date32, nullable, nil + case *types.TimeType: + return &arrow.Time64Type{Unit: arrow.Microsecond}, nullable, nil + case *types.IntervalYearType: + return intervalYear(), nullable, nil + case *types.IntervalDayType: + return intervalDay(), nullable, nil + case *types.UUIDType: + return uuid(), nullable, nil + case *types.FixedCharType: + return fixedChar(t.Length), nullable, nil + case *types.VarCharType: + return varChar(t.Length), nullable, nil + case *types.FixedBinaryType: + return &arrow.FixedSizeBinaryType{ByteWidth: int(t.Length)}, nullable, nil + case *types.DecimalType: + return &arrow.Decimal128Type{ + Precision: t.Precision, + Scale: t.Scale, + }, nullable, nil + case *types.StructType: + i := 0 + fields, err := FieldsFromSubstrait(t.Types, func() string { + i++ + return strconv.Itoa(i) + }, ext) + if err != nil { + return nil, false, err + } + + return arrow.StructOf(fields...), nullable, nil + case *types.ListType: + elem, elemNullable, err := FromSubstraitType(t.Type, ext) + if err != nil { + return nil, false, err + } + return arrow.ListOfField(arrow.Field{Name: "item", Type: elem, Nullable: elemNullable}), + nullable, nil + case *types.MapType: + key, keyNullable, err := FromSubstraitType(t.Key, ext) + if err != nil { + return nil, false, err + } + if keyNullable { + return nil, false, fmt.Errorf("%w: encountered nullable key field when converting to arrow.Map", + arrow.ErrInvalid) + } + + value, valueNullable, err := FromSubstraitType(t.Value, ext) + if err != nil { + return nil, false, err + } + ret := arrow.MapOf(key, value) + ret.SetItemNullable(valueNullable) + return ret, nullable, nil + case *types.UserDefinedType: + anchor := t.TypeReference + _, dt, ok := ext.DecodeTypeArrow(anchor) + if !ok { + return nil, false, arrow.ErrNotImplemented + } + return dt, nullable, nil + } + + return nil, false, arrow.ErrNotImplemented +} diff --git a/go/arrow/compute/internal/kernels/cast.go b/go/arrow/compute/internal/kernels/cast.go index f1cdb799a68aa..80be6ca15cc25 100644 --- a/go/arrow/compute/internal/kernels/cast.go +++ b/go/arrow/compute/internal/kernels/cast.go @@ -46,6 +46,7 @@ type CastState = CastOptions // This can be used for casting a type to itself, or for casts between // equivalent representations such as Int32 and Date32. func ZeroCopyCastExec(_ *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + out.Release() dt := out.Type *out = batch.Values[0].Array out.Type = dt diff --git a/go/arrow/errors.go b/go/arrow/errors.go index b4a11b952c059..72e6fd8bf934e 100644 --- a/go/arrow/errors.go +++ b/go/arrow/errors.go @@ -24,4 +24,5 @@ var ( ErrType = errors.New("type error") ErrKey = errors.New("key error") ErrIndex = errors.New("index error") + ErrNotFound = errors.New("not found") ) diff --git a/go/go.mod b/go/go.mod index 619efd38e189f..9dd89db8aa404 100644 --- a/go/go.mod +++ b/go/go.mod @@ -23,45 +23,52 @@ require ( github.com/andybalholm/brotli v1.0.4 github.com/apache/thrift v0.16.0 github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 - github.com/goccy/go-json v0.9.11 + github.com/goccy/go-json v0.10.0 github.com/golang/snappy v0.0.4 - github.com/google/flatbuffers v2.0.8+incompatible - github.com/google/uuid v1.3.0 + github.com/google/flatbuffers v23.1.21+incompatible github.com/klauspost/asmfmt v1.3.2 - github.com/klauspost/compress v1.15.9 - github.com/klauspost/cpuid/v2 v2.0.9 + github.com/klauspost/compress v1.15.15 + github.com/klauspost/cpuid/v2 v2.2.3 github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 - github.com/pierrec/lz4/v4 v4.1.15 - github.com/stretchr/testify v1.8.0 + github.com/pierrec/lz4/v4 v4.1.17 + github.com/stretchr/testify v1.8.1 github.com/zeebo/xxh3 v1.0.2 - golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 + golang.org/x/exp v0.0.0-20230206171751-46f607a40771 golang.org/x/sync v0.1.0 golang.org/x/sys v0.5.0 golang.org/x/tools v0.6.0 - golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f - gonum.org/v1/gonum v0.11.0 - google.golang.org/grpc v1.49.0 + golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 + gonum.org/v1/gonum v0.12.0 + google.golang.org/grpc v1.53.0 google.golang.org/protobuf v1.28.1 - modernc.org/sqlite v1.18.2 + modernc.org/sqlite v1.20.4 +) + +require ( + github.com/google/uuid v1.3.0 + github.com/substrait-io/substrait-go v0.2.1-0.20230517203920-30fa08bd57d0 ) require ( + github.com/alecthomas/participle/v2 v2.0.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/fatih/color v1.13.0 // indirect + github.com/goccy/go-yaml v1.9.8 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/kr/pretty v0.3.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.17 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect - github.com/stretchr/objx v0.4.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect golang.org/x/mod v0.8.0 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/text v0.7.0 // indirect - google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc // indirect gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/uint128 v1.2.0 // indirect modernc.org/cc/v3 v3.40.0 // indirect diff --git a/go/go.sum b/go/go.sum index 75fafd93a4530..0ccd809f50fae 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,13 +1,13 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= +github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk= +github.com/alecthomas/participle/v2 v2.0.0 h1:Fgrq+MbuSsJwIkw3fEj9h75vDP0Er5JzepJ0/HNHv0g= +github.com/alecthomas/participle/v2 v2.0.0/go.mod h1:rAKZdJldHu8084ojcWevWAL8KmEU+AT+Olodb+WoN2Y= +github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/apache/thrift v0.16.0 h1:qEy6UW60iVOlUy+b9ZR0d5WzUWYGOo4HfopoyBaNmoY= github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -16,52 +16,54 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 h1:bWDMxwH3px2JBh github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= -github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= +github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= +github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA= +github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-yaml v1.9.8 h1:5gMyLUeU1/6zl+WFfR1hN7D2kf+1/eRGa7DFtToiBvQ= +github.com/goccy/go-yaml v1.9.8/go.mod h1:JubOolP3gh0HpiBc4BLRD4YmjEjHAmIIB2aaXKkTfoE= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/flatbuffers v2.0.8+incompatible h1:ivUb1cGomAB101ZM1T0nOiWz9pSrTMoa9+EiY7igmkM= -github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/flatbuffers v23.1.21+incompatible h1:bUqzx/MXCDxuS0hRJL2EfjyZL3uQrPbMocUa8zGqsTA= +github.com/google/flatbuffers v23.1.21+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= -github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= -github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= +github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= +github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU= +github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= @@ -69,106 +71,85 @@ github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpsp github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= -github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= -github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc= +github.com/pierrec/lz4/v4 v4.1.17/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/substrait-io/substrait-go v0.2.1-0.20230517203920-30fa08bd57d0 h1:ULhfcCHY7uxA133qmInVpNpqfjyicryPXIaxCjbDVbw= +github.com/substrait-io/substrait-go v0.2.1-0.20230517203920-30fa08bd57d0/go.mod h1:qhpnLmrcvAnlZsUyPXZRqldiHapPTXC3t7xFgDi3aQg= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw= -golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= +golang.org/x/exp v0.0.0-20230206171751-46f607a40771 h1:xP7rWLUr1e1n2xkK5YB4LI0hPEy3LJC6Wk+D4pGlOJg= +golang.org/x/exp v0.0.0-20230206171751-46f607a40771/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f h1:uF6paiQQebLeSXkrTqHqz0MXhXXS1KgF41eUdBNvxK0= -golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= -gonum.org/v1/gonum v0.11.0 h1:f1IJhK4Km5tBJmaiJXtk/PkL4cdVX6J+tGiM187uT5E= -gonum.org/v1/gonum v0.11.0/go.mod h1:fSG4YDCxxUZQJ7rKsQrj0gMOg00Il0Z96/qMA4bVQhA= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.49.0 h1:WTLtQzmQori5FUH25Pq4WT22oCsv8USpQ+F6rqtsmxw= -google.golang.org/grpc v1.49.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +gonum.org/v1/gonum v0.12.0 h1:xKuo6hzt+gMav00meVPUlXwSdoEJP46BR+wdxQEFK2o= +gonum.org/v1/gonum v0.12.0/go.mod h1:73TDxJfAAHeA8Mk9mf8NlIppyhQNo5GLTcYeqgo2lvY= +google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc h1:ijGwO+0vL2hJt5gaygqP2j6PfflOBrRot0IczKbmtio= +google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= +google.golang.org/grpc v1.53.0 h1:LAv2ds7cmFV/XTS3XG1NneeENYrXGmorPxsBbptIjNc= +google.golang.org/grpc v1.53.0/go.mod h1:OnIrk0ipVdj4N5d9IUoFUx72/VlD7+jUsHwZgwSMQpw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= @@ -185,11 +166,11 @@ modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sqlite v1.18.2 h1:S2uFiaNPd/vTAP/4EmyY8Qe2Quzu26A2L1e25xRNTio= -modernc.org/sqlite v1.18.2/go.mod h1:kvrTLEWgxUcHa2GfHBQtanR1H9ht3hTJNtKpzH9k1u0= +modernc.org/sqlite v1.20.4 h1:J8+m2trkN+KKoE7jglyHYYYiaq5xmz2HoHJIiBlRzbE= +modernc.org/sqlite v1.20.4/go.mod h1:zKcGyrICaxNTMEHSr1HQ2GUraP0j+845GYw37+EyT6A= modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= -modernc.org/tcl v1.13.2 h1:5PQgL/29XkQ9wsEmmNPjzKs+7iPCaYqUJAhzPvQbjDA= +modernc.org/tcl v1.15.0 h1:oY+JeD11qVVSgVvodMJsu7Edf8tr5E/7tuhF5cNYz34= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= -modernc.org/z v1.5.1 h1:RTNHdsrOpeoSeOF4FbzTo8gBYByaJ5xT7NgZ9ZqRiJM= +modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE= diff --git a/go/internal/hashing/hash_funcs.go b/go/internal/hashing/hash_funcs.go index 1a859198e95da..c1bdfeb6ddf6e 100644 --- a/go/internal/hashing/hash_funcs.go +++ b/go/internal/hashing/hash_funcs.go @@ -53,9 +53,9 @@ func hashFloat64(val float64, alg uint64) uint64 { var exprimes = [2]uint64{1609587929392839161, 9650029242287828579} // for smaller amounts of bytes this is faster than even calling into -// xxh3 to do the hash, so we specialize in order to get the benefits +// xxh3 to do the Hash, so we specialize in order to get the benefits // of that performance. -func hash(b []byte, alg uint64) uint64 { +func Hash(b []byte, alg uint64) uint64 { n := uint32(len(b)) if n <= 16 { switch { diff --git a/go/internal/hashing/hash_string.go b/go/internal/hashing/hash_string.go index 6630010ba04a5..6cd49517184c3 100644 --- a/go/internal/hashing/hash_string.go +++ b/go/internal/hashing/hash_string.go @@ -22,5 +22,5 @@ import "unsafe" func hashString(val string, alg uint64) uint64 { buf := unsafe.Slice(unsafe.StringData(val), len(val)) - return hash(buf, alg) + return Hash(buf, alg) } diff --git a/go/internal/hashing/hash_string_go1.19.go b/go/internal/hashing/hash_string_go1.19.go index 8a799062938e6..a421d28409b5e 100644 --- a/go/internal/hashing/hash_string_go1.19.go +++ b/go/internal/hashing/hash_string_go1.19.go @@ -26,5 +26,5 @@ import ( func hashString(val string, alg uint64) uint64 { buf := *(*[]byte)(unsafe.Pointer(&val)) (*reflect.SliceHeader)(unsafe.Pointer(&buf)).Cap = len(val) - return hash(buf, alg) + return Hash(buf, alg) } diff --git a/go/internal/hashing/hashing_test.go b/go/internal/hashing/hashing_test.go index 875424a9d494f..4527f5f8196b7 100644 --- a/go/internal/hashing/hashing_test.go +++ b/go/internal/hashing/hashing_test.go @@ -89,11 +89,11 @@ func TestHashingBoundsStrings(t *testing.T) { str[idx] = uint8(idx) } - h := hash(str, 1) + h := Hash(str, 1) diff := 0 for i := 0; i < 120; i++ { str[len(str)-1] = uint8(i) - if hash(str, 1) != h { + if Hash(str, 1) != h { diff++ } } diff --git a/go/internal/hashing/xxh3_memo_table.go b/go/internal/hashing/xxh3_memo_table.go index 9e0cc96a04c93..5ec4d80d4bea4 100644 --- a/go/internal/hashing/xxh3_memo_table.go +++ b/go/internal/hashing/xxh3_memo_table.go @@ -195,11 +195,11 @@ func (BinaryMemoTable) getHash(val interface{}) uint64 { case string: return hashString(v, 0) case []byte: - return hash(v, 0) + return Hash(v, 0) case parquet.ByteArray: - return hash(*(*[]byte)(unsafe.Pointer(&v)), 0) + return Hash(*(*[]byte)(unsafe.Pointer(&v)), 0) case parquet.FixedLenByteArray: - return hash(*(*[]byte)(unsafe.Pointer(&v)), 0) + return Hash(*(*[]byte)(unsafe.Pointer(&v)), 0) default: panic("invalid type for binarymemotable") } From 08caf008b1bf2e93c70ae7db7531425bb9d4592e Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 7 Jun 2023 14:43:04 -0400 Subject: [PATCH 24/31] GH-35982: [Go] Fix go1.18 broken builds (#35983) ### Rationale for this change Fix usage of deprecated functions to fix the Go 1.18 builds. * Closes: #35982 Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/compute/exprs/exec.go | 4 ++-- go/arrow/compute/exprs/types.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/arrow/compute/exprs/exec.go b/go/arrow/compute/exprs/exec.go index 74e6435b8dafb..a8305cade2059 100644 --- a/go/arrow/compute/exprs/exec.go +++ b/go/arrow/compute/exprs/exec.go @@ -285,7 +285,7 @@ func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s", arrow.ErrInvalid, mapType, scalarKey.DataType()) } - if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) { + if !arrow.TypeEqual(mapType.ItemType(), scalarValue.DataType()) { return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s", arrow.ErrInvalid, mapType, scalarValue.DataType()) } @@ -293,7 +293,7 @@ func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) keys[i], values[i] = scalarKey, scalarValue } - keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType()) + keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ItemType()) defer keyBldr.Release() defer valBldr.Release() diff --git a/go/arrow/compute/exprs/types.go b/go/arrow/compute/exprs/types.go index f169f18d4bca7..8e6934ae50cfc 100644 --- a/go/arrow/compute/exprs/types.go +++ b/go/arrow/compute/exprs/types.go @@ -623,7 +623,7 @@ func ToSubstraitType(dt arrow.DataType, nullable bool, ext ExtensionIDSet) (type if err != nil { return nil, err } - valueType, err := ToSubstraitType(dt.ValueType(), dt.ValueField().Nullable, ext) + valueType, err := ToSubstraitType(dt.Elem(), dt.ElemField().Nullable, ext) if err != nil { return nil, err } From 9dd4c525289da9c65987fea99865579bba0e5bd0 Mon Sep 17 00:00:00 2001 From: Alex Shcherbakov Date: Wed, 7 Jun 2023 22:09:28 +0300 Subject: [PATCH 25/31] MINOR: [Go] Use proper methods in compute (#35976) Blocks #35973 @ zeroshade I believe this is a proper code (use Item instead of Value for map type). I'll leave the filling in & the issue creation up to you. ### Rationale for this change ### What changes are included in this PR? ### Are these changes tested? ### Are there any user-facing changes? Lead-authored-by: candiduslynx Co-authored-by: Alex Shcherbakov Signed-off-by: Matt Topol --- go/arrow/compute/exprs/exec.go | 2 +- go/arrow/compute/exprs/types.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/go/arrow/compute/exprs/exec.go b/go/arrow/compute/exprs/exec.go index a8305cade2059..97b16ede11464 100644 --- a/go/arrow/compute/exprs/exec.go +++ b/go/arrow/compute/exprs/exec.go @@ -286,7 +286,7 @@ func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) arrow.ErrInvalid, mapType, scalarKey.DataType()) } if !arrow.TypeEqual(mapType.ItemType(), scalarValue.DataType()) { - return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s", + return nil, fmt.Errorf("%w: value type mismatch for %s, got value with type %s", arrow.ErrInvalid, mapType, scalarValue.DataType()) } diff --git a/go/arrow/compute/exprs/types.go b/go/arrow/compute/exprs/types.go index 8e6934ae50cfc..b9bb9f492a752 100644 --- a/go/arrow/compute/exprs/types.go +++ b/go/arrow/compute/exprs/types.go @@ -623,7 +623,7 @@ func ToSubstraitType(dt arrow.DataType, nullable bool, ext ExtensionIDSet) (type if err != nil { return nil, err } - valueType, err := ToSubstraitType(dt.Elem(), dt.ElemField().Nullable, ext) + valueType, err := ToSubstraitType(dt.ItemType(), dt.ItemField().Nullable, ext) if err != nil { return nil, err } From 5e5002de6ed1ef0f6e430a7188f41d5779171ca3 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 7 Jun 2023 15:13:42 -0400 Subject: [PATCH 26/31] GH-35975: [Go] Support importing decimal256 (#35981) ### Rationale for this change We didn't implement this type. ### What changes are included in this PR? We now import this type. (And handle some other error cases.) ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * Closes: #35975 Authored-by: David Li Signed-off-by: Matt Topol --- go/arrow/cdata/cdata.go | 33 +++++++++++++++++++++++++++------ go/arrow/cdata/cdata_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/go/arrow/cdata/cdata.go b/go/arrow/cdata/cdata.go index e20a12bbdf61c..3756bd2b533ec 100644 --- a/go/arrow/cdata/cdata.go +++ b/go/arrow/cdata/cdata.go @@ -232,14 +232,35 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) { case "d": // decimal types are d:,[,] size is assumed 128 if left out props := typs[1] propList := strings.Split(props, ",") - if len(propList) == 3 { - err = xerrors.New("only decimal128 is supported") - return + bitwidth := 128 + var precision, scale int + + if len(propList) < 2 || len(propList) > 3 { + return ret, xerrors.Errorf("invalid decimal spec '%s': wrong number of properties", f) + } else if len(propList) == 3 { + bitwidth, err = strconv.Atoi(propList[2]) + if err != nil { + return ret, xerrors.Errorf("could not parse decimal bitwidth in '%s': %s", f, err.Error()) + } + } + + precision, err = strconv.Atoi(propList[0]) + if err != nil { + return ret, xerrors.Errorf("could not parse decimal precision in '%s': %s", f, err.Error()) } - precision, _ := strconv.Atoi(propList[0]) - scale, _ := strconv.Atoi(propList[1]) - dt = &arrow.Decimal128Type{Precision: int32(precision), Scale: int32(scale)} + scale, err = strconv.Atoi(propList[1]) + if err != nil { + return ret, xerrors.Errorf("could not parse decimal scale in '%s': %s", f, err.Error()) + } + + if bitwidth == 128 { + dt = &arrow.Decimal128Type{Precision: int32(precision), Scale: int32(scale)} + } else if bitwidth == 256 { + dt = &arrow.Decimal256Type{Precision: int32(precision), Scale: int32(scale)} + } else { + return ret, xerrors.Errorf("only decimal128 and decimal256 are supported, got '%s'", f) + } } if f[0] == '+' { // types with children diff --git a/go/arrow/cdata/cdata_test.go b/go/arrow/cdata/cdata_test.go index f0eb4b7e84bdf..8acdbfbf57f58 100644 --- a/go/arrow/cdata/cdata_test.go +++ b/go/arrow/cdata/cdata_test.go @@ -122,6 +122,7 @@ func TestPrimitiveSchemas(t *testing.T) { {&arrow.Decimal128Type{Precision: 16, Scale: 4}, "d:16,4"}, {&arrow.Decimal128Type{Precision: 15, Scale: 0}, "d:15,0"}, {&arrow.Decimal128Type{Precision: 15, Scale: -4}, "d:15,-4"}, + {&arrow.Decimal256Type{Precision: 15, Scale: -4}, "d:15,-4,256"}, } for _, tt := range tests { @@ -138,6 +139,31 @@ func TestPrimitiveSchemas(t *testing.T) { } } +func TestDecimalSchemaErrors(t *testing.T) { + tests := []struct { + fmt string + errorMessage string + }{ + {"d:", "invalid decimal spec 'd:': wrong number of properties"}, + {"d:1", "invalid decimal spec 'd:1': wrong number of properties"}, + {"d:1,2,3,4", "invalid decimal spec 'd:1,2,3,4': wrong number of properties"}, + {"d:a,2,3", "could not parse decimal precision in 'd:a,2,3':"}, + {"d:1,a,3", "could not parse decimal scale in 'd:1,a,3':"}, + {"d:1,2,a", "could not parse decimal bitwidth in 'd:1,2,a':"}, + {"d:1,2,384", "only decimal128 and decimal256 are supported, got 'd:1,2,384'"}, + } + + for _, tt := range tests { + t.Run(tt.fmt, func(t *testing.T) { + sc := testPrimitive(tt.fmt) + + _, err := ImportCArrowField(&sc) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMessage) + }) + } +} + func TestImportTemporalSchema(t *testing.T) { tests := []struct { typ arrow.DataType From 16328f0ccc73b7df665b4a18feb6adf26b7aa0e2 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 7 Jun 2023 15:31:06 -0400 Subject: [PATCH 27/31] GH-35974: [Go] Don't panic if importing C Array Stream fails (#35978) ### Rationale for this change Panicking is rude. ### What changes are included in this PR? Import C Array Stream schemas up front and report the error. ### Are these changes tested? Yes. ### Are there any user-facing changes? `ImportCArrayStream` (which cannot fail, only panic) is deprecated in favor of `ImportCRecordReader` (which can return an error). * Closes: #35974 Authored-by: David Li Signed-off-by: David Li --- go/arrow/cdata/cdata.go | 30 +++++++++++------------ go/arrow/cdata/cdata_fulltest.c | 32 +++++++++++++++++++++++++ go/arrow/cdata/cdata_test.go | 16 +++++++++++++ go/arrow/cdata/cdata_test_framework.go | 33 ++++++++++++++++++++++++-- go/arrow/cdata/interface.go | 26 ++++++++++++++++++-- 5 files changed, 118 insertions(+), 19 deletions(-) diff --git a/go/arrow/cdata/cdata.go b/go/arrow/cdata/cdata.go index 3756bd2b533ec..3d209e65a5d69 100644 --- a/go/arrow/cdata/cdata.go +++ b/go/arrow/cdata/cdata.go @@ -738,7 +738,7 @@ func importCArrayAsType(arr *CArrowArray, dt arrow.DataType) (imp *cimporter, er return } -func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) { +func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) error { rdr.stream = C.get_stream() C.ArrowArrayStreamMove(stream, rdr.stream) rdr.arr = C.get_arr() @@ -751,6 +751,20 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) { C.free(unsafe.Pointer(r.stream)) C.free(unsafe.Pointer(r.arr)) }) + + var sc CArrowSchema + errno := C.stream_get_schema(rdr.stream, &sc) + if errno != 0 { + return rdr.getError(int(errno)) + } + defer C.ArrowSchemaRelease(&sc) + s, err := ImportCArrowSchema((*CArrowSchema)(&sc)) + if err != nil { + return err + } + rdr.schema = s + + return nil } // Record Batch reader that conforms to arrio.Reader for the ArrowArrayStream interface @@ -823,20 +837,6 @@ func (n *nativeCRecordBatchReader) next() error { } func (n *nativeCRecordBatchReader) Schema() *arrow.Schema { - if n.schema == nil { - var sc CArrowSchema - errno := C.stream_get_schema(n.stream, &sc) - if errno != 0 { - panic(n.getError(int(errno))) - } - defer C.ArrowSchemaRelease(&sc) - s, err := ImportCArrowSchema((*CArrowSchema)(&sc)) - if err != nil { - panic(err) - } - - n.schema = s - } return n.schema } diff --git a/go/arrow/cdata/cdata_fulltest.c b/go/arrow/cdata/cdata_fulltest.c index 4731c0ef398ad..b85e1e8310f94 100644 --- a/go/arrow/cdata/cdata_fulltest.c +++ b/go/arrow/cdata/cdata_fulltest.c @@ -18,6 +18,7 @@ // +build test #include +#include #include #include #include @@ -415,3 +416,34 @@ int test_exported_stream(struct ArrowArrayStream* stream) { } return 0; } + +struct FallibleStream { + // empty structs are a GNU extension + int dummy; +}; + +const char* FallibleGetLastError(struct ArrowArrayStream* stream) { + return "Expected error message"; +} + +int FallibleGetSchema(struct ArrowArrayStream* stream, struct ArrowSchema* schema) { + return EINVAL; +} + +int FallibleGetNext(struct ArrowArrayStream* stream, struct ArrowArray* array) { + return EINVAL; +} + +void FallibleRelease(struct ArrowArrayStream* stream) { + memset(stream, 0, sizeof(*stream)); +} + +static struct FallibleStream kFallibleStream; + +void test_stream_schema_fallible(struct ArrowArrayStream* stream) { + stream->get_last_error = FallibleGetLastError; + stream->get_schema = FallibleGetSchema; + stream->get_next = FallibleGetNext; + stream->private_data = &kFallibleStream; + stream->release = FallibleRelease; +} diff --git a/go/arrow/cdata/cdata_test.go b/go/arrow/cdata/cdata_test.go index 8acdbfbf57f58..6dcabf39c5536 100644 --- a/go/arrow/cdata/cdata_test.go +++ b/go/arrow/cdata/cdata_test.go @@ -924,3 +924,19 @@ func TestRecordReaderError(t *testing.T) { } assert.Contains(t, err.Error(), "Expected error message") } + +func TestRecordReaderImportError(t *testing.T) { + // Regression test for apache/arrow#35974 + + err := fallibleSchemaTestDeprecated() + if err == nil { + t.Fatalf("Expected error but got nil") + } + assert.Contains(t, err.Error(), "Expected error message") + + err = fallibleSchemaTest() + if err == nil { + t.Fatalf("Expected error but got nil") + } + assert.Contains(t, err.Error(), "Expected error message") +} diff --git a/go/arrow/cdata/cdata_test_framework.go b/go/arrow/cdata/cdata_test_framework.go index 131733ad90315..fb6122964168b 100644 --- a/go/arrow/cdata/cdata_test_framework.go +++ b/go/arrow/cdata/cdata_test_framework.go @@ -55,6 +55,7 @@ package cdata // struct ArrowSchema** test_schema(const char** fmts, const char** names, int64_t* flags, const int n); // struct ArrowSchema** test_union(const char** fmts, const char** names, int64_t* flags, const int n); // int test_exported_stream(struct ArrowArrayStream* stream); +// void test_stream_schema_fallible(struct ArrowArrayStream* stream); import "C" import ( "errors" @@ -309,10 +310,14 @@ func exportedStreamTest(reader array.RecordReader) error { func roundTripStreamTest(reader array.RecordReader) error { out := C.get_test_stream() ExportRecordReader(reader, out) - rdr := ImportCArrayStream(out, nil) + rdr, err := ImportCRecordReader(out, nil) + + if err != nil { + return err + } for { - _, err := rdr.Read() + _, err = rdr.Read() if errors.Is(err, io.EOF) { break } else if err != nil { @@ -321,3 +326,27 @@ func roundTripStreamTest(reader array.RecordReader) error { } return nil } + +func fallibleSchemaTestDeprecated() (err error) { + stream := CArrowArrayStream{} + C.test_stream_schema_fallible(&stream) + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("Panicked: %#v", r) + } + }() + _ = ImportCArrayStream(&stream, nil) + return nil +} + +func fallibleSchemaTest() error { + stream := CArrowArrayStream{} + C.test_stream_schema_fallible(&stream) + + _, err := ImportCRecordReader(&stream, nil) + if err != nil { + return err + } + return nil +} diff --git a/go/arrow/cdata/interface.go b/go/arrow/cdata/interface.go index 4ac8cfdf31ffd..64b8176ad221a 100644 --- a/go/arrow/cdata/interface.go +++ b/go/arrow/cdata/interface.go @@ -164,10 +164,32 @@ func ImportCRecordBatch(arr *CArrowArray, sc *CArrowSchema) (arrow.Record, error // // NOTE: The reader takes ownership of the underlying memory buffers via ArrowArrayStreamMove, // it does not take ownership of the actual stream object itself. +// +// Deprecated: This will panic if importing the schema fails (which is possible). +// Prefer ImportCRecordReader instead. func ImportCArrayStream(stream *CArrowArrayStream, schema *arrow.Schema) arrio.Reader { + reader, err := ImportCRecordReader(stream, schema) + if err != nil { + panic(err) + } + return reader +} + +// ImportCStreamReader creates an arrio.Reader from an ArrowArrayStream taking ownership +// of the underlying stream object via ArrowArrayStreamMove. +// +// The records returned by this reader must be released manually after they are returned. +// The reader itself will release the stream via SetFinalizer when it is garbage collected. +// It will return (nil, io.EOF) from the Read function when there are no more records to return. +// +// NOTE: The reader takes ownership of the underlying memory buffers via ArrowArrayStreamMove, +// it does not take ownership of the actual stream object itself. +func ImportCRecordReader(stream *CArrowArrayStream, schema *arrow.Schema) (arrio.Reader, error) { out := &nativeCRecordBatchReader{schema: schema} - initReader(out, stream) - return out + if err := initReader(out, stream); err != nil { + return nil, err + } + return out, nil } // ExportArrowSchema populates the passed in CArrowSchema with the schema passed in so From 407423394cb988669a0b5716899810a924c0aa4c Mon Sep 17 00:00:00 2001 From: sgilmore10 <74676073+sgilmore10@users.noreply.github.com> Date: Wed, 7 Jun 2023 22:01:48 -0400 Subject: [PATCH 28/31] GH-35914: [MATLAB] Integrate the latest libmexclass changes to support error-handling (#35918) ### Rationale for this change This change integrates the latest version of `mathworks/libmexclass` into the MATLAB interface which enables throwing MATLAB errors. The [77f3d72](https://github.com/mathworks/libmexclass/commit/77f3d72c22a9ddab7b54ba325d757c3e82e57987) in `libmexclass` introduced the following changes: 1. Added a new class called `libmexclass::error::Error`. 2. Added a new field called `error` on the `libmexclass::proxy::method::Context` which is a `std::optional`. By default, this is a `std::nullopt`. To make MATLAB throw an error, set this optional. 3. To support throwing errors at construction, `libmexclass` now requires all `Proxy` subclasses to define a static `make` method: `static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments)`. This workflow is similar to using an `arrow::Result`object. Examples of throwing errors in MATLAB can be found on lines [45](https://github.com/mathworks/libmexclass/blob/77f3d72c22a9ddab7b54ba325d757c3e82e57987/example/proxy/Car.cpp#L45) and [94](https://github.com/mathworks/libmexclass/blob/77f3d72c22a9ddab7b54ba325d757c3e82e57987/example/proxy/Car.cpp#L94) in [libmexclass/example/proxy/Car.cpp](https://github.com/mathworks/libmexclass/blob/main/example/proxy/Car.cpp) ### What changes are included in this PR? 1. Pulled in the latest version of `libmexclass`: [77f3d72](https://github.com/mathworks/libmexclass/commit/77f3d72c22a9ddab7b54ba325d757c3e82e57987). 2. Added `static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments)` to `arrow::matlab::proxy::NumericArray`. 3. Throw an error when trying to create an unknown proxy class. ### Are these changes tested? 1. Added a new test class called `tGateway.m` that verifies `libmexclass.proxy.gateway("Create", ...)` errors if given the name of an unknown proxy class. ### Are there any user-facing changes? No. * Closes: #35914 Lead-authored-by: Sarah Gilmore Co-authored-by: sgilmore10 <74676073+sgilmore10@users.noreply.github.com> Co-authored-by: Sutou Kouhei Co-authored-by: Kevin Gurney Signed-off-by: Sutou Kouhei --- .../src/cpp/arrow/matlab/array/proxy/array.cc | 2 +- .../src/cpp/arrow/matlab/array/proxy/array.h | 2 +- .../arrow/matlab/array/proxy/numeric_array.h | 41 +++++++++++-------- matlab/src/cpp/arrow/matlab/error/error.h | 40 ++++++++++++++++++ matlab/src/cpp/arrow/matlab/proxy/factory.cc | 8 ++-- matlab/src/cpp/arrow/matlab/proxy/factory.h | 2 +- matlab/test/arrow/gateway/tGateway.m | 28 +++++++++++++ .../cmake/BuildMatlabArrowInterface.cmake | 6 ++- 8 files changed, 103 insertions(+), 26 deletions(-) create mode 100644 matlab/src/cpp/arrow/matlab/error/error.h create mode 100644 matlab/test/arrow/gateway/tGateway.m diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/array.cc b/matlab/src/cpp/arrow/matlab/array/proxy/array.cc index fc1d66ae244a3..56115f5a379d9 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/array.cc +++ b/matlab/src/cpp/arrow/matlab/array/proxy/array.cc @@ -21,7 +21,7 @@ namespace arrow::matlab::array::proxy { - Array::Array(const libmexclass::proxy::FunctionArguments& constructor_arguments) { + Array::Array() { // Register Proxy methods. REGISTER_METHOD(Array, toString); diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/array.h b/matlab/src/cpp/arrow/matlab/array/proxy/array.h index 0a69f6fcad900..859dc48571e00 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/array.h +++ b/matlab/src/cpp/arrow/matlab/array/proxy/array.h @@ -25,7 +25,7 @@ namespace arrow::matlab::array::proxy { class Array : public libmexclass::proxy::Proxy { public: - Array(const libmexclass::proxy::FunctionArguments& constructor_arguments); + Array(); virtual ~Array() {} diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/numeric_array.h b/matlab/src/cpp/arrow/matlab/array/proxy/numeric_array.h index ad2242a7559c2..019add312db5f 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/numeric_array.h +++ b/matlab/src/cpp/arrow/matlab/array/proxy/numeric_array.h @@ -17,7 +17,6 @@ #pragma once - #include "arrow/array.h" #include "arrow/array/data.h" #include "arrow/array/util.h" @@ -26,6 +25,7 @@ #include "arrow/type_traits.h" #include "arrow/matlab/array/proxy/array.h" +#include "arrow/matlab/error/error.h" #include "arrow/matlab/bit/bit_pack_matlab_logical_array.h" #include "libmexclass/proxy/Proxy.h" @@ -42,8 +42,12 @@ const uint8_t* getUnpackedValidityBitmap(const ::matlab::data::TypedArray& template class NumericArray : public arrow::matlab::array::proxy::Array { public: - NumericArray(const libmexclass::proxy::FunctionArguments& constructor_arguments) - : arrow::matlab::array::proxy::Array(constructor_arguments) { + NumericArray(const std::shared_ptr numeric_array) + : arrow::matlab::array::proxy::Array() { + array = numeric_array; + } + + static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { using ArrowType = typename arrow::CTypeTraits::ArrowType; using BuilderType = typename arrow::CTypeTraits::BuilderType; @@ -64,15 +68,16 @@ class NumericArray : public arrow::matlab::array::proxy::Array { auto unpacked_validity_bitmap = has_validity_bitmap ? getUnpackedValidityBitmap(constructor_arguments[2]) : nullptr; BuilderType builder; - auto st = builder.AppendValues(dt, numeric_mda.getNumberOfElements(), unpacked_validity_bitmap); - - // TODO: handle error case - if (st.ok()) { - auto maybe_array = builder.Finish(); - if (maybe_array.ok()) { - array = *maybe_array; - } - } + + + auto status = builder.AppendValues(dt, numeric_mda.getNumberOfElements(), unpacked_validity_bitmap); + MATLAB_ERROR_IF_NOT_OK(status, error::APPEND_VALUES_ERROR_ID); + + auto maybe_array = builder.Finish(); + MATLAB_ERROR_IF_NOT_OK(maybe_array.status(), error::BUILD_ARRAY_ERROR_ID); + + return std::make_shared>(std::move(maybe_array).ValueUnsafe()); + } else { const auto data_type = arrow::CTypeTraits::type_singleton(); const auto length = static_cast(numeric_mda.getNumberOfElements()); // cast size_t to int64_t @@ -81,11 +86,15 @@ class NumericArray : public arrow::matlab::array::proxy::Array { auto data_buffer = std::make_shared(reinterpret_cast(dt), sizeof(CType) * numeric_mda.getNumberOfElements()); - // Pack the validity bitmap values. - auto packed_validity_bitmap = has_validity_bitmap ? arrow::matlab::bit::bitPackMatlabLogicalArray(constructor_arguments[2]).ValueOrDie() : nullptr; - + std::shared_ptr packed_validity_bitmap; + if (has_validity_bitmap) { + // Pack the validity bitmap values. + auto maybe_buffer = arrow::matlab::bit::bitPackMatlabLogicalArray(constructor_arguments[2]); + MATLAB_ERROR_IF_NOT_OK(maybe_buffer.status(), error::BITPACK_VALIDITY_BITMAP_ERROR_ID); + packed_validity_bitmap = std::move(maybe_buffer).ValueUnsafe(); + } auto array_data = arrow::ArrayData::Make(data_type, length, {packed_validity_bitmap, data_buffer}); - array = arrow::MakeArray(array_data); + return std::make_shared>(arrow::MakeArray(array_data)); } } diff --git a/matlab/src/cpp/arrow/matlab/error/error.h b/matlab/src/cpp/arrow/matlab/error/error.h new file mode 100644 index 0000000000000..d54d276735269 --- /dev/null +++ b/matlab/src/cpp/arrow/matlab/error/error.h @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + + +#include "arrow/status.h" +#include "libmexclass/error/Error.h" + +#include + +#define MATLAB_ERROR_IF_NOT_OK(expr, id) \ + do { \ + arrow::Status _status = (expr); \ + if (!_status.ok()) { \ + return libmexclass::error::Error{(id), _status.message()}; \ + } \ + } while (0) + +namespace arrow::matlab::error { + // TODO: Make Error ID Enum class to avoid defining static constexpr + static const char* APPEND_VALUES_ERROR_ID = "arrow:matlab:proxy:make:FailedToAppendValues"; + static const char* BUILD_ARRAY_ERROR_ID = "arrow:matlab:proxy:make:FailedToAppendValues"; + static const char* BITPACK_VALIDITY_BITMAP_ERROR_ID = "arrow:matlab:proxy:make:FailedToBitPackValidityBitmap"; + static const char* UNKNOWN_PROXY_ERROR_ID = "arrow:matlab:proxy:UnknownProxy"; +} diff --git a/matlab/src/cpp/arrow/matlab/proxy/factory.cc b/matlab/src/cpp/arrow/matlab/proxy/factory.cc index b7f86b9fcf9b1..e159c0ea37962 100644 --- a/matlab/src/cpp/arrow/matlab/proxy/factory.cc +++ b/matlab/src/cpp/arrow/matlab/proxy/factory.cc @@ -18,12 +18,12 @@ #include "arrow/matlab/array/proxy/numeric_array.h" #include "factory.h" - +#include "arrow/matlab/error/error.h" #include namespace arrow::matlab::proxy { -std::shared_ptr Factory::make_proxy(const ClassName& class_name, const FunctionArguments& constructor_arguments) { +libmexclass::proxy::MakeResult Factory::make_proxy(const ClassName& class_name, const FunctionArguments& constructor_arguments) { // Register MATLAB Proxy classes with corresponding C++ Proxy classes. REGISTER_PROXY(arrow.array.proxy.Float32Array, arrow::matlab::array::proxy::NumericArray); REGISTER_PROXY(arrow.array.proxy.Float64Array, arrow::matlab::array::proxy::NumericArray); @@ -38,9 +38,7 @@ std::shared_ptr Factory::make_proxy(const ClassName& class_name, const Fu REGISTER_PROXY(arrow.array.proxy.Int32Array , arrow::matlab::array::proxy::NumericArray); REGISTER_PROXY(arrow.array.proxy.Int64Array , arrow::matlab::array::proxy::NumericArray); - // TODO: Decide what to do in the case that there isn't a Proxy match. - std::cout << "Did not find a matching C++ proxy for: " + class_name << std::endl; - return nullptr; + return libmexclass::error::Error{error::UNKNOWN_PROXY_ERROR_ID, "Did not find matching C++ proxy for " + class_name}; }; } diff --git a/matlab/src/cpp/arrow/matlab/proxy/factory.h b/matlab/src/cpp/arrow/matlab/proxy/factory.h index 6f68ff4ac9cba..fd41a1c10aca7 100644 --- a/matlab/src/cpp/arrow/matlab/proxy/factory.h +++ b/matlab/src/cpp/arrow/matlab/proxy/factory.h @@ -26,7 +26,7 @@ using namespace libmexclass::proxy; class Factory : public libmexclass::proxy::Factory { public: Factory() { } - virtual std::shared_ptr make_proxy(const ClassName& class_name, const FunctionArguments& constructor_arguments); + virtual libmexclass::proxy::MakeResult make_proxy(const ClassName& class_name, const FunctionArguments& constructor_arguments); }; } diff --git a/matlab/test/arrow/gateway/tGateway.m b/matlab/test/arrow/gateway/tGateway.m new file mode 100644 index 0000000000000..c2b9ef9d68c99 --- /dev/null +++ b/matlab/test/arrow/gateway/tGateway.m @@ -0,0 +1,28 @@ +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +classdef tGateway < matlab.unittest.TestCase + % Tests for libmexclass.proxy.gateway error conditions. + + methods (Test) + function UnknownProxyError(testCase) + % Verify the gateway function errors if given the name of an + % unknown proxy class. + id = "arrow:matlab:proxy:UnknownProxy"; + fcn = @()libmexclass.proxy.gateway("Create", "NotAProxyClass", {}); + testCase.verifyError(fcn, id); + end + end +end \ No newline at end of file diff --git a/matlab/tools/cmake/BuildMatlabArrowInterface.cmake b/matlab/tools/cmake/BuildMatlabArrowInterface.cmake index 0dda3fb770997..1f4ab05b0678f 100644 --- a/matlab/tools/cmake/BuildMatlabArrowInterface.cmake +++ b/matlab/tools/cmake/BuildMatlabArrowInterface.cmake @@ -24,7 +24,8 @@ set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_FETCH_CONTENT_NAME libmexclass) # libmexclass is accessible for CI without permission issues. set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_FETCH_CONTENT_GIT_REPOSITORY "https://github.com/mathworks/libmexclass.git") # Use a specific Git commit hash to avoid libmexclass version changing unexpectedly. -set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_FETCH_CONTENT_GIT_TAG "44c15d0") +set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_FETCH_CONTENT_GIT_TAG "77f3d72") + set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_FETCH_CONTENT_SOURCE_SUBDIR "libmexclass/cpp") # ------------------------------------------ @@ -34,7 +35,8 @@ set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_FETCH_CONTENT_SOURCE_SUBDIR "libmexclass/cpp set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_LIBRARY_NAME arrowproxy) set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_LIBRARY_ROOT_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/src/cpp") set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/array/proxy" - "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/bit") + "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/bit" + "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/error") set(MATLAB_ARROW_LIBMEXCLASS_CLIENT_PROXY_SOURCES "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/array/proxy/array.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/bit/bit_pack_matlab_logical_array.cc" "${CMAKE_SOURCE_DIR}/src/cpp/arrow/matlab/bit/bit_unpack_arrow_buffer.cc") From a2e1f6648f807e00fa144094222061c1b3bcb6e1 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Thu, 8 Jun 2023 16:39:12 +0900 Subject: [PATCH 29/31] GH-35961: [C++][FlightSQL] Accept Protobuf 3.12.0 or later (#35962) ### Rationale for this change We can use `optional` with Protobuf 3.12.0 by specifying `--experimental_allow_proto3_optional` explicitly. If we accept Protobuf 3.12.0 or later for Flight SQL, we can use system Protobuf on Ubuntu 22.04. Because Ubuntu 22.04 ships Protobuf 3.12.4. ### What changes are included in this PR? Specify `--experimental_allow_proto3_optional` explicitly when Protobuf is < 3.15.0. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * Closes: #35961 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/docker/debian-11-cpp.dockerfile | 3 ++- ci/docker/ubuntu-22.04-cpp.dockerfile | 5 ----- cpp/cmake_modules/ThirdpartyToolchain.cmake | 6 ++++-- cpp/src/arrow/flight/sql/CMakeLists.txt | 12 +++++++++--- cpp/src/arrow/flight/sql/client.cc | 16 ++++++++++++++++ cpp/src/arrow/flight/sql/server.cc | 8 ++++++++ dev/tasks/linux-packages/apache-arrow/Rakefile | 16 +++------------- 7 files changed, 42 insertions(+), 24 deletions(-) diff --git a/ci/docker/debian-11-cpp.dockerfile b/ci/docker/debian-11-cpp.dockerfile index e6ac9e6071db3..00adc6bd6b3c9 100644 --- a/ci/docker/debian-11-cpp.dockerfile +++ b/ci/docker/debian-11-cpp.dockerfile @@ -65,6 +65,8 @@ RUN apt-get update -y -q && \ libgoogle-glog-dev \ libgrpc++-dev \ liblz4-dev \ + libprotobuf-dev \ + libprotoc-dev \ libre2-dev \ libsnappy-dev \ libssl-dev \ @@ -121,5 +123,4 @@ ENV absl_SOURCE=BUNDLED \ GTest_SOURCE=BUNDLED \ ORC_SOURCE=BUNDLED \ PATH=/usr/lib/ccache/:$PATH \ - Protobuf_SOURCE=BUNDLED \ xsimd_SOURCE=BUNDLED diff --git a/ci/docker/ubuntu-22.04-cpp.dockerfile b/ci/docker/ubuntu-22.04-cpp.dockerfile index 03a6bb9c7293c..e6fd44ff2d26e 100644 --- a/ci/docker/ubuntu-22.04-cpp.dockerfile +++ b/ci/docker/ubuntu-22.04-cpp.dockerfile @@ -160,11 +160,7 @@ RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin # provided by the distribution: # - Abseil is old # - libc-ares-dev does not install CMake config files -# - flatbuffer is not packaged # - libgtest-dev only provide sources -# - libprotobuf-dev only provide sources -# ARROW-17051: this build uses static Protobuf, so we must also use -# static Arrow to run Flight/Flight SQL tests ENV absl_SOURCE=BUNDLED \ ARROW_ACERO=ON \ ARROW_BUILD_STATIC=ON \ @@ -199,6 +195,5 @@ ENV absl_SOURCE=BUNDLED \ PARQUET_BUILD_EXAMPLES=ON \ PARQUET_BUILD_EXECUTABLES=ON \ PATH=/usr/lib/ccache/:$PATH \ - Protobuf_SOURCE=BUNDLED \ PYTHON=python3 \ xsimd_SOURCE=BUNDLED diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index d0398d2d0d30e..9e8ecb5ceb369 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1758,8 +1758,10 @@ endmacro() if(ARROW_WITH_PROTOBUF) if(ARROW_FLIGHT_SQL) - # Flight SQL uses proto3 optionals, which require 3.15 or later. - set(ARROW_PROTOBUF_REQUIRED_VERSION "3.15.0") + # Flight SQL uses proto3 optionals, which require 3.12 or later. + # 3.12.0-3.14.0: need --experimental_allow_proto3_optional + # 3.15.0-: don't need --experimental_allow_proto3_optional + set(ARROW_PROTOBUF_REQUIRED_VERSION "3.12.0") elseif(ARROW_SUBSTRAIT) # Substrait protobuf files use proto3 syntax set(ARROW_PROTOBUF_REQUIRED_VERSION "3.0.0") diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index 628b02b9d2811..0f12dbfdf90be 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -27,10 +27,16 @@ set(FLIGHT_SQL_GENERATED_PROTO_FILES "${CMAKE_CURRENT_BINARY_DIR}/FlightSql.pb.c set(PROTO_DEPENDS ${FLIGHT_SQL_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF}) +set(FLIGHT_SQL_PROTOC_COMMAND + ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_SQL_PROTO_PATH}" + "--cpp_out=dllexport_decl=ARROW_FLIGHT_SQL_EXPORT:${CMAKE_CURRENT_BINARY_DIR}") +if(Protobuf_VERSION VERSION_LESS 3.15) + list(APPEND FLIGHT_SQL_PROTOC_COMMAND "--experimental_allow_proto3_optional") +endif() +list(APPEND FLIGHT_SQL_PROTOC_COMMAND "${FLIGHT_SQL_PROTO}") + add_custom_command(OUTPUT ${FLIGHT_SQL_GENERATED_PROTO_FILES} - COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_SQL_PROTO_PATH}" - "--cpp_out=dllexport_decl=ARROW_FLIGHT_SQL_EXPORT:${CMAKE_CURRENT_BINARY_DIR}" - "${FLIGHT_SQL_PROTO}" + COMMAND ${FLIGHT_SQL_PROTOC_COMMAND} DEPENDS ${PROTO_DEPENDS}) set_source_files_properties(${FLIGHT_SQL_GENERATED_PROTO_FILES} PROPERTIES GENERATED TRUE) diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index 25bf8e384ef06..b0d77563bc330 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -41,14 +41,22 @@ namespace { arrow::Result GetFlightDescriptorForCommand( const google::protobuf::Message& command) { google::protobuf::Any any; +#if PROTOBUF_VERSION >= 3015000 if (!any.PackFrom(command)) { return Status::SerializationError("Failed to pack ", command.GetTypeName()); } +#else + any.PackFrom(command); +#endif std::string buf; +#if PROTOBUF_VERSION >= 3015000 if (!any.SerializeToString(&buf)) { return Status::SerializationError("Failed to serialize ", command.GetTypeName()); } +#else + any.SerializeToString(&buf); +#endif return FlightDescriptor::Command(buf); } @@ -71,16 +79,24 @@ arrow::Result> GetSchemaForCommand( ::arrow::Result PackAction(const std::string& action_type, const google::protobuf::Message& message) { google::protobuf::Any any; +#if PROTOBUF_VERSION >= 3015000 if (!any.PackFrom(message)) { return Status::SerializationError("Could not pack ", message.GetTypeName(), " into Any"); } +#else + any.PackFrom(message); +#endif std::string buffer; +#if PROTOBUF_VERSION >= 3015000 if (!any.SerializeToString(&buffer)) { return Status::SerializationError("Could not serialize packed ", message.GetTypeName()); } +#else + any.SerializeToString(&buffer); +#endif Action action; action.type = action_type; diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 7621711308cd4..3975a0135331a 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -361,14 +361,22 @@ arrow::Result ParseActionEndTransactionRequest( arrow::Result PackActionResult(const google::protobuf::Message& message) { google::protobuf::Any any; +#if PROTOBUF_VERSION >= 3015000 if (!any.PackFrom(message)) { return Status::IOError("Failed to pack ", message.GetTypeName()); } +#else + any.PackFrom(message); +#endif std::string buffer; +#if PROTOBUF_VERSION >= 3015000 if (!any.SerializeToString(&buffer)) { return Status::IOError("Failed to serialize packed ", message.GetTypeName()); } +#else + any.SerializeToString(&buffer); +#endif return Result{Buffer::FromString(std::move(buffer))}; } diff --git a/dev/tasks/linux-packages/apache-arrow/Rakefile b/dev/tasks/linux-packages/apache-arrow/Rakefile index 962d34340a08c..cdc6d2cf35b66 100644 --- a/dev/tasks/linux-packages/apache-arrow/Rakefile +++ b/dev/tasks/linux-packages/apache-arrow/Rakefile @@ -107,21 +107,11 @@ class ApacheArrowPackageTask < PackageTask end def apt_prepare_debian_control_protobuf(control, target) - # Flight requires Protobuf 3.15.0 or later but Ubuntu 22.04 - # doesn't provide Protobuf 3.15.0 or later yet. - # - # See also: - # * cpp/cmake_modules/ThirdpartyToolchain.cmake - # * https://packages.debian.org/search?keywords=libprotobuf-dev - # * https://packages.ubuntu.com/search?keywords=libprotobuf-dev - # - # We can use system Protobuf without Flight because we can use - # Protobuf 3.0.0 or later without Flight. case target - when /\Adebian-bookworm/ - use_system_protobuf = "" - else + when /\Aubuntu-focal/ use_system_protobuf = "#" + else + use_system_protobuf = "" end control.gsub(/@USE_SYSTEM_PROTOBUF@/, use_system_protobuf) end From e920bed4cc0f7826cf979a36283fceed403d2860 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Fri, 9 Jun 2023 01:02:30 +0900 Subject: [PATCH 30/31] GH-35990: [CI][C++][Windows] Don't use -l for "choco list" (#35991) ### Rationale for this change Because it's removed and needless now. https://docs.chocolatey.org/en-us/guides/upgrading-to-chocolatey-v2-v6#the-list-command-now-lists-local-packages-only-and-the-local-only-and-lo-options-have-been-removed > The List Command Now Lists Local Packages Only and the --local-only and -lo Options Have Been Removed > > In version 1.0.0 of Chocolatey CLI, we added notices that the choco list command will list only local packages, and deprecated the -l and it's alias options. See this [GitHub issue for more information](https://github.com/chocolatey/choco/issues/158). We have also removed the -a and it's alias options from the list command as it no longer made sense to have that option once side-by-side installs were removed. ### What changes are included in this PR? Just removed "-l". ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * Closes: #35990 Authored-by: Sutou Kouhei Signed-off-by: Antoine Pitrou --- .github/workflows/cpp.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index b5631ca3bd82b..28cb16f354bad 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -267,7 +267,7 @@ jobs: /d 1 ` /f - name: Installed Packages - run: choco list -l + run: choco list - name: Install Dependencies run: choco install -y --no-progress openssl - name: Checkout Arrow From 8b5919d886125c3dae9dd5484f7e9e45ae8580d3 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 8 Jun 2023 14:12:49 -0400 Subject: [PATCH 31/31] GH-35515: [C++][Python] Add non decomposable aggregation UDF (#35514) ### Rationale for this change Non decomposable aggregation is aggregation that cannot be split into consume/merge/finalize. This is often when the logic rewritten with external python libraries (numpy, pandas, statmodels, etc) and those either cannot be decomposed or not worthy the effect (these are often one-off function instead of reusable one). This PR implements the support for non decomposable aggregation UDFs. The major issue with non decomposable UDF is that the UDF needs to see all data at once, unlike scalar UDF where UDF only needs to see a batch at a time. This makes non decomposable not so useful as it is same as collect all the data to a pd.DataFrame and apply the UDF on it. However, one very application of non decomposable UDF is with segmented aggregation. To refresh, segmented aggregation works on ordered data and passed one logic chunk at a time (e.g., all data with the same date). With segmented aggregation and non decomposable aggregation UDF, the user can apply any custom aggregation logic over large stream of ordered data, with the memory overhead of a single segment. ### What changes are included in this PR? This PR is currently WIP and not ready for review. So far I have implemented the minimal amount of code to make a basic test working but needs clean up, error handling etc. * [x] First round of self review * [x] Second round of self review * [x] Implement and test unary * [x] Implement and test varargs * [x] Implement and test Acero support with segmented aggregation ### Are these changes tested? Added new test calling with compute and acero. The compute tests calls the aggregation on the full array. The acero test callings the aggregation with segmented aggregation. ### Are there any user-facing changes? * Closes: #35515 Lead-authored-by: Li Jin Co-authored-by: Weston Pace Signed-off-by: Li Jin --- .../arrow/engine/substrait/extension_set.cc | 27 +-- python/pyarrow/_compute.pxd | 6 +- python/pyarrow/_compute.pyx | 161 +++++++++++---- python/pyarrow/compute.py | 3 +- python/pyarrow/conftest.py | 56 ++++++ python/pyarrow/includes/libarrow.pxd | 8 +- python/pyarrow/src/arrow/python/udf.cc | 188 +++++++++++++++++- python/pyarrow/src/arrow/python/udf.h | 11 +- python/pyarrow/tests/test_substrait.py | 156 ++++++++++++++- python/pyarrow/tests/test_udf.py | 113 ++++++++++- 10 files changed, 652 insertions(+), 77 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 5501889d7a20f..d89248383b722 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -954,7 +954,9 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( return Status::Invalid("Expected aggregate call ", call.id().uri, "#", call.id().name, " to have at least one argument"); } - case 1: { + default: { + // Handles all arity > 0 + std::shared_ptr options = nullptr; if (arrow_function_name == "stddev" || arrow_function_name == "variance") { // See the following URL for the spec of stddev and variance: @@ -981,21 +983,22 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( } fixed_arrow_func += arrow_function_name; - ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0)); - const FieldRef* arg_ref = arg.field_ref(); - if (!arg_ref) { - return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", - call.id().name, " to have a direct reference"); + std::vector target; + for (int i = 0; i < call.size(); i++) { + ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); + const FieldRef* arg_ref = arg.field_ref(); + if (!arg_ref) { + return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", + call.id().name, " to have a direct reference"); + } + // Copy arg_ref here because field_ref() return const FieldRef* + target.emplace_back(*arg_ref); } - return compute::Aggregate{std::move(fixed_arrow_func), - options ? std::move(options) : nullptr, *arg_ref, ""}; + options ? std::move(options) : nullptr, + std::move(target), ""}; } - default: - break; } - return Status::NotImplemented( - "Only nullary and unary aggregate functions are currently supported"); }; } diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 2dc0de2d0bfec..29b37da3ac4ef 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -21,11 +21,11 @@ from pyarrow.lib cimport * from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * -cdef class ScalarUdfContext(_Weakrefable): +cdef class UdfContext(_Weakrefable): cdef: - CScalarUdfContext c_context + CUdfContext c_context - cdef void init(self, const CScalarUdfContext& c_context) + cdef void init(self, const CUdfContext& c_context) cdef class FunctionOptions(_Weakrefable): diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index a5db5be551456..eaf9d1dfb65cb 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2559,7 +2559,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *: deref(pyarrow_unwrap_schema(schema).get()))) -cdef class ScalarUdfContext: +cdef class UdfContext: """ Per-invocation function context/state. @@ -2571,7 +2571,7 @@ cdef class ScalarUdfContext: raise TypeError("Do not call {}'s constructor directly" .format(self.__class__.__name__)) - cdef void init(self, const CScalarUdfContext &c_context): + cdef void init(self, const CUdfContext &c_context): self.c_context = c_context @property @@ -2620,26 +2620,26 @@ cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: return f_doc -cdef object box_scalar_udf_context(const CScalarUdfContext& c_context): - cdef ScalarUdfContext context = ScalarUdfContext.__new__(ScalarUdfContext) +cdef object box_udf_context(const CUdfContext& c_context): + cdef UdfContext context = UdfContext.__new__(UdfContext) context.init(c_context) return context -cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs): +cdef _udf_callback(user_function, const CUdfContext& c_context, inputs): """ - Helper callback function used to wrap the ScalarUdfContext from Python to C++ + Helper callback function used to wrap the UdfContext from Python to C++ execution. """ - context = box_scalar_udf_context(c_context) + context = box_udf_context(c_context) return user_function(context, *inputs) -def _get_scalar_udf_context(memory_pool, batch_length): - cdef CScalarUdfContext c_context +def _get_udf_context(memory_pool, batch_length): + cdef CUdfContext c_context c_context.pool = maybe_unbox_memory_pool(memory_pool) c_context.batch_length = batch_length - context = box_scalar_udf_context(c_context) + context = box_udf_context(c_context) return context @@ -2665,11 +2665,19 @@ cdef get_register_tabular_function(): return reg +cdef get_register_aggregate_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterAggregateFunction + return reg + + def register_scalar_function(func, function_name, function_doc, in_types, out_type, func_registry=None): """ Register a user-defined scalar function. + This API is EXPERIMENTAL. + A scalar function is a function that executes elementwise operations on arrays or scalars, i.e. a scalar function must be computed row-by-row with no state where each output row @@ -2684,17 +2692,18 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty func : callable A callable implementing the user-defined function. The first argument is the context argument of type - ScalarUdfContext. + UdfContext. Then, it must take arguments equal to the number of in_types defined. It must return an Array or Scalar matching the out_type. It must return a Scalar if all arguments are scalar, else it must return an Array. To define a varargs function, pass a callable that takes - varargs. The last in_type will be the type of all varargs + *args. The last in_type will be the type of all varargs arguments. function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). @@ -2738,9 +2747,86 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty 21 ] """ - return _register_scalar_like_function(get_register_scalar_function(), - func, function_name, function_doc, in_types, - out_type, func_registry) + return _register_user_defined_function(get_register_scalar_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_aggregate_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined non-decomposable aggregate function. + + This API is EXPERIMENTAL. + + A non-decomposable aggregation function is a function that executes + aggregate operations on the whole data that it is aggregating. + In other words, non-decomposable aggregate function cannot be + split into consume/merge/finalize steps. + + This is often used with ordered or segmented aggregation where groups + can be emit before accumulating all of the input data. + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The first argument is the context argument of type + UdfContext. + Then, it must take arguments equal to the number of + in_types defined. It must return a Scalar matching the + out_type. + To define a varargs function, pass a callable that takes + *args. The in_type needs to match in type of inputs when + the function gets called. + function_name : str + Name of the function. This name must be unique, i.e., + there should only be one function registered with + this name in the function registry. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + + Examples + -------- + >>> import numpy as np + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> + >>> func_doc = {} + >>> func_doc["summary"] = "simple median udf" + >>> func_doc["description"] = "compute median" + >>> + >>> def compute_median(ctx, array): + ... return pa.scalar(np.median(array)) + >>> + >>> func_name = "py_compute_median" + >>> in_types = {"array": pa.int64()} + >>> out_type = pa.float64() + >>> pc.register_aggregate_function(compute_median, func_name, func_doc, + ... in_types, out_type) + >>> + >>> func = pc.get_function(func_name) + >>> func.name + 'py_compute_median' + >>> answer = pc.call_function(func_name, [pa.array([20, 40])]) + >>> answer + + """ + return _register_user_defined_function(get_register_aggregate_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) def register_tabular_function(func, function_name, function_doc, in_types, out_type, @@ -2748,8 +2834,10 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t """ Register a user-defined tabular function. + This API is EXPERIMENTAL. + A tabular function is one accepting a context argument of type - ScalarUdfContext and returning a generator of struct arrays. + UdfContext and returning a generator of struct arrays. The in_types argument must be empty and the out_type argument specifies a schema. Each struct array must have field types correspoding to the schema. @@ -2759,11 +2847,12 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t func : callable A callable implementing the user-defined function. The only argument is the context argument of type - ScalarUdfContext. It must return a callable that + UdfContext. It must return a callable that returns on each invocation a StructArray matching the out_type, where an empty array indicates end. function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). @@ -2783,46 +2872,34 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t with nogil: c_type = make_shared[CStructType](deref(c_schema).fields()) out_type = pyarrow_wrap_data_type(c_type) - return _register_scalar_like_function(get_register_tabular_function(), - func, function_name, function_doc, in_types, - out_type, func_registry) + return _register_user_defined_function(get_register_tabular_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) -def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types, - out_type, func_registry=None): +def _register_user_defined_function(register_func, func, function_name, function_doc, in_types, + out_type, func_registry=None): """ - Register a user-defined scalar-like function. + Register a user-defined function. - A scalar-like function is a callable accepting a first - context argument of type ScalarUdfContext as well as - possibly additional Arrow arguments, and returning a - an Arrow result appropriate for the kind of function. - A scalar function and a tabular function are examples - for scalar-like functions. - This function is normally not called directly but via - register_scalar_function or register_tabular_function. + This method itself doesn't care about the type of the UDF + (i.e., scalar vs tabular vs aggregate) Parameters ---------- register_func: object - An object holding a CRegisterUdf in a "register_func" attribute. Use - get_register_scalar_function() for a scalar function and - get_register_tabular_function() for a tabular function. + An object holding a CRegisterUdf in a "register_func" attribute. func : callable A callable implementing the user-defined function. - See register_scalar_function and - register_tabular_function for details. - function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). in_types : Dict[str, DataType] A dictionary mapping function argument names to their respective DataType. - See register_scalar_function and - register_tabular_function for details. out_type : DataType Output type of the function. func_registry : FunctionRegistry diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index e299d44c04e16..e92f09354771f 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -84,7 +84,8 @@ call_tabular_function, register_scalar_function, register_tabular_function, - ScalarUdfContext, + register_aggregate_function, + UdfContext, # Expressions Expression, ) diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index ef09393cfbd6a..f32cbf01efcd6 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -278,3 +278,59 @@ def unary_function(ctx, x): {"array": pa.int64()}, pa.int64()) return unary_function, func_name + + +@pytest.fixture(scope="session") +def unary_agg_func_fixture(): + """ + Register a unary aggregate function + """ + from pyarrow import compute as pc + import numpy as np + + def func(ctx, x): + return pa.scalar(np.nanmean(x)) + + func_name = "y=avg(x)" + func_doc = {"summary": "y=avg(x)", + "description": "find mean of x"} + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.float64(), + }, + pa.float64() + ) + return func, func_name + + +@pytest.fixture(scope="session") +def varargs_agg_func_fixture(): + """ + Register a unary aggregate function + """ + from pyarrow import compute as pc + import numpy as np + + def func(ctx, *args): + sum = 0.0 + for arg in args: + sum += np.nanmean(arg) + return pa.scalar(sum) + + func_name = "y=sum_mean(x...)" + func_doc = {"summary": "Varargs aggregate", + "description": "Varargs aggregate"} + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + "y": pa.float64() + }, + pa.float64() + ) + return func, func_name diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 3190877ea0997..86f21f4b528e8 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2775,7 +2775,7 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: int64_t TotalBufferSize(const CRecordBatch& record_batch) int64_t TotalBufferSize(const CTable& table) -ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) +ctypedef PyObject* CallbackUdf(object user_function, const CUdfContext& context, object inputs) cdef extern from "arrow/api.h" namespace "arrow" nogil: @@ -2786,7 +2786,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: - cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": + cdef cppclass CUdfContext" arrow::py::UdfContext": CMemoryPool *pool int64_t batch_length @@ -2805,5 +2805,9 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: function[CallbackUdf] wrapper, const CUdfOptions& options, CFunctionRegistry* registry) + CStatus RegisterAggregateFunction(PyObject* function, + function[CallbackUdf] wrapper, const CUdfOptions& options, + CFunctionRegistry* registry) + CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction( const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 7d63adb8352e8..06c116af820db 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -16,14 +16,16 @@ // under the License. #include "arrow/python/udf.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" #include "arrow/python/common.h" +#include "arrow/table.h" #include "arrow/util/checked_cast.h" namespace arrow { +using internal::checked_cast; namespace py { - namespace { struct PythonUdfKernelState : public compute::KernelState { @@ -65,6 +67,26 @@ struct PythonUdfKernelInit { std::shared_ptr function; }; +struct ScalarUdfAggregator : public compute::KernelState { + virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) = 0; + virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) = 0; + virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0; +}; + +arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, + const compute::ExecSpan& batch) { + return checked_cast(ctx->state())->Consume(ctx, batch); +} + +arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, + compute::KernelState* dst) { + return checked_cast(dst)->MergeFrom(ctx, std::move(src)); +} + +arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { + return checked_cast(ctx->state())->Finalize(ctx, out); +} + struct PythonTableUdfKernelInit { PythonTableUdfKernelInit(std::shared_ptr function_maker, UdfWrapperCallback cb) @@ -82,12 +104,12 @@ struct PythonTableUdfKernelInit { Result> operator()( compute::KernelContext* ctx, const compute::KernelInitArgs&) { - ScalarUdfContext scalar_udf_context{ctx->memory_pool(), /*batch_length=*/0}; + UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; std::unique_ptr function; - RETURN_NOT_OK(SafeCallIntoPython([this, &scalar_udf_context, &function] { + RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] { OwnedRef empty_tuple(PyTuple_New(0)); function = std::make_unique( - cb(function_maker->obj(), scalar_udf_context, empty_tuple.obj())); + cb(function_maker->obj(), udf_context, empty_tuple.obj())); RETURN_NOT_OK(CheckPyError()); return Status::OK(); })); @@ -101,6 +123,105 @@ struct PythonTableUdfKernelInit { UdfWrapperCallback cb; }; +struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { + PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb, + std::shared_ptr agg_function, + std::vector> input_types, + std::shared_ptr output_type) + : agg_cb(std::move(agg_cb)), + agg_function(agg_function), + output_type(std::move(output_type)) { + Py_INCREF(agg_function->obj()); + std::vector> fields; + for (size_t i = 0; i < input_types.size(); i++) { + fields.push_back(field("", input_types[i])); + } + input_schema = schema(std::move(fields)); + }; + + ~PythonUdfScalarAggregatorImpl() override { + if (_Py_IsFinalizing()) { + agg_function->detach(); + } + } + + Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { + ARROW_ASSIGN_OR_RAISE( + auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); + values.push_back(std::move(rb)); + return Status::OK(); + } + + Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { + auto& other_values = checked_cast(src).values; + values.insert(values.end(), std::make_move_iterator(other_values.begin()), + std::make_move_iterator(other_values.end())); + + other_values.erase(other_values.begin(), other_values.end()); + return Status::OK(); + } + + Status Finalize(compute::KernelContext* ctx, Datum* out) override { + auto state = + arrow::internal::checked_cast(ctx->state()); + std::shared_ptr& function = state->agg_function; + const int num_args = input_schema->num_fields(); + + // Note: The way that batches are concatenated together + // would result in using double amount of the memory. + // This is OK for now because non decomposable aggregate + // UDF is supposed to be used with segmented aggregation + // where the size of the segment is more or less constant + // so doubling that is not a big deal. This can be also + // improved in the future to use more efficient way to + // concatenate. + ARROW_ASSIGN_OR_RAISE(auto table, + arrow::Table::FromRecordBatches(input_schema, values)); + ARROW_ASSIGN_OR_RAISE(table, table->CombineChunks(ctx->memory_pool())); + UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; + + if (table->num_rows() == 0) { + return Status::Invalid("Finalized is called with empty inputs"); + } + + RETURN_NOT_OK(SafeCallIntoPython([&] { + std::unique_ptr result; + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + + for (int arg_id = 0; arg_id < num_args; arg_id++) { + // Since we combined chunks there is only one chunk + std::shared_ptr c_data = table->column(arg_id)->chunk(0); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + result = std::make_unique( + agg_cb(function->obj(), udf_context, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + // unwrapping the output for expected output type + if (is_scalar(result->obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); + if (*output_type != *val->type) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type->ToString()); + } + out->value = std::move(val); + return Status::OK(); + } + return Status::TypeError("Unexpected output type: ", + Py_TYPE(result->obj())->tp_name, " (expected Scalar)"); + })); + return Status::OK(); + } + + UdfWrapperCallback agg_cb; + std::vector> values; + std::shared_ptr agg_function; + std::shared_ptr input_schema; + std::shared_ptr output_type; +}; + struct PythonUdf : public PythonUdfKernelState { PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, std::vector input_types, compute::OutputType output_type) @@ -130,7 +251,7 @@ struct PythonUdf : public PythonUdfKernelState { auto state = arrow::internal::checked_cast(ctx->state()); std::shared_ptr& function = state->function; const int num_args = batch.num_values(); - ScalarUdfContext scalar_udf_context{ctx->memory_pool(), batch.length}; + UdfContext udf_context{ctx->memory_pool(), batch.length}; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); @@ -146,7 +267,7 @@ struct PythonUdf : public PythonUdfKernelState { } } - OwnedRef result(cb(function->obj(), scalar_udf_context, arg_tuple.obj())); + OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_array(result.obj())) { @@ -234,6 +355,61 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp wrapper, options, registry); } +Status AddAggKernel(std::shared_ptr sig, + compute::KernelInit init, compute::ScalarAggregateFunction* func) { + compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), + AggregateUdfConsume, AggregateUdfMerge, + AggregateUdfFinalize, /*ordered=*/false); + RETURN_NOT_OK(func->AddKernel(std::move(kernel))); + return Status::OK(); +} + +Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + if (!PyCallable_Check(agg_function)) { + return Status::TypeError("Expected a callable Python object."); + } + + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } + + // Py_INCREF here so that once a function is registered + // its refcount gets increased by 1 and doesn't get gced + // if all existing refs are gone + Py_INCREF(agg_function); + + static auto default_scalar_aggregate_options = + compute::ScalarAggregateOptions::Defaults(); + auto aggregate_func = std::make_shared( + options.func_name, options.arity, options.func_doc, + &default_scalar_aggregate_options); + + std::vector input_types; + for (const auto& in_dtype : options.input_types) { + input_types.emplace_back(in_dtype); + } + compute::OutputType output_type(options.output_type); + + compute::KernelInit init = [agg_wrapper, agg_function, options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) + -> Result> { + return std::make_unique( + agg_wrapper, std::make_shared(agg_function), options.input_types, + options.output_type); + }; + + RETURN_NOT_OK(AddAggKernel( + compute::KernelSignature::Make(std::move(input_types), std::move(output_type), + options.arity.is_varargs), + init, aggregate_func.get())); + + RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); + return Status::OK(); +} + Result> CallTabularFunction( const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry) { diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index b3dcc9ccf44e9..682cbb2ffe8d5 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -43,14 +43,14 @@ struct ARROW_PYTHON_EXPORT UdfOptions { std::shared_ptr output_type; }; -/// \brief A context passed as the first argument of scalar UDF functions. -struct ARROW_PYTHON_EXPORT ScalarUdfContext { +/// \brief A context passed as the first argument of UDF functions. +struct ARROW_PYTHON_EXPORT UdfContext { MemoryPool* pool; int64_t batch_length; }; using UdfWrapperCallback = std::function; + PyObject* user_function, const UdfContext& context, PyObject* inputs)>; /// \brief register a Scalar user-defined-function from Python Status ARROW_PYTHON_EXPORT RegisterScalarFunction( @@ -62,6 +62,11 @@ Status ARROW_PYTHON_EXPORT RegisterTabularFunction( PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); +/// \brief register a Aggregate user-defined-function from Python +Status ARROW_PYTHON_EXPORT RegisterAggregateFunction( + PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, + compute::FunctionRegistry* registry = NULLPTR); + Result> ARROW_PYTHON_EXPORT CallTabularFunction(const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry = NULLPTR); diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index d0da517ea7f12..34faaa157af4d 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -34,9 +34,9 @@ pytestmark = [pytest.mark.dataset, pytest.mark.substrait] -def mock_scalar_udf_context(batch_length=10): - from pyarrow._compute import _get_scalar_udf_context - return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) +def mock_udf_context(batch_length=10): + from pyarrow._compute import _get_udf_context + return _get_udf_context(pa.default_memory_pool(), batch_length) def _write_dummy_data_to_disk(tmpdir, file_name, table): @@ -442,7 +442,7 @@ def table_provider(names, _): function, name = unary_func_fixture expected_tb = test_table.add_column(1, 'y', function( - mock_scalar_udf_context(10), test_table['x'])) + mock_udf_context(10), test_table['x'])) assert res_tb == expected_tb @@ -605,3 +605,151 @@ def table_provider(names, schema): expected = pa.Table.from_pydict({"out": [1, 2, 3]}) assert res_tb == expected + + +def test_aggregate_udf_basic(varargs_agg_func_fixture): + + test_table = pa.Table.from_pydict( + {"k": [1, 1, 2, 2], "v1": [1, 2, 3, 4], + "v2": [1.0, 1.0, 1.0, 1.0]} + ) + + def table_provider(names, _): + return test_table + + substrait_query = b""" +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "urn:arrow:substrait_simple_extension_function" + }, + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "y=sum_mean(x...)" + } + } + ], + "relations": [ + { + "root": { + "input": { + "extensionSingle": { + "common": { + "emit": { + "outputMapping": [ + 0, + 1 + ] + } + }, + "input": { + "read": { + "baseSchema": { + "names": [ + "k", + "v1", + "v2", + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["t1"] + } + } + }, + "detail": { + "@type": "/arrow.substrait_ext.SegmentedAggregateRel", + "segmentKeys": [ + { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + ], + "measures": [ + { + "measure": { + "functionReference": 1, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + } + }, + "names": [ + "k", + "v_avg" + ] + } + } + ], +} +""" + buf = pa._substrait._parse_json_plan(substrait_query) + reader = pa.substrait.run_query( + buf, table_provider=table_provider, use_threads=False) + res_tb = reader.read_all() + + expected_tb = pa.Table.from_pydict({ + 'k': [1, 2], + 'v_avg': [2.5, 4.5] + }) + + assert res_tb == expected_tb diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 0f336555f7647..c0cfd3d26e800 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -24,21 +24,82 @@ # UDFs are all tested with a dataset scan pytestmark = pytest.mark.dataset +# For convience, most of the test here doesn't care about udf func docs +empty_udf_doc = {"summary": "", "description": ""} + try: import pyarrow.dataset as ds except ImportError: ds = None -def mock_scalar_udf_context(batch_length=10): - from pyarrow._compute import _get_scalar_udf_context - return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) +def mock_udf_context(batch_length=10): + from pyarrow._compute import _get_udf_context + return _get_udf_context(pa.default_memory_pool(), batch_length) class MyError(RuntimeError): pass +@pytest.fixture(scope="session") +def exception_agg_func_fixture(): + def func(ctx, x): + raise RuntimeError("Oops") + return pa.scalar(len(x)) + + func_name = "y=exception_len(x)" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + +@pytest.fixture(scope="session") +def wrong_output_dtype_agg_func_fixture(scope="session"): + def func(ctx, x): + return pa.scalar(len(x), pa.int32()) + + func_name = "y=wrong_output_dtype(x)" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + +@pytest.fixture(scope="session") +def wrong_output_type_agg_func_fixture(scope="session"): + def func(ctx, x): + return len(x) + + func_name = "y=wrong_output_type(x)" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + @pytest.fixture(scope="session") def binary_func_fixture(): """ @@ -228,11 +289,11 @@ def check_scalar_function(func_fixture, if all_scalar: batch_length = 1 - expected_output = function(mock_scalar_udf_context(batch_length), *inputs) func = pc.get_function(name) assert func.name == name result = pc.call_function(name, inputs, length=batch_length) + expected_output = function(mock_udf_context(batch_length), *inputs) assert result == expected_output # At the moment there is an issue when handling nullary functions. # See: ARROW-15286 and ARROW-16290. @@ -593,3 +654,47 @@ def test_udt_datasource1_generator(): def test_udt_datasource1_exception(): with pytest.raises(RuntimeError, match='datasource1_exception'): _test_datasource1_udt(datasource1_exception) + + +def test_agg_basic(unary_agg_func_fixture): + arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) + result = pc.call_function("y=avg(x)", [arr]) + expected = pa.scalar(30.0) + assert result == expected + + +def test_agg_empty(unary_agg_func_fixture): + empty = pa.array([], pa.float64()) + + with pytest.raises(pa.ArrowInvalid, match='empty inputs'): + pc.call_function("y=avg(x)", [empty]) + + +def test_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50], pa.int64()) + with pytest.raises(pa.ArrowTypeError, match="output datatype"): + pc.call_function("y=wrong_output_dtype(x)", [arr]) + + +def test_agg_wrong_output_type(wrong_output_type_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50], pa.int64()) + with pytest.raises(pa.ArrowTypeError, match="output type"): + pc.call_function("y=wrong_output_type(x)", [arr]) + + +def test_agg_varargs(varargs_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) + + result = pc.call_function( + "y=sum_mean(x...)", [arr1, arr2] + ) + expected = pa.scalar(33.0) + assert result == expected + + +def test_agg_exception(exception_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) + + with pytest.raises(RuntimeError, match='Oops'): + pc.call_function("y=exception_len(x)", [arr])