-
Notifications
You must be signed in to change notification settings - Fork 469
/
utils.py
268 lines (226 loc) · 7.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch
import numpy as np
import scipy
class TorchDataset(Dataset):
"""
Format for numpy array
Parameters
----------
X: 2D array
The input matrix
y: 2D array
The one-hot encoded target
"""
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, index):
x, y = self.x[index], self.y[index]
return x, y
class PredictDataset(Dataset):
"""
Format for numpy array
Parameters
----------
X: 2D array
The input matrix
"""
def __init__(self, x):
self.x = x
def __len__(self):
return len(self.x)
def __getitem__(self, index):
x = self.x[index]
return x
def create_dataloaders(
X_train, y_train, eval_set, weights, batch_size, num_workers, drop_last
):
"""
Create dataloaders with or wihtout subsampling depending on weights and balanced.
Parameters
----------
X_train: np.ndarray
Training data
y_train: np.array
Mapped Training targets
X_valid: np.ndarray
Validation data
y_valid: np.array
Mapped Validation targets
weights : either 0, 1, dict or iterable
if 0 (default) : no weights will be applied
if 1 : classification only, will balanced class with inverse frequency
if dict : keys are corresponding class values are sample weights
if iterable : list or np array must be of length equal to nb elements
in the training set
Returns
-------
train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader
Training and validation dataloaders
"""
if isinstance(weights, int):
if weights == 0:
need_shuffle = True
sampler = None
elif weights == 1:
need_shuffle = False
class_sample_count = np.array(
[len(np.where(y_train == t)[0]) for t in np.unique(y_train)]
)
weights = 1.0 / class_sample_count
samples_weight = np.array([weights[t] for t in y_train])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
else:
raise ValueError("Weights should be either 0, 1, dictionnary or list.")
elif isinstance(weights, dict):
# custom weights per class
need_shuffle = False
samples_weight = np.array([weights[t] for t in y_train])
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
else:
# custom weights
if len(weights) != len(y_train):
raise ValueError("Custom weights should match number of train samples.")
need_shuffle = False
samples_weight = np.array(weights)
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataloader = DataLoader(
TorchDataset(X_train, y_train),
batch_size=batch_size,
sampler=sampler,
shuffle=need_shuffle,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True
)
valid_dataloaders = []
for X, y in eval_set:
valid_dataloaders.append(
DataLoader(
TorchDataset(X, y),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
)
return train_dataloader, valid_dataloaders
def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim):
"""
This is a computational trick.
In order to rapidly sum importances from same embeddings
to the initial index.
Parameters
----------
input_dim: int
Initial input dim
cat_emb_dim : int or list of int
if int : size of embedding for all categorical feature
if list of int : size of embedding for each categorical feature
cat_idxs : list of int
Initial position of categorical features
post_embed_dim : int
Post embedding inputs dimension
Returns
-------
reducing_matrix : np.array
Matrix of dim (post_embed_dim, input_dim) to performe reduce
"""
if isinstance(cat_emb_dim, int):
all_emb_impact = [cat_emb_dim - 1] * len(cat_idxs)
else:
all_emb_impact = [emb_dim - 1 for emb_dim in cat_emb_dim]
acc_emb = 0
nb_emb = 0
indices_trick = []
for i in range(input_dim):
if i not in cat_idxs:
indices_trick.append([i + acc_emb])
else:
indices_trick.append(
range(i + acc_emb, i + acc_emb + all_emb_impact[nb_emb] + 1)
)
acc_emb += all_emb_impact[nb_emb]
nb_emb += 1
reducing_matrix = np.zeros((post_embed_dim, input_dim))
for i, cols in enumerate(indices_trick):
reducing_matrix[cols, i] = 1
return scipy.sparse.csc_matrix(reducing_matrix)
def filter_weights(weights):
"""
This function makes sure that weights are in correct format for
regression and multitask TabNet
Parameters
----------
weights: int, dict or list
Initial weights parameters given by user
Returns
-------
None : This function will only throw an error if format is wrong
"""
err_msg = "Please provide a list of weights for regression or multitask : "
if isinstance(weights, int):
if weights == 1:
raise ValueError(err_msg + "1 given.")
if isinstance(weights, dict):
raise ValueError(err_msg + "Dict given.")
return
def validate_eval_set(eval_set, eval_name, X_train, y_train):
"""Check if the shapes of eval_set are compatible with (X_train, y_train).
Parameters
----------
eval_set: list of tuple
List of eval tuple set (X, y).
The last one is used for early stopping
eval_names: list of str
List of eval set names.
X_train: np.ndarray
Train owned products
y_train : np.array
Train targeted products
Returns
-------
eval_names : list of str
Validated list of eval_names.
eval_set : list of tuple
Validated list of eval_set.
"""
eval_name = eval_name or [f"val_{i}" for i in range(len(eval_set))]
assert len(eval_set) == len(
eval_name
), "eval_set and eval_name have not the same length"
if len(eval_set) > 0:
assert all(
len(elem) == 2 for elem in eval_set
), "Each tuple of eval_set need to have two elements"
for name, (X, y) in zip(eval_name, eval_set):
check_nans(X)
check_nans(y)
msg = (
f"Number of columns is different between X_{name} "
+ f"({X.shape[1]}) and X_train ({X_train.shape[1]})"
)
assert X.shape[1] == X_train.shape[1], msg
if len(y_train.shape) == 2:
msg = (
f"Number of columns is different between y_{name} "
+ f"({y.shape[1]}) and y_train ({y_train.shape[1]})"
)
assert y.shape[1] == y_train.shape[1], msg
msg = (
f"You need the same number of rows between X_{name} "
+ f"({X.shape[0]}) and y_{name} ({y.shape[0]})"
)
assert X.shape[0] == y.shape[0], msg
return eval_name, eval_set
def check_nans(array):
if np.isnan(array).any():
raise ValueError("NaN were found, TabNet does not allow nans.")
if np.isinf(array).any():
raise ValueError("Infinite values were found, TabNet does not allow inf.")