# Contrastive Table-to-Graph Learning Pipeline
## Table-Question Alignment with InfoNCE Loss

This notebook trains a contrastive learning model that:
- Encodes table structures as graph embeddings using GNN
- Encodes natural language questions using SentenceTransformer  
- Aligns table graphs with matching questions using InfoNCE loss
- Enables semantic table retrieval given a question

**UPDATED**: Now uses shared question encoder for column names (perfect semantic alignment!)

## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install -q sentence-transformers torch-geometric scikit-learn tqdm

print("✓ Dependencies installed")

In [None]:
# Mount Google Drive for checkpoint storage
from google.colab import drive
drive.mount('/content/drive')

import os
CHECKPOINT_DIR = '/content/drive/MyDrive/contrastive_table2graph_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"✓ Checkpoints will be saved to: {CHECKPOINT_DIR}")

In [None]:
# Verify required files are uploaded
import os

required_files = ['contrastive_table2graph.py', 'gcn_conv.py']
for f in required_files:
    if os.path.exists(f'/content/{f}'):
        print(f"✓ {f} found")
    else:
        print(f"✗ {f} MISSING - please upload!")

# Check for data directory
if os.path.exists('/content/data'):
    csv_count = len([f for f in os.listdir('/content/data') if f.endswith('.csv')])
    print(f"✓ data/ folder found with {csv_count} CSV files")
else:
    print("✗ data/ folder MISSING - please upload CSV files!")

## 2. Import Pipeline Components

In [None]:
import sys
sys.path.append('/content')

import torch
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import json

# Import pipeline components
from contrastive_table2graph import (
    DataProcessor,
    ColumnContentExtractor,
    LightweightFeatureTokenizer,
    RelationshipGenerator,
    SemanticLabelGenerator,
    GraphBuilder,
    QuestionEncoder,
    ContrastiveGNNEncoder,
    AttentionPooling,
    InfoNCELoss,
    TableSpecificQuestionGenerator,  # UPDATED: New question generator
    TableQuestionDataset,
    collate_fn
)

from torch.utils.data import DataLoader, random_split

print("✓ Pipeline components imported successfully")

In [None]:
# Set random seed for reproducibility
def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print("✓ Random seed set to 42")

## 3. Load Data

In [None]:
def load_tables_from_directory(data_dir, max_tables=None, max_rows=500):
    """
    Load CSV tables from directory.
    
    Args:
        data_dir: Path to directory with CSV files
        max_tables: Max number of tables to load (None = all)
        max_rows: Max rows per table (for memory management)
    
    Returns:
        List of (table_name, DataFrame) tuples
    """
    data_path = Path(data_dir)
    csv_files = sorted(data_path.glob('*.csv'))
    
    if max_tables is not None:
        csv_files = csv_files[:max_tables]
    
    print(f"Found {len(csv_files)} CSV files\n")
    
    tables = []
    data_processor = DataProcessor()
    
    for csv_file in csv_files:
        table_name = csv_file.stem
        
        try:
            df = pd.read_csv(csv_file, nrows=max_rows, low_memory=False)
            
            # Skip invalid tables
            if len(df.columns) < 2:
                print(f"  ✗ {table_name}: < 2 columns")
                continue
            
            if len(df) == 0:
                print(f"  ✗ {table_name}: empty table")
                continue
            
            # Assign table name for tracking
            df.name = table_name
            
            tables.append((table_name, df))
            print(f"  ✓ {table_name}: {df.shape[0]} × {df.shape[1]}")
            
        except Exception as e:
            print(f"  ✗ {table_name}: {e}")
            continue
    
    return tables

# Load tables
DATA_DIR = '/content/data'
tables = load_tables_from_directory(DATA_DIR, max_tables=50, max_rows=500)

print(f"\n✓ Loaded {len(tables)} tables successfully")

In [None]:
# Display table statistics
print("Table Statistics:")
print("=" * 70)

total_rows = sum(df.shape[0] for _, df in tables)
total_cols = sum(df.shape[1] for _, df in tables)
avg_rows = total_rows / len(tables)
avg_cols = total_cols / len(tables)

print(f"Total tables: {len(tables)}")
print(f"Total rows: {total_rows:,}")
print(f"Total columns: {total_cols}")
print(f"Avg rows per table: {avg_rows:.1f}")
print(f"Avg columns per table: {avg_cols:.1f}")

print("\nSample Tables:")
for name, df in tables[:5]:
    print(f"  {name}: {df.shape[0]} × {df.shape[1]} - {list(df.columns[:3])}...")

