Skip to content

Commit

Permalink
#496 done (version 0.24.0) #156 related, added unit test and prompting
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 16, 2023
1 parent 5ad9545 commit f047f03
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 10 deletions.
2 changes: 1 addition & 1 deletion arekit/common/data/const.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ID = 'id'
DOC_ID = 'doc_id'
TEXT = 'text_a'
LABEL = 'label'
LABEL_UINT = 'label'

# Global identifier of the opinion in the sampled data.
OPINION_ID = "opinion_id"
Expand Down
2 changes: 1 addition & 1 deletion arekit/common/data/input/providers/columns/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_columns_list_with_types(self):

# insert labels
if self.__store_labels:
dtypes_list.append((const.LABEL, 'int32'))
dtypes_list.append((const.LABEL_UINT, 'int32'))

# insert text columns
for col_name in self.__text_column_names:
Expand Down
2 changes: 1 addition & 1 deletion arekit/common/data/input/providers/rows/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __assign_value(column, value):
expected_label = text_opinion_linkage.get_linked_label()

if self.__store_labels:
row[const.LABEL] = self._label_provider.calculate_output_uint_label(
row[const.LABEL_UINT] = self._label_provider.calculate_output_uint_label(
expected_uint_label=self._label_provider.LabelScaler.label_to_uint(expected_label),
etalon_uint_label=self._label_provider.LabelScaler.label_to_uint(etalon_label))

Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/input/rows_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, row):

for key, value in row.items():

if key == const.LABEL:
if key == const.LABEL_UINT:
self.__uint_label = int(value)
# TODO: To be adopted in future instead of __uint_label
self.__params[key] = value
Expand Down
22 changes: 18 additions & 4 deletions arekit/contrib/prompt/sample.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from arekit.common.data import const
from arekit.common.data.input.providers.sample.cropped import CroppedSampleRowProvider
from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.labels.str_fmt import StringLabelsFormatter


class PromptedSampleRowProvider(CroppedSampleRowProvider):
""" Sample, enriched with the prompt technique.
"""

def __init__(self, crop_window_size, label_scaler, text_provider, prompt):
def __init__(self, crop_window_size, label_scaler, text_provider, prompt, label_fmt=None):
""" crop_window_size: int
crop window size for the original text.
prompt: str
text which wraps the original cropped (optionally text).
this string suppose to include the following parameters:
text, s_ind, r_ind, label (optional)
this string suppose to include the following parameters (optional):
text, s_ind, t_ind, s_val, t_val, label_uint
"""
assert(isinstance(prompt, str))
assert(isinstance(text_provider, BaseSingleTextProvider))
assert(isinstance(label_fmt, StringLabelsFormatter) or label_fmt is None)

super(PromptedSampleRowProvider, self).__init__(crop_window_size=crop_window_size,
label_scaler=label_scaler,
text_provider=text_provider)

self.__prompt = prompt
self.__labels_fmt = label_fmt

def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_label,
parsed_doc, sentence_ind, s_ind, t_ind):
Expand All @@ -36,10 +39,21 @@ def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_labe
s_ind=s_ind,
t_ind=t_ind)
original_text = row[BaseSingleTextProvider.TEXT_A]

sentence_terms, actual_s_ind, actual_t_ind = self._provide_sentence_terms(
parsed_doc=parsed_doc, sentence_ind=sentence_ind, s_ind=s_ind, t_ind=t_ind)

label_uint = row[const.LABEL_UINT] if const.LABEL_UINT in row else None
label_val = str(label_uint) if label_uint is None or self.__labels_fmt is None else \
self.__labels_fmt.label_to_str(self._label_provider.LabelScaler.uint_to_label(row[const.LABEL_UINT]))

row[BaseSingleTextProvider.TEXT_A] = self.__prompt.format(
text=original_text,
s_ind=row[const.S_IND],
t_ind=row[const.T_IND],
label=row[const.LABEL] if const.LABEL in row else None)
s_val=sentence_terms[actual_s_ind].Value,
t_val=sentence_terms[actual_t_ind].Value,
label_uint=label_uint,
label_val=label_val)

return row
2 changes: 1 addition & 1 deletion arekit/contrib/utils/data/writers/json_opennre.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __format_row(row, text_columns):
"token": tokens,
"h": {"pos": [s_ind, s_ind + 1], "id": str(bag_id + "s")},
"t": {"pos": [t_ind, t_ind + 1], "id": str(bag_id + "t")},
"relation": str(int(row[const.LABEL])) if const.LABEL in row else "NA"
"relation": str(int(row[const.LABEL_UINT])) if const.LABEL_UINT in row else "NA"
}

def open_target(self, target):
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/utils/test_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test(self):

balanced_storage = PandasBasedStorageBalancing.create_balanced_from(
storage=reader.read(target=join(self.__output_dir, "sample-train-0.csv")),
column_name=const.LABEL,
column_name=const.LABEL_UINT,
free_origin=True)

