In [1]:
import numpy as np
import random
import pymorphy3
import numpy as np
import math
import pickle
from razdel import tokenize

from dataset_builder import calculate_word_features_for_tokens, PAD_TOKEN,get_word_features
from inference import torch_model_runner, onnx_model_runner, infer

onnx_model = onnx_model_runner("results writers big/model.onnx")
with open("params.pickle", "rb") as f:
    params = pickle.load(f)

class jsinfer:
    async def infer(arr):
        class wrapper:
            def to_py():
                return onnx_model(arr)
        return wrapper

from stream import Stream
import functools
from collections import deque
import random
random.seed(42)

@functools.lru_cache(maxsize=128)
def get_word_features_cached(word):
    return get_word_features(word, params).numpy()

class Substr:
    def __init__(self, text):
        self.text = text
    def __repr__(self) -> str:
        return f"Substring(-1, -1, {self.text})"

def d_as_str(d):
  return "<" + " ".join(map(lambda text: text.text, d))+ ">"
    
async def infer_optimal(params, text): 
  # print("INFERCENC IS WIERD\n" * 10)
  res = []
  last_inserted_pos = 0
  def sink(token, log=False):
    nonlocal last_inserted_pos
    if token.text == "PAD": return
    if log: print('sink', token)
    if isinstance(token, Substr):
      res.append(token.text)
      if log: print("added1 ", f"`{token.text}`", token)
    else:
      if last_inserted_pos != token.start:
        res.append(text[last_inserted_pos: token.start])
        if log: print("added2 ", f"`{text[last_inserted_pos: token.start]}`", last_inserted_pos, token.start)
      last_inserted_pos = token.stop
      res.append(token.text)
      if log: print("added3 ", f"`{token.text}`", token)

  def skip(token, log=False):
    nonlocal last_inserted_pos
    last_inserted_pos = token.stop
    if log: print('skip', token)

  def sink_remaining():
     res.append(text[last_inserted_pos:])


  async def predict_on_tokens(window_left, window_right, return_probas):
    features = [get_word_features_cached(i.text) for i in Stream(window_left).chain(window_right)]
    features_for_batch = np.stack((features, ))
    arr = np.ascontiguousarray(features_for_batch, dtype=np.float32)
    output_probas = np.array((await jsinfer.infer(arr)).to_py())
    # output_probas[0][0] += 2.
    if return_probas:
      return params["ID_TO_PUNCTUATION"], output_probas 
    punct_idx = np.argmax(output_probas).item()
    punct = params["ID_TO_PUNCTUATION"][punct_idx]
    return punct


  window_left = deque()
  window_right = deque()
  log = False
  skip_next = False
  for i in Stream.repeat(Substr(PAD_TOKEN), params['INPUT_WORDS_CNT_LEFT']) \
      .chain(Stream(tokenize(text))) \
      .chain(Stream.repeat(Substr(PAD_TOKEN), params["INPUT_WORDS_CNT_RIGHT"])):
    window_right.append(i)
    if len(window_right) <= params["INPUT_WORDS_CNT_RIGHT"]:
        continue
    assert len(window_right) == params["INPUT_WORDS_CNT_RIGHT"] + 1

    next_ = window_right.popleft()
    sink(next_)
    window_left.append(next_)
    if len(window_left) < params['INPUT_WORDS_CNT_LEFT']: 
      continue

    assert len(window_left) == params["INPUT_WORDS_CNT_LEFT"]
    assert len(window_right) == params["INPUT_WORDS_CNT_RIGHT"]

    if skip_next:
      prediction = "$skip" 
    else:
      # params["ID_TO_PUNCTUATION"], output_probas
      prediction = await predict_on_tokens(window_left, window_right, return_probas=False) 


    #random.choice([" ", "."])
    if log: print(d_as_str(window_left).rjust(100), prediction.center(6), d_as_str(window_right))

    def is_replaceable_punct(punct):
      return punct in ',.'

    if prediction == "$skip":
      pass
    elif prediction != "$empty":
      if is_replaceable_punct(window_right[0].text):
        if window_right[0].text != prediction:
          window_right[0].text = prediction
      else:
        window_left.append(Substr(prediction))
        sink(window_left[-1])
    else:
      if is_replaceable_punct(window_right[0].text):
          skip(window_right.popleft())

    skip_next = is_replaceable_punct(window_right[0].text)

    while len(window_left) != params['INPUT_WORDS_CNT_LEFT'] - 1: 
      token = window_left.popleft()

    if log: print(d_as_str(window_left).rjust(100), "      ", d_as_str(window_right))

  for i in window_right:
    sink(i)
  sink_remaining()
  ress = "".join(res)
  return ress


In [64]:
await infer_optimal(params, "кек\n")

'кек.\n'

In [4]:
import diff_match_patch as dmp_module
from collections import defaultdict, Counter

