In [6]:
!pip install -Uq polars

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m32.5/32.5 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [16]:
import polars as pl
from datetime import datetime
from typing import List, Dict, Any, Union, Tuple
from dataclasses import dataclass, field
from collections import OrderedDict
import warnings


def singleton(class_):
    instances = {}
    def getinstance(*args, **kwargs):
        if class_ not in instances:
            instances[class_] = class_(*args, **kwargs)
        return instances[class_]
    return getinstance


@dataclass
class Filter:
    """..."""
    source: str
    column: str
    operator: str
    value: Any
    name: str
    mask: pl.Series = None
    order_key: int = 0

    def calc(self, df: pl.DataFrame) -> pl.Series:
        if self.operator == 'eq':
            mask = df[self.column].eq(self.value)
        elif self.operator == 'eq_missing':
            mask = df[self.column].eq_missing(self.value)
        elif self.operator == 'ne':
            mask = df[self.column].ne(self.value)
        elif self.operator == 'ne_missing':
            mask = df[self.column].ne_missing(self.value)
        elif self.operator == 'ge':
            mask = df[self.column].ge(self.value)
        elif self.operator == 'le':
            mask = df[self.column].le(self.value)
        elif self.operator == 'gt':
            mask = df[self.column].gt(self.value)
        elif self.operator == 'lt':
            mask = df[self.column].lt(self.value)
        elif self.operator == 'is_in':
            mask = df[self.column].is_in(self.value)
        elif self.operator == 'is_nan':
            mask = df[self.column].is_nan()
        elif self.operator == 'is_not_nan':
            mask = df[self.column].is_not_nan()
        elif self.operator == 'is_between':
            mask = df[self.column].is_between(**self.value)
        else:
            raise NotImplementedError(f'Оператор {self.operator} не поддерживается')

        return mask

    def fit(self, X: pl.DataFrame, y: pl.Series = None):
        self.mask = self.calc(X)

        return self

    def transform(self, X: pl.DataFrame, y: pl.Series = None) -> pl.DataFrame:
        if self.is_fitted:
            output = X.filter(self.mask)
        else:
            mask = self.calc(X)
            output = X.filter(mask)

        return output

    def purge(self):
        self.mask = None

        return self

    @property
    def is_fitted(self):
        return self.mask is not None


