Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change inner dtypes of structs to tuple lists #851

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ defmodule Explorer.Backend.LazySeries do

@impl true
def field(%Series{dtype: {:struct, inner_dtype}} = series, name) do
dtype = inner_dtype[name]
{^name, dtype} = List.keyfind!(inner_dtype, name, 0)
data = new(:field, [lazy_series!(series), name], dtype)

Backend.Series.new(data, dtype)
Expand Down
4 changes: 2 additions & 2 deletions lib/explorer/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5612,8 +5612,8 @@ defmodule Explorer.DataFrame do
columns
|> Enum.zip(dtypes)
|> Enum.reduce({%{}, %{}}, fn {column, {:struct, inner_dtypes}}, {new_dtypes, new_names} ->
new_dtypes = Map.merge(new_dtypes, inner_dtypes)
new_names = Map.put(new_names, column, Map.keys(inner_dtypes))
new_dtypes = Map.merge(new_dtypes, Map.new(inner_dtypes))
new_names = Map.put(new_names, column, Enum.map(inner_dtypes, &elem(&1, 0)))

{new_dtypes, new_names}
end)
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ defmodule Explorer.PolarsBackend.Shared do
series =
for {column, values} <- Table.to_columns(list) do
column = to_string(column)
inner_type = Map.fetch!(fields, column)
{^column, inner_type} = List.keyfind!(fields, column, 0)
from_list(values, inner_type, column)
end

