# Step 2: K-Nearest Neighbors Classifier

This notebook implements a KNN classifier for MNIST digit classification with comprehensive methodology.

In [None]:
import numpy as np
import struct
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import cross_val_score
import seaborn as sns
import time

%matplotlib inline

## 1. Load Dataset

In [None]:
def read_idx_images(filename):
    with open(filename, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        images = np.fromfile(f, dtype=np.uint8).reshape(num_images, rows, cols)
    return images

def read_idx_labels(filename):
    with open(filename, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        labels = np.fromfile(f, dtype=np.uint8)
    return labels

In [None]:
train_images = read_idx_images('../data/train-images.idx3-ubyte')
train_labels = read_idx_labels('../data/train-labels.idx1-ubyte')
test_images = read_idx_images('../data/t10k-images.idx3-ubyte')
test_labels = read_idx_labels('../data/t10k-labels.idx1-ubyte')

print(f"Training set: {train_images.shape}")
print(f"Test set: {test_images.shape}")

## 2. Data Preprocessing

### 2.1 Flatten Images
Convert 28×28 images to 784-dimensional vectors

In [None]:
X_train = train_images.reshape(train_images.shape[0], -1)
X_test = test_images.reshape(test_images.shape[0], -1)

print(f"Flattened training data: {X_train.shape}")
print(f"Flattened test data: {X_test.shape}")

### 2.2 Normalize Pixel Values
Scale from [0, 255] to [0, 1] for better distance calculations

In [None]:
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
y_train = train_labels
y_test = test_labels

print(f"Normalized range: [{X_train.min():.2f}, {X_train.max():.2f}]")

### 2.3 Use Subset for Faster Training (Optional)
KNN can be slow with large datasets. We'll use a subset for cross-validation.

In [None]:
subset_size = 10000
X_train_subset = X_train[:subset_size]
y_train_subset = y_train[:subset_size]

print(f"Using {subset_size} samples for hyperparameter tuning")

## 3. Hyperparameter Tuning

Find the optimal K value using cross-validation

In [None]:
k_values = [1, 3, 5, 7, 9, 11, 15]
cv_scores = []

print("Testing different K values...")
for k in k_values:
    knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
    scores = cross_val_score(knn, X_train_subset, y_train_subset, cv=5, scoring='accuracy')
    cv_scores.append(scores.mean())
    print(f"K={k:2d}: {scores.mean():.4f} (+/- {scores.std():.4f})")

best_k = k_values[np.argmax(cv_scores)]
print(f"\nBest K value: {best_k} with accuracy: {max(cv_scores):.4f}")

### Visualize K vs Accuracy

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(k_values, cv_scores, 'bo-', linewidth=2, markersize=8)
plt.xlabel('K (Number of Neighbors)', fontsize=12)
plt.ylabel('Cross-Validation Accuracy', fontsize=12)
plt.title('KNN Hyperparameter Tuning', fontsize=14)
plt.grid(alpha=0.3)
plt.xticks(k_values)
plt.axvline(x=best_k, color='r', linestyle='--', label=f'Best K={best_k}')
plt.legend()
plt.tight_layout()
plt.show()

## 4. Train Final Model

Train KNN with optimal K on full training set

In [None]:
print(f"Training KNN with K={best_k} on full training set...")
start_time = time.time()

knn_final = KNeighborsClassifier(n_neighbors=best_k, n_jobs=-1)
knn_final.fit(X_train, y_train)

train_time = time.time() - start_time
print(f"Training completed in {train_time:.2f} seconds")

## 5. Model Evaluation

### 5.1 Make Predictions

In [None]:
print("Making predictions on test set...")
start_time = time.time()

y_pred = knn_final.predict(X_test)

pred_time = time.time() - start_time
print(f"Predictions completed in {pred_time:.2f} seconds")
print(f"Average prediction time: {(pred_time / len(X_test)) * 1000:.2f} ms per sample")

### 5.2 Overall Accuracy

In [None]:
accuracy = accuracy_score(y_test, y_pred)
error_rate = 1 - accuracy

print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Error Rate: {error_rate:.4f} ({error_rate*100:.2f}%)")
print(f"Misclassified samples: {np.sum(y_pred != y_test)} out of {len(y_test)}")

### 5.3 Confusion Matrix

In [None]:
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', square=True)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix', fontsize=14)
plt.tight_layout()
plt.show()

### 5.4 Per-Class Performance

In [None]:
print("Classification Report:")
print("=" * 60)
print(classification_report(y_test, y_pred, digits=4))

### 5.5 Per-Class Accuracy Visualization

In [None]:
per_class_accuracy = cm.diagonal() / cm.sum(axis=1)

plt.figure(figsize=(10, 6))
bars = plt.bar(range(10), per_class_accuracy, alpha=0.8)
for i, bar in enumerate(bars):
    if per_class_accuracy[i] < 0.95:
        bar.set_color('orange')
plt.xlabel('Digit Class', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Per-Class Accuracy', fontsize=14)
plt.xticks(range(10))
plt.ylim([0.9, 1.0])
plt.grid(axis='y', alpha=0.3)
plt.axhline(y=accuracy, color='r', linestyle='--', label=f'Overall: {accuracy:.4f}')
plt.legend()
plt.tight_layout()
plt.show()

print("\nPer-class accuracy:")
for i in range(10):
    print(f"Digit {i}: {per_class_accuracy[i]:.4f}")

## 6. Error Analysis

### 6.1 Find Misclassified Examples

In [None]:
misclassified_idx = np.where(y_pred != y_test)[0]
print(f"Total misclassified: {len(misclassified_idx)}")

### 6.2 Visualize Correctly Classified Samples

In [None]:
correct_idx = np.where(y_pred == y_test)[0]
sample_correct = np.random.choice(correct_idx, 20, replace=False)

fig, axes = plt.subplots(2, 10, figsize=(15, 3))
fig.suptitle('Correctly Classified Samples', fontsize=14, color='green')

for i, idx in enumerate(sample_correct):
    ax = axes[i // 10, i % 10]
    ax.imshow(test_images[idx], cmap='gray')
    ax.set_title(f'True: {y_test[idx]}\nPred: {y_pred[idx]}', fontsize=8)
    ax.axis('off')

plt.tight_layout()
plt.show()

### 6.3 Visualize Misclassified Samples

In [None]:
sample_errors = misclassified_idx[:20]

fig, axes = plt.subplots(2, 10, figsize=(15, 3))
fig.suptitle('Misclassified Samples', fontsize=14, color='red')

for i, idx in enumerate(sample_errors):
    ax = axes[i // 10, i % 10]
    ax.imshow(test_images[idx], cmap='gray')
    ax.set_title(f'True: {y_test[idx]}\nPred: {y_pred[idx]}', fontsize=8, color='red')
    ax.axis('off')

plt.tight_layout()
plt.show()

### 6.4 Most Common Misclassifications

In [None]:
error_pairs = list(zip(y_test[misclassified_idx], y_pred[misclassified_idx]))
from collections import Counter

most_common_errors = Counter(error_pairs).most_common(10)

print("Most common misclassifications (True → Predicted):")
print("=" * 50)
for (true_label, pred_label), count in most_common_errors:
    print(f"{true_label} → {pred_label}: {count} times")

## 7. Summary

### Model Performance Summary

In [None]:
print("=" * 60)
print("KNN CLASSIFIER - PERFORMANCE SUMMARY")
print("=" * 60)
print(f"Optimal K value: {best_k}")
print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Misclassified: {np.sum(y_pred != y_test)}/{len(y_test)}")
print(f"Training time: {train_time:.2f}s")
print(f"Prediction time: {pred_time:.2f}s ({(pred_time/len(X_test))*1000:.2f}ms per sample)")
print("=" * 60)

print("\nBest performing digits:")
best_digits = np.argsort(per_class_accuracy)[-3:][::-1]
for digit in best_digits:
    print(f"  Digit {digit}: {per_class_accuracy[digit]:.4f}")

print("\nWorst performing digits:")
worst_digits = np.argsort(per_class_accuracy)[:3]
for digit in worst_digits:
    print(f"  Digit {digit}: {per_class_accuracy[digit]:.4f}")

## Key Takeaways

1. **KNN Performance**: Achieved ~97% accuracy on MNIST without any feature engineering
2. **Optimal K**: Cross-validation helped identify the best K value
3. **Trade-offs**: 
   - Simple to implement and understand
   - No training time (just stores data)
   - Slow prediction time (needs to compare with all training samples)
4. **Common Errors**: Certain digit pairs (like 4-9, 3-5) are commonly confused
5. **Room for Improvement**: Neural networks can achieve 98-99%+ accuracy