@dataclass
class FilterSet:
    """Класс, содержащий цепочку фильтров, отсортированных в лексикографическом порядке.

    arguments:
        source: Источник данных
        filters: Список фильтров
        name: Название цепочки фильтров. Если не указывать, по умолчанию название формируется путем склейки названий фильтров через '_'.
            `name` используется для сортировки цепочек фильтров в `FilteredChain`
    """
    source: str
    filters: Union[Filter, List[Filter], OrderedDict[str, Filter], Dict[str, Filter]]
    name: str = None

    @property
    def filters(self):
        return self._filters

    @property
    def filters_list(self):
        return self._filters_list

    @filters.setter
    def filters(self, filters: Union[Filter, List[Filter], OrderedDict[str, Filter], Dict[str, Filter]]):
        if isinstance(filters, list):
            self._filters = OrderedDict({f.name: f for f in sorted(filters, key=lambda x: (x.order_key, x.name))})
            self._filters_list = list(self._filters.values())
        elif isinstance(filters, OrderedDict) or isinstance(filters, dict):
            self._filters = OrderedDict({k: filters[k] for k, v in sorted(filters.items(), key=lambda x: (x[1].order_key, x[1].name))})
            self._filters_list = list(self._filters.values())
        elif isinstance(filters, Filter):
            self._filters = OrderedDict({filters.name: filters})
            self._filters_list = list(self._filters.values())
        else:
            raise NotImplementedError(f'Неправильный аргумент `filters`: {filters}')

    @property
    def name(self):
        return self._name

    @name.setter
    def name(self, name: str):
        self._name = '_'.join([v.name for k, v in self.filters.items()]) if name is None or isinstance(name, property) else name

    def __getitem__(self, value):
        try:
            if isinstance(value, int) or isinstance(value, slice):
                return self.filters_list[value]
            elif isinstance(value, str):
                return self.filters[value]
            else:
                raise IndexError(f'incorrect index value: {value}')

        except IndexError as e:
            keys = []
            for i, k in enumerate(self.keys()):
                if i == 10: break
                else: keys.append(k)

            e.args += (
                f'provided value: {value}',
                f'length of filters: {len(self)}',
                f'top 10 available keys: {keys}',
            )
            raise e

    def __len__(self):
        return len(self.filters)

    def values(self):
        return self.filters.values()

    def keys(self):
        return self.filters.keys()

    def items(self):
        return self.filters.items()

    def __iter__(self):
        for v in self._filters_list:
            yield v

    def fit(self, X: pl.DataFrame, y: pl.Series = None):
        for f in self:
            if f.is_fitted: continue
            else: f.fit(X)

        return self

    def transform(self, X: pl.DataFrame, y: pl.Series = None) -> pl.DataFrame:
        if self.is_empty:
            return X
        elif self.is_fitted:
            output = X.filter(self.mask)
        else:
            mask_list = []
            for f in self:
                mask = f.calc(X) if not f.is_fitted else f.mask
                mask_list.append(mask)

            mask = self.merge_mask_list(mask_list)
            output = X.filter(mask)

        return output

    def purge_all(self):
        for f in self:
            if f.is_fitted: f.purge()

        return self

    @property
    def is_fitted(self):
        if self.is_empty: return None
        else: return all(f.is_fitted for f in self)

    @property
    def mask_list(self):
        return [f.mask for f in self]

    @property
    def mask(self):
        if self.is_fitted:
            return self.merge_mask_list(self.mask_list)
        else:
            return None

    @staticmethod
    def merge_mask_list(mask_list):
        if len(mask_list) == 0: return None
        if len(mask_list) == 1: return mask_list[0]
        mask = mask_list[0]
        for m in mask_list[1:]:
            mask &= m # TBD
        return mask

    @property
    def is_empty(self) -> bool:
        return len(self) == 0


@dataclass
class FilteredChain:
    """..."""
    filter_sets: List[FilterSet]
    do_fit: bool = True
    n_cache_filters: int = 50
    fitted_filters: OrderedDict[str, Filter] = field(default_factory = lambda: OrderedDict())
    empty_filter_nodes: List[str] = field(default_factory = lambda: list())

    @property
    def filter_sets(self):
        return self._filter_sets

    @property
    def filter_sets_list(self):
        return self._filter_sets_list

    @property
    def filter_sets_tree(self):
        return self._filter_sets_tree

    @filter_sets.setter
    def filter_sets(self, filter_sets: List[FilterSet]):
        if isinstance(filter_sets, list):
            self._filter_sets = OrderedDict({f.name: f for f in sorted(filter_sets, key=lambda x: x.name)})
            self._filter_sets_list = list(self._filter_sets.values())
            self._filter_sets_tree = self.build_filter_tree(self._filter_sets_list)
        elif isinstance(filter_sets, FilterSet):
            self._filter_sets = OrderedDict({filter_sets.name: filter_sets})
            self._filter_sets_list = list(self._filter_sets.values())
            self._filter_sets_tree = self.build_filter_tree(self._filter_sets_list)
        else:
            raise NotImplementedError(f'Неправильный аргумент `filter_sets`: {filter_sets}')

    def fit_filter_set(self, filter_set: FilterSet, data: pl.DataFrame):
        if self.do_fit:
            filter_set.fit(data)

            for f in filter_set:
                self.fitted_filters[f.name] = f

        return self

    def purge_n_fitted_filters(self, n):
        for i, fitted_filter in enumerate(self.fitted_filters.copy().values()):
            self.fitted_filters[fitted_filter.name].purge()
            del self.fitted_filters[fitted_filter.name]
            if i == n - 1:
                break

    def apply(self, data: pl.DataFrame):
        for filter_set in self.filter_sets_list:
            self.fit_filter_set(filter_set, data)

            if len(self.fitted_filters) > self.n_cache_filters >= 0:
                n_filters_to_purge = len(self.fitted_filters) - self.n_cache_filters
                self.purge_n_fitted_filters(n_filters_to_purge)

            output = filter_set.transform(data)

            yield filter_set, output

    def purge_all(self):
        for f in self.filter_sets_list:
            f.purge_all()

        self.fitted_filters = OrderedDict()
        self.empty_filter_nodes = list()

        return self

    @staticmethod
    def build_filter_tree(filter_sets_list):

        def insert_into_tree(tree, filter_set):
            current_level = tree
            node_name = ''
            for f in filter_set:
                node_name += f'_{f.name}' if len(node_name) > 0 else f.name # TBD
                if node_name not in current_level:
                    current_level[node_name] = {}
                current_level = current_level[node_name]
            current_level['FilterSet'] = filter_set

        filter_tree = {}
        for filter_set in filter_sets_list:
            insert_into_tree(filter_tree, filter_set)

        return filter_tree

    def apply_filter_tree(self, data: pl.DataFrame):
        yield from self._iterate_filter_tree(self.filter_sets_tree.copy(), data)

    def _iterate_filter_tree(self, filter_tree, data: pl.DataFrame):
        for k, v in filter_tree.items():
            if k == 'FilterSet':
                self.fit_filter_set(v, data)

                if len(self.fitted_filters) > self.n_cache_filters >= 0:
                    n_filters_to_purge = len(self.fitted_filters) - self.n_cache_filters
                    self.purge_n_fitted_filters(n_filters_to_purge)

                if v.mask.sum() == 0:
                    self.empty_filter_nodes.append(v.name)
                    continue

                output = v.transform(data)

                yield v, output
            else:
                for f in self.empty_filter_nodes:
                    if f in k:
                        self.empty_filter_nodes.append(k)
                        break
                else:
                    yield from self._iterate_filter_tree(v, data)


