In [None]:
!pip install transformers

In [None]:
!pip install razdel

In [None]:
!pip install keras

In [None]:
import gc
import numpy as np
import pandas as pd
from scipy.stats import norm
from keras.layers import Input, Dense, Lambda, Layer
from keras.layers.normalization import batch_normalization
from keras.models import Model
from keras import backend as K
from keras import metrics
import tensorflow as tf
import transformers
from transformers import AutoTokenizer, BertTokenizer, TFBertModel
import time
import unicodedata as ud
from tqdm.auto import tqdm
from razdel import tokenize, sentenize
from typing import List, Tuple, Union, Dict
import unicodedata as ud
import warnings
import copy
import codecs
import json
import random
import pickle

In [None]:
tf.__version__

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
!wget https://github.com/yutkin/Lenta.Ru-News-Dataset/releases/download/v1.1/lenta-ru-news.csv.bz2
!wget https://github.com/dialogue-evaluation/RuNNE/blob/main/public_data/train.jsonl?raw=true
!wget https://github.com/dialogue-evaluation/RuNNE/blob/main/public_data/ners.txt?raw=true

Функция для парсинга строки json RuNNE

In [None]:
#!c1.32
def parse_runne_json_string(s: str) -> Tuple[int, str, List[Tuple[str, int, int]]]:
    data = json.loads(s)
    if 'id' not in data:
        err_msg = f'The "id" key is not found in the string "{s}".'
        raise ValueError(err_msg)
    if 'sentences' not in data:
        err_msg = f'The "sentences" key is not found in the string "{s}".'
        raise ValueError(err_msg)
    identifier = data['id']
    text = data['sentences']
    ners = []
    if 'ners' in data:
        if len(text.strip()) == 0:
            err_msg = f'The named entities are specified incorrectly ' \
                      f'in the string "{s}".'
            raise ValueError(err_msg)
        for idx, named_entity_info in enumerate(data['ners']):
            if not isinstance(named_entity_info, list):
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            if len(named_entity_info) != 3:
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            if not isinstance(named_entity_info[0], int):
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            if not isinstance(named_entity_info[1], int):
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            if not isinstance(named_entity_info[2], str):
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            if named_entity_info[0] > named_entity_info[1]:
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            if named_entity_info[0] < 0:
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            if named_entity_info[1] >= len(text):
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            start_pos = named_entity_info[0]
            end_pos = named_entity_info[1] + 1
            if text[start_pos].isspace():
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            if text[end_pos - 1].isspace():
                err_msg = f'Named entity {idx} is specified incorrectly ' \
                          f'in the string "{s}".'
                raise ValueError(err_msg)
            ners.append((named_entity_info[2], start_pos, end_pos))
    return identifier, text, ners

In [None]:
#!g1.1
def load_runne_data(fname: str) -> Dict[int, Tuple[str, List[Tuple[str, int, int]]]]:
    texts_and_annotations = dict()
    counter = 0
    with codecs.open(fname, mode='r', encoding='utf-8') as fp:
        cur_line = fp.readline()
        while len(cur_line) > 0:
            prep_line = cur_line.strip()
            if len(prep_line) > 0:
                identifier, text, ners = parse_runne_json_string(prep_line)
                if identifier in texts_and_annotations:
                    err_msg = f'Identifier {identifier} is duplicated!'
                    raise ValueError(err_msg)
                ners = sorted(
                    list(set(ners)),
                    key=lambda it: (it[1], it[2], it[0])
                )
                texts_and_annotations[identifier] = (text, ners)
            cur_line = fp.readline()
    return texts_and_annotations

In [None]:
#!g1.1
def calc_entity_freqs(data: Dict[int, Tuple[str, List[Tuple[str, int, int]]]],
                      id_list: Union[List[int], None] = None) -> Dict[str, int]:
    frequencies = dict()
    re_for_entity = re.compile(r'^[A-Z]+[_A-Z]*[A-Z]+$')
    if id_list is None:
        id_list_ = list(data.keys())
    else:
        id_list_ = id_list
    for identifier in id_list_:
        _, ners = data[identifier]
        for ne_type, _, _ in ners:
            err_msg = f'{ne_type} is inadmissible named entity class!'
            if ne_type.startswith('B-') or ne_type.startswith('I-') or \
                    (ne_type == 'O'):
                raise ValueError(err_msg)
            if re_for_entity.search(ne_type) is None:
                raise ValueError(err_msg)
            frequencies[ne_type] = frequencies.get(ne_type, 0) + 1
    return frequencies

