-
Notifications
You must be signed in to change notification settings - Fork 389
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
418 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import os | ||
import sys | ||
import unittest | ||
|
||
from preprocessing.tests.run import * | ||
from validation.tests.run import * | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import os | ||
import sys | ||
import unittest | ||
|
||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
sys.path.append(BASE_DIR) | ||
|
||
from .test_validator_kfold import KFoldValidatorTest | ||
from .test_validator_split import SplitValidatorTest | ||
from .test_validator_with_dataset import WithDatasetValidatorTest | ||
from .test_validation_step import ValidationStepTest | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import unittest | ||
import numpy as np | ||
from validation_step import ValidationStep | ||
from validation_step import ValidationStepException | ||
|
||
class ValidationStepTest(unittest.TestCase): | ||
|
||
def test_create(self): | ||
data = { | ||
'train': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
'y': np.array([0, 0, 1, 1]) | ||
} | ||
} | ||
params = { | ||
'validator_type': 'split', | ||
'shuffle': False, | ||
'stratify': False, | ||
'train_ratio': 0.5 | ||
} | ||
vl = ValidationStep(data, params) | ||
self.assertEqual(1, vl.get_n_splits()) | ||
for X_train, y_train, X_validation, y_validation in vl.split(): | ||
self.assertEqual(X_train.shape[0], 2) | ||
self.assertEqual(y_train.shape[0], 2) | ||
self.assertEqual(X_validation.shape[0], 2) | ||
self.assertEqual(y_validation.shape[0], 2) | ||
|
||
|
||
def test_wrong_validator_type(self): | ||
with self.assertRaises(ValidationStepException) as context: | ||
data = {} | ||
params = { | ||
'validator_type': 'no_such_validator' | ||
} | ||
vl = ValidationStep(data, params) | ||
|
||
self.assertTrue('Unknown' in str(context.exception)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import unittest | ||
import numpy as np | ||
from validator_kfold import KFoldValidator | ||
from validator_base import BaseValidatorException | ||
|
||
class KFoldValidatorTest(unittest.TestCase): | ||
|
||
def test_create(self): | ||
data = { | ||
'train': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
'y': np.array([0, 0, 1, 1]) | ||
} | ||
} | ||
params = { | ||
'shuffle': False, | ||
'stratify': False, | ||
'k_folds': 2 | ||
} | ||
vl = KFoldValidator(data, params) | ||
self.assertEqual(params['k_folds'], vl.get_n_splits()) | ||
for X_train, y_train, X_validation, y_validation in vl.split(): | ||
self.assertEqual(X_train.shape[0], 2) | ||
self.assertEqual(y_train.shape[0], 2) | ||
self.assertEqual(X_validation.shape[0], 2) | ||
self.assertEqual(y_validation.shape[0], 2) | ||
|
||
def test_create_with_target_as_labels(self): | ||
data = { | ||
'train': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
'y': np.array(['a', 'b', 'a', 'b']) | ||
} | ||
} | ||
params = { | ||
'shuffle': True, | ||
'stratify': True, | ||
'k_folds': 2 | ||
} | ||
vl = KFoldValidator(data, params) | ||
self.assertEqual(params['k_folds'], vl.get_n_splits()) | ||
for X_train, y_train, X_validation, y_validation in vl.split(): | ||
self.assertEqual(X_train.shape[0], 2) | ||
self.assertEqual(y_train.shape[0], 2) | ||
self.assertEqual(X_validation.shape[0], 2) | ||
self.assertEqual(y_validation.shape[0], 2) | ||
|
||
def test_missing_data(self): | ||
with self.assertRaises(BaseValidatorException) as context: | ||
data = { | ||
'train': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
} | ||
} | ||
params = { | ||
'shuffle': True, | ||
'stratify': True, | ||
'k_folds': 2 | ||
} | ||
vl = KFoldValidator(data, params) | ||
|
||
self.assertTrue('Missing' in str(context.exception)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import unittest | ||
import numpy as np | ||
from validator_split import SplitValidator | ||
from validator_split import SplitValidatorException | ||
|
||
class SplitValidatorTest(unittest.TestCase): | ||
|
||
def test_create(self): | ||
data = { | ||
'train': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
'y': np.array([0, 0, 1, 1]) | ||
} | ||
} | ||
params = { | ||
'shuffle': False, | ||
'stratify': False, | ||
'train_ratio': 0.75 | ||
} | ||
vl = SplitValidator(data, params) | ||
self.assertEqual(1, vl.get_n_splits()) | ||
cnt = 0 | ||
for X_train, y_train, X_validation, y_validation in vl.split(): | ||
print(X_train.shape, X_validation.shape) | ||
self.assertEqual(X_train.shape[0], 3) | ||
self.assertEqual(y_train.shape[0], 3) | ||
self.assertEqual(X_validation.shape[0], 1) | ||
self.assertEqual(X_validation.shape[1], 2) | ||
self.assertEqual(y_validation.shape[0], 1) | ||
cnt += 1 | ||
self.assertEqual(cnt, 1) | ||
|
||
def wrong_split_value(self, split_value): | ||
with self.assertRaises(ValueError) as context: | ||
data = { | ||
'train': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
'y': np.array([0, 0, 1, 1]) | ||
} | ||
} | ||
params = { | ||
'shuffle': True, | ||
'stratify': True, | ||
'train_ratio': split_value | ||
} | ||
vl = SplitValidator(data, params) | ||
X_train, y_train, X_validation, y_validation = vl.split() | ||
self.assertTrue('should be' in str(context.exception)) | ||
|
||
def test_wrong_split_values(self): | ||
for i in [0.1, 0.9, 1.1, -0.1]: | ||
self.wrong_split_value(i) |
43 changes: 43 additions & 0 deletions
43
supervised/validation/tests/test_validator_with_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import unittest | ||
import numpy as np | ||
from validator_with_dataset import WithDatasetValidator | ||
from validator_with_dataset import WithDatasetValidatorException | ||
|
||
class WithDatasetValidatorTest(unittest.TestCase): | ||
|
||
def test_create(self): | ||
data = { | ||
'train': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
'y': np.array([0, 0, 1, 1]) | ||
}, | ||
'validation': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
'y': np.array([0, 0, 1, 1]) | ||
} | ||
} | ||
vl = WithDatasetValidator(data, {}) | ||
self.assertEqual(1, vl.get_n_splits()) | ||
cnt = 0 | ||
for X_train, y_train, X_validation, y_validation in vl.split(): | ||
self.assertEqual(X_train.shape[0], 4) | ||
self.assertEqual(y_train.shape[0], 4) | ||
self.assertEqual(X_validation.shape[0], 4) | ||
self.assertEqual(y_validation.shape[0], 4) | ||
cnt += 1 | ||
self.assertEqual(cnt, 1) | ||
|
||
def test_missing_data(self): | ||
with self.assertRaises(WithDatasetValidatorException) as context: | ||
data = { | ||
'train': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), | ||
'y': np.array([0, 0, 1, 1]) | ||
}, | ||
'validation': { | ||
'X': np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) | ||
} | ||
} | ||
vl = WithDatasetValidator(data, {}) | ||
|
||
self.assertTrue('Missing' in str(context.exception)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import logging | ||
log = logging.getLogger(__name__) | ||
|
||
|
||
from validator_kfold import KFoldValidator | ||
from validator_split import SplitValidator | ||
from validator_with_dataset import WithDatasetValidator | ||
|
||
class ValidationStepException(Exception): | ||
def __init__(self, message): | ||
Exception.__init__(self, message) | ||
log.error(message) | ||
|
||
class ValidationStep(): | ||
|
||
def __init__(self, data, params): | ||
self.data = data | ||
self.params = params | ||
|
||
# kfold is default validation technique | ||
self.validator_type = self.params.get('validator_type', 'kfold') | ||
|
||
if self.validator_type == 'kfold': | ||
self.validator = KFoldValidator(data, params) | ||
elif self.validator_type == 'split': | ||
self.validator = SplitValidator(data, params) | ||
elif self.validator_type == 'split': | ||
self.validator = WithDatasetValidator(data, params) | ||
else: | ||
msg = 'Unknown validation type: {0}'.format(self.validator_type) | ||
raise ValidationStepException(msg) | ||
|
||
def split(self): | ||
return self.validator.split() | ||
|
||
def get_n_splits(self): | ||
return self.validator.get_n_splits() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import logging | ||
log = logging.getLogger(__name__) | ||
|
||
class BaseValidatorException(Exception): | ||
def __init__(self, message): | ||
Exception.__init__(self, message) | ||
log.error(message) | ||
|
||
class BaseValidator(object): | ||
|
||
def __init__(self, data, params): | ||
self.data = data | ||
self.params = params | ||
self.validate() | ||
|
||
def validate(self): | ||
if self.data.get('train') is None: | ||
msg = 'Missing train data' | ||
raise BaseValidatorException(msg) | ||
for i in ['X', 'y']: | ||
if self.data['train'].get(i) is None: | ||
msg = 'Missing {0} in train data'.format(i) | ||
raise BaseValidatorException(msg) | ||
|
||
def split(self): | ||
pass | ||
|
||
def get_n_splits(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import logging | ||
log = logging.getLogger(__name__) | ||
|
||
from sklearn.model_selection import StratifiedKFold | ||
from sklearn.model_selection import KFold | ||
|
||
from validator_base import BaseValidator | ||
|
||
class KFoldValidator(BaseValidator): | ||
|
||
def __init__(self, data, params): | ||
BaseValidator.__init__(self, data, params) | ||
|
||
self.k_folds = self.params.get('k_folds', 5) | ||
self.shuffle = self.params.get('shuffle', True) | ||
self.stratify = self.params.get('stratify', False) | ||
self.random_seed = self.params.get('random_seed', 1706) | ||
|
||
if self.stratify: | ||
self.skf = StratifiedKFold(n_splits = self.k_folds, | ||
shuffle = self.shuffle, | ||
random_state = self.random_seed) | ||
else: | ||
self.skf = KFold(n_splits = self.k_folds, | ||
shuffle = self.shuffle, | ||
random_state = self.random_seed) | ||
def split(self): | ||
X = self.data['train']['X'] | ||
y = self.data['train']['y'] | ||
for train_index, validation_index in self.skf.split(X, y): | ||
X_train, y_train = X[train_index], y[train_index] | ||
X_validation, y_validation = X[validation_index], y[validation_index] | ||
yield X_train, y_train, X_validation, y_validation | ||
|
||
def get_n_splits(self): | ||
return self.k_folds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
|
||
from sklearn.model_selection import train_test_split | ||
|
||
from validator_base import BaseValidator | ||
|
||
class SplitValidatorException(Exception): | ||
def __init__(self, message): | ||
Exception.__init__(self, message) | ||
log.error(message) | ||
|
||
|
||
class SplitValidator(BaseValidator): | ||
|
||
def __init__(self, data, params): | ||
BaseValidator.__init__(self, data, params) | ||
|
||
self.train_ratio = self.params.get('train_ratio', 0.8) | ||
self.shuffle = self.params.get('shuffle', True) | ||
self.stratify = self.params.get('stratify', False) | ||
self.random_seed = self.params.get('random_seed', 1706) | ||
|
||
def split(self): | ||
X = self.data['train']['X'] | ||
y = self.data['train']['y'] | ||
|
||
X_train, X_validation, y_train, y_validation = train_test_split(X, y, | ||
train_size = self.train_ratio, | ||
test_size = 1.0 - self.train_ratio, | ||
stratify = y if self.stratify else None, | ||
random_state = self.random_seed) | ||
yield X_train, y_train, X_validation, y_validation | ||
|
||
|
||
|
||
def get_n_splits(self): | ||
return 1 | ||
|
||
|
||
''' | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.utils.fixes import bincount | ||
from sklearn.model_selection import train_test_split | ||
import logging | ||
logger = logging.getLogger('mljar') | ||
def validation_split(train, validation_train_split, stratify, shuffle, random_seed): | ||
if shuffle: | ||
else: | ||
if stratify is None: | ||
train, validation = data_split(validation_train_split, train) | ||
else: | ||
train, validation = data_split_stratified(validation_train_split, train, stratify) | ||
return train, validation | ||
''' |
Oops, something went wrong.