Skip to content

Commit

Permalink
FIX imports in Geometric-SMOTE examples and format code (scikit-learn…
Browse files Browse the repository at this point in the history
  • Loading branch information
joaopfonseca committed Dec 18, 2021
1 parent fa3ffe5 commit 467c557
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 56 deletions.
48 changes: 23 additions & 25 deletions examples/over-sampling/plot_geometric_smote_generation_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
Data generation mechanism
=========================
This example illustrates the Geometric SMOTE data
generation mechanism and the usage of its
This example illustrates the Geometric SMOTE data
generation mechanism and the usage of its
hyperparameters.
"""
Expand All @@ -16,9 +16,7 @@
import matplotlib.pyplot as plt

from sklearn.datasets import make_blobs
from imblearn.over_sampling import SMOTE

from gsmote import GeometricSMOTE
from imblearn.over_sampling import SMOTE, GeometricSMOTE

print(__doc__)

Expand Down Expand Up @@ -47,11 +45,11 @@ def generate_imbalanced_data(
def plot_scatter(X, y, title):
"""Function to plot some data as a scatter plot."""
plt.figure()
plt.scatter(X[y == 1, 0], X[y == 1, 1], label='Positive Class')
plt.scatter(X[y == 0, 0], X[y == 0, 1], label='Negative Class')
plt.scatter(X[y == 1, 0], X[y == 1, 1], label="Positive Class")
plt.scatter(X[y == 0, 0], X[y == 0, 1], label="Negative Class")
plt.xlim(*XLIM)
plt.ylim(*YLIM)
plt.gca().set_aspect('equal', adjustable='box')
plt.gca().set_aspect("equal", adjustable="box")
plt.legend()
plt.title(title)

Expand All @@ -66,9 +64,9 @@ def plot_hyperparameters(oversampler, X, y, param, vals, n_subplots):
for ax, val in zip(ax_arr, vals):
oversampler.set_params(**{param: val})
X_res, y_res = oversampler.fit_resample(X, y)
ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label='Positive Class')
ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label='Negative Class')
ax.set_title(f'{val}')
ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label="Positive Class")
ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label="Negative Class")
ax.set_title(f"{val}")
ax.set_xlim(*XLIM)
ax.set_ylim(*YLIM)

Expand All @@ -79,8 +77,8 @@ def plot_comparison(oversamplers, X, y):
fig, ax_arr = plt.subplots(1, 2, figsize=(15, 5))
for ax, (name, ovs) in zip(ax_arr, oversamplers):
X_res, y_res = ovs.fit_resample(X, y)
ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label='Positive Class')
ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label='Negative Class')
ax.scatter(X_res[y_res == 1, 0], X_res[y_res == 1, 1], label="Positive Class")
ax.scatter(X_res[y_res == 0, 0], X_res[y_res == 0, 1], label="Negative Class")
ax.set_title(name)
ax.set_xlim(*XLIM)
ax.set_ylim(*YLIM)
Expand All @@ -98,7 +96,7 @@ def plot_comparison(oversamplers, X, y):
X, y = generate_imbalanced_data(
200, 2, [(-2.0, 2.25), (1.0, 2.0)], 0.25, [-0.7, 2.3], [-0.5, 3.1]
)
plot_scatter(X, y, 'Imbalanced data')
plot_scatter(X, y, "Imbalanced data")

###############################################################################
# Geometric hyperparameters
Expand Down Expand Up @@ -133,13 +131,13 @@ def plot_comparison(oversamplers, X, y):
gsmote = GeometricSMOTE(
k_neighbors=1,
deformation_factor=0.0,
selection_strategy='minority',
selection_strategy="minority",
random_state=RANDOM_STATE,
)
truncation_factors = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
n_subplots = [2, 3]
plot_hyperparameters(gsmote, X, y, 'truncation_factor', truncation_factors, n_subplots)
plot_hyperparameters(gsmote, X, y, 'truncation_factor', -truncation_factors, n_subplots)
plot_hyperparameters(gsmote, X, y, "truncation_factor", truncation_factors, n_subplots)
plot_hyperparameters(gsmote, X, y, "truncation_factor", -truncation_factors, n_subplots)

###############################################################################
# Deformation factor
Expand All @@ -151,12 +149,12 @@ def plot_comparison(oversamplers, X, y):
gsmote = GeometricSMOTE(
k_neighbors=1,
truncation_factor=0.0,
selection_strategy='minority',
selection_strategy="minority",
random_state=RANDOM_STATE,
)
deformation_factors = np.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
n_subplots = [2, 3]
plot_hyperparameters(gsmote, X, y, 'deformation_factor', truncation_factors, n_subplots)
plot_hyperparameters(gsmote, X, y, "deformation_factor", truncation_factors, n_subplots)

###############################################################################
# Selection strategy
Expand All @@ -177,10 +175,10 @@ def plot_comparison(oversamplers, X, y):
deformation_factor=0.5,
random_state=RANDOM_STATE,
)
selection_strategies = np.array(['minority', 'majority', 'combined'])
selection_strategies = np.array(["minority", "majority", "combined"])
n_subplots = [1, 3]
plot_hyperparameters(
gsmote, X, y, 'selection_strategy', selection_strategies, n_subplots
gsmote, X, y, "selection_strategy", selection_strategies, n_subplots
)

###############################################################################
Expand All @@ -193,7 +191,7 @@ def plot_comparison(oversamplers, X, y):

X_new = np.vstack([X, np.array([2.0, 2.0])])
y_new = np.hstack([y, np.ones(1, dtype=np.int8)])
plot_scatter(X_new, y_new, 'Imbalanced data')
plot_scatter(X_new, y_new, "Imbalanced data")

###############################################################################
# When the number of ``k_neighbors`` is increased, SMOTE results to the
Expand All @@ -202,11 +200,11 @@ def plot_comparison(oversamplers, X, y):
# ``majority``.

oversamplers = [
('SMOTE', SMOTE(k_neighbors=2, random_state=RANDOM_STATE)),
("SMOTE", SMOTE(k_neighbors=2, random_state=RANDOM_STATE)),
(
'Geometric SMOTE',
"Geometric SMOTE",
GeometricSMOTE(
k_neighbors=2, selection_strategy='combined', random_state=RANDOM_STATE
k_neighbors=2, selection_strategy="combined", random_state=RANDOM_STATE
),
),
]
Expand Down
36 changes: 18 additions & 18 deletions examples/over-sampling/plot_geometric_smote_validation_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import LinearSVC
from sklearn.model_selection import validation_curve
from sklearn.metrics import make_scorer, cohen_kappa_score
from sklearn.metrics import make_scorer
from sklearn.datasets import make_classification
from imblearn.pipeline import make_pipeline
from imblearn.metrics import geometric_mean_score

from gsmote import GeometricSMOTE
from imblearn.over_sampling import GeometricSMOTE

print(__doc__)

Expand Down Expand Up @@ -81,12 +81,12 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
plt.scatter(param_range[idx_max], test_scores_mean[idx_max])
plt.title(title)
plt.ylabel(scoring_name)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_position(('outward', 10))
ax.spines["left"].set_position(("outward", 10))
ax.spines["bottom"].set_position(("outward", 10))
plt.ylim([0.9, 1.0])


Expand All @@ -107,11 +107,11 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
DecisionTreeClassifier(random_state=RANDOM_STATE),
)

scoring_name = 'Geometric Mean Score'
scoring_name = "Geometric Mean Score"
validation_curve_info = generate_validation_curve_info(
gsmote_gbc, X, y, range(1, 8), "geometricsmote__k_neighbors", SCORER
)
plot_validation_curve(validation_curve_info, scoring_name, 'K Neighbors')
plot_validation_curve(validation_curve_info, scoring_name, "K Neighbors")

validation_curve_info = generate_validation_curve_info(
gsmote_gbc,
Expand All @@ -121,7 +121,7 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
"geometricsmote__truncation_factor",
SCORER,
)
plot_validation_curve(validation_curve_info, scoring_name, 'Truncation Factor')
plot_validation_curve(validation_curve_info, scoring_name, "Truncation Factor")

validation_curve_info = generate_validation_curve_info(
gsmote_gbc,
Expand All @@ -131,17 +131,17 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
"geometricsmote__deformation_factor",
SCORER,
)
plot_validation_curve(validation_curve_info, scoring_name, 'Deformation Factor')
plot_validation_curve(validation_curve_info, scoring_name, "Deformation Factor")

validation_curve_info = generate_validation_curve_info(
gsmote_gbc,
X,
y,
['minority', 'majority', 'combined'],
["minority", "majority", "combined"],
"geometricsmote__selection_strategy",
SCORER,
)
plot_validation_curve(validation_curve_info, scoring_name, 'Selection Strategy')
plot_validation_curve(validation_curve_info, scoring_name, "Selection Strategy")

###############################################################################
# High Imbalance Ratio or low Samples to Features Ratio
Expand All @@ -158,11 +158,11 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
LinearSVC(random_state=RANDOM_STATE, max_iter=1e5),
)

scoring_name = 'Geometric Mean Score'
scoring_name = "Geometric Mean Score"
validation_curve_info = generate_validation_curve_info(
gsmote_gbc, X, y, range(1, 8), "geometricsmote__k_neighbors", SCORER
)
plot_validation_curve(validation_curve_info, scoring_name, 'K Neighbors')
plot_validation_curve(validation_curve_info, scoring_name, "K Neighbors")

validation_curve_info = generate_validation_curve_info(
gsmote_gbc,
Expand All @@ -172,7 +172,7 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
"geometricsmote__truncation_factor",
SCORER,
)
plot_validation_curve(validation_curve_info, scoring_name, 'Truncation Factor')
plot_validation_curve(validation_curve_info, scoring_name, "Truncation Factor")

validation_curve_info = generate_validation_curve_info(
gsmote_gbc,
Expand All @@ -182,14 +182,14 @@ def plot_validation_curve(validation_curve_info, scoring_name, title):
"geometricsmote__deformation_factor",
SCORER,
)
plot_validation_curve(validation_curve_info, scoring_name, 'Deformation Factor')
plot_validation_curve(validation_curve_info, scoring_name, "Deformation Factor")

validation_curve_info = generate_validation_curve_info(
gsmote_gbc,
X,
y,
['minority', 'majority', 'combined'],
["minority", "majority", "combined"],
"geometricsmote__selection_strategy",
SCORER,
)
plot_validation_curve(validation_curve_info, scoring_name, 'Selection Strategy')
plot_validation_curve(validation_curve_info, scoring_name, "Selection Strategy")
26 changes: 13 additions & 13 deletions imblearn/over_sampling/_smote/tests/test_geometric_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@pytest.mark.parametrize(
'center,surface_point',
"center,surface_point",
[
(CENTERS[0], SURFACE_POINTS[0]),
(CENTERS[1], SURFACE_POINTS[1]),
Expand All @@ -48,7 +48,7 @@ def test_make_geometric_sample_hypersphere(center, surface_point):


@pytest.mark.parametrize(
'surface_point,deformation_factor',
"surface_point,deformation_factor",
[
(np.array([1.0, 0.0]), 0.0),
(2.6 * np.array([0.0, 1.0]), 0.25),
Expand All @@ -68,7 +68,7 @@ def test_make_geometric_sample_half_hypersphere(surface_point, deformation_facto


@pytest.mark.parametrize(
'center,surface_point,truncation_factor',
"center,surface_point,truncation_factor",
[
(center, surface_point, truncation_factor)
for center, surface_point in zip(CENTERS, SURFACE_POINTS)
Expand All @@ -94,11 +94,11 @@ def test_make_geometric_sample_line_segment(center, surface_point, truncation_fa
def test_gsmote_default_init():
"""Test the intialization with default parameters."""
gsmote = GeometricSMOTE()
assert gsmote.sampling_strategy == 'auto'
assert gsmote.sampling_strategy == "auto"
assert gsmote.random_state is None
assert gsmote.truncation_factor == 1.0
assert gsmote.deformation_factor == 0.0
assert gsmote.selection_strategy == 'combined'
assert gsmote.selection_strategy == "combined"
assert gsmote.k_neighbors == 5
assert gsmote.n_jobs == 1

Expand All @@ -119,12 +119,12 @@ def test_gsmote_invalid_selection_strategy():
X, y = make_classification(
random_state=RND_SEED, n_samples=n_samples, weights=weights
)
gsmote = GeometricSMOTE(random_state=RANDOM_STATE, selection_strategy='Minority')
gsmote = GeometricSMOTE(random_state=RANDOM_STATE, selection_strategy="Minority")
with pytest.raises(ValueError):
gsmote.fit_resample(X, y)


@pytest.mark.parametrize('selection_strategy', ['combined', 'minority', 'majority'])
@pytest.mark.parametrize("selection_strategy", ["combined", "minority", "majority"])
def test_gsmote_nn(selection_strategy):
"""Test nearest neighbors object."""
n_samples, weights = 200, [0.6, 0.4]
Expand All @@ -135,14 +135,14 @@ def test_gsmote_nn(selection_strategy):
random_state=RANDOM_STATE, selection_strategy=selection_strategy
)
_ = gsmote.fit_resample(X, y)
if selection_strategy in ('minority', 'combined'):
if selection_strategy in ("minority", "combined"):
assert gsmote.nns_pos_.n_neighbors == gsmote.k_neighbors + 1
if selection_strategy in ('majority', 'combined'):
if selection_strategy in ("majority", "combined"):
assert gsmote.nn_neg_.n_neighbors == 1


@pytest.mark.parametrize(
'selection_strategy, truncation_factor, deformation_factor',
"selection_strategy, truncation_factor, deformation_factor",
[
(selection_strategy, truncation_factor, deformation_factor)
for selection_strategy in SELECTION_STRATEGY
Expand All @@ -160,7 +160,7 @@ def test_gsmote_fit_resample_binary(
radius = np.sqrt(0.5) * step
k_neighbors = 1
gsmote = GeometricSMOTE(
'auto',
"auto",
RANDOM_STATE,
truncation_factor,
deformation_factor,
Expand All @@ -174,7 +174,7 @@ def test_gsmote_fit_resample_binary(


@pytest.mark.parametrize(
'selection_strategy, truncation_factor, deformation_factor',
"selection_strategy, truncation_factor, deformation_factor",
[
(selection_strategy, truncation_factor, deformation_factor)
for selection_strategy in SELECTION_STRATEGY
Expand All @@ -196,7 +196,7 @@ def test_gsmote_fit_resample_multiclass(
)
k_neighbors, majority_label = 1, 0
gsmote = GeometricSMOTE(
'auto',
"auto",
RANDOM_STATE,
truncation_factor,
deformation_factor,
Expand Down

0 comments on commit 467c557

Please sign in to comment.