In [None]:
#!g1.1
def tokenize_text_with_ners(s: str, tokenizer: BertTokenizer,
                            ners: List[Tuple[str, int, int]],
                            ne_vocabulary: List[str]) \
        -> Tuple[List[str], List[List[int]]]:
    words, subtokens, subtoken_bounds = tokenize_text(s, tokenizer)
    word_bounds = []
    for cur_word, word_start, word_end in words:
        word_bounds.append((
            subtoken_bounds[word_start][0],
            subtoken_bounds[word_end - 1][1]
        ))
    ne_indicators = []
    for _ in range(len(ne_vocabulary)):
        ne_indicators.append([0 for _ in range(len(subtokens))])
    ne_set = set(map(lambda it: it[0], ners))
    if len(ne_set) == 0:
        return subtokens, ne_indicators
    diff_ne = ne_set - set(ne_vocabulary)
    if len(diff_ne) > 0:
        err_msg = f'The annotation {ners} is wrong because ' \
                  f'it contains unknown entities! ' \
                  f'They are: {sorted(list(diff_ne))}.'
        raise ValueError(err_msg)
    for ne_class, ne_start, ne_end in ners:
        ne_id = ne_vocabulary.index(ne_class)
        start_word_idx = -1
        for word_idx, (word_start, word_end) in enumerate(word_bounds):
            if (ne_start >= word_start) and (ne_start < word_end):
                if ne_start != word_start:
                    warn_msg = f'The annotation {ners} can have errors. ' \
                               f'The entity {(ne_class, ne_start, ne_end)} ' \
                               f'is inexactly found in the text "{s}". ' \
                               f'{ne_start} != {word_start}'
                    warnings.warn(warn_msg)
                start_word_idx = word_idx
                break
        if start_word_idx < 0:
            err_msg = f'The annotation {ners} is wrong. ' \
                      f'The entity {(ne_class, ne_start, ne_end)} ' \
                      f'is not found in the text {s}.'
            raise ValueError(err_msg)
        end_word_idx = -1
        for word_idx, (word_start, word_end) in enumerate(word_bounds):
            if (ne_end > word_start) and (ne_end <= word_end):
                if ne_end != word_end:
                    warn_msg = f'The annotation {ners} can have errors. ' \
                               f'The entity {(ne_class, ne_start, ne_end)} ' \
                               f'is inexactly found in the text "{s}". ' \
                               f'{ne_end} != {word_end}'
                    warnings.warn(warn_msg)
                end_word_idx = word_idx
                break
        if end_word_idx < 0:
            err_msg = f'The annotation {ners} is wrong. ' \
                      f'The entity {(ne_class, ne_start, ne_end)} ' \
                      f'is not found in the text {s}.'
            raise ValueError(err_msg)
        init_ne_subtoken = words[start_word_idx][1]
        fin_ne_subtoken = words[end_word_idx][2]
        for subtoken_idx in range(init_ne_subtoken, fin_ne_subtoken):
            ne_indicators[ne_id][subtoken_idx] = 1
        ne_indicators[ne_id][init_ne_subtoken] = 2
    return subtokens, ne_indicators

In [None]:
#!g1.1
def tokenize_text(s: str, tokenizer: BertTokenizer) \
        -> Tuple[List[Tuple[str, int, int]], List[str],
                 List[Union[None, Tuple[int, int]]]]:
    words: List[Tuple[str, int, int]] = []
    subtokens: List[str] = []
    subtoken_bounds: List[Union[None, Tuple[int, int]]] = []
    subtokens.append(tokenizer.cls_token)
    subtoken_bounds.append(None)
    n_bpe = 1
    tokenization_iterator = filter(
        lambda it2: len(s[it2[0]:it2[1]].strip()) > 0,
        map(
            lambda it1: (tuple(it1)[0], tuple(it1)[1]),
            tokenize(s.replace('​', ' '))
        )
    )
    word_bounds = []
    punctuation = {',', '-', ':', ';', '.', ')', '(', '\]', '[', '<', '>',
                   '=', '+', '?', '!'}
    for start_word_pos, end_word_pos in tokenization_iterator:
        cur_word = s[start_word_pos:end_word_pos]
        if len(cur_word.strip()) > 0:
            wordpart_start = -1
            for char_idx, char_val in enumerate(cur_word):
                if char_val in punctuation:
                    if wordpart_start >= 0:
                        word_bounds.append((
                            start_word_pos + wordpart_start,
                            start_word_pos + char_idx
                        ))
                        wordpart_start = -1
                    word_bounds.append((
                        start_word_pos + char_idx,
                        start_word_pos + char_idx + 1
                    ))
                else:
                    if wordpart_start < 0:
                        wordpart_start = char_idx
            if wordpart_start >= 0:
                word_bounds.append((
                    start_word_pos + wordpart_start,
                    start_word_pos + len(cur_word)
                ))
    for start_word_pos, end_word_pos in word_bounds:
        cur_word = s[start_word_pos:end_word_pos]
        bpe = tokenizer.tokenize(cur_word)
        if len(bpe) == 0:
            err_msg = f'The word "{cur_word}" cannot be tokenized!'
            raise ValueError(err_msg)
        if tokenizer.unk_token in bpe:
            subtokens.append(tokenizer.unk_token)
            subtoken_bounds.append((start_word_pos, end_word_pos))
            words.append((cur_word, n_bpe, n_bpe + 1))
            n_bpe += 1
        elif len(bpe) > 1:
            prep_word = remove_accents(cur_word.lower())
            subword_start_pos = 0
            for src in bpe:
                if src.startswith('##'):
                    prep = src[2:]
                else:
                    prep = src
                prep = remove_accents(prep.lower()).replace('`', '')
                found_start, found_end = find_substring(
                    s=prep_word[subword_start_pos:],
                    substring=prep
                )
                if (found_start < 0) or (found_end < 0):
                    err_msg = f'The text {s} cannot be tokenized! "{prep}" is ' \
                              f'not found in the "{prep_word}" from ' \
                              f'{subword_start_pos}. Subwords are: {bpe}'
                    raise ValueError(err_msg)
                subword_start_pos += found_start
                subword_end_pos = subword_start_pos + (found_end - found_start)
                subtokens.append(src)
                subtoken_bounds.append((
                    start_word_pos + subword_start_pos,
                    start_word_pos + subword_end_pos
                ))
                subword_start_pos = subword_end_pos
            if (subtoken_bounds[-1][1] - start_word_pos) < len(prep_word):
                subtoken_bounds[-1] = (
                    subtoken_bounds[-1][0],
                    len(prep_word) + start_word_pos
                )
            words.append((cur_word, n_bpe, n_bpe + len(bpe)))
            n_bpe += len(bpe)
        else:
            subtokens.append(bpe[0])
            subtoken_bounds.append((start_word_pos, end_word_pos))
            words.append((cur_word, n_bpe, n_bpe + 1))
            n_bpe += 1
    subtokens.append(tokenizer.sep_token)
    subtoken_bounds.append(None)
    return words, subtokens, subtoken_bounds

