# Optimize Classification Threshold for F1 Score

Find the optimal probability threshold to maximize F1 score

In [None]:
import sys
sys.path.append('../')

import pandas as pd
import numpy as np
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, confusion_matrix
from sklearn.linear_model import LogisticRegression

pd.set_option('display.max_columns', None)

## Load Data and Model

In [None]:
# Load test data
X_test = pd.read_pickle('../data/features/X_test_temporal.pkl')
y_test = pd.read_pickle('../data/features/y_test_cls_temporal.pkl')

print(f"Test set: {X_test.shape}")
print(f"High-impact papers: {y_test.sum()} ({y_test.mean()*100:.1f}%)")

In [None]:
# Load or train Logistic Regression (best model)
model_path = Path('../models/logistic_regression_classifier.pkl')

if model_path.exists():
    print("Loading saved model...")
    with open(model_path, 'rb') as f:
        model = pickle.load(f)
else:
    print("Training Logistic Regression...")
    X_train = pd.read_pickle('../data/features/X_train_temporal.pkl')
    y_train = pd.read_pickle('../data/features/y_train_cls_temporal.pkl')
    
    model = LogisticRegression(max_iter=1000, random_state=42, n_jobs=-1)
    model.fit(X_train, y_train)
    
    # Save model
    Path('../models').mkdir(exist_ok=True)
    with open(model_path, 'wb') as f:
        pickle.dump(model, f)
    print("Model saved!")

# Get probability predictions
y_pred_proba = model.predict_proba(X_test)[:, 1]
print(f"\nâœ“ Predictions generated")

## Current Performance (threshold=0.5)

In [None]:
# Default threshold = 0.5
y_pred_default = (y_pred_proba >= 0.5).astype(int)

default_f1 = f1_score(y_test, y_pred_default)
default_precision = precision_score(y_test, y_pred_default)
default_recall = recall_score(y_test, y_pred_default)
default_roc_auc = roc_auc_score(y_test, y_pred_proba)

print("CURRENT PERFORMANCE (threshold=0.5):")
print(f"  ROC-AUC: {default_roc_auc:.4f} ({default_roc_auc*100:.2f}%)")
print(f"  F1 Score: {default_f1:.4f} ({default_f1*100:.2f}%)")
print(f"  Precision: {default_precision:.4f} ({default_precision*100:.2f}%)")
print(f"  Recall: {default_recall:.4f} ({default_recall*100:.2f}%)")

## Find Optimal Threshold

In [None]:
# Try different thresholds
thresholds = np.arange(0.1, 0.9, 0.01)

f1_scores = []
precision_scores = []
recall_scores = []

for threshold in thresholds:
    y_pred = (y_pred_proba >= threshold).astype(int)
    f1_scores.append(f1_score(y_test, y_pred))
    precision_scores.append(precision_score(y_test, y_pred, zero_division=0))
    recall_scores.append(recall_score(y_test, y_pred))

# Find optimal threshold
optimal_idx = np.argmax(f1_scores)
optimal_threshold = thresholds[optimal_idx]
optimal_f1 = f1_scores[optimal_idx]
optimal_precision = precision_scores[optimal_idx]
optimal_recall = recall_scores[optimal_idx]

print("="*60)
print("OPTIMAL THRESHOLD FOUND")
print("="*60)
print(f"\nOptimal threshold: {optimal_threshold:.2f}")
print(f"\nOptimized performance:")
print(f"  ROC-AUC: {default_roc_auc:.4f} ({default_roc_auc*100:.2f}%) [unchanged]")
print(f"  F1 Score: {optimal_f1:.4f} ({optimal_f1*100:.2f}%)")
print(f"  Precision: {optimal_precision:.4f} ({optimal_precision*100:.2f}%)")
print(f"  Recall: {optimal_recall:.4f} ({optimal_recall*100:.2f}%)")

print(f"\nImprovement over default (0.5):")
print(f"  F1: {(optimal_f1 - default_f1)*100:.2f} points ({((optimal_f1/default_f1 - 1)*100):.1f}% increase)")
print(f"  Precision: {(optimal_precision - default_precision)*100:.2f} points")
print(f"  Recall: {(optimal_recall - default_recall)*100:.2f} points")

