# Baseline k-Nearest Neighbors Classifier

We first create a baseline k-Nearest Neighbors classifier, which will serve as a reference for our classification task. We define our hyperparameter grid to search over `n_neighbors` values of 3, 5, and 7, and weights options 'uniform' and 'distance'. We use `GridSearchCV` with 5-fold cross-validation and the 'f1_weighted' scoring metric to tune our pipeline.

In [1]:
import os
import joblib
import logging

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
    confusion_matrix,
    roc_curve,
    auc,
)
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV

In [2]:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

In [3]:
OUT_DIR = "results/models"
OUT_VIS = "results/figures"
OUT_CSV = "results/csv"
for d in (OUT_DIR, OUT_VIS, OUT_CSV):
    os.makedirs(d, exist_ok=True)

In [4]:
def load_data():
    X_train = pd.read_csv("data/processed/X/train.csv")
    X_val = pd.read_csv("data/processed/X/val.csv")
    logging.info(f"Loaded features: X_train {X_train.shape}, X_val {X_val.shape}")

    y_train_df = pd.read_csv("data/processed/Y/train.csv")
    y_val_df = pd.read_csv("data/processed/Y/val.csv")
    logging.info(f"Loaded labels: y_train {y_train_df.shape}, y_val {y_val_df.shape}")

    y_train = y_train_df.iloc[:, 0]
    y_val = y_val_df.iloc[:, 0]
    logging.info(f"Initial y_train classes: {sorted(y_train.unique())}")

    X_train = X_train.select_dtypes(include=[np.number])
    X_val = X_val.select_dtypes(include=[np.number])
    logging.info(f"Numeric filter: X_train {X_train.shape}, X_val {X_val.shape}")

    threshold = 50
    counts = y_train.value_counts()
    logging.info(f"Pre-merge class counts: {counts.to_dict()}")
    rare = counts[counts < threshold].index.tolist()
    if rare:
        y_train = y_train.replace({cls: "Other" for cls in rare})
        y_val = y_val.replace({cls: "Other" for cls in rare})
        logging.info(f"Merged rare classes: {rare} -> 'Other'")
    else:
        logging.info("No rare classes to merge.")
    logging.info(f"Post-merge classes: {sorted(y_train.unique())}")

    return X_train, y_train, X_val, y_val

In [5]:
X_train, y_train, X_val, y_val = load_data()

