-
Notifications
You must be signed in to change notification settings - Fork 2
/
serialize_bert.py
116 lines (98 loc) · 5.75 KB
/
serialize_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
from os.path import join, dirname, basename
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
from arekit.common.labels.base import NoLabel
from arekit.common.labels.provider.constant import ConstantLabelProvider
from arekit.common.labels.scaler.single import SingleLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.news.entities_grouping import EntitiesGroupingPipelineItem
from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm
from arekit.common.opinions.annot.base import BaseOpinionAnnotator
from arekit.common.pipeline.base import BasePipeline
from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.csv_native import NativeCsvWriter
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.pipelines.items.sampling.bert import BertExperimentInputSerializerPipelineItem
from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser
from arelight.doc_ops import InMemoryDocOperations
from arelight.pipelines.annot_nolabel import create_neutral_annotation_pipeline
from arelight.pipelines.items.utils import input_to_docs
from arelight.samplers.bert import create_bert_sample_provider
from arelight.samplers.types import BertSampleProviderTypes
from examples.args import const, common
from examples.entities.factory import create_entity_formatter
from examples.utils import read_synonyms_collection
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Serialization script for obtaining sources, "
"required for inference and training.")
# Provide arguments.
common.InputTextArg.add_argument(parser, default=None)
common.FromFilesArg.add_argument(parser)
common.EntitiesParserArg.add_argument(parser, default="bert-ontonotes")
common.TermsPerContextArg.add_argument(parser, default=const.TERMS_PER_CONTEXT)
common.EntityFormatterTypesArg.add_argument(parser, default="hidden-bert-styled")
common.FromDataframeArg.add_argument(parser)
common.SynonymsCollectionFilepathArg.add_argument(parser, default=join(const.DATA_DIR, "synonyms.txt"))
common.PredictOutputFilepathArg.add_argument(parser, default=const.OUTPUT_TEMPLATE)
common.BertTextBFormatTypeArg.add_argument(parser, default='nli_m')
common.SentenceParserArg.add_argument(parser)
# Parsing arguments.
args = parser.parse_args()
# Parsing arguments.
text_from_arg = common.InputTextArg.read_argument(args)
texts_from_files = common.FromFilesArg.read_argument(args)
texts_from_dataframe = common.FromDataframeArg.read_argument(args)
entities_parser = common.EntitiesParserArg.read_argument(args)
sentence_parser = common.SentenceParserArg.read_argument(args)
entity_fmt = create_entity_formatter(common.EntityFormatterTypesArg.read_argument(args))
input_texts = text_from_arg if text_from_arg is not None else \
texts_from_files if texts_from_files is not None else texts_from_dataframe
opin_annot = BaseOpinionAnnotator()
doc_ops = InMemoryDocOperations(docs=input_to_docs(input_texts, sentence_parser=sentence_parser))
labels_fmt = StringLabelsFormatter(stol={"neu": NoLabel})
label_scaler = SingleLabelScaler(NoLabel())
backend_template = common.PredictOutputFilepathArg.read_argument(args)
synonyms = read_synonyms_collection(
filepath=common.SynonymsCollectionFilepathArg.read_argument(args))
annot_algo = PairBasedOpinionAnnotationAlgorithm(
dist_in_terms_bound=None,
label_provider=ConstantLabelProvider(label_instance=NoLabel()))
# Declare text parser.
text_parser = BaseTextParser(pipeline=[
TermsSplitterParser(),
entities_parser,
EntitiesGroupingPipelineItem(lambda value:
SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value(
synonyms=synonyms, value=value))
])
terms_per_context = common.TermsPerContextArg.read_argument(args)
# Initialize data processing pipeline.
test_pipeline = create_neutral_annotation_pipeline(synonyms=synonyms,
dist_in_terms_bound=terms_per_context,
terms_per_context=terms_per_context,
doc_ops=doc_ops,
text_parser=text_parser,
dist_in_sentences=0)
rows_provider = create_bert_sample_provider(
label_scaler=label_scaler,
provider_type=BertSampleProviderTypes.NLI_M,
entity_formatter=entity_fmt)
pipeline = BasePipeline([
BertExperimentInputSerializerPipelineItem(
rows_provider=rows_provider,
storage=RowCacheStorage(),
samples_io=SamplesIO(target_dir=dirname(backend_template),
prefix=basename(backend_template),
writer=NativeCsvWriter(delimiter=',')),
save_labels_func=lambda data_type: data_type != DataType.Test,
balance_func=lambda data_type: data_type == DataType.Train)
])
no_folding = NoFolding(doc_ids=list(range(len(texts_from_files))), supported_data_type=DataType.Test)
pipeline.run(input_data=None,
params_dict={
"data_folding": no_folding,
"data_type_pipelines": {DataType.Test: test_pipeline}
})