# Structured data learning with TabTransformer

## Introduction

This example demonstrates how to do structured data classification using [TabTransformer](https://arxiv.org/pdf/2012.06678.pdf), a deep tabular data modeling architecture for supervised and semi-supervised learning. The TabTransformer is built upon self-attention based Transformers. The Transformer layers transform the embeddings of categorical features into robust contextual embeddings to achieve higher predictive accuracy.

**TabTransformer: Tabular Data Modeling Using Contextual Embeddings (2020 arXiv)**

<img src='./images/tabtransformer.png' width='400'> 

## Setup

In [35]:
import numpy as np
import pandas as pd
import torch
from torch import einsum
from einops import rearrange
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchinfo import summary
from matplotlib import pyplot as plt
from collections import Counter
from sklearn.metrics import confusion_matrix
import os
import time
import sys
from urllib import request
import zipfile

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

## Prepare the data

- Dataset : United States Census Income Dataset
- Info
    - Provided by the UC Irvine Machine Learning Repository
    - Binary classification to predict whether a person is likely to be making over USD 50,000 a year
    - 48,842 instances with 14 input features(5 numerical features, 9 categorical features)

In [4]:
CSV_HEADER = [
    "age",
    "workclass",
    "fnlwgt",
    "education",
    "education_num",
    "marital_status",
    "occupation",
    "relationship",
    "race",
    "gender",
    "capital_gain",
    "capital_loss",
    "hours_per_week",
    "native_country",
    "income_bracket",
]

annotation_folder = os.path.abspath(".") + "/dataset/USCI"

train_data_file = os.path.join(annotation_folder, "train_data.csv")
test_data_file = os.path.join(annotation_folder, "test_data.csv")

if not os.path.exists(annotation_folder):
    os.makedirs(annotation_folder)
    print('Folder creation complete!')
else:
    print('The folder already exists.')

if not os.path.isfile(train_data_file):
    train_data_url = (
        "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
    )
    train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)
    train_data.to_csv(train_data_file, index=False)
    print('train_data.csv creation complete!')
else:
    train_data = pd.read_csv(train_data_file)
    print('train_data.csv already exists.')
print('train_data load complete!')


if not os.path.isfile(test_data_file):
    test_data_url = (
        "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"
    )
    test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)
    test_data = test_data[1:]
    test_data.income_bracket = test_data.income_bracket.apply(
        lambda value: value.replace(".", "")
    )
    test_data.to_csv(test_data_file, index=False)
    print('test_data.csv creation complete!')
else:
    test_data = pd.read_csv(test_data_file)
    print('test_data.csv already exists.')
print('test_data load complete!')

print(f"Train dataset shape: {train_data.shape}")
print(f"Test dataset shape: {test_data.shape}")

The folder already exists.
train_data.csv already exists.
train_data load complete!
test_data.csv already exists.
test_data load complete!
Train dataset shape: (32561, 15)
Test dataset shape: (16281, 15)


In [5]:
train_data.head()

Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,gender,capital_gain,capital_loss,hours_per_week,native_country,income_bracket
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K


## Define dataset metadata

In [6]:
# A list of the numerical feature names.
NUMERIC_FEATURE_NAMES = [
    "age",
    "education_num",
    "capital_gain",
    "capital_loss",
    "hours_per_week",
]
# A dictionary of the categorical features and their vocabulary.
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
    "workclass": sorted(list(train_data["workclass"].unique())),
    "education": sorted(list(train_data["education"].unique())),
    "marital_status": sorted(list(train_data["marital_status"].unique())),
    "occupation": sorted(list(train_data["occupation"].unique())),
    "relationship": sorted(list(train_data["relationship"].unique())),
    "race": sorted(list(train_data["race"].unique())),
    "gender": sorted(list(train_data["gender"].unique())),
    "native_country": sorted(list(train_data["native_country"].unique())),
}

# Name of the column to be used as instances weight.
WEIGHT_COLUMN_NAME = "fnlwgt"
# A list of the categorical feature names.
CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())
# A list of all the input features.
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES
# A list of column default values for each feature.
COLUMN_DEFAULTS = [
    0.0 if feature_name in NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME] else "NA"
    for feature_name in CSV_HEADER
]
# The name of the target feature.
TARGET_FEATURE_NAME = "income_bracket"
# A list of the labels of the target features.
TARGET_LABELS = [" <=50K", " >50K"]

