In [183]:
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix

# Load the MNIST dataset
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist["data"], mnist["target"].astype(int)

# Check the distribution of digits in the dataset
unique, counts = np.unique(y, return_counts=True)
print(dict(zip(unique, counts)))  # This will show the number of each digit in the dataset

# Split the data into training and testing sets
X_train, y_train = X[:60000], y[:60000]
X_test, y_test = X[60000:], y[60000:]

# Create a binary target variable to identify the digit 5
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

# Check the distribution of the binary variable
print(f"Number of True in y_train_5: {np.sum(y_train_5)}")
print(f"Number of False in y_train_5: {len(y_train_5) - np.sum(y_train_5)}")

# Ensure there are instances of the digit 5
if np.sum(y_train_5) == 0:
    raise ValueError("The training set does not contain any instances of the digit 5.")

# Create and train the SVM classifier
svm_clf = Pipeline([
    ("scaler", StandardScaler()),
    ("svc", SVC(decision_function_shape='ovo'))
])

# Use cross_val_predict to get the predicted labels for y_train_5
y_train_predict = cross_val_predict(svm_clf, X_train, y_train_5, cv=3)

# Compute the confusion matrix
conf_matrix = confusion_matrix(y_train_5, y_train_predict)
print(conf_matrix)


{0: 6903, 1: 7877, 2: 6990, 3: 7141, 4: 6824, 5: 6313, 6: 6876, 7: 7293, 8: 6825, 9: 6958}
Number of True in y_train_5: 5421
Number of False in y_train_5: 54579
[[54481    98]
 [  465  4956]]


In [187]:
from sklearn.metrics import precision_score, recall_score, f1_score
precision = precision_score(y_train_5, y_train_predict)
recall = recall_score(y_train_5, y_train_predict)
f1score = f1_score(y_train_5, y_train_predict)
print("Precision:",precision)
print("Recall:", recall)
print("f1_score:",f1score)

Precision: 0.9806094182825484
Recall: 0.9142224681793027
f1_score: 0.9462529832935561