In [None]:
#!g1.1
def train_test_split(data: Dict[int, Tuple[str, List[Tuple[str, int, int]]]]) \
        -> Tuple[Dict[int, Tuple[str, List[Tuple[str, int, int]]]],
                 Dict[int, Tuple[str, List[Tuple[str, int, int]]]]]:
    frequencies = calc_entity_freqs(data)
    identifiers = sorted(list(data.keys()))
    print(f'There are {len(frequencies)} named entity classes:')
    max_txt_width = max(map(lambda it: len(it), frequencies))
    max_num_width = max(map(lambda it: len(str(frequencies[it])), frequencies))
    sorted_entity_list = sorted(
        list(frequencies.keys()),
        key=lambda it: (-frequencies[it], it)
    )
    for named_entity in sorted_entity_list:
        freq = frequencies[named_entity]
        if freq < 15:
            err_msg = f'The data cannot be splitted because ' \
                      f'the entity {named_entity} is too rare ' \
                      f'(its frequency is {freq}).'
            raise ValueError(err_msg)
        print('  {0:<{1}} {2:>{3}}'.format(named_entity, max_txt_width,
                                           freq, max_num_width))
    print('')
    random.shuffle(identifiers)
    n = int(round(0.15 * float(len(identifiers))))
    training_frequencies = calc_entity_freqs(data, identifiers[n:])
    test_frequencies = calc_entity_freqs(data, identifiers[:n])
    if set(training_frequencies.keys()) == set(test_frequencies.keys()):
        ok = True
        for it in training_frequencies:
            ratio = training_frequencies[it] / float(frequencies[it])
            if ratio < 0.1:
                ok = False
                break
        if ok:
            for it in test_frequencies:
                ratio = test_frequencies[it] / float(frequencies[it])
                if ratio < 0.1:
                    ok = False
                    break
    else:
        ok = False
    if not ok:
        for _ in range(1000):
            random.shuffle(identifiers)
            training_frequencies = calc_entity_freqs(data, identifiers[n:])
            test_frequencies = calc_entity_freqs(data, identifiers[:n])
            if set(training_frequencies.keys()) == set(test_frequencies.keys()):
                ok = True
                for it in training_frequencies:
                    ratio = training_frequencies[it] / float(frequencies[it])
                    if ratio < 0.1:
                        ok = False
                        break
                if ok:
                    for it in test_frequencies:
                        ratio = test_frequencies[it] / float(frequencies[it])
                        if ratio < 0.1:
                            ok = False
                            break
            else:
                ok = False
            if ok:
                break
    if not ok:
        err_msg = 'The data cannot be splitted.'
        raise ValueError(err_msg)
    data_for_training = dict()
    data_for_testing = dict()
    for it in identifiers[n:]:
        data_for_training[it] = data[it]
    for it in identifiers[:n]:
        data_for_testing[it] = data[it]
    print('For training:')
    for named_entity in sorted_entity_list:
        freq = training_frequencies[named_entity]
        print('  {0:<{1}} {2:>{3}}'.format(named_entity, max_txt_width,
                                           freq, max_num_width))
    print('')
    print('For testing:')
    for named_entity in sorted_entity_list:
        freq = test_frequencies[named_entity]
        print('  {0:<{1}} {2:>{3}}'.format(named_entity, max_txt_width,
                                           freq, max_num_width))
    print('')
    return data_for_training, data_for_testing

In [None]:
#!g1.1
SENTENIZE_EXCLUSIONS = [
    'st.',
    'св.',
    'г.',
    'с.',
    'род.',
    'рожд.'
]

In [None]:
#!g1.1
def find_quoted_substrings(s: str) -> List[Tuple[int, int]]:
    span_start = -1
    spans = []
    for char_idx, char_val in enumerate(s):
        if char_val in {'"', '\''}:
            if span_start < 0:
                span_start = char_idx
            else:
                span_end = char_idx + 1
                spans.append((span_start, span_end))
                span_start = -1
        elif char_val == '«':
            if span_start < 0:
                span_start = char_idx
        elif char_val == '»':
            if span_start >= 0:
                span_end = char_idx + 1
                spans.append((span_start, span_end))
                span_start = -1
    return spans

In [None]:
#!g1.1
def find_span(spans: List[Tuple[int, int]], char_position: int) -> int:
    found_idx = -1
    for span_idx, (span_start, span_end) in enumerate(spans):
        if (char_position >= span_start) and (char_position < span_end):
            found_idx = span_idx
            break
    return found_idx

In [None]:
#!g1.1
def is_exclusion(s: str) -> bool:
    prep = s.lower()
    ok = False
    for cur_exclusion in SENTENIZE_EXCLUSIONS:
        found_idx = prep.rfind(cur_exclusion)
        if found_idx >= 0:
            if prep[found_idx:] == cur_exclusion:
                if found_idx > 0:
                    ok = (not prep[found_idx - 1].isalnum())
                else:
                    ok = True
        if ok:
            break
    return ok


