-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbert.py
115 lines (85 loc) · 4.32 KB
/
bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
from os.path import join, exists
from arekit.common.data.input.writers.tsv import TsvWriter
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.samples import BaseSampleStorageView
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.contrib.utils.data.views.opinions import BaseOpinionStorageView
from arekit.contrib.utils.model_io.utils import join_dir_with_subfolder_name, filename_template
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class DefaultBertIOUtils(BaseIOUtils):
""" This is a default file-based Input-output utils,
which describes file-paths towards the resources, required
for BERT-related data preparation.
"""
def _get_experiment_sources_dir(self):
""" Provides directory for samples.
"""
raise NotImplementedError()
def check_targets_existed(self):
model_dir = self.__get_target_dir()
if not exists(model_dir):
logger.info("Model dir does not exist. Skipping")
return False
exp_dir = join_dir_with_subfolder_name(
subfolder_name=self.__get_experiment_folder_name(),
dir=self._get_experiment_sources_dir())
if not exists(exp_dir):
logger.info("Experiment dir: {}".format(exp_dir))
logger.info("Experiment dir does not exist. Skipping")
return False
return
def get_target_dir(self):
return self.__get_target_dir()
# region experiment dir related
def __get_target_dir(self):
""" Provides a main directory for input
NOTE:
We consider to save serialized results into model dir,
rather than experiment dir in a base implementation,
as model affects on text_b, entities representation, etc.
"""
default_dir = join_dir_with_subfolder_name(
subfolder_name=self.__get_experiment_folder_name(),
dir=self._get_experiment_sources_dir())
return join(default_dir, self._exp_ctx.ModelIO.get_model_name())
def __get_experiment_folder_name(self):
return "{name}_{scale}l".format(name=self._exp_ctx.Name,
scale=str(self._exp_ctx.LabelsCount))
# endregion
# region public methods
def create_samples_view(self, data_type, data_folding):
return BaseSampleStorageView(
storage=BaseRowsStorage.from_tsv(filepath=self.__get_input_sample_filepath(
data_type=data_type, data_folding=data_folding)),
row_ids_provider=MultipleIDProvider())
def create_opinions_view(self, target):
storage = BaseRowsStorage.from_tsv(filepath=target, compression='infer')
return BaseOpinionStorageView(storage=storage)
def create_opinions_writer_target(self, data_type, data_folding):
return self.__get_input_opinions_filepath(data_type, data_folding=data_folding)
def create_samples_writer_target(self, data_type, data_folding):
return self.__get_input_sample_filepath(data_type, data_folding=data_folding)
def create_samples_writer(self):
return TsvWriter(write_header=True)
def create_opinions_writer(self):
return TsvWriter(write_header=False)
# endregion
# region private methods (filepaths)
def __get_input_opinions_filepath(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self.__get_target_dir(), template=template, prefix="opinion")
def __get_input_sample_filepath(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self.__get_target_dir(), template=template, prefix="sample")
@staticmethod
def __get_filepath(out_dir, template, prefix):
assert(isinstance(template, str))
assert(isinstance(prefix, str))
return join(out_dir, DefaultBertIOUtils.__generate_tsv_archive_filename(template=template, prefix=prefix))
@staticmethod
def __generate_tsv_archive_filename(template, prefix):
return "{prefix}-{template}.tsv.gz".format(prefix=prefix, template=template)
# endregion