# Tutorial: iLTM - Integrated Large Tabular Model

This tutorial explains the key ideas from the **iLTM** paper and demonstrates how to use our simplified implementation.

**Paper:** [arXiv:2511.15941](https://arxiv.org/abs/2511.15941)  
**Authors:** David Bonet, Marçal Comajoan Cara, Alvaro Calafell, Daniel Mas Montserrat, Alexander G. Ioannidis  
**Venue:** arXiv 2025 (Stanford & UC Santa Cruz)

## Overview

iLTM is a **tabular foundation model** that integrates multiple approaches:

1. **Tree-Derived Embeddings**: Uses GBDT leaf indices as features
2. **Dimensionality-Agnostic Representations**: Random features + PCA for consistent embedding sizes
3. **Meta-Trained Hypernetwork**: Generates MLP weights from training data
4. **Retrieval-Augmented Predictions**: Soft k-NN blended with MLP output

The key insight is that **tree-based and neural methods are complementary**:
- GBDTs excel at capturing discrete feature interactions and are robust to uninformative features
- MLPs learn smooth functions and benefit from gradient-based optimization
- Retrieval provides local adaptivity and implicit ensembling

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_friedman1
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor

from models import iLTM, create_iltm
from models.iltm import TreeEmbedding, RandomFeatureProjection

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Part 1: Understanding GBDT Leaf Embeddings

The first key idea in iLTM is to use **GBDT leaf indices as embeddings**.

When a GBDT makes a prediction, each sample falls into exactly one leaf in each tree. By one-hot encoding these leaf indices, we get a sparse binary representation that captures:

- **Non-linear feature interactions** discovered by the tree
- **Robust feature selection** (uninformative features rarely appear in splits)
- **Discrete patterns** (regime-like behavior)

In [None]:
# Create a dataset with non-linear interactions
X, y = make_friedman1(n_samples=1000, n_features=10, noise=0.5, random_state=42)

# Fit a GBDT and extract leaf embeddings
tree_emb = TreeEmbedding(n_estimators=20, max_depth=4, task='regression')
tree_emb.fit(X, y)

# Get embeddings
embeddings = tree_emb.transform(X)

print(f"Input shape: {X.shape}")
print(f"Embedding shape: {embeddings.shape}")
print(f"Embedding dimension: {tree_emb.embedding_dim} (sum of leaves across all trees)")
print(f"Sparsity: {(embeddings == 0).mean():.2%} zeros")

In [None]:
# Visualize the sparse embeddings for a few samples
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Show embedding matrix for first 20 samples
ax = axes[0]
im = ax.imshow(embeddings[:20], aspect='auto', cmap='Blues')
ax.set_xlabel('Leaf index (across all trees)')
ax.set_ylabel('Sample')
ax.set_title('GBDT Leaf Embeddings (One-Hot Encoded)')
plt.colorbar(im, ax=ax)

# Show the number of leaves per tree
ax = axes[1]
ax.bar(range(len(tree_emb.n_leaves_per_tree)), tree_emb.n_leaves_per_tree)
ax.set_xlabel('Tree index')
ax.set_ylabel('Number of leaves')
ax.set_title('Leaves per Tree in GBDT')

plt.tight_layout()
plt.show()

## Part 2: Dimensionality-Agnostic Representation

Different datasets have different numbers of features. To build a foundation model, we need a **fixed-size representation** regardless of input dimension.

iLTM achieves this with:
1. **Random Feature Expansion**: Project to high dimension using random matrix (approximates arc-cosine kernel)
2. **PCA Reduction**: Reduce back to fixed dimension (512 in the paper)

In [None]:
# Demonstrate dimensionality-agnostic projection
proj = RandomFeatureProjection(d_out=64, n_random_features=1024)

# Convert embeddings to tensor
X_tensor = torch.tensor(embeddings, dtype=torch.float32)

# Fit the projection
proj.fit(X_tensor)

# Transform to fixed dimension
fixed_emb = proj(X_tensor)

print(f"Input shape: {X_tensor.shape}")
print(f"Output shape: {fixed_emb.shape} (always d_out={proj.d_out} regardless of input dim)")

In [None]:
# Visualize the projection
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Original sparse embeddings
ax = axes[0]
ax.hist(embeddings.flatten(), bins=3, edgecolor='black')
ax.set_xlabel('Value')
ax.set_ylabel('Count')
ax.set_title('GBDT Embeddings (Sparse Binary)')

# Projected dense embeddings
ax = axes[1]
ax.hist(fixed_emb.detach().numpy().flatten(), bins=50, edgecolor='black')
ax.set_xlabel('Value')
ax.set_ylabel('Count')
ax.set_title('Projected Embeddings (Dense, Normalized)')

plt.tight_layout()
plt.show()

## Part 3: Soft Retrieval Module

The retrieval component finds **similar training examples** for each query and aggregates their labels.

Unlike hard k-NN:
- Uses **cosine similarity** in the learned embedding space
- Applies **softmax** over similarities for smooth weighting
- Supports **temperature scaling** to control sharpness
- **Blends** with MLP predictions via α parameter

In [None]:
from models.iltm import SoftRetrievalModule

# Create retrieval module
retrieval = SoftRetrievalModule(
    d_embedding=64,
    temperature=1.0,
    k_neighbors=10
)

# Use projected embeddings as our representation
candidate_embeddings = fixed_emb  # (1000, 64)
candidate_labels = torch.tensor(y, dtype=torch.float32)  # (1000,)

# Query with a few test points
query_embeddings = fixed_emb[:5]  # (5, 64)

# Get retrieval predictions
retrieval_output = retrieval(
    query=query_embeddings,
    candidates=candidate_embeddings,
    candidate_labels=candidate_labels,
    n_classes=None  # Regression mode
)

print(f"Query shape: {query_embeddings.shape}")
print(f"Retrieval output shape: {retrieval_output.shape}")
print(f"\nRetrieval predictions vs actual:")
for i in range(5):
    print(f"  Sample {i}: predicted={retrieval_output[i].item():.3f}, actual={y[i]:.3f}")

## Part 4: Full iLTM Model

Now let's see the complete iLTM model in action. The model:

1. **Setup phase**: Fits GBDT and random projection on training data
2. **Forward pass**: Computes embeddings → conditioning → MLP → retrieval blend

In [None]:
# Create train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Convert to tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32)
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.float32)