In [7]:
CATEGORICAL_FEATURES_WITH_VOCABULARY

{'workclass': [' ?',
  ' Federal-gov',
  ' Local-gov',
  ' Never-worked',
  ' Private',
  ' Self-emp-inc',
  ' Self-emp-not-inc',
  ' State-gov',
  ' Without-pay'],
 'education': [' 10th',
  ' 11th',
  ' 12th',
  ' 1st-4th',
  ' 5th-6th',
  ' 7th-8th',
  ' 9th',
  ' Assoc-acdm',
  ' Assoc-voc',
  ' Bachelors',
  ' Doctorate',
  ' HS-grad',
  ' Masters',
  ' Preschool',
  ' Prof-school',
  ' Some-college'],
 'marital_status': [' Divorced',
  ' Married-AF-spouse',
  ' Married-civ-spouse',
  ' Married-spouse-absent',
  ' Never-married',
  ' Separated',
  ' Widowed'],
 'occupation': [' ?',
  ' Adm-clerical',
  ' Armed-Forces',
  ' Craft-repair',
  ' Exec-managerial',
  ' Farming-fishing',
  ' Handlers-cleaners',
  ' Machine-op-inspct',
  ' Other-service',
  ' Priv-house-serv',
  ' Prof-specialty',
  ' Protective-serv',
  ' Sales',
  ' Tech-support',
  ' Transport-moving'],
 'relationship': [' Husband',
  ' Not-in-family',
  ' Other-relative',
  ' Own-child',
  ' Unmarried',
  ' Wife'],
 'r

## Configure the hyperparameters

In [8]:
params = {
    'LEARNING_RATE': 0.001,
    'WEIGHT_DECAY': 0.0001,
    'DROPOUT_RATE': 0.2,
    'BATCH_SIZE': 265,
    'NUM_WORKERS': 4,
    'PIN_MEMORY' : True,
    'NUM_EPOCHS': 15,
    'NUM_TRANSFORMER_BLOCKS': 3,    # Number of transformer blocks.
    'NUM_HEADS': 4,  # Number of attention heads.
    'EMBEDDING_DIMS': 16,   # Embedding dimensions of the categorical features.
    'MLP_HIDDEN_UNITS_FACTORS': [2, 1], # MLP hidden layer units, as factors of the number of inputs.
    'NUM_MLP_BLOCKS': 2, # Number of MLP blocks in the baseline model.
    'DEVICE': device
}

## Implement data reading pipeline

In [9]:
class CustomDataset(Dataset):
    def __init__(self,
                 dataframe,
                 VOCABULARY=CATEGORICAL_FEATURES_WITH_VOCABULARY,
                 FEATURE_NAMES=FEATURE_NAMES,
                 NUMERIC_FEATURE_NAMES=NUMERIC_FEATURE_NAMES,
                 TARGET_FEATURE_NAME=TARGET_FEATURE_NAME,
                 TARGET_LABELS=TARGET_LABELS,
                 WEIGHT_COLUMN_NAME=WEIGHT_COLUMN_NAME,
                 COLUMN_DEFAULTS=COLUMN_DEFAULTS):

        data = dataframe.copy().dropna()

        for idx, feature_name in enumerate(FEATURE_NAMES):
            # data[feature_name] = data[feature_name].fillna(COLUMN_DEFAULTS[idx])
            if feature_name not in NUMERIC_FEATURE_NAMES:
                data[feature_name] = data[feature_name].map(lambda x: VOCABULARY[feature_name].index(x))

        data[TARGET_FEATURE_NAME] = data[TARGET_FEATURE_NAME].map(lambda x: TARGET_LABELS.index(x))
        self.weights = np.array(data.pop(WEIGHT_COLUMN_NAME), dtype=np.float32)
        self.targets = np.array(data.pop(TARGET_FEATURE_NAME), dtype=np.int)
        self.data = np.array(data[FEATURE_NAMES])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features = torch.from_numpy(self.data[idx])
        target_index = torch.as_tensor(self.targets[idx])
        weight = torch.as_tensor(self.weights[idx])
        return features, target_index, weight

In [10]:
train_dataset = CustomDataset(train_data)
test_dataset = CustomDataset(test_data)
train_loader = DataLoader(dataset=train_dataset,
                          shuffle=True,
                          batch_size=params['BATCH_SIZE'],
                          pin_memory=params['PIN_MEMORY'],
                          num_workers=params['NUM_WORKERS'])
test_loader = DataLoader(dataset=test_dataset,
                         shuffle=False,
                         batch_size=params['BATCH_SIZE'],
                         pin_memory=params['PIN_MEMORY'],
                         num_workers=params['NUM_WORKERS'])

In [11]:
Counter(train_data[TARGET_FEATURE_NAME])

Counter({' <=50K': 24720, ' >50K': 7841})

In [12]:
Counter(test_data[TARGET_FEATURE_NAME])

Counter({' <=50K': 12435, ' >50K': 3846})

In [13]:
print(next(iter(train_loader))[0].shape, next(iter(train_loader))[1].shape)

torch.Size([265, 13]) torch.Size([265])


In [14]:
input_shape = next(iter(train_loader))[0].shape

## Experiment 1: a baseline model

<img src='./images/tab_baseline.png' width='1000'> 

In [15]:
class create_mlp(nn.Module):
    def __init__(self, dims, dropout_rate, activation, normalization_layer):
        super(create_mlp, self).__init__()
        self.mlp_layers = nn.ModuleList()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        
        for ind, (dim_in, dim_out) in enumerate(dims_pairs):
            mlp = nn.Sequential(
                nn.LayerNorm(dim_in, eps=1e-6) if normalization_layer=='LayerNorm'
                else nn.BatchNorm1d(dim_in),
                nn.Linear(dim_in, dim_out),
                activation,
                nn.Dropout(dropout_rate)
            )
            self.mlp_layers.append(mlp)
        self.mlp_layers = nn.Sequential(*self.mlp_layers)

    def forward(self, inputs):
        return self.mlp_layers(inputs)

In [16]:
summary(create_mlp([13, 13], 0.01, nn.GELU(), 'LayerNorm'), input_shape)

Layer (type:depth-idx)                   Output Shape              Param #
create_mlp                               --                        --
├─Sequential: 1-1                        [265, 13]                 --
│    └─Sequential: 2-1                   [265, 13]                 --
│    │    └─LayerNorm: 3-1               [265, 13]                 26
│    │    └─Linear: 3-2                  [265, 13]                 182
│    │    └─GELU: 3-3                    [265, 13]                 --
│    │    └─Dropout: 3-4                 [265, 13]                 --
Total params: 208
Trainable params: 208
Non-trainable params: 0
Total mult-adds (M): 0.06
Input size (MB): 0.01
Forward/backward pass size (MB): 0.06
Params size (MB): 0.00
Estimated Total Size (MB): 0.07

In [17]:
class create_baseline_model(nn.Module):
    def __init__(self,
                 params,
                 FEATURE_NAMES=FEATURE_NAMES,
                 CATEGORICAL_FEATURES_WITH_VOCABULARY=CATEGORICAL_FEATURES_WITH_VOCABULARY, ):
        super(create_baseline_model, self).__init__()
        self.embedding_dims = params['EMBEDDING_DIMS']
        self.num_mlp_blocks = params['NUM_MLP_BLOCKS']
        self.mlp_hidden_units_factors = params['MLP_HIDDEN_UNITS_FACTORS']
        self.dropout_rate = params['DROPOUT_RATE']
        self.device = params['DEVICE']
        
        self.num_features = len(FEATURE_NAMES)
        self.num_cat_features = len(CATEGORICAL_FEATURES_WITH_VOCABULARY)
        self.num_numeric_features = self.num_features - self.num_cat_features
        self.embedding_layers = nn.ModuleList()
        self.mlp_blocks = nn.ModuleList()

        for feature in FEATURE_NAMES:
            if feature in CATEGORICAL_FEATURES_WITH_VOCABULARY.keys():
                emb = nn.Embedding(len(CATEGORICAL_FEATURES_WITH_VOCABULARY[feature]), self.embedding_dims)
                self.embedding_layers.append(emb)
                
        input_size = (self.embedding_dims * self.num_cat_features) + self.num_numeric_features

        for blocks_idx in range(self.num_mlp_blocks):
            block = create_mlp([input_size, input_size],
                               self.dropout_rate,
                               nn.GELU(),
                               'LayerNorm')
            self.mlp_blocks.append(block)
        self.mlp_blocks = nn.Sequential(*self.mlp_blocks)
        
        mlp_hidden_units = [factor * input_size for factor in self.mlp_hidden_units_factors]
        mlp_hidden_units.insert(0, input_size)

        self.MLP = create_mlp(mlp_hidden_units,
                              self.dropout_rate,
                              nn.SELU(),
                              'BatchNorm')

        self.classifier = nn.Linear(mlp_hidden_units[-1], 1)

    def forward(self, inputs):
        numeric_vectors = []
        for i in range(self.num_numeric_features):
            emb = inputs[:, i].view(inputs.size(0), -1, 1).float().to(self.device)
            numeric_vectors.append(emb)
        numeric_vectors = torch.cat(numeric_vectors, dim=1)
        numeric_vectors = numeric_vectors.squeeze()

        categoical_vectors = []
        for i in range(self.num_cat_features):
            emb = self.embedding_layers[i](inputs[:, self.num_numeric_features+i].view(inputs.size(0), -1).long().to(self.device))
            categoical_vectors.append(emb)
        categoical_vectors = torch.cat(categoical_vectors, dim=2)
        categoical_vectors = categoical_vectors.squeeze()
        embeddings = torch.cat([numeric_vectors,categoical_vectors], dim=1)

        features = embeddings
        features = self.mlp_blocks(features)

        features = self.MLP(features)
        output = self.classifier(features)

        return output.type(torch.FloatTensor).to(self.device)


In [18]:
baseline_model = create_baseline_model(params).to(device)
optimizer = optim.Adam(baseline_model.parameters(), lr=params['LEARNING_RATE'])
loss_fn = nn.BCEWithLogitsLoss()

In [19]:
summary(baseline_model, input_shape)

Layer (type:depth-idx)                   Output Shape              Param #
create_baseline_model                    --                        --
├─ModuleList: 1-1                        --                        --
│    └─Embedding: 2-1                    [265, 1, 16]              144
│    └─Embedding: 2-2                    [265, 1, 16]              256
│    └─Embedding: 2-3                    [265, 1, 16]              112
│    └─Embedding: 2-4                    [265, 1, 16]              240
│    └─Embedding: 2-5                    [265, 1, 16]              96
│    └─Embedding: 2-6                    [265, 1, 16]              80
│    └─Embedding: 2-7                    [265, 1, 16]              32
│    └─Embedding: 2-8                    [265, 1, 16]              672
├─Sequential: 1-2                        [265, 133]                --
│    └─create_mlp: 2-9                   [265, 133]                --
│    │    └─Sequential: 3-1              [265, 133]                18,088
│    └

In [20]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [21]:
def binary_acc(y_pred, y_test):
    y_pred_tag = torch.round(torch.sigmoid(y_pred))

    correct_results_sum = (y_pred_tag == y_test).sum().float()
    acc = correct_results_sum/y_test.shape[0]
    acc = torch.round(acc * 100)

    return acc

In [22]:
def train(model, train_data, optimizer, loss_fn, use_fp16=True, max_norm=None, progress_display=False):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    Acc = AverageMeter('Acc', ':6.2f')
    progress = ProgressMeter(
        len(train_data),
        [batch_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    model.train()
    end = time.time()
    for idx, [x, y, w] in enumerate(train_data):
        optimizer.zero_grad(set_to_none=True)
        scaler = torch.cuda.amp.GradScaler()

        input = x
        target = y.unsqueeze(1).to(device, dtype=torch.float32)

        with torch.cuda.amp.autocast(enabled=use_fp16):
            predictions = model(input)
            train_loss = loss_fn(predictions, target)
        if use_fp16:
            scaler.scale(train_loss).backward()
            if max_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            train_loss.backward()
            if max_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
        acc = binary_acc(predictions, target)
        Acc.update(acc, input.size(0))
        losses.update(train_loss.item(), input.size(0))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if progress_display == True and idx % 50 == 0:
            progress.display(idx)

    return losses.avg, Acc.avg

In [23]:
def validation(model, val_data, loss_fn):
    losses = AverageMeter('Loss', ':.4e')
    Acc = AverageMeter('Acc', ':6.2f')
    model.eval()
    val_loss = 0
    for idx, [x, y, w] in enumerate(val_data):
        input = x
        target = y.unsqueeze(1).to(device, dtype=torch.float32)
        with torch.no_grad():
            predictions = model(input)
            val_loss = loss_fn(predictions, target)
        acc = binary_acc(predictions, target)
        Acc.update(acc, input.size(0))
        losses.update(val_loss.item(), input.size(0))

    return losses.avg, Acc.avg

In [24]:
class EarlyStopping():
    def __init__(self, patience=0, verbose=0):
        self._step = 0
        self._loss = float('inf')
        self.patience = patience
        self.verbose = verbose

    def validate(self, loss):
        if self._loss < loss:
            self._step += 1
            if self._step > self.patience:
                if self.verbose:
                    print(f'\n Training process is stopped early....')
                return True
        else:
            self._step = 0
            self._loss = loss

        return False

In [25]:
%%time
best = {"acc": sys.float_info.min}
history = dict()
early_stopping = EarlyStopping(patience=5, verbose=1)

for epoch in range(1, params['NUM_EPOCHS']+1):
    epoch_loss, epoch_acc = train(baseline_model, train_loader, optimizer, loss_fn, use_fp16=False)
    val_loss, val_acc = validation(baseline_model, test_loader, loss_fn)

    history.setdefault('loss', []).append(epoch_loss)
    history.setdefault('val_loss', []).append(val_loss)
    history.setdefault('accuracy', []).append(epoch_acc)
    history.setdefault('val_accuracy', []).append(val_acc)

    print(f"[Train] Epoch : {epoch:^3}"
          f"  Train Loss: {epoch_loss:.4}"
          f"  Train Acc: {epoch_acc:.4}"
          f"  Valid Loss: {val_loss:.4}"
          f"  Valid Acc: {val_acc:.4}")

    if val_acc > best["acc"]:
        best["state"] = baseline_model.state_dict()
        best["acc"] = val_acc
        best["epoch"] = epoch
    if early_stopping.validate(val_loss):
        break

[Train] Epoch :  1   Train Loss: 0.4283  Train Acc: 78.47  Valid Loss: 0.441  Valid Acc: 75.83
[Train] Epoch :  2   Train Loss: 0.3779  Train Acc: 80.88  Valid Loss: 0.3644  Valid Acc: 82.16
[Train] Epoch :  3   Train Loss: 0.3681  Train Acc: 81.07  Valid Loss: 0.3639  Valid Acc: 81.25
[Train] Epoch :  4   Train Loss: 0.3654  Train Acc: 81.43  Valid Loss: 0.357  Valid Acc: 82.07
[Train] Epoch :  5   Train Loss: 0.3615  Train Acc: 81.82  Valid Loss: 0.3581  Valid Acc: 81.83
[Train] Epoch :  6   Train Loss: 0.3611  Train Acc: 81.76  Valid Loss: 0.3557  Valid Acc: 81.84
[Train] Epoch :  7   Train Loss: 0.3583  Train Acc: 81.92  Valid Loss: 0.3579  Valid Acc: 81.36
[Train] Epoch :  8   Train Loss: 0.3543  Train Acc: 82.36  Valid Loss: 0.3601  Valid Acc: 82.27
[Train] Epoch :  9   Train Loss: 0.3442  Train Acc: 83.11  Valid Loss: 0.3346  Valid Acc: 83.22
[Train] Epoch : 10   Train Loss: 0.336  Train Acc: 83.63  Valid Loss: 0.3342  Valid Acc: 83.87
[Train] Epoch : 11   Train Loss: 0.3316  Tr

## Test

In [37]:
losses = AverageMeter('Loss', ':.4e')
Acc = AverageMeter('Acc', ':6.2f')

baseline_model.eval()
test_loss = 0
predic = []
for idx, [x, y, w] in enumerate(test_loader):
    input = x
    target = y.unsqueeze(1).to(device, dtype=torch.float32)
    with torch.no_grad():
        predictions = baseline_model(input)
        test_loss = loss_fn(predictions, target)
    predic.extend(torch.round(torch.sigmoid(predictions)).detach().cpu().squeeze().tolist())
    acc = binary_acc(predictions, target)
    Acc.update(acc, input.size(0))
    losses.update(test_loss.item(), input.size(0))

print(f'Loss : {losses.avg}, Acc : {Acc.avg}')

Loss : 0.33143243851323995, Acc : 84.24415588378906


In [39]:
confusion_matrix(test_data[TARGET_FEATURE_NAME].map(lambda x: TARGET_LABELS.index(x)).tolist(), predic)

array([[11798,   637],
       [ 1933,  1913]])

# Experiment 2: TabTransformer

The TabTransformer architecture works as follows:

1. All the categorical features are encoded as embeddings, using the same embedding_dims. This means that each value in each categorical feature will have its own embedding vector.
2. A column embedding, one embedding vector for each categorical feature, is added (point-wise) to the categorical feature embedding.
3. The embedded categorical features are fed into a stack of Transformer blocks. Each Transformer block consists of a multi-head self-attention layer followed by a feed-forward layer.
4. The outputs of the final Transformer layer, which are the contextual embeddings of the categorical features, are concatenated with the input numerical features, and fed into a final MLP block.
5. A softmax classifer is applied at the end of the model.

<img src='./images/tabtransformer.png' width='400'> 

<img src='./images/tabtransformer_1.png' width='1000'> 

Reference : https://github.com/lucidrains/tab-transformer-pytorch

In [26]:
# classes

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [27]:
# attention

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x, **kwargs):
        return self.net(x)

class Attention(nn.Module):
    def __init__(
            self,
            dim,
            heads = 8,
            dim_head = 16,
            dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

In [28]:
class create_tabtransformer_classifier(nn.Module):
    def __init__(self,
                 params,
                 FEATURE_NAMES=FEATURE_NAMES,
                 CATEGORICAL_FEATURES_WITH_VOCABULARY=CATEGORICAL_FEATURES_WITH_VOCABULARY,
                 USE_COLUMN_EMBEDDING=False):
        super(create_tabtransformer_classifier, self).__init__()
        
        self.num_transformer_blocks = params['NUM_TRANSFORMER_BLOCKS']
        self.num_heads = params['NUM_HEADS']
        self.embedding_dims = params['EMBEDDING_DIMS']
        self.mlp_hidden_units_factors = params['MLP_HIDDEN_UNITS_FACTORS']
        self.dropout_rate = params['DROPOUT_RATE']
        self.device = params['DEVICE']
        
        self.num_features = len(FEATURE_NAMES)
        self.num_cat_features = len(CATEGORICAL_FEATURES_WITH_VOCABULARY)
        self.use_column_embedding = USE_COLUMN_EMBEDDING
        self.num_numeric_features = self.num_features - self.num_cat_features
        self.embedding_layers = nn.ModuleList()
        self.transformer_blocks = nn.ModuleList()

        for feature in FEATURE_NAMES:
            if feature in CATEGORICAL_FEATURES_WITH_VOCABULARY.keys():
                emb = nn.Embedding(len(CATEGORICAL_FEATURES_WITH_VOCABULARY[feature]), self.embedding_dims)
                self.embedding_layers.append(emb)

        if self.use_column_embedding:
            self.column_emb = nn.Embedding(self.num_cat_features, self.embedding_dims)

        for block_idx in range(self.num_transformer_blocks):
            self.transformer_blocks.append(nn.ModuleList([
                Residual(PreNorm(self.embedding_dims, Attention(self.embedding_dims, heads = self.num_heads, dim_head = self.embedding_dims, dropout = self.dropout_rate))),
                Residual(PreNorm(self.embedding_dims, FeedForward(self.embedding_dims, dropout = self.dropout_rate))),
            ]))

        input_size = (self.embedding_dims * self.num_cat_features) + self.num_numeric_features
        mlp_hidden_units = [factor * input_size for factor in self.mlp_hidden_units_factors]
        mlp_hidden_units.insert(0, input_size)

        self.MLP = create_mlp(mlp_hidden_units,
                              self.dropout_rate,
                              nn.SELU(),
                              'BatchNorm')

        self.classifier = nn.Linear(mlp_hidden_units[-1], 1)

    def forward(self, inputs):
        numeric_vectors = []
        for i in range(self.num_numeric_features):
            # emb = self.linear_layers[i](inputs[:, i].view(inputs.size(0), -1, 1).float().to(self.device))
            emb = inputs[:, i].view(inputs.size(0), -1, 1).float().to(self.device)
            numeric_vectors.append(emb)
        numeric_vectors = torch.cat(numeric_vectors, dim=1)
        numeric_vectors = numeric_vectors.squeeze()

        categorical_vectors = []
        for i in range(self.num_cat_features):
            emb = self.embedding_layers[i](inputs[:, self.num_numeric_features+i].view(inputs.size(0), -1).long().to(self.device))
            categorical_vectors.append(emb)
        categorical_vectors = torch.cat(categorical_vectors, dim=1)

        for attn, ff in self.transformer_blocks:
            categorical_vectors = attn(categorical_vectors)
            categorical_vectors = ff(categorical_vectors)

        categorical_vectors = rearrange(categorical_vectors, 'b n d -> b (n d)')
        embeddings = torch.cat([numeric_vectors,categorical_vectors], dim=1)

        features = embeddings

        features = self.MLP(features)
        output = self.classifier(features)

        return output.type(torch.FloatTensor).to(self.device)


In [29]:
tabtransformer = create_tabtransformer_classifier(params).to(device)
optimizer = optim.Adam(tabtransformer.parameters(), lr=params['LEARNING_RATE'])
loss_fn = nn.BCEWithLogitsLoss()

In [30]:
summary(tabtransformer, input_shape)

Layer (type:depth-idx)                             Output Shape              Param #
create_tabtransformer_classifier                   --                        --
├─ModuleList: 1-1                                  --                        --
├─ModuleList: 1-2                                  --                        --
│    └─ModuleList: 2-1                             --                        --
│    └─ModuleList: 2-2                             --                        --
│    └─ModuleList: 2-3                             --                        --
├─ModuleList: 1-1                                  --                        --
│    └─Embedding: 2-4                              [265, 1, 16]              144
│    └─Embedding: 2-5                              [265, 1, 16]              256
│    └─Embedding: 2-6                              [265, 1, 16]              112
│    └─Embedding: 2-7                              [265, 1, 16]              240
│    └─Embedding: 2-8          

In [31]:
%%time
best = {"acc": sys.float_info.min}
history = dict()
early_stopping = EarlyStopping(patience=5, verbose=1)

for epoch in range(1, params['NUM_EPOCHS']+1):
    epoch_loss, epoch_acc = train(tabtransformer, train_loader, optimizer, loss_fn, use_fp16=False)
    val_loss, val_acc = validation(tabtransformer, test_loader, loss_fn)

    history.setdefault('loss', []).append(epoch_loss)
    history.setdefault('val_loss', []).append(val_loss)
    history.setdefault('accuracy', []).append(epoch_acc)
    history.setdefault('val_accuracy', []).append(val_acc)

    print(f"[Train] Epoch : {epoch:^3}"
          f"  Train Loss: {epoch_loss:.4}"
          f"  Train Acc: {epoch_acc:.4}"
          f"  Valid Loss: {val_loss:.4}"
          f"  Valid Acc: {val_acc:.4}")

    if val_acc > best["acc"]:
        best["state"] = tabtransformer.state_dict()
        best["acc"] = val_acc
        best["epoch"] = epoch
    if early_stopping.validate(val_loss):
        break

[Train] Epoch :  1   Train Loss: 0.4034  Train Acc: 80.9  Valid Loss: 0.3334  Valid Acc: 84.64
[Train] Epoch :  2   Train Loss: 0.3411  Train Acc: 84.07  Valid Loss: 0.3298  Valid Acc: 84.34
[Train] Epoch :  3   Train Loss: 0.3351  Train Acc: 84.36  Valid Loss: 0.3233  Valid Acc: 84.8
[Train] Epoch :  4   Train Loss: 0.3299  Train Acc: 84.58  Valid Loss: 0.3233  Valid Acc: 84.78
[Train] Epoch :  5   Train Loss: 0.3291  Train Acc: 84.58  Valid Loss: 0.3204  Valid Acc: 85.04
[Train] Epoch :  6   Train Loss: 0.3261  Train Acc: 84.65  Valid Loss: 0.3205  Valid Acc: 84.97
[Train] Epoch :  7   Train Loss: 0.3256  Train Acc: 84.69  Valid Loss: 0.3175  Valid Acc: 85.22
[Train] Epoch :  8   Train Loss: 0.3237  Train Acc: 84.68  Valid Loss: 0.3213  Valid Acc: 84.87
[Train] Epoch :  9   Train Loss: 0.3241  Train Acc: 84.9  Valid Loss: 0.3179  Valid Acc: 85.16
[Train] Epoch : 10   Train Loss: 0.3193  Train Acc: 84.94  Valid Loss: 0.3218  Valid Acc: 84.66
[Train] Epoch : 11   Train Loss: 0.3204  Tr

In [40]:
losses = AverageMeter('Loss', ':.4e')
Acc = AverageMeter('Acc', ':6.2f')

tabtransformer.eval()
test_loss = 0
predic = []
for idx, [x, y, w] in enumerate(test_loader):
    input = x
    target = y.unsqueeze(1).to(device, dtype=torch.float32)
    with torch.no_grad():
        predictions = tabtransformer(input)
        test_loss = loss_fn(predictions, target)
    predic.extend(torch.round(torch.sigmoid(predictions)).detach().cpu().squeeze().tolist())
    acc = binary_acc(predictions, target)
    Acc.update(acc, input.size(0))
    losses.update(test_loss.item(), input.size(0))

print(f'Loss : {losses.avg}, Acc : {Acc.avg}')

Loss : 0.31577428002485913, Acc : 85.18414306640625


In [41]:
confusion_matrix(test_data[TARGET_FEATURE_NAME].map(lambda x: TARGET_LABELS.index(x)).tolist(), predic)

array([[11452,   983],
       [ 1434,  2412]])

# Compare with official Tabtransformer code

In [32]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

In [33]:
# transformer

class Transformer(nn.Module):
    def __init__(self, num_tokens, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
        super().__init__()
        self.embeds = nn.Embedding(num_tokens, dim)
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
                Residual(PreNorm(dim, FeedForward(dim, dropout = ff_dropout))),
            ]))

    def forward(self, x):
        x = self.embeds(x)

        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)

        return x
