From 5b190916515793114ffa1a9ac4f3869222a14c11 Mon Sep 17 00:00:00 2001 From: Optimox Date: Tue, 22 Mar 2022 14:31:23 +0100 Subject: [PATCH] fix: feature importance not dependent from dataloader --- pytorch_tabnet/abstract_model.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 9f5cc023..8f48e29f 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -251,7 +251,7 @@ def fit( self.network.eval() # compute feature importance once the best model is defined - self._compute_feature_importances(train_dataloader) + self.feature_importances_ = self._compute_feature_importances(X_train) def predict(self, X): """ @@ -283,7 +283,7 @@ def predict(self, X): res = np.vstack(results) return self.predict_func(res) - def explain(self, X): + def explain(self, X, normalize=False): """ Return local explanation @@ -291,6 +291,8 @@ def explain(self, X): ---------- X : tensor: `torch.Tensor` Input data + normalize : bool (default False) + Wheter to normalize so that sum of features are equal to 1 Returns ------- @@ -318,9 +320,9 @@ def explain(self, X): value.cpu().detach().numpy(), self.reducing_matrix ) - res_explain.append( - csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix) - ) + original_feat_explain = csc_matrix.dot(M_explain.cpu().detach().numpy(), + self.reducing_matrix) + res_explain.append(original_feat_explain) if batch_nb == 0: res_masks = masks @@ -330,6 +332,9 @@ def explain(self, X): res_explain = np.vstack(res_explain) + if normalize: + res_explain /= np.sum(res_explain, axis=1)[:, None] + return res_explain, res_masks def load_weights_from_unsupervised(self, unsupervised_model): @@ -697,7 +702,7 @@ def _construct_loaders(self, X_train, y_train, eval_set): ) return train_dataloader, valid_dataloaders - def _compute_feature_importances(self, loader): + def _compute_feature_importances(self, X): """Compute global feature importance. Parameters @@ -706,17 +711,10 @@ def _compute_feature_importances(self, loader): Pytorch dataloader. """ - self.network.eval() - feature_importances_ = np.zeros((self.network.post_embed_dim)) - for data, targets in loader: - data = data.to(self.device).float() - M_explain, masks = self.network.forward_masks(data) - feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy() - - feature_importances_ = csc_matrix.dot( - feature_importances_, self.reducing_matrix - ) - self.feature_importances_ = feature_importances_ / np.sum(feature_importances_) + M_explain, _ = self.explain(X, normalize=False) + sum_explain = M_explain.sum(axis=0) + feature_importances_ = sum_explain / np.sum(sum_explain) + return feature_importances_ def _update_network_params(self): self.network.virtual_batch_size = self.virtual_batch_size