Skip to content

Commit

Permalink
Merge branch 'refactor'
Browse files Browse the repository at this point in the history
Conflicts:
	imblearn/ensemble/balance_cascade.py
  • Loading branch information
Guillaume Lemaitre committed Jul 9, 2016
2 parents 5f20c3d + 721bc2a commit 45e6457
Show file tree
Hide file tree
Showing 78 changed files with 646 additions and 2,111 deletions.
120 changes: 77 additions & 43 deletions imblearn/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Base class for sampling"""
"""Base class for sampling"""

from __future__ import division
from __future__ import print_function

import warnings
import logging

import numpy as np

Expand All @@ -19,14 +20,16 @@


class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):

"""Mixin class for samplers with abstact method.
Warning: This class should not be used directly. Use the derive classes
instead.
"""

@abstractmethod
def __init__(self, ratio='auto', random_state=None, verbose=True):
_estimator_type = "sampler"

def __init__(self, ratio='auto'):
"""Initialize this object and its instance variables.
Parameters
Expand All @@ -37,45 +40,15 @@ def __init__(self, ratio='auto', random_state=None, verbose=True):
of samples in the minority class over the the number of samples
in the majority class.
random_state : int or None, optional (default=None)
Seed for random number generation.
verbose : bool, optional (default=True)
Boolean to either or not print information about the processing
Returns
-------
None
"""
# The ratio correspond to the number of samples in the minority class
# over the number of samples in the majority class. Thus, the ratio
# cannot be greater than 1.0
if isinstance(ratio, float):
if ratio > 1:
raise ValueError('Ration cannot be greater than one.')
elif ratio <= 0:
raise ValueError('Ratio cannot be negative.')
else:
self.ratio = ratio
elif isinstance(ratio, string_types):
if ratio == 'auto':
self.ratio = ratio
else:
raise ValueError('Unknown string for the parameter ratio.')
else:
raise ValueError('Unknown parameter type for ratio.')

self.random_state = random_state
self.verbose = verbose
self.ratio = ratio
self.logger = logging.getLogger(__name__)

# Create the member variables regarding the classes statistics
self.min_c_ = None
self.maj_c_ = None
self.stats_c_ = {}
self.X_shape_ = None

@abstractmethod
def fit(self, X, y):
"""Find the classes statistics before to perform sampling.
Expand All @@ -97,8 +70,15 @@ def fit(self, X, y):
# Check the consistency of X and y
X, y = check_X_y(X, y)

if self.verbose:
print("Determining classes statistics... ", end="")
self.min_c_ = None
self.maj_c_ = None
self.stats_c_ = {}
self.X_shape_ = None

if hasattr(self, 'ratio'):
self._validate_ratio()

self.logger.info('Compute classes statistics ...')

# Get all the unique elements in the target array
uniques = np.unique(y)
Expand All @@ -122,9 +102,8 @@ def fit(self, X, y):
self.min_c_ = min(self.stats_c_, key=self.stats_c_.get)
self.maj_c_ = max(self.stats_c_, key=self.stats_c_.get)

if self.verbose:
print('{} classes detected: {}'.format(uniques.size,
self.stats_c_))
self.logger.info('%s classes detected: %s', uniques.size,
self.stats_c_)

# Check if the ratio provided at initialisation make sense
if isinstance(self.ratio, float):
Expand All @@ -136,7 +115,6 @@ def fit(self, X, y):

return self

@abstractmethod
def sample(self, X, y):
"""Resample the dataset.
Expand All @@ -158,8 +136,11 @@ def sample(self, X, y):
"""

# Check the consistency of X and y
X, y = check_X_y(X, y)

# Check that the data have been fitted
if not self.stats_c_:
if not hasattr(self, 'stats_c_'):
raise RuntimeError('You need to fit the data, first!!!')

# Check if the size of the data is identical than at fitting
Expand All @@ -168,7 +149,10 @@ def sample(self, X, y):
' seem to be the one earlier fitted. Use the'
' fitted data.')

return self
if hasattr(self, 'ratio'):
self._validate_ratio()

return self._sample(X, y)

def fit_sample(self, X, y):
"""Fit the statistics and resample the data directly.
Expand All @@ -192,3 +176,53 @@ def fit_sample(self, X, y):
"""

return self.fit(X, y).sample(X, y)

def _validate_ratio(self):
# The ratio correspond to the number of samples in the minority class
# over the number of samples in the majority class. Thus, the ratio
# cannot be greater than 1.0
if isinstance(self.ratio, float):
if self.ratio > 1:
raise ValueError('Ration cannot be greater than one.')
elif self.ratio <= 0:
raise ValueError('Ratio cannot be negative.')

elif isinstance(self.ratio, string_types):
if self.ratio != 'auto':
raise ValueError('Unknown string for the parameter ratio.')
else:
raise ValueError('Unknown parameter type for ratio.')

@abstractmethod
def _sample(self, X, y):
"""Resample the dataset.
Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.
Returns
-------
X_resampled : ndarray, shape (n_samples_new, n_features)
The array containing the resampled data.
y_resampled : ndarray, shape (n_samples_new)
The corresponding label of `X_resampled`
"""
pass