## Visualize Threshold Impact

In [None]:
plt.figure(figsize=(12, 6))

plt.plot(thresholds, f1_scores, label='F1 Score', linewidth=2)
plt.plot(thresholds, precision_scores, label='Precision', linewidth=2, linestyle='--')
plt.plot(thresholds, recall_scores, label='Recall', linewidth=2, linestyle='--')

# Mark optimal threshold
plt.axvline(optimal_threshold, color='red', linestyle=':', linewidth=2, label=f'Optimal ({optimal_threshold:.2f})')
plt.axvline(0.5, color='gray', linestyle=':', linewidth=1, label='Default (0.5)', alpha=0.5)

plt.xlabel('Classification Threshold', fontsize=12)
plt.ylabel('Score', fontsize=12)
plt.title('Classification Metrics vs Threshold', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()

# Save figure
figures_dir = Path('../reports/figures')
figures_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(figures_dir / 'threshold_optimization.png', dpi=300, bbox_inches='tight')

plt.show()

## Confusion Matrix Comparison

In [None]:
# Get predictions with optimal threshold
y_pred_optimal = (y_pred_proba >= optimal_threshold).astype(int)

# Confusion matrices
cm_default = confusion_matrix(y_test, y_pred_default)
cm_optimal = confusion_matrix(y_test, y_pred_optimal)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Default threshold
im1 = axes[0].imshow(cm_default, cmap='Blues')
axes[0].set_title(f'Default Threshold (0.5)\nF1={default_f1:.4f}', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')
axes[0].set_xticks([0, 1])
axes[0].set_yticks([0, 1])
axes[0].set_xticklabels(['Low', 'High'])
axes[0].set_yticklabels(['Low', 'High'])

for i in range(2):
    for j in range(2):
        axes[0].text(j, i, str(cm_default[i, j]), ha='center', va='center', fontsize=14)

# Optimal threshold
im2 = axes[1].imshow(cm_optimal, cmap='Blues')
axes[1].set_title(f'Optimal Threshold ({optimal_threshold:.2f})\nF1={optimal_f1:.4f}', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Actual')
axes[1].set_xticks([0, 1])
axes[1].set_yticks([0, 1])
axes[1].set_xticklabels(['Low', 'High'])
axes[1].set_yticklabels(['Low', 'High'])

for i in range(2):
    for j in range(2):
        axes[1].text(j, i, str(cm_optimal[i, j]), ha='center', va='center', fontsize=14)

plt.tight_layout()
plt.savefig(figures_dir / 'confusion_matrix_threshold_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## Save Optimal Threshold

In [None]:
# Save optimal threshold for deployment
threshold_info = {
    'optimal_threshold': optimal_threshold,
    'f1_score': optimal_f1,
    'precision': optimal_precision,
    'recall': optimal_recall,
    'roc_auc': default_roc_auc
}

models_dir = Path('../models')
models_dir.mkdir(exist_ok=True)

with open(models_dir / 'optimal_threshold.pkl', 'wb') as f:
    pickle.dump(threshold_info, f)

print("âœ“ Optimal threshold saved to models/optimal_threshold.pkl")

## Summary

In [None]:
print("="*60)
print("THRESHOLD OPTIMIZATION SUMMARY")
print("="*60)

print(f"\nBEFORE (threshold=0.5):")
print(f"  ROC-AUC: {default_roc_auc*100:.2f}%")
print(f"  F1 Score: {default_f1*100:.2f}%")
print(f"  Precision: {default_precision*100:.2f}%")
print(f"  Recall: {default_recall*100:.2f}%")

print(f"\nAFTER (threshold={optimal_threshold:.2f}):")
print(f"  ROC-AUC: {default_roc_auc*100:.2f}% [unchanged]")
print(f"  F1 Score: {optimal_f1*100:.2f}% (+{(optimal_f1-default_f1)*100:.2f})")
print(f"  Precision: {optimal_precision*100:.2f}% (+{(optimal_precision-default_precision)*100:.2f})")
print(f"  Recall: {optimal_recall*100:.2f}% (+{(optimal_recall-default_recall)*100:.2f})")

print(f"\nðŸ’¡ Use threshold={optimal_threshold:.2f} for deployment!")