In [1]:
import pandas as pd

In [2]:
# from tab_transformer_pytorch import FTTransformer

In [3]:
# NOTE: I copy-pasted the source code directly because I wanted to modify it
# such that the FTTransformer does not use the final MLP for any classification
# In my modified version, we get the output of the transformer during feedforward, NOT MLP output

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat

# feedforward and attention

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

In [37]:
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        print("to_qkv layer # output nodes:", inner_dim)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, explainable=False):
        h = self.heads
        
        print("x shape before self.norm(x):", x.shape)
        x = self.norm(x)
        print("x shape after self.norm(x):", x.shape)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        print("q shape after self.to_qkv(x):", q.shape)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        print("q shape after rearrange:", q.shape)
        q = q * self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        
        if explainable:
            return self.to_out
        return self.to_out(out)

In [5]:
# transformer

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads,
        dim_head,
        attn_dropout,
        ff_dropout
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
                FeedForward(dim, dropout = ff_dropout),
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return x

In [6]:
# numerical embedder

class NumericalEmbedder(nn.Module):
    def __init__(self, dim, num_numerical_types):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
        self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))

    def forward(self, x):
        x = rearrange(x, 'b n -> b n 1')
        return x * self.weights + self.biases

In [7]:
# main class

class FTTransformer(nn.Module):
    def __init__(
        self,
        *,
        categories,
        num_continuous,
        dim,
        depth,
        heads,
        dim_head = 16,
        dim_out = 1,
        num_special_tokens = 2,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()
        assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'

        # categories related calculations

        self.num_categories = len(categories)
        self.num_unique_categories = sum(categories)

        # create category embeddings table

        self.num_special_tokens = num_special_tokens
        total_tokens = self.num_unique_categories + num_special_tokens

        # for automatically offsetting unique category ids to the correct position in the categories embedding table

        categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
        categories_offset = categories_offset.cumsum(dim = -1)[:-1]
        self.register_buffer('categories_offset', categories_offset)

        # categorical embedding

        self.categorical_embeds = nn.Embedding(total_tokens, dim)

        # continuous

        self.numerical_embedder = NumericalEmbedder(dim, num_continuous)

        # cls token

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # transformer

        self.transformer = Transformer(            
            dim = dim,
            depth = depth,
            heads = heads,
            dim_head = dim_head,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout
        )

        # to logits

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.ReLU(),
            nn.Linear(dim, dim_out)
        )

    def forward(self, x_categ, x_numer):
        b = x_categ.shape[0]

        assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
        x_categ += self.categories_offset

        x_categ = self.categorical_embeds(x_categ)

        # add numerically embedded tokens

        x_numer = self.numerical_embedder(x_numer)

        # concat categorical and numerical

        x = torch.cat((x_categ, x_numer), dim = 1)

        # append cls tokens

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim = 1)

        # attend

        x = self.transformer(x)

        # get cls token

        x = x[:, 0]

        # out in the paper is linear(relu(ln(cls)))

#         return self.to_logits(x)
        return x

### Load data

In [29]:
data_df = pd.read_csv("matt_metadata_norm.csv", sep='\t')

In [30]:
data_df.head()

Unnamed: 0,Patient,Age,Gender,Body temperature,Underlying diseases,MCHC,MCH,MCV,HCT,HGB,...,FDG,LPS,U,UALB,BCF8,ASO,PS,RF,PC,LAC
0,Patient 1,81.0,0,0.983871,1,0.736842,0.6,0.483333,-0.42,-0.155556,...,0.5,0.5,0.5,0.5,-0.86875,0.5,-0.991667,0.5,-0.992857,0.5
1,Patient 2,50.0,0,1.040323,1,0.868421,0.642857,0.455556,0.15,0.333333,...,0.5,0.5,0.5,0.5,-0.86875,0.5,-0.991667,0.5,-0.992857,0.5
2,Patient 3,65.0,1,1.034946,1,0.368421,0.585714,0.677778,-0.96,-0.644444,...,0.5,0.5,0.5,0.5,-0.86875,0.5,-0.991667,0.5,-0.992857,0.5
3,Patient 4,73.0,0,1.034946,1,0.552632,0.528571,0.516667,0.09,0.177778,...,0.5,0.5,0.5,0.5,-0.86875,0.5,-0.991667,0.5,-0.992857,0.5
4,Patient 5,64.0,1,1.021505,1,0.342105,0.485714,0.577778,0.15,0.133333,...,0.5,0.5,0.5,0.5,-0.86875,0.5,-0.991667,0.5,-0.992857,0.5


In [31]:
cat_features = ["Gender", "Underlying diseases"]
cont_features = [col for col in list(data_df.columns) if col not in cat_features+["Patient"]]

In [32]:
print("Num cat_features:", len(cat_features))
print("Num cont_features:", len(cont_features))

Num cat_features: 2
Num cont_features: 125


In [38]:
X_cat = torch.Tensor(data_df[cat_features].values).to(torch.int64)
X_cont = torch.Tensor(data_df[cont_features].values)

In [39]:
print("X_cat shape:", X_cat.shape)
print("X_cont shape:", X_cont.shape)

X_cat shape: torch.Size([1521, 2])
X_cont shape: torch.Size([1521, 125])


### Model

In [40]:
model = FTTransformer(
    categories = (2, 2),      # Gender and Udis
    num_continuous = len(cont_features),     # number of continuous values
    dim = 32,                 # dimension of transformer input and output, paper set at 32
    dim_out = 1,              # dimension of MLP output (ignored here)
    depth = 1,                # depth, paper recommended 6
    heads = 8,                # heads, paper recommends 8
    attn_dropout = 0.1,       # post-attention dropout
    ff_dropout = 0.1          # feed forward dropout
)

to_qkv layer # output nodes: 128


### Feedforward

In [41]:
out = model(X_cat, X_cont)

x shape before self.norm(x): torch.Size([1521, 128, 32])
x shape after self.norm(x): torch.Size([1521, 128, 32])
q shape after self.to_qkv(x): torch.Size([1521, 128, 128])
q shape after rearrange: torch.Size([1521, 8, 128, 16])


In [44]:
out

tensor([[-0.5913,  0.6638, -0.4280,  ..., -0.2831, -0.8897,  0.3718],
        [-0.5915,  0.7084, -0.3349,  ..., -0.1663, -0.8588,  0.4356],
        [-0.7374,  0.6315, -0.4598,  ...,  0.0166, -0.8086,  0.3691],
        ...,
        [-0.6397,  0.7095, -0.3896,  ..., -0.1789, -0.8906,  0.4230],
        [-0.6200,  0.6721, -0.4066,  ..., -0.1645, -0.8418,  0.3589],
        [-0.5875,  0.7461, -0.3216,  ..., -0.1899, -0.8700,  0.4087]],
       grad_fn=<SelectBackward0>)

In [45]:
out.shape

torch.Size([1521, 32])