## 4. Initialize Pipeline Components

In [None]:
print("Initializing pipeline components...\n")

# Device detection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n")

# Core components
content_extractor = ColumnContentExtractor()
print("✓ ColumnContentExtractor")

# UPDATED: Removed include_column_names parameter (doesn't exist)
feature_tokenizer = LightweightFeatureTokenizer(
    embedding_strategy='hybrid'
)
print("✓ LightweightFeatureTokenizer")
print(f"  - Statistical feature dimension: {feature_tokenizer.feature_dim}")

relationship_generator = RelationshipGenerator()
print("✓ RelationshipGenerator")

semantic_label_generator = SemanticLabelGenerator()
print("✓ SemanticLabelGenerator")

# CRITICAL: Initialize question encoder FIRST (will be shared with GraphBuilder)
question_encoder = QuestionEncoder(
    model_name='all-mpnet-base-v2',
    freeze=True
).to(device)
print("\n✓ QuestionEncoder: all-mpnet-base-v2 (frozen)")
print(f"  Output dim: {question_encoder.output_dim}")

# UPDATED: GraphBuilder now shares the question encoder for column names
# This creates perfect semantic alignment!
pyg_converter = GraphBuilder(
    content_extractor=content_extractor,
    feature_tokenizer=feature_tokenizer,
    relationship_generator=relationship_generator,
    semantic_label_generator=semantic_label_generator,
    mode='train',
    use_column_names=True,           # UPDATED: Enable column name embeddings
    question_encoder=question_encoder  # UPDATED: Share the question encoder!
)
print("\n✓ GraphBuilder (PyG converter)")
print("  [INFO] Using shared question encoder for column names - perfect semantic alignment!")

# UPDATED: Use TableSpecificQuestionGenerator instead of QuestionGenerator
question_generator = TableSpecificQuestionGenerator()
print("\n✓ TableSpecificQuestionGenerator")
print("  - Generates unique questions per table with column names")
print("  - 20 questions per table: 4 column enum + 5 structural + 5 relationship + 3 hybrid + 3 domain")
print("  - 25% of questions leverage GNN message passing (relationship-aware)")

print("\n✓ All components initialized")

## 5. Generate Question-Table Pairs

In [None]:
print("Generating question-table pairs...\n")
print("=" * 70)

# Extract DataFrames
table_dfs = [df for _, df in tables]

# UPDATED: Use new question generator
question_data = question_generator.generate_dataset(
    tables=table_dfs,
    relationship_generator=relationship_generator,
    num_per_table=20  # 20 unique questions per table
)

print(f"\n✓ Generated {len(question_data)} question-table pairs")
print(f"  - All positive pairs (label=1)")
print(f"  - {len(question_data) / len(table_dfs):.1f} questions per table (avg)")
print(f"  - Each question mentions specific column names")
print(f"  - In-batch negatives will be used during training")

In [None]:
# Display sample questions
print("\nSample Questions:")
print("=" * 70)

for i, q_data in enumerate(question_data[:5]):
    print(f"\n{i+1}. Question: {q_data['question']}")
    print(f"   Table name: {q_data.get('table_name', 'unknown')}")
    print(f"   Table shape: {q_data['table'].shape}")

## 6. Create Datasets and DataLoaders

In [None]:
print("Creating datasets...\n")

# Create full dataset
full_dataset = TableQuestionDataset(
    question_data=question_data,
    data_processor=DataProcessor(),
    pyg_converter=pyg_converter
)

# Train/val split (80/20)
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size

train_dataset, val_dataset = random_split(
    full_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"Dataset Splits:")
print(f"  Train: {len(train_dataset)} pairs ({len(train_dataset)/total_size*100:.1f}%)")
print(f"  Val:   {len(val_dataset)} pairs ({len(val_dataset)/total_size*100:.1f}%)")
print(f"  Total: {total_size} pairs")

In [None]:
# Create DataLoaders
BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Colab compatibility
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)

print(f"\nDataLoaders Created:")
print(f"  Train batches: {len(train_loader)} (batch_size={BATCH_SIZE})")
print(f"  Val batches:   {len(val_loader)} (batch_size={BATCH_SIZE})")
print(f"\n  In each batch:")
print(f"    - {BATCH_SIZE} positive pairs (question[i] ↔ table[i])")
print(f"    - {BATCH_SIZE * (BATCH_SIZE - 1)} in-batch negative pairs")

## 7. Initialize Graph Encoder

In [None]:
print("\nInitializing graph encoder...\n")
print("=" * 70)

