[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merantix-momentum/squirrel-datasets-core/blob/main/examples/05.XGBoost.ipynb)

# Install Squirrel and Squirrel Datasets

In [None]:
try:
    import squirrel
    import squirrel_datasets_core
    import numpy as np
    import xgboost
    import sklearn
    import matplotlib
except:
    !pip install -q --ignore-requires-python --upgrade squirrel-datasets-core numpy xgboost scikit-learn matplotlib # noqa
    import squirrel
    import squirrel_datasets_core

print(squirrel.__version__)
print(squirrel_datasets_core.__version__)

In this tutorial, we will fit an XGBClassifier on the MNIST dataset.

First, we define how we will transform the samples and compose our batches:

In [None]:
import numpy as np


def transform(sample):
    x, y = sample["image"], sample["label"]
    x = np.array(x).reshape(1, -1)
    y = np.array(y).reshape(-1)
    return x, y


def collation_fn(batch):
    x, y = zip(*batch)
    return (np.concatenate(x, axis=0), np.concatenate(y, axis=0))

Then, we can construct our training set:

In [None]:
from squirrel.catalog import Catalog

cat = Catalog.from_plugins()

it = (
    cat["mnist"]
    .get_driver()
    .get_iter("train")
    .async_map(transform)  # uses threadpool to parallelize
    .batched(1000, collation_fn=collation_fn)
    .take(4)  # only 4 batches for demonstration, comment out this line to train on the whole dataset
    .tqdm()  # so that we can monitor the progress of data loading
)

Ready for training!

In [None]:
import xgboost as xgb

clf = xgb.XGBClassifier()
for idx, sample in enumerate(it):
    if idx == 0:
        clf.fit(*sample)
    else:
        clf.fit(*sample, xgb_model=clf.get_booster())

Let's visualize the predictions:

In [None]:
import matplotlib.pyplot as plt

N = 10
test_samples = cat["mnist"].get_driver().get_iter("test").shuffle(100).map(transform).take(N).collect()

fix, axs = plt.subplots(1, N, figsize=(20, 10))
for i, (x, y) in enumerate(test_samples):
    y_hat = clf.predict(x)
    axs[i].imshow(x.reshape(28, 28), cmap=plt.cm.gray)
    axs[i].set_title(f"GT:{y}, Pred: {y_hat}")