### Feature Importance via the attention weights

I will start by saying that I consider this feature of the library purely experimental. First of all I think there are multiple ways one could address finding the features importances for these models. However, and more importantly, one has to bear in mind that even tree-based algorithms on the same dataset produce different feature importances. This is more "dramatic" if one uses different techniques, such as shap or feature permutation (see for example [this](https://reneelin2019.medium.com/calculating-feature-importance-with-permutation-to-explain-the-model-income-prediction-example-38a52e67441d) and references therein). All this to say that, sometimes, feature importance is just a measure contained within the experiment run, and for the model used.

With that in mind, each instantiation of a deep tabular model, that has millions of trainable parameters, will potentially produce a different set of feature importances, even if the model has the same architecture. Moreover, this effect will become more apparent if the dataset is relatively easy and there are dependent/related columns so that one could get to the same success metric with different parameters. 

In summary, feature importances are implemented in this librray for all attention-based models for tabular data, with the exception of the `TabPerceiver`. However this functionality has to be used and interpreted with care and consider of value within the 'universe' (or context) of the model with which these features were produced.

Nonetheless, let's have a look to how one would access to the feature importances when using this library. 

In [1]:
import torch

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score


from pytorch_widedeep import Trainer
from pytorch_widedeep.models import TabTransformer, ContextAttentionMLP, WideDeep
from pytorch_widedeep.callbacks import EarlyStopping
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor

In [2]:
# use_cuda = torch.cuda.is_available()
df = load_adult(as_frame=True)
df.columns = [c.replace("-", "_") for c in df.columns]
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop(["income", "fnlwgt", "educational_num"], axis=1, inplace=True)
target_colname = "income_label"

In [3]:
df.head()

Unnamed: 0,age,workclass,education,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,income_label
0,25,Private,11th,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,0
1,38,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,0
2,28,Local-gov,Assoc-acdm,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,1
3,44,Private,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,1
4,18,?,Some-college,Never-married,?,Own-child,White,Female,0,0,30,United-States,0


In [4]:
cat_embed_cols = []
for col in df.columns:
    if df[col].dtype == "O" or df[col].nunique() < 200 and col != target_colname:
        cat_embed_cols.append(col)

In [5]:
# all cols will be categorical
assert len(cat_embed_cols) == df.shape[1] - 1

In [6]:
train, test = train_test_split(
    df, test_size=0.1, random_state=1, stratify=df[[target_colname]]
)

In [7]:
tab_preprocessor = TabPreprocessor(cat_embed_cols=cat_embed_cols, with_attention=True)

In [8]:
X_tab_train = tab_preprocessor.fit_transform(train)
X_tab_test = tab_preprocessor.transform(test)
target = train[target_colname].values

In [9]:
tab_transformer = TabTransformer(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    cat_embed_dropout=0.0,
    input_dim=8,
    n_heads=2,
    n_blocks=1,
    attn_dropout=0.1,
    transformer_activation="relu",
)

In [10]:
model = WideDeep(deeptabular=tab_transformer)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0)

In [12]:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    threshold=0.001,
    threshold_mode="abs",
    patience=10,
)

In [13]:
early_stopping = EarlyStopping(
    min_delta=0.001, patience=30, restore_best_weights=True, verbose=True
)

In [14]:
trainer = Trainer(
    model,
    objective="binary",
    optimizers=optimizer,
    lr_schedulers=lr_scheduler,
    reducelronplateau_criterion="loss",
    callbacks=[early_stopping],
    metrics=[Accuracy],
)

The feature importances will be computed after training, using a sample of the training dataset of size `feature_importance_sample_size`

In [15]:
trainer.fit(
    X_tab=X_tab_train,
    target=target,
    val_split=0.2,
    n_epochs=100,
    batch_size=128,
    validation_freq=1,
    feature_importance_sample_size=1000,
)