# UPDATED: Changed node_dim from 896 to 1280
# - 512-d statistical features
# - 768-d column name embeddings (from shared all-mpnet-base-v2)
# = 1280-d total node features

graph_encoder = ContrastiveGNNEncoder(
    node_dim=1280,       # UPDATED: Was 896, now 1280 (512 stats + 768 column names)
    hidden_dim=1280,     # Keep same to avoid bottleneck
    output_dim=768,      # Project to question space
    num_layers=2
)

# Count parameters
total_params = sum(p.numel() for p in graph_encoder.parameters())
trainable_params = sum(p.numel() for p in graph_encoder.parameters() if p.requires_grad)

print("✓ ContrastiveGNNEncoder: 2-layer GNN")
print(f"  Node dim: 1280 (512 statistical + 768 column name semantics)")
print(f"  Hidden dim: 1280 (preserves all features)")
print(f"  Output dim: 768 (matches question space)")
print(f"  Total params: {total_params:,}")
print(f"  Trainable params: {trainable_params:,}")

# Move to device
graph_encoder = graph_encoder.to(device)

print(f"\n✓ Model moved to {device}")

In [None]:
# Loss function and optimizer
loss_fn = InfoNCELoss(temperature=0.07)
print(f"✓ InfoNCELoss: temperature=0.07")

optimizer = torch.optim.AdamW(
    graph_encoder.parameters(),
    lr=5e-4,
    weight_decay=0.01
)
print(f"✓ AdamW optimizer: lr=5e-4, weight_decay=0.01")

# Cosine annealing scheduler
num_training_steps = len(train_loader) * 50
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_training_steps,
    eta_min=1e-6
)
print(f"✓ CosineAnnealingLR scheduler: T_max={num_training_steps}, eta_min=1e-6")

## 8. Training Configuration

In [None]:
CONFIG = {
    'num_epochs': 50,
    'gradient_clip': 1.0,
    'print_every': 1,
    'save_every': 10,
}

print("Training Configuration:")
print("=" * 70)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 9. Training Utilities

In [None]:
def compute_recall_at_k(graph_embeddings, question_embeddings, k=1):
    """Compute Recall@K for retrieval."""
    num_questions = question_embeddings.size(0)
    
    # Similarity matrix
    similarity = torch.matmul(question_embeddings, graph_embeddings.T)
    
    # Top-k indices
    _, top_k_indices = torch.topk(similarity, k=min(k, similarity.size(1)), dim=1)
    
    # Check if correct graph in top-k
    correct_indices = torch.arange(num_questions, device=graph_embeddings.device).unsqueeze(1)
    
    # Check if correct index is in top-k predictions
    matches = (top_k_indices == correct_indices).any(dim=1)
    recall_at_k = matches.float().mean().item()
    
    return recall_at_k

def train_epoch(graph_encoder, question_encoder, loss_fn, optimizer, scheduler, train_loader, device, config):
    """Train for one epoch."""
    graph_encoder.train()
    epoch_loss = 0.0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc="Training")
    
    for batch in pbar:
        batched_graphs = batch['graphs'].to(device)
        questions = batch['questions']
        
        # Forward pass
        graph_embeddings = graph_encoder(
            batched_graphs,
            batch=batched_graphs.batch
        )
        
        question_embeddings = question_encoder(questions)
        question_embeddings = question_embeddings.to(device)
        
        # Compute loss
        loss = loss_fn(graph_embeddings, question_embeddings)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            graph_encoder.parameters(),
            max_norm=config['gradient_clip']
        )
        optimizer.step()
        scheduler.step()
        
        epoch_loss += loss.item()
        num_batches += 1
        
        # Show current LR in progress bar
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{current_lr:.2e}'})
    
    avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

@torch.no_grad()
def validate(graph_encoder, question_encoder, loss_fn, val_loader, device):
    """Validate on validation set."""
    graph_encoder.eval()
    total_loss = 0.0
    num_batches = 0
    
    all_graph_embeddings = []
    all_question_embeddings = []
    
    for batch in val_loader:
        batched_graphs = batch['graphs'].to(device)
        questions = batch['questions']
        
        # Forward pass
        graph_embeddings = graph_encoder(
            batched_graphs,
            batch=batched_graphs.batch
        )
        
        question_embeddings = question_encoder(questions)
        question_embeddings = question_embeddings.to(device)
        
        # Compute loss
        loss = loss_fn(graph_embeddings, question_embeddings)
        total_loss += loss.item()
        num_batches += 1
        
        # Store for recall calculation
        all_graph_embeddings.append(graph_embeddings)
        all_question_embeddings.append(question_embeddings)
    
    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    
    # Compute Recall@K
    graph_embs = torch.cat(all_graph_embeddings, dim=0)
    question_embs = torch.cat(all_question_embeddings, dim=0)
    
    recall_1 = compute_recall_at_k(graph_embs, question_embs, k=1)
    recall_5 = compute_recall_at_k(graph_embs, question_embs, k=5)
    
    return {
        'loss': avg_loss,
        'recall@1': recall_1,
        'recall@5': recall_5
    }

