Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(planner): support independent right join #7634

Merged
merged 11 commits into from Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/query/service/src/pipelines/processors/mod.rs
Expand Up @@ -54,6 +54,7 @@ pub use transforms::KeyU64HashTable;
pub use transforms::KeyU8HashTable;
pub use transforms::MarkJoinCompactor;
pub use transforms::ProjectionTransform;
pub use transforms::RightJoinCompactor;
pub use transforms::SerializerHashTable;
pub use transforms::SinkBuildHashTable;
pub use transforms::SortMergeCompactor;
Expand Down
Expand Up @@ -12,17 +12,36 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;

use common_exception::Result;
use common_functions::scalars::FunctionFactory;
use parking_lot::RwLock;

use crate::evaluator::EvalNode;
use crate::evaluator::Evaluator;
use crate::pipelines::processors::transforms::hash_join::row::RowPtr;
use crate::pipelines::processors::transforms::hash_join::MarkJoinDesc;
use crate::sql::executor::HashJoin;
use crate::sql::executor::PhysicalScalar;
use crate::sql::plans::JoinType;

pub struct RightJoinDesc {
/// Record rows in build side that are matched with rows in probe side.
pub(crate) build_indexes: RwLock<Vec<RowPtr>>,
/// Record row in build side that is matched how many rows in probe side.
pub(crate) row_state: RwLock<HashMap<RowPtr, usize>>,
}

impl RightJoinDesc {
pub fn create() -> Self {
RightJoinDesc {
build_indexes: RwLock::new(Vec::new()),
row_state: RwLock::new(HashMap::new()),
}
}
}

pub struct HashJoinDesc {
pub(crate) build_keys: Vec<EvalNode>,
pub(crate) probe_keys: Vec<EvalNode>,
Expand All @@ -31,6 +50,7 @@ pub struct HashJoinDesc {
pub(crate) marker_join_desc: MarkJoinDesc,
/// Whether the Join are derived from correlated subquery.
pub(crate) from_correlated_subquery: bool,
pub(crate) right_join_desc: RightJoinDesc,
}

impl HashJoinDesc {
Expand All @@ -50,6 +70,7 @@ impl HashJoinDesc {
marker_index: join.marker_index,
},
from_correlated_subquery: join.from_correlated_subquery,
right_join_desc: RightJoinDesc::create(),
})
}

Expand Down
Expand Up @@ -44,4 +44,7 @@ pub trait HashJoinState: Send + Sync {

/// Get mark join results
fn mark_join_blocks(&self) -> Result<Vec<DataBlock>>;

/// Get right join results
fn right_join_blocks(&self, blocks: &[DataBlock]) -> Result<Vec<DataBlock>>;
}
Expand Up @@ -13,10 +13,12 @@
// limitations under the License.

use std::borrow::BorrowMut;
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::Mutex;

use common_arrow::arrow::bitmap::Bitmap;
use common_arrow::arrow::bitmap::MutableBitmap;
use common_base::base::tokio::sync::Notify;
use common_datablocks::DataBlock;
Expand All @@ -35,7 +37,9 @@ use common_datavalues::DataField;
use common_datavalues::DataSchema;
use common_datavalues::DataSchemaRef;
use common_datavalues::DataSchemaRefExt;
use common_datavalues::DataType;
use common_datavalues::DataTypeImpl;
use common_datavalues::DataValue;
use common_datavalues::NullableType;
use common_exception::ErrorCode;
use common_exception::Result;
Expand Down Expand Up @@ -108,7 +112,7 @@ pub enum HashTable {
KeyU512HashTable(KeyU512HashTable),
}

