In [1]:
#
# Inpired by https://arxiv.org/abs/1702.08835 and https://github.com/STO-OTZ/my_gcForest/
#
import numpy as np
import random
import uuid

from sklearn.datasets import fetch_mldata
from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

from deep_forest import MGCForest

# The MNIST dataset

In [2]:
mnist = fetch_mldata('MNIST original', data_home='~/scikit-learn-datasets')
mnist.data.shape

print('Data: {}, target: {}'.format(mnist.data.shape, mnist.target.shape))

Data: (70000, 784), target: (70000,)


In [3]:
X_train, X_test, y_train, y_test = train_test_split(
    mnist.data,
    mnist.target,
    test_size=0.2,
    random_state=42,
)

X_train = X_train.reshape((len(X_train), 28, 28))
X_test = X_test.reshape((len(X_test), 28, 28))

#
# Limit the size of the dataset
#
X_train = X_train[:2000]
y_train = y_train[:2000]
X_test = X_test[:2000]
y_test = y_test[:2000]

print('X_train:', X_train.shape, X_train.dtype)
print('y_train:', y_train.shape, y_train.dtype)
print('X_test:', X_test.shape)
print('y_test:', y_test.shape)

X_train: (2000, 28, 28) uint8
y_train: (2000,) float64
X_test: (2000, 28, 28)
y_test: (2000,)


## Using the MGCForest

Creates a simple *MGCForest* with 2 random forests for the *Multi-Grained-Scanning* process and 2 other random forests for the *Cascade* process.

In [4]:
mgc_forest = MGCForest(
    estimators_config={
        'mgs': [{
            'estimator_class': ExtraTreesClassifier,
            'estimator_params': {
                'n_estimators': 30,
                'min_samples_split': 21,
                'n_jobs': -1,
            }
        }, {
            'estimator_class': RandomForestClassifier,
            'estimator_params': {
                'n_estimators': 30,
                'min_samples_split': 21,
                'n_jobs': -1,
            }
        }],
        'cascade': [{
            'estimator_class': ExtraTreesClassifier,
            'estimator_params': {
                'n_estimators': 1000,
                'min_samples_split': 11,
                'max_features': 1,
                'n_jobs': -1,
            }
        }, {
            'estimator_class': ExtraTreesClassifier,
            'estimator_params': {
                'n_estimators': 1000,
                'min_samples_split': 11,
                'max_features': 'sqrt',
                'n_jobs': -1,
            }
        }, {
            'estimator_class': RandomForestClassifier,
            'estimator_params': {
                'n_estimators': 1000,
                'min_samples_split': 11,
                'max_features': 1,
                'n_jobs': -1,
            }
        }, {
            'estimator_class': RandomForestClassifier,
            'estimator_params': {
                'n_estimators': 1000,
                'min_samples_split': 11,
                'max_features': 'sqrt',
                'n_jobs': -1,
            }
        }]
    },
    stride_ratios=[1.0 / 4, 1.0 / 9, 1.0 / 16],
)

mgc_forest.fit(X_train, y_train)

<MultiGrainedScanner stride_ratio=0.25> - Scanning and fitting for X ((2000, 28, 28)) and y ((2000,)) started
<MultiGrainedScanner stride_ratio=0.25> - Scanning turned X ((2000, 28, 28)) into sliced_X ((22, 2000, 1078)). 484 new instances were added per sample
<MultiGrainedScanner stride_ratio=0.25> - Finished fitting X ((2000, 28, 28)) and got predictions with shape (44, 2000, 10)
<MultiGrainedScanner stride_ratio=0.1111111111111111> - Scanning and fitting for X ((2000, 28, 28)) and y ((2000,)) started
<MultiGrainedScanner stride_ratio=0.1111111111111111> - Scanning turned X ((2000, 28, 28)) into sliced_X ((26, 2000, 234)). 676 new instances were added per sample
<MultiGrainedScanner stride_ratio=0.1111111111111111> - Finished fitting X ((2000, 28, 28)) and got predictions with shape (52, 2000, 10)
<MultiGrainedScanner stride_ratio=0.0625> - Scanning and fitting for X ((2000, 28, 28)) and y ((2000,)) started
<MultiGrainedScanner stride_ratio=0.0625> - Scanning turned X ((2000, 28, 28)

In [5]:
y_pred = mgc_forest.predict(X_test)

print('Prediction shape:', y_pred.shape)
print(
    'Accuracy:', accuracy_score(y_test, y_pred),
    'F1 score:', f1_score(y_test, y_pred, average='weighted')
)

<MultiGrainedScanner stride_ratio=0.25> - Predicting X ((2000, 28, 28))
<MultiGrainedScanner stride_ratio=0.25> - Scanning turned X ((2000, 28, 28)) into sliced_X ((22, 2000, 1078)). 484 new instances were added per sample
<MultiGrainedScanner stride_ratio=0.1111111111111111> - Predicting X ((2000, 28, 28))
<MultiGrainedScanner stride_ratio=0.1111111111111111> - Scanning turned X ((2000, 28, 28)) into sliced_X ((26, 2000, 234)). 676 new instances were added per sample
<MultiGrainedScanner stride_ratio=0.0625> - Predicting X ((2000, 28, 28))
<MultiGrainedScanner stride_ratio=0.0625> - Scanning turned X ((2000, 28, 28)) into sliced_X ((28, 2000, 28)). 784 new instances were added per sample
<CascadeForest forests=4> - Shape of predictions: (4, 2000, 10) shape of X: (2000, 1520)
<CascadeForest forests=4> - Shape of predictions: (4, 2000, 10) shape of X: (2000, 1560)


Prediction shape: (2000,)
Accuracy: 0.9365 F1 score: 0.936298471966
