<a href="https://colab.research.google.com/github/edypidy/SkyElephant-not-a-FlyingElephant/blob/main/CustomModel/BiconFTTransformer_CustomModel_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install einops
from einops import repeat

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Model

In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# numerical embedder

class NumericalEmbedder(nn.Module):
    def __init__(self, dim, num_numerical_types):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
        self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))

    def forward(self, x):
        x = x.unsqueeze(-1)
        return x * self.weights + self.biases


# Feedforward

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, in_dim, hidden_mult = 4, dropout = 0.):
        super().__init__()
        self.Layer1 = nn.Sequential(nn.LayerNorm(in_dim),
                                    nn.Linear(in_dim, in_dim*hidden_mult*2),
                                    GEGLU(),
                                    nn.Dropout(dropout))
        self.Layer2 = nn.Linear(in_dim*hidden_mult, in_dim)
        self.norm = nn.LayerNorm(in_dim)
    
    def forward(self, x):
        output = self.Layer1(x)
        output = self.Layer2(output)
        output = self.norm(output)
        output = output + x # residual
        return output


# Attention for Binary Conditions

class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads=8, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)

    def forward(self, x, k, v):
        output = self.attn(x,k,v)[0]
        output = self.norm(output)
        output = output + x # residual
        return output


# Self Attention

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)

    def forward(self, x):
        output = self.attn(x,x,x)[0]
        output = self.norm(output)
        output = output + x # residual
        return output


# Transformer

class Transformer(nn.Module):
    def __init__(self, embed_dim, depth, num_heads, attn_dropout, ff_dropout):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                SelfAttention(embed_dim, num_heads=num_heads, dropout=attn_dropout),
                FeedForward(embed_dim, dropout=ff_dropout),
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)

        return x

In [70]:
class BiconFTTransformer(nn.Module):
    def __init__(self, *,
        categories,
        num_continuous,
        num_bicons, # Number of Binary Conditions (Input)
        embed_dim = 16,
        depth = 2,
        heads = 8,
        dim_out = 1,
        num_special_tokens = 2,
        attn_dropout = 0.,
        ff_dropout = 0.):
        
        super().__init__()

        # Treat Categories

        self.num_categories = len(categories)
        self.num_unique_categories = sum(categories)

        # Create category embeddings table

        self.num_special_tokens = num_special_tokens # Since add categories_offset to x_categories, first 'num_special_tokens' special tokens mean NA
        total_tokens = self.num_unique_categories + num_special_tokens
        # embedding table
        self.categorical_embeds = nn.Embedding(total_tokens, embed_dim) # LookUp Table : total_tokens x embed_dim

        # offset of categories for the categories embedding table like positional encoding (Alternative methodology from paper)
        categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
        categories_offset = categories_offset.cumsum(dim = -1)[:-1] # by cumsuming so every category is distinguished
        self.register_buffer('categories_offset', categories_offset) # categories offset must be unlearnable


        # Treat Continuous

        self.numerical_embedder = NumericalEmbedder(embed_dim, num_continuous)
        

        # Treat Binary Condition

        self.bicon_embeds = nn.Embedding(2*num_bicons, embed_dim)
        bicon_offset = torch.arange(0,2*num_bicons,2) # every Binary Condition is distinguished
        self.register_buffer('bicon_offset', bicon_offset) # bicon offset must be unlearnable


        # cls token

        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))


        # FeedForward & Attention for Bicon

        self.feedfoward = FeedForward(in_dim=embed_dim,
                                      hidden_mult = 4,
                                      dropout = 0.)

        self.attention = Attention(embed_dim=embed_dim,
                                   num_heads=8,
                                   dropout=0.)


        # Transformer

        self.transformer = Transformer(embed_dim=embed_dim,
                                       depth=depth,
                                       num_heads=heads,
                                       attn_dropout=attn_dropout,
                                       ff_dropout=ff_dropout,
                                       )


        # To logits

        self.to_logits = nn.Sequential(nn.LayerNorm(embed_dim),
                                       nn.ReLU(),
                                       nn.Linear(embed_dim, dim_out)
                                       )




    def forward(self, x_categ, x_numer, x_bicon):
        b = x_categ.shape[0] # batch size

        assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
        x_categ += self.categories_offset

        x_categ = self.categorical_embeds(x_categ) # Categories Embedding is 'LookUp Table' method => batch x categ_col_nums x embed_dim

        # add numerically embedded tokens

        x_numer = self.numerical_embedder(x_numer)

        # concat categorical and numerical

        x = torch.cat((x_categ, x_numer), dim = 1)

        # Append cls tokens by batch == torch.cat([self.cls_token for _ in range(b)], dim=0)

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim = 1)

        # Tabular transformer

        x = self.transformer(x)

        # bicon

        x_bicon += self.bicon_offset
        x_bicon = self.bicon_embeds(x_bicon)
        x = self.attention(x, x_bicon, x_bicon)
        x = self.feedfoward(x)

        # get cls token

        x = x[:, 0]


        return self.to_logits(x)

