-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathpre_treatment.py
More file actions
123 lines (85 loc) · 3.36 KB
/
pre_treatment.py
File metadata and controls
123 lines (85 loc) · 3.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import re
from abc import ABC, abstractmethod
from typing import Any
import attr
import numpy as np
from pandas import DataFrame
@attr.s(auto_attribs=True)
class PreTreatment(ABC):
"""
Represents an abstract pre-treatment step applied to data before
calculating statistics.
"""
analysis_period_length: int = attr.ib(kw_only=True, default=1)
@classmethod
def name(cls):
"""Return snake-cased name of the statistic."""
return re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__).lower()
@abstractmethod
def apply(self, df: DataFrame, col: str) -> DataFrame:
"""
Applies the pre-treatment transformation to a DataFrame and returns
the resulting DataFrame.
"""
raise NotImplementedError
@classmethod
def from_dict(cls, config_dict: dict[str, Any]):
"""Create a class instance with the specified config parameters."""
return cls(**config_dict) # type: ignore
class RemoveNulls(PreTreatment):
"""Removes rows with null values."""
def apply(self, df: DataFrame, col: str) -> DataFrame:
return df.dropna(subset=[col])
class RemoveIndefinites(PreTreatment):
"""Removes null and infinite values."""
def apply(self, df: DataFrame, col: str) -> DataFrame:
df[col] = df[col].replace(np.inf, np.nan)
return df.dropna(subset=[col])
@attr.s(auto_attribs=True)
class CensorHighestValues(PreTreatment):
"""Removes rows with the highest n% of values."""
fraction: float = 1 - 1e-5
def apply(self, df: DataFrame, col: str) -> DataFrame:
mask = df[col] < df[col].quantile(self.fraction)
return df.loc[mask, :]
@attr.s(auto_attribs=True)
class CensorLowestValues(PreTreatment):
"""Removes rows with the lowest n% of values."""
fraction: float = 1e-5
def apply(self, df: DataFrame, col: str) -> DataFrame:
mask = df[col] > df[col].quantile(self.fraction)
return df.loc[mask, :]
@attr.s(auto_attribs=True)
class CensorValuesBelowThreshold(PreTreatment):
"""Removes rows with values below the provided threshold."""
threshold: float
def apply(self, df: DataFrame, col: str) -> DataFrame:
mask = df[col] > self.threshold
return df.loc[mask, :]
@attr.s(auto_attribs=True)
class CensorValuesAboveThreshold(PreTreatment):
"""Removes rows with values above the provided threshold."""
threshold: float
def apply(self, df: DataFrame, col: str) -> DataFrame:
mask = df[col] < self.threshold
return df.loc[mask, :]
@attr.s(auto_attribs=True)
class NormalizeOverAnalysisPeriod(PreTreatment):
"""Normalizes the row values over a given analysis period (number of days)."""
analysis_period_length: int = 1
def apply(self, df: DataFrame, col: str) -> DataFrame:
df[col] = df[col] / self.analysis_period_length
return df
@attr.s(auto_attribs=True)
class Log(PreTreatment):
base: float | None = 10.0
def apply(self, df: DataFrame, col: str) -> DataFrame:
# Silence divide-by-zero and domain warnings
with np.errstate(divide="ignore", invalid="ignore"):
result = np.log(df[col])
if self.base:
result /= np.log(self.base)
return df.assign(**{col: result})
class ZeroFill(PreTreatment):
def apply(self, df: DataFrame, col: str) -> DataFrame:
return df.fillna(value={col: 0})