# Scam Token Detection — Balanced Random Forest Hyperparameter Tuning

This notebook implements a complete hyperparameter tuning pipeline for detecting scam tokens.

**Key components:**
- **Dataset**: ChainAbuse (scam) + CoinMarketCap (licit)
- **Model**: Balanced Random Forest
- **GridSearch**: Hyperparameter optimization via cross-validation
- **Threshold Tuning**: Decision threshold optimization on validation set
- **SMOTE**: Synthetic oversampling to handle class imbalance

## Goal
Find the optimal Balanced Random Forest configuration to maximize balanced accuracy on imbalanced scam detection data.

## 1. Imports

In [None]:
import os
import time
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.feature_selection import VarianceThreshold
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    balanced_accuracy_score,
    matthews_corrcoef,
    f1_score,
    average_precision_score,
    make_scorer,
)

from imblearn.over_sampling import SMOTE
from imblearn.ensemble import BalancedRandomForestClassifier
from imblearn.pipeline import Pipeline as ImbPipeline

import joblib

# Plotting settings
%matplotlib inline
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 10)

warnings.filterwarnings('ignore')

print("✓ Libraries imported")

## 2. Configuration

In [None]:
# Paths
DATA_DIR_CANDIDATES = [
    '../data/dataset_with_features',
    './data/dataset_with_features',
    'data/dataset_with_features',
]
DATA_DIR = next((p for p in DATA_DIR_CANDIDATES if os.path.exists(p)), DATA_DIR_CANDIDATES[0])
if not os.path.exists(DATA_DIR):
    raise FileNotFoundError(f"Could not find dataset_with_features. Tried: {DATA_DIR_CANDIDATES}")

OUTPUT_DIR = './results_brf_hyperparameter_tuning_final'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Data paths
CHAINABUSE_PATH = f'{DATA_DIR}/chainabuse_scam_tokens/features.csv'
CMC_PATH = f'{DATA_DIR}/cmc_licit_tokens/features.csv'

# Random seed
SEED = 42
np.random.seed(SEED)

# Train/Val/Test split
TEST_SIZE = 0.2
VAL_SIZE = 0.1

# Feature selection
VARIANCE_THRESHOLD = 0.001
CORRELATION_THRESHOLD = 0.95

# SMOTE configuration
SMOTE_PARAMS = {
    'sampling_strategy': 0.8,
    'random_state': SEED,
    'k_neighbors': 5,
}

# GridSearch CV configuration
CV_FOLDS = 5
N_JOBS = -1  # Use all CPU cores

print("="*80)
print("CONFIGURATION — BALANCED RANDOM FOREST HYPERPARAMETER TUNING")
print("="*80)
print(f"Data Directory: {DATA_DIR}")
print(f"Output Directory: {OUTPUT_DIR}")
print(f"Random Seed: {SEED}")
print(f"Test Size: {TEST_SIZE*100:.0f}%")
print(f"Validation Size: {VAL_SIZE*100:.0f}%")
print(f"GridSearch CV Folds: {CV_FOLDS}")
print(f"SMOTE strategy: {SMOTE_PARAMS['sampling_strategy']}")
print("="*80)

## 3. Load Data

## 4. Feature Engineering & Selection

In [None]:
print("\nFeature Engineering & Selection...\n")

# Extract features and target
exclude_cols = {'target', 'label', 'source', 'source_directory', 'token_address', 'token_file'}
drop_cols = [c for c in df.columns if c in exclude_cols]
X_df = df.drop(columns=drop_cols, errors='ignore')
X_df = X_df.select_dtypes(include=[np.number]).copy()
y = df['target'].astype(int).values

print(f"Initial numeric features: {X_df.shape[1]}")

# Train/Val/Test split (stratified)
X_temp, X_test_df, y_temp, y_test = train_test_split(
    X_df, y, test_size=TEST_SIZE, random_state=SEED, stratify=y
)
X_train_df, X_val_df, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=VAL_SIZE, random_state=SEED, stratify=y_temp
)

print(f"\nTrain/Val/Test split:")
print(f"  Train: {len(y_train):,} samples (scam={y_train.sum():,}, licit={len(y_train)-y_train.sum():,})")
print(f"  Val:   {len(y_val):,} samples (scam={y_val.sum():,}, licit={len(y_val)-y_val.sum():,})")
print(f"  Test:  {len(y_test):,} samples (scam={y_test.sum():,}, licit={len(y_test)-y_test.sum():,})")

