Skip to content

Commit

Permalink
#496 done (version 0.23.1) #156 related, added unit test and prompting
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 16, 2023
1 parent 6b31098 commit d7b7b1c
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 4 deletions.
22 changes: 18 additions & 4 deletions arekit/contrib/prompt/sample.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from arekit.common.data import const
from arekit.common.data.input.providers.sample.cropped import CroppedSampleRowProvider
from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.labels.str_fmt import StringLabelsFormatter


class PromptedSampleRowProvider(CroppedSampleRowProvider):
""" Sample, enriched with the prompt technique.
"""

def __init__(self, crop_window_size, label_scaler, text_provider, prompt):
def __init__(self, crop_window_size, label_scaler, text_provider, prompt, label_fmt=None):
""" crop_window_size: int
crop window size for the original text.
prompt: str
text which wraps the original cropped (optionally text).
this string suppose to include the following parameters:
text, s_ind, r_ind, label (optional)
this string suppose to include the following parameters (optional):
text, s_ind, t_ind, s_val, t_val, label_uint
"""
assert(isinstance(prompt, str))
assert(isinstance(text_provider, BaseSingleTextProvider))
assert(isinstance(label_fmt, StringLabelsFormatter) or label_fmt is None)

super(PromptedSampleRowProvider, self).__init__(crop_window_size=crop_window_size,
label_scaler=label_scaler,
text_provider=text_provider)

self.__prompt = prompt
self.__labels_fmt = label_fmt

def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_label,
parsed_news, sentence_ind, s_ind, t_ind):
Expand All @@ -36,10 +39,21 @@ def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_labe
s_ind=s_ind,
t_ind=t_ind)
original_text = row[BaseSingleTextProvider.TEXT_A]

sentence_terms, actual_s_ind, actual_t_ind = self._provide_sentence_terms(
parsed_news=parsed_news, sentence_ind=sentence_ind, s_ind=s_ind, t_ind=t_ind)

label_uint = row[const.LABEL] if const.LABEL in row else None
label_val = str(label_uint) if label_uint is None or self.__labels_fmt is None else \
self.__labels_fmt.label_to_str(self._label_provider.LabelScaler.uint_to_label(row[const.LABEL]))

row[BaseSingleTextProvider.TEXT_A] = self.__prompt.format(
text=original_text,
s_ind=row[const.S_IND],
t_ind=row[const.T_IND],
label=row[const.LABEL] if const.LABEL in row else None)
s_val=sentence_terms[actual_s_ind].Value,
t_val=sentence_terms[actual_t_ind].Value,
label_uint=label_uint,
label_val=label_val)

return row
135 changes: 135 additions & 0 deletions tests/tutorials/test_tutorial_pipeline_sampling_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import unittest
from collections import OrderedDict
from os.path import join, dirname

from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.entities.base import Entity
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.entities.types import OpinionEntityType
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
from arekit.common.labels.base import NoLabel, Label
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.base import BasePipeline
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.terms.mapper import BertDefaultStringTextTermsMapper
from arekit.contrib.prompt.sample import PromptedSampleRowProvider
from arekit.contrib.source.brat.entities.parser import BratTextEntitiesParser
from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter
from arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader
from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage
from arekit.contrib.utils.data.writers.csv_pd import PandasCsvWriter
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.pipelines.items.sampling.base import BaseSerializerPipelineItem
from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer
from arekit.contrib.utils.pipelines.text_opinion.annot.predefined import PredefinedTextOpinionAnnotator
from arekit.contrib.utils.pipelines.text_opinion.extraction import text_opinion_extraction_pipeline
from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter
from tests.tutorials.test_tutorial_pipeline_text_opinion_annotation import FooDocumentOperations


class Positive(Label):
pass


class Negative(Label):
pass


class SentimentLabelScaler(BaseLabelScaler):

def __init__(self):
int_to_label = OrderedDict([(NoLabel(), 0), (Positive(), 1), (Negative(), -1)])
uint_to_label = OrderedDict([(NoLabel(), 0), (Positive(), 1), (Negative(), 2)])
super(SentimentLabelScaler, self).__init__(int_dict=int_to_label,
uint_dict=uint_to_label)


class CustomLabelsFormatter(StringLabelsFormatter):
def __init__(self, pos_label_type, neg_label_type):
stol = {"POSITIVE_TO": neg_label_type, "NEGATIVE_TO": pos_label_type}
super(CustomLabelsFormatter, self).__init__(stol=stol)


class CustomEntitiesFormatter(StringEntitiesFormatter):

def __init__(self, subject_fmt="[subject]", object_fmt="[object]"):
self.__subj_fmt = subject_fmt
self.__obj_fmt = object_fmt

def to_string(self, original_value, entity_type):
assert(isinstance(original_value, Entity))
if entity_type == OpinionEntityType.Other:
return original_value.Value
elif entity_type == OpinionEntityType.Object or entity_type == OpinionEntityType.SynonymObject:
return self.__obj_fmt
elif entity_type == OpinionEntityType.Subject or entity_type == OpinionEntityType.SynonymSubject:
return self.__subj_fmt
return None


class TestPromptSerialization(unittest.TestCase):

__output_dir = join(dirname(__file__), "out")

def test(self):
terms_mapper = BertDefaultStringTextTermsMapper(
entity_formatter=CustomEntitiesFormatter(subject_fmt="#S", object_fmt="#O"))

text_provider = BaseSingleTextProvider(terms_mapper)

rows_provider = PromptedSampleRowProvider(
crop_window_size=1000,
label_scaler=SentimentLabelScaler(),
prompt="Для текста: `{text}` отношении в нём между " +
"объектом `{s_val}` (слово `{s_ind}`) " +
"и субъектом `{s_val}` (слово `{s_ind}`) " +
"имеет оценку {label_val}",
label_fmt=RuSentRelLabelsFormatter(pos_label_type=Positive, neg_label_type=Negative),
text_provider=text_provider)

writer = PandasCsvWriter(write_header=True)
samples_io = SamplesIO(self.__output_dir, writer, prefix="prompt-sample", target_extension=".tsv.gz")

pipeline_item = BaseSerializerPipelineItem(
rows_provider=rows_provider,
samples_io=samples_io,
save_labels_func=lambda data_type: True,
storage=PandasBasedRowsStorage(),
balance_func=lambda _: True)

pipeline = BasePipeline([
pipeline_item
])

#####
# Declaring pipeline related context parameters.
#####
doc_ops = FooDocumentOperations()
text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), DefaultTextTokenizer(keep_tokens=True)])
train_pipeline = text_opinion_extraction_pipeline(
annotators=[
PredefinedTextOpinionAnnotator(doc_ops,
label_formatter=CustomLabelsFormatter(
pos_label_type=Positive,
neg_label_type=Negative))
],
text_opinion_filters=[
DistanceLimitedTextOpinionFilter(terms_per_context=50)
],
get_doc_by_id_func=doc_ops.by_id,
text_parser=text_parser)
#####

pipeline.run(input_data=None,
params_dict={
"data_folding": NoFolding(doc_ids=[0, 1], supported_data_type=DataType.Train),
"data_type_pipelines": {DataType.Train: train_pipeline},
"doc_ids": {DataType.Train: [0, 1]}
})

reader = PandasCsvReader()
source = join(self.__output_dir, "prompt-sample-train-0.tsv.gz")
storage = reader.read(source)
self.assertEqual(28, len(storage), "Amount of rows is non equal!")

0 comments on commit d7b7b1c

Please sign in to comment.