Skip to content

Commit

Permalink
fix: local explain all batches
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Dec 5, 2019
1 parent 269b4c5 commit 91461fb
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 8 deletions.
3 changes: 2 additions & 1 deletion census_example.ipynb
Expand Up @@ -141,7 +141,8 @@
"metadata": {},
"outputs": [],
"source": [
"clf = TabNetClassifier()"
"clf = TabNetClassifier(cat_idxs=cat_idxs, cat_dims=cat_dims,\n",
" cat_emb_dim=[2, 2, 3, 2, 2, 3, 2, 2])"
]
},
{
Expand Down
25 changes: 19 additions & 6 deletions pytorch_tabnet/tab_model.py
Expand Up @@ -6,7 +6,8 @@
from pytorch_tabnet.multiclass_utils import unique_labels
from sklearn.metrics import roc_auc_score, mean_squared_error, accuracy_score
from torch.nn.utils import clip_grad_norm_
from pytorch_tabnet.utils import PredictDataset, plot_losses, create_dataloaders
from pytorch_tabnet.utils import (PredictDataset, plot_losses,
create_dataloaders, create_explain_matrix)
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -126,6 +127,11 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None,
momentum=self.momentum,
device_name=self.device_name).to(self.device)

self.reducing_matrix = create_explain_matrix(self.network.input_dim,
self.network.cat_emb_dim,
self.network.cat_idxs,
self.network.post_embed_dim)

self.optimizer = self.optimizer_fn(self.network.parameters(),
**self.opt_params)

Expand Down Expand Up @@ -294,17 +300,20 @@ def explain(self, X):

output, M_loss, M_explain, masks = self.network(data)
for key, value in masks.items():
masks[key] = value.cpu().detach().numpy()
masks[key] = np.matmul(value.cpu().detach().numpy(),
self.reducing_matrix)

if batch_nb == 0:
res_explain = M_explain.cpu().detach().numpy()
res_explain = np.matmul(M_explain.cpu().detach().numpy(),
self.reducing_matrix)
res_masks = masks
else:
res_explain = np.vstack([res_explain,
M_explain.cpu().detach().numpy()])
np.matmul(M_explain.cpu().detach().numpy(),
self.reducing_matrix)])
for key, value in masks.items():
res_masks[key] = np.vstack([res_masks[key], value])
return M_explain, res_masks
return res_explain, res_masks


class TabNetClassifier(TabModel):
Expand Down Expand Up @@ -436,7 +445,7 @@ def train_epoch(self, train_loader):
y_preds = []
ys = []
total_loss = 0
feature_importances_ = np.zeros((self.input_dim))
feature_importances_ = np.zeros((self.network.post_embed_dim))
with tqdm() as pbar:
for data, targets in train_loader:
batch_outs = self.train_batch(data, targets)
Expand All @@ -450,6 +459,8 @@ def train_epoch(self, train_loader):
feature_importances_ += batch_outs['batch_importance']
pbar.update(1)

# Reduce to initial input_dim
feature_importances_ = np.matmul(feature_importances_, self.reducing_matrix)
# Normalize feature_importances_
feature_importances_ = feature_importances_ / np.sum(feature_importances_)

Expand Down Expand Up @@ -725,6 +736,8 @@ def train_epoch(self, train_loader):
feature_importances_ += batch_outs['batch_importance']
pbar.update(1)

# Reduce to initial input_dim
feature_importances_ = np.matmul(feature_importances_, self.reducing_matrix)
# Normalize feature_importances_
feature_importances_ = feature_importances_ / np.sum(feature_importances_)

Expand Down
6 changes: 5 additions & 1 deletion pytorch_tabnet/tab_network.py
Expand Up @@ -119,7 +119,11 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
# record continuous indices
self.continuous_idx = torch.ones(self.input_dim, dtype=torch.bool)
self.continuous_idx[self.cat_idxs] = 0
self.post_embed_dim = self.input_dim + (cat_emb_dim - 1)*len(self.cat_idxs)

if isinstance(cat_emb_dim, int):
self.post_embed_dim = self.input_dim + (cat_emb_dim - 1)*len(self.cat_idxs)
else:
self.post_embed_dim = self.input_dim + np.sum(cat_emb_dim) - len(cat_emb_dim)
self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01)

if self.n_shared > 0:
Expand Down
47 changes: 47 additions & 0 deletions pytorch_tabnet/utils.py
Expand Up @@ -135,3 +135,50 @@ def plot_losses(losses_train, losses_valid, metrics_train, metrics_valid):
plt.title('Training Metrics')
plt.legend()
plt.show()


def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim):
"""
This is a computational trick.
In order to rapidly sum importances from same embeddings
to the initial index.
Parameters
----------
input_dim: int
Initial input dim
cat_emb_dim : int or list of int
if int : size of embedding for all categorical feature
if list of int : size of embedding for each categorical feature
cat_idxs : list of int
Initial position of categorical features
post_embed_dim : int
Post embedding inputs dimension
Returns
-------
reducing_matrix : np.array
Matrix of dim (post_embed_dim, input_dim) to performe reduce
"""

if isinstance(cat_emb_dim, int):
all_emb_impact = [cat_emb_dim-1]*len(cat_idxs)
else:
all_emb_impact = [emb_dim-1 for emb_dim in cat_emb_dim]

acc_emb = 0
nb_emb = 0
indices_trick = []
for i in range(input_dim):
if i not in cat_idxs:
indices_trick.append([i+acc_emb])
else:
indices_trick.append(range(i+acc_emb, i+acc_emb+all_emb_impact[nb_emb]+1))
acc_emb += all_emb_impact[nb_emb]
nb_emb += 1

reducing_matrix = np.zeros((post_embed_dim, input_dim))
for i, cols in enumerate(indices_trick):
reducing_matrix[cols, i] = 1

return reducing_matrix

0 comments on commit 91461fb

Please sign in to comment.