@dataclass
class FeatureAggregate:
    """Class for storing configs of features."""
    source: str
    id_column: str
    column: str
    dtype: Union[str, type]
    aggregations: Union[str, Any, List[Union[str, Any]]] = 'mean'
    name: str = None

    @property
    def aggregations(self):
        return self._aggregations

    @aggregations.setter
    def aggregations(self, aggregations: Union[str, Any, List[Union[str, Any]]]):
        self._aggregations = [aggregations] if not isinstance(aggregations, list) else aggregations

    @property
    def name(self):
        return self._name

    @name.setter
    def name(self, name: str):
        self._name = self.column if name is None or isinstance(name, property) else name

    def fit(self, X: pl.DataFrame, y: pl.Series = None):
        # do nothing
        return self

    def aggregate_functions(self, suffixes: str = '') -> List[pl.Expr]:
        aggregate_functions = [
            getattr(pl.col(self.column), agg)().alias(f'{agg}_{self.name}{suffixes}')
            for agg in self.aggregations
        ]
        return aggregate_functions

    def calc(self, df: pl.DataFrame, filter_set: FilterSet = None, suffixes: str = '') -> pl.DataFrame:
        """Расчет агрегации."""
        if filter_set is None:
            return df.group_by(self.id_column)\
                   .agg(*self.aggregate_functions(suffixes=suffixes))
        else:
            suffixes = f'_{filter_set.name}' if suffixes == '' else suffixes
            return filter_set.transform(df).group_by(self.id_column)\
                   .agg(*self.aggregate_functions(suffixes=suffixes))

In [13]:
data = {
    'REQUESTID': [1, 1, 2, 2, 3, 3],
    'CREDITSUM': [1000, 2000, 1500, 3000, 1000, 500],
    'OVERDUEDEBT': [100, 0, 200, 500, 0, 50],
    'IS_OWN': [1, 1, 0, 1, 0, 0],
    'LOAN_TYPE': [7, 9, 10, 9, 9, 10]
}

df = pl.DataFrame(data)

feat_0 = FeatureAggregate(**{
    'source': 'cr_loan',
    'id_column': 'REQUESTID',
    'column': 'CREDITSUM',
    'dtype': float,
    'aggregations': ['mean', 'max'],
})

