Skip to content

Commit

Permalink
Address feedback and add 'try_xxxxxxx' functions
Browse files Browse the repository at this point in the history
  • Loading branch information
junli1026 committed Feb 2, 2022
1 parent 60744da commit a3d8a0f
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 66 deletions.
6 changes: 3 additions & 3 deletions common/functions/src/scalars/function2_adapter.rs
Expand Up @@ -113,9 +113,9 @@ impl Function2 for Function2Adapter {

let col = self.eval(&columns, input_rows)?;

// Some functions always returns a nullable column because invalid input(not Nulls) may return a null output.
// For example, inet_aton("helloworld") will return Null.
// In this case, we need to merge the validity.
// The'try' series functions always return Null when they failed the try.
// For example, try_inet_aton("helloworld") will return Null because it failed to parse "helloworld" to a valid IP address.
// The same thing may happen on other 'try' functions. So we need to merge the validity.
if col.is_nullable() {
let (_, bitmap) = col.validity();
validity = match validity {
Expand Down
121 changes: 98 additions & 23 deletions common/functions/src/scalars/others/inet_aton.rs
Expand Up @@ -13,15 +13,19 @@
// limitations under the License.

use std::fmt;
use std::net::Ipv4Addr;
use std::str;
use std::sync::Arc;

use common_datavalues2::remove_nullable;
use common_datavalues2::type_primitive;
use common_datavalues2::ColumnBuilder;
use common_datavalues2::ColumnRef;
use common_datavalues2::ColumnViewer;
use common_datavalues2::ColumnsWithField;
use common_datavalues2::DataTypePtr;
use common_datavalues2::DataValue;
use common_datavalues2::NullType;
use common_datavalues2::NullableColumnBuilder;
use common_datavalues2::NullableType;
use common_datavalues2::TypeID;
Expand All @@ -32,15 +36,20 @@ use crate::scalars::function_factory::FunctionFeatures;
use crate::scalars::Function2;
use crate::scalars::Function2Description;

#[derive(Clone)]
#[doc(alias = "TryIPv4StringToNumFunction")]
pub type TryInetAtonFunction = InetAtonFunctionImpl<true>;

#[doc(alias = "IPv4StringToNumFunction")]
pub struct InetAtonFunction {
pub type InetAtonFunction = InetAtonFunctionImpl<false>;

#[derive(Clone)]
pub struct InetAtonFunctionImpl<const SUPPRESS_PARSE_ERROR: bool> {
display_name: String,
}

impl InetAtonFunction {
impl<const SUPPRESS_PARSE_ERROR: bool> InetAtonFunctionImpl<SUPPRESS_PARSE_ERROR> {
pub fn try_create(display_name: &str) -> Result<Box<dyn Function2>> {
Ok(Box::new(InetAtonFunction {
Ok(Box::new(InetAtonFunctionImpl::<SUPPRESS_PARSE_ERROR> {
display_name: display_name.to_string(),
}))
}
Expand All @@ -51,7 +60,7 @@ impl InetAtonFunction {
}
}

impl Function2 for InetAtonFunction {
impl<const SUPPRESS_PARSE_ERROR: bool> Function2 for InetAtonFunctionImpl<SUPPRESS_PARSE_ERROR> {
fn name(&self) -> &str {
&*self.display_name
}
Expand All @@ -67,35 +76,101 @@ impl Function2 for InetAtonFunction {
))),
}?;

// For invalid input, the function should return null. So the return type must be nullable.
Ok(Arc::new(NullableType::create(output_type)))
if SUPPRESS_PARSE_ERROR {
// For invalid input, we suppress parse error and return null. So the return type must be nullable.
return Ok(Arc::new(NullableType::create(output_type)));
}

if args[0].is_nullable() {
Ok(Arc::new(NullableType::create(output_type)))
} else {
Ok(output_type)
}
}

fn eval(&self, columns: &ColumnsWithField, input_rows: usize) -> Result<ColumnRef> {
let mut builder: NullableColumnBuilder<u32> =
NullableColumnBuilder::with_capacity(input_rows);
if SUPPRESS_PARSE_ERROR {
let viewer = ColumnViewer::<Vec<u8>>::create(columns[0].column())?;

let mut builder = NullableColumnBuilder::<u32>::with_capacity(input_rows);

for i in 0..input_rows {
// We skip the null check because the function has passthrough_null is true.
// This is arguably correct because the address parsing is not optimized by SIMD, not quite sure how much we can gain from skipping branch prediction.
// Think about the case if we have 1000 rows and 999 are Nulls.
let input = viewer.value(i);
let addr_str = String::from_utf8_lossy(input);
match addr_str.parse::<Ipv4Addr>() {
Ok(addr) => {
let addr_binary: u32 = u32::from(addr);
builder.append(addr_binary, viewer.valid_at(i));
}
Err(_) => builder.append_null(),
}
}
return Ok(builder.build(input_rows));
}

if columns[0].column().data_type_id() == TypeID::Null {
return NullType::arc().create_constant_column(&DataValue::Null, input_rows);
}

let viewer = ColumnViewer::<Vec<u8>>::create(columns[0].column())?;

for i in 0..input_rows {
// We skip the null check because the function has passthrough_null is true.
// This is arguably correct because the address parsing is not optimized by SIMD, not quite sure how much we can gain from skipping branch prediction.
// Think about the case if we have 1000 rows and 999 are Nulls.
let input = viewer.value(i);
let parsed_addr = String::from_utf8_lossy(input).parse::<std::net::Ipv4Addr>();

match parsed_addr {
Ok(addr) => {
let addr_binary: u32 = u32::from(addr);
builder.append(addr_binary, viewer.valid_at(i));
if columns[0].column().is_nullable() || columns[0].column().data_type_id() == TypeID::Null {
let mut builder = NullableColumnBuilder::<u32>::with_capacity(input_rows);
for i in 0..input_rows {
if viewer.null_at(i) {
builder.append_null();
continue;
}

let input = viewer.value(i);
let addr_str = String::from_utf8_lossy(input);
match addr_str.parse::<Ipv4Addr>() {
Ok(addr) => {
let addr_binary: u32 = u32::from(addr);
builder.append(addr_binary, viewer.valid_at(i));
}
Err(err) => {
return Err(ErrorCode::StrParseError(format!(
"Failed to parse '{}' into a IPV4 address, {}",
addr_str, err
)));
}
}
}
Ok(builder.build(input_rows))
} else {
let mut builder = ColumnBuilder::<u32>::with_capacity(input_rows);
for i in 0..input_rows {
let input = viewer.value(i);
let addr_str = String::from_utf8_lossy(input);
match addr_str.parse::<Ipv4Addr>() {
Ok(addr) => {
let addr_binary: u32 = u32::from(addr);
builder.append(addr_binary);
}
Err(err) => {
return Err(ErrorCode::StrParseError(format!(
"Failed to parse '{}' into a IPV4 address, {}",
addr_str, err
)));
}
}
Err(_) => builder.append_null(),
}
Ok(builder.build(input_rows))
}
Ok(builder.build(input_rows))
}

fn passthrough_null(&self) -> bool {
// Null will cause parse error when SUPPRESS_PARSE_ERROR is false.
// In this case we need to check null and skip the parsing, so passthrough_null should be false.
SUPPRESS_PARSE_ERROR
}
}

impl fmt::Display for InetAtonFunction {
impl<const SUPPRESS_PARSE_ERROR: bool> fmt::Display for InetAtonFunctionImpl<SUPPRESS_PARSE_ERROR> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.display_name.to_uppercase())
}
Expand Down
108 changes: 71 additions & 37 deletions common/functions/src/scalars/others/inet_ntoa.rs
Expand Up @@ -17,8 +17,8 @@ use std::net::Ipv4Addr;
use std::str;
use std::sync::Arc;

use common_datavalues2::remove_nullable;
use common_datavalues2::types::type_string::StringType;
use common_datavalues2::ColumnBuilder;
use common_datavalues2::ColumnRef;
use common_datavalues2::ColumnViewer;
use common_datavalues2::ColumnsWithField;
Expand All @@ -27,6 +27,7 @@ use common_datavalues2::Float64Type;
use common_datavalues2::NullableColumnBuilder;
use common_datavalues2::NullableType;
use common_datavalues2::TypeID;
use common_datavalues2::UInt32Type;
use common_exception::ErrorCode;
use common_exception::Result;

Expand All @@ -38,15 +39,20 @@ use crate::scalars::Function2;
use crate::scalars::Function2Description;
use crate::scalars::ParsingMode;

#[doc(alias = "TryIPv4StringToNumFunction")]
pub type TryInetNtoaFunction = InetNtoaFunctionImpl<true>;

#[doc(alias = "IPv4StringToNumFunction")]
pub type InetNtoaFunction = InetNtoaFunctionImpl<false>;

#[derive(Clone)]
#[doc(alias = "IPv4NumToStringFunction")]
pub struct InetNtoaFunction {
pub struct InetNtoaFunctionImpl<const SUPPRESS_CAST_ERROR: bool> {
display_name: String,
}

impl InetNtoaFunction {
impl<const SUPPRESS_CAST_ERROR: bool> InetNtoaFunctionImpl<SUPPRESS_CAST_ERROR> {
pub fn try_create(display_name: &str) -> Result<Box<dyn Function2>> {
Ok(Box::new(InetNtoaFunction {
Ok(Box::new(InetNtoaFunctionImpl::<SUPPRESS_CAST_ERROR> {
display_name: display_name.to_string(),
}))
}
Expand All @@ -57,14 +63,13 @@ impl InetNtoaFunction {
}
}

impl Function2 for InetNtoaFunction {
impl<const SUPPRESS_CAST_ERROR: bool> Function2 for InetNtoaFunctionImpl<SUPPRESS_CAST_ERROR> {
fn name(&self) -> &str {
&*self.display_name
}

fn return_type(&self, args: &[&DataTypePtr]) -> Result<DataTypePtr> {
let input_type = remove_nullable(args[0]);

let input_type = args[0];
let output_type = if input_type.data_type_id().is_numeric()
|| input_type.data_type_id().is_string()
|| input_type.data_type_id() == TypeID::Null
Expand All @@ -77,43 +82,72 @@ impl Function2 for InetNtoaFunction {
)))
}?;

// For invalid input, the function should return null. So the return type must be nullable.
Ok(Arc::new(NullableType::create(output_type)))
if SUPPRESS_CAST_ERROR {
// For invalid input, the function should return null. So the return type must be nullable.
Ok(Arc::new(NullableType::create(output_type)))
} else {
Ok(output_type)
}
}

fn eval(&self, columns: &ColumnsWithField, input_rows: usize) -> Result<ColumnRef> {
let cast_to: DataTypePtr = Arc::new(NullableType::create(Float64Type::arc()));
let cast_option = CastOptions {
// we allow cast failure
exception_mode: ExceptionMode::Zero,
parsing_mode: ParsingMode::Partial,
};
let column = cast_with_type(
columns[0].column(),
columns[0].data_type(),
&cast_to,
&cast_option,
)?;
let viewer = ColumnViewer::<f64>::create(&column)?;

let mut builder: NullableColumnBuilder<Vec<u8>> =
NullableColumnBuilder::with_capacity(input_rows);

for i in 0..input_rows {
let val = viewer.value(i);

if val.is_nan() || val < 0.0 || val > u32::MAX as f64 {
builder.append_null();
} else {
let addr_str = Ipv4Addr::from((val as u32).to_be_bytes()).to_string();
builder.append(addr_str.as_bytes(), viewer.valid_at(i));
if SUPPRESS_CAST_ERROR {
let cast_to: DataTypePtr = Arc::new(NullableType::create(Float64Type::arc()));

let cast_options = CastOptions {
// we allow cast failure
exception_mode: ExceptionMode::Zero,
parsing_mode: ParsingMode::Partial,
};
let column = cast_with_type(
columns[0].column(),
columns[0].data_type(),
&cast_to,
&cast_options,
)?;
let viewer = ColumnViewer::<f64>::create(&column)?;

let mut builder: NullableColumnBuilder<Vec<u8>> =
NullableColumnBuilder::with_capacity(input_rows);

for i in 0..input_rows {
let val = viewer.value(i);

if val.is_nan() || val < 0.0 || val > u32::MAX as f64 {
builder.append_null();
} else {
let addr_str = Ipv4Addr::from((val as u32).to_be_bytes()).to_string();
builder.append(addr_str.as_bytes(), viewer.valid_at(i));
}
}
Ok(builder.build(input_rows))
} else {
let cast_to: DataTypePtr = UInt32Type::arc();
let cast_options = CastOptions {
exception_mode: ExceptionMode::Throw,
parsing_mode: ParsingMode::Strict,
};
let column = cast_with_type(
columns[0].column(),
columns[0].data_type(),
&cast_to,
&cast_options,
)?;
let viewer = ColumnViewer::<u32>::create(&column)?;

let mut builder = ColumnBuilder::<Vec<u8>>::with_capacity(input_rows);

for i in 0..input_rows {
let val = viewer.value(i);
let addr_str = Ipv4Addr::from((val).to_be_bytes()).to_string();
builder.append(addr_str.as_bytes());
}
Ok(builder.build(input_rows))
}
Ok(builder.build(input_rows))
}
}

impl fmt::Display for InetNtoaFunction {
impl<const SUPPRESS_CAST_ERROR: bool> fmt::Display for InetNtoaFunctionImpl<SUPPRESS_CAST_ERROR> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.display_name.to_uppercase())
}
Expand Down
2 changes: 2 additions & 0 deletions common/functions/src/scalars/others/mod.rs
Expand Up @@ -20,6 +20,8 @@ mod running_difference_function;

pub use ignore::IgnoreFunction;
pub use inet_aton::InetAtonFunction;
pub use inet_aton::TryInetAtonFunction;
pub use inet_ntoa::InetNtoaFunction;
pub use inet_ntoa::TryInetNtoaFunction;
pub use other::OtherFunction;
pub use running_difference_function::RunningDifferenceFunction;
16 changes: 15 additions & 1 deletion common/functions/src/scalars/others/other.rs
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

use super::inet_aton::InetAtonFunction;
use super::inet_aton::TryInetAtonFunction;
use super::inet_ntoa::InetNtoaFunction;
use super::inet_ntoa::TryInetNtoaFunction;
use super::running_difference_function::RunningDifferenceFunction;
use super::IgnoreFunction;
use crate::scalars::Function2Factory;
Expand All @@ -23,11 +25,23 @@ pub struct OtherFunction {}

impl OtherFunction {
pub fn register(factory: &mut Function2Factory) {
factory.register("runningDifference", RunningDifferenceFunction::desc());
factory.register("ignore", IgnoreFunction::desc());

// inet_aton
factory.register("inet_aton", InetAtonFunction::desc());
factory.register("IPv4StringToNum", InetAtonFunction::desc());

// try_inet_aton
factory.register("try_inet_aton", TryInetAtonFunction::desc());
factory.register("TryIPv4StringToNum", TryInetAtonFunction::desc());

// inet_ntoa
factory.register("inet_ntoa", InetNtoaFunction::desc());
factory.register("IPv4NumToString", InetNtoaFunction::desc());
factory.register("runningDifference", RunningDifferenceFunction::desc());

// try_inet_ntoa
factory.register("try_inet_ntoa", TryInetNtoaFunction::desc());
factory.register("TryIPv4NumToString", TryInetNtoaFunction::desc());
}
}

0 comments on commit a3d8a0f

Please sign in to comment.