print(balanced_storage.DataFrame)
134 changes: 134 additions & 0 deletions tests/tutorials/test_tutorial_pipeline_sampling_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import unittest
from collections import OrderedDict
from os.path import join, dirname

from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.entities.base import Entity
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.entities.types import OpinionEntityType
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
from arekit.common.labels.base import NoLabel, Label
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.base import BasePipeline
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.terms.mapper import BertDefaultStringTextTermsMapper
from arekit.contrib.prompt.sample import PromptedSampleRowProvider
from arekit.contrib.source.brat.entities.parser import BratTextEntitiesParser
from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter
from arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader
from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage
from arekit.contrib.utils.data.writers.csv_pd import PandasCsvWriter
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.pipelines.items.sampling.base import BaseSerializerPipelineItem
from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer
from arekit.contrib.utils.pipelines.text_opinion.annot.predefined import PredefinedTextOpinionAnnotator
from arekit.contrib.utils.pipelines.text_opinion.extraction import text_opinion_extraction_pipeline
from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter
from tests.tutorials.test_tutorial_pipeline_text_opinion_annotation import FooDocumentProvider


class Positive(Label):
pass


class Negative(Label):
pass


class SentimentLabelScaler(BaseLabelScaler):

def __init__(self):
int_to_label = OrderedDict([(NoLabel(), 0), (Positive(), 1), (Negative(), -1)])
uint_to_label = OrderedDict([(NoLabel(), 0), (Positive(), 1), (Negative(), 2)])
super(SentimentLabelScaler, self).__init__(int_dict=int_to_label,
uint_dict=uint_to_label)


class CustomLabelsFormatter(StringLabelsFormatter):
def __init__(self, pos_label_type, neg_label_type):
stol = {"POSITIVE_TO": neg_label_type, "NEGATIVE_TO": pos_label_type}
super(CustomLabelsFormatter, self).__init__(stol=stol)


class CustomEntitiesFormatter(StringEntitiesFormatter):

def __init__(self, subject_fmt="[subject]", object_fmt="[object]"):
self.__subj_fmt = subject_fmt
self.__obj_fmt = object_fmt

def to_string(self, original_value, entity_type):
assert(isinstance(original_value, Entity))
if entity_type == OpinionEntityType.Other:
return original_value.Value
elif entity_type == OpinionEntityType.Object or entity_type == OpinionEntityType.SynonymObject:
return self.__obj_fmt
elif entity_type == OpinionEntityType.Subject or entity_type == OpinionEntityType.SynonymSubject:
return self.__subj_fmt
return None


class TestPromptSerialization(unittest.TestCase):

__output_dir = join(dirname(__file__), "out")

def test(self):
terms_mapper = BertDefaultStringTextTermsMapper(
entity_formatter=CustomEntitiesFormatter(subject_fmt="#S", object_fmt="#O"))

text_provider = BaseSingleTextProvider(terms_mapper)

rows_provider = PromptedSampleRowProvider(
crop_window_size=1000,
label_scaler=SentimentLabelScaler(),
prompt="Для текста: `{text}` отношении в нём между " +
"объектом `{s_val}` (слово `{s_ind}`) " +
"и субъектом `{s_val}` (слово `{s_ind}`) " +
"имеет оценку {label_val}",
label_fmt=RuSentRelLabelsFormatter(pos_label_type=Positive, neg_label_type=Negative),
text_provider=text_provider)

writer = PandasCsvWriter(write_header=True)
samples_io = SamplesIO(self.__output_dir, writer, prefix="prompt-sample", target_extension=".tsv.gz")

pipeline_item = BaseSerializerPipelineItem(
rows_provider=rows_provider,
samples_io=samples_io,
save_labels_func=lambda data_type: True,
storage=PandasBasedRowsStorage())

pipeline = BasePipeline([
pipeline_item
])

#####
# Declaring pipeline related context parameters.
#####
doc_provider = FooDocumentProvider()
text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), DefaultTextTokenizer(keep_tokens=True)])
train_pipeline = text_opinion_extraction_pipeline(
annotators=[
PredefinedTextOpinionAnnotator(doc_provider,
label_formatter=CustomLabelsFormatter(
pos_label_type=Positive,
neg_label_type=Negative))
],
text_opinion_filters=[
DistanceLimitedTextOpinionFilter(terms_per_context=50)
],
get_doc_by_id_func=doc_provider.by_id,
text_parser=text_parser)
#####

pipeline.run(input_data=None,
params_dict={
"data_folding": NoFolding(),
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": {DataType.Train: [0, 1]}
})

reader = PandasCsvReader()
source = join(self.__output_dir, "prompt-sample-train-0.tsv.gz")
storage = reader.read(source)
self.assertEqual(20, len(storage), "Amount of rows is non equal!")

0 comments on commit f047f03

Please sign in to comment.