<a href="https://colab.research.google.com/github/edypidy/SkyElephant-not-a-FlyingElephant/blob/main/FTTransformer/FTTransformer_StudyNote_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
!pip install einops
from einops import repeat

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 493 kB/s 
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


# Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# numerical embedder

class NumericalEmbedder(nn.Module):
    def __init__(self, dim, num_numerical_types):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
        self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))

    def forward(self, x):
        x = x.unsqueeze(-1)
        return x * self.weights + self.biases


# Feedforward

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, in_dim, hidden_mult = 4, dropout = 0.):
        super().__init__()
        self.Layer1 = nn.Sequential(nn.LayerNorm(in_dim),
                                    nn.Linear(in_dim, in_dim*hidden_mult*2),
                                    GEGLU(),
                                    nn.Dropout(dropout))
        self.Layer2 = nn.Linear(in_dim*hidden_mult, in_dim)
        self.norm = nn.LayerNorm(in_dim)
    
    def forward(self, x):
        output = self.Layer1(x)
        output = self.Layer2(output)
        output = self.norm(output)
        output = output + x # residual
        return output


# Attention

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)

    def forward(self, x):
        output = self.attn(x,x,x)[0]
        output = self.norm(output)
        output = output + x # residual
        return output


# Transformer

class Transformer(nn.Module):
    def __init__(self, embed_dim, depth, num_heads, attn_dropout, ff_dropout):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                SelfAttention(embed_dim, num_heads=num_heads, dropout=attn_dropout),
                FeedForward(embed_dim, dropout = ff_dropout),
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)

        return x

In [94]:
class FTTransformer(nn.Module):
    def __init__(self, *,
        categories,
        num_continuous,
        embed_dim = 16,
        depth = 2,
        heads = 8,
        dim_out = 1,
        num_special_tokens = 2,
        attn_dropout = 0.,
        ff_dropout = 0.):
        
        super().__init__()

        # Treat Categories

        self.num_categories = len(categories)
        self.num_unique_categories = sum(categories)

        # Create category embeddings table

        self.num_special_tokens = num_special_tokens # Since add categories_offset to x_categories, first 'num_special_tokens' special tokens mean NA
        total_tokens = self.num_unique_categories + num_special_tokens
        # embedding table
        self.categorical_embeds = nn.Embedding(total_tokens, embed_dim) # LookUp Table : total_tokens x embed_dim

        # offset of categories for the categories embedding table like positional encoding (Alternative methodology from paper)
        categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
        categories_offset = categories_offset.cumsum(dim = -1)[:-1] # by cumsuming so every category is distinguished
        self.register_buffer('categories_offset', categories_offset) # categories offset must be unlearnable


        # Treat Continuous

        self.numerical_embedder = NumericalEmbedder(embed_dim, num_continuous)
        

        # cls token

        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))


        # Transformer

        self.transformer = Transformer(embed_dim=embed_dim,
                                       depth=depth,
                                       num_heads=heads,
                                       attn_dropout=attn_dropout,
                                       ff_dropout=ff_dropout,
                                       )


        # To logits

        self.to_logits = nn.Sequential(nn.LayerNorm(embed_dim),
                                       nn.ReLU(),
                                       nn.Linear(embed_dim, dim_out)
                                       )




    def forward(self, x_categ, x_numer):
        b = x_categ.shape[0] # batch size

        assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
        x_categ += self.categories_offset

        x_categ = self.categorical_embeds(x_categ) # Categories Embedding is 'LookUp Table' method => batch x categ_col_nums x embed_dim

        # add numerically embedded tokens

        x_numer = self.numerical_embedder(x_numer)

        # concat categorical and numerical

        x = torch.cat((x_categ, x_numer), dim = 1)

        # Append cls tokens by batch == torch.cat([self.cls_token for _ in range(b)], dim=0)

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim = 1)

        # attend

        x = self.transformer(x)

        # get cls token

        x = x[:, 0]

        # out in the paper is linear(relu(ln(cls)))

        return self.to_logits(x)

# GEGLU

In [86]:
# GEGLUE
x = torch.randn(5*2)
x, gate = x.chunk(2, dim = -1)
x * F.gelu(gate)

tensor([ 0.1978, -0.1254, -0.0500,  0.1201,  0.0209])

# Category embedding table & Offset encoding

In [58]:
categories = (2,3,4,5,6)
em = nn.Embedding(sum(categories)+2, 4)
em.weight

