/
whole_dataset_drift.py
56 lines (48 loc) · 1.97 KB
/
whole_dataset_drift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2023 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Module contains the WholeDatasetDrift check - deprecated."""
import warnings
from deepchecks.tabular.checks.train_test_validation import MultivariateDrift
class WholeDatasetDrift(MultivariateDrift):
"""
Calculate drift between the entire train and test datasets using a model trained to distinguish between them.
.. deprecated:: 0.9
The WholeDatasetDrift check is deprecated and will be removed in the 0.11 version. Please use the
MultivariateDrift check instead.
"""
def __init__(
self,
n_top_columns: int = 3,
min_feature_importance: float = 0.05,
max_num_categories_for_display: int = 10,
show_categories_by: str = 'largest_difference',
n_samples: int = 10_000,
random_state: int = 42,
test_size: float = 0.3,
min_meaningful_drift_score: float = 0.05,
**kwargs
):
warnings.warn(
'The WholeDatasetDrift check is deprecated and will be removed in the 0.11 version. '
'Please use the MultivariateDrift check instead.',
DeprecationWarning, stacklevel=2
)
MultivariateDrift.__init__(
self, n_top_columns,
min_feature_importance,
max_num_categories_for_display,
show_categories_by,
n_samples,
random_state,
test_size,
min_meaningful_drift_score,
**kwargs
)