Skip to content

Commit

Permalink
feat: Expressify pattern of str.extract (pola-rs#13607)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored and bushuyev committed Jan 11, 2024
1 parent 8b9c816 commit 930bfc7
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 32 deletions.
38 changes: 38 additions & 0 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ where
ChunkedArray::from_chunk_iter(ca.name(), ca.downcast_iter().map(op))
}

#[inline]
pub fn try_unary_mut_with_options<T, V, F, Arr, E>(
ca: &ChunkedArray<T>,
op: F,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
V: PolarsDataType<Array = Arr>,
Arr: Array + StaticArray,
F: FnMut(&T::Array) -> Result<Arr, E>,
E: Error,
{
ChunkedArray::try_from_chunk_iter(ca.name(), ca.downcast_iter().map(op))
}

#[inline]
pub fn binary_elementwise<T, U, V, F>(
lhs: &ChunkedArray<T>,
Expand Down Expand Up @@ -381,6 +396,29 @@ where
ChunkedArray::from_chunk_iter(name, iter)
}

#[inline]
pub fn try_binary_mut_with_options<T, U, V, F, Arr, E>(
lhs: &ChunkedArray<T>,
rhs: &ChunkedArray<U>,
mut op: F,
name: &str,
) -> Result<ChunkedArray<V>, E>
where
T: PolarsDataType,
U: PolarsDataType,
V: PolarsDataType<Array = Arr>,
Arr: Array,
F: FnMut(&T::Array, &U::Array) -> Result<Arr, E>,
E: Error,
{
let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let iter = lhs
.downcast_iter()
.zip(rhs.downcast_iter())
.map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr));
ChunkedArray::try_from_chunk_iter(name, iter)
}