Parameter containing:
tensor([[-0.1422, -0.0644, -0.2687, -0.3519],
        [-0.1310, -0.7452, -0.0613, -0.0983],
        [-0.6265,  1.6764,  2.5666,  0.0101],
        [ 1.1247, -0.1997, -0.7096, -0.4328],
        [-0.7631,  0.0071,  0.1174,  1.2644],
        [-0.2029, -0.7141,  1.2456, -1.0946],
        [ 0.7132, -0.0879,  0.8890, -0.1016],
        [-1.2408,  0.6409, -1.3812,  1.7102],
        [ 0.1079,  0.0172, -0.3947,  0.1761],
        [ 1.1205, -0.8397,  0.2906,  0.0784],
        [-0.3582, -0.1701, -0.2558, -1.2307],
        [ 0.0965, -1.2816, -2.9810,  1.9548],
        [ 0.1950, -0.0039, -1.6652, -0.2576],
        [ 2.2781, -0.7847,  0.7122,  1.2124],
        [ 0.8438,  0.3787, -2.6487,  3.0309],
        [ 1.2570, -0.1070,  0.6101,  1.7393],
        [-0.3522,  1.4458, -0.2538,  1.2323],
        [ 1.8391,  1.2671, -0.8121,  1.0957],
        [-0.7916, -1.7335, -0.7643,  0.7543],
        [ 0.3516,  0.2313,  0.7469,  0.3150],
        [-0.2621,  1.8219,  1.4400,  0.9879],
        [ 2.

In [59]:
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = 2)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
categories_offset

tensor([ 2,  4,  7, 11, 16])

In [66]:
x_categ = torch.tensor([[0,2,2,3,5],
                        [1,2,3,4,4]])
x_categ += categories_offset
x_categ

tensor([[ 2,  6,  9, 14, 21],
        [ 3,  6, 10, 15, 20]])

In [67]:
x_categ = em(x_categ)
x_categ

tensor([[[-0.6265,  1.6764,  2.5666,  0.0101],
         [ 0.7132, -0.0879,  0.8890, -0.1016],
         [ 1.1205, -0.8397,  0.2906,  0.0784],
         [ 0.8438,  0.3787, -2.6487,  3.0309],
         [ 2.8616,  1.8804,  1.7830,  1.1806]],

        [[ 1.1247, -0.1997, -0.7096, -0.4328],
         [ 0.7132, -0.0879,  0.8890, -0.1016],
         [-0.3582, -0.1701, -0.2558, -1.2307],
         [ 1.2570, -0.1070,  0.6101,  1.7393],
         [-0.2621,  1.8219,  1.4400,  0.9879]]], grad_fn=<EmbeddingBackward0>)

In [68]:
x_categ.shape

torch.Size([2, 5, 4])

# Numerical embedding

In [69]:
nem = NumericalEmbedder(dim=4, num_numerical_types=6)

In [71]:
x_numeric = torch.tensor([[1,2,3,4,5,6],
                          [7,8,9,10,11,12]])
x_numeric = nem(x_numeric)
x_numeric

tensor([[[  3.0601,   0.9302,   0.4125,   0.9413],
         [ -3.4176,  -0.5423,   2.9324,  -0.9961],
         [ -0.3026,   0.6032,  -2.0968,   0.1400],
         [  0.7046,   1.7296,   7.3469,  -4.1400],
         [ -0.9279,  -6.4806,  -5.2516,   2.1788],
         [  2.8452,  -2.9124,  -2.3786,  -0.4056]],

        [[  8.8174,   8.0591,   4.0979,  -3.1194],
         [ -8.4295,  -5.6186,  10.4920,  -4.2528],
         [ -0.6172,   2.5088,  -4.8487,  -0.9211],
         [  0.2880,   4.7461,  20.1742,  -9.9391],
         [ -1.9728, -11.7453, -11.0497,   4.9542],
         [  7.3594,  -6.0567,  -3.6522,   0.0626]]], grad_fn=<AddBackward0>)

In [72]:
x_numeric.shape

torch.Size([2, 6, 4])

# Cat(x_categ, x_numer)

In [75]:
x = torch.cat((x_categ, x_numeric), dim = 1)
x

tensor([[[-6.2654e-01,  1.6764e+00,  2.5666e+00,  1.0077e-02],
         [ 7.1324e-01, -8.7891e-02,  8.8896e-01, -1.0161e-01],
         [ 1.1205e+00, -8.3974e-01,  2.9055e-01,  7.8428e-02],
         [ 8.4383e-01,  3.7866e-01, -2.6487e+00,  3.0309e+00],
         [ 2.8616e+00,  1.8804e+00,  1.7830e+00,  1.1806e+00],
         [ 3.0601e+00,  9.3024e-01,  4.1254e-01,  9.4132e-01],
         [-3.4176e+00, -5.4230e-01,  2.9324e+00, -9.9614e-01],
         [-3.0264e-01,  6.0320e-01, -2.0968e+00,  1.3997e-01],
         [ 7.0458e-01,  1.7296e+00,  7.3469e+00, -4.1400e+00],
         [-9.2787e-01, -6.4806e+00, -5.2516e+00,  2.1788e+00],
         [ 2.8452e+00, -2.9124e+00, -2.3786e+00, -4.0556e-01]],

        [[ 1.1247e+00, -1.9965e-01, -7.0964e-01, -4.3275e-01],
         [ 7.1324e-01, -8.7891e-02,  8.8896e-01, -1.0161e-01],
         [-3.5822e-01, -1.7007e-01, -2.5581e-01, -1.2307e+00],
         [ 1.2570e+00, -1.0699e-01,  6.1011e-01,  1.7393e+00],
         [-2.6215e-01,  1.8219e+00,  1.4400e+00,  9.8

In [76]:
x.shape

torch.Size([2, 11, 4])

# Cls Tokens

In [78]:
b = 2 # bs
cls_token = nn.Parameter(torch.randn(1, 1, 4))
cls_tokens = repeat(cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x

tensor([[[-6.8002e-01, -2.5018e-01,  2.7479e-01,  1.7927e-01],
         [-6.2654e-01,  1.6764e+00,  2.5666e+00,  1.0077e-02],
         [ 7.1324e-01, -8.7891e-02,  8.8896e-01, -1.0161e-01],
         [ 1.1205e+00, -8.3974e-01,  2.9055e-01,  7.8428e-02],
         [ 8.4383e-01,  3.7866e-01, -2.6487e+00,  3.0309e+00],
         [ 2.8616e+00,  1.8804e+00,  1.7830e+00,  1.1806e+00],
         [ 3.0601e+00,  9.3024e-01,  4.1254e-01,  9.4132e-01],
         [-3.4176e+00, -5.4230e-01,  2.9324e+00, -9.9614e-01],
         [-3.0264e-01,  6.0320e-01, -2.0968e+00,  1.3997e-01],
         [ 7.0458e-01,  1.7296e+00,  7.3469e+00, -4.1400e+00],
         [-9.2787e-01, -6.4806e+00, -5.2516e+00,  2.1788e+00],
         [ 2.8452e+00, -2.9124e+00, -2.3786e+00, -4.0556e-01]],

        [[-6.8002e-01, -2.5018e-01,  2.7479e-01,  1.7927e-01],
         [ 1.1247e+00, -1.9965e-01, -7.0964e-01, -4.3275e-01],
         [ 7.1324e-01, -8.7891e-02,  8.8896e-01, -1.0161e-01],
         [-3.5822e-01, -1.7007e-01, -2.5581e-01, -1.2

In [79]:
x.shape

torch.Size([2, 12, 4])

# Transformer

In [80]:
trfm = Transformer(embed_dim=4,
                   depth=2,
                   num_heads=1,
                   attn_dropout=0.,
                   ff_dropout=0.,
                   )

In [82]:
x = trfm(x)
x

tensor([[[ 1.7170e+00,  6.0308e-01, -7.7584e-01, -2.0204e+00],
         [-3.1903e-03,  4.8822e+00,  7.0637e-01, -1.9588e+00],
         [ 2.5718e+00,  1.6446e+00, -1.7760e-01, -2.6262e+00],
         [ 3.5948e+00,  7.6540e-01, -6.5508e-01, -3.0554e+00],
         [ 4.3630e+00, -1.5198e+00, -6.1122e-01, -6.2730e-01],
         [ 6.7756e+00,  1.8411e+00,  2.1694e+00, -3.0805e+00],
         [ 6.6662e+00, -3.9355e-01,  2.1617e+00, -3.0901e+00],
         [-1.3417e+00,  1.5901e+00,  1.0613e+00, -3.3334e+00],
         [ 3.1904e+00, -1.0991e+00, -4.7316e-01, -3.2745e+00],
         [ 3.9085e+00,  3.5467e+00,  4.3235e+00, -6.1377e+00],
         [ 1.1349e+00, -5.5581e+00, -5.1775e+00, -8.8057e-01],
         [ 5.7184e+00, -2.2019e+00, -2.4690e+00, -3.8989e+00]],

        [[ 1.6682e+00,  1.4527e+00, -3.0288e+00, -5.6816e-01],
         [ 5.6886e+00, -5.2240e-01, -1.1869e+00, -4.1966e+00],
         [ 5.2694e+00, -6.0032e-02,  1.3743e-01, -3.9340e+00],
         [ 1.5544e+00,  2.5287e+00, -3.8554e+00, -2.2

In [83]:
x.shape

torch.Size([2, 12, 4])

# Get cls token

In [87]:
x = x[:, 0]
x

tensor([[ 1.7170,  0.6031, -0.7758, -2.0204],
        [ 1.6682,  1.4527, -3.0288, -0.5682]], grad_fn=<SelectBackward0>)

In [88]:
x.shape

torch.Size([2, 4])

# To Logits(Output)

In [92]:
dim_out=1
to_logits = nn.Sequential(nn.LayerNorm(4),
                         nn.ReLU(),
                         nn.Linear(4, dim_out)
                         )

In [93]:
output = to_logits(x)
output

tensor([[-0.9282],
        [-0.8931]], grad_fn=<AddmmBackward0>)

# APPENDIX 1 : Self Attention

In [65]:
# q : bs x d_L x embed_dim
# k : bs x d_s x embed_dim
# v : bs x d_s x embed_dim 
bs = 16
embed_dim = 32
d_L = 13 # columns
d_s = 3 # key columns

q = torch.randn(bs, d_L, embed_dim)
attn = SelfAttention(embed_dim=32, num_heads=8, dropout=0.)
attn(q).shape

torch.Size([16, 13, 32])

# APPENDIX 2 : Transformer

In [70]:
q = torch.randn(1, 15, embed_dim)
trfm = Transformer(embed_dim=embed_dim, depth=3, num_heads=8, attn_dropout=0.1, ff_dropout=0.1)
trfm(q).shape

torch.Size([1, 15, 32])

In [82]:
bs = 1
embed_dim = 32
col_dim = 8

q = torch.randn(bs, col_dim)
nembd = NumericalEmbedder(embed_dim, col_dim)
q = nembd(q) # bs x col_dim x embed_dim

# Transformer

num_heads = 16 # must be : embed_dim%num_heads == 0
attn = SelfAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=0.)
attn(q).shape # bs x col_dim x embed_dim

fdfd = FeedForward(embed_dim, hidden_mult=4, dropout = 0.)
fdfd(q).shape # bs x col_dim x embed_dim

torch.Size([1, 8, 32])

# APPENDIX 3 : FTTransformer GitHub

https://github.com/lucidrains/tab-transformer-pytorch/blob/main/tab_transformer_pytorch/ft_transformer.py

In [5]:
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat

# feedforward and attention

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        h = self.heads

        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        q = q * self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

# transformer

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads,
        dim_head,
        attn_dropout,
        ff_dropout
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
                FeedForward(dim, dropout = ff_dropout),
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return x

# numerical embedder

class NumericalEmbedder(nn.Module):
    def __init__(self, dim, num_numerical_types):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
        self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))

    def forward(self, x):
        x = rearrange(x, 'b n -> b n 1')
        return x * self.weights + self.biases

# main class

class FTTransformer(nn.Module):
    def __init__(
        self,
        *,
        categories,
        num_continuous,
        dim,
        depth,
        heads,
        dim_head = 16,
        dim_out = 1,
        num_special_tokens = 2,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()
        assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'

        # categories related calculations

        self.num_categories = len(categories)
        self.num_unique_categories = sum(categories)

        # create category embeddings table

        self.num_special_tokens = num_special_tokens
        total_tokens = self.num_unique_categories + num_special_tokens

        # for automatically offsetting unique category ids to the correct position in the categories embedding table

        categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
        categories_offset = categories_offset.cumsum(dim = -1)[:-1]
        self.register_buffer('categories_offset', categories_offset)

        # categorical embedding

        self.categorical_embeds = nn.Embedding(total_tokens, dim)

        # continuous

        self.numerical_embedder = NumericalEmbedder(dim, num_continuous)

        # cls token

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # transformer

        self.transformer = Transformer(            
            dim = dim,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout
        )

        # to logits

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.ReLU(),
            nn.Linear(dim, dim_out)
        )

    def forward(self, x_categ, x_numer):
        b = x_categ.shape[0]

        assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
        x_categ += self.categories_offset

        x_categ = self.categorical_embeds(x_categ)

        # add numerically embedded tokens

        x_numer = self.numerical_embedder(x_numer)

        # concat categorical and numerical

        x = torch.cat((x_categ, x_numer), dim = 1)

        # append cls tokens

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim = 1)

        # attend

        x = self.transformer(x)

        # get cls token

        x = x[:, 0]

        # out in the paper is linear(relu(ln(cls)))

        return self.to_logits(x)

# To Do

* Why is cls token needed in this model