# 1) Impute missing values
imputer = SimpleImputer(strategy='median')
X_train_imp = imputer.fit_transform(X_train_df)
X_val_imp = imputer.transform(X_val_df)
X_test_imp = imputer.transform(X_test_df)

# 2) Replace inf/-inf with finite values
MAX_VAL = 1e10
X_train_imp = np.nan_to_num(X_train_imp, nan=0.0, posinf=MAX_VAL, neginf=-MAX_VAL)
X_val_imp = np.nan_to_num(X_val_imp, nan=0.0, posinf=MAX_VAL, neginf=-MAX_VAL)
X_test_imp = np.nan_to_num(X_test_imp, nan=0.0, posinf=MAX_VAL, neginf=-MAX_VAL)

X_train_imp = np.clip(X_train_imp, -MAX_VAL, MAX_VAL)
X_val_imp = np.clip(X_val_imp, -MAX_VAL, MAX_VAL)
X_test_imp = np.clip(X_test_imp, -MAX_VAL, MAX_VAL)

# 3) Variance threshold
var_sel = VarianceThreshold(threshold=VARIANCE_THRESHOLD)
X_train_var = var_sel.fit_transform(X_train_imp)
X_val_var = var_sel.transform(X_val_imp)
X_test_var = var_sel.transform(X_test_imp)

var_feature_names = list(X_train_df.columns[var_sel.get_support()])
print(f"\nAfter variance threshold ({VARIANCE_THRESHOLD}): {len(var_feature_names)} features")

# 4) Correlation filter
df_train_var = pd.DataFrame(X_train_var, columns=var_feature_names)
corr = df_train_var.corr().abs().fillna(0.0)
upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
to_drop = [c for c in upper.columns if any(upper[c] > CORRELATION_THRESHOLD)]

keep_features = [c for c in var_feature_names if c not in to_drop]
keep_idx = [var_feature_names.index(c) for c in keep_features]

X_train = X_train_var[:, keep_idx]
X_val = X_val_var[:, keep_idx]
X_test = X_test_var[:, keep_idx]

print(f"After correlation filter (>{CORRELATION_THRESHOLD}): {len(keep_features)} features")
print(f"Total removed: {X_df.shape[1] - len(keep_features)} features")
print("\n✓ Feature engineering complete")

## 5. GridSearch Hyperparameter Tuning

In [None]:
print(f"\n{'='*80}")
print("GRIDSEARCH HYPERPARAMETER TUNING")
print(f"{'='*80}\n")

# Create pipeline with SMOTE + BalancedRandomForest
pipeline = ImbPipeline([
    ('smote', SMOTE(**SMOTE_PARAMS)),
    ('classifier', BalancedRandomForestClassifier(random_state=SEED, n_jobs=N_JOBS))
])

# Define hyperparameter grid
param_grid = {
    'classifier__n_estimators': [200, 400, 600],
    'classifier__max_depth': [None, 20, 30, 50],
    'classifier__min_samples_split': [2, 5, 10],
    'classifier__min_samples_leaf': [1, 2, 4],
    'classifier__max_features': ['sqrt', 'log2', None],
}

print("Hyperparameter grid:")
for param, values in param_grid.items():
    print(f"  {param}: {values}")

total_combinations = np.prod([len(v) for v in param_grid.values()])
print(f"\nTotal combinations: {total_combinations}")
print(f"Total fits: {total_combinations * CV_FOLDS}")
print(f"\nUsing {CV_FOLDS}-fold cross-validation with balanced_accuracy scoring...\n")

# GridSearch with balanced_accuracy scoring
grid_search = GridSearchCV(
    pipeline,
    param_grid,
    cv=CV_FOLDS,
    scoring='balanced_accuracy',
    n_jobs=N_JOBS,
    verbose=2,
    return_train_score=True,
)

print("Starting GridSearch (this may take 30-60+ minutes)...\n")
start_time = time.time()

grid_search.fit(X_train, y_train)

elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"GridSearch complete in {elapsed/60:.1f} minutes")
print(f"{'='*80}\n")

# Best parameters
print("Best parameters:")
for param, value in grid_search.best_params_.items():
    print(f"  {param}: {value}")

print(f"\nBest CV balanced accuracy: {grid_search.best_score_:.4f}")

