# sklearn TF-IDF vs SPLADE Classifier Comparison

This notebook compares two text classification approaches on the AG News dataset:

1. **sklearn TF-IDF + LogisticRegression** - Traditional sparse bag-of-words approach
2. **SPLADE Neural Classifier** - Neural sparse representations with interpretability

We evaluate both on:
- Accuracy and F1 scores
- Training and inference time
- Sparsity of representations
- Interpretability

## 1. Setup & Imports

In [None]:
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

from src.models import SPLADEClassifier

# Set random seeds for reproducibility
import torch
np.random.seed(42)
torch.manual_seed(42)

# Display settings
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3
%matplotlib inline

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Load AG News Dataset

AG News is a 4-class news topic classification dataset with 120,000 training samples.

**Classes:**
- 0: World
- 1: Sports
- 2: Business
- 3: Sci/Tech

In [None]:
# Load AG News from HuggingFace
print("Loading AG News dataset...")
dataset = load_dataset("ag_news")

print(f"\nDataset structure:")
print(dataset)

In [None]:
# Extract texts and labels (convert to lists for compatibility)
CLASS_NAMES = ["World", "Sports", "Business", "Sci/Tech"]

train_texts = list(dataset["train"]["text"])
train_labels = list(dataset["train"]["label"])

test_texts = list(dataset["test"]["text"])
test_labels = list(dataset["test"]["label"])

print(f"Training samples: {len(train_texts):,}")
print(f"Test samples: {len(test_texts):,}")
print(f"Number of classes: {len(CLASS_NAMES)}")

## 3. Exploratory Data Analysis

In [None]:
# Class distribution
train_label_counts = pd.Series(train_labels).value_counts().sort_index()

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Bar chart of class distribution
axes[0].bar(CLASS_NAMES, train_label_counts.values, color=['#e74c3c', '#3498db', '#2ecc71', '#9b59b6'])
axes[0].set_ylabel('Number of samples')
axes[0].set_title('Training Set Class Distribution')
for i, v in enumerate(train_label_counts.values):
    axes[0].text(i, v + 500, f'{v:,}', ha='center')

# Text length distribution
text_lengths = [len(t.split()) for t in train_texts[:5000]]  # Sample for speed
axes[1].hist(text_lengths, bins=50, color='steelblue', edgecolor='white')
axes[1].set_xlabel('Number of words')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Text Length Distribution (sample)')
axes[1].axvline(np.median(text_lengths), color='red', linestyle='--', label=f'Median: {np.median(text_lengths):.0f}')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"\nText length statistics (words):")
print(f"  Min: {np.min(text_lengths)}, Max: {np.max(text_lengths)}, Median: {np.median(text_lengths):.0f}")

In [None]:
# Sample examples from each class
print("Sample examples from each class:\n")
for class_idx, class_name in enumerate(CLASS_NAMES):
    # Find first example of this class
    for i, label in enumerate(train_labels):
        if label == class_idx:
            print(f"=== {class_name} ===")
            print(train_texts[i][:300] + "...")
            print()
            break

## 4. sklearn TF-IDF + LogisticRegression Baseline

In [None]:
print("Training sklearn TF-IDF + LogisticRegression...\n")

# TF-IDF Vectorization
tfidf_start = time.time()

vectorizer = TfidfVectorizer(
    max_features=30000,  # Match SPLADE vocab size roughly
    ngram_range=(1, 2),
    min_df=2,
    max_df=0.95,
    sublinear_tf=True
)

X_train_tfidf = vectorizer.fit_transform(train_texts)
X_test_tfidf = vectorizer.transform(test_texts)

tfidf_time = time.time() - tfidf_start
print(f"TF-IDF vectorization: {tfidf_time:.2f}s")
print(f"TF-IDF vocabulary size: {len(vectorizer.vocabulary_):,}")
print(f"TF-IDF matrix shape: {X_train_tfidf.shape}")
print(f"TF-IDF sparsity: {(1 - X_train_tfidf.nnz / (X_train_tfidf.shape[0] * X_train_tfidf.shape[1])) * 100:.2f}%")

In [None]:
# Train Logistic Regression
lr_start = time.time()

lr_clf = LogisticRegression(
    max_iter=1000,
    solver='lbfgs',
    multi_class='multinomial',
    n_jobs=-1,
    random_state=42
)

lr_clf.fit(X_train_tfidf, train_labels)

sklearn_train_time = tfidf_time + (time.time() - lr_start)
print(f"\nTotal sklearn training time: {sklearn_train_time:.2f}s")

In [None]:
# Evaluate sklearn model
sklearn_inference_start = time.time()
sklearn_preds = lr_clf.predict(X_test_tfidf)
sklearn_inference_time = time.time() - sklearn_inference_start

sklearn_accuracy = accuracy_score(test_labels, sklearn_preds)
sklearn_f1 = f1_score(test_labels, sklearn_preds, average='macro')

print("sklearn TF-IDF + LogisticRegression Results:")
print(f"  Accuracy: {sklearn_accuracy:.4f}")
print(f"  F1 (macro): {sklearn_f1:.4f}")
print(f"  Inference time: {sklearn_inference_time:.2f}s")
print()
print(classification_report(test_labels, sklearn_preds, target_names=CLASS_NAMES))

