In [45]:
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import GridSearchCV
import numpy as np
import joblib, os

In [46]:
#Data
mnist = fetch_openml("mnist_784", version=1, cache=True, as_frame=False)
X = mnist["data"][:10000]
y = mnist["target"][:10000].astype(np.uint8)

print("Data shape:", X.shape, "Labels shape:", y.shape)

X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=2000, random_state=42, stratify=y
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=2000, random_state=42, stratify=y_train_val
)

scaler = StandardScaler()
X_s = scaler.fit_transform(X)
X_train_s = scaler.fit_transform(X_train)
X_val_s = scaler.transform(X_val)
X_test_s = scaler.transform(X_test)

Data shape: (10000, 784) Labels shape: (10000,)


In [47]:
#Models
models = {
    "extratrees": ExtraTreesClassifier(
        n_estimators=200,
        max_depth=None,
        n_jobs=-1,
        random_state=42
    ),
    "randomforest": RandomForestClassifier(
        n_estimators=300,
        max_depth=None,
        n_jobs=-1,
        random_state=42
    ),
}
#GridSearch
param_grid = {
    "extratrees": {
        "n_estimators": [100, 200, 400],
        "max_depth": [None, 20, 40],
        "max_features": ["sqrt", "log2", None]
    },
    "randomforest": {
        "n_estimators": [100, 300, 500],
        "max_depth": [None, 20, 40],
        "max_features": ["sqrt", "log2", None]
    }
}

#Train - Val
best_models = {}
val_scores = {}

for name, model in models.items():
    print(f"Grid search på {name}…")
    grid = GridSearchCV(
        model,
        param_grid[name],
        scoring="f1_macro",
        cv=3,   
        n_jobs=-1,
        verbose=1
    )
    grid.fit(X_train_s, y_train)
    best_models[name] = grid.best_estimator_

    val_preds = grid.best_estimator_.predict(X_val_s)
    val_f1 = f1_score(y_val, val_preds, average="macro")
    val_scores[name] = val_f1

    print(f"Best {name}: {grid.best_params_} (F1= {val_f1:.4f})")

Grid search på extratrees…
Fitting 3 folds for each of 27 candidates, totalling 81 fits
Best extratrees: {'max_depth': None, 'max_features': 'sqrt', 'n_estimators': 400} (F1= 0.9488)
Grid search på randomforest…
Fitting 3 folds for each of 27 candidates, totalling 81 fits
Best randomforest: {'max_depth': None, 'max_features': 'sqrt', 'n_estimators': 500} (F1= 0.9470)


In [48]:
#Val
best_name = max(val_scores, key=val_scores.get)
best_model = best_models[best_name]

print(f"\nBest Model: {best_name}")


Best Model: extratrees


In [49]:
#Test
test_preds = best_model.predict(X_test_s)

print("\nClassification report:")
print(classification_report(y_test, test_preds))


Classification report:
              precision    recall  f1-score   support

           0       0.96      0.99      0.98       200
           1       0.97      0.98      0.98       225
           2       0.93      0.92      0.93       198
           3       0.95      0.90      0.92       206
           4       0.98      0.93      0.95       196
           5       0.95      0.95      0.95       173
           6       0.95      0.99      0.97       203
           7       0.95      0.97      0.96       214
           8       0.95      0.93      0.94       189
           9       0.90      0.92      0.91       196

    accuracy                           0.95      2000
   macro avg       0.95      0.95      0.95      2000
weighted avg       0.95      0.95      0.95      2000



In [50]:
#Save
best_model.fit(X_s,y)
os.makedirs("models", exist_ok=True)
model_path = f"models/mnist_{best_name}.joblib"

joblib.dump(best_model, model_path)
joblib.dump(scaler, "models/mnist_scaler.joblib")

print(f"Model Saved as: {model_path}")

Model Saved as: models/mnist_extratrees.joblib
