diff --git a/rust/cubestore/cubestore/src/streaming/kafka.rs b/rust/cubestore/cubestore/src/streaming/kafka.rs index 6f745f84ad37c..22a6b9babb392 100644 --- a/rust/cubestore/cubestore/src/streaming/kafka.rs +++ b/rust/cubestore/cubestore/src/streaming/kafka.rs @@ -2,12 +2,30 @@ use crate::config::injection::DIService; use crate::config::ConfigObj; use crate::metastore::table::StreamOffset; use crate::metastore::Column; +use crate::sql::MySqlDialectWithBackTicks; use crate::streaming::{parse_json_payload_and_key, StreamingSource}; use crate::table::{Row, TableValue}; use crate::CubeError; +use arrow::array::ArrayRef; +use arrow::record_batch::RecordBatch; +use arrow::{datatypes::Schema, datatypes::SchemaRef}; use async_std::stream; use async_trait::async_trait; +use datafusion::catalog::TableReference; use datafusion::cube_ext; +use datafusion::datasource::datasource::Statistics; +use datafusion::datasource::TableProvider; +use datafusion::error::DataFusionError; +use datafusion::logical_plan::Expr as DExpr; +use datafusion::logical_plan::LogicalPlan; +use datafusion::physical_plan::empty::EmptyExec; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::udaf::AggregateUDF; +use datafusion::physical_plan::udf::ScalarUDF; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::ExecutionContext; +use datafusion::sql::parser::Statement as DFStatement; +use datafusion::sql::planner::{ContextProvider, SqlToRel}; use futures::Stream; use json::object::Object; use json::JsonValue; @@ -16,6 +34,10 @@ use rdkafka::error::KafkaResult; use rdkafka::message::BorrowedMessage; use rdkafka::util::Timeout; use rdkafka::{ClientConfig, Message, Offset, TopicPartitionList}; +use sqlparser::ast::{Query, SetExpr, Statement}; +use sqlparser::parser::Parser; +use sqlparser::tokenizer::Tokenizer; +use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -28,18 +50,18 @@ pub struct KafkaStreamingSource { password: Option, topic: String, host: String, - // TODO Support parsing of filters and applying before insert - _select_statement: Option, offset: Option, partition: usize, kafka_client: Arc, use_ssl: bool, + post_filter: Option>, } impl KafkaStreamingSource { pub fn new( table_id: u64, unique_key_columns: Vec, + columns: Vec, user: Option, password: Option, topic: String, @@ -50,6 +72,26 @@ impl KafkaStreamingSource { kafka_client: Arc, use_ssl: bool, ) -> Self { + let post_filter = if let Some(select_statement) = select_statement { + let planner = KafkaFilterPlanner { + topic: topic.clone(), + columns, + }; + match planner.parse_select_statement(select_statement.clone()) { + Ok(p) => p, + Err(e) => { + //FIXME May be we should stop execution here + log::error!( + "Error while parsing `select_statement`: {}. Select statement ignored", + e + ); + None + } + } + } else { + None + }; + KafkaStreamingSource { table_id, unique_key_columns, @@ -57,11 +99,129 @@ impl KafkaStreamingSource { password, topic, host, - _select_statement: select_statement, offset, partition, kafka_client, use_ssl, + post_filter, + } + } +} + +pub struct KafkaFilterPlanner { + topic: String, + columns: Vec, +} + +impl KafkaFilterPlanner { + fn parse_select_statement( + &self, + select_statement: String, + ) -> Result>, CubeError> { + let dialect = &MySqlDialectWithBackTicks {}; + let mut tokenizer = Tokenizer::new(dialect, &select_statement); + let tokens = tokenizer.tokenize().unwrap(); + let statement = Parser::new(tokens, dialect).parse_statement()?; + + match &statement { + Statement::Query(box Query { + body: SetExpr::Select(s), + .. + }) => { + if s.selection.is_none() { + return Ok(None); + } + let provider = TopicTableProvider::new(self.topic.clone(), &self.columns); + let query_planner = SqlToRel::new(&provider); + let logical_plan = + query_planner.statement_to_plan(&DFStatement::Statement(statement.clone()))?; + let physical_filter = Self::make_physical_filter(&logical_plan)?; + Ok(physical_filter) + } + _ => Err(CubeError::user(format!( + "{} is not valid select query", + select_statement + ))), + } + } + + /// Only Projection > Filter > TableScan plans are allowed + fn make_physical_filter( + plan: &LogicalPlan, + ) -> Result>, CubeError> { + match plan { + LogicalPlan::Projection { input, .. } => match input.as_ref() { + filter_plan @ LogicalPlan::Filter { input, .. } => match input.as_ref() { + LogicalPlan::TableScan { .. } => { + let plan_ctx = Arc::new(ExecutionContext::new()); + let phys_plan = plan_ctx.create_physical_plan(&filter_plan)?; + Ok(Some(phys_plan)) + } + _ => Ok(None), + }, + _ => Ok(None), + }, + _ => Ok(None), + } + } +} + +#[derive(Debug, Clone)] +struct TopicTableProvider { + topic: String, + schema: SchemaRef, +} + +impl TopicTableProvider { + pub fn new(topic: String, columns: &Vec) -> Self { + let schema = Arc::new(Schema::new( + columns.iter().map(|c| c.clone().into()).collect::>(), + )); + Self { topic, schema } + } +} + +impl ContextProvider for TopicTableProvider { + fn get_table_provider(&self, name: TableReference) -> Option> { + match name { + TableReference::Bare { table } if table == self.topic => Some(Arc::new(self.clone())), + _ => None, + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } +} + +impl TableProvider for TopicTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn scan( + &self, + _projection: &Option>, + _batch_size: usize, + _filters: &[DExpr], + _limit: Option, + ) -> Result, DataFusionError> { + Ok(Arc::new(EmptyExec::new(false, self.schema()))) + } + + fn statistics(&self) -> Statistics { + Statistics { + num_rows: None, + total_byte_size: None, + column_statistics: None, } } } @@ -297,6 +457,25 @@ impl StreamingSource for KafkaStreamingSource { Ok(stream) } + async fn apply_post_filter(&self, data: Vec) -> Result, CubeError> { + if let Some(post_filter) = &self.post_filter { + let schema = post_filter.children()[0].schema(); + let batch = RecordBatch::try_new(schema.clone(), data)?; + let input = Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None)?); + let filter = post_filter.with_new_children(vec![input])?; + let mut out_batches = collect(filter).await?; + let res = if out_batches.len() == 1 { + out_batches.pop().unwrap() + } else { + RecordBatch::concat(&schema, &out_batches)? + }; + + Ok(res.columns().to_vec()) + } else { + Ok(data) + } + } + fn validate_table_location(&self) -> Result<(), CubeError> { // TODO // self.query(None)?; diff --git a/rust/cubestore/cubestore/src/streaming/mod.rs b/rust/cubestore/cubestore/src/streaming/mod.rs index fd7815903b4c0..7437427d8ddfa 100644 --- a/rust/cubestore/cubestore/src/streaming/mod.rs +++ b/rust/cubestore/cubestore/src/streaming/mod.rs @@ -14,6 +14,7 @@ use crate::table::{Row, TableValue, TimestampValue}; use crate::util::decimal::Decimal; use crate::CubeError; use arrow::array::ArrayBuilder; +use arrow::array::ArrayRef; use async_trait::async_trait; use chrono::Utc; use datafusion::cube_ext::ordfloat::OrdF64; @@ -133,6 +134,7 @@ impl StreamingServiceImpl { table.get_row().unique_key_columns() .ok_or_else(|| CubeError::internal(format!("Streaming table without unique key columns: {:?}", table)))? .into_iter().cloned().collect(), + table.get_row().get_columns().clone(), user.clone(), password.clone(), table_name, @@ -307,11 +309,14 @@ impl StreamingService for StreamingServiceImpl { .meta_store .create_replay_handle(table.get_id(), location_index, seq_pointer) .await?; + let data = finish(builders); + let data = source.apply_post_filter(data).await?; + let new_chunks = self .chunk_store .partition_data( table.get_id(), - finish(builders), + data, table.get_row().get_columns().as_slice(), true, ) @@ -355,6 +360,10 @@ pub trait StreamingSource: Send + Sync { initial_seq_value: Option, ) -> Result> + Send>>, CubeError>; + async fn apply_post_filter(&self, data: Vec) -> Result, CubeError> { + Ok(data) + } + fn validate_table_location(&self) -> Result<(), CubeError>; } @@ -1067,7 +1076,8 @@ mod tests { serde_json::json!({ "MESSAGEID": i.to_string() }).to_string() )), payload: Some( - serde_json::json!({ "ANONYMOUSID": j.to_string() }).to_string(), + serde_json::json!({ "ANONYMOUSID": j.to_string(), "TIMESTAMP": i }) + .to_string(), ), offset: i, }); @@ -1368,4 +1378,67 @@ mod tests { }) .await; } + + #[tokio::test] + async fn streaming_filter_kafka() { + Config::test("streaming_filter_kafka").update_config(|mut c| { + c.stream_replay_check_interval_secs = 1; + c.compaction_in_memory_chunks_max_lifetime_threshold = 8; + c.partition_split_threshold = 1000000; + c.max_partition_split_threshold = 1000000; + c.compaction_chunks_count_threshold = 100; + c.compaction_chunks_total_size_threshold = 100000; + c.stale_stream_timeout = 1; + c.wal_split_threshold = 16384; + c + }).start_with_injector_override(async move |injector| { + injector.register_typed::(async move |_| { + Arc::new(MockKafkaClient) + }) + .await + }, async move |services| { + let service = services.sql_service; + + let _ = service.exec_query("CREATE SCHEMA test").await.unwrap(); + + service + .exec_query("CREATE SOURCE OR UPDATE kafka AS 'kafka' VALUES (user = 'foo', password = 'bar', host = 'localhost:9092')") + .await + .unwrap(); + + let listener = services.cluster.job_result_listener(); + + let _ = service + .exec_query("CREATE TABLE test.events_by_type_1 (`ANONYMOUSID` text, `MESSAGEID` text, `TIMESTAMP` int) \ + WITH (stream_offset = 'earliest', select_statement = 'SELECT * FROM EVENTS_BY_TYPE WHERE TIMESTAMP >= 10000 and TIMESTAMP < 14000') \ + unique key (`ANONYMOUSID`, `MESSAGEID`, `TIMESTAMP`) INDEX by_anonymous(`ANONYMOUSID`, `TIMESTAMP`) location 'stream://kafka/EVENTS_BY_TYPE/0', 'stream://kafka/EVENTS_BY_TYPE/1'") + .await + .unwrap(); + + let wait = listener.wait_for_job_results(vec![ + (RowKey::Table(TableId::Tables, 1), JobType::TableImportCSV("stream://kafka/EVENTS_BY_TYPE/0".to_string())), + (RowKey::Table(TableId::Tables, 1), JobType::TableImportCSV("stream://kafka/EVENTS_BY_TYPE/1".to_string())), + ]); + timeout(Duration::from_secs(15), wait).await.unwrap().unwrap(); + + let result = service + .exec_query("SELECT COUNT(*) FROM test.events_by_type_1") + .await + .unwrap(); + assert_eq!(result.get_rows(), &vec![Row::new(vec![TableValue::Int(8000)])]); + + let result = service + .exec_query("SELECT min(TIMESTAMP) FROM test.events_by_type_1 ") + .await + .unwrap(); + assert_eq!(result.get_rows(), &vec![Row::new(vec![TableValue::Int(10000)])]); + + let result = service + .exec_query("SELECT max(TIMESTAMP) FROM test.events_by_type_1 ") + .await + .unwrap(); + assert_eq!(result.get_rows(), &vec![Row::new(vec![TableValue::Int(13999)])]); + }) + .await; + } }