## 5. SPLADE Neural Classifier

In [None]:
# Initialize SPLADE classifier for 4-class classification
print("Initializing SPLADE classifier...")

splade_clf = SPLADEClassifier(
    model_name="distilbert-base-uncased",
    max_length=128,
    batch_size=32,
    learning_rate=2e-5,
    flops_lambda=1e-4,
    num_labels=4,
    class_names=CLASS_NAMES,
    verbose=True
)

print(f"Device: {splade_clf.device}")

In [None]:
# Train SPLADE model
# Note: For faster demo, you can use a subset of the data
# For full comparison, use all training data (takes ~1-2 hours on GPU)

# Subsample for faster training (optional - comment out for full training)
TRAIN_SUBSET = 10000  # Use None for full dataset

if TRAIN_SUBSET:
    indices = np.random.choice(len(train_texts), TRAIN_SUBSET, replace=False)
    splade_train_texts = [train_texts[i] for i in indices]
    splade_train_labels = [train_labels[i] for i in indices]
    print(f"Training on subset: {len(splade_train_texts):,} samples")
else:
    splade_train_texts = train_texts
    splade_train_labels = train_labels
    print(f"Training on full dataset: {len(splade_train_texts):,} samples")

splade_train_start = time.time()

splade_clf.fit(
    splade_train_texts,
    splade_train_labels,
    epochs=3
)

splade_train_time = time.time() - splade_train_start
print(f"\nSPLADE training time: {splade_train_time:.2f}s")

In [None]:
# Evaluate SPLADE model
print("Evaluating SPLADE on test set...")

splade_inference_start = time.time()
splade_preds = splade_clf.predict(test_texts)
splade_inference_time = time.time() - splade_inference_start

splade_accuracy = accuracy_score(test_labels, splade_preds)
splade_f1 = f1_score(test_labels, splade_preds, average='macro')

print("\nSPLADE Classifier Results:")
print(f"  Accuracy: {splade_accuracy:.4f}")
print(f"  F1 (macro): {splade_f1:.4f}")
print(f"  Inference time: {splade_inference_time:.2f}s")
print()
print(classification_report(test_labels, splade_preds, target_names=CLASS_NAMES))

In [None]:
# Compute SPLADE sparsity
print("Computing SPLADE sparsity on sample...")
sample_texts = test_texts[:100]
splade_sparsity = splade_clf.get_sparsity(sample_texts)
print(f"SPLADE vector sparsity: {splade_sparsity:.2f}%")

## 6. Comparison & Visualization

In [None]:
# Comparison table
comparison_data = {
    'Metric': ['Accuracy', 'F1 (macro)', 'Training Time (s)', 'Inference Time (s)', 'Sparsity (%)'],
    'sklearn TF-IDF': [
        f"{sklearn_accuracy:.4f}",
        f"{sklearn_f1:.4f}",
        f"{sklearn_train_time:.2f}",
        f"{sklearn_inference_time:.2f}",
        f"{(1 - X_test_tfidf.nnz / (X_test_tfidf.shape[0] * X_test_tfidf.shape[1])) * 100:.2f}"
    ],
    'SPLADE': [
        f"{splade_accuracy:.4f}",
        f"{splade_f1:.4f}",
        f"{splade_train_time:.2f}",
        f"{splade_inference_time:.2f}",
        f"{splade_sparsity:.2f}"
    ]
}

comparison_df = pd.DataFrame(comparison_data)
print("\n" + "="*60)
print("COMPARISON SUMMARY")
print("="*60)
print(comparison_df.to_string(index=False))
print("="*60)

In [None]:
# Visualization: Side-by-side comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Accuracy comparison
models = ['sklearn\nTF-IDF', 'SPLADE']
accuracies = [sklearn_accuracy, splade_accuracy]
colors = ['#3498db', '#e74c3c']

axes[0].bar(models, accuracies, color=colors)
axes[0].set_ylim(0.7, 1.0)
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Accuracy Comparison')
for i, v in enumerate(accuracies):
    axes[0].text(i, v + 0.01, f'{v:.4f}', ha='center', fontweight='bold')

# F1 comparison
f1_scores = [sklearn_f1, splade_f1]
axes[1].bar(models, f1_scores, color=colors)
axes[1].set_ylim(0.7, 1.0)
axes[1].set_ylabel('F1 Score (macro)')
axes[1].set_title('F1 Score Comparison')
for i, v in enumerate(f1_scores):
    axes[1].text(i, v + 0.01, f'{v:.4f}', ha='center', fontweight='bold')

