Skip to content

Commit

Permalink
Feature mrmr (#622)
Browse files Browse the repository at this point in the history
* Wip 1

* new ideas

* new ideas

* commit

* commit

* exclude venv

* Wip2

* Wip3

* Wip 3.5

* Pushing some optim

* Doc string and examples

* Docstring WIP

* Docstring WIP2

* Adding something

* Bugfix

* Mkdocs and small fixes

* Added tests

* Added scripts

* Wip4

* removing tests

* Added tests and some bugifx

* revert pandastransformer

* Update sklego/feature_selection/mrmr.py

Fbruzzesi suggestion

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>

* Resolving comments on PR

* features

* venv

* Add missing file

* Wip userguide

* typing

* Update sklego/feature_selection/mrmr.py

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>

* Update sklego/feature_selection/mrmr.py

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>

* Update docs/user-guide/feature-selection.md

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>

* Update sklego/feature_selection/mrmr.py

Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>

* resolve comments

* clean

* suggestions + general rephrase

* Typo

* Update docs/user-guide/feature-selection.md

Co-authored-by: vincent d warmerdam  <vincentwarmerdam@gmail.com>

---------

Co-authored-by: Fabio Scantamburlo <fabio.scantamburlo@nielsen.com>
Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com>
Co-authored-by: vincent d warmerdam <vincentwarmerdam@gmail.com>
  • Loading branch information
4 people committed Mar 9, 2024
1 parent a3c8c75 commit 64485a9
Show file tree
Hide file tree
Showing 11 changed files with 553 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ venv/
ENV/
env.bak/
venv.bak/
venv*/

# Spyder project settings
.spyderproject
Expand All @@ -120,4 +121,4 @@ dmypy.json
.DS_Store

# Local Netlify folder
.netlify
.netlify
122 changes: 122 additions & 0 deletions docs/_scripts/feature-selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from pathlib import Path

_file = Path(__file__)
print(f"Executing {_file}")


_static_path = Path("_static") / _file.stem
_static_path.mkdir(parents=True, exist_ok=True)

# --8<-- [start:mrmr-commonimports]
from sklearn.datasets import fetch_openml
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.feature_selection import f_classif, mutual_info_classif
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklego.feature_selection import MaximumRelevanceMinimumRedundancy
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
# --8<-- [end:mrmr-commonimports]

# --8<-- [start:mrmr-intro]

# Download MNIST dataset using scikit-learn
mnist = fetch_openml("mnist_784", cache=True)

# Assign features and labels
X_pd, y_pd = mnist["data"], mnist["target"].astype(int)

X, y = X_pd.to_numpy(), y_pd.to_numpy()
t_t_s_params = {'test_size': 10000, 'random_state': 42}
X_train, X_test, y_train, y_test = train_test_split(X, y, **t_t_s_params)
X_train = X_train.reshape(60000, 28 * 28)
X_test = X_test.reshape(10000, 28 * 28)
# --8<-- [end:mrmr-intro]

# --8<-- [start:mrmr-smile]
def smile_relevance(X, y):
rows = 28
cols = 28
smiling_face = np.zeros((rows, cols), dtype=int)

# Set the values for the eyes, nose,
# and mouth with adjusted positions and sizes
# Left eye
smiling_face[10:13, 8:10] = 1
# Right eye
smiling_face[10:13, 18:20] = 1
# Upper part of the mouth
smiling_face[18:20, 10:18] = 1
# Left edge of the open mouth
smiling_face[16:18, 8:10] = 1
# Right edge of the open mouth
smiling_face[16:18, 18:20] = 1

# Add the nose as four pixels one pixel higher
smiling_face[14, 13:15] = 1
smiling_face[27, :] = 1
return smiling_face.reshape(rows * cols,)


def smile_redundancy(X, selected, left):
return np.ones(len(left))
# --8<-- [end:mrmr-smile]

# --8<-- [start:mrmr-core]
K = 38
mrmr = MaximumRelevanceMinimumRedundancy(k=K,
kind="auto",
redundancy_func="p",
relevance_func="f")
mrmr_s = MaximumRelevanceMinimumRedundancy(k=K,
redundancy_func=smile_redundancy,
relevance_func=smile_relevance)

f = f_classif(X_train ,y_train.reshape(60000,))[0]
f_features = np.argsort(np.nan_to_num(f, nan=np.finfo(float).eps))[-K:]
mi = mutual_info_classif(X_train, y_train.reshape(60000,))
mi_features = np.argsort(np.nan_to_num(mi, nan=np.finfo(float).eps))[-K:]
mrmr_features = mrmr.fit(X_train, y_train).selected_features_
mrmr_smile_features = mrmr_s.fit(X_train, y_train).selected_features_

# --8<-- [end:mrmr-core]
# --8<-- [start:mrmr-selected-features]
# Define features dictionary
features = {
"f_classif": f_features,
"mutual_info": mi_features,
"mrmr": mrmr_features,
"mrmr_smile": mrmr_smile_features,
}
for name, s_f in features.items():
model = HistGradientBoostingClassifier(random_state=42)
model.fit(X_train[:, s_f], y_train.squeeze())
y_pred = model.predict(X_test[:, s_f])
print(f"Feature selection method: {name}")
print(f"F1 score: {round(f1_score(y_test, y_pred, average="weighted"), 3)}")

# --8<-- [end:mrmr-selected-features]

# --8<-- [start:mrmr-plots]
# Create figure and axes for the plots
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Iterate through the features dictionary and plot the images
for idx, (name, s_f) in enumerate(features.items()):
row = idx // 2
col = idx % 2

a = np.zeros(28 * 28)
a[s_f] = 1
ax = axes[row, col]
plot_= sns.heatmap(a.reshape(28, 28), cmap="binary", ax=ax, cbar=False)
ax.set_title(name)




# --8<-- [end:mrmr-plots]
plt.tight_layout()
plt.savefig(_static_path / "mrmr-feature-selection-mnist.png")
plt.clf()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/api/feature-selection.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Features Selection

:::sklego.feature_selection.mrmr.MaximumRelevanceMinimumRedundancy
options:
show_root_full_path: true
show_root_heading: true
6 changes: 6 additions & 0 deletions docs/api/features-selection.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Features Selection

:::sklego.feature_selection.mrmr.MaximumRelevanceMinimumRedundancy
options:
show_root_full_path: true
show_root_heading: true
71 changes: 71 additions & 0 deletions docs/user-guide/feature-selection.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Feature Selection

## Maximum Relevance Minimum Redundancy

The [`Maximum Relevance Minimum Redundancy`][MaximumRelevanceMinimumRedundancy-api] (MRMR) is an iterative feature selection method commonly used in data science to select a subset of features from a larger feature set. The goal of MRMR is to choose features that have high *relevance* to the target variable while minimizing *redundancy* among the already selected features.

MRMR is heavily dependent on the two functions used to determine relevace and redundancy. However, the paper [Maximum Relevanceand Minimum Redundancy Feature Selection Methods for a Marketing Machine Learning Platform](https://arxiv.org/pdf/1908.05376.pdf) shows that using [f_classif](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_classif.html) or [f_regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_regression.html) as relevance function and Pearson correlation as redundancy function is the best choice for a variety of different problems and in general is a good choice.

Inspired by the Medium article [Feature Selection: How To Throw Away 95% of Your Data and Get 95% Accuracy](https://towardsdatascience.com/feature-selection-how-to-throw-away-95-of-your-data-and-get-95-accuracy-ad41ca016877) we showcase a practical application using the well known mnist dataset.

Note that although the default scikit-lego MRMR implementation uses redundancy and relevance as defined in [Maximum Relevanceand Minimum Redundancy Feature Selection Methods for a Marketing Machine Learning Platform](https://arxiv.org/pdf/1908.05376.pdf), our implementation offers the possibility of defining custom functions, that may be necessary in different scenarios depending on the data.

We will compare this list of well known filters method:

- F statistical test ([ANOVA F-test](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_classif.html)).
- Mutual information approximation based on sklearn implementation.

Against the default scikit-lego MRMR implementation and a custom MRMR implementation aimed to select features in order to draw a smiling face on the plot showing the minst letters.



??? example "MRMR imports"
```py
--8<-- "docs/_scripts/feature-selection.py:mrmr-commonimports"
```

```py title="MRMR mnist"
--8<-- "docs/_scripts/feature-selection.py:mrmr-intro"
```

As custom functions, we implemented the smile redundancy and smile relevance.

```py title="MRMR smile functions"
--8<-- "docs/_scripts/feature-selection.py:mrmr-smile"
```

Then we execute the main code part.

```py title="MRMR core"
--8<-- "docs/_scripts/feature-selection.py:mrmr-core"
```

After the execution it is possible to inspect the F1-score for the selected features:

```py title="MRMR mnist selected features"
--8<-- "docs/_scripts/feature-selection.py:mrmr-selected-features"
```

```console hl_lines="5-6"
Feature selection method: f_classif
F1 score: 0.854
Feature selection method: mutual_info
F1 score: 0.879
Feature selection method: mrmr
F1 score: 0.925
Feature selection method: mrmr_smile
F1 score: 0.849
```

The MRMR feature selection model provides better results compared against the other methods, although the smile technique performs rather good as well.

Finally, we can take a look at the selected features.

??? example "MRMR generate plots"
```py
--8<-- "docs/_scripts/feature-selection.py:mrmr-plots"
```

![selected-features-mrmr](../_static/feature-selection/mrmr-feature-selection-mnist.png)

[MaximumRelevanceMinimumRedundancy-api]: ../../api/feature-selection#sklego.feature_selection.mrmr.MaximumRelevanceMinimumRedundancy
2 changes: 2 additions & 0 deletions mkdocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ nav:
- Datasets: user-guide/datasets.md
- Linear Models: user-guide/linear-models.md
- Mixture Methods: user-guide/mixture-methods.md
- Feature Selection: user-guide/feature-selection.md
- Naive Bayes: user-guide/naive-bayes.md
- Meta Models: user-guide/meta-models.md
- Fairness: user-guide/fairness.md
Expand All @@ -147,6 +148,7 @@ nav:
- Meta: api/meta.md
- Metrics: api/metrics.md
- Mixture: api/mixture.md
- Feature Selection: api/feature-selection.md
- Model Selection: api/model-selection.md
- Naive Bayes: api/naive-bayes.md
- Neighbors: api/neighbors.md
Expand Down
5 changes: 5 additions & 0 deletions sklego/feature_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__all__ = [
"MaximumRelevanceMinimumRedundancy",
]

from sklego.feature_selection.mrmr import MaximumRelevanceMinimumRedundancy
Loading

0 comments on commit 64485a9

Please sign in to comment.