<a href="https://colab.research.google.com/github/cicerohen/hands-on-ml-book/blob/main/hands_on_chapter_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **üõ†Ô∏è Setup & Data Acquisition**
In this section, we prepare the laboratory and gather the raw materials.

In [35]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_curve, roc_auc_score

# Load Data
mnist = fetch_openml('mnist_784', as_frame=False)
X, y = mnist.data, mnist.target.astype(np.uint8)

# Train/Test Split
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

### **üéØ Binary Classification (The "5-Detector")**

We start simple: can the model distinguish a '5' from everything else?

In [36]:
# Create binary labels
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

# Train SGD
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

# Performance Evaluation
# Note: Accuracy is misleading here because only 10% of images are 5s.
cv_accuracy = cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
print(f"CV Accuracy: {cv_accuracy}")

CV Accuracy: [0.95035 0.96035 0.9604 ]


### **‚öñÔ∏è Metrics & The Precision/Recall Tradeoff**

Going beyond accuracy to understand False Positives and False Negatives.

In [37]:
# Confusion Matrix
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
cm = confusion_matrix(y_train_5, y_train_pred)

# Precision, Recall, and F1
print(f"Precision: {precision_score(y_train_5, y_train_pred)}")
print(f"Recall: {recall_score(y_train_5, y_train_pred)}")

# ROC Curve Comparison (SGD vs Random Forest)
forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method="predict_proba")
y_scores_forest = y_probas_forest[:, 1]
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)

Precision: 0.8370879772350012
Recall: 0.6511713705958311


### **üî¢ Multiclass Classification (0-9)**

Now we challenge the model to identify all ten digits using One-vs-All (OvA).

In [38]:
# Training the full model
sgd_clf.fit(X_train, y_train)

# Testing a specific digit (The '5' that often looks like a '3')
some_digit = X[0] # This is a 5
prediction = sgd_clf.predict([some_digit])
scores = sgd_clf.decision_function([some_digit])

print(f"Prediction: {prediction}")
print(f"Decision Scores: {scores}")

Prediction: [3]
Decision Scores: [[-31893.03095419 -34419.69069632  -9530.63950739   1823.73154031
  -22320.14822878  -1385.80478895 -26188.91070951 -16147.51323997
   -4604.35491274 -12050.767298  ]]


### **‚öôÔ∏è Feature Scaling & Performance Boost**

Transforming the data to provide a "Level Playing Field" for the SGD algorithm.

In [None]:
# Applying StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))

# Verifying Improvement
scaled_cv_accuracy = cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
print(f"Scaled CV Accuracy: {scaled_cv_accuracy}")

### **üìâ Error Analysis**

Investigating the "Why" behind the mistakes.

In [None]:
# Generate Confusion Matrix for all digits
y_train_pred_scaled = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred_scaled)

# Plotting the Confusion Matrix Heatmap
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

# Focus on errors: Divide each value by the number of images in the class
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
np.fill_diagonal(norm_conf_mx, 0) # Fill diagonal with 0 to highlight only errors
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()