# Save results
results_df = pd.DataFrame(grid_search.cv_results_)
results_path = os.path.join(OUTPUT_DIR, 'gridsearch_results.csv')
results_df.to_csv(results_path, index=False)
print(f"\n✓ GridSearch results saved to: {results_path}")

# Save best model
best_model_path = os.path.join(OUTPUT_DIR, 'best_brf_model.pkl')
joblib.dump(grid_search.best_estimator_, best_model_path)
print(f"✓ Best model saved to: {best_model_path}")

## 6. Threshold Tuning on Validation Set

In [None]:
print(f"\n{'='*80}")
print("THRESHOLD TUNING ON VALIDATION SET")
print(f"{'='*80}\n")

# Get validation predictions
best_model = grid_search.best_estimator_
y_val_proba = best_model.predict_proba(X_val)[:, 1]

# Test different thresholds
thresholds = np.arange(0.1, 0.91, 0.01)
val_scores = []

for threshold in thresholds:
    y_pred = (y_val_proba >= threshold).astype(int)
    score = balanced_accuracy_score(y_val, y_pred)
    val_scores.append(score)

best_idx = np.argmax(val_scores)
best_threshold = thresholds[best_idx]
best_val_score = val_scores[best_idx]

print(f"Optimal threshold: {best_threshold:.4f}")
print(f"Validation balanced accuracy: {best_val_score:.4f}")

# Plot threshold vs balanced accuracy
plt.figure(figsize=(12, 6))
plt.plot(thresholds, val_scores, linewidth=2, color='steelblue')
plt.axvline(best_threshold, color='red', linestyle='--', linewidth=2, 
            label=f'Optimal={best_threshold:.3f}')