In [75]:
model = BiconFTTransformer(
                            categories = (2,3,4,5,6),
                            num_continuous = 6,
                            num_bicons = 5, # Number of Binary Conditions (Input)
                            embed_dim = 16,
                            depth = 2,
                            heads = 8,
                            dim_out = 1,
                            num_special_tokens = 2,
                            attn_dropout = 0.,
                            ff_dropout = 0.,
                            )                   

In [76]:
x_categ = torch.tensor([[0,2,2,3,5],
                        [1,2,3,4,4]])

x_numer = torch.tensor([[1,2,3,4,5,6],
                        [7,8,9,10,11,12]])

x_bicon = torch.tensor([[1,0,1,1,0],
                        [0,1,0,1,1]])

In [77]:
pred = model(x_categ, x_numer, x_bicon)
pred

tensor([[-0.1343],
        [-0.2419]], grad_fn=<AddmmBackward0>)

# GEGLU

In [5]:
# GEGLUE
x = torch.randn(5*2)
x, gate = x.chunk(2, dim = -1)
x * F.gelu(gate)

tensor([-0.0279,  0.0397, -0.6922,  0.1041,  0.0402])

# Category embedding table & Offset encoding

In [6]:
categories = (2,3,4,5,6)
em = nn.Embedding(sum(categories)+2, 4)
em.weight

