# Tutorial 3: Numerical Embeddings for Tabular Deep Learning

**Paper:** [On Embeddings for Numerical Features in Tabular Deep Learning](https://arxiv.org/abs/2203.05556)

**Authors:** Yury Gorishniy, Ivan Rubachev, Artem Babenko (Yandex Research)

**Venue:** NeurIPS 2022

---

## Key Insight

The paper shows that **transforming scalar numerical features into high-dimensional embeddings** before mixing them in the main backbone (MLP, Transformer, etc.) significantly improves performance.

Two main approaches:
1. **Piecewise Linear Encoding (PLE):** Uses learnable bin boundaries
2. **Periodic Embeddings:** Uses sin/cos functions with learnable frequencies

The magic: Simple MLPs with these embeddings can match complex Transformers!

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from models import MLPPLR, PeriodicEmbeddings, PiecewiseLinearEncoding, compute_bins

## 1. Why Do We Need Numerical Embeddings?

Consider a simple problem: predicting `y = sin(2πx)` where `x ∈ [0, 1]`.

A standard MLP struggles with this because:
- ReLU activations are piecewise linear
- The network needs many layers to approximate smooth functions
- Numerical features often have irregular distributions

**Solution:** Transform `x` into a higher-dimensional representation that makes the target function easier to learn.

In [None]:
# Visualize the problem
x = torch.linspace(0, 1, 100).unsqueeze(-1)
y = torch.sin(2 * np.pi * x)

plt.figure(figsize=(10, 4))
plt.plot(x.numpy(), y.numpy(), 'b-', linewidth=2)
plt.xlabel('x')
plt.ylabel('y = sin(2πx)')
plt.title('Target Function: Hard to Approximate with ReLU MLPs')
plt.grid(True)
plt.show()

## 2. Piecewise Linear Encoding (PLE)

PLE encodes a scalar `x` into a vector where each component represents "how much" the value falls into each bin.

For bins `[b₀, b₁], [b₁, b₂], ...`:
- If `x ≤ bᵢ`: encoding[i] = 0
- If `x ≥ bᵢ₊₁`: encoding[i] = 1  
- Otherwise: encoding[i] = (x - bᵢ) / (bᵢ₊₁ - bᵢ)

This creates a **sparse, interpretable** representation!

In [None]:
# Create sample data with 1 feature
X_train = torch.randn(1000, 1)

# Compute bins using quantiles
bins = compute_bins(X_train, n_bins=8)
print(f"Bin boundaries for feature 0: {bins[0].numpy()}")

# Create PLE
ple = PiecewiseLinearEncoding(bins)

# Encode some test values
test_values = torch.tensor([[bins[0][2].item()],   # Middle of bin 2
                            [bins[0][4].item()],   # Middle of bin 4
                            [bins[0][6].item()]])  # Middle of bin 6

encodings = ple(test_values)
print(f"\nEncoding shape: {encodings.shape}")
print(f"\nEncodings (sparse pattern visible):")
print(encodings.numpy())

In [None]:
# Visualize PLE encoding
x_range = torch.linspace(X_train.min() - 0.5, X_train.max() + 0.5, 200).unsqueeze(-1)
encodings = ple(x_range).numpy()

plt.figure(figsize=(12, 5))

# Plot each bin's activation
for i in range(encodings.shape[1]):
    plt.plot(x_range.numpy(), encodings[:, i], label=f'Bin {i}')

# Mark bin boundaries
for b in bins[0].numpy():
    plt.axvline(x=b, color='gray', linestyle='--', alpha=0.3)

plt.xlabel('Input value x')
plt.ylabel('Encoding value')
plt.title('Piecewise Linear Encoding: Each bin activates in its range')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()

## 3. Periodic Embeddings

Periodic embeddings use sin/cos functions with **learnable frequencies**:

```
embed(x) = ReLU(Linear([sin(2π·f₁·x), cos(2π·f₁·x), sin(2π·f₂·x), ..., cos(2π·fₖ·x)]))
```

Where `f₁, f₂, ..., fₖ` are learnable frequency parameters.

**Why it works:**
- Similar to Fourier features in [Neural Tangent Kernels](https://arxiv.org/abs/2006.10739)
- Network can learn which frequencies are important for the task
- Captures periodic patterns naturally

In [None]:
# Create periodic embeddings
periodic = PeriodicEmbeddings(
    n_features=1,
    d_embedding=8,
    n_frequencies=4,
    frequency_init_scale=0.1,  # Small init for stability
    lite=False,
)

# Check learned frequencies
print(f"Initialized frequencies: {periodic.frequencies.data.squeeze().numpy()}")
print(f"Output shape for single input: {periodic(torch.randn(1, 1)).shape}")

In [None]:
# Visualize periodic encoding before training
x_range = torch.linspace(-2, 2, 200).unsqueeze(-1)
with torch.no_grad():
    embeddings = periodic(x_range)  # Shape: (200, 1, 8)
    embeddings = embeddings.squeeze(1).numpy()  # Shape: (200, 8)

plt.figure(figsize=(12, 5))
for i in range(embeddings.shape[1]):
    plt.plot(x_range.numpy(), embeddings[:, i], label=f'Dim {i}')

plt.xlabel('Input value x')
plt.ylabel('Embedding value')
plt.title('Periodic Embeddings: Each dimension captures different patterns')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()

## 4. MLPPLR: MLP with Numerical Embeddings

The full model architecture:

```
Input (batch, n_features)
    ↓
Embedding Layer (periodic or PLE)
    ↓
(batch, n_features, d_embedding)
    ↓
Flatten
    ↓
(batch, n_features × d_embedding)
    ↓
MLP Backbone (Linear → BN → ReLU → Dropout) × n_blocks
    ↓
Output (batch, d_out)
```

In [None]:
# Create MLPPLR with periodic embeddings
model = MLPPLR(
    d_in=10,
    d_out=1,
    d_embedding=24,
    embedding_type="periodic",
    n_blocks=3,
    d_block=128,
    dropout=0.1,
    n_frequencies=48,
    frequency_init_scale=0.01,
)

print(f"Total parameters: {model.count_parameters():,}")

# Test forward pass
x = torch.randn(32, 10)
out = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")

In [None]:
# Create MLPPLR with PLE (need to compute bins first)
X_train = torch.randn(1000, 10)  # Simulated training data

bins = compute_bins(X_train, n_bins=32)
print(f"Number of features: {len(bins)}")
print(f"Bins per feature: {[len(b) - 1 for b in bins]}")

model_ple = MLPPLR(
    d_in=10,
    d_out=1,
    d_embedding=16,
    embedding_type="ple",
    bins=bins,
    n_blocks=3,
    d_block=128,
)

print(f"\nPLE model parameters: {model_ple.count_parameters():,}")
out = model_ple(torch.randn(32, 10))
print(f"Output shape: {out.shape}")

## 5. HFT/MFT Trading Applications

Numerical embeddings are highly relevant for financial time series:

### Why Periodic Embeddings Help:
- **Intraday patterns:** Market behavior differs at open/close
- **Round number effects:** Support/resistance at \$100, \$50 levels
- **Cyclical indicators:** RSI, oscillators have periodic properties

### Why PLE Helps:
- **Price regimes:** Different behavior in different price ranges
- **Volume bins:** High/low volume environments
- **Volatility regimes:** Calm vs. turbulent markets

### Computational Efficiency:
- Both methods add minimal overhead
- Suitable for real-time inference (unlike retrieval-based methods)
- GPU-friendly operations

In [None]:
# Example: Financial features
feature_names = [
    'returns_1m', 'returns_5m', 'returns_15m',  # Returns at different horizons
    'volume_ratio',  # Volume vs average
    'rsi_14',        # RSI (periodic by nature!)
    'vwap_deviation',  # Distance from VWAP
    'bid_ask_spread',  # Liquidity measure
    'order_imbalance', # Order flow
    'volatility_5m',   # Recent volatility
    'time_of_day',     # Normalized time (periodic!)
]

# Create model tailored for HFT
hft_model = MLPPLR(
    d_in=len(feature_names),
    d_out=1,  # Predict next return
    d_embedding=24,
    embedding_type="periodic",  # Good for RSI, time_of_day
    n_blocks=2,  # Shallow for speed
    d_block=64,  # Small for speed
    dropout=0.05,  # Light regularization
    lite=True,  # Parameter efficient
)

print(f"HFT model parameters: {hft_model.count_parameters():,}")

# Analyze learned frequencies
print(f"\nFrequency stats: {hft_model.get_frequency_stats()}")

## 6. Benchmark Results

From our benchmarks, MLPPLR achieves:

| Dataset | MLPPLR | MLP | Best Overall |
|---------|--------|-----|-------------|
| friedman | 1.67 | 1.23 | TabR (1.12) |
| nonlinear | **0.75** | 0.86 | **MLPPLR (0.75)** |
| high_dim | **1.17** | 1.37 | **MLPPLR (1.17)** |
| temporal | 1.22 | 1.24 | Temporal (0.61) |
| mixed | 1.97 | 1.98 | TabM (1.97) |

**Key insights:**
- MLPPLR excels on **nonlinear** and **high-dimensional** data
- Simple MLP baseline is hard to beat on basic regression
- Embeddings shine when features have complex relationships with targets

## 7. Hyperparameter Tips

### Periodic Embeddings:
- `frequency_init_scale`: **Critical!** Start small (0.01), tune up to 1.0
- `n_frequencies`: 32-64 usually sufficient
- `lite=True`: Use for efficiency, minimal performance loss

### PLE:
- `n_bins`: 32-64 for most tasks
- Use tree-based bins when you have `y` for supervision
- `activation=False` (version B) recommended

### General:
- `d_embedding`: 16-32 for MLP, larger for Transformers
- Standard MLP hyperparameters (layers, width, dropout) still matter

In [None]:
# Quick training example
from torch.utils.data import DataLoader, TensorDataset

# Generate synthetic data
X = torch.randn(1000, 10)
y = (X[:, 0] * X[:, 1] + torch.sin(X[:, 2] * 3)).unsqueeze(-1)  # Nonlinear target

# Create model
model = MLPPLR(d_in=10, d_out=1, embedding_type="periodic")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training loop
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

for epoch in range(5):
    total_loss = 0
    for batch_x, batch_y in loader:
        optimizer.zero_grad()
        pred = model(batch_x)
        loss = criterion(pred, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss/len(loader):.4f}")

## Summary

**On Embeddings for Numerical Features** (NeurIPS 2022) shows that:

1. ✅ Simple scalar→vector transformations dramatically improve tabular DL
2. ✅ Piecewise Linear Encoding (PLE) creates sparse, interpretable features
3. ✅ Periodic embeddings capture complex patterns via learnable frequencies
4. ✅ MLPs with embeddings can match Transformers at fraction of the cost
5. ✅ Low overhead makes it practical for real-time trading systems

**When to use:**
- Features with irregular distributions
- Complex, nonlinear feature-target relationships
- Need speed (prefer over retrieval-based methods like TabR)
- Financial data with periodic patterns