In [None]:
#!g1.1
def sentenize_with_exclusions(s: str) -> List[Tuple[int, int]]:
    sentence_bounds = list(map(
        lambda it: (tuple(it)[0], tuple(it)[1]),
        sentenize(s)
    ))
    if len(sentence_bounds) == 0:
        return sentence_bounds
    prepared_sentence_bounds = [sentence_bounds[0]]
    prev_text = s[sentence_bounds[0][0]:sentence_bounds[0][1]].lower()
    for cur_bounds in sentence_bounds[1:]:
        if is_exclusion(prev_text):

            prepared_sentence_bounds[-1] = (
                prepared_sentence_bounds[-1][0],
                cur_bounds[1]
            )
      
        else:
            prepared_sentence_bounds.append(cur_bounds)
        prev_text = s[cur_bounds[0]:cur_bounds[1]].lower()
    quote_spans = find_quoted_substrings(s)
    if len(quote_spans) == 0:
        return prepared_sentence_bounds
    for quote_span_start, quote_span_end in quote_spans:
        first_sent_idx = find_span(prepared_sentence_bounds, quote_span_start)
        last_sent_idx = find_span(prepared_sentence_bounds, quote_span_end - 1)
        if (first_sent_idx < 0) or (last_sent_idx < 0):
            raise ValueError(f'The sentence "{s}" has incorrect spans!')
        if first_sent_idx < last_sent_idx:
            prepared_sentence_bounds[first_sent_idx] = (
                prepared_sentence_bounds[first_sent_idx][0],
                prepared_sentence_bounds[last_sent_idx][1]
            )
            prepared_sentence_bounds = \
                prepared_sentence_bounds[:(first_sent_idx + 1)] + \
                prepared_sentence_bounds[(last_sent_idx + 1):]
    return prepared_sentence_bounds

In [None]:
#!g1.1
def sentenize_text(s: str) -> List[Tuple[int, int]]:
    sent_start = -1
    sentence_bounds = []
    newline_counter = 0
    last_char = ''
    for char_idx, char_val in enumerate(s):
        if char_val in {'\n', '\r'}:
            newline_counter += 1
        else:
            if not char_val.isspace():
                if sent_start < 0:
                    sent_start = char_idx
                else:
                    if newline_counter > 0:
                        if last_char in {'?', '!'}:
                            sent_end = char_idx
                        elif char_val.istitle() or (last_char == '.'):
                            sent_end = char_idx
                        else:
                            sent_end = -1
                        if sent_end >= 0:
                            while sent_end > sent_start:
                                if not s[sent_end - 1].isspace():
                                    break
                                sent_end -= 1
                            if sent_end > sent_start:
                                text = s[sent_start:sent_end].replace('​', ' ')
                                if len(text.strip()) > 0:
                                    for it in sentenize_with_exclusions(text):
                                        sentence_bounds.append((
                                            sent_start + it[0],
                                            sent_start + it[1]
                                        ))
                            sent_start = char_idx
                        newline_counter = 0
                last_char = char_val
    if sent_start >= 0:
        sent_end = len(s)
        while sent_end > sent_start:
            if not s[sent_end - 1].isspace():
                break
            sent_end -= 1
        if sent_end > sent_start:
            text = s[sent_start:sent_end].replace('​', ' ')
            if len(text.strip()) > 0:
                for it in sentenize_with_exclusions(text):
                    sentence_bounds.append((
                        sent_start + it[0],
                        sent_start + it[1]
                    ))
    return sentence_bounds
    

In [None]:
#!g1.1
def find_substring(s: str, substring: str) -> Tuple[int, int]:
    if '`' in substring:
        err_msg = f'"{substring}" is a wrong sub-word, because ' \
                  f'it contains "`". It cannot be found in the string "{s}".'
        raise ValueError(err_msg)
    if substring != substring.strip():
        err_msg = f'"{substring}" is a wrong sub-word, because ' \
                  f'it includes initial and/or final spaces.'
        raise ValueError(err_msg)
    if len(substring) == 0:
        return -1, -1
    if '`' not in s:
        start_pos = s.find(substring)
        if start_pos < 0:
            return -1, -1
        return start_pos, start_pos + len(substring)
    found_idx = s.find(substring[0])
    if found_idx < 0:
        return -1, -1
    idx1 = found_idx + 1
    if found_idx > 0:
        while found_idx > 0:
            if s[found_idx - 1] != '`':
                break
            found_idx -= 1
    for idx2 in range(1, len(substring)):
        while idx1 < len(s):
            if s[idx1] != '`':
                break
            idx1 += 1
        if idx1 >= len(s):
            break
        if s[idx1] != substring[idx2]:
            break
        idx1 += 1
    if s[found_idx:idx1].replace('`', '') != substring:
        return -1, -1
    while idx1 < len(s):
        if s[idx1] != '`':
            break
        idx1 += 1
    return found_idx, idx1

In [None]:
#!g1.1
def remove_accents(s: str) -> str:
    res = []
    for c in s:
        norm = ud.normalize('NFKD', c)
        if len(norm) > 1:
            new_res = list(filter(lambda it: ud.combining(it) == 0, norm))
            if len(new_res) == 0:
                res.append('`')
            else:
                res.append(new_res[0])
        elif len(norm) == 1:
            if ud.combining(norm) == 0:
                res.append(norm)
            else:
                res.append('`')
        else:
            res.append('`')
    return "".join(res)

