From 91461fbcd4b8c806e920936e0154258b2dc02373 Mon Sep 17 00:00:00 2001 From: Sebastien Fischman Date: Wed, 4 Dec 2019 11:26:05 +0100 Subject: [PATCH] fix: local explain all batches --- census_example.ipynb | 3 ++- pytorch_tabnet/tab_model.py | 25 ++++++++++++++----- pytorch_tabnet/tab_network.py | 6 ++++- pytorch_tabnet/utils.py | 47 +++++++++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 8 deletions(-) diff --git a/census_example.ipynb b/census_example.ipynb index 2e65af95..78b49dd5 100644 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -141,7 +141,8 @@ "metadata": {}, "outputs": [], "source": [ - "clf = TabNetClassifier()" + "clf = TabNetClassifier(cat_idxs=cat_idxs, cat_dims=cat_dims,\n", + " cat_emb_dim=[2, 2, 3, 2, 2, 3, 2, 2])" ] }, { diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 97bd37c2..b04c09b8 100644 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -6,7 +6,8 @@ from pytorch_tabnet.multiclass_utils import unique_labels from sklearn.metrics import roc_auc_score, mean_squared_error, accuracy_score from torch.nn.utils import clip_grad_norm_ -from pytorch_tabnet.utils import PredictDataset, plot_losses, create_dataloaders +from pytorch_tabnet.utils import (PredictDataset, plot_losses, + create_dataloaders, create_explain_matrix) from torch.utils.data import DataLoader @@ -126,6 +127,11 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, momentum=self.momentum, device_name=self.device_name).to(self.device) + self.reducing_matrix = create_explain_matrix(self.network.input_dim, + self.network.cat_emb_dim, + self.network.cat_idxs, + self.network.post_embed_dim) + self.optimizer = self.optimizer_fn(self.network.parameters(), **self.opt_params) @@ -294,17 +300,20 @@ def explain(self, X): output, M_loss, M_explain, masks = self.network(data) for key, value in masks.items(): - masks[key] = value.cpu().detach().numpy() + masks[key] = np.matmul(value.cpu().detach().numpy(), + self.reducing_matrix) if batch_nb == 0: - res_explain = M_explain.cpu().detach().numpy() + res_explain = np.matmul(M_explain.cpu().detach().numpy(), + self.reducing_matrix) res_masks = masks else: res_explain = np.vstack([res_explain, - M_explain.cpu().detach().numpy()]) + np.matmul(M_explain.cpu().detach().numpy(), + self.reducing_matrix)]) for key, value in masks.items(): res_masks[key] = np.vstack([res_masks[key], value]) - return M_explain, res_masks + return res_explain, res_masks class TabNetClassifier(TabModel): @@ -436,7 +445,7 @@ def train_epoch(self, train_loader): y_preds = [] ys = [] total_loss = 0 - feature_importances_ = np.zeros((self.input_dim)) + feature_importances_ = np.zeros((self.network.post_embed_dim)) with tqdm() as pbar: for data, targets in train_loader: batch_outs = self.train_batch(data, targets) @@ -450,6 +459,8 @@ def train_epoch(self, train_loader): feature_importances_ += batch_outs['batch_importance'] pbar.update(1) + # Reduce to initial input_dim + feature_importances_ = np.matmul(feature_importances_, self.reducing_matrix) # Normalize feature_importances_ feature_importances_ = feature_importances_ / np.sum(feature_importances_) @@ -725,6 +736,8 @@ def train_epoch(self, train_loader): feature_importances_ += batch_outs['batch_importance'] pbar.update(1) + # Reduce to initial input_dim + feature_importances_ = np.matmul(feature_importances_, self.reducing_matrix) # Normalize feature_importances_ feature_importances_ = feature_importances_ / np.sum(feature_importances_) diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index eb2b0bcd..7135ccd7 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -119,7 +119,11 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8, # record continuous indices self.continuous_idx = torch.ones(self.input_dim, dtype=torch.bool) self.continuous_idx[self.cat_idxs] = 0 - self.post_embed_dim = self.input_dim + (cat_emb_dim - 1)*len(self.cat_idxs) + + if isinstance(cat_emb_dim, int): + self.post_embed_dim = self.input_dim + (cat_emb_dim - 1)*len(self.cat_idxs) + else: + self.post_embed_dim = self.input_dim + np.sum(cat_emb_dim) - len(cat_emb_dim) self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01) if self.n_shared > 0: diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index 45fd8d06..ab2bbec2 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -135,3 +135,50 @@ def plot_losses(losses_train, losses_valid, metrics_train, metrics_valid): plt.title('Training Metrics') plt.legend() plt.show() + + +def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim): + """ + This is a computational trick. + In order to rapidly sum importances from same embeddings + to the initial index. + + Parameters + ---------- + input_dim: int + Initial input dim + cat_emb_dim : int or list of int + if int : size of embedding for all categorical feature + if list of int : size of embedding for each categorical feature + cat_idxs : list of int + Initial position of categorical features + post_embed_dim : int + Post embedding inputs dimension + + Returns + ------- + reducing_matrix : np.array + Matrix of dim (post_embed_dim, input_dim) to performe reduce + """ + + if isinstance(cat_emb_dim, int): + all_emb_impact = [cat_emb_dim-1]*len(cat_idxs) + else: + all_emb_impact = [emb_dim-1 for emb_dim in cat_emb_dim] + + acc_emb = 0 + nb_emb = 0 + indices_trick = [] + for i in range(input_dim): + if i not in cat_idxs: + indices_trick.append([i+acc_emb]) + else: + indices_trick.append(range(i+acc_emb, i+acc_emb+all_emb_impact[nb_emb]+1)) + acc_emb += all_emb_impact[nb_emb] + nb_emb += 1 + + reducing_matrix = np.zeros((post_embed_dim, input_dim)) + for i, cols in enumerate(indices_trick): + reducing_matrix[cols, i] = 1 + + return reducing_matrix