At the moment there are 3 transformer-based algorithms available. 

Here are examples of how to use them

Perhaps the main comment is that when using transformer-based models, the data preparation is a bit different than in other models. Therefore one needs to know the set up at pre-processing stage. 

Let's have a look, starting with the `TabTransformer`

In [1]:
import numpy as np
import pandas as pd
import torch

from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.training import Trainer
from pytorch_widedeep.models import TabTransformer, SAINT, WideDeep
from pytorch_widedeep.metrics import Accuracy

  return f(*args, **kwds)


In [2]:
df = pd.read_csv('data/adult/adult.csv.zip')
df.head()

Unnamed: 0,age,workclass,fnlwgt,education,educational-num,marital-status,occupation,relationship,race,gender,capital-gain,capital-loss,hours-per-week,native-country,income
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K
4,18,?,103497,Some-college,10,Never-married,?,Own-child,White,Female,0,0,30,United-States,<=50K


In [3]:
# For convenience, we'll replace '-' with '_'
df.columns = [c.replace("-", "_") for c in df.columns]
#binary target
df['target'] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop(["income", "educational_num"], axis=1, inplace=True)

df.head()

Unnamed: 0,age,workclass,fnlwgt,education,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,target
0,25,Private,226802,11th,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,0
1,38,Private,89814,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,0
2,28,Local-gov,336951,Assoc-acdm,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,1
3,44,Private,160323,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,1
4,18,?,103497,Some-college,Never-married,?,Own-child,White,Female,0,0,30,United-States,0


In [4]:
cat_cols, cont_cols = [], []
for col in df.columns:
    # 50 is just a random number I choose here for this example
    if df[col].dtype == "O" or df[col].nunique() < 50 and col != "target":
        cat_cols.append(col)
    elif col != "target":        
        cont_cols.append(col)
target_col = "target"

"Standard" `TabTransformer`

In [5]:
target = df[target_col].values

tab_preprocessor = TabPreprocessor(embed_cols=cat_cols, 
                                   continuous_cols=cont_cols, 
                                   for_transformer=True
                                  )
X_tab = tab_preprocessor.fit_transform(df)

In [6]:
# here all categorical columns will be encoded as 32 dim embeddings, then passed through the transformer 
# blocks, concatenated with the continuous and finally through an MLP
tab_transformer = TabTransformer(column_idx=tab_preprocessor.column_idx,
                                 embed_input=tab_preprocessor.embeddings_input,
                                 continuous_cols=tab_preprocessor.continuous_cols, 
                                 cont_norm_layer="layernorm"
                                )

In [7]:
tab_transformer