In [None]:
#!g1.1
def sentenize_text_with_ners(s: str, tokenizer: BertTokenizer,
                             ners: List[Tuple[str, int, int]],
                             ne_vocabulary: List[str]) \
        -> List[Tuple[List[str], List[List[int]]]]:
    if len(ners) != len(set(ners)):
        raise ValueError('Some entities are duplicated!')
    sentence_bounds = sentenize_text(s)
    res = []
    used_entities = set()
    for sent_start, sent_end in sentence_bounds:
        ners_for_sent = []
        for ne_type, ne_start, ne_end in ners:
            if ne_end <= sent_start:
                continue
            if ne_start >= sent_end:
                continue
            if (ne_start < sent_start) or (ne_end > sent_end):
                err_msg = f'The entity ({ne_type}, {ne_start}, {ne_end}) ' \
                          f'is wrong! It is located in more than ' \
                          f'a single sentence. More probably sentence is ' \
                          f'"{s[sent_start:sent_end]}"'
                raise ValueError(err_msg)
            if (ne_type, ne_start, ne_end) in used_entities:
                err_msg = f'The entity ({ne_type}, {ne_start}, {ne_end}) ' \
                          f'is wrong! It is located in more than ' \
                          f'a single sentence. More probably sentence is ' \
                          f'"{s[sent_start:sent_end]}"'
                raise ValueError(err_msg)
            ners_for_sent.append(
                (
                    ne_type,
                    ne_start - sent_start,
                    ne_end - sent_start
                )
            )
            used_entities.add((ne_type, ne_start, ne_end))
        res.append(tokenize_text_with_ners(s[sent_start:sent_end], tokenizer,
                                           ners_for_sent, ne_vocabulary))
    if len(used_entities) != len(ners):
        err_msg = f'Some entities are not used! They are: ' \
                  f'{sorted(list(set(ners) - used_entities))}'
        raise ValueError(err_msg)
    return res

In [None]:
def build_trainset_for_ner(data: Dict[int,
                                      Tuple[str, List[Tuple[str, int, int]]]],
                           tokenizer: BertTokenizer, max_seq_len: int,
                           entities: List[str]) \
        -> Tuple[np.ndarray, List[np.ndarray]]:
    if 'O' in entities:
        err_msg = f'The entities list {entities} is wrong ' \
                  f'because it contains the `O` entity.'
        raise ValueError(err_msg)
    list_of_tokenized_texts = []
    list_of_ne_indicators = []
    max_seq_len_ = max_seq_len
    print(f'Number of texts is {len(data)}.')
    for cur_id in tqdm(sorted(list(data.keys()))):
        text, ners = data[cur_id]
        batch = sentenize_text_with_ners(
            s=text,
            tokenizer=tokenizer,
            ners=ners,
            ne_vocabulary=entities
        )
        for tokenized_text, ne_indicators in batch:
            list_of_tokenized_texts.append(tokenized_text)
            list_of_ne_indicators.append(ne_indicators)
            if len(tokenized_text) > max_seq_len_:
                max_seq_len_ = len(tokenized_text)
    print(f'Number of sentences is {len(list_of_tokenized_texts)}.')
    X = []
    y = [[] for _ in range(len(entities))]
    for tokenized_text, ne_indicators in zip(list_of_tokenized_texts,
                                             list_of_ne_indicators):
        ne_indicators_ = copy.copy(ne_indicators)
        while len(tokenized_text) < max_seq_len_:
            tokenized_text.append(tokenizer.pad_token)
            for ne_id in range(len(entities)):
                ne_indicators_[ne_id].append(0)
        X.append(tokenizer.convert_tokens_to_ids(tokenized_text))
        for ne_id in range(len(entities)):
            y[ne_id].append(
                transform_indicator_to_classmatrix(ne_indicators_[ne_id])
            )
        del ne_indicators_
    X = np.array(X, dtype=np.int32)
    y = [np.concatenate(cur, axis=0) for cur in y]
    if X.shape[1] == max_seq_len:
        return X, y
    indices_of_long_texts = []
    for sample_idx in range(X.shape[0]):
        is_padding = True
        for token_idx in range(max_seq_len, X.shape[1]):
            if X[sample_idx, token_idx] != tokenizer.pad_token_id:
                is_padding = False
                break
        if not is_padding:
            indices_of_long_texts.append(sample_idx)
    iteration = 1
    while len(indices_of_long_texts) > 0:
        print(f'Iter {iteration}: '
              f'there are {len(indices_of_long_texts)} very long texts!')
        new_X = np.full(
            shape=(len(indices_of_long_texts), max_seq_len_),
            fill_value=tokenizer.pad_token_id,
            dtype=np.int32
        )
        new_y = [np.zeros((len(indices_of_long_texts), max_seq_len_, 5),
                          dtype=np.float32) for _ in range(len(y))]
        ndiff = max_seq_len_ - max_seq_len
        for local_idx, global_idx in enumerate(indices_of_long_texts):
            new_X[local_idx, 0:ndiff] = X[global_idx, max_seq_len:]
            X[global_idx, max_seq_len:] = tokenizer.pad_token_id
            for output_idx in range(len(y)):
                new_y[output_idx][local_idx, 0:ndiff, :] = \
                    y[output_idx][global_idx, max_seq_len:, :]
                y[output_idx][global_idx, max_seq_len:, :] = 0.0
        X = np.concatenate((X, new_X), axis=0)
        y = [np.concatenate((y[output_idx], new_y[output_idx]), axis=0)
             for output_idx in range(len(y))]
        indices_of_long_texts = []
        for sample_idx in range(X.shape[0]):
            is_padding = True
            for token_idx in range(max_seq_len, X.shape[1]):
                if X[sample_idx, token_idx] != tokenizer.pad_token_id:
                    is_padding = False
                    break
            if not is_padding:
                indices_of_long_texts.append(sample_idx)
        iteration += 1
    X = X[:, :max_seq_len]
    y = [cur[:, :max_seq_len, :] for cur in y]
    print(f'Number of sentences after cutting is {X.shape[0]}.')
    return X, y