epoch 1: 100%|████| 275/275 [00:03<00:00, 90.98it/s, loss=0.332, metrics={'acc': 0.847}]
valid: 100%|███████| 69/69 [00:00<00:00, 148.31it/s, loss=0.292, metrics={'acc': 0.866}]
epoch 2: 100%|████| 275/275 [00:02<00:00, 96.35it/s, loss=0.289, metrics={'acc': 0.868}]
valid: 100%|██████| 69/69 [00:00<00:00, 147.03it/s, loss=0.278, metrics={'acc': 0.8717}]
epoch 3: 100%|████| 275/275 [00:02<00:00, 94.23it/s, loss=0.28, metrics={'acc': 0.8719}]
valid: 100%|██████| 69/69 [00:00<00:00, 139.13it/s, loss=0.275, metrics={'acc': 0.8732}]
epoch 4: 100%|████| 275/275 [00:03<00:00, 90.15it/s, loss=0.276, metrics={'acc': 0.872}]
valid: 100%|██████| 69/69 [00:00<00:00, 133.21it/s, loss=0.275, metrics={'acc': 0.8706}]
epoch 5: 100%|███| 275/275 [00:03<00:00, 86.75it/s, loss=0.274, metrics={'acc': 0.8736}]
valid: 100%|██████| 69/69 [00:00<00:00, 132.85it/s, loss=0.275, metrics={'acc': 0.8717}]
epoch 6: 100%|███| 275/275 [00:03<00:00, 89.16it/s, loss=0.272, metrics={'acc': 0.8742}]
valid: 100%|██████| 6

Best Epoch: 9. Best val_loss: 0.27098
Restoring model weights from the end of the best epoch


In [17]:
trainer.feature_importance

{'age': 0.098023,
 'workclass': 0.07621125,
 'education': 0.07414728,
 'marital_status': 0.113280274,
 'occupation': 0.07292068,
 'relationship': 0.08008792,
 'race': 0.104180396,
 'gender': 0.07037963,
 'capital_gain': 0.06584223,
 'capital_loss': 0.07647487,
 'hours_per_week': 0.09369389,
 'native_country': 0.0747586}

In [18]:
preds = trainer.predict(X_tab=X_tab_test)

predict: 100%|█████████████████████████████████████████| 39/39 [00:00<00:00, 213.15it/s]


In [19]:
accuracy_score(preds, test.income_label)

0.8734902763561925

In [20]:
test.reset_index(drop=True, inplace=True)

In [21]:
test[test.income_label == 0].head(1)

Unnamed: 0,age,workclass,education,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,income_label
0,26,Private,Some-college,Never-married,Exec-managerial,Not-in-family,White,Male,0,0,60,United-States,0


In [22]:
test[test.income_label == 1].head(1)

Unnamed: 0,age,workclass,education,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,income_label
3,36,Local-gov,Doctorate,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,1887,50,United-States,1


To get the feature importance of a test dataset, simply use the `explain` method

In [23]:
feat_imp_per_sample = trainer.explain(X_tab_test, save_step_masks=False)

In [24]:
list(test.iloc[0].index[np.argsort(-feat_imp_per_sample[0])])

['marital_status',
 'race',
 'age',
 'capital_loss',
 'occupation',
 'native_country',
 'workclass',
 'education',
 'gender',
 'relationship',
 'hours_per_week',
 'capital_gain']

In [25]:
list(test.iloc[3].index[np.argsort(-feat_imp_per_sample[3])])

['marital_status',
 'race',
 'capital_loss',
 'occupation',
 'education',
 'native_country',
 'hours_per_week',
 'relationship',
 'age',
 'workclass',
 'gender',
 'capital_gain']

We could do the same with the `ContextAttentionMLP`

In [26]:
context_attn_mlp = ContextAttentionMLP(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    cat_embed_dropout=0.0,
    input_dim=16,
    attn_dropout=0.1,
    attn_activation="relu",
)

In [27]:
mlp_model = WideDeep(deeptabular=context_attn_mlp)

In [28]:
mlp_optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.01, weight_decay=0.0)

In [29]:
mlp_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    mlp_optimizer,
    threshold=0.001,
    threshold_mode="abs",
    patience=10,
)

In [30]:
mlp_early_stopping = EarlyStopping(
    min_delta=0.001, patience=30, restore_best_weights=True, verbose=True
)

In [31]:
mlp_trainer = Trainer(
    mlp_model,
    objective="binary",
    optimizers=mlp_optimizer,
    lr_schedulers=mlp_lr_scheduler,
    reducelronplateau_criterion="loss",
    callbacks=[mlp_early_stopping],
    metrics=[Accuracy],
)

