Skip to content

Commit

Permalink
#489 simplified API
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 3, 2023
1 parent b38995e commit 2f9c150
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 29 deletions.
4 changes: 0 additions & 4 deletions arekit/common/folding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ class BaseDataFolding(object):
and how to such state into string.
"""

def __init__(self, supported_data_types=None):
assert(isinstance(supported_data_types, list) or supported_data_types is None)
self._supported_data_types = supported_data_types

def fold_doc_ids_set(self, doc_ids):
""" Perform the doc_ids folding process onto provided data_types
"""
Expand Down
14 changes: 3 additions & 11 deletions arekit/common/folding/nofold.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import collections

from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding

Expand All @@ -8,13 +6,7 @@ class NoFolding(BaseDataFolding):
""" The case of absent folding in experiment.
"""

def __init__(self, data_type):
assert(isinstance(data_type, DataType))
super(NoFolding, self).__init__(supported_data_types=[data_type])
self.__data_type = data_type

def fold_doc_ids_set(self, doc_ids):
assert(isinstance(doc_ids, collections.Iterable))
return {
self.__data_type: list(set(doc_ids))
}
assert(isinstance(doc_ids, dict) and len(doc_ids) == 1)
assert(isinstance(list(doc_ids.keys())[0], DataType))
return doc_ids
6 changes: 4 additions & 2 deletions arekit/contrib/utils/cv/two_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ class TwoClassCVFolding(BaseDataFolding):
"""

def __init__(self, supported_data_types, cv_count, splitter):
assert(isinstance(supported_data_types, list))
assert(isinstance(splitter, CrossValidationSplitter))
assert(isinstance(cv_count, int) and cv_count > 0)

if len(supported_data_types) > 2:
raise NotImplementedError("Experiments with such amount of data-types are not supported!")

super(TwoClassCVFolding, self).__init__(supported_data_types=supported_data_types)
super(TwoClassCVFolding, self).__init__()
self._supported_data_types = supported_data_types

self.__cv_count = cv_count
self.__splitter = splitter
Expand Down Expand Up @@ -59,7 +61,7 @@ def fold_doc_ids_set(self, doc_ids):
}

if self.__splitter is None:
raise NotImplementedError("Splitter has not been intialized!")
raise NotImplementedError("Splitter has not been initialized!")

it = self.__splitter.items_to_cv_pairs(doc_ids=set(doc_ids),
cv_count=self.__cv_count)
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/utils/test_csv_stream_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __launch(self, writer, target_extention):
#####
# Declaring pipeline related context parameters.
#####
no_folding = NoFolding(data_type=DataType.Train)
no_folding = NoFolding()
doc_provider = FooDocumentProvider()
text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), DefaultTextTokenizer(keep_tokens=True)])
train_pipeline = text_opinion_extraction_pipeline(
Expand All @@ -86,7 +86,7 @@ def __launch(self, writer, target_extention):
params_dict={
"data_folding": no_folding,
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": [0, 1]
"doc_ids": {DataType.Train: [0, 1]}
})

def test_csv_native(self):
Expand Down
6 changes: 2 additions & 4 deletions tests/tutorials/test_tutorial_data_foldings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ def test(self):
DataType.Test: [4, 5, 6, 7]
}

fixed_folding = FixedFolding()
print("Fixed folding:")
self.show_folding(fixed_folding, doc_ids=parts)
self.show_folding(FixedFolding(), doc_ids=parts)

no_folding = NoFolding(data_type=DataType.Train)
print("No folding:")
self.show_folding(no_folding, doc_ids=parts[DataType.Train])
self.show_folding(NoFolding(), doc_ids={DataType.Train: parts[DataType.Train]})

splitter_simple = SimpleCrossValidationSplitter(shuffle=True, seed=1)

Expand Down
5 changes: 2 additions & 3 deletions tests/tutorials/test_tutorial_pipeline_sampling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def test(self):
#####
# Declaring pipeline related context parameters.
#####
no_folding = NoFolding(data_type=DataType.Train)
doc_provider = FooDocumentProvider()
text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), DefaultTextTokenizer(keep_tokens=True)])
train_pipeline = text_opinion_extraction_pipeline(
Expand All @@ -123,9 +122,9 @@ def test(self):

pipeline.run(input_data=None,
params_dict={
"data_folding": no_folding,
"data_folding": NoFolding(),
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": [0, 1]
"doc_ids": {DataType.Train: [0, 1]}
})

reader = PandasCsvReader()
Expand Down
5 changes: 2 additions & 3 deletions tests/tutorials/test_tutorial_pipeline_sampling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def test(self):
#####
# Declaring pipeline related context parameters.
#####
no_folding = NoFolding(data_type=DataType.Train)
doc_provider = FooDocumentProvider()
text_parser = BaseTextParser(pipeline=[
BratTextEntitiesParser(),
Expand All @@ -127,9 +126,9 @@ def test(self):

pipeline.run(input_data=None,
params_dict={
"data_folding": no_folding,
"data_folding": NoFolding(),
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": [0, 1]
"doc_ids": {DataType.Train: [0, 1]}
})

reader = PandasCsvReader()
Expand Down

0 comments on commit 2f9c150

Please sign in to comment.