Skip to content

Commit

Permalink
fix: feature importance not dependent from dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Mar 23, 2022
1 parent 95e0e9a commit 5b19091
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions pytorch_tabnet/abstract_model.py
Expand Up @@ -251,7 +251,7 @@ def fit(
self.network.eval()

# compute feature importance once the best model is defined
self._compute_feature_importances(train_dataloader)
self.feature_importances_ = self._compute_feature_importances(X_train)

def predict(self, X):
"""
Expand Down Expand Up @@ -283,14 +283,16 @@ def predict(self, X):
res = np.vstack(results)
return self.predict_func(res)

def explain(self, X):
def explain(self, X, normalize=False):
"""
Return local explanation
Parameters
----------
X : tensor: `torch.Tensor`
Input data
normalize : bool (default False)
Wheter to normalize so that sum of features are equal to 1
Returns
-------
Expand Down Expand Up @@ -318,9 +320,9 @@ def explain(self, X):
value.cpu().detach().numpy(), self.reducing_matrix
)

res_explain.append(
csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
)
original_feat_explain = csc_matrix.dot(M_explain.cpu().detach().numpy(),
self.reducing_matrix)
res_explain.append(original_feat_explain)

if batch_nb == 0:
res_masks = masks
Expand All @@ -330,6 +332,9 @@ def explain(self, X):

res_explain = np.vstack(res_explain)

if normalize:
res_explain /= np.sum(res_explain, axis=1)[:, None]

return res_explain, res_masks

def load_weights_from_unsupervised(self, unsupervised_model):
Expand Down Expand Up @@ -697,7 +702,7 @@ def _construct_loaders(self, X_train, y_train, eval_set):
)
return train_dataloader, valid_dataloaders

def _compute_feature_importances(self, loader):
def _compute_feature_importances(self, X):
"""Compute global feature importance.
Parameters
Expand All @@ -706,17 +711,10 @@ def _compute_feature_importances(self, loader):
Pytorch dataloader.
"""
self.network.eval()
feature_importances_ = np.zeros((self.network.post_embed_dim))
for data, targets in loader:
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy()

feature_importances_ = csc_matrix.dot(
feature_importances_, self.reducing_matrix
)
self.feature_importances_ = feature_importances_ / np.sum(feature_importances_)
M_explain, _ = self.explain(X, normalize=False)
sum_explain = M_explain.sum(axis=0)
feature_importances_ = sum_explain / np.sum(sum_explain)
return feature_importances_

def _update_network_params(self):
self.network.virtual_batch_size = self.virtual_batch_size
Expand Down

0 comments on commit 5b19091

Please sign in to comment.