# The Transformer Family
At the moment there are 5 transformer-based algorithms available. 

Here are examples of how to use them

Perhaps the main comment is that when using transformer-based models, the data preparation is a bit different than in other models. Therefore one needs to know the set up at pre-processing stage. 

Let's have a look, starting with the `TabTransformer`

In [1]:
import numpy as np
import pandas as pd
import torch

from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.training import Trainer
from pytorch_widedeep.models import (
    TabTransformer,
    SAINT,
    FTTransformer,
    TabFastFormer,
    TabPerceiver,
    WideDeep,
)
from pytorch_widedeep.metrics import Accuracy

  return f(*args, **kwds)


In [2]:
df = pd.read_csv("data/adult/adult.csv.zip")
df.head()

Unnamed: 0,age,workclass,fnlwgt,education,educational-num,marital-status,occupation,relationship,race,gender,capital-gain,capital-loss,hours-per-week,native-country,income
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K
4,18,?,103497,Some-college,10,Never-married,?,Own-child,White,Female,0,0,30,United-States,<=50K


In [3]:
# For convenience, we'll replace '-' with '_'
df.columns = [c.replace("-", "_") for c in df.columns]
# binary target
df["target"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop(["income", "educational_num"], axis=1, inplace=True)

df.head()

Unnamed: 0,age,workclass,fnlwgt,education,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,target
0,25,Private,226802,11th,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,0
1,38,Private,89814,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,0
2,28,Local-gov,336951,Assoc-acdm,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,1
3,44,Private,160323,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,1
4,18,?,103497,Some-college,Never-married,?,Own-child,White,Female,0,0,30,United-States,0


In [4]:
cat_cols, cont_cols = [], []
for col in df.columns:
    # 50 is just a random number I choose here for this example
    if df[col].dtype == "O" or df[col].nunique() < 50 and col != "target":
        cat_cols.append(col)
    elif col != "target":
        cont_cols.append(col)
target_col = "target"

"Standard" `TabTransformer`

In [5]:
target = df[target_col].values

tab_preprocessor = TabPreprocessor(
    embed_cols=cat_cols, continuous_cols=cont_cols, for_transformer=True
)
X_tab = tab_preprocessor.fit_transform(df)

In [6]:
# here all categorical columns will be encoded as 32 dim embeddings, then passed through the transformer
# blocks, concatenated with the continuous and finally through an MLP
tab_transformer = TabTransformer(
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    cont_norm_layer="batchnorm",
    n_blocks=4,
    n_heads=4,
)

In [7]:
tab_transformer

TabTransformer(
  (cat_and_cont_embed): CatAndContEmbeddings(
    (cat_embed): CategoricalEmbeddings(
      (embed): Embedding(103, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (cont_norm): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (transformer_blks): Sequential(
    (transformer_block0): TransformerEncoder(
      (attn): MultiHeadedAttention(
        (dropout): Dropout(p=0.2, inplace=False)
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (kv_proj): Linear(in_features=32, out_features=64, bias=False)
        (out_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (ff): PositionwiseFF(
        (w_1): Linear(in_features=32, out_features=128, bias=True)
        (w_2): Linear(in_features=128, out_features=32, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation): GELU()
      )
      (attn_addnorm): AddNorm(
        (dropout): Dropout(p=0.

In [8]:
model = WideDeep(deeptabular=tab_transformer)

In [9]:
trainer = Trainer(model, objective="binary", metrics=[Accuracy])

In [10]:
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:14<00:00, 10.56it/s, loss=0.356, metrics={'acc': 0.8321}]
valid: 100%|██████████| 39/39 [00:01<00:00, 36.52it/s, loss=0.336, metrics={'acc': 0.8465}]


We can also choose to use the `FT-Transformer`, where continuous cols are also represented by "Embeddings", via a 1 layer MLP (with or without activation function). When using the `FT-Transformer` we can choose to use the `[CLS]` token as a pooling method or concatenate the output from the transformer blocks, as we did before. Let's use here the `[CLS]` token. Also note that under the hood, the `FT-Transformer` uses Linear Attention. See [Linformer: Self-Attention with Linear Complexity](https://arxiv.org/pdf/2006.04768.pdf)

In [11]:
tab_preprocessor = TabPreprocessor(
    embed_cols=cat_cols,
    continuous_cols=cont_cols,
    for_transformer=True,
    with_cls_token=True,
)
X_tab = tab_preprocessor.fit_transform(df)

In [12]:
ft_transformer = FTTransformer(
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    n_blocks=3,
    n_heads=6,
    input_dim=36,
)

In [13]:
ft_transformer

FTTransformer(
  (cat_and_cont_embed): CatAndContEmbeddings(
    (cat_embed): CategoricalEmbeddings(
      (embed): Embedding(104, 36, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (cont_norm): Identity()
    (cont_embed): ContinuousEmbeddings()
  )
  (transformer_blks): Sequential(
    (fttransformer_block0): FTTransformerEncoder(
      (attn): LinearAttention(
        (dropout): Dropout(p=0.2, inplace=False)
        (qkv_proj): Linear(in_features=36, out_features=108, bias=False)
        (out_proj): Linear(in_features=36, out_features=36, bias=False)
      )
      (ff): PositionwiseFF(
        (w_1): Linear(in_features=36, out_features=94, bias=True)
        (w_2): Linear(in_features=47, out_features=36, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation): REGLU()
      )
      (attn_normadd): NormAdd(
        (dropout): Dropout(p=0.2, inplace=False)
        (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)
      )
     

In [14]:
model = WideDeep(deeptabular=ft_transformer)

In [15]:
trainer = Trainer(model, objective="binary", metrics=[Accuracy])

In [16]:
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:15<00:00,  9.62it/s, loss=0.382, metrics={'acc': 0.8167}]
valid: 100%|██████████| 39/39 [00:01<00:00, 28.84it/s, loss=0.317, metrics={'acc': 0.8566}]


Or we can choose to use SAINT, with its inter-sample attention

In [17]:
saint = SAINT(
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    transformer_activation="geglu",
    n_blocks=2,
    n_heads=4,
)

In [18]:
saint

SAINT(
  (cat_and_cont_embed): CatAndContEmbeddings(
    (cat_embed): CategoricalEmbeddings(
      (embed): Embedding(104, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (cont_norm): Identity()
    (cont_embed): ContinuousEmbeddings()
  )
  (transformer_blks): Sequential(
    (saint_block0): SaintEncoder(
      (col_attn): MultiHeadedAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (kv_proj): Linear(in_features=32, out_features=64, bias=False)
        (out_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (col_attn_ff): PositionwiseFF(
        (w_1): Linear(in_features=32, out_features=256, bias=True)
        (w_2): Linear(in_features=128, out_features=32, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (activation): GEGLU()
      )
      (col_attn_addnorm): AddNorm(
        (dropout): Dropout(p=0.1, inplace=False)
        (ln)

In [19]:
model = WideDeep(deeptabular=saint)
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=128, val_split=0.2)

epoch 1: 100%|██████████| 306/306 [00:47<00:00,  6.42it/s, loss=0.377, metrics={'acc': 0.8224}]
valid: 100%|██████████| 77/77 [00:02<00:00, 32.20it/s, loss=0.338, metrics={'acc': 0.8529}]


The previous models have all been published. The following two are adaptations of existing Transformer models for tabular data and by the time I am writing this they are only available in this library. If I have the time I will write a post about their implementation. Nonetheless, all the details can be found in the [docs](https://pytorch-widedeep.readthedocs.io/en/latest/index.html).

The first one is an adaptation of [Fastformer: Additive Attention Can Be All You Need](https://arxiv.org/pdf/2108.09084.pdf). I have mixed feelings towards that paper, that I will not be covering here, but you can go and watch [Yannic's video](https://www.youtube.com/watch?v=qgUegkefocg&t=1s) since most of my opinions are also explained there. Nonetheless, the reason to bring this model to the library is because in essence, the `FastFormer` is an "elaborated MLP" with an "interesting" attention aggregated attention mechanism. Since MLPs work really well for tabular data compared to other, more complex models, why not add it to the library.  

To use it, just follow the same routine as with any other transformer-based model

In [20]:
tabfastformer = TabFastFormer(
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    n_blocks=2,
    n_heads=4,
)

In [21]:
tabfastformer

TabFastFormer(
  (cat_and_cont_embed): CatAndContEmbeddings(
    (cat_embed): CategoricalEmbeddings(
      (embed): Embedding(104, 32, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (cont_norm): Identity()
    (cont_embed): ContinuousEmbeddings()
  )
  (transformer_blks): Sequential(
    (fastformer_block0): FastFormerEncoder(
      (attn): AdditiveAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (v_proj): Linear(in_features=32, out_features=32, bias=False)
        (k_proj): Linear(in_features=32, out_features=32, bias=False)
        (W_q): Linear(in_features=8, out_features=1, bias=False)
        (W_k): Linear(in_features=8, out_features=1, bias=False)
        (r_out): Linear(in_features=8, out_features=8, bias=True)
      )
      (ff): PositionwiseFF(
        (w_1): Linear(in_features=32, out_features=128, bias=True)
        (w_2): Linear(in_features=128, out_features=32,

In [22]:
model = WideDeep(deeptabular=tabfastformer)
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:10<00:00, 14.58it/s, loss=0.46, metrics={'acc': 0.7867}] 
valid: 100%|██████████| 39/39 [00:00<00:00, 48.19it/s, loss=0.342, metrics={'acc': 0.8443}]


And finally, the last of the transformer-based models that are currently available in the library is DeepMind's [Perceiver](https://arxiv.org/pdf/2103.03206.pdf). The reason to add this model to the library is the following. The Perceiver is meant to be an architecture agnostic of the nature of the input data, i.e. it is meant to work with audio, images, text...So why not tabular, right? 

To use it...you guessed right! 

In [23]:
tab_preprocessor = TabPreprocessor(
    embed_cols=cat_cols,
    continuous_cols=cont_cols,
    for_transformer=True,
)
X_tab = tab_preprocessor.fit_transform(df)

In [24]:
tabperceiver = TabPerceiver(
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    n_perceiver_blocks=1,
    n_latent_blocks=3,
    n_latent_heads=2,
    n_latents=6,
    latent_dim=32,
)

In [25]:
model = WideDeep(deeptabular=tabperceiver)
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:16<00:00,  9.45it/s, loss=0.4, metrics={'acc': 0.81}]    
valid: 100%|██████████| 39/39 [00:01<00:00, 37.95it/s, loss=0.323, metrics={'acc': 0.8542}]


One final comment is that all transformer-based models have the option of using the so called "Shared Embeddings". The idea behind the shared embeddings is explained in the original TabTransformer paper and also here in this [post](https://jrzaurin.github.io/infinitoml/2021/02/18/pytorch-widedeep_iii.html).

For transformer-based models this implies a bit of a different data preparation process since each column will be encoded individually (programmatically is way easier to implement) and the use of shared embeddings needs to be specified at preprocessing stage

In [26]:
tab_preprocessor = TabPreprocessor(
    embed_cols=cat_cols,
    continuous_cols=cont_cols,
    for_transformer=True,
    shared_embed=True,
    with_cls_token=True,
)
X_tab = tab_preprocessor.fit_transform(df)

In [27]:
ft_transformer = TabTransformer(
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
    continuous_cols=tab_preprocessor.continuous_cols,
    embed_continuous=True,
    embed_continuous_activation=None,
    shared_embed=True,
    cont_norm_layer="batchnorm",
    n_blocks=4,
    n_heads=4,
)

In [28]:
ft_transformer

TabTransformer(
  (cat_and_cont_embed): CatAndContEmbeddings(
    (cat_embed): CategoricalEmbeddings(
      (embed): ModuleDict(
        (emb_layer_cls_token): SharedEmbeddings(
          (embed): Embedding(1, 32, padding_idx=0)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (emb_layer_education): SharedEmbeddings(
          (embed): Embedding(17, 32, padding_idx=0)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (emb_layer_gender): SharedEmbeddings(
          (embed): Embedding(3, 32, padding_idx=0)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (emb_layer_marital_status): SharedEmbeddings(
          (embed): Embedding(8, 32, padding_idx=0)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (emb_layer_native_country): SharedEmbeddings(
          (embed): Embedding(43, 32, padding_idx=0)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (emb_layer_occupation): SharedEmbeddings(
       

In [29]:
model = WideDeep(deeptabular=ft_transformer)
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)

epoch 1: 100%|██████████| 153/153 [00:20<00:00,  7.62it/s, loss=0.4, metrics={'acc': 0.8061}]  
valid: 100%|██████████| 39/39 [00:01<00:00, 30.53it/s, loss=0.324, metrics={'acc': 0.8551}]