#[derive(Clone, Copy, Eq, PartialEq, Debug)]
#[derive(Clone, Copy, Eq, PartialEq, Debug, Hash)]
pub enum MarkerKind {
True,
False,
Expand All @@ -130,6 +134,7 @@ pub struct JoinHashTable {
pub(crate) row_space: RowSpace,
pub(crate) hash_join_desc: HashJoinDesc,
pub(crate) row_ptrs: RwLock<Vec<RowPtr>>,
pub(crate) probe_schema: DataSchemaRef,
finished_notify: Arc<Notify>,
}

Expand All @@ -138,6 +143,7 @@ impl JoinHashTable {
ctx: Arc<QueryContext>,
build_keys: &[PhysicalScalar],
build_schema: DataSchemaRef,
probe_schema: DataSchemaRef,
hash_join_desc: HashJoinDesc,
) -> Result<Arc<JoinHashTable>> {
let hash_key_types: Vec<DataTypeImpl> =
Expand All @@ -151,6 +157,7 @@ impl JoinHashTable {
hash_method: HashMethodSerializer::default(),
}),
build_schema,
probe_schema,
hash_join_desc,
)?),
HashMethodKind::KeysU8(hash_method) => Arc::new(JoinHashTable::try_create(
Expand All @@ -160,6 +167,7 @@ impl JoinHashTable {
hash_method,
}),
build_schema,
probe_schema,
hash_join_desc,
)?),
HashMethodKind::KeysU16(hash_method) => Arc::new(JoinHashTable::try_create(
Expand All @@ -169,6 +177,7 @@ impl JoinHashTable {
hash_method,
}),
build_schema,
probe_schema,
hash_join_desc,
)?),
HashMethodKind::KeysU32(hash_method) => Arc::new(JoinHashTable::try_create(
Expand All @@ -178,6 +187,7 @@ impl JoinHashTable {
hash_method,
}),
build_schema,
probe_schema,
hash_join_desc,
)?),
HashMethodKind::KeysU64(hash_method) => Arc::new(JoinHashTable::try_create(
Expand All @@ -187,6 +197,7 @@ impl JoinHashTable {
hash_method,
}),
build_schema,
probe_schema,
hash_join_desc,
)?),
HashMethodKind::KeysU128(hash_method) => Arc::new(JoinHashTable::try_create(
Expand All @@ -196,6 +207,7 @@ impl JoinHashTable {
hash_method,
}),
build_schema,
probe_schema,
hash_join_desc,
)?),
HashMethodKind::KeysU256(hash_method) => Arc::new(JoinHashTable::try_create(
Expand All @@ -205,6 +217,7 @@ impl JoinHashTable {
hash_method,
}),
build_schema,
probe_schema,
hash_join_desc,
)?),
HashMethodKind::KeysU512(hash_method) => Arc::new(JoinHashTable::try_create(
Expand All @@ -214,6 +227,7 @@ impl JoinHashTable {
hash_method,
}),
build_schema,
probe_schema,
hash_join_desc,
)?),
})
Expand All @@ -223,6 +237,7 @@ impl JoinHashTable {
ctx: Arc<QueryContext>,
hash_table: HashTable,
mut build_data_schema: DataSchemaRef,
mut probe_data_schema: DataSchemaRef,
hash_join_desc: HashJoinDesc,
) -> Result<Self> {
if hash_join_desc.join_type == JoinType::Left
Expand All @@ -237,6 +252,16 @@ impl JoinHashTable {
}
build_data_schema = DataSchemaRefExt::create(nullable_field);
};
if hash_join_desc.join_type == JoinType::Right {
let mut nullable_field = Vec::with_capacity(probe_data_schema.fields().len());
for field in probe_data_schema.fields().iter() {
nullable_field.push(DataField::new_nullable(
field.name(),
field.data_type().clone(),
));
}
probe_data_schema = DataSchemaRefExt::create(nullable_field);
}
Ok(Self {
row_space: RowSpace::new(build_data_schema),
ref_count: Mutex::new(0),
Expand All @@ -245,6 +270,7 @@ impl JoinHashTable {
ctx,
hash_table: RwLock::new(hash_table),
row_ptrs: RwLock::new(vec![]),
probe_schema: probe_data_schema,
finished_notify: Arc::new(Notify::new()),
})
}
Expand Down Expand Up @@ -407,6 +433,31 @@ impl JoinHashTable {
}
}
}

