# How to work with grouped data

One of the often appearing properties of the Data Science problems is the natural grouping of the data. You could for instance have multiple samples for the same customer. In such case, you need to make sure that all samples from a given group are in the same fold e.g. in Cross-Validation.

Let's prepare a dataset with groups.

In [6]:
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=100, n_features=10, random_state=42)
groups = [i % 5 for i in range(100)]
groups[:10]

[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]

The integers in `groups` variable indicate the group id, to which a given sample belongs.

One of the easiest ways to ensure that the data is split using the information about groups is using `from sklearn.model_selection import GroupKFold`. You can also read more about other ways of splitting data with groups in sklearn [here](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data).

In [23]:
from sklearn.model_selection import GroupKFold

cv = list(GroupKFold(n_splits=5).split(X, y, groups=groups))

Such variable can be passed to the `cv` parameter in `probatus` as well as to hyperparameter optimization e.g. `RandomizedSearchCV` classes.

In [24]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV

from probatus.feature_elimination import ShapRFECV

clf = RandomForestClassifier(random_state=42)

param_grid = {
    "n_estimators": [5, 7, 10],
    "max_leaf_nodes": [3, 5, 7, 10],
}
search = RandomizedSearchCV(clf, param_grid, cv=cv, n_iter=1, random_state=42)

shap_elimination = ShapRFECV(clf=search, step=0.2, cv=cv, scoring="roc_auc", n_jobs=3, random_state=42)
report = shap_elimination.fit_compute(X, y)

In [25]:
report

Unnamed: 0,num_features,features_set,eliminated_features,train_metric_mean,train_metric_std,val_metric_mean,val_metric_std
1,10,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]","[8, 7]",1.0,0.001,0.957,0.086
2,8,"[0, 1, 2, 3, 4, 5, 6, 9]",[5],0.999,0.001,0.966,0.055
3,7,"[0, 1, 2, 3, 4, 6, 9]",[4],1.0,0.0,0.942,0.114
4,6,"[0, 1, 2, 3, 6, 9]",[9],0.999,0.001,0.98,0.032
5,5,"[0, 1, 2, 3, 6]",[6],1.0,0.0,0.96,0.073
6,4,"[0, 1, 2, 3]",[1],0.999,0.001,0.951,0.091
7,3,"[0, 2, 3]",[3],0.999,0.001,0.971,0.052
8,2,"[0, 2]",[0],0.998,0.002,0.925,0.122
9,1,[2],[],0.998,0.002,0.938,0.098
