/
normalize.py
101 lines (82 loc) · 3.66 KB
/
normalize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from paips.core import Task
import tqdm
import pandas as pd
import numpy as np
from IPython import embed
class BatchedMVN:
def __init__(self, axis=0):
self.moving_mean = 0
self.moving_var = 0
self.data_size = 0
self.axis = axis
def update(self,data):
if self.axis is None:
n = data.size
else:
n = data.shape[self.axis]
batch_mean = np.mean(data,axis=self.axis)
batch_std = np.std(data,axis=self.axis)
self.moving_mean = (self.data_size*self.moving_mean + n*batch_mean)/(n+self.data_size)
self.moving_var = (n*batch_std**2)/(n+self.data_size) + (self.data_size*self.moving_var)/(n+self.data_size) +\
(n*self.data_size*(self.moving_mean - batch_mean)**2)/((n+self.data_size)**2)
#if np.any(np.isnan(self.moving_var)):
# embed()
self.data_size += n
def get_mean_and_std(self):
return self.moving_mean, self.moving_var**0.5
class NormalizationStatistics(Task):
def process(self):
data = self.parameters['in']
normalization_by = self.parameters.get('by','global')
column = self.parameters.get('column',None)
if not isinstance(column,list):
column = [column]
mode = self.parameters.get('mode','mvn')
axis = self.parameters.get('axis',0)
if normalization_by == 'global':
statistics = {}
for col in column:
feat_i_stats = {'global': {}}
idxs = data.index
if mode == 'mvn':
batched_mvn = BatchedMVN(axis=axis)
for idx in idxs:
data_i = data.loc[idx][col]
data_i_type = type(data_i).__name__
if (data_i_type == 'GenericFile'):
data_i = data_i.load()
elif (data_i_type == 'PosixPath') or (data_i_type == 'str'):
data_i = joblib.load(data_i)
else:
pass
batched_mvn.update(data_i)
mean, std = batched_mvn.get_mean_and_std()
feat_i_stats['global'] = dict(mean = mean, std = std)
statistics[col] = feat_i_stats
return statistics
elif normalization_by is None:
return {}
else:
statistics = {}
groups = data[normalization_by].unique()
for col in column:
feat_i_stats = {normalization_by:{}}
for g in groups:
grouped_data = data.loc[data[normalization_by] == g][col]
group_idxs = grouped_data.index
if mode == 'mvn':
group_mvn = BatchedMVN(axis=axis)
for idx in group_idxs:
data_i = grouped_data.loc[idx]
data_i_type = type(data_i).__name__
if (data_i_type == 'GenericFile'):
data_i = data_i.load()
elif (data_i_type == 'PosixPath') or (data_i_type == 'str'):
data_i = joblib.load(data_i)
else:
pass
group_mvn.update(data_i)
mean, std = group_mvn.get_mean_and_std()
feat_i_stats[normalization_by][g] = dict(mean = mean, std = std)
statistics[col] = feat_i_stats
return statistics