diff --git a/examples/ipc_pyarrow/.gitignore b/examples/ipc_pyarrow/.gitignore new file mode 100644 index 00000000000..7d92770a26f --- /dev/null +++ b/examples/ipc_pyarrow/.gitignore @@ -0,0 +1 @@ +data.arrows diff --git a/examples/ipc_pyarrow/Cargo.toml b/examples/ipc_pyarrow/Cargo.toml new file mode 100644 index 00000000000..47bc82181b9 --- /dev/null +++ b/examples/ipc_pyarrow/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "ipc_stream" +version = "0.1.0" +edition = "2018" + +[dependencies] +arrow2 = { path = "../../", default-features = false, features = ["io_ipc"] } diff --git a/examples/ipc_pyarrow/main.py b/examples/ipc_pyarrow/main.py new file mode 100644 index 00000000000..c3b1468dea8 --- /dev/null +++ b/examples/ipc_pyarrow/main.py @@ -0,0 +1,24 @@ +import pyarrow as pa +from time import sleep +import socket + +# Set up the data exchange socket +sk = socket.socket() +sk.bind(("127.0.0.1", 12989)) +sk.listen() + +data = [ + pa.array([1, 2, 3, 4]), + pa.array(["foo", "bar", "baz", None]), + pa.array([True, None, False, True]), +] + +batch = pa.record_batch(data, names=["f0", "f1", "f2"]) + +# Accept incoming connection and stream the data away +connection, address = sk.accept() +dummy_socket_file = connection.makefile("wb") +with pa.RecordBatchStreamWriter(dummy_socket_file, batch.schema) as writer: + for i in range(50): + writer.write_batch(batch) + sleep(1) diff --git a/examples/ipc_pyarrow/run.sh b/examples/ipc_pyarrow/run.sh new file mode 100755 index 00000000000..193fc1c4183 --- /dev/null +++ b/examples/ipc_pyarrow/run.sh @@ -0,0 +1,7 @@ +python main.py & +PRODUCER_PID=$! + +sleep 1 # wait for metadata to be available. +cargo run + +kill $PRODUCER_PID diff --git a/examples/ipc_pyarrow/src/main.rs b/examples/ipc_pyarrow/src/main.rs new file mode 100644 index 00000000000..ce92e4e1b21 --- /dev/null +++ b/examples/ipc_pyarrow/src/main.rs @@ -0,0 +1,33 @@ +use std::net::TcpStream; +use std::thread; +use std::time::Duration; + +use arrow2::array::{Array, Int64Array}; +use arrow2::datatypes::DataType; +use arrow2::error::Result; +use arrow2::io::ipc::read; + +fn main() -> Result<()> { + const ADDRESS: &str = "127.0.0.1:12989"; + + let mut reader = TcpStream::connect(ADDRESS)?; + let metadata = read::read_stream_metadata(&mut reader)?; + let mut stream = read::StreamReader::new(&mut reader, metadata); + + let mut idx = 0; + loop { + match stream.next() { + Some(x) => match x { + Ok(read::StreamState::Some(b)) => { + idx += 1; + println!("batch: {:?}", idx) + } + Ok(read::StreamState::Waiting) => thread::sleep(Duration::from_millis(2000)), + Err(l) => println!("{:?} ({})", l, idx), + }, + None => break, + }; + } + + Ok(()) +} diff --git a/guide/src/SUMMARY.md b/guide/src/SUMMARY.md index cdf5aec37ee..6ba760bec79 100644 --- a/guide/src/SUMMARY.md +++ b/guide/src/SUMMARY.md @@ -13,4 +13,5 @@ - [Read Parquet](./io/parquet_read.md) - [Write Parquet](./io/parquet_write.md) - [Read Arrow](./io/ipc_read.md) + - [Read Arrow stream](./io/ipc_stream_read.md) - [Write Arrow](./io/ipc_write.md) diff --git a/guide/src/io/ipc_stream_read.md b/guide/src/io/ipc_stream_read.md new file mode 100644 index 00000000000..0ab872ac986 --- /dev/null +++ b/guide/src/io/ipc_stream_read.md @@ -0,0 +1,21 @@ +# Read Arrow streams + +When compiled with feature `io_ipc`, this crate can be used to read Arrow streams. + +The example below shows how to read from a stream: + +```rust +{{#include ../../../examples/ipc_pyarrow/src/main.rs}} +``` + +e.g. written by pyarrow: + +```python,ignore +{{#include ../../../examples/ipc_pyarrow/main.py}} +``` + +via + +```bash,ignore +{{#include ../../../examples/ipc_pyarrow/run.sh}} +``` diff --git a/integration-testing/src/bin/arrow-file-to-stream.rs b/integration-testing/src/bin/arrow-file-to-stream.rs index 9b1ae23a0b5..34da2bdf075 100644 --- a/integration-testing/src/bin/arrow-file-to-stream.rs +++ b/integration-testing/src/bin/arrow-file-to-stream.rs @@ -30,7 +30,7 @@ fn main() -> Result<()> { let mut reader = read::FileReader::new(&mut f, metadata, None); let schema = reader.schema(); - let mut writer = StreamWriter::try_new(std::io::stdout(), &schema)?; + let mut writer = StreamWriter::try_new(std::io::stdout(), schema)?; reader.try_for_each(|batch| { let batch = batch?; diff --git a/integration-testing/src/bin/arrow-stream-to-file.rs b/integration-testing/src/bin/arrow-stream-to-file.rs index c9db624d64b..c17ff7d9e91 100644 --- a/integration-testing/src/bin/arrow-stream-to-file.rs +++ b/integration-testing/src/bin/arrow-stream-to-file.rs @@ -29,9 +29,9 @@ fn main() -> Result<()> { let mut writer = io::stdout(); - let mut writer = FileWriter::try_new(&mut writer, &schema)?; + let mut writer = FileWriter::try_new(&mut writer, schema)?; - arrow_stream_reader.try_for_each(|batch| writer.write(&batch?))?; + arrow_stream_reader.try_for_each(|batch| writer.write(&batch?.unwrap()))?; writer.finish()?; Ok(()) diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index 7cb128cdf67..820c82114b1 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -131,7 +131,7 @@ async fn send_batch( options: &write::IpcWriteOptions, ) -> Result { let (dictionary_flight_data, mut batch_flight_data) = - arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); + arrow_flight::utils::flight_data_from_arrow_batch(batch, options); upload_tx .send_all(&mut stream::iter(dictionary_flight_data).map(Ok)) @@ -169,7 +169,7 @@ async fn verify_data( consume_flight_location( location, ticket.clone(), - &expected_data, + expected_data, expected_schema.clone(), ) .await?; diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index de6951fee43..16954647090 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -295,7 +295,7 @@ async fn record_batch_from_message( schema_ref, None, true, - &dictionaries_by_field, + dictionaries_by_field, MetadataVersion::V5, &mut reader, 0, diff --git a/src/io/ipc/read/mod.rs b/src/io/ipc/read/mod.rs index 8ed127e2698..76b139a589f 100644 --- a/src/io/ipc/read/mod.rs +++ b/src/io/ipc/read/mod.rs @@ -24,4 +24,4 @@ mod stream; pub use common::{read_dictionary, read_record_batch}; pub use reader::{read_file_metadata, FileMetadata, FileReader}; -pub use stream::{read_stream_metadata, StreamMetadata, StreamReader}; +pub use stream::{read_stream_metadata, StreamMetadata, StreamReader, StreamState}; diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index 5018f7953be..ed7f5955dfe 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -23,7 +23,7 @@ use gen::Schema::MetadataVersion; use crate::array::*; use crate::datatypes::Schema; use crate::error::{ArrowError, Result}; -use crate::record_batch::{RecordBatch, RecordBatchReader}; +use crate::record_batch::RecordBatch; use super::super::CONTINUATION_MARKER; use super::super::{convert, gen}; @@ -76,12 +76,27 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { }) } +pub enum StreamState { + Waiting, + Some(RecordBatch), +} + +impl StreamState { + pub fn unwrap(self) -> RecordBatch { + if let StreamState::Some(batch) = self { + batch + } else { + panic!("The batch is not available") + } + } +} + /// Reads the next item pub fn read_next( reader: &mut R, metadata: &StreamMetadata, dictionaries_by_field: &mut Vec>, -) -> Result> { +) -> Result> { // determine metadata length let mut meta_size: [u8; 4] = [0; 4]; @@ -92,7 +107,7 @@ pub fn read_next( // Handle EOF without the "0xFFFFFFFF 0x00000000" // valid according to: // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format - Ok(None) + Ok(Some(StreamState::Waiting)) } else { Err(ArrowError::from(e)) }; @@ -144,7 +159,7 @@ pub fn read_next( &mut reader, 0, ) - .map(Some) + .map(|x| Some(StreamState::Some(x))) } gen::Message::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().ok_or_else(|| { @@ -168,7 +183,7 @@ pub fn read_next( // read the next message until we encounter a RecordBatch read_next(reader, metadata, dictionaries_by_field) } - gen::Message::MessageHeader::NONE => Ok(None), + gen::Message::MessageHeader::NONE => Ok(Some(StreamState::Waiting)), t => Err(ArrowError::Ipc(format!( "Reading types other than record batches not yet supported, unable to read {:?} ", t @@ -210,32 +225,26 @@ impl StreamReader { self.finished } - fn maybe_next(&mut self) -> Result> { + fn maybe_next(&mut self) -> Result> { + if self.finished { + return Ok(None); + } let batch = read_next( &mut self.reader, &self.metadata, &mut self.dictionaries_by_field, )?; if batch.is_none() { - self.finished = false; - } - if self.finished { - return Ok(None); + self.finished = true; } Ok(batch) } } impl Iterator for StreamReader { - type Item = Result; + type Item = Result; fn next(&mut self) -> Option { self.maybe_next().transpose() } } - -impl RecordBatchReader for StreamReader { - fn schema(&self) -> &Schema { - self.metadata.schema.as_ref() - } -} diff --git a/tests/it/io/ipc/common.rs b/tests/it/io/ipc/common.rs index f9f2be91fe5..6a46c03393e 100644 --- a/tests/it/io/ipc/common.rs +++ b/tests/it/io/ipc/common.rs @@ -62,6 +62,9 @@ pub fn read_arrow_stream(version: &str, file_name: &str) -> (Schema, Vec>().unwrap(), + reader + .map(|x| x.map(|x| x.unwrap())) + .collect::>() + .unwrap(), ) } diff --git a/tests/it/io/ipc/read/stream.rs b/tests/it/io/ipc/read/stream.rs index 9aa5e31a1bc..d6ad005e78d 100644 --- a/tests/it/io/ipc/read/stream.rs +++ b/tests/it/io/ipc/read/stream.rs @@ -22,7 +22,7 @@ fn test_file(version: &str, file_name: &str) -> Result<()> { batches .iter() - .zip(reader.map(|x| x.unwrap())) + .zip(reader.map(|x| x.unwrap().unwrap())) .for_each(|(lhs, rhs)| { assert_eq!(lhs, &rhs); }); diff --git a/tests/it/io/ipc/write/stream.rs b/tests/it/io/ipc/write/stream.rs index c13831839bf..e9019d01f9b 100644 --- a/tests/it/io/ipc/write/stream.rs +++ b/tests/it/io/ipc/write/stream.rs @@ -34,7 +34,10 @@ fn test_file(version: &str, file_name: &str) { assert_eq!(schema.as_ref(), &expected_schema); - let batches = reader.collect::>>().unwrap(); + let batches = reader + .map(|x| x.map(|x| x.unwrap())) + .collect::>>() + .unwrap(); assert_eq!(batches, expected_batches); }