print(f"Training set: {X_train_t.shape}")
print(f"Test set: {X_test_t.shape}")

In [None]:
# Create iLTM model
model = iLTM(
    d_in=10,
    d_out=1,
    d_main=128,
    n_blocks=2,
    use_tree_embedding=True,
    n_estimators=50,
    retrieval_alpha=0.3,  # 30% retrieval, 70% MLP
    k_neighbors=20,
)

# Setup the model (fits GBDT and projection)
model.setup(X_train_t, y_train_t)
print("Model setup complete!")

# Set candidates for retrieval
model.set_candidates(X_train_t, y_train_t)
print(f"Set {len(y_train_t)} candidates for retrieval")

# Count parameters
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {n_params:,}")

In [None]:
# Make predictions (without training the MLP)
model.eval()
with torch.no_grad():
    y_pred = model(X_test_t)

# Calculate metrics
mse = ((y_pred.flatten() - y_test_t) ** 2).mean().item()
rmse = np.sqrt(mse)

print(f"Test RMSE (without training): {rmse:.4f}")

In [None]:
# Let's train the model briefly to see improvement
import torch.optim as optim

model.train()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = torch.nn.MSELoss()

losses = []
for epoch in range(100):
    optimizer.zero_grad()
    y_pred = model(X_train_t)
    loss = criterion(y_pred.flatten(), y_train_t)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}: loss = {loss.item():.4f}")

In [None]:
# Evaluate after training
model.eval()
with torch.no_grad():
    y_pred = model(X_test_t)

mse = ((y_pred.flatten() - y_test_t) ** 2).mean().item()
rmse = np.sqrt(mse)

print(f"Test RMSE (after training): {rmse:.4f}")

# Plot training loss
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('iLTM Training Progress')
plt.grid(True, alpha=0.3)
plt.show()

## Part 5: Interpretability via Nearest Neighbors

One advantage of retrieval-augmented models is **interpretability**. We can see which training examples influenced each prediction.

In [None]:
# Get nearest neighbors for test samples
model.eval()
indices, similarities, neighbor_labels = model.get_nearest_neighbors(X_test_t[:5], k=5)

print("Nearest neighbors for first 5 test samples:")
print("="*60)

for i in range(5):
    print(f"\nTest sample {i} (actual y = {y_test[i]:.3f})")
    print(f"  Nearest neighbors (from training set):")
    for j in range(5):
        idx = indices[i, j].item()
        sim = similarities[i, j].item()
        label = neighbor_labels[i, j].item()
        print(f"    Neighbor {j+1}: idx={idx:4d}, similarity={sim:.3f}, y={label:.3f}")

## Part 6: Effect of Retrieval Alpha

The `retrieval_alpha` parameter controls the blend between MLP and retrieval:
- α = 0: Pure MLP (no retrieval)
- α = 1: Pure retrieval (no MLP)
- α = 0.3: 30% retrieval, 70% MLP (default)

In [None]:
# Test different alpha values
alphas = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
results = []

for alpha in alphas:
    model.retrieval_alpha = alpha
    model.eval()
    with torch.no_grad():
        y_pred = model(X_test_t)
    rmse = np.sqrt(((y_pred.flatten() - y_test_t) ** 2).mean().item())
    results.append(rmse)
    print(f"α = {alpha:.1f}: Test RMSE = {rmse:.4f}")

# Plot
plt.figure(figsize=(8, 4))
plt.plot(alphas, results, 'bo-', linewidth=2, markersize=8)
plt.xlabel('Retrieval Alpha (α)')
plt.ylabel('Test RMSE')
plt.title('Effect of Retrieval Weight on Performance')
plt.grid(True, alpha=0.3)
plt.show()

## Summary: Key Ideas from iLTM

### What We Learned

1. **Tree Embeddings**: GBDT leaf indices capture discrete patterns and feature interactions
2. **Dimensionality-Agnostic**: Random features + PCA create fixed-size representations
3. **Soft Retrieval**: k-NN with temperature-scaled softmax enables smooth blending
4. **Complementary Methods**: Trees + MLPs + Retrieval work better together

### When to Use iLTM

✅ **Good for:**
- Datasets where GBDTs typically excel (structured/tabular)
- Finding similar historical patterns (e.g., market regimes)
- When interpretability matters (can inspect neighbors)
- Varying input dimensions across datasets

⚠️ **Limitations of this implementation:**
- No pretrained hypernetwork (the key innovation of the paper)
- Requires setup phase for each new dataset
- Tree fitting adds overhead

For the full pretrained model, see: https://github.com/AI-sandbox/iLTM