Expand Down
14 changes: 7 additions & 7 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ defmodule Explorer.Series do
* `:time` - Time type that unwraps to `Elixir.Time`
* `{:list, dtype}` - A recursive dtype that can store lists. Examples: `{:list, :boolean}` or
a nested list dtype like `{:list, {:list, :boolean}}`.
* `{:struct, %{key => dtype}}` - A recursive dtype that can store Arrow/Polars structs (not to be
* `{:struct, [{key, dtype}]}` - A recursive dtype that can store Arrow/Polars structs (not to be
confused with Elixir's struct). This type unwraps to Elixir maps with string keys. Examples:
`{:struct, %{"a" => :string}}` or a nested struct dtype like `{:struct, %{"a" => {:struct, %{"b" => :string}}}}`.
`{:struct, [{"a", :string}]}` or a nested struct dtype like `{:struct, [{"a", {:struct, [{"b", :string}]}}]}`.

When passing a dtype as argument, aliases are supported for convenience
and compatibility with the Elixir ecosystem:
Expand Down Expand Up @@ -151,7 +151,7 @@ defmodule Explorer.Series do
@type datetime_dtype :: {:datetime, time_unit}
@type duration_dtype :: {:duration, time_unit}
@type list_dtype :: {:list, dtype()}
@type struct_dtype :: {:struct, %{String.t() => dtype()}}
@type struct_dtype :: {:struct, [{String.t(), dtype()}]}

@type signed_integer_dtype :: {:s, 8} | {:s, 16} | {:s, 32} | {:s, 64}
@type unsigned_integer_dtype :: {:u, 8} | {:u, 16} | {:u, 32} | {:u, 64}
Expand Down Expand Up @@ -6069,12 +6069,12 @@ defmodule Explorer.Series do
"""
@doc type: :struct_wise
@spec field(Series.t(), String.t()) :: Series.t()
def field(%Series{dtype: {:struct, dtype}} = series, name) when is_binary(name) do
if Map.has_key?(dtype, name) do
def field(%Series{dtype: {:struct, dtypes}} = series, name) when is_binary(name) do
if List.keymember?(dtypes, name, 0) do
apply_series(series, :field, [name])
else
raise ArgumentError,
"field #{inspect(name)} not found in fields #{inspect(Map.keys(dtype))}"
"field #{inspect(name)} not found in fields #{inspect(Enum.map(dtypes, &elem(&1, 0)))}"
end
end

Expand All @@ -6091,7 +6091,7 @@ defmodule Explorer.Series do
>

iex> s = Series.from_list(["{\\"a\\":1}"])
iex> Series.json_decode(s, {:struct, %{"a" => {:s, 64}}})
iex> Series.json_decode(s, {:struct, [{"a", {:s, 64}}]})
#Explorer.Series<
Polars[1]
struct[1] [%{"a" => 1}]
Expand Down
49 changes: 37 additions & 12 deletions lib/explorer/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,23 @@ defmodule Explorer.Shared do

def normalise_dtype({:struct, inner_types}) do
inner_types
|> Enum.reduce_while(%{}, fn {key, dtype}, normalized_dtypes ->
|> Enum.reduce_while([], fn {key, dtype}, normalized_dtypes ->
case normalise_dtype(dtype) do
nil -> {:halt, nil}
dtype -> {:cont, Map.put(normalized_dtypes, key, dtype)}
nil ->
{:halt, nil}

dtype ->
key = to_string(key)
{:cont, List.keystore(normalized_dtypes, key, 0, {key, dtype})}
end
end)
|> then(fn
nil -> nil
normalized_dtypes -> {:struct, normalized_dtypes}
nil ->
nil

normalized_dtypes ->
{:struct,
if(is_map(inner_types), do: Enum.sort(normalized_dtypes), else: normalized_dtypes)}
end)
end

Expand All @@ -74,6 +82,7 @@ defmodule Explorer.Shared do
def normalise_dtype(:u16), do: {:u, 16}
def normalise_dtype(:u32), do: {:u, 32}
def normalise_dtype(:u64), do: {:u, 64}

def normalise_dtype(_dtype), do: nil

@doc """
Expand Down Expand Up @@ -316,14 +325,15 @@ defmodule Explorer.Shared do

defp infer_struct(%{} = map, types) do
types =
for {key, value} <- map, into: %{} do
for {key, value} <- map do
key = to_string(key)

cond do
types == nil ->
{key, infer_type(value, :null)}

type = types[key] ->
result = List.keyfind(types, key, 0) ->
{^key, type} = result
{key, infer_type(value, type)}

true ->
Expand All @@ -332,7 +342,7 @@ defmodule Explorer.Shared do
end
end

{:struct, types}
{:struct, List.keysort(types, 0)}
end

defp merge_preferred(type, type), do: type
Expand All @@ -348,7 +358,19 @@ defmodule Explorer.Shared do
end

defp merge_preferred({:struct, inferred}, {:struct, preferred}) do
{:struct, Map.merge(inferred, preferred, fn _, v1, v2 -> merge_preferred(v1, v2) end)}
{remaining, all_merged} =
Enum.reduce(preferred, {inferred, []}, fn {col, dtype}, {inferred_rest, merged} ->
case List.keytake(inferred_rest, col, 0) do
{{^col, inferred_dtype}, rest} ->
solved = merge_preferred(inferred_dtype, dtype)
{rest, List.keystore(merged, col, 0, {col, solved})}

nil ->
{inferred, List.keystore(merged, col, 0, {col, dtype})}
end
end)

{:struct, all_merged ++ remaining}
end

defp merge_preferred(inferred, _preferred) do
Expand All @@ -366,8 +388,11 @@ defmodule Explorer.Shared do
"""
def cast_numerics(list, {:struct, dtypes}) when is_list(list) do
Enum.map(list, fn item ->
Map.new(item, fn {field, inner_value} ->
inner_dtype = Map.fetch!(dtypes, to_string(field))
Enum.map(item, fn {field, inner_value} ->
column = to_string(field)

{^column, inner_dtype} = List.keyfind!(dtypes, column, 0)

[casted_value] = cast_numerics([inner_value], inner_dtype)
{field, casted_value}
end)
Expand Down Expand Up @@ -537,7 +562,7 @@ defmodule Explorer.Shared do
def dtype_to_string({:duration, :microsecond}), do: "duration[μs]"
def dtype_to_string({:duration, :nanosecond}), do: "duration[ns]"
def dtype_to_string({:list, dtype}), do: "list[" <> dtype_to_string(dtype) <> "]"
def dtype_to_string({:struct, fields}), do: "struct[#{map_size(fields)}]"
def dtype_to_string({:struct, fields}), do: "struct[#{length(fields)}]"
def dtype_to_string({:f, size}), do: "f" <> Integer.to_string(size)
def dtype_to_string({:s, size}), do: "s" <> Integer.to_string(size)
def dtype_to_string({:u, size}), do: "u" <> Integer.to_string(size)
Expand Down
7 changes: 3 additions & 4 deletions native/explorer/src/datatypes/ex_dtypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use polars::datatypes::DataType;
use polars::datatypes::Field;
use polars::datatypes::TimeUnit;
use rustler::NifTaggedEnum;
use std::collections::HashMap;
use std::ops::Deref;

impl rustler::Encoder for Box<ExSeriesDtype> {
Expand Down Expand Up @@ -56,7 +55,7 @@ pub enum ExSeriesDtype {
Datetime(ExTimeUnit),
Duration(ExTimeUnit),
List(Box<ExSeriesDtype>),
Struct(HashMap<String, ExSeriesDtype>),
Struct(Vec<(String, ExSeriesDtype)>),
}

impl TryFrom<&DataType> for ExSeriesDtype {
Expand Down Expand Up @@ -108,11 +107,11 @@ impl TryFrom<&DataType> for ExSeriesDtype {
)?))),

DataType::Struct(fields) => {
let mut struct_fields = HashMap::new();
let mut struct_fields = Vec::new();

for field in fields {
struct_fields
.insert(field.name().to_string(), Self::try_from(field.data_type())?);
.push((field.name().to_string(), Self::try_from(field.data_type())?));
}

Ok(ExSeriesDtype::Struct(struct_fields))
Expand Down
18 changes: 18 additions & 0 deletions test/explorer/data_frame/ndjson_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,24 @@ defmodule Explorer.DataFrame.NDJSONTest do
assert_ndjson({:struct, %{"a" => {:s, 64}}}, [%{a: 1}], %{"a" => 1})
end

test "infers correctly ordered dtype from ordered source" do
df =
"""
{"col": {"b": "b", "a": "a"}}
"""
|> DF.load_ndjson!()

assert df["col"].dtype == {:struct, [{"b", :string}, {"a", :string}]}

df1 =
"""
{"col": {"a": "a", "b": "b"}}
"""
|> DF.load_ndjson!()

assert df1["col"].dtype == {:struct, [{"a", :string}, {"b", :string}]}
end

# test "date" do
# assert_ndjson(:date, "19327", ~D[2022-12-01])
# assert_ndjson(:date, "-3623", ~D[1960-01-31])
Expand Down
6 changes: 3 additions & 3 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,8 @@ defmodule Explorer.DataFrameTest do
test "extracts a field from struct to new column" do
df = DF.new([%{a: %{n: 1}}, %{a: %{n: 1}}])
df2 = DF.mutate(df, n: field(a, "n"))
assert df.dtypes == %{"a" => {:struct, %{"n" => {:s, 64}}}}
assert df2.dtypes == %{"a" => {:struct, %{"n" => {:s, 64}}}, "n" => {:s, 64}}
assert df.dtypes == %{"a" => {:struct, [{"n", {:s, 64}}]}}
assert df2.dtypes == %{"a" => {:struct, [{"n", {:s, 64}}]}, "n" => {:s, 64}}
end

test "throws error when a field is not found in struct" do
Expand Down Expand Up @@ -4297,7 +4297,7 @@ defmodule Explorer.DataFrameTest do
"dt" => {:datetime, :microsecond},
"f" => {:f, 64},
"l" => {:list, {:s, 64}},
"st" => {:struct, %{"n" => {:s, 64}}}
"st" => {:struct, [{"n", {:s, 64}}]}
}

assert df1 |> DF.collect() |> DF.to_columns() == %{
Expand Down
18 changes: 14 additions & 4 deletions test/explorer/series/list_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,29 @@ defmodule Explorer.Series.ListTest do

test "list of structs" do
series =
Series.from_list([[%{"a" => 42}], []], dtype: {:list, {:struct, %{"a" => :integer}}})
Series.from_list([[%{"a" => 42}], []], dtype: {:list, {:struct, [{"a", :integer}]}})

assert Series.dtype(series) == {:list, {:struct, %{"a" => {:s, 64}}}}
assert Series.dtype(series) == {:list, {:struct, [{"a", {:s, 64}}]}}
assert Series.to_list(series) == [[%{"a" => 42}], []]
end

test "list of structs with first empty" do
series =
Series.from_list([[], [%{"a" => 42}], []], dtype: {:list, {:struct, %{"a" => :integer}}})
Series.from_list([[], [%{"a" => 42}], []], dtype: {:list, {:struct, [{"a", :integer}]}})

assert Series.dtype(series) == {:list, {:struct, %{"a" => {:s, 64}}}}
assert Series.dtype(series) == {:list, {:struct, [{"a", {:s, 64}}]}}
assert Series.to_list(series) == [[], [%{"a" => 42}], []]
end

test "list of structs and multiple fields" do
series =
Series.from_list([[], [%{"a" => 42, "b" => "f"}], []],
dtype: {:list, {:struct, [{"a", :integer}, {"b", :string}]}}
)

assert Series.dtype(series) == {:list, {:struct, [{"a", {:s, 64}}, {"b", :string}]}}
assert Series.to_list(series) == [[], [%{"a" => 42, "b" => "f"}], []]
end
end

describe "cast/2" do
Expand Down
Loading
Loading