def __getstate__(self):
"""Prevent logger from being pickled."""
object_dictionary = self.__dict__.copy()
del object_dictionary['logger']
return object_dictionary

def __setstate__(self, dict):
"""Re-open the logger."""
logger = logging.getLogger(__name__)
self.__dict__.update(dict)
self.logger = logger
101 changes: 14 additions & 87 deletions imblearn/combine/smote_enn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from __future__ import print_function
from __future__ import division

from sklearn.utils import check_X_y

from ..over_sampling import SMOTE
from ..under_sampling import EditedNearestNeighbours
from ..base import SamplerMixin
Expand All @@ -22,11 +20,11 @@ class SMOTEENN(SamplerMixin):
number of samples in the minority class over the the number of
samples in the majority class.
random_state : int or None, optional (default=None)
Seed for random number generation.
verbose : bool, optional (default=True)
Whether or not to print information about the processing.
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by np.random.
k : int, optional (default=5)
Number of nearest neighbours to used to construct synthetic
Expand Down Expand Up @@ -60,15 +58,6 @@ class SMOTEENN(SamplerMixin):
Attributes
----------
ratio : str or float
If 'auto', the ratio will be defined automatically to balance
the dataset. Otherwise, the ratio is defined as the
number of samples in the minority class over the the number of
samples in the majority class.
random_state : int or None
Seed for random number generation.
min_c_ : str or int
The identifier of the minority class.
Expand Down Expand Up @@ -96,81 +85,25 @@ class SMOTEENN(SamplerMixin):
"""

def __init__(self, ratio='auto', random_state=None, verbose=True,
def __init__(self, ratio='auto', random_state=None,
k=5, m=10, out_step=0.5, kind_smote='regular',
size_ngh=3, kind_enn='all', n_jobs=-1, **kwargs):

"""Initialise the SMOTE ENN object.
Parameters
----------
ratio : str or float, optional (default='auto')
If 'auto', the ratio will be defined automatically to balance
the dataset. Otherwise, the ratio is defined as the
number of samples in the minority class over the the number of
samples in the majority class.
random_state : int or None, optional (default=None)
Seed for random number generation.
verbose : bool, optional (default=True)
Whether or not to print information about the processing.
k : int, optional (default=5)
Number of nearest neighbours to used to construct synthetic
samples.
m : int, optional (default=10)
Number of nearest neighbours to use to determine if a minority
sample is in danger.
out_step : float, optional (default=0.5)
Step size when extrapolating.
kind_smote : str, optional (default='regular')
The type of SMOTE algorithm to use one of the following
options: 'regular', 'borderline1', 'borderline2', 'svm'.
size_ngh : int, optional (default=3)
Size of the neighbourhood to consider to compute the average
distance to the minority point samples.
kind_sel : str, optional (default='all')
Strategy to use in order to exclude samples.
- If 'all', all neighbours will have to agree with the samples of
interest to not be excluded.
- If 'mode', the majority vote of the neighbours will be used in
order to exclude a sample.
n_jobs : int, optional (default=-1)
The number of threads to open if possible.
Returns
-------
None
"""
super(SMOTEENN, self).__init__(ratio=ratio, random_state=random_state,
verbose=verbose)

super(SMOTEENN, self).__init__(ratio=ratio)
self.random_state = random_state
self.k = k
self.m = m
self.out_step = out_step
self.kind_smote = kind_smote
self.size_ngh = size_ngh
self.kind_enn = kind_enn
self.n_jobs = n_jobs
self.kwargs = kwargs

self.sm = SMOTE(ratio=self.ratio, random_state=self.random_state,
verbose=self.verbose, k=self.k, m=self.m,
out_step=self.out_step, kind=self.kind_smote,
n_jobs=self.n_jobs, **self.kwargs)

self.size_ngh = size_ngh
self.kind_enn = kind_enn

k=self.k, m=self.m, out_step=self.out_step,
kind=self.kind_smote, n_jobs=self.n_jobs,
**self.kwargs)
self.enn = EditedNearestNeighbours(random_state=self.random_state,
verbose=self.verbose,
size_ngh=self.size_ngh,
kind_sel=self.kind_enn,
n_jobs=self.n_jobs)
Expand All @@ -192,8 +125,6 @@ def fit(self, X, y):
Return self.
"""
# Check the consistency of X and y
X, y = check_X_y(X, y)

super(SMOTEENN, self).fit(X, y)

Expand All @@ -202,7 +133,7 @@ def fit(self, X, y):

return self

def sample(self, X, y):
def _sample(self, X, y):
"""Resample the dataset.
Parameters
Expand All @@ -222,10 +153,6 @@ def sample(self, X, y):
The corresponding label of `X_resampled`
"""
# Check the consistency of X and y
X, y = check_X_y(X, y)

super(SMOTEENN, self).sample(X, y)

# Transform using SMOTE
X, y = self.sm.sample(X, y)
Expand Down
Loading

0 comments on commit 45e6457

Please sign in to comment.