Skip to content

Commit

Permalink
#518 refactored, #489 related
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 19, 2023
1 parent 0b5a044 commit 5457ec9
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 128 deletions.
32 changes: 0 additions & 32 deletions arekit/common/folding/fixed.py

This file was deleted.

12 changes: 0 additions & 12 deletions arekit/common/folding/nofold.py

This file was deleted.

6 changes: 1 addition & 5 deletions arekit/contrib/utils/io_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from os.path import join, exists

from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.contrib.utils.utils_folding import experiment_iter_index


logger = logging.getLogger(__name__)
Expand All @@ -23,9 +21,7 @@ def join_dir_with_subfolder_name(subfolder_name, dir):

def filename_template(data_type, data_folding):
assert(isinstance(data_type, DataType))
assert(isinstance(data_folding, BaseDataFolding))
return "{data_type}-{iter_index}".format(data_type=data_type.name.lower(),
iter_index=experiment_iter_index(data_folding))
return "{data_type}-0".format(data_type=data_type.name.lower())


def check_targets_existence(targets):
Expand Down
31 changes: 21 additions & 10 deletions arekit/contrib/utils/pipelines/items/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.common.folding.nofold import NoFolding
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem
Expand Down Expand Up @@ -33,8 +31,11 @@ def __init__(self, rows_provider, samples_io, save_labels_func, storage):
self._storage = storage

def _serialize_iteration(self, data_type, pipeline, data_folding, doc_ids):
assert (isinstance(data_type, DataType))
assert (isinstance(pipeline, BasePipeline))
assert(isinstance(data_type, DataType))
assert(isinstance(pipeline, BasePipeline))
assert(isinstance(data_folding, dict) or data_folding is None)
assert(isinstance(doc_ids, list) or doc_ids is None)
assert(doc_ids is not None or data_folding is not None)

repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
Expand All @@ -50,10 +51,21 @@ def _serialize_iteration(self, data_type, pipeline, data_folding, doc_ids):
}

for description, repo in repos.items():

if data_folding is None:
# Consider only the predefined doc_ids.
doc_ids_iter = doc_ids
else:
# Take particular data_type.
doc_ids_iter = data_folding[data_type]
# Consider only predefined doc_ids.
if doc_ids is not None:
doc_ids_iter = set(doc_ids_iter).intersection(doc_ids)

InputDataSerializationHelper.fill_and_write(
repo=repo,
pipeline=pipeline,
doc_ids_iter=data_folding.fold_doc_ids_set(doc_ids=doc_ids)[data_type],
doc_ids_iter=doc_ids_iter,
desc="{desc} [{data_type}]".format(desc=description, data_type=data_type),
writer=writer_and_targets[description][0],
target=writer_and_targets[description][1])
Expand All @@ -62,7 +74,6 @@ def _handle_iteration(self, data_type_pipelines, data_folding, doc_ids):
""" Performing data serialization for a particular iteration
"""
assert(isinstance(data_type_pipelines, dict))
assert(isinstance(data_folding, BaseDataFolding))
for data_type, pipeline in data_type_pipelines.items():
self._serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding,
doc_ids=doc_ids)
Expand All @@ -75,17 +86,17 @@ def apply_core(self, input_data, pipeline_ctx):
DataType.Test: BasePipeline
}
pipeline: doc_id -> parsed_doc -> annot -> opinion linkages
data_type_pipelines: doc_id -> parsed_doc -> annot -> opinion linkages
for example, function: sentiment_attitude_extraction_default_pipeline
doc_ids: optional
this parameter allows to limit amount of documents considered for sampling
"""
assert (isinstance(pipeline_ctx, PipelineContext))
assert ("data_type_pipelines" in pipeline_ctx)
assert ("doc_ids" in pipeline_ctx)

data_folding = pipeline_ctx.provide_or_none("data_folding")
data_folding = NoFolding() if data_folding is None else data_folding

for _ in folding_iter_states(data_folding):
self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
doc_ids=pipeline_ctx.provide("doc_ids"),
doc_ids=pipeline_ctx.provide_or_none("doc_ids"),
data_folding=data_folding)
1 change: 0 additions & 1 deletion arekit/contrib/utils/pipelines/items/sampling/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def _handle_iteration(self, data_type_pipelines, data_folding, doc_ids):
""" Performing data serialization for a particular iteration
"""
assert(isinstance(data_type_pipelines, dict))
assert(isinstance(data_folding, BaseDataFolding))

# Prepare for the present iteration.
self._rows_provider.clear_embedding_pairs()
Expand Down
5 changes: 1 addition & 4 deletions tests/contrib/utils/test_csv_stream_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider
from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
from arekit.common.pipeline.base import BasePipeline
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.input.providers.text_pair import PairTextProvider
Expand Down Expand Up @@ -65,7 +64,6 @@ def __launch(self, writer, target_extention):
#####
# Declaring pipeline related context parameters.
#####
no_folding = NoFolding()
doc_provider = FooDocumentProvider()
text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), DefaultTextTokenizer(keep_tokens=True)])
train_pipeline = text_opinion_extraction_pipeline(
Expand All @@ -84,9 +82,8 @@ def __launch(self, writer, target_extention):

pipeline.run(input_data=None,
params_dict={
"data_folding": no_folding,
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": {DataType.Train: [0, 1]}
"data_folding": {DataType.Train: [0, 1]}
})

def test_csv_native(self):
Expand Down
55 changes: 0 additions & 55 deletions tests/tutorials/test_tutorial_data_foldings.py

This file was deleted.

4 changes: 1 addition & 3 deletions tests/tutorials/test_tutorial_pipeline_sampling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.entities.types import OpinionEntityType
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
from arekit.common.labels.base import NoLabel, Label
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
Expand Down Expand Up @@ -122,9 +121,8 @@ def test(self):

pipeline.run(input_data=None,
params_dict={
"data_folding": NoFolding(),
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": {DataType.Train: [0, 1]}
"data_folding": {DataType.Train: [0, 1]}
})

reader = PandasCsvReader()
Expand Down
4 changes: 1 addition & 3 deletions tests/tutorials/test_tutorial_pipeline_sampling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from os.path import dirname, join

from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
from arekit.common.frames.variants.collection import FrameVariantsCollection
from arekit.common.labels.base import Label, NoLabel
from arekit.common.labels.scaler.sentiment import SentimentLabelScaler
Expand Down Expand Up @@ -126,9 +125,8 @@ def test(self):

pipeline.run(input_data=None,
params_dict={
"data_folding": NoFolding(),
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": {DataType.Train: [0, 1]}
"data_folding": {DataType.Train: [0, 1]}
})

reader = PandasCsvReader()
Expand Down
4 changes: 1 addition & 3 deletions tests/tutorials/test_tutorial_pipeline_sampling_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.entities.types import OpinionEntityType
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
from arekit.common.labels.base import NoLabel, Label
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
Expand Down Expand Up @@ -123,9 +122,8 @@ def test(self):

pipeline.run(input_data=None,
params_dict={
"data_folding": NoFolding(),
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": {DataType.Train: [0, 1]}
"data_folding": {DataType.Train: [0, 1]}
})

reader = PandasCsvReader()
Expand Down

0 comments on commit 5457ec9

Please sign in to comment.