plt.xlabel('Threshold', fontsize=12)
plt.ylabel('Balanced Accuracy', fontsize=12)
plt.title('Threshold Tuning on Validation Set', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()

threshold_plot_path = os.path.join(OUTPUT_DIR, 'threshold_tuning.png')
plt.savefig(threshold_plot_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ Threshold tuning plot saved to: {threshold_plot_path}")

## 7. Test Set Evaluation & Summary

In [None]:
print(f"\n{'='*80}")
print("TEST SET EVALUATION")
print(f"{'='*80}\n")

# Get test predictions with optimal threshold
y_test_proba = best_model.predict_proba(X_test)[:, 1]
y_test_pred = (y_test_proba >= best_threshold).astype(int)

# Calculate metrics
test_bal_acc = balanced_accuracy_score(y_test, y_test_pred)
test_roc_auc = roc_auc_score(y_test, y_test_proba)
test_pr_auc = average_precision_score(y_test, y_test_proba)
test_mcc = matthews_corrcoef(y_test, y_test_pred)
test_f1 = f1_score(y_test, y_test_pred)

print("Test Set Metrics:")
print(f"  Balanced Accuracy: {test_bal_acc:.4f}")
print(f"  ROC-AUC:          {test_roc_auc:.4f}")
print(f"  PR-AUC:           {test_pr_auc:.4f}")
print(f"  MCC:              {test_mcc:.4f}")
print(f"  F1 Score:         {test_f1:.4f}")

print(f"\nClassification Report:")
print(classification_report(y_test, y_test_pred, target_names=['Licit', 'Scam']))

cm = confusion_matrix(y_test, y_test_pred)
print(f"Confusion Matrix:")
print(f"  TN={cm[0,0]:,}, FP={cm[0,1]:,}")
print(f"  FN={cm[1,0]:,}, TP={cm[1,1]:,}")

# Save metrics
metrics = {
    'best_threshold': float(best_threshold),
    'val_balanced_accuracy': float(best_val_score),
    'test_balanced_accuracy': float(test_bal_acc),
    'test_roc_auc': float(test_roc_auc),
    'test_pr_auc': float(test_pr_auc),
    'test_mcc': float(test_mcc),
    'test_f1': float(test_f1),
    'test_tn': int(cm[0,0]),
    'test_fp': int(cm[0,1]),
    'test_fn': int(cm[1,0]),
    'test_tp': int(cm[1,1]),
}

metrics_path = os.path.join(OUTPUT_DIR, 'final_metrics.json')
import json
with open(metrics_path, 'w') as f:
    json.dump(metrics, f, indent=2)

print(f"\n✓ Final metrics saved to: {metrics_path}")

# Plot confusion matrix + ROC + PR curves
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Confusion Matrix
ax1 = axes[0]
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, ax=ax1,
            xticklabels=['Licit', 'Scam'], yticklabels=['Licit', 'Scam'], annot_kws={'fontsize': 14})
ax1.set_title('Confusion Matrix (Test Set)', fontsize=14, fontweight='bold')
ax1.set_ylabel('True Label', fontsize=12)
ax1.set_xlabel('Predicted Label', fontsize=12)

# ROC Curve
ax2 = axes[1]
fpr, tpr, _ = roc_curve(y_test, y_test_proba)
ax2.plot(fpr, tpr, linewidth=2.5, color='steelblue', label=f'ROC (AUC={test_roc_auc:.3f})')
ax2.plot([0, 1], [0, 1], 'k--', linewidth=1.5, alpha=0.3)
ax2.set_xlabel('False Positive Rate', fontsize=12)
ax2.set_ylabel('True Positive Rate', fontsize=12)
ax2.set_title('ROC Curve', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11, loc='lower right')
ax2.grid(True, alpha=0.3)

# Precision-Recall Curve
ax3 = axes[2]
precision, recall, _ = precision_recall_curve(y_test, y_test_proba)
ax3.plot(recall, precision, linewidth=2.5, color='coral', label=f'PR (AUC={test_pr_auc:.3f})')
ax3.set_xlabel('Recall', fontsize=12)
ax3.set_ylabel('Precision', fontsize=12)
ax3.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
ax3.legend(fontsize=11, loc='upper right')
ax3.grid(True, alpha=0.3)

plt.tight_layout()

viz_path = os.path.join(OUTPUT_DIR, 'evaluation_visualizations.png')
plt.savefig(viz_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Visualizations saved to: {viz_path}")

# FINAL SUMMARY
print(f"\n{'='*80}")
print("FINAL SUMMARY")
print(f"{'='*80}\n")

print("Dataset: ChainAbuse (scam) + CoinMarketCap (licit)")
print(f"Total samples: {len(df):,} | Features (after selection): {len(keep_features)}")

print(f"\nBest Hyperparameters:")
for param, value in grid_search.best_params_.items():
    print(f"  {param}: {value}")

print(f"\nOptimal Decision Threshold: {best_threshold:.4f}")

print(f"\nPerformance:")
print(f"  Validation Balanced Accuracy: {best_val_score:.4f}")
print(f"  Test Balanced Accuracy:       {test_bal_acc:.4f}")
print(f"  Test ROC-AUC:                 {test_roc_auc:.4f}")
print(f"  Test PR-AUC:                  {test_pr_auc:.4f}")
print(f"  Test F1 Score:                {test_f1:.4f}")
print(f"  Test MCC:                     {test_mcc:.4f}")

print(f"\n{'='*80}")
print("✓ TRAINING COMPLETE")
print(f"{'='*80}")

In [None]:
# Load the scam and legitimate token datasets
print("\nLoading ChainAbuse + CoinMarketCap datasets...\n")

def read_features_csv(path: str):
    """Read CSV with fast pyarrow engine if available, fallback to pandas."""
    try:
        return pd.read_csv(path, engine='pyarrow')
    except Exception:
        return pd.read_csv(path)

# Load scam tokens from ChainAbuse
print(f"Loading ChainAbuse (scam): {CHAINABUSE_PATH}")
df_chainabuse = read_features_csv(CHAINABUSE_PATH)
df_chainabuse['source'] = 'chainabuse_scam'
df_chainabuse['target'] = 1
print(f"  Loaded: {len(df_chainabuse):,} rows")

# Load legitimate tokens from CoinMarketCap
print(f"Loading CoinMarketCap (licit): {CMC_PATH}")
df_cmc = read_features_csv(CMC_PATH)
df_cmc['source'] = 'cmc_licit'
df_cmc['target'] = 0
print(f"  Loaded: {len(df_cmc):,} rows")

# Combine and shuffle
df = pd.concat([df_chainabuse, df_cmc], ignore_index=True)
df = df.sample(frac=1, random_state=SEED).reset_index(drop=True)

# Print dataset summary
print(f"\n{'='*80}")
print("DATASET OVERVIEW")
print(f"{'='*80}")
print(f"Total samples: {len(df):,}")
print(f"\nClass distribution:")
print(df['target'].value_counts().sort_index())
print(f"\nImbalance ratio (scam/licit): {df['target'].sum() / (len(df) - df['target'].sum()):.2f}")
print(f"\nSource distribution:")
print(df['source'].value_counts())
print(f"{'='*80}")