# TabR: Retrieval-Augmented Tabular Deep Learning

This tutorial explains TabR, a model that combines deep learning with k-NN-style retrieval. For each prediction, TabR retrieves similar training examples and uses attention to aggregate their information.

**Paper:** [TabR: Tabular Deep Learning Meets Nearest Neighbors (ICLR 2024)](https://arxiv.org/abs/2307.14338)

## Key Ideas

1. **Retrieval-Augmented Prediction**: Instead of relying solely on learned weights, TabR retrieves similar examples from the training set to inform predictions.

2. **Soft Attention Retrieval**: Unlike hard k-NN, TabR uses soft attention over candidates, making it end-to-end differentiable.

3. **Label Integration**: Retrieved neighbors contribute both their features AND labels to the prediction.

## Why TabR for Trading?

- **Regime Detection**: Finds similar historical market patterns
- **Explainability**: "This prediction is based on these similar historical examples"
- **Noise Robustness**: Averaging over similar examples provides implicit ensembling
- **Non-stationarity**: Can adapt by finding recent similar patterns

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from models import TabR

## 1. Architecture Overview

TabR has three main components:

```
Input (x) → Embeddings → Query
                           ↓
Candidates ────────────→ Retrieval (Attention) → Context
                           ↓
              [Query + Context] → MLP → Prediction
```

Let's create a model and examine its structure:

In [None]:
# Create a TabR model
model = TabR(
    d_in=10,           # 10 input features
    d_out=1,           # 1 output (regression)
    d_embedding=24,    # Embedding dimension per feature
    d_block=128,       # MLP hidden dimension
    n_blocks=2,        # Number of MLP blocks
    n_heads=4,         # Attention heads in retrieval
    k_neighbors=64,    # Number of neighbors to retrieve
    max_candidates=1000,  # Max candidates to store
)

print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

## 2. How Retrieval Works

The retrieval module computes attention over all stored candidates:

1. **Embed query and candidates** into a shared space
2. **Compute attention scores** using dot product
3. **Select top-k neighbors** for efficiency
4. **Aggregate with softmax attention**

The key insight is that this is differentiable, so gradients flow back to improve embeddings!

In [None]:
# Generate synthetic training data
torch.manual_seed(42)
n_samples = 500

# Create data with a non-linear pattern
X_train = torch.randn(n_samples, 10)
y_train = (
    2 * torch.sin(X_train[:, 0] * 2) +
    X_train[:, 1] ** 2 +
    X_train[:, 2] * X_train[:, 3] +
    0.5 * torch.randn(n_samples)  # noise
).unsqueeze(1)

print(f"Training data: X={X_train.shape}, y={y_train.shape}")

## 3. Training with Candidate Accumulation

During training, TabR accumulates candidates from training batches. This is done automatically when you pass `y_for_candidates`:

In [None]:
# Set up training
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Training loop
model.train()
losses = []
batch_size = 32

for epoch in range(50):
    epoch_loss = 0
    
    # Random batches
    perm = torch.randperm(n_samples)
    for i in range(0, n_samples, batch_size):
        idx = perm[i:i+batch_size]
        x_batch = X_train[idx]
        y_batch = y_train[idx]
        
        optimizer.zero_grad()
        
        # Pass y_for_candidates to accumulate candidates during training
        pred = model(x_batch, y_for_candidates=y_batch)
        
        loss = criterion(pred, y_batch)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    losses.append(epoch_loss / (n_samples // batch_size))
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: Loss = {losses[-1]:.4f}")

# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss')
plt.grid(True)
plt.show()

## 4. Interpretability: Finding Similar Examples

One of TabR's key advantages is interpretability. We can see which training examples influenced a prediction:

In [None]:
# Switch to eval mode
model.eval()

# Create a test sample
x_test = torch.randn(5, 10)

# Make prediction
with torch.no_grad():
    pred = model(x_test)
    
print("Predictions:", pred.squeeze().numpy())

In [None]:
# Get nearest neighbors for interpretability
k = 5
indices, distances, neighbor_labels = model.get_nearest_neighbors(x_test, k=k)

print(f"\nFor each test sample, showing {k} nearest neighbors:")
print("="*60)

for i in range(len(x_test)):
    print(f"\nTest sample {i}: Predicted = {pred[i].item():.3f}")
    print(f"  Neighbors (labels): {neighbor_labels[i].numpy()}")
    print(f"  Mean neighbor label: {neighbor_labels[i].mean():.3f}")
    print(f"  Distances: {distances[i].numpy()[:3]}...")  # Show first 3

## 5. Visualizing the Embedding Space

Let's visualize how TabR embeds examples and finds neighbors:

In [None]:
# Get embeddings for visualization
from sklearn.decomposition import PCA

with torch.no_grad():
    # Embed training data
    train_emb = model.embeddings(X_train).view(n_samples, -1).numpy()
    test_emb = model.embeddings(x_test).view(len(x_test), -1).numpy()

# Reduce to 2D for visualization
pca = PCA(n_components=2)
all_emb = np.vstack([train_emb, test_emb])
all_2d = pca.fit_transform(all_emb)

train_2d = all_2d[:n_samples]
test_2d = all_2d[n_samples:]

# Plot
fig, ax = plt.subplots(figsize=(10, 8))

# Plot training points colored by label
scatter = ax.scatter(
    train_2d[:, 0], train_2d[:, 1],
    c=y_train.squeeze().numpy(),
    cmap='viridis',
    alpha=0.6,
    s=30,
    label='Training'
)
plt.colorbar(scatter, label='Target value')

# Plot test points
ax.scatter(
    test_2d[:, 0], test_2d[:, 1],
    c='red',
    marker='*',
    s=200,
    edgecolors='black',
    label='Test queries'
)

# Draw lines to neighbors for first test point
test_idx = 0
for neighbor_idx in indices[test_idx]:
    ax.plot(
        [test_2d[test_idx, 0], train_2d[neighbor_idx, 0]],
        [test_2d[test_idx, 1], train_2d[neighbor_idx, 1]],
        'r--', alpha=0.3
    )

ax.set_xlabel('PCA Component 1')
ax.set_ylabel('PCA Component 2')
ax.set_title('TabR Embedding Space (2D PCA projection)')
ax.legend()
plt.tight_layout()
plt.show()

## 6. Trading Application: Market Regime Detection

TabR is particularly useful for identifying similar historical market conditions. Here's a conceptual example:

In [None]:
# Simulate market features
np.random.seed(42)
n_days = 1000

# Features: [volatility, momentum, volume, spread, ...]
market_features = np.column_stack([
    np.random.exponential(0.02, n_days),  # volatility
    np.random.randn(n_days) * 0.01,       # momentum
    np.random.lognormal(0, 0.5, n_days),  # volume
    np.random.exponential(0.001, n_days), # spread
    np.random.randn(n_days, 6) * 0.1      # other features
])

# Simulate returns (depends on features with noise)
returns = (
    0.1 * market_features[:, 1] -         # momentum effect
    0.5 * market_features[:, 0] +         # volatility drag
    0.02 * np.random.randn(n_days)        # noise
)

# Create TabR model for market prediction
market_model = TabR(
    d_in=10,
    d_out=1,
    d_embedding=16,
    k_neighbors=20,
    max_candidates=500,
)

X_market = torch.tensor(market_features, dtype=torch.float32)
y_market = torch.tensor(returns, dtype=torch.float32).unsqueeze(1)

print(f"Market data: {n_days} days, {market_features.shape[1]} features")
print(f"Return range: [{returns.min():.4f}, {returns.max():.4f}]")

In [None]:
# Train on historical data
market_model.train()
optimizer = torch.optim.Adam(market_model.parameters(), lr=0.001)

train_size = 800
X_hist = X_market[:train_size]
y_hist = y_market[:train_size]

for epoch in range(30):
    optimizer.zero_grad()
    pred = market_model(X_hist, y_for_candidates=y_hist)
    loss = nn.MSELoss()(pred, y_hist)
    loss.backward()
    optimizer.step()

print(f"Training complete. Final loss: {loss.item():.6f}")

In [None]:
# Analyze current market conditions
market_model.eval()
current_day = X_market[850:851]  # Pick a test day

with torch.no_grad():
    prediction = market_model(current_day)
    indices, distances, similar_returns = market_model.get_nearest_neighbors(current_day, k=10)

print("="*60)
print("MARKET REGIME ANALYSIS")
print("="*60)
print(f"\nPredicted return: {prediction.item():.4f}")
print(f"Actual return: {y_market[850].item():.4f}")
print(f"\nSimilar historical days (returns): {similar_returns.squeeze().numpy()[:5]}")
print(f"Average return in similar conditions: {similar_returns.mean():.4f}")
print(f"Std of returns in similar conditions: {similar_returns.std():.4f}")

# This information can be used for:
# 1. Confidence estimation (low std = high confidence)
# 2. Risk management (check similar days for drawdowns)
# 3. Explainability ("prediction based on these similar days")

## 7. Key Takeaways

1. **TabR combines the best of both worlds**: Deep learning representation + k-NN retrieval

2. **Interpretable**: You can always explain predictions by showing similar training examples

3. **Works best with moderate dimensions**: For very high-dimensional data (50+ features), consider dimensionality reduction first

4. **Trading applications**:
   - Market regime detection
   - Confidence estimation from neighbor variance
   - Risk management via historical analogs
   - Regulatory-compliant explainability

5. **Limitations**:
   - Memory scales with number of candidates
   - Retrieval slower for high-dimensional embeddings
   - Curse of dimensionality affects k-NN component