Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Series.split_into/3 #873

Merged
merged 8 commits into from
Mar 5, 2024
Merged
8 changes: 8 additions & 0 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ defmodule Explorer.Backend.LazySeries do
downcase: 1,
substring: 3,
split: 2,
split_into: 3,
json_decode: 2,
json_path_match: 2,
# Float round
Expand Down Expand Up @@ -1053,6 +1054,13 @@ defmodule Explorer.Backend.LazySeries do
Backend.Series.new(data, {:list, :string})
end

@impl true
def split_into(series, by, fields) do
data = new(:split_into, [lazy_series!(series), by, fields], :string)

Backend.Series.new(data, {:struct, Enum.map(fields, &{&1, :string})})
end

@impl true
def round(series, decimals) when is_integer(decimals) and decimals >= 0 do
data = new(:round, [lazy_series!(series), decimals], {:f, 64})
Expand Down
1 change: 1 addition & 0 deletions lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ defmodule Explorer.Backend.Series do
@callback rstrip(s, String.t() | nil) :: s
@callback substring(s, integer(), non_neg_integer() | nil) :: s
@callback split(s, String.t()) :: s
@callback split_into(s, String.t(), list(String.t() | atom())) :: s
@callback json_decode(s, dtype()) :: s
@callback json_path_match(s, String.t()) :: s

Expand Down
1 change: 1 addition & 0 deletions lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ defmodule Explorer.PolarsBackend.Expression do
upcase: 1,
substring: 3,
split: 2,
split_into: 3,
json_decode: 2,
json_path_match: 2,

Expand Down
1 change: 1 addition & 0 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_cut(_s, _bins, _labels, _break_point_label, _category_label), do: err()
def s_substring(_s, _offset, _length), do: err()
def s_split(_s, _by), do: err()
def s_split_into(_s, _by, _num_fields), do: err()

def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label),
do: err()
Expand Down
4 changes: 4 additions & 0 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ defmodule Explorer.PolarsBackend.Series do
def split(series, by),
do: Shared.apply_series(series, :s_split, [by])

@impl true
def split_into(series, by, fields),
do: Shared.apply_series(series, :s_split_into, [by, fields])

# Float round
@impl true
def round(series, decimals),
Expand Down
26 changes: 26 additions & 0 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5651,6 +5651,32 @@ defmodule Explorer.Series do
def split(%Series{dtype: dtype}, _by),
do: dtype_error("split/2", dtype, [:string])

@doc """
Split a string Series into a struct of string `fields`.

The length of the field names list determines how many times the
string will be split at most. If the string cannot be split into that
many separate strings, null values will be provided for the
remaining fields.

## Examples

iex> s = Series.from_list(["Smith, John", "Jones, Jane"])
iex> Series.split_into(s, ", ", ["Last Name", "First Name"])
#Explorer.Series<
Polars[2]
struct[2] [%{"First Name" => "John", "Last Name" => "Smith"}, %{"First Name" => "Jane", "Last Name" => "Jones"}]
>

"""
@doc type: :string_wise
@spec split_into(Series.t(), String.t(), list(String.t() | atom())) :: Series.t()
def split_into(%Series{dtype: :string} = series, by, [_ | _] = fields) when is_binary(by),
do: apply_series(series, :split_into, [by, Enum.map(fields, &to_string/1)])

def split_into(%Series{dtype: dtype}, by, [_ | _]) when is_binary(by),
do: dtype_error("split_into/3", dtype, [:string])

# Float

