forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 3
/
iforest.py
324 lines (255 loc) · 12 KB
/
iforest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# Authors: Nicolas Goix <nicolas.goix@telecom-paristech.fr>
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# License: BSD 3 clause
from __future__ import division
import numpy as np
import scipy as sp
from warnings import warn
from sklearn.utils.fixes import euler_gamma
from scipy.sparse import issparse
import numbers
from ..externals import six
from ..tree import ExtraTreeRegressor
from ..utils import check_random_state, check_array
from .bagging import BaseBagging
__all__ = ["IsolationForest"]
INTEGER_TYPES = (numbers.Integral, np.integer)
class IsolationForest(BaseBagging):
"""Isolation Forest Algorithm
Return the anomaly score of each sample using the IsolationForest algorithm
The IsolationForest 'isolates' observations by randomly selecting a feature
and then randomly selecting a split value between the maximum and minimum
values of the selected feature.
Since recursive partitioning can be represented by a tree structure, the
number of splittings required to isolate a sample is equivalent to the path
length from the root node to the terminating node.
This path length, averaged over a forest of such random trees, is a
measure of normality and our decision function.
Random partitioning produces noticeably shorter paths for anomalies.
Hence, when a forest of random trees collectively produce shorter path
lengths for particular samples, they are highly likely to be anomalies.
Read more in the :ref:`User Guide <isolation_forest>`.
.. versionadded:: 0.18
Parameters
----------
n_estimators : int, optional (default=100)
The number of base estimators in the ensemble.
max_samples : int or float, optional (default="auto")
The number of samples to draw from X to train each base estimator.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
- If "auto", then `max_samples=min(256, n_samples)`.
If max_samples is larger than the number of samples provided,
all samples will be used for all trees (no sampling).
contamination : float in (0., 0.5), optional (default=0.1)
The amount of contamination of the data set, i.e. the proportion
of outliers in the data set. Used when fitting to define the threshold
on the decision function.
max_features : int or float, optional (default=1.0)
The number of features to draw from X to train each base estimator.
- If int, then draw `max_features` features.
- If float, then draw `max_features * X.shape[1]` features.
bootstrap : boolean, optional (default=False)
If True, individual trees are fit on random subsets of the training
data sampled with replacement. If False, sampling without replacement
is performed.
n_jobs : integer, optional (default=1)
The number of jobs to run in parallel for both `fit` and `predict`.
If -1, then the number of jobs is set to the number of cores.
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
verbose : int, optional (default=0)
Controls the verbosity of the tree building process.
Attributes
----------
estimators_ : list of DecisionTreeClassifier
The collection of fitted sub-estimators.
estimators_samples_ : list of arrays
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator.
max_samples_ : integer
The actual number of samples
References
----------
.. [1] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest."
Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on.
.. [2] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation-based
anomaly detection." ACM Transactions on Knowledge Discovery from
Data (TKDD) 6.1 (2012): 3.
"""
def __init__(self,
n_estimators=100,
max_samples="auto",
contamination=0.1,
max_features=1.,
bootstrap=False,
n_jobs=1,
random_state=None,
verbose=0):
super(IsolationForest, self).__init__(
base_estimator=ExtraTreeRegressor(
max_features=1,
splitter='random',
random_state=random_state),
# here above max_features has no links with self.max_features
bootstrap=bootstrap,
bootstrap_features=False,
n_estimators=n_estimators,
max_samples=max_samples,
max_features=max_features,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose)
self.contamination = contamination
def _set_oob_score(self, X, y):
raise NotImplementedError("OOB score not supported by iforest")
def fit(self, X, y=None, sample_weight=None):
"""Fit estimator.
Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
The input samples. Use ``dtype=np.float32`` for maximum
efficiency. Sparse matrices are also supported, use sparse
``csc_matrix`` for maximum efficiency.
sample_weight : array-like, shape = [n_samples] or None
Sample weights. If None, then samples are equally weighted.
Returns
-------
self : object
"""
X = check_array(X, accept_sparse=['csc'])
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
X.sort_indices()
rnd = check_random_state(self.random_state)
y = rnd.uniform(size=X.shape[0])
# ensure that max_sample is in [1, n_samples]:
n_samples = X.shape[0]
if isinstance(self.max_samples, six.string_types):
if self.max_samples == 'auto':
max_samples = min(256, n_samples)
else:
raise ValueError('max_samples (%s) is not supported.'
'Valid choices are: "auto", int or'
'float' % self.max_samples)
elif isinstance(self.max_samples, INTEGER_TYPES):
if self.max_samples > n_samples:
warn("max_samples (%s) is greater than the "
"total number of samples (%s). max_samples "
"will be set to n_samples for estimation."
% (self.max_samples, n_samples))
max_samples = n_samples
else:
max_samples = self.max_samples
else: # float
if not (0. < self.max_samples <= 1.):
raise ValueError("max_samples must be in (0, 1], got %r"
% self.max_samples)
max_samples = int(self.max_samples * X.shape[0])
self.max_samples_ = max_samples
max_depth = int(np.ceil(np.log2(max(max_samples, 2))))
super(IsolationForest, self)._fit(X, y, max_samples,
max_depth=max_depth,
sample_weight=sample_weight)
self.threshold_ = -sp.stats.scoreatpercentile(
-self.decision_function(X), 100. * (1. - self.contamination))
return self
def predict(self, X):
"""Predict if a particular sample is an outlier or not.
Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
The input samples. Internally, it will be converted to
``dtype=np.float32`` and if a sparse matrix is provided
to a sparse ``csr_matrix``.
Returns
-------
is_inlier : array, shape (n_samples,)
For each observations, tells whether or not (+1 or -1) it should
be considered as an inlier according to the fitted model.
"""
X = check_array(X, accept_sparse='csr')
is_inlier = np.ones(X.shape[0], dtype=int)
is_inlier[self.decision_function(X) <= self.threshold_] = -1
return is_inlier
def decision_function(self, X):
"""Average anomaly score of X of the base classifiers.
The anomaly score of an input sample is computed as
the mean anomaly score of the trees in the forest.
The measure of normality of an observation given a tree is the depth
of the leaf containing this observation, which is equivalent to
the number of splittings required to isolate this point. In case of
several observations n_left in the leaf, the average path length of
a n_left samples isolation tree is added.
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
The training input samples. Sparse matrices are accepted only if
they are supported by the base estimator.
Returns
-------
scores : array of shape (n_samples,)
The anomaly score of the input samples.
The lower, the more abnormal.
"""
# code structure from ForestClassifier/predict_proba
# Check data
X = check_array(X, accept_sparse='csr')
n_samples = X.shape[0]
n_samples_leaf = np.zeros((n_samples, self.n_estimators), order="f")
depths = np.zeros((n_samples, self.n_estimators), order="f")
if self._max_features == X.shape[1]:
subsample_features = False
else:
subsample_features = True
for i, (tree, features) in enumerate(zip(self.estimators_,
self.estimators_features_)):
if subsample_features:
X_subset = X[:, features]
else:
X_subset = X
leaves_index = tree.apply(X_subset)
node_indicator = tree.decision_path(X_subset)
n_samples_leaf[:, i] = tree.tree_.n_node_samples[leaves_index]
depths[:, i] = np.ravel(node_indicator.sum(axis=1))
depths[:, i] -= 1
depths += _average_path_length(n_samples_leaf)
scores = 2 ** (-depths.mean(axis=1) / _average_path_length(self.max_samples_))
# Take the opposite of the scores as bigger is better (here less
# abnormal) and add 0.5 (this value plays a special role as described
# in the original paper) to give a sense to scores = 0:
return 0.5 - scores
def _average_path_length(n_samples_leaf):
""" The average path length in a n_samples iTree, which is equal to
the average path length of an unsuccessful BST search since the
latter has the same structure as an isolation tree.
Parameters
----------
n_samples_leaf : array-like of shape (n_samples, n_estimators), or int.
The number of training samples in each test sample leaf, for
each estimators.
Returns
-------
average_path_length : array, same shape as n_samples_leaf
"""
if isinstance(n_samples_leaf, INTEGER_TYPES):
if n_samples_leaf <= 1:
return 1.
else:
return 2. * (np.log(n_samples_leaf - 1.) + euler_gamma) - 2. * (
n_samples_leaf - 1.) / n_samples_leaf
else:
n_samples_leaf_shape = n_samples_leaf.shape
n_samples_leaf = n_samples_leaf.reshape((1, -1))
average_path_length = np.zeros(n_samples_leaf.shape)
mask = (n_samples_leaf <= 1)
not_mask = np.logical_not(mask)
average_path_length[mask] = 1.
average_path_length[not_mask] = 2. * (
np.log(n_samples_leaf[not_mask] - 1.) + euler_gamma) - 2. * (
n_samples_leaf[not_mask] - 1.) / n_samples_leaf[not_mask]
return average_path_length.reshape(n_samples_leaf_shape)