-
Notifications
You must be signed in to change notification settings - Fork 15
/
polars.py
168 lines (134 loc) · 5.06 KB
/
polars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""Support for pola.rs library.
pola.rs Series and DataFrames can be passed to h1, ..., h
in the same way as their pandas equivalents.
Note that by default, we drop NAs, but not nulls.
Histogramming a column with nulls will result in an error.
Examples:
>>> import polars, physt
>>> series = polars.Series("x", range(100))
>>> physt.h1(series)
Histogram1D(bins=(10,), total=100, dtype=int64)
"""
# TODO: Support structures with numerical items
from typing import Any, Collection, Iterable, NoReturn, Optional, Tuple, Union
import numpy as np
import pandas as pd
import polars
import physt
from physt._construction import (
extract_1d_array,
extract_axis_name,
extract_axis_names,
extract_nd_array,
extract_weights,
)
from physt.types import Histogram1D, HistogramND
NUMERIC_POLARS_DTYPES = [
polars.Int8,
polars.Int16,
polars.Int32,
polars.Int64,
polars.UInt8,
polars.UInt16,
polars.UInt32,
polars.UInt64,
polars.Float32,
polars.Float64,
]
@extract_axis_name.register
def _(data: polars.Series, *, axis_name: Optional[str] = None) -> Optional[str]:
if axis_name is not None:
return axis_name
return data.name
@extract_axis_name.register
def _(data: polars.DataFrame, **kwargs) -> NoReturn:
raise ValueError("Cannot extract axis name from a polars DataFrame.")
@extract_1d_array.register
def _(
data: polars.Series, *, dropna: bool = True
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
if data.dtype not in NUMERIC_POLARS_DTYPES:
raise ValueError(
f"Cannot extract float array from type {data.dtype}, must be int-like or float-like"
)
if data.is_null().any():
raise ValueError("Cannot create histogram from series with nulls")
return extract_1d_array(data.to_numpy(allow_copy=True), dropna=dropna) # type: ignore
@extract_1d_array.register
def _(data: polars.DataFrame, **kwargs) -> NoReturn:
raise ValueError(
"Cannot extract 1D array suitable for histogramming from a polars dataframe. "
"Either select a Series or extract multidimensional data."
)
@extract_nd_array.register
def _(data: polars.Series, **kwargs) -> NoReturn:
raise ValueError(
"Cannot extract multidimensional array suitable for histogramming from a polars series. "
"Either select a DataFrame or extract 1D data."
)
@extract_nd_array.register
def _(
data: polars.DataFrame, *, dim: Optional[int] = None, dropna: bool = True
) -> Tuple[int, np.ndarray, Optional[np.ndarray]]:
if data.shape[1] == 0:
raise ValueError("Must have at least one column.")
# TODO: This is not very optimized
pandas_df = pd.DataFrame(
{key: extract_1d_array(data[key], dropna=False)[0] for key in data.columns}
)
return extract_nd_array(pandas_df, dim=dim, dropna=dropna) # type: ignore
@extract_axis_names.register
def _(
data: polars.DataFrame, *, axis_names: Optional[Iterable[str]] = None
) -> Optional[Tuple[str, ...]]:
if axis_names is not None:
result = tuple(axis_names)
if (given_length := len(result)) != (expected_length := data.shape[1]):
raise ValueError(
f"Explicit {axis_names=} has invalid length {given_length}, {expected_length} expected."
)
return result
return tuple(data.columns)
@extract_axis_names.register
def _(data: polars.Series, **kwargs) -> NoReturn:
raise ValueError("Cannot extract axis names from a single polars Series.")
@extract_weights.register
def _(data: polars.Series, array_mask: Optional[np.ndarray] = None) -> np.ndarray:
array, _ = extract_1d_array(data, dropna=False)
return extract_weights(array, array_mask=array_mask) # type: ignore
@extract_weights.register
def _(data: polars.DataFrame, **kwargs) -> NoReturn:
raise ValueError("Cannot extract weights from a polars DataFrame.")
@polars.api.register_series_namespace("physt")
class PhystSeries:
def __init__(self, series: polars.Series):
# TODO: Check numeric dtypes!
self._series = series
def h1(self, bins: Any = None, **kwargs) -> Histogram1D:
return physt.h1(self._series, bins=bins, **kwargs)
@polars.api.register_dataframe_namespace("physt")
class PhystFrame:
def __init__(self, df: polars.DataFrame):
self._df = df
def h(
self,
columns: Union[str, Collection[str], None] = None,
bins: Any = None,
**kwargs,
) -> Union[Histogram1D, HistogramND]:
if columns is None:
columns = self._df.columns
if isinstance(columns, str):
columns = [columns]
try:
data = self._df[columns]
except KeyError as exc:
raise KeyError(
f"At least one of the columns '{columns}' could not be found."
) from exc
if len(columns) == 1:
return physt.h1(data, bins=bins, **kwargs)
if len(columns) == 2:
return physt.h2(data[columns[0]], data[columns[1]], bins=bins, **kwargs)
# TODO: Check numeric dtypes ?
return physt.h(data, bins=bins, **kwargs)