feat_1 = FeatureAggregate(**{
    'source': 'cr_loan',
    'id_column': 'REQUESTID',
    'column': 'CREDITSUM',
    'dtype': float,
    'aggregations': 'max',
})

f_0 = Filter(**{
    'name': '',
    'column': 'REQUESTID',
    'operator': 'is_not_nan',
    'value': -1,
    'source': 'cr_loan',
    'order_key': 0,
})
f_1 = Filter(**{
    'name': 'MTSB',
    'column': 'IS_OWN',
    'operator': 'eq',
    'value': 1,
    'source': 'cr_loan',
    'order_key': 0,
})
f_2 = Filter(**{
    'name': 'POTREB',
    'column': 'LOAN_TYPE',
    'operator': 'is_in',
    'value': [9],
    'source': 'cr_loan',
    'order_key': 1,
})
f_3 = Filter(**{
    'name': 'MICRO',
    'column': 'LOAN_TYPE',
    'operator': 'is_in',
    'value': [21],
    'source': 'cr_loan',
    'order_key': 1,
})
f_4 = Filter(**{
    'name': 'CREDITSUM_GE_2000',
    'column': 'CREDITSUM',
    'operator': 'ge',
    'value': 2000,
    'source': 'cr_loan',
    'order_key': 2,
})

f_s_0 = FilterSet(source='cr_loan', filters=[f_0])
f_s_1 = FilterSet(source='cr_loan', filters=[f_2, f_1])
f_s_2 = FilterSet(source='cr_loan', filters=[f_1])
f_s_3 = FilterSet(source='cr_loan', filters=[f_2])
f_s_4 = FilterSet(source='cr_loan', filters=[f_1, f_2])
f_s_5 = FilterSet(source='cr_loan', filters=[f_3, f_1])
f_s_6 = FilterSet(source='cr_loan', filters=[f_3])
f_s_7 = FilterSet(source='cr_loan', filters=[f_3, f_1, f_4])

filtered_chain = FilteredChain([f_s_0, f_s_1, f_s_2, f_s_3, f_s_4, f_s_5, f_s_6, f_s_7], do_fit=True, n_cache_filters=-1)

In [14]:
feat_0.calc(df, f_s_1)

REQUESTID,mean_CREDITSUM_MTSB_POTREB,max_CREDITSUM_MTSB_POTREB
i64,f64,i64
1,2000.0,2000
2,3000.0,3000


In [15]:
for filter_set, filtered_df in filtered_chain.apply_filter_tree(df):
    # print(filter_set)
    display(filtered_df)
    display(feat_0.calc(filtered_df))

REQUESTID,CREDITSUM,OVERDUEDEBT,IS_OWN,LOAN_TYPE
i64,i64,i64,i64,i64
1,1000,100,1,7
1,2000,0,1,9
2,1500,200,0,10
2,3000,500,1,9
3,1000,0,0,9
3,500,50,0,10


REQUESTID,mean_CREDITSUM,max_CREDITSUM
i64,f64,i64
1,1500.0,2000
3,750.0,1000
2,2250.0,3000


REQUESTID,CREDITSUM,OVERDUEDEBT,IS_OWN,LOAN_TYPE
i64,i64,i64,i64,i64
1,1000,100,1,7
1,2000,0,1,9
2,3000,500,1,9


REQUESTID,mean_CREDITSUM,max_CREDITSUM
i64,f64,i64
1,1500.0,2000
2,3000.0,3000


REQUESTID,CREDITSUM,OVERDUEDEBT,IS_OWN,LOAN_TYPE
i64,i64,i64,i64,i64
1,2000,0,1,9
2,3000,500,1,9


REQUESTID,mean_CREDITSUM,max_CREDITSUM
i64,f64,i64
1,2000.0,2000
2,3000.0,3000


REQUESTID,CREDITSUM,OVERDUEDEBT,IS_OWN,LOAN_TYPE
i64,i64,i64,i64,i64
1,2000,0,1,9
2,3000,500,1,9
3,1000,0,0,9


REQUESTID,mean_CREDITSUM,max_CREDITSUM
i64,f64,i64
2,3000.0,3000
3,1000.0,1000
1,2000.0,2000
