-
Notifications
You must be signed in to change notification settings - Fork 75
/
transform.py
65 lines (54 loc) · 2.46 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"Cleaning and feature engineering functions for structured data"
from ..torch_core import *
__all__ = ['Categorify', 'FillMissing', 'FillStrategy', 'TabularTransform']
@dataclass
class TabularTransform():
"A transform for tabular dataframes."
cat_names:StrList
cont_names:StrList
def __call__(self, df:DataFrame, test:bool=False):
"Apply the correct function to `df` depending on `test`."
func = self.apply_test if test else self.apply_train
func(df)
def apply_train(self, df:DataFrame):
"Function applied to `df` if it's the train set."
raise NotImplementedError
def apply_test(self, df:DataFrame):
"Function applied to `df` if it's the test set."
self.apply_train(df)
class Categorify(TabularTransform):
"Transform the categorical variables to that type."
def apply_train(self, df:DataFrame):
self.categories = {}
for n in self.cat_names:
df[n] = df[n].astype('category').cat.as_ordered()
self.categories[n] = df[n].cat.categories
def apply_test(self, df:DataFrame):
for n in self.cat_names:
df[n] = pd.Categorical(df[n], categories=self.categories[n], ordered=True)
FillStrategy = IntEnum('FillStrategy', 'MEDIAN COMMON CONSTANT')
@dataclass
class FillMissing(TabularTransform):
"Fill the missing values in continuous columns."
fill_strategy:FillStrategy=FillStrategy.MEDIAN
add_col:bool=True
fill_val:float=0.
def apply_train(self, df:DataFrame):
self.na_dict = {}
for name in self.cont_names:
if pd.isnull(df[name]).sum():
if self.add_col:
df[name+'_na'] = pd.isnull(df[name])
if name+'_na' not in self.cat_names: self.cat_names.append(name+'_na')
if self.fill_strategy == FillStrategy.MEDIAN: filler = df[name].median()
elif self.fill_strategy == FillStrategy.CONSTANT: filler = self.fill_val
else: filler = df[name].dropna().value_counts().idxmax()
df[name] = df[name].fillna(filler)
self.na_dict[name] = filler
def apply_test(self, df:DataFrame):
for name in self.cont_names:
if name in self.na_dict:
if self.add_col:
df[name+'_na'] = pd.isnull(df[name])
if name+'_na' not in self.cat_names: self.cat_names.append(name+'_na')
df[name] = df[name].fillna(self.na_dict[name])