TabTransformer(
  (cat_embed): Embedding(103, 32, padding_idx=0)
  (embedding_dropout): Dropout(p=0.1, inplace=False)
  (cont_norm): LayerNorm((5,), eps=1e-05, elementwise_affine=True)
  (transformer_blks): Sequential(
    (block0): TransformerEncoder(
      (self_attn): MultiHeadedAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (inp_proj): Linear(in_features=32, out_features=96, bias=True)
        (out_proj): Linear(in_features=32, out_features=32, bias=True)
      )
      (ff): PositionwiseFF(
        (w_1): Linear(in_features=32, out_features=128, bias=True)
        (w_2): Linear(in_features=128, out_features=32, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation): GELU()
      )
      (attn_addnorm): AddNorm(
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      )
      (ff_addnorm): AddNorm(
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((3

In [8]:
model = WideDeep(deeptabular=tab_transformer)

In [9]:
trainer = Trainer(model, objective='binary', metrics=[Accuracy])

In [10]:
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:23<00:00,  6.54it/s, loss=0.375, metrics={'acc': 0.8247}]
valid: 100%|██████████| 39/39 [00:01<00:00, 19.74it/s, loss=0.356, metrics={'acc': 0.8375}]


We can also choose to use the `FT-Transformer` (just set `embed_continuous=True`), where continuous cols are also represented by "Embeddings", via a 1 layer MLP (with or without activation function). When using the `FT-Transformer` we can choose to use the `[CLS]` token as a pooling method or concatenate the output from the transformer blocks, as we did before. Let's use here the `[CLS]` token.

In [11]:
tab_preprocessor = TabPreprocessor(embed_cols=cat_cols, 
                                   continuous_cols=cont_cols, 
                                   for_transformer=True,
                                   with_cls_token=True
                                  )
X_tab = tab_preprocessor.fit_transform(df)

In [12]:
# here all categorical columns will be encoded as 32 dim embeddings, then passed through the transformer 
# blocks, concatenated with the continuous and finally through an MLP
ft_transformer = TabTransformer(column_idx=tab_preprocessor.column_idx,
                                 embed_input=tab_preprocessor.embeddings_input,
                                 continuous_cols=tab_preprocessor.continuous_cols, 
                                 embed_continuous=True,
                                 embed_continuous_activation=None,                            
                                )

In [13]:
ft_transformer

TabTransformer(
  (cat_embed): Embedding(104, 32, padding_idx=0)
  (embedding_dropout): Dropout(p=0.1, inplace=False)
  (cont_norm): Identity()
  (cont_embed): ContinuousEmbeddings()
  (transformer_blks): Sequential(
    (block0): TransformerEncoder(
      (self_attn): MultiHeadedAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (inp_proj): Linear(in_features=32, out_features=96, bias=True)
        (out_proj): Linear(in_features=32, out_features=32, bias=True)
      )
      (ff): PositionwiseFF(
        (w_1): Linear(in_features=32, out_features=128, bias=True)
        (w_2): Linear(in_features=128, out_features=32, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation): GELU()
      )
      (attn_addnorm): AddNorm(
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      )
      (ff_addnorm): AddNorm(
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((32,

In [14]:
model = WideDeep(deeptabular=ft_transformer)

In [15]:
trainer = Trainer(model, objective='binary', metrics=[Accuracy])

In [16]:
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:38<00:00,  4.02it/s, loss=0.412, metrics={'acc': 0.8002}]
valid: 100%|██████████| 39/39 [00:02<00:00, 14.69it/s, loss=0.326, metrics={'acc': 0.8524}]


Or we can choose to use SAINT, with its inter-sample attention

In [18]:
saint = SAINT(column_idx=tab_preprocessor.column_idx,
              embed_input=tab_preprocessor.embeddings_input,
              continuous_cols=tab_preprocessor.continuous_cols, 
              transformer_activation="geglu",
              embed_continuous=True,
              embed_continuous_activation=None,                            
             )

In [19]:
saint

SAINT(
  (cat_embed): Embedding(104, 32, padding_idx=0)
  (embedding_dropout): Dropout(p=0.1, inplace=False)
  (cont_norm): LayerNorm((5,), eps=1e-05, elementwise_affine=True)
  (cont_embed): ContinuousEmbeddings()
  (transformer_blks): Sequential(
    (block0): SaintEncoder(
      (self_attn): MultiHeadedAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (inp_proj): Linear(in_features=32, out_features=96, bias=True)
        (out_proj): Linear(in_features=32, out_features=32, bias=True)
      )
      (self_attn_ff): PositionwiseFF(
        (w_1): Linear(in_features=32, out_features=256, bias=True)
        (w_2): Linear(in_features=128, out_features=32, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation): GEGLU()
      )
      (self_attn_addnorm): AddNorm(
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      )
      (self_attn_ff_addnorm): AddNorm(
        (dropout): Dropou

In [20]:
model = WideDeep(deeptabular=saint)
trainer = Trainer(model, objective='binary', metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=128, val_split=0.2)

epoch 1: 100%|██████████| 306/306 [02:34<00:00,  1.98it/s, loss=0.555, metrics={'acc': 0.7607}]
valid: 100%|██████████| 77/77 [00:08<00:00,  8.96it/s, loss=0.551, metrics={'acc': 0.7607}]


One final comment is that all 3 transformer-based model have the option of using the so called "Shared Embeddings". The idea behind the shared embeddings is explained in the original TabTransformer paper and also here in this [post](https://jrzaurin.github.io/infinitoml/2021/02/18/pytorch-widedeep_iii.html).

For transformer-based models this implies a bit of a different data preparation process since each column will be encoded individually (programmatically is way easier to implement) and the use of shared embeddings needs to be specified at preprocessing stage

In [21]:
tab_preprocessor = TabPreprocessor(embed_cols=cat_cols, 
                                   continuous_cols=cont_cols, 
                                   for_transformer=True,
                                   shared_embed=True,
                                   with_cls_token=True
                                  )
X_tab = tab_preprocessor.fit_transform(df)

In [22]:
ft_transformer = TabTransformer(column_idx=tab_preprocessor.column_idx,
                                embed_input=tab_preprocessor.embeddings_input,
                                continuous_cols=tab_preprocessor.continuous_cols, 
                                embed_continuous=True,
                                embed_continuous_activation=None,       
                                shared_embed=True,  
                                )

In [23]:
ft_transformer

TabTransformer(
  (cat_embed): ModuleDict(
    (emb_layer_cls_token): SharedEmbeddings(
      (embed): Embedding(1, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_layer_education): SharedEmbeddings(
      (embed): Embedding(17, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_layer_gender): SharedEmbeddings(
      (embed): Embedding(3, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_layer_marital_status): SharedEmbeddings(
      (embed): Embedding(8, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_layer_native_country): SharedEmbeddings(
      (embed): Embedding(43, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_layer_occupation): SharedEmbeddings(
      (embed): Embedding(16, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_layer_race): SharedEmbeddings(
      (embed): Embedding(6, 32, padd

In [25]:
model = WideDeep(deeptabular=ft_transformer)
trainer = Trainer(model, objective='binary', metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:35<00:00,  4.32it/s, loss=0.411, metrics={'acc': 0.8036}]
valid: 100%|██████████| 39/39 [00:02<00:00, 15.03it/s, loss=0.319, metrics={'acc': 0.8583}]