/// Applies a kernel that produces `Array` types.
pub fn binary<T, U, V, F, Arr>(
lhs: &ChunkedArray<T>,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/predicate_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn test_issue_2472() -> PolarsResult<()> {
let extract = col("group")
.cast(DataType::String)
.str()
.extract(r"(\d+-){4}(\w+)-", 2)
.extract(lit(r"(\d+-){4}(\w+)-"), 2)
.cast(DataType::Int32)
.alias("age");
let predicate = col("age").is_in(lit(Series::new("", [2i32])));
Expand Down
88 changes: 81 additions & 7 deletions crates/polars-ops/src/chunked_array/strings/extract.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::iter::zip;

#[cfg(feature = "extract_groups")]
use arrow::array::{Array, StructArray};
use arrow::array::{MutableArray, MutableUtf8Array, Utf8Array};
use polars_core::export::regex::Regex;
use polars_core::prelude::arity::{try_binary_mut_with_options, try_unary_mut_with_options};

use super::*;

Expand Down Expand Up @@ -72,7 +75,7 @@ pub(super) fn extract_groups(
Series::try_from((ca.name(), chunks))
}

fn extract_group_array(
fn extract_group_reg_lit(
arr: &Utf8Array<i64>,
reg: &Regex,
group_index: usize,
Expand All @@ -95,14 +98,85 @@ fn extract_group_array(
Ok(builder.into())
}

fn extract_group_array_lit(
s: &str,
pat: &Utf8Array<i64>,
group_index: usize,
) -> PolarsResult<Utf8Array<i64>> {
let mut builder = MutableUtf8Array::<i64>::with_capacity(pat.len());

for opt_pat in pat {
if let Some(pat) = opt_pat {
let reg = Regex::new(pat)?;
let mut locs = reg.capture_locations();
if reg.captures_read(&mut locs, s).is_some() {
builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));
continue;
}
}

// Push null if either the pat is null or there was no match.
builder.push_null();
}

Ok(builder.into())
}

fn extract_group_binary(
arr: &Utf8Array<i64>,
pat: &Utf8Array<i64>,
group_index: usize,
) -> PolarsResult<Utf8Array<i64>> {
let mut builder = MutableUtf8Array::<i64>::with_capacity(arr.len());

for (opt_s, opt_pat) in zip(arr, pat) {
match (opt_s, opt_pat) {
(Some(s), Some(pat)) => {
let reg = Regex::new(pat)?;
let mut locs = reg.capture_locations();
if reg.captures_read(&mut locs, s).is_some() {
builder.push(locs.get(group_index).map(|(start, stop)| &s[start..stop]));
continue;
}
// Push null if there was no match.
builder.push_null()
},
_ => builder.push_null(),
}
}

Ok(builder.into())
}

pub(super) fn extract_group(
ca: &StringChunked,
pat: &str,
pat: &StringChunked,
group_index: usize,
) -> PolarsResult<StringChunked> {
let reg = Regex::new(pat)?;
let chunks = ca
.downcast_iter()
.map(|array| extract_group_array(array, &reg, group_index));
ChunkedArray::try_from_chunk_iter(ca.name(), chunks)
match (ca.len(), pat.len()) {
(_, 1) => {
if let Some(pat) = pat.get(0) {
let reg = Regex::new(pat)?;
try_unary_mut_with_options(ca, |arr| extract_group_reg_lit(arr, &reg, group_index))
} else {
Ok(StringChunked::full_null(ca.name(), ca.len()))
}
},
(1, _) => {
if let Some(s) = ca.get(0) {
try_unary_mut_with_options(pat, |pat| extract_group_array_lit(s, pat, group_index))
} else {
Ok(StringChunked::full_null(ca.name(), pat.len()))
}
},
(len_ca, len_pat) if len_ca == len_pat => try_binary_mut_with_options(
ca,
pat,
|ca, pat| extract_group_binary(ca, pat, group_index),
ca.name(),
),
_ => {
polars_bail!(ComputeError: "ca(len: {}) and pat(len: {}) should either broadcast or have the same length", ca.len(), pat.len())
},
}
}
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ pub trait StringNameSpaceImpl: AsString {
}

/// Extract the nth capture group from pattern.
fn extract(&self, pat: &str, group_index: usize) -> PolarsResult<StringChunked> {
fn extract(&self, pat: &StringChunked, group_index: usize) -> PolarsResult<StringChunked> {
let ca = self.as_string();
super::extract::extract_group(ca, pat, group_index)
}
Expand Down
22 changes: 8 additions & 14 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ pub enum StringFunction {
CountMatches(bool),
EndsWith,
Explode,
Extract {
pat: String,
group_index: usize,
},
Extract(usize),
ExtractAll,
#[cfg(feature = "extract_groups")]
ExtractGroups {
Expand Down Expand Up @@ -133,7 +130,7 @@ impl StringFunction {
CountMatches(_) => mapper.with_dtype(DataType::UInt32),
EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
Explode => mapper.with_same_dtype(),
Extract { .. } => mapper.with_same_dtype(),
Extract(_) => mapper.with_same_dtype(),
ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
#[cfg(feature = "extract_groups")]
ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()),
Expand Down Expand Up @@ -201,7 +198,7 @@ impl Display for StringFunction {
Contains { .. } => "contains",
CountMatches(_) => "count_matches",
EndsWith { .. } => "ends_with",
Extract { .. } => "extract",
Extract(_) => "extract",
#[cfg(feature = "concat_str")]
ConcatHorizontal(_) => "concat_horizontal",
#[cfg(feature = "concat_str")]
Expand Down Expand Up @@ -287,9 +284,7 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
},
EndsWith { .. } => map_as_slice!(strings::ends_with),
StartsWith { .. } => map_as_slice!(strings::starts_with),
Extract { pat, group_index } => {
map!(strings::extract, &pat, group_index)
},
Extract(group_index) => map_as_slice!(strings::extract, group_index),
ExtractAll => {
map_as_slice!(strings::extract_all)
},
Expand Down Expand Up @@ -457,11 +452,10 @@ pub(super) fn starts_with(s: &[Series]) -> PolarsResult<Series> {
}

/// Extract a regex pattern from the a string value.
pub(super) fn extract(s: &Series, pat: &str, group_index: usize) -> PolarsResult<Series> {
let pat = pat.to_string();

let ca = s.str()?;
ca.extract(&pat, group_index).map(|ca| ca.into_series())
pub(super) fn extract(s: &[Series], group_index: usize) -> PolarsResult<Series> {
let ca = s[0].str()?;
let pat = s[1].str()?;
ca.extract(pat, group_index).map(|ca| ca.into_series())
}

#[cfg(feature = "extract_groups")]
Expand Down
11 changes: 7 additions & 4 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,13 @@ impl StringNameSpace {
}

/// Extract a regex pattern from the a string value. If `group_index` is out of bounds, null is returned.
pub fn extract(self, pat: &str, group_index: usize) -> Expr {
let pat = pat.to_string();
self.0
.map_private(StringFunction::Extract { pat, group_index }.into())
pub fn extract(self, pat: Expr, group_index: usize) -> Expr {
self.0.map_many_private(
StringFunction::Extract(group_index).into(),
&[pat],
false,
true,
)
}

#[cfg(feature = "extract_groups")]
Expand Down
2 changes: 1 addition & 1 deletion docs/src/rust/user-guide/expressions/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let out = df
.clone()
.lazy()
.select([col("a").str().extract(r"candidate=(\w+)", 1)])
.select([col("a").str().extract(lit(r"candidate=(\w+)"), 1)])
.collect()?;
println!("{}", &out);
// --8<-- [end:extract]
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ def encode(self, encoding: TransferEncoding) -> Expr:
msg = f"`encoding` must be one of {{'hex', 'base64'}}, got {encoding!r}"
raise ValueError(msg)

def extract(self, pattern: str, group_index: int = 1) -> Expr:
def extract(self, pattern: IntoExprColumn, group_index: int = 1) -> Expr:
r"""
Extract the target capture group from provided patterns.
Expand Down Expand Up @@ -1464,6 +1464,7 @@ def extract(self, pattern: str, group_index: int = 1) -> Expr:
│ ronaldo ┆ polars ┆ null │
└───────────┴─────────┴───────┘
"""
pattern = parse_as_expression(pattern, str_as_lit=True)
return wrap_expr(self._pyexpr.str_extract(pattern, group_index))

def extract_all(self, pattern: str | Expr) -> Expr:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def json_path_match(self, json_path: str) -> Series:
]
"""

def extract(self, pattern: str, group_index: int = 1) -> Series:
def extract(self, pattern: IntoExprColumn, group_index: int = 1) -> Series:
r"""
Extract the target capture group from provided patterns.
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,12 @@ impl PyExpr {
.into()
}

fn str_extract(&self, pat: &str, group_index: usize) -> Self {
self.inner.clone().str().extract(pat, group_index).into()
fn str_extract(&self, pat: Self, group_index: usize) -> Self {
self.inner
.clone()
.str()
.extract(pat.inner, group_index)
.into()
}

fn str_extract_all(&self, pat: Self) -> Self {
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/namespaces/string/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,29 @@ def test_extract_regex() -> None:
assert_series_equal(s.str.extract(r"candidate=(\w+)", 1), expected)


def test_extract() -> None:
df = pl.DataFrame(
{
"s": ["aron123", "12butler", "charly*", "~david", None],
"pat": [r"^([a-zA-Z]+)", r"^(\d+)", None, "^(da)", r"(.*)"],
}
)

out = df.select(
all_expr=pl.col("s").str.extract(pl.col("pat"), 1),
str_expr=pl.col("s").str.extract("^([a-zA-Z]+)", 1),
pat_expr=pl.lit("aron123").str.extract(pl.col("pat")),
)
expected = pl.DataFrame(
{
"all_expr": ["aron", "12", None, None, None],
"str_expr": ["aron", None, "charly", None, None],
"pat_expr": ["aron", None, None, None, "aron123"],
}
)
assert_frame_equal(out, expected)


def test_extract_binary() -> None:
df = pl.DataFrame({"foo": ["aron", "butler", "charly", "david"]})
out = df.filter(pl.col("foo").str.extract("^(a)", 1) == "a").to_series()
Expand Down

0 comments on commit 930bfc7

Please sign in to comment.