print("✓ Training utilities defined")

## 10. Training Loop

In [None]:
print("Starting training...\n")
print("=" * 70)

history = {
    'epoch': [],
    'train_loss': [],
    'val_loss': [],
    'val_recall@1': [],
    'val_recall@5': [],
    'learning_rate': [],
    'time': []
}

best_recall = 0.0

for epoch in range(CONFIG['num_epochs']):
    epoch_start = time.time()
    
    # Train
    train_loss = train_epoch(
        graph_encoder, question_encoder, loss_fn, optimizer, scheduler,
        train_loader, device, CONFIG
    )
    
    # Validate
    val_metrics = validate(
        graph_encoder, question_encoder, loss_fn, val_loader, device
    )
    
    epoch_time = time.time() - epoch_start
    current_lr = optimizer.param_groups[0]['lr']
    
    # Record history
    history['epoch'].append(epoch + 1)
    history['train_loss'].append(float(train_loss))
    history['val_loss'].append(float(val_metrics['loss']))
    history['val_recall@1'].append(float(val_metrics['recall@1']))
    history['val_recall@5'].append(float(val_metrics['recall@5']))
    history['learning_rate'].append(float(current_lr))
    history['time'].append(float(epoch_time))
    
    # Print progress
    if (epoch + 1) % CONFIG['print_every'] == 0:
        print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_metrics['loss']:.4f}")
        print(f"  Recall@1:   {val_metrics['recall@1']:.3f}")
        print(f"  Recall@5:   {val_metrics['recall@5']:.3f}")
        print(f"  LR:         {current_lr:.2e}")
        print(f"  Time:       {epoch_time:.2f}s")
    
    # Save best model
    if val_metrics['recall@1'] > best_recall:
        best_recall = val_metrics['recall@1']
        torch.save({
            'epoch': epoch + 1,
            'graph_encoder_state': graph_encoder.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'metrics': val_metrics,
        }, f"{CHECKPOINT_DIR}/best_model.pt")
        print(f"  ✓ New best model saved (Recall@1: {best_recall:.3f})")
    
    # Periodic checkpoint
    if (epoch + 1) % CONFIG['save_every'] == 0:
        torch.save({
            'epoch': epoch + 1,
            'graph_encoder_state': graph_encoder.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'metrics': val_metrics,
        }, f"{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pt")
        print(f"  ✓ Checkpoint saved")

print("\n" + "=" * 70)
print(f"Training Complete!")
print(f"Best Recall@1: {best_recall:.3f}")
print("=" * 70)

