Skip to content

Commit

Permalink
#233 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 30, 2021
1 parent 375483c commit ff91670
Show file tree
Hide file tree
Showing 17 changed files with 287 additions and 303 deletions.
4 changes: 3 additions & 1 deletion arekit/common/data/input/providers/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from arekit.common.data.input.sample import InputSampleBase
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.news.parsed.providers.entity_service import EntityServiceProvider
from arekit.common.news.parsed.providers.text_opinion_pairs import TextOpinionPairsProvider


Expand Down Expand Up @@ -58,6 +59,7 @@ def __iter_linked_text_opins(news_opins_for_extraction_func, parse_news_func,
for doc_id in doc_ids_it:

parsed_news = parse_news_func(doc_id)
entity_service = EntityServiceProvider(parsed_news)

linked_text_opinion_lists = OpinionProvider.__iter_linked_text_opinion_lists(
# TODO. To be refactored.
Expand All @@ -66,7 +68,7 @@ def __iter_linked_text_opins(news_opins_for_extraction_func, parse_news_func,
value_to_group_id_func=value_to_group_id_func),
iter_opins_for_extraction=news_opins_for_extraction_func(doc_id=parsed_news.RelatedDocID),
filter_text_opinion_func=lambda text_opinion: InputSampleBase.check_ability_to_create_sample(
parsed_news=parsed_news,
entity_service=entity_service,
text_opinion=text_opinion,
window_size=terms_per_context))

Expand Down
4 changes: 3 additions & 1 deletion arekit/common/data/input/providers/rows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging

from arekit.common.data.input.providers.opinions import OpinionProvider
from arekit.common.news.parsed.providers.entity_service import EntityServiceProvider

logger = logging.getLogger(__name__)

Expand All @@ -12,7 +13,7 @@ class BaseRowProvider(object):

# region protected methods

def _provide_rows(self, parsed_news, text_opinion_linkage, idle_mode):
def _provide_rows(self, parsed_news, entity_service, text_opinion_linkage, idle_mode):
raise NotImplementedError()

# endregion
Expand All @@ -24,6 +25,7 @@ def iter_by_rows(self, opinion_provider, doc_ids_iter, idle_mode):
for parsed_news, linkage in opinion_provider.iter_linked_opinions(doc_ids_iter):

rows_it = self._provide_rows(parsed_news=parsed_news,
entity_service=EntityServiceProvider(parsed_news),
text_opinion_linkage=linkage,
idle_mode=idle_mode)

Expand Down
18 changes: 7 additions & 11 deletions arekit/common/data/input/providers/rows/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,27 @@
from arekit.common.data import const
from arekit.common.data.input.providers.rows.base import BaseRowProvider
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.dataset.text_opinions.enums import EntityEndType
from arekit.common.dataset.text_opinions.helper import TextOpinionHelper
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.news.parsed.base import ParsedNews
from arekit.common.news.parsed.providers.entity_service import EntityEndType, EntityServiceProvider


class BaseOpinionsRowProvider(BaseRowProvider):

@staticmethod
def __create_opinion_row(parsed_news, text_opinions_linkage):
def __create_opinion_row(entity_service, text_opinions_linkage):
"""
row format: [id, src, target, label]
"""
assert(isinstance(parsed_news, ParsedNews))
assert(isinstance(entity_service, EntityServiceProvider))
assert(isinstance(text_opinions_linkage, TextOpinionsLinkage))

row = OrderedDict()

src_value = TextOpinionHelper.extract_entity_value(
parsed_news=parsed_news,
src_value = entity_service.extract_entity_value(
text_opinion=text_opinions_linkage.First,
end_type=EntityEndType.Source)

target_value = TextOpinionHelper.extract_entity_value(
parsed_news=parsed_news,
target_value = entity_service.extract_entity_value(
text_opinion=text_opinions_linkage.First,
end_type=EntityEndType.Target)

Expand All @@ -42,9 +38,9 @@ def __create_opinion_row(parsed_news, text_opinions_linkage):

return row

def _provide_rows(self, parsed_news, text_opinion_linkage, idle_mode):
def _provide_rows(self, parsed_news, entity_service, text_opinion_linkage, idle_mode):
if idle_mode:
yield None
else:
yield BaseOpinionsRowProvider.__create_opinion_row(parsed_news=parsed_news,
yield BaseOpinionsRowProvider.__create_opinion_row(entity_service=entity_service,
text_opinions_linkage=text_opinion_linkage)
28 changes: 14 additions & 14 deletions arekit/common/data/input/providers/rows/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
from arekit.common.data.input.providers.rows.base import BaseRowProvider
from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.dataset.text_opinions.enums import EntityEndType
from arekit.common.dataset.text_opinions.helper import TextOpinionHelper
from arekit.common.labels.base import Label

from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.news.parsed.base import ParsedNews
from arekit.common.news.parsed.providers.entity_service import EntityEndType, EntityServiceProvider
from arekit.common.news.parsed.term_position import TermPositionTypes
from arekit.common.text_opinions.base import TextOpinion
from arekit.contrib.bert.input.providers.label_binary import BinaryLabelProvider
Expand Down Expand Up @@ -84,7 +83,7 @@ def __assign_value(column, value):
row[const.S_IND] = s_ind
row[const.T_IND] = t_ind

def _provide_rows(self, parsed_news, text_opinion_linkage, idle_mode):
def _provide_rows(self, parsed_news, entity_service, text_opinion_linkage, idle_mode):
assert(isinstance(idle_mode, bool))

row_dict = OrderedDict()
Expand All @@ -93,6 +92,7 @@ def _provide_rows(self, parsed_news, text_opinion_linkage, idle_mode):

rows_it = self.__provide_rows(
parsed_news=parsed_news,
entity_service=entity_service,
row_dict=row_dict,
text_opinion_linkage=text_opinion_linkage,
index_in_linked=index_in_linked,
Expand All @@ -119,7 +119,8 @@ def __create_instances_provider(label_provider):
if isinstance(label_provider, MultipleLabelProvider):
return SingleInstanceTextOpinionsLinkageProvider()

def __provide_rows(self, row_dict, parsed_news, text_opinion_linkage, index_in_linked, idle_mode):
def __provide_rows(self, row_dict, parsed_news, entity_service,
text_opinion_linkage, index_in_linked, idle_mode):
"""
Providing Rows depending on row_id_formatter type
"""
Expand All @@ -131,13 +132,14 @@ def __provide_rows(self, row_dict, parsed_news, text_opinion_linkage, index_in_l
for instance in self.__instances_provider.iter_instances(text_opinion_linkage):
yield self.__create_row(row=row_dict,
parsed_news=parsed_news,
entity_service=entity_service,
text_opinions_linkage=instance,
index_in_linked=index_in_linked,
# TODO. provide uint_label
etalon_label=etalon_label,
idle_mode=idle_mode)

def __create_row(self, row, parsed_news, text_opinions_linkage, index_in_linked, etalon_label, idle_mode):
def __create_row(self, row, parsed_news, entity_service, text_opinions_linkage, index_in_linked, etalon_label, idle_mode):
"""
Composing row in following format:
[id, label, type, text_a]
Expand All @@ -146,7 +148,6 @@ def __create_row(self, row, parsed_news, text_opinions_linkage, index_in_linked,
row with key values
"""
assert(isinstance(row, OrderedDict))
assert(isinstance(parsed_news, ParsedNews))
assert(isinstance(text_opinions_linkage, TextOpinionsLinkage))
assert(isinstance(index_in_linked, int))
assert(isinstance(etalon_label, Label))
Expand All @@ -157,14 +158,13 @@ def __create_row(self, row, parsed_news, text_opinions_linkage, index_in_linked,

text_opinion = text_opinions_linkage[index_in_linked]

s_ind, t_ind = self.__get_opinion_end_indices(parsed_news, text_opinion)
s_ind, t_ind = self.__get_opinion_end_indices(entity_service, text_opinion)

row.clear()

self._fill_row_core(row=row,
parsed_news=parsed_news,
sentence_ind=TextOpinionHelper.extract_entity_position(
parsed_news=parsed_news,
sentence_ind=entity_service.extract_entity_position(
text_opinion=text_opinion,
end_type=EntityEndType.Source,
position_type=TermPositionTypes.SentenceIndex),
Expand All @@ -176,17 +176,17 @@ def __create_row(self, row, parsed_news, text_opinions_linkage, index_in_linked,
return row

@staticmethod
def __get_opinion_end_indices(parsed_news, text_opinion):
assert(isinstance(parsed_news, ParsedNews))
def __get_opinion_end_indices(service, text_opinion):
assert(isinstance(service, EntityServiceProvider))
assert(isinstance(text_opinion, TextOpinion))

s_ind = parsed_news.get_entity_position(text_opinion.SourceId).get_index(
s_ind = service.get_entity_position(text_opinion.SourceId).get_index(
position_type=TermPositionTypes.IndexInSentence)

t_ind = parsed_news.get_entity_position(text_opinion.TargetId).get_index(
t_ind = service.get_entity_position(text_opinion.TargetId).get_index(
position_type=TermPositionTypes.IndexInSentence)

return (s_ind, t_ind)
return s_ind, t_ind

# endregion

Expand Down
14 changes: 5 additions & 9 deletions arekit/common/data/input/sample.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from collections import OrderedDict

from arekit.common.dataset.text_opinions.enums import DistanceType
from arekit.common.dataset.text_opinions.helper import TextOpinionHelper
from arekit.common.news.parsed.base import ParsedNews
from arekit.common.news.parsed.providers.entity_service import EntityServiceProvider, DistanceType
from arekit.common.text_opinions.base import TextOpinion


Expand All @@ -28,11 +26,11 @@ def ID(self):
# endregion

@staticmethod
def check_ability_to_create_sample(parsed_news, window_size, text_opinion):
def check_ability_to_create_sample(entity_service, window_size, text_opinion):
"""
Main text_opinion filtering rules
"""
assert(isinstance(parsed_news, ParsedNews))
assert(isinstance(entity_service, EntityServiceProvider))
assert(isinstance(text_opinion, TextOpinion))
assert(isinstance(window_size, int) and window_size > 0)

Expand All @@ -43,16 +41,14 @@ def check_ability_to_create_sample(parsed_news, window_size, text_opinion):
if text_opinion.SourceId != text_opinion.TargetId:
is_not_same_ends = True

dist_between_entities = TextOpinionHelper.calc_dist_between_text_opinion_ends(
parsed_news=parsed_news,
dist_between_entities = entity_service.calc_dist_between_text_opinion_ends(
text_opinion=text_opinion,
distance_type=DistanceType.InTerms)

if InputSampleBase._check_ends_could_be_fitted_in_window(dist_between_entities, window_size):
is_in_window = True

dist_in_sents = TextOpinionHelper.calc_dist_between_text_opinion_ends(
parsed_news=parsed_news,
dist_in_sents = entity_service.calc_dist_between_text_opinion_ends(
text_opinion=text_opinion,
distance_type=DistanceType.InSentences)

Expand Down
Empty file removed arekit/common/dataset/__init__.py
Empty file.
Empty file.
23 changes: 0 additions & 23 deletions arekit/common/dataset/text_opinions/enums.py

This file was deleted.

109 changes: 0 additions & 109 deletions arekit/common/dataset/text_opinions/helper.py

This file was deleted.

Loading

0 comments on commit ff91670

Please sign in to comment.