# mlp

class MLP(nn.Module):
    def __init__(self, dims, act = None):
        super().__init__()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        layers = []
        for ind, (dim_in, dim_out) in enumerate(dims_pairs):
            is_last = ind >= (len(dims_pairs) - 1)
            linear = nn.Linear(dim_in, dim_out)
            layers.append(linear)

            if is_last:
                continue

            act = default(act, nn.ReLU())
            layers.append(act)

        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)

In [34]:
# main class

class TabTransformer(nn.Module):
    def __init__(
            self,
            *,
            categories,
            num_continuous,
            dim,
            depth,
            heads,
            dim_head = 16,
            dim_out = 1,
            mlp_hidden_mults = (4, 2),
            mlp_act = None,
            num_special_tokens = 2,
            continuous_mean_std = None,
            attn_dropout = 0.,
            ff_dropout = 0.
    ):
        super().__init__()
        assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'

        # categories related calculations

        self.num_categories = len(categories)
        self.num_unique_categories = sum(categories)

        # create category embeddings table

        self.num_special_tokens = num_special_tokens
        total_tokens = self.num_unique_categories + num_special_tokens

        # for automatically offsetting unique category ids to the correct position in the categories embedding table

        categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
        categories_offset = categories_offset.cumsum(dim = -1)[:-1]
        self.register_buffer('categories_offset', categories_offset)

        # continuous

        if exists(continuous_mean_std):
            assert continuous_mean_std.shape == (num_continuous, 2), f'continuous_mean_std must have a shape of ({num_continuous}, 2) where the last dimension contains the mean and variance respectively'
        self.register_buffer('continuous_mean_std', continuous_mean_std)

        self.norm = nn.LayerNorm(num_continuous)
        self.num_continuous = num_continuous

        # transformer

        self.transformer = Transformer(
            num_tokens = total_tokens,
            dim = dim,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout
        )

        # mlp to logits

        input_size = (dim * self.num_categories) + num_continuous
        l = input_size // 8

        hidden_dimensions = list(map(lambda t: l * t, mlp_hidden_mults))
        all_dimensions = [input_size, *hidden_dimensions, dim_out]

        self.mlp = MLP(all_dimensions, act = mlp_act)

    def forward(self, x_categ, x_cont):
        assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
        x_categ += self.categories_offset

        x = self.transformer(x_categ)

        flat_categ = x.flatten(1)

        assert x_cont.shape[1] == self.num_continuous, f'you must pass in {self.num_continuous} values for your continuous input'

        if exists(self.continuous_mean_std):
            mean, std = self.continuous_mean_std.unbind(dim = -1)
            x_cont = (x_cont - mean) / std

        normed_cont = self.norm(x_cont)

        x = torch.cat((flat_categ, normed_cont), dim = -1)
        return self.mlp(x)