2025-07-13 20:53:22,211 INFO Loaded features: X_train (88307, 425), X_val (18923, 425)
2025-07-13 20:53:22,223 INFO Loaded labels: y_train (88307, 1), y_val (18923, 1)
2025-07-13 20:53:22,230 INFO Initial y_train classes: ['Cell Junction', 'Cell Projection', 'Cell Surface', 'Cytoplasm', 'Endoplasmic Reticulum', 'Endosome', 'Golgi Apparatus', 'Lysosome', 'Membrane', 'Mitochondrion', 'Nucleus', 'Periplasm', 'Peroxisome', 'Plastid', 'Secreted', 'Vacuole', 'Virion']
2025-07-13 20:53:22,321 INFO Numeric filter: X_train (88307, 424), X_val (18923, 424)
2025-07-13 20:53:22,327 INFO Pre-merge class counts: {'Cytoplasm': 49960, 'Membrane': 16586, 'Secreted': 8934, 'Nucleus': 7277, 'Mitochondrion': 2519, 'Periplasm': 1019, 'Virion': 1007, 'Endoplasmic Reticulum': 308, 'Peroxisome': 159, 'Lysosome': 115, 'Vacuole': 112, 'Plastid': 102, 'Golgi Apparatus': 96, 'Cell Surface': 58, 'Endosome': 48, 'Cell Junction': 6, 'Cell Projection': 1}
2025-07-13 20:53:22,344 INFO Merged rare classes: ['Endosome',

In [6]:
def train_knn(X_train, y_train):
    pipeline = Pipeline([("scaler", StandardScaler()), ("knn", KNeighborsClassifier())])
    param_grid = {
        "knn__n_neighbors": [3, 5, 7],
        "knn__weights": ["uniform", "distance"],
    }
    grid = GridSearchCV(
        pipeline,
        param_grid=param_grid,
        cv=5,
        scoring="f1_weighted",
        n_jobs=-1,
        verbose=1,
        error_score="raise",
    )
    logging.info("Starting k-NN GridSearchCV...")
    grid.fit(X_train, y_train)
    logging.info(f"Best params: {grid.best_params_}")
    logging.info(f"Best CV F1-weighted: {grid.best_score_:.4f}")
    return grid.best_estimator_, grid.cv_results_

In [7]:
def evaluate_and_save(model, X, y, out_csv=OUT_CSV, out_vis=OUT_VIS):
    y_pred = model.predict(X)

    acc = accuracy_score(y, y_pred)
    f1w = f1_score(y, y_pred, average="weighted")
    logging.info(f"Validation accuracy: {acc:.4f}")
    logging.info(f"Validation F1-weighted: {f1w:.4f}")

    # Save classification report
    report = classification_report(y, y_pred, output_dict=True, zero_division=0)
    report_df = pd.DataFrame(report).transpose()
    report_df.to_csv(f"{out_csv}/knn_classification_report.csv")

    # Save confusion matrix
    classes = model.classes_
    cm = confusion_matrix(y, y_pred, labels=classes)
    cm_df = pd.DataFrame(cm, index=classes, columns=classes)
    cm_df.to_csv(f"{out_csv}/knn_confusion_matrix.csv")

    # Save ROC data
    y_bin = label_binarize(y, classes=classes)
    pred_bin = label_binarize(y_pred, classes=classes)
    fpr, tpr, roc_auc = {}, {}, {}
    for i, cls in enumerate(classes):
        fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], pred_bin[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        pd.DataFrame({"fpr": fpr[i], "tpr": tpr[i]}).to_csv(
            f"{out_csv}/knn_roc_curve_{cls}.csv", index=False
        )

    # Plot overall ROC
    plt.figure(figsize=(8, 6))
    for i, cls in enumerate(classes):
        plt.plot(fpr[i], tpr[i], label=f"{cls} (AUC={roc_auc[i]:.2f})")
    plt.plot([0, 1], [0, 1], "k--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("k-NN ROC Curves (One-vs-Rest)")
    plt.legend(loc="lower right", fontsize="small")
    plt.tight_layout()
    plt.savefig(f"{out_vis}/knn_roc_curves.png", dpi=150)
    plt.close()

    return acc, f1w, classes, roc_auc, fpr, tpr

In [8]:
def benchmark_k_values(X_train, y_train, X_val, y_val, ks=[3, 5, 7]):
    records = []
    for k in ks:
        model = Pipeline(
            [("scaler", StandardScaler()), ("knn", KNeighborsClassifier(n_neighbors=k))]
        )
        model.fit(X_train, y_train)
        y_pred = model.predict(X_val)
        acc = accuracy_score(y_val, y_pred)
        f1w = f1_score(y_val, y_pred, average="weighted")
        records.append({"k": k, "accuracy": acc, "f1_weighted": f1w})
    df = pd.DataFrame(records)
    # Plot
    df.set_index("k").plot.bar(rot=0)
    plt.ylabel("Score")
    plt.title("k-NN Benchmark: Accuracy & Weighted-F1 by k")
    plt.tight_layout()
    plt.savefig(f"{OUT_VIS}/knn_benchmark_ks.png", dpi=150)
    plt.close()
    df.to_csv(f"{OUT_CSV}/knn_benchmark_ks.csv", index=False)
    return df

In [9]:
def plot_confusion_subset(model, X, y, classes, out_vis=OUT_VIS):
    y_pred = model.predict(X)
    rpt = classification_report(y, y_pred, output_dict=True, zero_division=0)
    df_rpt = pd.DataFrame(rpt).transpose()
    # select two lowest-recall classes with at least 5 samples
    df_rpt["recall"] = df_rpt["recall"]
    eligible = df_rpt.loc[(df_rpt.index.isin(classes)) & (df_rpt["support"] >= 5)]
    worst = eligible.sort_values("recall").head(2).index.tolist()
    for cls in worst:
        cm = confusion_matrix(
            y, y_pred, labels=[cls], normalize="true"
        )  # 1x1 matrix? better to show binary: cls vs rest
        # create binary confusion
        binary_y = (y == cls).astype(int)
        binary_pred = (y_pred == cls).astype(int)
        cm2 = confusion_matrix(binary_y, binary_pred)
        sns.heatmap(
            cm2,
            annot=True,
            fmt="d",
            cmap="Blues",
            xticklabels=[f"!{cls}", cls],
            yticklabels=[f"!{cls}", cls],
        )
        plt.title(f"Binary Confusion: {cls} vs Rest")
        plt.ylabel("True")
        plt.xlabel("Pred")
        plt.tight_layout()
        plt.savefig(f"{out_vis}/knn_confusion_binary_{cls}.png", dpi=150)
        plt.close()

In [10]:
def plot_top_bottom_roc(roc_auc, fpr, tpr, classes, out_vis=OUT_VIS):
    # identify top and bottom
    sorted_auc = sorted(roc_auc.items(), key=lambda x: x[1])
    bottom = sorted_auc[0][0]
    top = sorted_auc[-1][0]
    plt.figure(figsize=(6, 5))
    for idx in [bottom, top]:
        cls = classes[idx]
        plt.plot(fpr[idx], tpr[idx], label=f"{cls} (AUC={roc_auc[idx]:.2f})")
    plt.plot([0, 1], [0, 1], "k--")
    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.title("ROC: Worst vs Best Classes")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{out_vis}/knn_roc_worst_best.png", dpi=150)
    plt.close()

In [11]:
def show_top5_neighbors(model, X_train, y_train, X_val, sample_idx=None, n_neighbors=5):
    if sample_idx is None:
        sample_idx = X_val.index[0]
    sample_vec = X_val.loc[[sample_idx]]
    knn = model.named_steps["knn"]
    scaler = model.named_steps["scaler"]
    sample_scaled = scaler.transform(sample_vec)
    distances, neighbors = knn.kneighbors(sample_scaled, n_neighbors=n_neighbors)
    logging.info(f"Sample index: {sample_idx}")
    for dist, nbr in zip(distances[0], neighbors[0]):
        seq_id = X_train.index[nbr]
        label = y_train.iloc[nbr]
        logging.info(f"Neighbor: {seq_id}, Label: {label}, Distance: {dist:.3f}")

In [12]:
best_knn, cv_results = train_knn(X_train, y_train)
joblib.dump(best_knn, os.path.join(OUT_DIR, "knn_baseline_best.pkl"))
pd.DataFrame(cv_results).to_csv(f"{OUT_CSV}/knn_cv_results.csv", index=False)

2025-07-13 20:53:22,397 INFO Starting k-NN GridSearchCV...


Fitting 5 folds for each of 6 candidates, totalling 30 fits


2025-07-13 20:56:13,562 INFO Best params: {'knn__n_neighbors': 3, 'knn__weights': 'distance'}
2025-07-13 20:56:13,564 INFO Best CV F1-weighted: 0.7892


In [13]:
acc, f1w, classes, roc_auc, fpr, tpr = evaluate_and_save(best_knn, X_val, y_val)

2025-07-13 20:56:20,435 INFO Validation accuracy: 0.8087
2025-07-13 20:56:20,436 INFO Validation F1-weighted: 0.8039


In [14]:
df_bench = benchmark_k_values(X_train, y_train, X_val, y_val)
plot_confusion_subset(best_knn, X_val, y_val, classes)
plot_top_bottom_roc(roc_auc, fpr, tpr, classes)
show_top5_neighbors(best_knn, X_train, y_train, X_val)

logging.info("All analyses complete.")

2025-07-13 20:56:55,092 INFO Sample index: 0
2025-07-13 20:56:55,093 INFO Neighbor: 62172, Label: Membrane, Distance: 11.433
2025-07-13 20:56:55,093 INFO Neighbor: 46286, Label: Membrane, Distance: 14.830
2025-07-13 20:56:55,094 INFO Neighbor: 44674, Label: Membrane, Distance: 16.192
2025-07-13 20:56:55,094 INFO Neighbor: 82422, Label: Membrane, Distance: 16.492
2025-07-13 20:56:55,094 INFO Neighbor: 20952, Label: Membrane, Distance: 16.528
2025-07-13 20:56:55,095 INFO All analyses complete.


In [15]:
import time
X_test = pd.read_csv("data/processed/X/test.csv")
y_test = pd.read_csv("data/processed/Y/test.csv").squeeze()
X_test = X_test.select_dtypes(include=[np.number])
X_test = X_test[X_train.columns]

# Measure training time
start_train = time.time()
best_knn.fit(X_train, y_train)
train_time = time.time() - start_train
print(f"Training time: {train_time:.2f} seconds")

# Predict and measure throughput
start_pred = time.time()
y_test_pred = best_knn.predict(X_test)
pred_time = time.time() - start_pred
throughput = len(X_test) / pred_time
print(f"Prediction throughput: {throughput:.0f} samples/second")

# Evaluate performance
accuracy = accuracy_score(y_test, y_test_pred)
f1w = f1_score(y_test, y_test_pred, average="weighted")
print(f"Test Accuracy: {accuracy:.3f}")
print(f"Weighted F1-Score: {f1w:.3f}")

# Per-class F1 for classes >=5% support
report = classification_report(y_test, y_test_pred, output_dict=True, zero_division=0)
report_df = pd.DataFrame(report).transpose()
threshold_count = len(y_test) * 0.05
selected = report_df.loc[(report_df.index != 'accuracy') & (report_df['support'] >= threshold_count)]
min_f1 = selected['f1-score'].min() if not selected.empty else None
print(f"Minimum per-class F1 for classes ≥5%: {min_f1:.3f}" if min_f1 is not None else "No classes ≥5% support")

# Save detailed test report
report_df.to_csv(f"{OUT_CSV}/knn_test_classification_report.csv")

Training time: 1.17 seconds
Prediction throughput: 2436 samples/second
Test Accuracy: 0.814
Weighted F1-Score: 0.809
Minimum per-class F1 for classes ≥5%: 0.440
Prediction throughput: 2436 samples/second
Test Accuracy: 0.814
Weighted F1-Score: 0.809
Minimum per-class F1 for classes ≥5%: 0.440