In [None]:
#!g1.1
def calc_features(tokenizer: BertTokenizer, feature_extractor: TFBertModel,
                  max_sent_len: int, source_text: str) -> \
        Tuple[List[Tuple[str, int, int]], np.ndarray]:
    word_features = []
    all_words = []
    for sent_start, sent_end in sentenize_text(source_text):
        words, subtokens, subtoken_bounds = tokenize_text(
            s=source_text[sent_start:sent_end],
            tokenizer=tokenizer
        )
        while (len(subtokens) % max_sent_len) != 0:
            subtokens.append(tokenizer.pad_token)
            subtoken_bounds.append(None)
        x = []
        start_pos = 0
        for _ in range(len(subtokens) // max_sent_len):
            end_pos = start_pos + max_sent_len
            subtoken_indices = tokenizer.convert_tokens_to_ids(
                subtokens[start_pos:end_pos]
            )
            x.append(
                np.array(
                    subtoken_indices,
                    dtype=np.int32
                ).reshape((1, max_sent_len))
            )
            start_pos = end_pos
        predicted = feature_extractor.predict(np.vstack(x), batch_size=1)[0]
        if not isinstance(predicted, np.ndarray):
            predicted = predicted.numpy()
        if len(predicted.shape) != 3:
            err_msg = f'The predicted feature matrix is wrong! ' \
                      f'Expected 3-D array, got {len(predicted.shape)}-D one.'
            raise ValueError(err_msg)
        if predicted.shape[0] != (len(subtokens) // max_sent_len):
            err_msg = f'The predicted feature matrix does not correspond to' \
                      f' the input data! {predicted.shape[0]} != ' \
                      f'{len(subtokens) // max_sent_len}.'
            raise ValueError(err_msg)
        subtoken_features = [predicted[0]]
        for idx in range(1, predicted.shape[0]):
            subtoken_features.append(predicted[idx])
        subtoken_features = np.vstack(subtoken_features)
        del predicted
        for cur_word, word_start, word_end in words:
            word_features.append(
                np.mean(subtoken_features[word_start:word_end],
                        axis=0, keepdims=True)
            )
            all_words.append((
                cur_word,
                subtoken_bounds[word_start][0] + sent_start,
                subtoken_bounds[word_end - 1][1] + sent_start
            ))
    return all_words, np.vstack(word_features)

In [None]:
#!g1.1
def find_entity_words(words: List[Tuple[str, int, int]],
                      entity_start: int, entity_end: int) -> Tuple[int, int]:
    start_word_idx = -1
    end_word_idx = -1
    for word_idx, (_, word_start, word_end) in enumerate(words):
        if entity_start < word_end:
            if start_word_idx < 0:
                start_word_idx = word_idx
        if entity_end > word_start:
            end_word_idx = word_idx
        if word_start >= entity_end:
            break
    if (start_word_idx < 0) or (end_word_idx < 0):
        return -1, -1
    true_entity_start = words[start_word_idx][1]
    true_entity_end = words[end_word_idx][2]
    if entity_end <= true_entity_start:
        return -1, -1
    if entity_start >= true_entity_end:
        return -1, -1
    return start_word_idx, end_word_idx + 1

In [None]:
#!g1.1
def calc_features_and_labels(tokenizer: BertTokenizer,
                             feature_extractor: TFBertModel, max_sent_len: int,
                             ne_list: List[str], source_text: str,
                             annotation: List[Tuple[str, int, int]]) -> \
        Tuple[np.ndarray, List[np.ndarray]]:
    words, features = calc_features(tokenizer, feature_extractor, max_sent_len,
                                    source_text)
    named_entities = [np.zeros((len(words), 5), dtype=np.float32)
                      for _ in range(len(ne_list))]
    for word_idx in range(len(words)):
        for named_entity_id in range(len(ne_list)):
            named_entities[named_entity_id][word_idx, 0] = 1.0
    for entity_class, entity_char_start, entity_char_end in annotation:
        try:
            named_entity_id = ne_list.index(entity_class)
        except:
            named_entity_id = -1
        if named_entity_id < 0:
            err_msg = f'The entity class "{entity_class}" is unknown!'
            raise ValueError(err_msg)
        entity_start, entity_end = find_entity_words(words, entity_char_start,
                                                     entity_char_end)
        if (entity_start < 0) or (entity_end < 0):
            unknown_entity = (entity_class, entity_char_start, entity_char_end)
            input_text = source_text.replace("\n", " ").replace("\r", " ")
            err_msg = f'The entity {unknown_entity} is not found in the text ' \
                      f'"{input_text}", tokenized by the following words: ' \
                      f'{words}.'
            raise ValueError(err_msg)
        if entity_end - entity_start > 1:
            named_entities[named_entity_id][entity_start, 0] = 0.0
            named_entities[named_entity_id][entity_start, 1] = 1.0
            for word_idx in range(entity_start + 1, entity_end - 1):
                named_entities[named_entity_id][word_idx, 0] = 0.0
                named_entities[named_entity_id][word_idx, 3] = 1.0
            named_entities[named_entity_id][entity_end - 1, 0] = 0.0
            named_entities[named_entity_id][entity_end - 1, 2] = 1.0
        else:
            named_entities[named_entity_id][entity_start, 0] = 0.0
            named_entities[named_entity_id][entity_start, 4] = 1.0
    return features, named_entities

In [None]:
#!g1.1
model_path = 'DeepPavlov/rubert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
#!g1.1
model = transformers.TFBertModel.from_pretrained(model_path, 
                                                 output_hidden_states = True, 
                                                 from_pt=True,
                                                 )

In [None]:
#!g1.1

with codecs.open(ners_fname, mode='r', encoding='utf-8') as fp:
        possible_named_entities = list(filter(
            lambda it2: len(it2) > 0,
            map(
                lambda it1: it1.strip(),
                fp.readlines()
            )
        ))

In [None]:
train_runne_text_and_ners = load_runne_data('train.jsonl?raw=true')

In [None]:
runne_features = []
runne_labels = [[] for _ in range(len(possible_named_entities))]
for cur_id in tqdm(sorted(list(train_runne_text_and_ners.keys()))):
    text, ners = train_runne_text_and_ners[cur_id]
    X, y = calc_features_and_labels(
                  tokenizer=tokenizer,
                  feature_extractor=model, 
                  max_sent_len=128,
                  ne_list=possible_named_entities,
                  source_text=text,
                  annotation=ners,
              )
    runne_features.append(X)
    for idx in range(len(possible_named_entities)):
        runne_labels[idx].append(y[idx])

In [None]:
runne_concat_features = np.vstack([cur for cur in runne_features])
runne_concat_labels = runne_labels
print('')
print(f'X.shape = {runne_concat_features.shape}')
for idx in range(len(possible_named_entities)):
            runne_concat_labels[idx] = np.vstack(runne_labels[idx])
print('')
for ne_id, ne_cls in enumerate(possible_named_entities):
            print(f'y[{ne_cls}].shape = {runne_concat_labels[ne_id].shape}')

In [None]:
with open("runne_features_and_labels_train", 'wb') as fp:
    pickle.dump(
        obj=(runne_concat_features, runne_concat_labels),
        file=fp,
        protocol=pickle.HIGHEST_PROTOCOL
    )

In [None]:
del train_runne_text_and_ners
gc.collect()

In [None]:
start_read = time.time()
train = pd.read_csv("lenta-ru-news.csv.bz2", usecols=['text'], nrows=1000, dtype={'text': str})
end_read = time.time() - start_read
print("Reading time: ", end_read)

In [None]:
lenta_features = []
for cur_id in tqdm(sorted(list(train.text.keys()))):
    text = train.text[cur_id]
    text = text.replace('\xad', '')
    text = text.replace('\n', '') 
    X = calc_features(
      tokenizer=tokenizer,
      feature_extractor=model, 
      max_sent_len=128,
      source_text=text,
  )
    lenta_features.append(X)

In [None]:
prep_data = build_trainset_for_ner(
            data=train_runne_text_and_ners,
            tokenizer=tokenizer,
            entities=possible_named_entities,
            max_seq_len=128,
        )

In [None]:
del train
gc.collect()

In [None]:
concat_features = np.vstack([cur[1] for cur in lenta_features])
print('')
print(f'X.shape = {concat_features.shape}')

In [None]:
training_set = tf.data.Dataset.from_tensor_slices(
        concat_features,
    ).batch(16)

In [None]:
concat_lenta_token_bounds = np.vstack([cur[0] for cur in lenta_features])
print('')
print(f'X.shape = {concat_lenta_token_bounds.shape}')

In [None]:
with open("lenta_concat_features_and_token_bounds_clear", 'wb') as fp:
    pickle.dump(
        obj=(concat_lenta_token_bounds, concat_features_2),
        file=fp,
        protocol=pickle.HIGHEST_PROTOCOL
    )

In [None]:
!pip install tensorflow_addons

In [None]:
import tensorflow_addons as tfa
import tensorflow_probability as tfp

In [None]:
def build_neural_network(n_features: int, n_classes: int,
                         n_latent: int, n_hidden: int, depth: int,
                         nn_name: str) -> Tuple[tf.keras.Model, tf.keras.Model, \
                                                tf.keras.Model]:
    if n_hidden < 1:
        err_msg = f'The hidden layer size = {n_hidden} is too small!'
    else:
        info_msg = f'There are {n_features} features.'
    print(info_msg)
    feature_vector = tf.keras.layers.Input(
        shape=(n_features,), dtype=tf.float32,
        name=f'{nn_name}_feature_vector'
    )
    try:
        kernel_initializer = tf.keras.initializers.LecunNormal(
            seed=random.randint(0, 2147483647)
        )
    except:
        kernel_initializer = tf.compat.v1.keras.initializers.lecun_normal(
            seed=random.randint(0, 2147483647)
        )
    encoder_layer = tf.keras.layers.Dense(
        units=n_hidden,
        activation='selu',
        kernel_initializer=kernel_initializer,
        bias_initializer='zeros',
        name=f'{nn_name}_enc_dense1'
    )(feature_vector)
    for layer_idx in range(1, depth):
        try:
            kernel_initializer = tf.keras.initializers.LecunNormal(
                seed=random.randint(0, 2147483647)
            )
        except:
            kernel_initializer = tf.compat.v1.keras.initializers.lecun_normal(
                seed=random.randint(0, 2147483647)
            )
        encoder_layer = tf.keras.layers.Dense(
            units=n_hidden,
            activation='selu',
            kernel_initializer=kernel_initializer,
            bias_initializer='zeros',
            name=f'{nn_name}_enc_dense{layer_idx + 1}'
        )(encoder_layer)
    prior = tfp.distributions.Independent(
        distribution=tfp.distributions.Normal(
            loc=tf.zeros(n_latent),
            scale=1
        ),
        reinterpreted_batch_ndims=1
    )
    try:
        kernel_initializer = tf.keras.initializers.LecunNormal(
            seed=random.randint(0, 2147483647)
        )
    except:
        kernel_initializer = tf.compat.v1.keras.initializers.lecun_normal(
            seed=random.randint(0, 2147483647)
        )
    latent_layer = tf.keras.layers.Dense(
        units=tfp.layers.IndependentNormal.params_size(n_latent),
        activation=None,
        kernel_initializer=kernel_initializer,
        bias_initializer='zeros',
        name=f'{nn_name}_latent'
    )(encoder_layer)
    z = tfp.layers.IndependentNormal(
        event_shape=n_latent,
        convert_to_tensor_fn=tfp.distributions.Distribution.sample,
        activity_regularizer=tfp.layers.KLDivergenceRegularizer(
            distribution_b=prior,
            weight=1e-3
        ),
        name=f'{nn_name}_z'
    )(latent_layer)
    classifier_input = tf.keras.layers.Input(
        shape=(n_latent,), dtype=tf.float32,
        name=f'{nn_name}_feature_vector'
    )
    try:
        kernel_initializer = tf.keras.initializers.LecunNormal(
            seed=random.randint(0, 2147483647)
        )
    except:
        kernel_initializer = tf.compat.v1.keras.initializers.lecun_normal(
            seed=random.randint(0, 2147483647)
        )
    hidden_layer = tf.keras.layers.Dense(
        units=n_hidden,
        activation='selu',
        kernel_initializer=kernel_initializer,
        bias_initializer='zeros',
        name=f'{nn_name}_cls_hidden'
    )(classifier_input)
    cls_layer = tf.keras.layers.Dense(
        units=n_classes,
        activation='softmax',
        kernel_initializer=kernel_initializer,
        bias_initializer='zeros',
        name=f'{nn_name}_cls_output'
    )(hidden_layer)
    cls_name = f'{nn_name}_cls'
    cls_model = tf.keras.Model(
        inputs=classifier_input,
        outputs=cls_layer,
        name=cls_name
    )
    cls_model.build(input_shape=[None, n_latent])
    try:
        kernel_initializer = tf.keras.initializers.LecunNormal(
            seed=random.randint(0, 2147483647)
        )
    except:
        kernel_initializer = tf.compat.v1.keras.initializers.lecun_normal(
            seed=random.randint(0, 2147483647)
        )
    decoder_layer = tf.keras.layers.Dense(
        units=n_hidden,
        activation='selu',
        kernel_initializer=kernel_initializer,
        bias_initializer='zeros',
        name=f'{nn_name}_dec_dense1'
    )(z)
    for layer_idx in range(1, depth):
        try:
            kernel_initializer = tf.keras.initializers.LecunNormal(
                seed=random.randint(0, 2147483647)
            )
        except:
            kernel_initializer = tf.compat.v1.keras.initializers.lecun_normal(
                seed=random.randint(0, 2147483647)
            )
        decoder_layer = tf.keras.layers.Dense(
            units=n_hidden,
            activation='selu',
            kernel_initializer=kernel_initializer,
            bias_initializer='zeros',
            name=f'{nn_name}_dec_dense{layer_idx + 1}'
        )(decoder_layer)
    try:
        kernel_initializer = tf.keras.initializers.LecunNormal(
            seed=random.randint(0, 2147483647)
        )
    except:
        kernel_initializer = tf.compat.v1.keras.initializers.lecun_normal(
            seed=random.randint(0, 2147483647)
        )
    reconstruction_name =   f'{nn_name}_reconstruction'
    reconstruction_layer = tf.keras.layers.Dense(
        units=n_features,
        activation=None,
        kernel_initializer=kernel_initializer,
        bias_initializer='zeros',
        name=reconstruction_name
    )(decoder_layer)
    encoder_model = tf.keras.Model(
        inputs=feature_vector,
        outputs=z,
        name=f'{nn_name}_enc'
    )
    vae_cls_model = tf.keras.Model(
        inputs=feature_vector,
        outputs=cls_model(z),
        name=f'{nn_name}_vae_cls'
    )
    vae_name = f'{nn_name}_vae'
    vae_model = tf.keras.Model(
        inputs=feature_vector,
        outputs=reconstruction_layer,
        name=vae_name,
    )
    encoder_model.build(input_shape=[None, n_features])
    vae_cls_model.build(input_shape=[None, n_features])
    metrics = {cls_name: [tf.keras.metrics.CategoricalAccuracy()]}
    loss_cls =  tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)
    loss_vae = tf.keras.losses.LogCosh()
    loss_weights_cls = 1.0
    loss_weights_vae = 1.5
    adam = tf.optimizers.Adam(learning_rate=0.01)
    #united_model.compile(optimizer=ranger, loss=losses, loss_weights=loss_weights,
    #                     metrics=metrics)
    vae_model.compile(optimizer=radam, loss=loss_vae, loss_weights=loss_weights_vae)
    return vae_model, vae_cls_model, encoder_model, cls_model

In [None]:
vae_nn, vae_cls_nn, encoding_nn, classification_nn = build_neural_network(
    n_features=concat_features.shape[1], n_classes=29,
    n_latent=64, n_hidden=200, depth=6, nn_name='deeploma_ner'
)

In [None]:
vae_nn.summary()

In [None]:
vae_nn_fname = 'deeploma_vae.h5'

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath=vae_nn_fname,
        monitor="val_loss",
        mode="min",
        save_best_only=True,
        save_weights_only=True
    ),
]

In [None]:
model_history = vae_nn.fit(training_set, validation_data=training_set,
                                 epochs=100, callbacks=callbacks,
                                 verbose=1)

In [None]:
del model_history
gc.collect()

In [None]:
training_set = tf.data.Dataset.from_tensor_slices(
    (
        prep_data_train[0],
        tuple(prep_data_train[1])
    )
.shuffle(prep_data[0].shape[0]).batch(16)
)

validation_set = tf.data.Dataset.from_tensor_slices(
    (
        prep_data_val[0],
        tuple(prep_data_val[1])
    )
.shuffle(prep_data_val[0].shape[0]).batch(16)
)

In [None]:
vae_cls_nn.summary()

In [None]:
vae_cls_nn_fname = "deeploma_vae_cls.h5"

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath=vae_cls_nn_fname,
        monitor="val_loss",
        mode="min",
        save_best_only=True,
        save_weights_only=True
    ),
]

In [None]:
model_history = vae_cls_nn.fit(training_set, validation_data=validation_set,
                                 epochs=100, callbacks=callbacks,
                                 verbose=1)