In [None]:
import sys
from io import StringIO

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score, classification_report, log_loss
from sklearn.utils import shuffle


In [11]:
def load_mnist(class_0="3", class_1="8", random_state=42, split=0.8):
    """Load MNIST, select two classes, shuffle and return only n_samples."""
    # Load data from http://openml.org/d/554
    mnist = fetch_openml("mnist_784", version=1, as_frame=False)

    # take only two classes for binary classification
    mask = np.logical_or(mnist.target == class_0, mnist.target == class_1)

    X, y = shuffle(mnist.data[mask], mnist.target[mask], random_state=random_state)

    y_bin = (y == class_1).astype(int)  # 3 → 0 , 8 → 1
    X = X.astype(np.float32) / 255.0  # [0,255] → [0,1]

    n_obs = X.shape[0]
    n_train = int(n_obs * split)
    print(n_train)

    # 5) Re-create the original MNIST train/test split -------------------
    X_train, X_test = X[:n_train], X[n_train:]
    y_train, y_test = y_bin[:n_train], y_bin[n_train:]

    return X_train, y_train, X_test, y_test

In [12]:
X_train, y_train, X_test, y_test = load_mnist()

11172


In [None]:
# # -------------------------------------------------------------------
# # 1) Load & prepare the data  (same steps as before)
# # -------------------------------------------------------------------
# mnist = fetch_openml("mnist_784", version=1, as_frame=False)
# X, y = mnist["data"], mnist["target"].astype(str)

# mask = np.isin(y, ["3", "8"])
# X, y = X[mask], y[mask]

# y_bin = (y == "8").astype(int)  # 3 → 0 , 8 → 1
# X = X.astype(np.float32) / 255.0  # [0,255] → [0,1]

# n_obs = 13966
# n_train = 11000
# n_test = n_obs - n_train

# # 5) Re-create the original MNIST train/test split -------------------
# X_train, X_test = X[:n_train], X[n_train:]
# y_train, y_test = y_bin[:n_train], y_bin[n_train:]


In [13]:
# -------------------------------------------------------------------
# 2) Build the SGDClassifier  (same hyper-parameters as requested)
# -------------------------------------------------------------------
clf = SGDClassifier(
    loss="log_loss",  # logistic regression
    penalty="l2",
    alpha=1.0,  # μ = 1
    learning_rate="constant",
    eta0=0.001,  # γ = 0.001
    shuffle=False,  # we will shuffle manually each epoch
    tol=1e-6,  # stop when optimisation converges
    random_state=3317,
)


In [17]:
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()

In [22]:
# -------------------------------------------------------------------
# 3) Mini-batch training with partial_fit
# -------------------------------------------------------------------
batch_size = 10
n_epochs = 14  # feel free to increase if needed

for epoch in range(n_epochs):
    X_train, y_train = shuffle(X_train, y_train, random_state=epoch)

    for start in range(0, len(X_train), batch_size):
        end = start + batch_size
        X_batch = X_train[start:end]
        y_batch = y_train[start:end]

        # classes must be supplied ONLY the first time
        if epoch == 0 and start == 0:
            clf.partial_fit(X_batch, y_batch, classes=np.array([0, 1]))
        else:
            clf.partial_fit(X_batch, y_batch)

        probas = clf.predict_proba(X_batch)
        data_loss = log_loss(y_batch, probas, labels=[0, 1])

        # regularisation term  (½ * α * ||w||²)
        reg_loss = 0.5 * clf.alpha * np.sum(clf.coef_**2)
        print(f"loss_history: {(data_loss + reg_loss):.4f}")

    print(f"Epoch {epoch + 1:2d}/{n_epochs} done")


loss_history: 0.4574
loss_history: 0.6443
loss_history: 0.5615
loss_history: 0.5387
loss_history: 0.4898
loss_history: 0.5473
loss_history: 0.4311
loss_history: 0.4977
loss_history: 0.5581
loss_history: 0.4435
loss_history: 0.4835
loss_history: 0.5249
loss_history: 0.5260
loss_history: 0.4709
loss_history: 0.3793
loss_history: 0.4156
loss_history: 0.4563
loss_history: 0.4606
loss_history: 0.4718
loss_history: 0.4613
loss_history: 0.5336
loss_history: 0.4605
loss_history: 0.5478
loss_history: 0.4975
loss_history: 0.5438
loss_history: 0.4303
loss_history: 0.4651
loss_history: 0.5028
loss_history: 0.4945
loss_history: 0.4575
loss_history: 0.5180
loss_history: 0.4625
loss_history: 0.5334
loss_history: 0.3992
loss_history: 0.5887
loss_history: 0.4695
loss_history: 0.3906
loss_history: 0.3955
loss_history: 0.4688
loss_history: 0.5509
loss_history: 0.5346
loss_history: 0.5397
loss_history: 0.4207
loss_history: 0.5284
loss_history: 0.4968
loss_history: 0.4710
loss_history: 0.4942
loss_history:

In [19]:
sys.stdout = old_stdout
loss_history = mystdout.getvalue()

loss_list = []
for line in loss_history.split("\n"):
    if len(line.split("loss: ")) == 1:
        continue
    loss_list.append(float(line.split("loss: ")[-1]))

len(loss_list)

0

In [15]:
# -------------------------------------------------------------------
# 4) Evaluate
# -------------------------------------------------------------------
y_pred = clf.predict(X_test)
print("\nTest accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=["digit 3", "digit 8"]))


Test accuracy: 0.9108804581245526
              precision    recall  f1-score   support

     digit 3       0.88      0.96      0.92      1422
     digit 8       0.95      0.86      0.90      1372

    accuracy                           0.91      2794
   macro avg       0.91      0.91      0.91      2794
weighted avg       0.91      0.91      0.91      2794