# Training time comparison
train_times = [sklearn_train_time, splade_train_time]
axes[2].bar(models, train_times, color=colors)
axes[2].set_ylabel('Time (seconds)')
axes[2].set_title('Training Time Comparison')
for i, v in enumerate(train_times):
    axes[2].text(i, v + max(train_times)*0.02, f'{v:.1f}s', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Confusion matrices
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# sklearn confusion matrix
sklearn_cm = confusion_matrix(test_labels, sklearn_preds)
sns.heatmap(sklearn_cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[0])
axes[0].set_title('sklearn TF-IDF Confusion Matrix')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')

# SPLADE confusion matrix
splade_cm = confusion_matrix(test_labels, splade_preds)
sns.heatmap(splade_cm, annot=True, fmt='d', cmap='Reds',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[1])
axes[1].set_title('SPLADE Confusion Matrix')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')

plt.tight_layout()
plt.show()

## 7. Interpretability Demo

One of the key advantages of SPLADE is its inherent interpretability. Each dimension in the sparse vector corresponds to a vocabulary token, making it easy to understand *why* the model made a prediction.

In [None]:
# Example texts for interpretation
example_texts = [
    "Apple stock surged 5% after announcing record iPhone sales and strong quarterly earnings.",
    "The Lakers defeated the Celtics 112-98 in an exciting NBA playoff game last night.",
    "Scientists discovered a new exoplanet that could potentially support life in a distant solar system.",
    "Political tensions rise as world leaders meet at the United Nations summit to discuss climate change."
]

expected_classes = ["Business", "Sports", "Sci/Tech", "World"]

print("SPLADE Interpretability Demo\n")
print("="*60)

for i, text in enumerate(example_texts):
    print(f"\nExample {i+1} (Expected: {expected_classes[i]})")
    splade_clf.print_explanation(text, top_k=10)
    print("\n" + "-"*60)

In [None]:
# Compare with sklearn TF-IDF weights for interpretation
print("TF-IDF Top Terms (for comparison)\n")

# Handle both old and new sklearn API
try:
    feature_names = vectorizer.get_feature_names_out()
except AttributeError:
    feature_names = vectorizer.get_feature_names()

for i, text in enumerate(example_texts[:2]):  # Just first 2 examples
    print(f"Example {i+1}: {text[:80]}...")
    
    # Get TF-IDF vector
    tfidf_vec = vectorizer.transform([text])
    
    # Get top terms
    indices = tfidf_vec.toarray()[0].argsort()[-10:][::-1]
    
    print("Top TF-IDF terms:")
    for idx in indices:
        weight = tfidf_vec[0, idx]
        if weight > 0:
            print(f"  {feature_names[idx]:<20} {weight:.3f}")
    print()

## 8. Sparsity Analysis

Both methods produce sparse representations, but with different characteristics.

In [None]:
# Get SPLADE vectors for analysis
sample_size = 100
sample_texts_analysis = test_texts[:sample_size]

# SPLADE vectors
splade_vectors = splade_clf.transform(sample_texts_analysis)
splade_nnz = (splade_vectors != 0).sum(dim=1).float()

# TF-IDF vectors
tfidf_vectors = vectorizer.transform(sample_texts_analysis)
tfidf_nnz = np.array([(tfidf_vectors[i] != 0).sum() for i in range(sample_size)])

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Non-zero elements histogram
axes[0].hist(tfidf_nnz, bins=30, alpha=0.7, label='TF-IDF', color='#3498db')
axes[0].hist(splade_nnz.numpy(), bins=30, alpha=0.7, label='SPLADE', color='#e74c3c')
axes[0].set_xlabel('Number of non-zero elements')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Non-zero Elements per Document')
axes[0].legend()

# Sparsity comparison
tfidf_total = tfidf_vectors.shape[1]
splade_total = splade_vectors.shape[1]

tfidf_sparsity = (1 - tfidf_nnz.mean() / tfidf_total) * 100
splade_sparsity_val = (1 - splade_nnz.mean().item() / splade_total) * 100

axes[1].bar(['TF-IDF', 'SPLADE'], [tfidf_sparsity, splade_sparsity_val], 
            color=['#3498db', '#e74c3c'])
axes[1].set_ylabel('Sparsity (%)')
axes[1].set_title('Vector Sparsity Comparison')
axes[1].set_ylim(90, 100)
for i, v in enumerate([tfidf_sparsity, splade_sparsity_val]):
    axes[1].text(i, v + 0.3, f'{v:.2f}%', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\nAverage non-zero elements:")
print(f"  TF-IDF: {tfidf_nnz.mean():.1f} / {tfidf_total} ({tfidf_sparsity:.2f}% sparse)")
print(f"  SPLADE: {splade_nnz.mean().item():.1f} / {splade_total} ({splade_sparsity_val:.2f}% sparse)")

## 9. Summary & Conclusions

### Key Findings:

1. **Accuracy**: Both methods achieve strong performance on AG News
2. **Training Time**: sklearn is faster (CPU-based, no backprop), SPLADE requires GPU training
3. **Sparsity**: SPLADE produces sparser vectors (better for retrieval)
4. **Interpretability**: SPLADE provides semantic term weights, TF-IDF provides lexical weights

### When to use each:

- **sklearn TF-IDF**: Quick prototyping, limited compute, large-scale batch processing
- **SPLADE**: When interpretability matters, semantic matching needed, or for retrieval tasks

In [None]:
# Save the SPLADE model
splade_clf.save("../models/splade_ag_news.pt")
print("Model saved to ../models/splade_ag_news.pt")