def calculate_diff(text, text_res):
  dmp = dmp_module.diff_match_patch()
  diff = dmp.diff_main(text, text_res)

  diff_aggregate = defaultdict(lambda : 0)
  def sink2():
    nonlocal cur_add, cur_remove
    if cur_add == "" and cur_remove == "":
      return
    if cur_add == "":
      diff_aggregate['remove ' + cur_remove] += 1
      cur_remove = ""
      return
    if cur_remove == "":
      diff_aggregate['add ' + cur_add] += 1
      cur_add = ""
      return
    
    diff_aggregate['replace ' + cur_remove + " with " + cur_add] += 1
    cur_add = ""
    cur_remove = ""

  cur_remove = ""
  cur_add = ""

  UNCHANGED = 0
  ADD = 1
  REMOVE = -1

  for change in diff:
      if change[0] == UNCHANGED:
        c = Counter(change[1])
        diff_aggregate['unchanged .'] += c['.']
        diff_aggregate['unchanged ,'] += c[',']
        sink2() 
      elif change[0] == ADD:
        cur_add += change[1]
      elif change[0] == REMOVE:
        cur_remove += change[1]
      else:
        raise Exception("Unknown format")
      
  sink2()

  return diff_aggregate
 
text = "кек"
calculate_diff(text, await infer_optimal(params, text))

defaultdict(<function __main__.calculate_diff.<locals>.<lambda>()>,
            {'unchanged .': 0, 'unchanged ,': 0, 'add .': 1})

In [66]:
# text = txt
# text_res = await infer_optimal(params, text)
# print(text)
# print("="* 50)
# print(text_res)
# diff = calculate_diff2(text, text_res)
# diff

In [2]:
from collections import defaultdict
text = "Тест. тест, тест. Тест"
text_res = await infer_optimal(params, text)

def calculate_diff2(text, text_res):
  res = defaultdict(lambda: 0)

  def is_punctuation(c):
      return c in ".,"

  def sink_add(c):
    nonlocal res
    res['added ' + c] += 1

  def sink_remove(c):
    nonlocal res
    res['removed ' + c] += 1

  def sink_change(c1, c2):
    nonlocal res
    res['changed ' + c1 + " with " + c2] += 1

  i = 0
  j = 0
  while True:
      if i >= len(text): break
      if j >= len(text_res): break
      # print(text[i], text_res[j])
      if text[i] == text_res[j]:
          if is_punctuation(text[i]):
             res['not changed ' + text[i]] += 1
          i += 1
          j += 1
          continue
      
      if is_punctuation(text[i]) and is_punctuation(text_res[j]):
        sink_change(text[i], text_res[j])
        i += 1
        j += 1
        continue
      
      if is_punctuation(text[i]):
        sink_remove(text[i])
        i += 1
        continue
      
      if is_punctuation(text_res[j]):
        sink_add(text_res[j])
        j += 1
        continue
      
      raise Exception("Change not in punctuation", text[i], text_res[j], "at ", i, j)

  while i < len(text):
    # print("remaining: ", text[i])
    assert is_punctuation(text[i])
    sink_remove(text[i])
    i += 1

  while j < len(text_res):
    # print("remaining(2): ",text_res[j])
    assert is_punctuation(text_res[j])
    sink_add(text_res[j])
    j += 1

  res['possible punctuation places'] = len(list(tokenize(text)))

  return res

calculate_diff2(text, text_res)

defaultdict(<function __main__.calculate_diff2.<locals>.<lambda>()>,
            {'changed . with ,': 1,
             'not changed ,': 1,
             'not changed .': 1,
             'added .': 1,
             'possible punctuation places': 7})

In [5]:
text, text_res

('Тест. тест, тест. Тест', 'Тест, тест, тест. Тест.')

In [3]:
import glob
from striprtf.striprtf import rtf_to_text
from tqdm.notebook import tqdm

def dicts_sum(dict1, dict2):
  for key in dict2:
    dict1[key] += dict2[key]
  return dict1

res = defaultdict(lambda: 0)

i = 0
for rtf_path in tqdm(glob.glob("../validation/Mark Tven/Mark Tven rtf/*.rtf")):
  with open(rtf_path, "rb") as rtf_file:
    encoded = rtf_file.read()
    try:
      rtf = encoded.decode('cp1251')
      txt = rtf_to_text(rtf)
      diff = calculate_diff2(txt, await infer_optimal(params, txt))
      res = dicts_sum(res, diff)
    except Exception as ex:
      print("skipped ", rtf_path, len(encoded), ex)
      raise
    i += 1
    # if i> 2: break

res

  0%|          | 0/148 [00:00<?, ?it/s]

skipped  ../validation/Mark Tven/Mark Tven rtf/Tven_Iz_zapisnyih_knizhek_1865-1905.131699.rtf 125718 ('Change not in punctuation', ' ', '\n', 'at ', 23961, 24102)


Exception: ('Change not in punctuation', ' ', '\n', 'at ', 23961, 24102)

In [14]:
res

defaultdict(<function __main__.<lambda>()>, {})