/
abstract_model.py
681 lines (575 loc) · 20.6 KB
/
abstract_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
from dataclasses import dataclass, field
from typing import List, Any, Dict
import torch
from torch.nn.utils import clip_grad_norm_
import numpy as np
from scipy.sparse import csc_matrix
from abc import abstractmethod
from pytorch_tabnet import tab_network
from pytorch_tabnet.utils import (
PredictDataset,
create_explain_matrix,
validate_eval_set,
create_dataloaders,
check_nans,
define_device,
)
from pytorch_tabnet.callbacks import (
CallbackContainer,
History,
EarlyStopping,
LRSchedulerCallback,
)
from pytorch_tabnet.metrics import MetricContainer, check_metrics
from sklearn.base import BaseEstimator
from torch.utils.data import DataLoader
import io
import json
from pathlib import Path
import shutil
import zipfile
@dataclass
class TabModel(BaseEstimator):
""" Class for TabNet model."""
n_d: int = 8
n_a: int = 8
n_steps: int = 3
gamma: float = 1.3
cat_idxs: List[int] = field(default_factory=list)
cat_dims: List[int] = field(default_factory=list)
cat_emb_dim: int = 1
n_independent: int = 2
n_shared: int = 2
epsilon: float = 1e-15
momentum: float = 0.02
lambda_sparse: float = 1e-3
seed: int = 0
clip_value: int = 1
verbose: int = 1
optimizer_fn: Any = torch.optim.Adam
optimizer_params: Dict = field(default_factory=lambda: dict(lr=2e-2))
scheduler_fn: Any = None
scheduler_params: Dict = field(default_factory=dict)
mask_type: str = "sparsemax"
input_dim: int = None
output_dim: int = None
device_name: str = "auto"
def __post_init__(self):
self.batch_size = 1024
self.virtual_batch_size = 1024
torch.manual_seed(self.seed)
# Defining device
self.device = torch.device(define_device(self.device_name))
print(f"Device used : {self.device}")
def fit(
self,
X_train,
y_train,
eval_set=None,
eval_name=None,
eval_metric=None,
loss_fn=None,
weights=0,
max_epochs=100,
patience=10,
batch_size=1024,
virtual_batch_size=128,
num_workers=0,
drop_last=False,
callbacks=None,
pin_memory=True
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
valid_dataloader for validation.
Parameters
----------
X_train: np.ndarray
Train set
y_train : np.array
Train targets
eval_set: list of tuple
List of eval tuple set (X, y).
The last one is used for early stopping
eval_name: list of str
List of eval set names.
eval_metric : list of str
List of evaluation metrics.
The last metric is used for early stopping.
weights : bool or dictionnary
0 for no balancing
1 for automated balancing
dict for custom weights per class
max_epochs : int
Maximum number of epochs during training
patience : int
Number of consecutive non improving epoch before early stopping
batch_size : int
Training batch size
virtual_batch_size : int
Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size)
num_workers : int
Number of workers used in torch.utils.data.DataLoader
drop_last : bool
Whether to drop last batch during training
callbacks : list of callback function
List of custom callbacks
pin_memory: bool
Whether to set pin_memory to True or False during training
"""
# update model name
self.max_epochs = max_epochs
self.patience = patience
self.batch_size = batch_size
self.virtual_batch_size = virtual_batch_size
self.num_workers = num_workers
self.drop_last = drop_last
self.input_dim = X_train.shape[1]
self._stop_training = False
self.pin_memory = pin_memory
eval_set = eval_set if eval_set else []
if loss_fn is None:
self.loss_fn = self._default_loss
else:
self.loss_fn = loss_fn
check_nans(X_train)
check_nans(y_train)
self.update_fit_params(
X_train, y_train, eval_set, weights,
)
# Validate and reformat eval set depending on training data
eval_names, eval_set = validate_eval_set(eval_set, eval_name, X_train, y_train)
train_dataloader, valid_dataloaders = self._construct_loaders(
X_train, y_train, eval_set
)
self._set_network()
self._set_metrics(eval_metric, eval_names)
self._set_optimizer()
self._set_callbacks(callbacks)
# Call method on_train_begin for all callbacks
self._callback_container.on_train_begin()
# Training loop over epochs
for epoch_idx in range(self.max_epochs):
# Call method on_epoch_begin for all callbacks
self._callback_container.on_epoch_begin(epoch_idx)
self._train_epoch(train_dataloader)
# Apply predict epoch to all eval sets
for eval_name, valid_dataloader in zip(eval_names, valid_dataloaders):
self._predict_epoch(eval_name, valid_dataloader)
# Call method on_epoch_end for all callbacks
self._callback_container.on_epoch_end(epoch_idx,
logs=self.history.epoch_metrics)
if self._stop_training:
break
# Call method on_train_end for all callbacks
self._callback_container.on_train_end()
self.network.eval()
# compute feature importance once the best model is defined
self._compute_feature_importances(train_dataloader)
def predict(self, X):
"""
Make predictions on a batch (valid)
Parameters
----------
X: a :tensor: `torch.Tensor`
Input data
Returns
-------
predictions: np.array
Predictions of the regression problem
"""
self.network.eval()
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
results = []
for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()
output, M_loss = self.network(data)
predictions = output.cpu().detach().numpy()
results.append(predictions)
res = np.vstack(results)
return self.predict_func(res)
def explain(self, X):
"""
Return local explanation
Parameters
----------
X: tensor: `torch.Tensor`
Input data
Returns
-------
M_explain: matrix
Importance per sample, per columns.
masks: matrix
Sparse matrix showing attention masks used by network.
"""
self.network.eval()
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
res_explain = []
for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
for key, value in masks.items():
masks[key] = csc_matrix.dot(
value.cpu().detach().numpy(), self.reducing_matrix
)
res_explain.append(
csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
)
if batch_nb == 0:
res_masks = masks
else:
for key, value in masks.items():
res_masks[key] = np.vstack([res_masks[key], value])
res_explain = np.vstack(res_explain)
return res_explain, res_masks
def save_model(self, path):
"""Saving TabNet model in two distinct files.
Parameters
----------
filepath : str
Path of the model.
"""
saved_params = {}
for key, val in self.get_params().items():
if isinstance(val, type):
# Don't save torch specific params
continue
else:
saved_params[key] = val
# Create folder
Path(path).mkdir(parents=True, exist_ok=True)
# Save models params
with open(Path(path).joinpath("model_params.json"), "w", encoding="utf8") as f:
json.dump(saved_params, f)
# Save state_dict
torch.save(self.network.state_dict(), Path(path).joinpath("network.pt"))
shutil.make_archive(path, "zip", path)
shutil.rmtree(path)
print(f"Successfully saved model at {path}.zip")
return f"{path}.zip"
def load_model(self, filepath):
"""Load TabNet model.
Parameters
----------
filepath : str
Path of the model.
"""
try:
with zipfile.ZipFile(filepath) as z:
with z.open("model_params.json") as f:
loaded_params = json.load(f)
with z.open("network.pt") as f:
try:
saved_state_dict = torch.load(f, map_location=self.device)
except io.UnsupportedOperation:
# In Python <3.7, the returned file object is not seekable (which at least
# some versions of PyTorch require) - so we'll try buffering it in to a
# BytesIO instead:
saved_state_dict = torch.load(
io.BytesIO(f.read()),
map_location=self.device,
)
except KeyError:
raise KeyError("Your zip file is missing at least one component")
self.__init__(**loaded_params)
self._set_network()
self.network.load_state_dict(saved_state_dict)
self.network.eval()
return
def _train_epoch(self, train_loader):
"""
Trains one epoch of the network in self.network
Parameters
----------
train_loader: a :class: `torch.utils.data.Dataloader`
DataLoader with train set
"""
self.network.train()
for batch_idx, (X, y) in enumerate(train_loader):
self._callback_container.on_batch_begin(batch_idx)
batch_logs = self._train_batch(X, y)
self._callback_container.on_batch_end(batch_idx, batch_logs)
epoch_logs = {"lr": self._optimizer.param_groups[-1]["lr"]}
self.history.epoch_metrics.update(epoch_logs)
return
def _train_batch(self, X, y):
"""
Trains one batch of data
Parameters
----------
X: torch.tensor
Train matrix
y: torch.tensor
Target matrix
Returns
-------
batch_outs : dict
Dictionnary with "y": target and "score": prediction scores.
batch_logs : dict
Dictionnary with "batch_size" and "loss".
"""
batch_logs = {"batch_size": X.shape[0]}
X = X.to(self.device).float()
y = y.to(self.device).float()
for param in self.network.parameters():
param.grad = None
output, M_loss = self.network(X)
loss = self.compute_loss(output, y)
# Add the overall sparsity loss
loss -= self.lambda_sparse * M_loss
# Perform backward pass and optimization
loss.backward()
if self.clip_value:
clip_grad_norm_(self.network.parameters(), self.clip_value)
self._optimizer.step()
batch_logs["loss"] = loss.cpu().detach().numpy().item()
return batch_logs
def _predict_epoch(self, name, loader):
"""
Predict an epoch and update metrics.
Parameters
----------
name: str
Name of the validation set
loader: torch.utils.data.Dataloader
DataLoader with validation set
"""
# Setting network on evaluation mode (no dropout etc...)
self.network.eval()
list_y_true = []
list_y_score = []
# Main loop
for batch_idx, (X, y) in enumerate(loader):
scores = self._predict_batch(X)
list_y_true.append(y)
list_y_score.append(scores)
y_true, scores = self.stack_batches(list_y_true, list_y_score)
metrics_logs = self._metric_container_dict[name](y_true, scores)
self.network.train()
self.history.epoch_metrics.update(metrics_logs)
return
def _predict_batch(self, X):
"""
Predict one batch of data.
Parameters
----------
x: torch.tensor
Owned products
Returns
-------
np.array
model scores
"""
X = X.to(self.device).float()
# compute model output
scores, _ = self.network(X)
if isinstance(scores, list):
scores = [x.cpu().detach().numpy() for x in scores]
else:
scores = scores.cpu().detach().numpy()
return scores
def _set_network(self):
"""Setup the network and explain matrix."""
self.network = tab_network.TabNet(
self.input_dim,
self.output_dim,
n_d=self.n_d,
n_a=self.n_a,
n_steps=self.n_steps,
gamma=self.gamma,
cat_idxs=self.cat_idxs,
cat_dims=self.cat_dims,
cat_emb_dim=self.cat_emb_dim,
n_independent=self.n_independent,
n_shared=self.n_shared,
epsilon=self.epsilon,
virtual_batch_size=self.virtual_batch_size,
momentum=self.momentum,
device_name=self.device_name,
mask_type=self.mask_type,
).to(self.device)
self.reducing_matrix = create_explain_matrix(
self.network.input_dim,
self.network.cat_emb_dim,
self.network.cat_idxs,
self.network.post_embed_dim,
)
def _set_metrics(self, metrics, eval_names):
"""Set attributes relative to the metrics.
Parameters
----------
metrics : list of str
List of eval metric names.
eval_names : list of str
List of eval set names.
"""
metrics = metrics or [self._default_metric]
metrics = check_metrics(metrics)
# Set metric container for each sets
self._metric_container_dict = {}
for name in eval_names:
self._metric_container_dict.update(
{name: MetricContainer(metrics, prefix=f"{name}_")}
)
self._metrics = []
self._metrics_names = []
for _, metric_container in self._metric_container_dict.items():
self._metrics.extend(metric_container.metrics)
self._metrics_names.extend(metric_container.names)
# Early stopping metric is the last eval metric
self.early_stopping_metric = (
self._metrics_names[-1] if len(self._metrics_names) > 0 else None
)
def _set_callbacks(self, custom_callbacks):
"""Setup the callbacks functions.
Parameters
----------
callbacks : list of func
List of callback functions.
"""
# Setup default callbacks history, early stopping and scheduler
callbacks = []
self.history = History(self, verbose=self.verbose)
callbacks.append(self.history)
if (self.early_stopping_metric is not None) and (self.patience > 0):
early_stopping = EarlyStopping(
early_stopping_metric=self.early_stopping_metric,
is_maximize=(
self._metrics[-1]._maximize if len(self._metrics) > 0 else None
),
patience=self.patience,
)
callbacks.append(early_stopping)
else:
print("No early stopping will be performed, last training weights will be used.")
if self.scheduler_fn is not None:
# Add LR Scheduler call_back
is_batch_level = self.scheduler_params.pop("is_batch_level", False)
scheduler = LRSchedulerCallback(
scheduler_fn=self.scheduler_fn,
scheduler_params=self.scheduler_params,
optimizer=self._optimizer,
early_stopping_metric=self.early_stopping_metric,
is_batch_level=is_batch_level,
)
callbacks.append(scheduler)
if custom_callbacks:
callbacks.extend(custom_callbacks)
self._callback_container = CallbackContainer(callbacks)
self._callback_container.set_trainer(self)
def _set_optimizer(self):
"""Setup optimizer."""
self._optimizer = self.optimizer_fn(
self.network.parameters(), **self.optimizer_params
)
def _construct_loaders(self, X_train, y_train, eval_set):
"""Generate dataloaders for train and eval set.
Parameters
----------
X_train : np.array
Train set.
y_train : np.array
Train targets.
eval_set: list of tuple
List of eval tuple set (X, y).
Returns
-------
train_dataloader : `torch.utils.data.Dataloader`
Training dataloader.
valid_dataloaders : list of `torch.utils.data.Dataloader`
List of validation dataloaders.
"""
# all weights are not allowed for this type of model
y_train_mapped = self.prepare_target(y_train)
for i, (X, y) in enumerate(eval_set):
y_mapped = self.prepare_target(y)
eval_set[i] = (X, y_mapped)
train_dataloader, valid_dataloaders = create_dataloaders(
X_train,
y_train_mapped,
eval_set,
self.updated_weights,
self.batch_size,
self.num_workers,
self.drop_last,
self.pin_memory,
)
return train_dataloader, valid_dataloaders
def _compute_feature_importances(self, loader):
"""Compute global feature importance.
Parameters
----------
loader : `torch.utils.data.Dataloader`
Pytorch dataloader.
"""
self.network.eval()
feature_importances_ = np.zeros((self.network.post_embed_dim))
for data, targets in loader:
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy()
feature_importances_ = csc_matrix.dot(
feature_importances_, self.reducing_matrix
)
self.feature_importances_ = feature_importances_ / np.sum(feature_importances_)
@abstractmethod
def update_fit_params(self, X_train, y_train, eval_set, weights):
"""
Set attributes relative to fit function.
Parameters
----------
X_train: np.ndarray
Train set
y_train : np.array
Train targets
eval_set: list of tuple
List of eval tuple set (X, y).
weights : bool or dictionnary
0 for no balancing
1 for automated balancing
"""
raise NotImplementedError(
"users must define update_fit_params to use this base class"
)
@abstractmethod
def compute_loss(self, y_score, y_true):
"""
Compute the loss.
Parameters
----------
y_score: a :tensor: `torch.Tensor`
Score matrix
y_true: a :tensor: `torch.Tensor`
Target matrix
Returns
-------
float
Loss value
"""
raise NotImplementedError(
"users must define compute_loss to use this base class"
)
@abstractmethod
def prepare_target(self, y):
"""
Prepare target before training.
Parameters
----------
y: a :tensor: `torch.Tensor`
Target matrix.
Returns
-------
`torch.Tensor`
Converted target matrix.
"""
raise NotImplementedError(
"users must define prepare_target to use this base class"
)