Skip to content

Commit

Permalink
#280 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Feb 14, 2022
1 parent 7536ad9 commit 0d201cc
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 87 deletions.
7 changes: 1 addition & 6 deletions arekit/common/entities/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
class Entity(object):

def __init__(self, value, e_type, id_in_doc, group_index=None):
def __init__(self, value, e_type, group_index=None):
assert(isinstance(value, str) and len(value) > 0)
assert(isinstance(e_type, str) or e_type is None)
assert(isinstance(group_index, int) or group_index is None)
self.__value = value.lower()
self.__id = id_in_doc
self.__type = e_type
self.__group_index = group_index

Expand All @@ -17,10 +16,6 @@ def GroupIndex(self):
def Value(self):
return self.__value

@property
def IdInDocument(self):
return self.__id

@property
def Type(self):
return self.__type
Expand Down
10 changes: 5 additions & 5 deletions arekit/common/experiment/annot/algo/pair_based.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from arekit.common.entities.base import Entity
from arekit.common.experiment.annot.algo.base import BaseAnnotationAlgorithm
from arekit.common.labels.provider.base import BasePairLabelProvider
from arekit.common.news.entity import DocumentEntity
from arekit.common.news.parsed.base import ParsedNews
from arekit.common.news.parsed.providers.entity_service import EntityServiceProvider, DistanceType
from arekit.common.news.parsed.providers.opinion_pairs import OpinionPairsProvider
Expand Down Expand Up @@ -31,8 +31,8 @@ def __init__(self, dist_in_terms_bound, label_provider, dist_in_sents=0, ignored

@staticmethod
def __create_key_by_entity_pair(e1, e2):
assert(isinstance(e1, Entity))
assert(isinstance(e2, Entity))
assert(isinstance(e1, DocumentEntity))
assert(isinstance(e2, DocumentEntity))
return "{}_{}".format(e1.IdInDocument, e2.IdInDocument)

def __is_ignored_entity_value(self, entity_value):
Expand All @@ -41,8 +41,8 @@ def __is_ignored_entity_value(self, entity_value):

def __try_create_pair_key(self, entity_service, e1, e2, existed_opinions):
assert(isinstance(entity_service, EntityServiceProvider))
assert(isinstance(e1, Entity))
assert(isinstance(e2, Entity))
assert(isinstance(e1, DocumentEntity))
assert(isinstance(e2, DocumentEntity))

if e1.IdInDocument == e2.IdInDocument:
return
Expand Down
14 changes: 14 additions & 0 deletions arekit/common/news/entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from arekit.common.entities.base import Entity


class DocumentEntity(Entity):

def __init__(self, value, e_type, id_in_doc, group_index):
super(DocumentEntity, self).__init__(value=value,
e_type=e_type,
group_index=group_index)
self.__id = id_in_doc

@property
def IdInDocument(self):
return self.__id
12 changes: 11 additions & 1 deletion arekit/common/news/parsed/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from arekit.common.news.entity import DocumentEntity
from arekit.common.news.parsed.base import ParsedNews


class BaseParsedNewsServiceProvider(object):

def __init__(self):
self._doc_entities = None

@property
def Name(self):
raise NotImplementedError()

def init_parsed_news(self, parsed_news):
raise NotImplementedError()
assert(isinstance(parsed_news, ParsedNews))
self._doc_entities = [DocumentEntity(id_in_doc=doc_id, value=entity.Value,
e_type=entity.Type, group_index=entity.GroupIndex)
for doc_id, entity in enumerate(parsed_news.iter_entities())]
31 changes: 14 additions & 17 deletions arekit/common/news/parsed/providers/base_pairs.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
from arekit.common.entities.base import Entity
import collections
from arekit.common.labels.provider.base import BasePairLabelProvider
from arekit.common.news.parsed.base import ParsedNews
from arekit.common.news.parsed.providers.base import BaseParsedNewsServiceProvider


class BasePairProvider(BaseParsedNewsServiceProvider):

def __init__(self):
self._entities = None

@property
def Name(self):
raise NotImplementedError()

def init_parsed_news(self, parsed_news):
assert(isinstance(parsed_news, ParsedNews))
self._entities = list(parsed_news.iter_entities())

def _create_pair(self, source_entity, target_entity, label):
raise NotImplementedError()

# region private methods

def _iter_from_entities(self, source_entities, target_entities, label_provider, filter_func=None):
def _iter_from_entities(self, src_entity_doc_ids, tgt_entity_doc_ids, label_provider, filter_func=None):
assert(isinstance(src_entity_doc_ids, list))
assert(isinstance(tgt_entity_doc_ids, list))
assert(isinstance(label_provider, BasePairLabelProvider))
assert(callable(filter_func) or filter_func is None)

for source_entity in source_entities:
for target_entity in target_entities:
assert(isinstance(source_entity, Entity))
assert(isinstance(target_entity, Entity))
for src_e_doc_id in src_entity_doc_ids:
for tgt_e_doc_id in tgt_entity_doc_ids:
assert(isinstance(src_e_doc_id, int))
assert(isinstance(tgt_e_doc_id, int))

# Extract entities by doc_id.
source_entity = self._doc_entities[src_e_doc_id]
target_entity = self._doc_entities[tgt_e_doc_id]

if filter_func is not None and not filter_func(source_entity, target_entity):
continue
Expand All @@ -48,8 +46,7 @@ def _iter_from_entities(self, source_entities, target_entities, label_provider,

def iter_from_all(self, label_provider, filter_func):
assert(isinstance(label_provider, BasePairLabelProvider))

return self._iter_from_entities(source_entities=self._entities,
target_entities=self._entities,
return self._iter_from_entities(src_entity_doc_ids=list(map(lambda e: e.IdInDocument, self._doc_entities)),
tgt_entity_doc_ids=list(map(lambda e: e.IdInDocument, self._doc_entities)),
label_provider=label_provider,
filter_func=filter_func)
32 changes: 13 additions & 19 deletions arekit/common/news/parsed/providers/entity_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from enum import Enum

from arekit.common.entities.base import Entity
from arekit.common.news.entity import DocumentEntity
from arekit.common.news.parsed.base import ParsedNews
from arekit.common.news.parsed.providers.base import BaseParsedNewsServiceProvider
from arekit.common.news.parsed.term_position import TermPositionTypes, TermPosition
from arekit.common.text.enums import TermFormat
from arekit.common.text.parsed import BaseParsedText
from arekit.common.text_opinions.base import TextOpinion


Expand Down Expand Up @@ -43,9 +42,9 @@ class EntityServiceProvider(BaseParsedNewsServiceProvider):
NAME = "entity-service-provider"

def __init__(self):
super(EntityServiceProvider, self).__init__()
# Initialize API.
self.__iter_raw_terms_func = None
self.__get_sent = None
# Initialize entity positions.
self.__entity_positions = None

Expand All @@ -54,9 +53,9 @@ def Name(self):
return self.NAME

def init_parsed_news(self, parsed_news):
super(EntityServiceProvider, self).init_parsed_news(parsed_news)
assert(isinstance(parsed_news, ParsedNews))
self.__iter_raw_terms_func = lambda: parsed_news.iter_terms(filter_func=None, term_only=False)
self.__get_sent = parsed_news.get_sentence
self.__init_entity_positions()

# region public 'extract' methods
Expand Down Expand Up @@ -90,8 +89,8 @@ def calc_dist_between_text_opinion_ends(self, text_opinion, distance_type):
position_type=DistanceType.to_position_type(distance_type))

def calc_dist_between_entities(self, e1, e2, distance_type):
assert(isinstance(e1, Entity))
assert(isinstance(e2, Entity))
assert(isinstance(e1, DocumentEntity))
assert(isinstance(e2, DocumentEntity))
assert(isinstance(distance_type, DistanceType))

return self.__calc_distance(
Expand All @@ -113,15 +112,7 @@ def get_entity_position(self, id_in_document, position_type=None):
return e_pos.get_index(position_type)

def get_entity_value(self, id_in_document):
position = self.__entity_positions[id_in_document]
assert(isinstance(position, TermPosition))

sent_ind = position.get_index(position_type=TermPositionTypes.SentenceIndex)
sentence = self.__get_sent(sent_ind)

assert(isinstance(sentence, BaseParsedText))
entity = sentence.get_term(position.get_index(position_type=TermPositionTypes.IndexInSentence),
term_format=TermFormat.Raw)
entity = self._doc_entities[id_in_document]
assert(isinstance(entity, Entity))
return entity.Value

Expand Down Expand Up @@ -159,15 +150,18 @@ def __init_entity_positions(self):
self.__entity_positions = self.__calculate_entity_positions()

def __calculate_entity_positions(self):
positions = {}
""" Note: here we consider the same order as in self._entities.
"""
positions = []
t_ind_in_doc = 0

for s_ind, t_ind_in_sent, term in self.__iter_raw_terms_func():

if isinstance(term, Entity):
positions[term.IdInDocument] = TermPosition(term_ind_in_doc=t_ind_in_doc,
term_ind_in_sent=t_ind_in_sent,
s_ind=s_ind)
position = TermPosition(term_ind_in_doc=t_ind_in_doc,
term_ind_in_sent=t_ind_in_sent,
s_ind=s_ind)
positions.append(position)

t_ind_in_doc += 1

Expand Down
12 changes: 6 additions & 6 deletions arekit/common/news/parsed/providers/text_opinion_pairs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from arekit.common.entities.base import Entity
from arekit.common.entities.collection import EntityCollection
from arekit.common.news.entity import DocumentEntity
from arekit.common.news.parsed.providers.base_pairs import BasePairProvider
from arekit.common.opinions.base import Opinion
from arekit.common.text_opinions.base import TextOpinion
Expand All @@ -27,8 +27,8 @@ def Name(self):
return self.NAME

def _create_pair(self, source_entity, target_entity, label):
assert(isinstance(source_entity, Entity))
assert(isinstance(target_entity, Entity))
assert(isinstance(source_entity, DocumentEntity))
assert(isinstance(target_entity, DocumentEntity))

return TextOpinion(doc_id=self.__doc_id,
source_id=source_entity.IdInDocument,
Expand All @@ -41,7 +41,7 @@ def init_parsed_news(self, parsed_news):
super(TextOpinionPairsProvider, self).init_parsed_news(parsed_news)
self.__doc_id = parsed_news.RelatedDocID
self.__entities_collection = EntityCollection(
entities=list(self._entities),
entities=list(self._doc_entities),
value_to_group_id_func=self.__value_to_group_id_func)

def iter_from_opinion(self, opinion, debug=False):
Expand Down Expand Up @@ -71,8 +71,8 @@ def iter_from_opinion(self, opinion, debug=False):

label_provider = PairSingleLabelProvider(label_instance=opinion.Sentiment)

pairs_it = self._iter_from_entities(source_entities=source_entities,
target_entities=target_entities,
pairs_it = self._iter_from_entities(src_entity_doc_ids=list(map(lambda e: e.IdInDocument, source_entities)),
tgt_entity_doc_ids=list(map(lambda e: e.IdInDocument, target_entities)),
label_provider=label_provider)

for pair in pairs_it:
Expand Down
10 changes: 6 additions & 4 deletions arekit/contrib/source/rusentrel/entities/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ class RuSentRelEntity(Entity):
"""

def __init__(self, id_in_doc, e_type, char_index_begin, char_index_end, value):
assert(isinstance(id_in_doc, int))
assert(isinstance(e_type, str))
assert(isinstance(char_index_begin, int))
assert(isinstance(char_index_end, int))
super(RuSentRelEntity, self).__init__(value=value,
e_type=e_type,
id_in_doc=id_in_doc)
super(RuSentRelEntity, self).__init__(value=value, e_type=e_type)

self.__e_type = e_type
self.__begin = char_index_begin
self.__end = char_index_end
self.__id = id_in_doc

@property
def CharIndexBegin(self):
Expand All @@ -30,3 +28,7 @@ def CharIndexEnd(self):
@property
def Type(self):
return self.__e_type

@property
def ID(self):
return self.__id
23 changes: 5 additions & 18 deletions arekit/processing/text/pipeline_entities_bert_ontonotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,23 @@
class BertOntonotesNERPipelineItem(SentenceObjectsParserPipelineItem):

KEY = "src"
NEXT_ENTITY_ID_KEY = "next_entity_id"

def __init__(self):
# Initialize bert-based model instance.
self.__ontonotes_ner = BertOntonotesNER()
super(BertOntonotesNERPipelineItem, self).__init__(TermsPartitioning())

def _get_parts_provider_func(self, pipeline_ctx):
return self.__iter_subs_values_with_bounds(pipeline_ctx)
assert(isinstance(pipeline_ctx, PipelineContext))
terms_list = self._get_text(pipeline_ctx)
return self.__iter_subs_values_with_bounds(terms_list)

def _get_text(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
assert(self.KEY in pipeline_ctx)
return pipeline_ctx.provide(self.KEY)

def __get_and_register_next_entity_id(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
target_id = pipeline_ctx.provide(param=self.NEXT_ENTITY_ID_KEY) \
if self.NEXT_ENTITY_ID_KEY in pipeline_ctx else 0
pipeline_ctx.update(param=self.NEXT_ENTITY_ID_KEY, value=target_id + 1)
return target_id

def __iter_subs_values_with_bounds(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
terms_list = self._get_text(pipeline_ctx)
def __iter_subs_values_with_bounds(self, terms_list):
assert(isinstance(terms_list, list))

single_sequence = [terms_list]
Expand All @@ -43,11 +35,6 @@ def __iter_subs_values_with_bounds(self, pipeline_ctx):
for p_sequence in processed_sequences:
for s_obj in p_sequence:
assert(isinstance(s_obj, NerObjectDescriptor))

value = " ".join(terms_list[s_obj.Position:s_obj.Position + s_obj.Length])

entity = Entity(value=value,
e_type=s_obj.ObjectType,
id_in_doc=self.__get_and_register_next_entity_id(pipeline_ctx))

entity = Entity(value=value, e_type=s_obj.ObjectType)
yield entity, Bound(pos=s_obj.Position, length=s_obj.Length)
1 change: 0 additions & 1 deletion examples/network/args/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def read_argument(args):

@staticmethod
def add_argument(parser, default):
assert(isinstance(default, str))
parser.add_argument('--text',
dest='input_text',
type=str,
Expand Down
3 changes: 1 addition & 2 deletions examples/pipelines/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@

def run_data_serialization_pipeline(sentences, terms_per_context, entities_parser,
embedding_path, entity_fmt_type, stemmer):
assert(isinstance(entities_parser, BasePipelineItem) or entities_parser is None)
assert(isinstance(sentences, list))
assert(isinstance(entities_parser, BasePipelineItem) or entities_parser is None)
assert(isinstance(terms_per_context, int))
assert(isinstance(embedding_path, str))
assert(isinstance(entity_fmt_type, EntityFormatterTypes))
Expand All @@ -48,7 +48,6 @@ def run_data_serialization_pipeline(sentences, terms_per_context, entities_parse

label_provider = MultipleLabelProvider(label_scaler=labels_scaler)

# TODO. split text onto sentences.
sentences = list(map(lambda text: BaseNewsSentence(text), sentences))

annot_algo = PairBasedAnnotationAlgorithm(
Expand Down
Loading

0 comments on commit 0d201cc

Please sign in to comment.