diff --git a/Cargo.lock b/Cargo.lock index 243e70208a..77384ee699 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -450,6 +450,16 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3b5ca7a04898ad4bcd41c90c5285445ff5b791899bb1b0abdd2a2aa791211d7" +[[package]] +name = "bytelines" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784face321c535fcd9a1456632fa720aa53ea0640b57341d961c8c09de2da59f" +dependencies = [ + "futures", + "tokio", +] + [[package]] name = "byteorder" version = "1.4.3" @@ -663,6 +673,7 @@ name = "connector_proxy" version = "0.0.0" dependencies = [ "async-trait", + "bytelines", "byteorder", "bytes", "clap 3.1.8", @@ -685,6 +696,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-stream", "tokio-util", "tracing", "validator", diff --git a/crates/connector_proxy/Cargo.toml b/crates/connector_proxy/Cargo.toml index ed5fe19d9d..3651d0ed5e 100644 --- a/crates/connector_proxy/Cargo.toml +++ b/crates/connector_proxy/Cargo.toml @@ -33,5 +33,7 @@ tempfile="*" thiserror = "*" tokio = { version = "1.15.0", features = ["full"] } tokio-util = { version = "*", features = ["io"] } +tokio-stream = { version = "*", features = ["io-util"] } +bytelines = "*" tracing="*" validator = { version = "*", features = ["derive"] } \ No newline at end of file diff --git a/crates/connector_proxy/src/interceptors/network_tunnel_capture_interceptor.rs b/crates/connector_proxy/src/interceptors/network_tunnel_capture_interceptor.rs index ece41962c0..2d05e61b72 100644 --- a/crates/connector_proxy/src/interceptors/network_tunnel_capture_interceptor.rs +++ b/crates/connector_proxy/src/interceptors/network_tunnel_capture_interceptor.rs @@ -2,14 +2,14 @@ use crate::apis::{FlowCaptureOperation, InterceptorStream}; use crate::errors::{Error, Must}; use crate::libs::network_tunnel::NetworkTunnel; use crate::libs::protobuf::{decode_message, encode_message}; -use crate::libs::stream::{get_decoded_message, stream_all_bytes}; +use crate::libs::stream::get_decoded_message; use futures::{future, stream, StreamExt, TryStreamExt}; use protocol::capture::{ ApplyRequest, DiscoverRequest, PullRequest, SpecResponse, ValidateRequest, }; use serde_json::value::RawValue; -use tokio_util::io::StreamReader; +use tokio_util::io::{ReaderStream, StreamReader}; pub struct NetworkTunnelCaptureInterceptor {} @@ -81,7 +81,7 @@ impl NetworkTunnelCaptureInterceptor { } let first = stream::once(future::ready(encode_message(&request))); - let rest = stream_all_bytes(reader); + let rest = ReaderStream::new(reader); // We need to set explicit error type, see https://github.com/rust-lang/rust/issues/63502 Ok::<_, std::io::Error>(first.chain(rest)) diff --git a/crates/connector_proxy/src/interceptors/network_tunnel_materialize_interceptor.rs b/crates/connector_proxy/src/interceptors/network_tunnel_materialize_interceptor.rs index fe1ead6840..13ba1f3272 100644 --- a/crates/connector_proxy/src/interceptors/network_tunnel_materialize_interceptor.rs +++ b/crates/connector_proxy/src/interceptors/network_tunnel_materialize_interceptor.rs @@ -2,13 +2,13 @@ use crate::apis::{FlowMaterializeOperation, InterceptorStream}; use crate::errors::{Error, Must}; use crate::libs::network_tunnel::NetworkTunnel; use crate::libs::protobuf::{decode_message, encode_message}; -use crate::libs::stream::{get_decoded_message, stream_all_bytes}; +use crate::libs::stream::get_decoded_message; use futures::{future, stream, StreamExt, TryStreamExt}; use protocol::materialize::{ApplyRequest, SpecResponse, TransactionRequest, ValidateRequest}; use serde_json::value::RawValue; -use tokio_util::io::StreamReader; +use tokio_util::io::{ReaderStream, StreamReader}; pub struct NetworkTunnelMaterializeInterceptor {} @@ -63,7 +63,7 @@ impl NetworkTunnelMaterializeInterceptor { } } let first = stream::once(future::ready(encode_message(&request))); - let rest = stream_all_bytes(reader); + let rest = ReaderStream::new(reader); // We need to set explicit error type, see https://github.com/rust-lang/rust/issues/63502 Ok::<_, std::io::Error>(first.chain(rest)) diff --git a/crates/connector_proxy/src/libs/stream.rs b/crates/connector_proxy/src/libs/stream.rs index a38c553395..b9ef51ab75 100644 --- a/crates/connector_proxy/src/libs/stream.rs +++ b/crates/connector_proxy/src/libs/stream.rs @@ -2,111 +2,70 @@ use crate::libs::airbyte_catalog::Message; use crate::{apis::InterceptorStream, errors::create_custom_error}; use crate::errors::raise_err; -use bytes::{Buf, Bytes, BytesMut}; -use futures::{stream, StreamExt, TryStream, TryStreamExt}; -use serde_json::{Deserializer, Value}; -use tokio::io::{AsyncRead, AsyncReadExt}; +use bytelines::AsyncByteLines; +use bytes::Bytes; +use futures::{StreamExt, TryStream, TryStreamExt}; use tokio_util::io::StreamReader; use validator::Validate; +use super::airbyte_catalog::{Log, LogLevel, MessageType}; use super::protobuf::decode_message; -pub fn stream_all_bytes( - reader: R, +// Creates a stream of bytes of lines from the given stream +// This allows our other methods such as stream_airbyte_messages to operate +// on lines, simplifying their logic +// Note that we keep the newline character \n at the end of each byte +// to avoid inconsistencies among different streams with regards to the \n character +pub fn stream_lines( + in_stream: InterceptorStream, ) -> impl TryStream, Error = std::io::Error, Ok = bytes::Bytes> { - stream::try_unfold(reader, |mut r| async { - // consistent with the default capacity of ReaderStream. - // https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/io/reader_stream.rs#L8 - let mut buf = BytesMut::with_capacity(4096); - match r.read_buf(&mut buf).await { - Ok(0) => Ok(None), - Ok(_) => Ok(Some((Bytes::from(buf), r))), - Err(e) => raise_err(&format!("error during streaming {:?}.", e)), - } - }) + AsyncByteLines::new(StreamReader::new(in_stream)) + .into_stream() + .map_ok(|mut vec| { + vec.push(b'\n'); + Bytes::from(vec) + }) } -/// Given a stream of bytes, try to deserialize them into Airbyte Messages. +/// Given a stream of lines, try to deserialize them into Airbyte Messages. /// This can be used when reading responses from the Airbyte connector, and will /// handle validation of messages as well as handling of AirbyteLogMessages. -/// Will ignore* messages that cannot be parsed to an AirbyteMessage. +/// Will ignore* lines that cannot be parsed to an AirbyteMessage. /// * See https://docs.airbyte.com/understanding-airbyte/airbyte-specification#the-airbyte-protocol pub fn stream_airbyte_responses( in_stream: InterceptorStream, ) -> impl TryStream, Ok = Message, Error = std::io::Error> { - stream::once(async { - let mut buf = BytesMut::new(); - let items = in_stream - .map(move |bytes| { - let b = bytes?; - buf.extend_from_slice(b.chunk()); - let chunk = buf.chunk(); - - // Deserialize to Value first, instead of Message, to avoid missing 'is_eof' signals in error. - let deserializer = Deserializer::from_slice(chunk); - let mut value_stream = deserializer.into_iter::(); - - // Turn Values into Messages and validate them - let values: Vec> = value_stream - .by_ref() - .map_while(|value| match value { - Ok(v) => Some(Ok(v)), - Err(e) => { - // we must stop as soon as we hit EOF to avoid - // progressing value_stream.byte_offset() so that we can - // safely drop the buffer up to byte_offset() and pick up the leftovers - // when working with the next bytes - if e.is_eof() { - return None; - } - - Some(raise_err(&format!( - "error in decoding JSON: {:?}, {:?}", - e, - std::str::from_utf8(chunk) - ))) - } + stream_lines(in_stream).try_filter_map(|line| async move { + let message: Message = match serde_json::from_slice(&line) { + Ok(m) => m, + Err(e) => { + // It is currently ambiguous for us whether Airbyte protocol specification + // mandates that there must be no plaintext or not, as such we handle all + // errors in parsing of stdout lines by logging the issue, but not failing + Message { + message_type: MessageType::Log, + connection_status: None, + state: None, + record: None, + spec: None, + catalog: None, + log: Some(Log { + level: LogLevel::Debug, + message: format!("Encountered error while trying to parse Airbyte Message: {:?} in line {:?}", e, line) }) - .map(|value| match value { - Ok(v) => { - let message: Message = match serde_json::from_value(v) { - Ok(m) => m, - // We ignore JSONs that are not Airbyte Messages according - // to the specification: - // https://docs.airbyte.com/understanding-airbyte/airbyte-specification#the-airbyte-protocol - Err(_) => return Ok(None), - }; - - message.validate().map_err(|e| { - create_custom_error(&format!("error in validating message {:?}", e)) - })?; - - tracing::debug!("read message:: {:?}", &message); - Ok(Some(message)) - } - Err(e) => Err(e), - }) - // Flipping the Option and Result to filter out the None values - .filter_map(|value| match value { - Ok(Some(v)) => Some(Ok(v)), - Ok(None) => None, - Err(e) => Some(Err(e)), - }) - .collect(); - - let byte_offset = value_stream.byte_offset(); - drop(buf.split_to(byte_offset)); + } + } + }; - Ok::<_, std::io::Error>(stream::iter(values)) - }) - .try_flatten(); + message + .validate() + .map_err(|e| create_custom_error(&format!("error in validating message {:?}", e)))?; - // We need to set explicit error type, see https://github.com/rust-lang/rust/issues/63502 - Ok::<_, std::io::Error>(items) + Ok(Some(message)) }) - .try_flatten() - // Handle logs here so we don't have to worry about them everywhere else .try_filter_map(|message| async { + // For AirbyteLogMessages, log them and then filter them out + // so that we don't have to handle them elsewhere if let Some(log) = message.log { log.log(); Ok(None) @@ -159,21 +118,43 @@ where #[cfg(test)] mod test { - use futures::future; + use std::{collections::HashMap, pin::Pin}; + + use bytes::BytesMut; + use futures::stream; + use protocol::{ + flow::EndpointType, + materialize::{validate_request, ValidateRequest}, + }; + use tokio_util::io::ReaderStream; - use crate::libs::airbyte_catalog::{ConnectionStatus, MessageType, Status}; + use crate::libs::{ + airbyte_catalog::{ConnectionStatus, MessageType, Status}, + protobuf::encode_message, + }; use super::*; + fn create_stream( + input: Vec, + ) -> Pin, Ok = T, Error = std::io::Error>>> { + Box::pin(stream::iter(input.into_iter().map(Ok::))) + } + #[tokio::test] - async fn test_stream_all_bytes() { - let input = "{\"test\": \"hello\"}".as_bytes(); - let stream = stream::once(future::ready(Ok::<_, std::io::Error>(input))); - let reader = StreamReader::new(stream); - let mut all_bytes = Box::pin(stream_all_bytes(reader)); - - let result = all_bytes.next().await.unwrap().unwrap(); - assert_eq!(result.chunk(), input); + async fn test_stream_lines() { + let line_0 = "{\"test\": \"hello\"}\n".as_bytes(); + let line_1 = "other\n".as_bytes(); + let line_2 = "{\"object\": {}}\n".as_bytes(); + let mut input = BytesMut::new(); + input.extend_from_slice(line_0); + input.extend_from_slice(line_1); + input.extend_from_slice(line_2); + let stream = create_stream(vec![Bytes::from(input)]); + let all_bytes = Box::pin(stream_lines(stream)); + + let result: Vec = all_bytes.try_collect::>().await.unwrap(); + assert_eq!(result, vec![line_0, line_1, line_2]); } #[tokio::test] @@ -191,16 +172,12 @@ mod test { }), }; let input = vec![ - Ok::<_, std::io::Error>( - "{\"type\": \"CONNECTION_STATUS\", \"connectionStatus\": {".as_bytes(), - ), - Ok::<_, std::io::Error>("\"status\": \"SUCCEEDED\",\"message\":\"test\"}}".as_bytes()), + Bytes::from("{\"type\": \"CONNECTION_STATUS\", \"connectionStatus\": {"), + Bytes::from("\"status\": \"SUCCEEDED\",\"message\":\"test\"}}"), ]; - let stream = stream::iter(input); - let reader = StreamReader::new(stream); + let stream = create_stream(input); - let byte_stream = Box::pin(stream_all_bytes(reader)); - let mut messages = Box::pin(stream_airbyte_responses(byte_stream)); + let mut messages = Box::pin(stream_airbyte_responses(stream)); let result = messages.next().await.unwrap().unwrap(); assert_eq!( @@ -224,16 +201,43 @@ mod test { }), }; let input = vec![ - Ok::<_, std::io::Error>( - "{}\n{\"type\": \"CONNECTION_STATUS\", \"connectionStatus\": {".as_bytes(), + Bytes::from("{}\n{\"type\": \"CONNECTION_STATUS\", \"connectionStatus\": {"), + Bytes::from("\"status\": \"SUCCEEDED\",\"message\":\"test\"}}"), + ]; + let stream = create_stream(input); + + let mut messages = Box::pin(stream_airbyte_responses(stream)); + + let result = messages.next().await.unwrap().unwrap(); + assert_eq!( + result.connection_status.unwrap(), + input_message.connection_status.unwrap() + ); + } + + #[tokio::test] + async fn test_stream_airbyte_responses_plaintext_mixed() { + let input_message = Message { + message_type: MessageType::ConnectionStatus, + log: None, + state: None, + record: None, + spec: None, + catalog: None, + connection_status: Some(ConnectionStatus { + status: Status::Succeeded, + message: Some("test".to_string()), + }), + }; + let input = vec![ + Bytes::from( + "I am plaintext!\n{\"type\": \"CONNECTION_STATUS\", \"connectionStatus\": {", ), - Ok::<_, std::io::Error>("\"status\": \"SUCCEEDED\",\"message\":\"test\"}}".as_bytes()), + Bytes::from("\"status\": \"SUCCEEDED\",\"message\":\"test\"}}"), ]; - let stream = stream::iter(input); - let reader = StreamReader::new(stream); + let stream = create_stream(input); - let byte_stream = Box::pin(stream_all_bytes(reader)); - let mut messages = Box::pin(stream_airbyte_responses(byte_stream)); + let mut messages = Box::pin(stream_airbyte_responses(stream)); let result = messages.next().await.unwrap().unwrap(); assert_eq!( @@ -241,4 +245,29 @@ mod test { input_message.connection_status.unwrap() ); } + + // TODO: this test fails for some reason + // it throws an UnexpectedEof error, but my manual tests have been okay + #[tokio::test] + async fn test_get_decoded_message() { + let msg = ValidateRequest { + materialization: "materialization".to_string(), + endpoint_type: EndpointType::AirbyteSource.into(), + endpoint_spec_json: "{}".to_string(), + bindings: vec![validate_request::Binding { + resource_spec_json: "{}".to_string(), + collection: None, + field_config_json: HashMap::new(), + }], + }; + + let msg_buf = encode_message(&msg).unwrap(); + + let stream = Box::pin(ReaderStream::new(std::io::Cursor::new(msg_buf))); + let result = get_decoded_message::(stream) + .await + .unwrap(); + + assert_eq!(result, msg); + } }