Parameter containing:
tensor([[-1.3067e+00,  1.2984e+00, -1.0242e+00, -1.7189e+00],
        [ 6.5156e-01,  6.2607e-01, -1.4641e+00,  8.3954e-01],
        [ 4.6663e-01,  1.3844e+00, -1.3162e+00,  1.2584e-01],
        [-1.3158e+00, -4.1214e-01,  5.9935e-01,  1.0766e+00],
        [-6.8095e-01, -5.2712e-01,  3.7826e-01,  3.9366e-01],
        [ 8.1831e-01, -5.4963e-01, -1.7696e+00, -3.1709e-01],
        [ 2.5596e-01,  1.3004e+00,  5.5314e-02,  2.1055e-01],
        [-1.4878e+00,  7.4846e-01, -1.3628e+00, -7.1766e-01],
        [-7.4579e-02,  2.0124e-01,  6.7340e-01,  8.0523e-01],
        [-8.3029e-01, -3.1396e-01, -2.5810e-01, -1.2058e-01],
        [ 8.4641e-01, -2.5466e+00,  1.3709e+00,  1.9676e+00],
        [-7.5616e-01, -6.6290e-01,  7.3224e-01,  1.8029e+00],
        [ 1.1537e+00,  1.2258e+00,  1.5889e+00,  3.6194e-01],
        [-4.6227e-01,  6.4212e-01,  1.4638e+00, -7.1945e-01],
        [-9.2603e-01,  1.0327e+00, -2.1214e-01, -1.8899e+00],
        [-8.1364e-01, -2.2753e-01,  1.5726e+00, 

In [7]:
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = 2)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
categories_offset

tensor([ 2,  4,  7, 11, 16])

In [8]:
x_categ = torch.tensor([[0,2,2,3,5],
                        [1,2,3,4,4]])
x_categ += categories_offset
x_categ

tensor([[ 2,  6,  9, 14, 21],
        [ 3,  6, 10, 15, 20]])

In [9]:
x_categ = em(x_categ)
x_categ

tensor([[[ 4.6663e-01,  1.3844e+00, -1.3162e+00,  1.2584e-01],
         [ 2.5596e-01,  1.3004e+00,  5.5314e-02,  2.1055e-01],
         [-8.3029e-01, -3.1396e-01, -2.5810e-01, -1.2058e-01],
         [-9.2603e-01,  1.0327e+00, -2.1214e-01, -1.8899e+00],
         [ 5.4358e-01,  7.4395e-01,  3.9582e-04, -2.5787e+00]],

        [[-1.3158e+00, -4.1214e-01,  5.9935e-01,  1.0766e+00],
         [ 2.5596e-01,  1.3004e+00,  5.5314e-02,  2.1055e-01],
         [ 8.4641e-01, -2.5466e+00,  1.3709e+00,  1.9676e+00],
         [-8.1364e-01, -2.2753e-01,  1.5726e+00, -4.5318e-01],
         [-9.9903e-01, -9.2784e-02,  3.9141e-01,  2.9928e-01]]],
       grad_fn=<EmbeddingBackward0>)

In [10]:
x_categ.shape

torch.Size([2, 5, 4])

# Numerical embedding

In [11]:
nem = NumericalEmbedder(dim=4, num_numerical_types=6)

In [12]:
x_numeric = torch.tensor([[1,2,3,4,5,6],
                          [7,8,9,10,11,12]])
x_numeric = nem(x_numeric)
x_numeric

tensor([[[ -1.8792,   0.4936,  -0.9040,   0.2564],
         [  0.7676,   2.1136,  -2.5433,   4.9248],
         [ -1.5781,  -1.3108,  -1.8789,  -3.2565],
         [ -0.3217,  -2.3346,   1.2943,  -5.1166],
         [ -4.1135,   8.9746,   3.7493,  -4.6333],
         [  8.1482,  -5.4053,   1.7416,  -4.0964]],

        [[ -2.3383,   3.7103,  -7.4790,   2.6496],
         [  0.4137,   9.6544, -11.0666,  17.7024],
         [ -4.5528,  -5.2269,  -9.9656, -11.5588],
         [ -2.3793,  -5.4869,   1.8560,  -9.6100],
         [-11.9771,  18.2670,   7.4803, -10.0707],
         [ 16.9347, -11.5030,   3.9160,  -8.8078]]], grad_fn=<AddBackward0>)

In [13]:
x_numeric.shape

torch.Size([2, 6, 4])

# Cat(x_categ, x_numer)

In [14]:
x = torch.cat((x_categ, x_numeric), dim = 1)
x

tensor([[[ 4.6663e-01,  1.3844e+00, -1.3162e+00,  1.2584e-01],
         [ 2.5596e-01,  1.3004e+00,  5.5314e-02,  2.1055e-01],
         [-8.3029e-01, -3.1396e-01, -2.5810e-01, -1.2058e-01],
         [-9.2603e-01,  1.0327e+00, -2.1214e-01, -1.8899e+00],
         [ 5.4358e-01,  7.4395e-01,  3.9582e-04, -2.5787e+00],
         [-1.8792e+00,  4.9361e-01, -9.0405e-01,  2.5636e-01],
         [ 7.6763e-01,  2.1136e+00, -2.5433e+00,  4.9248e+00],
         [-1.5781e+00, -1.3108e+00, -1.8789e+00, -3.2565e+00],
         [-3.2170e-01, -2.3346e+00,  1.2943e+00, -5.1166e+00],
         [-4.1135e+00,  8.9746e+00,  3.7493e+00, -4.6333e+00],
         [ 8.1482e+00, -5.4053e+00,  1.7416e+00, -4.0964e+00]],

        [[-1.3158e+00, -4.1214e-01,  5.9935e-01,  1.0766e+00],
         [ 2.5596e-01,  1.3004e+00,  5.5314e-02,  2.1055e-01],
         [ 8.4641e-01, -2.5466e+00,  1.3709e+00,  1.9676e+00],
         [-8.1364e-01, -2.2753e-01,  1.5726e+00, -4.5318e-01],
         [-9.9903e-01, -9.2784e-02,  3.9141e-01,  2.9

In [15]:
x.shape

torch.Size([2, 11, 4])

# Cls Tokens

In [16]:
# cls token means target's latent variable
b = 2 # bs
cls_token = nn.Parameter(torch.randn(1, 1, 4))
cls_tokens = repeat(cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x

tensor([[[-1.4014e+00,  1.2479e+00,  1.7844e+00,  4.9004e-01],
         [ 4.6663e-01,  1.3844e+00, -1.3162e+00,  1.2584e-01],
         [ 2.5596e-01,  1.3004e+00,  5.5314e-02,  2.1055e-01],
         [-8.3029e-01, -3.1396e-01, -2.5810e-01, -1.2058e-01],
         [-9.2603e-01,  1.0327e+00, -2.1214e-01, -1.8899e+00],
         [ 5.4358e-01,  7.4395e-01,  3.9582e-04, -2.5787e+00],
         [-1.8792e+00,  4.9361e-01, -9.0405e-01,  2.5636e-01],
         [ 7.6763e-01,  2.1136e+00, -2.5433e+00,  4.9248e+00],
         [-1.5781e+00, -1.3108e+00, -1.8789e+00, -3.2565e+00],
         [-3.2170e-01, -2.3346e+00,  1.2943e+00, -5.1166e+00],
         [-4.1135e+00,  8.9746e+00,  3.7493e+00, -4.6333e+00],
         [ 8.1482e+00, -5.4053e+00,  1.7416e+00, -4.0964e+00]],

        [[-1.4014e+00,  1.2479e+00,  1.7844e+00,  4.9004e-01],
         [-1.3158e+00, -4.1214e-01,  5.9935e-01,  1.0766e+00],
         [ 2.5596e-01,  1.3004e+00,  5.5314e-02,  2.1055e-01],
         [ 8.4641e-01, -2.5466e+00,  1.3709e+00,  1.9

In [17]:
x.shape

torch.Size([2, 12, 4])

# Transformer

In [18]:
trfm = Transformer(embed_dim=4,
                   depth=2,
                   num_heads=1,
                   attn_dropout=0.,
                   ff_dropout=0.,
                   )

In [19]:
x = trfm(x)
x

tensor([[[ -0.7888,   2.6764,   1.2417,  -1.0084],
         [  3.7996,  -0.1730,  -1.8198,  -1.1460],
         [  2.3503,   1.2890,  -0.0215,  -1.7956],
         [  1.8441,  -0.5871,  -1.8705,  -0.9094],
         [  1.4295,  -2.1434,  -1.7860,   0.5047],
         [  2.5898,  -2.5680,   0.3917,  -1.7043],
         [  0.3994,   0.7377,  -3.1928,   0.0224],
         [  4.6120,   1.3486,  -3.8567,   3.1588],
         [ -1.3121,  -3.9626,  -1.7074,  -1.0421],
         [  0.8802,  -2.8357,  -0.2089,  -4.3142],
         [ -2.4442,   7.1592,   5.5957,  -6.3336],
         [ 11.0505,  -6.6187,   1.6975,  -5.7412]],

        [[  0.8180,  -1.1675,   3.1470,  -0.6766],
         [  2.0998,  -0.5965,   0.1763,  -1.7317],
         [  3.3142,  -0.2662,   0.4821,  -1.7079],
         [  2.4479,   0.8118,   0.2092,  -1.8307],
         [  1.4139,  -1.1606,   1.9425,  -2.1176],
         [  0.3630,  -2.0077,   2.0484,  -0.8048],
         [ -1.4120,   2.3614,  -6.8060,   2.3992],
         [  3.6653,   9.8020,

In [20]:
x.shape

torch.Size([2, 12, 4])

# Binary Condition ATTENTION

In [21]:
embed_dim = 4
bin_size = 5

bem = nn.Embedding(2*bin_size, embed_dim)
bem.weight

Parameter containing:
tensor([[-1.0635,  0.8810, -1.7549,  0.2084],
        [ 0.7246,  0.1679, -0.7772, -0.7525],
        [ 0.1409,  0.5818, -0.9637,  0.2260],
        [-0.3102,  0.6790,  0.6046,  0.0792],
        [ 0.2471,  0.5555,  1.8950,  0.8578],
        [-0.5641,  0.8336,  0.5755, -0.8082],
        [-0.4365,  0.2985,  0.9872,  0.4378],
        [ 0.1914,  0.0812, -1.0502, -0.7422],
        [-0.0071, -0.3820,  1.1363, -0.9915],
        [ 0.5141,  0.3512,  0.1439,  1.1197]], requires_grad=True)

In [22]:
bin_condition_tensor = torch.tensor([[1,0,1,1,0],
                                     [0,1,0,1,1]])
bin_offset = torch.arange(0,2*bin_size,2)
bin_condition_tensor += bin_offset
bin_k = bem(bin_condition_tensor)

bin_v = bin_k
bin_k

tensor([[[ 0.7246,  0.1679, -0.7772, -0.7525],
         [ 0.1409,  0.5818, -0.9637,  0.2260],
         [-0.5641,  0.8336,  0.5755, -0.8082],
         [ 0.1914,  0.0812, -1.0502, -0.7422],
         [-0.0071, -0.3820,  1.1363, -0.9915]],

        [[-1.0635,  0.8810, -1.7549,  0.2084],
         [-0.3102,  0.6790,  0.6046,  0.0792],
         [ 0.2471,  0.5555,  1.8950,  0.8578],
         [ 0.1914,  0.0812, -1.0502, -0.7422],
         [ 0.5141,  0.3512,  0.1439,  1.1197]]], grad_fn=<EmbeddingBackward0>)

In [23]:
bcattn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=2, dropout=0., batch_first=True)
x = bcattn(x, bin_k, bin_v)[0]
x

tensor([[[ 4.6725e-01, -1.6088e-01, -1.8428e-01,  1.3721e-01],
         [ 2.6983e-01, -1.4563e-01, -2.1253e-01, -3.9766e-02],
         [ 4.4512e-01, -2.0049e-01, -2.6601e-01,  7.0584e-03],
         [ 8.0603e-02, -5.0830e-02, -7.4027e-02, -6.6657e-03],
         [-5.2986e-02,  9.8208e-03,  1.2880e-02, -7.9237e-03],
         [ 3.4141e-01, -1.5935e-01, -2.1669e-01,  5.0772e-03],
         [-9.1156e-02,  2.8843e-02,  4.2193e-02, -4.9003e-03],
         [-2.8066e-02, -2.7779e-02, -6.2709e-02, -7.9504e-02],
         [-1.2866e-01,  6.4153e-02,  1.0513e-01,  4.1912e-02],
         [ 2.5326e-01, -1.1290e-01, -1.5010e-01,  3.2531e-02],
         [ 7.1787e-01, -2.2565e-01, -2.3679e-01,  2.2749e-01],
         [ 4.7098e-01, -2.4602e-01, -3.4683e-01, -1.0469e-01]],

        [[ 3.7504e-02, -8.9181e-02, -1.6738e-01, -3.9273e-01],
         [ 6.8236e-03, -6.3006e-02, -4.3001e-02, -2.3421e-01],
         [ 8.8369e-02, -9.7678e-02, -4.6880e-02, -2.7342e-01],
         [ 7.5746e-02, -8.8275e-02, -5.5150e-02, -2.6

In [24]:
x.shape

torch.Size([2, 12, 4])

## feedfoward

In [25]:
fdfd = FeedForward(in_dim=embed_dim)
x = fdfd(x)
x

tensor([[[-1.2004, -0.0074,  0.4194,  1.0477],
         [-1.2976, -0.2404,  0.3447,  1.0652],
         [-1.1598, -0.2166,  0.3126,  1.0496],
         [-1.5231, -0.0636,  0.4903,  1.0455],
         [-1.0616,  1.1393,  0.8740, -0.9899],
         [-1.2654, -0.1694,  0.3590,  1.0463],
         [-1.1129,  1.1745,  0.8843, -0.9709],
         [-0.9264, -0.9203,  0.2057,  1.4429],
         [-1.2527,  1.1561,  0.9995, -0.8203],
         [-1.3827, -0.0493,  0.4347,  1.0201],
         [-0.9534, -0.0618,  0.3724,  1.1257],
         [-1.0659, -0.4066,  0.2115,  1.0345]],

        [[-0.4282, -1.2404, -0.1149,  1.1718],
         [ 0.2498, -1.6024,  0.0045,  1.0147],
         [ 0.2840, -1.7548,  0.4178,  0.7234],
         [ 0.1985, -1.7062,  0.3438,  0.8310],
         [-0.2567, -1.4072,  0.0759,  1.1413],
         [-0.7325, -0.5967, -0.6253,  1.4166],
         [-1.6178,  1.0785,  1.0334, -0.8505],
         [-1.6273,  1.0800,  1.0337, -0.8501],
         [-1.6277,  1.0790,  1.0339, -0.8425],
         [-

# Get cls token

In [26]:
x = x[:, 0]
x

tensor([[-1.2004, -0.0074,  0.4194,  1.0477],
        [-0.4282, -1.2404, -0.1149,  1.1718]], grad_fn=<SelectBackward0>)

In [27]:
x.shape

torch.Size([2, 4])

# To Logits(Output)

In [61]:
dim_out=1
to_logits = nn.Sequential(nn.LayerNorm(4),
                         nn.ReLU(),
                         nn.Linear(4, dim_out)
                         )

In [62]:
output = to_logits(x)
output

tensor([[0.3770],
        [0.5120]], grad_fn=<AddmmBackward0>)

In [63]:
label = torch.tensor([[1.],
                      [0.]])
label = label

In [64]:
output

tensor([[0.3770],
        [0.5120]], grad_fn=<AddmmBackward0>)

In [65]:
label

tensor([[1.],
        [0.]])

In [66]:
loss_fn = nn.BCEWithLogitsLoss()

In [67]:
loss_fn(output, label)

tensor(0.7519, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

# APPENDIX 1 : Self Attention

In [38]:
# q : bs x d_L x embed_dim
# k : bs x d_s x embed_dim
# v : bs x d_s x embed_dim 
bs = 2
embed_dim = 4
d_L = 12 # columns
d_s = 5 # key columns

x = torch.randn(bs, d_L, embed_dim)
attn = SelfAttention(embed_dim=embed_dim, num_heads=2, dropout=0.)
x = attn(x)
x

tensor([[[ 0.0579,  0.0130, -0.7549,  0.5313],
         [ 0.7225,  0.6496, -1.9422,  0.8639],
         [-0.1123, -0.6458, -1.5604,  0.9912],
         [ 1.2144,  0.1313,  1.0603, -0.8106],
         [ 0.4380,  0.1359,  1.4092, -1.1720],
         [-1.7602, -1.0861, -0.4067,  0.2237],
         [ 0.0419, -0.0641,  0.1427, -0.1083],
         [-0.7174, -0.1175, -0.4785,  0.0666],
         [ 1.5394,  1.0056, -0.9245,  0.8754],
         [ 0.4132,  0.5687,  0.4832, -0.8695],
         [-0.2170,  0.3749,  1.6618, -0.9050],
         [ 0.3322, -1.4531,  1.6286, -0.6463]],

        [[ 0.1636, -0.0983,  0.4336, -2.2646],
         [ 1.3706,  0.7240, -1.0046, -1.0719],
         [ 1.5137,  1.8197,  0.0781, -1.5781],
         [ 0.3187, -1.1351, -1.2591,  0.0212],
         [ 0.2089,  1.5636,  0.8218, -2.4446],
         [-0.2044,  0.1289,  0.1894, -1.5802],
         [ 1.9724,  1.1960,  1.0150, -1.2885],
         [-0.4001, -0.4109,  0.6811, -3.2625],
         [ 0.9830, -0.1506,  1.8850, -2.2997],
         [ 

In [39]:
x.shape

torch.Size([2, 12, 4])

In [40]:
fdfd = FeedForward(in_dim=embed_dim)
x = fdfd(x)
x

tensor([[[-0.0191,  0.2112, -2.2198,  1.8751],
         [ 0.2210,  1.8677, -3.3013,  1.5063],
         [ 0.3734, -1.0315, -2.9535,  2.2842],
         [-0.2879,  1.1304,  1.8668, -1.1138],
         [-1.0593,  0.8167,  2.5105, -1.4567],
         [-0.7525, -1.1819, -1.9925,  0.8977],
         [-1.0067,  0.7492,  1.3117, -1.0420],
         [ 0.4734, -1.1125, -1.4627,  0.8550],
         [ 1.4386,  2.2276, -2.4494,  1.2791],
         [-1.0966,  1.6847,  0.2358, -0.2282],
         [-1.6550,  0.8717,  2.9204, -1.2225],
         [-0.0924, -0.5290,  2.5706, -2.0877]],

        [[-1.5330,  0.7939,  0.8502, -1.8769],
         [ 1.9390,  1.5469, -2.7069, -0.7610],
         [ 1.7843,  2.9085, -1.5542, -1.3051],
         [ 1.5474, -2.2210, -2.1354,  0.7546],
         [-1.1386,  2.9425,  0.4342, -2.0884],
         [-1.8980,  1.0335,  0.5967, -1.1986],
         [ 0.7988,  2.1834,  0.2061, -0.2935],
         [-2.0653,  0.3151,  1.5099, -3.1521],
         [-0.5341,  0.6537,  2.8725, -2.5745],
         [-

In [41]:
x.shape

torch.Size([2, 12, 4])

# APPENDIX 2 : Transformer

In [None]:
q = torch.randn(1, 15, embed_dim)
trfm = Transformer(embed_dim=embed_dim, depth=3, num_heads=8, attn_dropout=0.1, ff_dropout=0.1)
trfm(q).shape

torch.Size([1, 15, 32])

In [None]:
bs = 1
embed_dim = 32
col_dim = 8

q = torch.randn(bs, col_dim)
nembd = NumericalEmbedder(embed_dim, col_dim)
q = nembd(q) # bs x col_dim x embed_dim

# Transformer

num_heads = 16 # must be : embed_dim%num_heads == 0
attn = SelfAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=0.)
attn(q).shape # bs x col_dim x embed_dim

fdfd = FeedForward(embed_dim, hidden_mult=4, dropout = 0.)
fdfd(q).shape # bs x col_dim x embed_dim

torch.Size([1, 8, 32])