@doc """
Expand Down
12 changes: 12 additions & 0 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,18 @@ pub fn expr_json_path_match(expr: ExExpr, json_path: &str) -> ExExpr {
ExExpr::new(expr)
}

#[rustler::nif]
pub fn expr_split_into(expr: ExExpr, by: String, names: Vec<String>) -> ExExpr {
let expr = expr
.clone_inner()
.str()
.splitn(by.lit(), names.len())
.struct_()
.rename_fields(names);

ExExpr::new(expr)
}

#[rustler::nif]
pub fn expr_struct(ex_exprs: Vec<ExExpr>) -> ExExpr {
let exprs = ex_exprs.iter().map(|e| e.clone_inner()).collect();
Expand Down
2 changes: 2 additions & 0 deletions native/explorer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ rustler::init!(
expr_substring,
expr_replace,
expr_json_path_match,
expr_split_into,
// float round expressions
expr_round,
expr_floor,
Expand Down Expand Up @@ -456,6 +457,7 @@ rustler::init!(
s_strip,
s_substring,
s_split,
s_split_into,
s_subtract,
s_sum,
s_tail,
Expand Down
19 changes: 19 additions & 0 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,25 @@ pub fn s_split(s1: ExSeries, by: &str) -> Result<ExSeries, ExplorerError> {
Ok(ExSeries::new(s2))
}

#[rustler::nif(schedule = "DirtyCpu")]
ryancurtin marked this conversation as resolved.
Show resolved Hide resolved
pub fn s_split_into(s1: ExSeries, by: &str, names: Vec<String>) -> Result<ExSeries, ExplorerError> {
let s2 = s1
.clone_inner()
.into_frame()
.lazy()
.select([col(s1.name())
.str()
.splitn(by.lit(), names.len())
.struct_()
.rename_fields(names)
.alias(s1.name())])
.collect()?
.column(s1.name())?
.clone();

Ok(ExSeries::new(s2))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_round(s: ExSeries, decimals: u32) -> Result<ExSeries, ExplorerError> {
Ok(ExSeries::new(s.round(decimals)?.into_series()))
Expand Down
33 changes: 25 additions & 8 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,27 @@ defmodule Explorer.DataFrameTest do
member?: [true, false]
}
end

test "splits a string column into multiple new columns" do
new_column_names = ["Last Name", "First Name"]
df = DF.new(%{names: ["Smith, John", "Jones, Jane"]})

df =
DF.mutate_with(df, fn ldf ->
%{names: Series.split_into(ldf[:names], ", ", new_column_names)}
end)
|> DF.unnest(:names)

assert DF.dtypes(df) == %{
"Last Name" => :string,
"First Name" => :string
}

assert DF.to_columns(df) == %{
"Last Name" => ["Smith", "Jones"],
"First Name" => ["John", "Jane"]
}
end
end

describe "sort_by/3" do
Expand Down Expand Up @@ -2618,17 +2639,13 @@ defmodule Explorer.DataFrameTest do
end

test "mixing nulls, signed, unsigned integers, and floats" do
df1 =
DF.new(x: Series.from_list([1, 2], dtype: :u16), y: Series.from_list(["a", "b"]))
df1 = DF.new(x: Series.from_list([1, 2], dtype: :u16), y: Series.from_list(["a", "b"]))

df2 =
DF.new(x: Series.from_list([3.0, 4.0], dtype: :f32), y: Series.from_list(["c", "d"]))
df2 = DF.new(x: Series.from_list([3.0, 4.0], dtype: :f32), y: Series.from_list(["c", "d"]))

df3 =
DF.new(x: [nil, nil], y: [nil, nil])
df3 = DF.new(x: [nil, nil], y: [nil, nil])

df4 =
DF.new(x: Series.from_list([5, 6], dtype: :s16), y: Series.from_list(["e", "f"]))
df4 = DF.new(x: Series.from_list([5, 6], dtype: :s16), y: Series.from_list(["e", "f"]))

df5 = DF.concat_rows([df1, df2, df3, df4])

Expand Down
22 changes: 22 additions & 0 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5320,6 +5320,28 @@ defmodule Explorer.SeriesTest do
end
end

describe "split_into" do
test "split_into/3 produces the correct number of fields in a struct" do
series = Series.from_list(["Smith, John", "Jones, Jane"])
split_series = series |> Series.split_into(", ", ["Last Name", "First Name"])

assert Series.to_list(split_series) == [
%{"First Name" => "John", "Last Name" => "Smith"},
%{"First Name" => "Jane", "Last Name" => "Jones"}
]
end

test "split_into/3 produces a nil field when string cannot be split for every field" do
series = Series.from_list(["Smith-John", "Jones-Jane"])
split_series = series |> Series.split_into("-", ["Last Name", "First Name", "Middle Name"])

assert Series.to_list(split_series) == [
%{"First Name" => "John", "Last Name" => "Smith", "Middle Name" => nil},
%{"First Name" => "Jane", "Last Name" => "Jones", "Middle Name" => nil}
]
end
end

describe "strptime/2 and strftime/2" do
test "parse datetime from string" do
series = Series.from_list(["2023-01-05 12:34:56", "XYZ", nil])
Expand Down