In [None]:
import torch
import torch.nn as nn

class FTTransformer(nn.Module):
    def __init__(self, category_sizes, num_continuous, emb_dim=32, n_heads=8, depth=6, dim_out=1):
        super().__init__()
        self.n_cat = len(category_sizes)
        self.n_num = num_continuous
        # Categorical embeddings and feature biases
        self.cat_embeddings = nn.ModuleList([
            nn.Embedding(num_cat, emb_dim) for num_cat in category_sizes
        ])
        # Learnable weight and bias for each numerical feature
        self.num_weight = nn.Parameter(torch.randn(num_continuous, emb_dim))
        self.num_bias   = nn.Parameter(torch.zeros(num_continuous, emb_dim))
        # Feature bias for categorical features (same role as above bias, can merge with embedding as single param)
        self.cat_bias   = nn.Parameter(torch.zeros(len(category_sizes), emb_dim))
        # CLS token embedding
        self.cls_token  = nn.Parameter(torch.zeros(1, emb_dim))
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads, dim_feedforward=4*emb_dim)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        # Prediction head
        self.head = nn.Linear(emb_dim, dim_out)
    
    def forward(self, x_categ, x_cont):
        batch_size = x_cont.size(0)
        # Tokenize categorical features:
        cat_tokens = []
        for j, embed in enumerate(self.cat_embeddings):
            cat_tok = embed(x_categ[:, j]) + self.cat_bias[j]  # [batch, emb_dim]
            cat_tokens.append(cat_tok)
        # Tokenize numeric features:
        num_tokens = []
        # x_cont shape: [batch, n_num]
        # Use learned weight & bias: essentially x * W + bias for each feature
        for j in range(self.n_num):
            # elementwise multiplication of scalar feature by weight vector
            num_tok = x_cont[:, j].unsqueeze(-1) * self.num_weight[j] + self.num_bias[j]
            # num_tok shape [batch, emb_dim] (broadcast multiplication)
            num_tokens.append(num_tok)
        # Stack all feature tokens and append CLS token
        # feature_tokens: [batch, n_cat + n_num, emb_dim]
        feature_tokens = torch.cat(cat_tokens + num_tokens, dim=1).view(batch_size, -1, emb_dim)
        # Append CLS token at position 0:
        cls_token_batch = self.cls_token.expand(batch_size, 1, emb_dim)  # [batch, 1, emb_dim]
        tokens = torch.cat([cls_token_batch, feature_tokens], dim=1)      # [batch, k+1, emb_dim]
        # Transformer expects [seq_len, batch, emb_dim] by default
        tokens = tokens.permute(1, 0, 2)
        out = self.transformer(tokens)        # [seq_len, batch, emb_dim]
        out = out.permute(1, 0, 2)            # [batch, seq_len, emb_dim]
        # Take output corresponding to CLS token (position 0)
        cls_out = out[:, 0, :]               # [batch, emb_dim]
        # Final prediction
        return self.head(cls_out)            # [batch, dim_out]
