Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ NiaAML
.. image:: https://coveralls.io/repos/github/lukapecnik/NiaAML/badge.svg?branch=travisCI_integration
:target: https://coveralls.io/github/lukapecnik/NiaAML?branch=travisCI_integration

.. image:: https://img.shields.io/pypi/v/niaaml.svg
:target: https://pypi.python.org/pypi/niaaml

.. image:: https://img.shields.io/pypi/pyversions/niaaml.svg
:target: https://pypi.org/project/NiaPy/

Expand Down
10 changes: 6 additions & 4 deletions niaaml/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import TestCase
from niaaml import Pipeline, OptimizationStats
from niaaml.classifiers import Bagging, AdaBoost
from niaaml.classifiers import RandomForest, AdaBoost
from niaaml.preprocessing.feature_selection import SelectKBest, SelectPercentile
from niaaml.preprocessing.feature_transform import StandardScaler, Normalizer
from niaaml.data import CSVDataReader
Expand All @@ -13,13 +13,13 @@ def setUp(self):
self.__pipeline = Pipeline(
feature_selection_algorithm=SelectKBest(),
feature_transform_algorithm=Normalizer(),
classifier=Bagging()
classifier=RandomForest()
)

def test_pipeline_optimize_works_fine(self):
data_reader = CSVDataReader(src=os.path.dirname(os.path.abspath(__file__)) + '/tests_files/dataset_header_classes.csv', has_header=True, contains_classes=True)

self.assertIsInstance(self.__pipeline.get_classifier(), Bagging)
self.assertIsInstance(self.__pipeline.get_classifier(), RandomForest)
self.assertIsInstance(self.__pipeline.get_feature_selection_algorithm(), SelectKBest)
self.assertIsInstance(self.__pipeline.get_feature_transform_algorithm(), Normalizer)

Expand All @@ -28,12 +28,14 @@ def test_pipeline_optimize_works_fine(self):
self.assertGreaterEqual(accuracy, -1.0)
self.assertLessEqual(accuracy, 0.0)

self.assertIsInstance(self.__pipeline.get_classifier(), Bagging)
self.assertIsInstance(self.__pipeline.get_classifier(), RandomForest)
self.assertIsInstance(self.__pipeline.get_feature_selection_algorithm(), SelectKBest)
self.assertIsInstance(self.__pipeline.get_feature_transform_algorithm(), Normalizer)

def test_pipeline_run_works_fine(self):
data_reader = CSVDataReader(src=os.path.dirname(os.path.abspath(__file__)) + '/tests_files/dataset_header_classes.csv', has_header=True, contains_classes=True)
print(data_reader.get_x())
print(data_reader.get_y())
self.__pipeline.optimize(data_reader.get_x(), data_reader.get_y(), 20, 40, 'ParticleSwarmAlgorithm', 'Accuracy')
predicted = self.__pipeline.run(numpy.random.uniform(low=0.0, high=15.0, size=(30, data_reader.get_x().shape[1])))

Expand Down