## 11. Plot Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(history['epoch'], history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(history['epoch'], history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('InfoNCE Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Recall curves
axes[1].plot(history['epoch'], history['val_recall@1'], 'g-', label='Recall@1', linewidth=2)
axes[1].plot(history['epoch'], history['val_recall@5'], 'm-', label='Recall@5', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Recall', fontsize=12)
axes[1].set_title('Validation Recall Metrics', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim([0, 1])

plt.tight_layout()
plt.savefig(f'{CHECKPOINT_DIR}/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Training curves saved to {CHECKPOINT_DIR}/training_curves.png")

## 12. Test Table Retrieval

In [None]:
# Load best model
checkpoint = torch.load(f"{CHECKPOINT_DIR}/best_model.pt", map_location=device)
graph_encoder.load_state_dict(checkpoint['graph_encoder_state'])
graph_encoder.eval()

print(f"✓ Loaded best model (Epoch {checkpoint['epoch']})")
print(f"  Recall@1: {checkpoint['metrics']['recall@1']:.3f}")
print(f"  Recall@5: {checkpoint['metrics']['recall@5']:.3f}")

In [None]:
# Create embeddings for all tables
print("\nGenerating embeddings for all tables...")

table_embeddings_list = []
table_names_list = []

with torch.no_grad():
    for table_name, df in tqdm(tables, desc="Encoding tables"):
        # UPDATED: Use build_graph instead of convert_table
        pyg_data = pyg_converter.build_graph(df)
        pyg_data = pyg_data.to(device)
        
        # Encode
        graph_emb = graph_encoder(pyg_data, batch=None)
        
        table_embeddings_list.append(graph_emb.cpu())
        table_names_list.append(table_name)

table_embeddings = torch.cat(table_embeddings_list, dim=0)
print(f"✓ Encoded {len(tables)} tables: {table_embeddings.shape}")

In [None]:
def retrieve_tables(query_question, top_k=5):
    """
    Given a question, retrieve top-k most relevant tables.
    
    Args:
        query_question: Natural language question
        top_k: Number of tables to retrieve
    
    Returns:
        List of (table_name, similarity_score) tuples
    """
    # Encode question
    with torch.no_grad():
        question_emb = question_encoder([query_question])
        question_emb = question_emb.to(device)
    
    # Compute similarities
    table_embs_device = table_embeddings.to(device)
    similarities = torch.matmul(question_emb, table_embs_device.T).squeeze()
    
    # Get top-k
    top_k_scores, top_k_indices = torch.topk(similarities, k=min(top_k, len(tables)))
    
    results = []
    for score, idx in zip(top_k_scores, top_k_indices):
        results.append({
            'table_name': table_names_list[idx],
            'similarity': score.item(),
            'table_shape': tables[idx][1].shape
        })
    
    return results

print("✓ Retrieval function ready")

In [None]:
# Test retrieval with column-specific questions
test_questions = [
    "Which table contains patient_id and admission_date columns?",
    "Which table has foreign key hadm_id that references subject_id?",
    "Which table tracks time intervals between admittime and dischtime?",
    "Which table pairs measurement valuenum with dimension itemid?",
    "Which table has temporal columns charttime and storetime?",
    "Which table uses icustay_id as an identifier?"
]

for query in test_questions:
    print("\n" + "=" * 70)
    print(f"Query: {query}")
    print("=" * 70)
    
    results = retrieve_tables(query, top_k=5)
    
    print("\nTop 5 Retrieved Tables:")
    for i, result in enumerate(results):
        print(f"{i+1}. {result['table_name']}")
        print(f"   Similarity: {result['similarity']:.4f}")
        print(f"   Shape: {result['table_shape']}")

## 13. Export Results

In [None]:
# Save training history
with open(f'{CHECKPOINT_DIR}/training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"✓ Training history saved")

# Save final metrics
final_metrics = {
    'best_recall@1': best_recall,
    'final_epoch': CONFIG['num_epochs'],
    'num_tables': len(tables),
    'num_questions': len(question_data),
    'train_size': len(train_dataset),
    'val_size': len(val_dataset),
    'model_params': total_params,
    'node_dim': 1280,
    'uses_shared_encoder': True,
}

with open(f'{CHECKPOINT_DIR}/final_metrics.json', 'w') as f:
    json.dump(final_metrics, f, indent=2)

print(f"✓ Final metrics saved")
print(f"\n✓ All results exported to: {CHECKPOINT_DIR}")

## 14. Summary

### Training Complete!

**Model**: Contrastive GNN for Table-Question Alignment (UPDATED VERSION)
- Architecture: 2-layer GNN + Attention Pooling + Projection Head
- Loss: InfoNCE with in-batch negatives
- Question Encoder: all-mpnet-base-v2 (frozen, **shared with column name encoder**)
- Graph Encoder: Trainable
- Node Features: **1280-d** (512 stats + 768 column names from shared encoder)

**Key Improvements**:
1. ✅ **Shared encoder**: Column names and questions use the SAME encoder → perfect semantic alignment!
2. ✅ **Table-specific questions**: Each question mentions specific column names (e.g., "hadm_id", "subject_id")
3. ✅ **Relationship-aware questions**: 25% of questions leverage GNN message passing
4. ✅ **No information bottleneck**: hidden_dim=1280 preserves all features

**Saved Artifacts**:
- `best_model.pt` - Best model checkpoint
- `checkpoint_epoch_*.pt` - Periodic checkpoints
- `training_curves.png` - Loss and recall curves
- `training_history.json` - Complete training history
- `final_metrics.json` - Summary statistics

**Expected Performance**:
- **Old approach**: Recall@1 < 5% (worse than random guessing)
- **New approach**: Recall@1 = 60-80% within 10-20 epochs

**Why it works better**:
- Column name "hadm_id" in table embeddings **exactly matches** "hadm_id" in question embeddings
- No need to learn complex transformations - just learn which columns to attend to
- Questions are unique per table → no ambiguous supervision