#**XG Boost**

In [None]:
%pip install -r ../requirements.txt

In [16]:
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from scipy.io import loadmat
import seaborn as sns
import numpy as np
import pandas as pd
import xgboost as xgb

#**LOAD MNIST**

In [21]:
mnist = loadmat('../data/MNIST.mat')

# Extract data
X_train_full = mnist['train_fea']
y_train_full = mnist['train_gnd'].ravel()
X_test = mnist['test_fea']
y_test = mnist['test_gnd'].ravel()

# Split training into train and validation (55k train, 5k val)
X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=5000, random_state=42)

# shift labels to 0-9
y_train = y_train - 1
y_val = y_val - 1
y_test = y_test - 1

print(f"Training set: {X_train.shape}, {y_train.shape}")
print(f"Validation set: {X_val.shape}, {y_val.shape}")
print(f"Test set: {X_test.shape}, {y_test.shape}")




Training set: (55000, 784), (55000,)
Validation set: (5000, 784), (5000,)
Test set: (10000, 784), (10000,)


#**MODEL FITTING AND EARLY STOPPING**
- may take a minute to run

In [27]:
# Early stopping callback to prevent overfitting
early_stop = xgb.callback.EarlyStopping(
    rounds=2, metric_name='logloss', data_name='validation_0', save_best=True
)
# Use "hist" for constructing the trees, with early stopping enabled.
clf = xgb.XGBClassifier(tree_method="hist", early_stopping_rounds=2)
# Fit the model, test sets are used for early stopping.
clf.fit(X_train, y_train, eval_set=[(X_val, y_val)])
# save model
clf.save_model("clf.json")

[0]	validation_0-mlogloss:1.38809
[1]	validation_0-mlogloss:1.06407
[2]	validation_0-mlogloss:0.85474
[3]	validation_0-mlogloss:0.70024
[4]	validation_0-mlogloss:0.58811
[5]	validation_0-mlogloss:0.50251
[6]	validation_0-mlogloss:0.43551
[7]	validation_0-mlogloss:0.38226
[8]	validation_0-mlogloss:0.33958
[9]	validation_0-mlogloss:0.30392
[10]	validation_0-mlogloss:0.27506
[11]	validation_0-mlogloss:0.25179
[12]	validation_0-mlogloss:0.23136
[13]	validation_0-mlogloss:0.21501
[14]	validation_0-mlogloss:0.20192
[15]	validation_0-mlogloss:0.18927
[16]	validation_0-mlogloss:0.17882
[17]	validation_0-mlogloss:0.16995
[18]	validation_0-mlogloss:0.16228
[19]	validation_0-mlogloss:0.15561
[20]	validation_0-mlogloss:0.14925
[21]	validation_0-mlogloss:0.14333
[22]	validation_0-mlogloss:0.13818
[23]	validation_0-mlogloss:0.13392
[24]	validation_0-mlogloss:0.12966
[25]	validation_0-mlogloss:0.12582
[26]	validation_0-mlogloss:0.12249
[27]	validation_0-mlogloss:0.11902
[28]	validation_0-mlogloss:0.1