fn find_unmatched_build_indexes(&self) -> Result<Vec<RowPtr>> {
// For right join, build side will appear at least once in the joined table
// Find the unmatched rows in build side
let mut unmatched_build_indexes = vec![];
let build_indexes = self.hash_join_desc.right_join_desc.build_indexes.read();
let build_indexes_set: HashSet<&RowPtr> = build_indexes.iter().collect();
// TODO(xudong): remove the line of code below after https://github.com/rust-lang/rust-clippy/issues/8987
#[allow(clippy::significant_drop_in_scrutinee)]
for (chunk_index, chunk) in self.row_space.chunks.read().unwrap().iter().enumerate() {
for row_index in 0..chunk.num_rows() {
let row_ptr = RowPtr {
chunk_index: chunk_index as u32,
row_index: row_index as u32,
marker: None,
};
if !build_indexes_set.contains(&row_ptr) {
let mut row_state = self.hash_join_desc.right_join_desc.row_state.write();
row_state.entry(row_ptr).or_insert(0_usize);
unmatched_build_indexes.push(row_ptr);
}
}
}
Ok(unmatched_build_indexes)
}
}

#[async_trait::async_trait]
Expand All @@ -429,7 +480,8 @@ impl HashJoinState for JoinHashTable {
| JoinType::Anti
| JoinType::Left
| Mark
| JoinType::Single => self.probe_join(input, probe_state),
| JoinType::Single
| JoinType::Right => self.probe_join(input, probe_state),
JoinType::Cross => self.probe_cross_join(input, probe_state),
_ => unimplemented!("{} is unimplemented", self.hash_join_desc.join_type),
}
Expand Down Expand Up @@ -661,4 +713,73 @@ impl HashJoinState for JoinHashTable {
let build_block = self.row_space.gather(&row_ptrs)?;
Ok(vec![self.merge_eq_block(&marker_block, &build_block)?])
}

fn right_join_blocks(&self, blocks: &[DataBlock]) -> Result<Vec<DataBlock>> {
let unmatched_build_indexes = self.find_unmatched_build_indexes()?;
if unmatched_build_indexes.is_empty() && self.hash_join_desc.other_predicate.is_none() {
return Ok(blocks.to_vec());
}

let unmatched_build_block = self.row_space.gather(&unmatched_build_indexes)?;
// Create null block for unmatched rows in probe side
let null_probe_block = DataBlock::create(
self.probe_schema.clone(),
self.probe_schema
.fields()
.iter()
.map(|df| {
df.data_type()
.clone()
.create_constant_column(&DataValue::Null, unmatched_build_indexes.len())
})
.collect::<Result<Vec<_>>>()?,
);
let mut merged_block = self.merge_eq_block(&unmatched_build_block, &null_probe_block)?;
merged_block = DataBlock::concat_blocks(&[blocks, &[merged_block]].concat())?;

if self.hash_join_desc.other_predicate.is_none() {
return Ok(vec![merged_block]);
}

let (bm, all_true, all_false) = self.get_other_filters(
&merged_block,
self.hash_join_desc.other_predicate.as_ref().unwrap(),
)?;

if all_true {
return Ok(vec![merged_block]);
}

let validity = match (bm, all_false) {
(Some(b), _) => b,
(None, true) => Bitmap::new_zeroed(merged_block.num_rows()),
// must be one of above
_ => unreachable!(),
};
let probe_column_len = self.probe_schema.fields().len();
let probe_columns = merged_block.columns()[0..probe_column_len]
.iter()
.map(|c| Self::set_validity(c, &validity))
.collect::<Result<Vec<_>>>()?;
let probe_block = DataBlock::create(self.probe_schema.clone(), probe_columns);
let build_block = DataBlock::create(
self.row_space.data_schema.clone(),
merged_block.columns()[probe_column_len..].to_vec(),
);
merged_block = self.merge_eq_block(&build_block, &probe_block)?;

// If build_indexes size will greater build table size, we need filter the redundant rows for build side.
let mut build_indexes = self.hash_join_desc.right_join_desc.build_indexes.write();
let mut row_state = self.hash_join_desc.right_join_desc.row_state.write();
build_indexes.extend(&unmatched_build_indexes);
if build_indexes.len() > self.row_space.rows_number() {
let mut bm = validity.into_mut().right().unwrap();
Self::filter_rows_for_right_join(&mut bm, &build_indexes, &mut row_state);
let predicate = BooleanColumn::from_arrow_data(bm.into()).arc();
let filtered_block = DataBlock::filter_block(merged_block, &predicate)?;
return Ok(vec![filtered_block]);
}

Ok(vec![merged_block])
}
}