In [32]:
mlp_trainer.fit(
    X_tab=X_tab_train,
    target=target,
    val_split=0.2,
    n_epochs=100,
    batch_size=128,
    validation_freq=1,
    feature_importance_sample_size=1000,
)

epoch 1: 100%|███| 275/275 [00:03<00:00, 91.33it/s, loss=0.395, metrics={'acc': 0.8139}]
valid: 100%|██████| 69/69 [00:00<00:00, 125.66it/s, loss=0.306, metrics={'acc': 0.8577}]
epoch 2: 100%|███| 275/275 [00:03<00:00, 87.31it/s, loss=0.333, metrics={'acc': 0.8396}]
valid: 100%|██████| 69/69 [00:00<00:00, 140.84it/s, loss=0.291, metrics={'acc': 0.8631}]
epoch 3: 100%|███| 275/275 [00:03<00:00, 84.57it/s, loss=0.323, metrics={'acc': 0.8494}]
valid: 100%|██████| 69/69 [00:00<00:00, 141.55it/s, loss=0.293, metrics={'acc': 0.8632}]
epoch 4: 100%|███| 275/275 [00:03<00:00, 84.62it/s, loss=0.312, metrics={'acc': 0.8518}]
valid: 100%|████████| 69/69 [00:00<00:00, 156.92it/s, loss=0.3, metrics={'acc': 0.8543}]
epoch 5: 100%|████| 275/275 [00:03<00:00, 85.74it/s, loss=0.31, metrics={'acc': 0.8552}]
valid: 100%|██████| 69/69 [00:00<00:00, 130.66it/s, loss=0.303, metrics={'acc': 0.8545}]
epoch 6: 100%|███| 275/275 [00:02<00:00, 93.70it/s, loss=0.303, metrics={'acc': 0.8579}]
valid: 100%|██████| 6

epoch 47: 100%|██| 275/275 [00:03<00:00, 70.92it/s, loss=0.279, metrics={'acc': 0.8693}]
valid: 100%|██████| 69/69 [00:00<00:00, 121.47it/s, loss=0.276, metrics={'acc': 0.8699}]
epoch 48: 100%|██| 275/275 [00:03<00:00, 75.09it/s, loss=0.277, metrics={'acc': 0.8701}]
valid: 100%|██████| 69/69 [00:00<00:00, 126.88it/s, loss=0.276, metrics={'acc': 0.8706}]
epoch 49: 100%|██| 275/275 [00:03<00:00, 74.75it/s, loss=0.279, metrics={'acc': 0.8694}]
valid: 100%|██████| 69/69 [00:00<00:00, 130.57it/s, loss=0.276, metrics={'acc': 0.8705}]
epoch 50: 100%|██| 275/275 [00:03<00:00, 81.06it/s, loss=0.278, metrics={'acc': 0.8704}]
valid: 100%|████████| 69/69 [00:00<00:00, 137.71it/s, loss=0.276, metrics={'acc': 0.87}]
epoch 51: 100%|██| 275/275 [00:03<00:00, 79.21it/s, loss=0.278, metrics={'acc': 0.8689}]
valid: 100%|██████| 69/69 [00:00<00:00, 134.51it/s, loss=0.276, metrics={'acc': 0.8702}]
epoch 52: 100%|███| 275/275 [00:03<00:00, 81.95it/s, loss=0.279, metrics={'acc': 0.869}]
valid: 100%|██████| 6

Best Epoch: 30. Best val_loss: 0.27563
Restoring model weights from the end of the best epoch


In [33]:
mlp_trainer.feature_importance

{'age': 0.103683405,
 'workclass': 0.066264994,
 'education': 0.10014994,
 'marital_status': 0.1235957,
 'occupation': 0.12825337,
 'relationship': 0.15234835,
 'race': 0.061964743,
 'gender': 0.05328226,
 'capital_gain': 0.03052448,
 'capital_loss': 0.037544865,
 'hours_per_week': 0.07689079,
 'native_country': 0.0654971}

In [34]:
mlp_preds = mlp_trainer.predict(X_tab=X_tab_test)

predict: 100%|█████████████████████████████████████████| 39/39 [00:00<00:00, 211.89it/s]


In [35]:
accuracy_score(mlp_preds, test.income_label)

0.873899692937564