Skip to content

Commit

Permalink
#528: simplified external API for pipeline.run. Now we need to provid…
Browse files Browse the repository at this point in the history
…e parameters only via `input_data`.
  • Loading branch information
nicolay-r committed Oct 23, 2023
1 parent 012f6de commit 3e3af5b
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 27 deletions.
10 changes: 5 additions & 5 deletions arekit/contrib/utils/pipelines/items/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def apply_core(self, input_data, pipeline_ctx):
doc_ids: optional
this parameter allows to limit amount of documents considered for sampling
"""
assert (isinstance(pipeline_ctx, PipelineContext))
assert ("data_type_pipelines" in pipeline_ctx)
assert(isinstance(input_data, PipelineContext))
assert("data_type_pipelines" in input_data)

data_folding = pipeline_ctx.provide_or_none("data_folding")
data_folding = input_data.provide_or_none("data_folding")

self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
doc_ids=pipeline_ctx.provide_or_none("doc_ids"),
self._handle_iteration(data_type_pipelines=input_data.provide("data_type_pipelines"),
doc_ids=input_data.provide_or_none("doc_ids"),
data_folding=data_folding)
10 changes: 5 additions & 5 deletions tests/contrib/utils/test_csv_stream_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.experiment.data_type import DataType
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.input.providers.text_pair import PairTextProvider
from arekit.contrib.bert.terms.mapper import BertDefaultStringTextTermsMapper
Expand Down Expand Up @@ -80,11 +81,10 @@ def __launch(self, writer):
text_parser=text_parser)
#####

pipeline.run(input_data=None,
params_dict={
"data_type_pipelines": {DataType.Train: train_pipeline},
"data_folding": {DataType.Train: [0, 1]}
})
pipeline.run(input_data=PipelineContext(d={
"data_type_pipelines": {DataType.Train: train_pipeline},
"data_folding": {DataType.Train: [0, 1]}
}))

def test_csv_native(self):
""" Testing writing into CSV format
Expand Down
10 changes: 5 additions & 5 deletions tests/tutorials/test_tutorial_pipeline_sampling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.input.providers.text_pair import PairTextProvider
from arekit.contrib.bert.terms.mapper import BertDefaultStringTextTermsMapper
Expand Down Expand Up @@ -120,11 +121,10 @@ def test(self):
text_parser=text_parser)
#####

pipeline.run(input_data=None,
params_dict={
"data_type_pipelines": {DataType.Train: train_pipeline},
"data_folding": {DataType.Train: [0, 1]}
})
pipeline.run(input_data=PipelineContext(d={
"data_type_pipelines": {DataType.Train: train_pipeline},
"data_folding": {DataType.Train: [0, 1]}
}))

reader = PandasCsvReader()
source = join(self.__output_dir, "sample-train-0" + writer.extension())
Expand Down
10 changes: 5 additions & 5 deletions tests/tutorials/test_tutorial_pipeline_sampling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from arekit.common.labels.scaler.sentiment import SentimentLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.networks.input.ctx_serialization import NetworkSerializationContext
from arekit.contrib.source.brat.entities.parser import BratTextEntitiesParser
Expand Down Expand Up @@ -124,11 +125,10 @@ def test(self):
text_parser=text_parser)
#####

pipeline.run(input_data=None,
params_dict={
"data_type_pipelines": {DataType.Train: train_pipeline},
"data_folding": {DataType.Train: [0, 1]}
})
pipeline.run(input_data=PipelineContext(d={
"data_type_pipelines": {DataType.Train: train_pipeline},
"data_folding": {DataType.Train: [0, 1]}
}))

reader = PandasCsvReader()
source = join(self.__output_dir, "sample-train-0" + writer.extension())
Expand Down
10 changes: 5 additions & 5 deletions tests/tutorials/test_tutorial_pipeline_sampling_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.terms.mapper import BertDefaultStringTextTermsMapper
from arekit.contrib.prompt.sample import PromptedSampleRowProvider
Expand Down Expand Up @@ -121,11 +122,10 @@ def test(self):
text_parser=text_parser)
#####

pipeline.run(input_data=None,
params_dict={
"data_type_pipelines": {DataType.Train: train_pipeline},
"data_folding": {DataType.Train: [0, 1]}
})
pipeline.run(input_data=PipelineContext(d={
"data_type_pipelines": {DataType.Train: train_pipeline},
"data_folding": {DataType.Train: [0, 1]}
}))

reader = PandasCsvReader()
source = join(self.__output_dir, "prompt-sample-train-0" + writer.extension())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from arekit.common.labels.base import Label, NoLabel
from arekit.common.labels.provider.constant import ConstantLabelProvider
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.linkage.meta import MetaEmptyLinkedDataWrapper
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.docs.parsed.providers.entity_service import EntityServiceProvider, EntityEndType
from arekit.common.docs.parsed.service import ParsedDocumentService
Expand Down Expand Up @@ -85,8 +86,11 @@ def test(self):
text_parser=text_parser)

# Running the pipeline.
for linked in pipeline.run(input_data=[0], params_dict={}):
assert(isinstance(linked, TextOpinionsLinkage))
for linked in pipeline.run(input_data=[0]):
assert(isinstance(linked, TextOpinionsLinkage) or isinstance(linked, MetaEmptyLinkedDataWrapper))

if isinstance(linked, MetaEmptyLinkedDataWrapper):
continue

pns = linked.Tag
assert(isinstance(pns, ParsedDocumentService))
Expand Down

0 comments on commit 3e3af5b

Please sign in to comment.