In [1]:
from transformers import pipeline
import torch

In [14]:
class Ner_Extractor:
    """
    Labeling each token in sentence as named entity

    :param model_checkpoint: name or path to model
    :type model_checkpoint: string
    """

    def __init__(self, model_checkpoint: str):
        self.token_pred_pipeline = pipeline("token-classification",
                                            model=model_checkpoint,
                                            aggregation_strategy="average")

    @staticmethod
    def text_color(txt, txt_c="blue", txt_hglt="on_yellow"):
        """
        Coloring part of text

        :param txt: part of text from sentence
        :type txt: string
        :param txt_c: text color
        :type txt_c: string
        :param txt_hglt: color of text highlighting
        :type txt_hglt: string
        :return: string with color labeling
        :rtype: string
        """
        return colored(txt, txt_c, txt_hglt)

    @staticmethod
    def concat_entities(ner_result):
        """
        Concatenation entities from model output on grouped entities

        :param ner_result: output from model pipeline
        :type ner_result: list
        :return: list of grouped entities with start - end position in text
        :rtype: list
        """
        entities = []
        prev_entity = None
        prev_end = 0
        for i in range(len(ner_result)):

            if (ner_result[i]["entity_group"] == prev_entity) &\
               (ner_result[i]["start"] == prev_end):

                entities[i-1][2] = ner_result[i]["end"]
                prev_entity = ner_result[i]["entity_group"]
                prev_end = ner_result[i]["end"]
            else:
                entities.append([ner_result[i]["entity_group"],
                                 ner_result[i]["start"],
                                 ner_result[i]["end"]])
                prev_entity = ner_result[i]["entity_group"]
                prev_end = ner_result[i]["end"]

        return entities


    def colored_text(self, text: str, entities: list):
        """
        Highlighting in the text named entities

        :param text: sentence or a part of corpus
        :type text: string
        :param entities: concated entities on groups with start - end position in text
        :type entities: list
        :return: Highlighted sentence
        :rtype: string
        """
        colored_text = ""
        init_pos = 0
        for ent in entities:
            if ent[1] > init_pos:
                colored_text += text[init_pos: ent[1]]
                colored_text += self.text_color(text[ent[1]: ent[2]]) + f"({ent[0]})"
                init_pos = ent[2]
            else:
                colored_text += self.text_color(text[ent[1]: ent[2]]) + f"({ent[0]})"
                init_pos = ent[2]

        return colored_text


    def get_entities(self, text: str):
        """
        Extracting entities from text with them position in text

        :param text: input sentence for preparing
        :type text: string
        :return: list with entities from text
        :rtype: list
        """
        assert len(text) > 0, text
        entities = self.token_pred_pipeline(text)
        # concat_ent = self.concat_entities(entities)

        return entities


    def show_ents_on_text(self, text: str):
        """
        Highlighting named entities in input text

        :param text: input sentence for preparing
        :type text: string
        :return: Highlighting text
        :rtype: string
        """
        assert len(text) > 0, text
        entities = self.get_entities(text)

        return self.colored_text(text, entities)

In [15]:
def unite_entities(extractor, text):
  entity_list = extractor.get_entities(text)

  output = []
  for i in range(len(entity_list)-1):
      if entity_list[i+1][1] - entity_list[i][2] == 1:
          output.append([entity_list[i][0], entity_list[i][1], entity_list[i+1][2]])
      else:
          output.append(entity_list[i])
  return output

In [16]:
extractor = Ner_Extractor(model_checkpoint = "surdan/LaBSE_ner_nerel")

In [17]:
text_exmp = "'Добрый день. Подскажите пожалуйста по такому вопросу. Подала на выплату с 3-7. Есть старший сын студент очник. Предоставила в соц защиту справку о стипендии в электронном варианте. Соц.защита требует в бумажном виде . Сказали , что будет отказ. Дети на практике и взять справку на руки не предоставляется возможным. В прошлом году сама соц.защита предложила в электронном варианте справку послать, а в этом нужна бумага. Правомерно ли поступают в нашей соц.защите г. Чусового? Спасибо."

In [20]:
text_exmp[75:78]

'3-7'

In [22]:
result = extractor.get_entities(text_exmp)
result

[{'entity_group': 'ORDINAL',
  'score': 0.9697609,
  'word': '3',
  'start': 75,
  'end': 76},
 {'entity_group': 'ORDINAL',
  'score': 0.64828557,
  'word': '7',
  'start': 77,
  'end': 78},
 {'entity_group': 'PROFESSION',
  'score': 0.9998497,
  'word': 'студент',
  'start': 97,
  'end': 104},
 {'entity_group': 'PROFESSION',
  'score': 0.99975973,
  'word': 'очник',
  'start': 105,
  'end': 110},
 {'entity_group': 'ORGANIZATION',
  'score': 0.999985,
  'word': 'соц защиту',
  'start': 127,
  'end': 137},
 {'entity_group': 'ORGANIZATION',
  'score': 0.9999734,
  'word': 'Соц',
  'start': 182,
  'end': 185},
 {'entity_group': 'ORGANIZATION',
  'score': 0.9999631,
  'word': '. защита',
  'start': 185,
  'end': 192},
 {'entity_group': 'DATE',
  'score': 0.9999862,
  'word': 'В прошлом году',
  'start': 317,
  'end': 331},
 {'entity_group': 'ORGANIZATION',
  'score': 0.9999831,
  'word': 'соц',
  'start': 337,
  'end': 340},
 {'entity_group': 'ORGANIZATION',
  'score': 0.9999793,
  'word':

In [24]:
res_dict = []
for i in range(1, len(result)):
    if result[i]['start'] - result[i-1]['end'] == 1:
        res_dict.append({
            'entity_group': result[i]['entity_group'],
            'start': result[i-1]['start'],
            'end': result[i]['end']
        })
    else:
        res_dict.append({
            'entity_group': result[i]['entity_group'],
            'start': result[i]['start'],
            'end': result[i]['end']
        })

In [25]:
res_dict

[{'entity_group': 'ORDINAL', 'start': 75, 'end': 78},
 {'entity_group': 'PROFESSION', 'start': 97, 'end': 104},
 {'entity_group': 'PROFESSION', 'start': 97, 'end': 110},
 {'entity_group': 'ORGANIZATION', 'start': 127, 'end': 137},
 {'entity_group': 'ORGANIZATION', 'start': 182, 'end': 185},
 {'entity_group': 'ORGANIZATION', 'start': 185, 'end': 192},
 {'entity_group': 'DATE', 'start': 317, 'end': 331},
 {'entity_group': 'ORGANIZATION', 'start': 337, 'end': 340},
 {'entity_group': 'ORGANIZATION', 'start': 340, 'end': 347},
 {'entity_group': 'ORGANIZATION', 'start': 454, 'end': 457},
 {'entity_group': 'ORGANIZATION', 'start': 457, 'end': 464},
 {'entity_group': 'CITY', 'start': 468, 'end': 476}]

In [19]:
unite_entities(extractor, text_exmp)

KeyError: 1