## RuCor to CoNLL-U

In [1]:
from corpuscula import Conllu
import csv
import difflib
import os
import pandas as pd
#import random
import re
#import textdistance
from toxine import TextPreprocessor
#from uuid import uuid4

#text_dist = textdistance.JaroWinkler().distance

#random.seed(42)  # for uuid

dataset_dir = '../_dataset'
rucor_dir = os.path.join(dataset_dir, 'rucoref')
conllu_dir = os.path.join(dataset_dir, 'rucoref_conllu')

models_dir = '..'              # workaround for *mordl*: models should be
models_root_dir = '_models'    # placed in the ../_models/ directory
# if you don't have cdict, set cdict_path = None
cdict_path = os.path.join(models_dir, models_root_dir, 'upos-bert_model/cdict.pickle')

log_fn = os.path.join(dataset_dir, 'out.log')

In [None]:
TAG_COREF_HEADS = False

if TAG_COREF_HEADS:
    from mordl import UposTagger, FeatsTagger, LemmaTagger

    _work_dir = os.path.abspath(os.getcwd())
    os.chdir(models_dir)
    tagger_u = UposTagger()
    tagger_u.load(os.path.join(models_root_dir, 'upos-bert_model'), device='cuda:0', dataset_device='cuda:0')
    tagger_f = FeatsTagger()
    tagger_f.load(os.path.join(models_root_dir, 'feats-bert_model'), device='cuda:0', dataset_device='cuda:0')
    #tagger_l = LemmaTagger()
    #tagger_l.load(os.path.join(models_root_dir, 'lemma-ft_model'), device='cuda:0', dataset_device='cuda:0')
    os.chdir(_work_dir)
tp = TextPreprocessor(cdict_restore_from=cdict_path)

Fit corpus dict... 

In [None]:
rucor_docs_fn = os.path.join(rucor_dir, 'Documents.txt')
rucor_groups_fn = os.path.join(rucor_dir, 'Groups.txt')
rucor_tokens_fn = os.path.join(rucor_dir, 'Tokens.txt')

docs = pd.read_csv(rucor_docs_fn, sep='\t', index_col='doc_id', quoting=csv.QUOTE_NONE)
groups = pd.read_csv(rucor_groups_fn, sep='\t', index_col='group_id', quoting=csv.QUOTE_NONE)

In [None]:
# process certain doc
DOC_ID = 0
# adjust correct shifts to the wrong ones in Rucor
NEED_SHIFT_ADJUST = True

In [None]:
if not os.path.isdir(conllu_dir):
    os.mkdir(conllu_dir, mode=0o755)

def get_fns(doc_id, conllu_dir=None):
    in_fn_ = docs.loc[doc_id, 'path']
    in_fn = os.path.join(rucor_dir, 'rucoref_texts', in_fn_)
    if conllu_dir:
        out_dir = os.path.join(conllu_dir, os.path.dirname(in_fn_))
        if not os.path.isdir(out_dir):
            os.mkdir(out_dir, mode=0o755)
    else:
        out_dir = dataset_dir
    out_fn_ = os.path.join(out_dir, os.path.splitext(os.path.basename(in_fn))[0])
    out_ext_ = '.conllu'
    out_fn = out_fn_ + out_ext_
    return in_fn, out_fn

if DOC_ID:
    in_fn, out_fn = get_fns(DOC_ID)
    print(in_fn, out_fn, sep=', ')

In [None]:
def norm_punct(punct):
    return punct.replace('—', '-').replace(';', '.').replace('...', '.').replace('…', '.') \
                .replace('«', '"').replace('„', '"') \
                .replace('»', '"').replace('“', '"') \
                .replace('``', '"').replace("''", '"')

In [None]:
def get_raw(in_fn):
    re_html = re.compile('&[a-z]+;')

    with open(in_fn, 'rb') as f:
        raw = f.read().decode('utf-8-sig').lower()

        def process(match):
            text = match.group(0)
            len_text = len(text)
            text = tp._unescape_html(text)
            return ' ' * (len_text - len(text)) + text

        raw = re_html.sub(process, raw)

        raw_forms, raw_punct = [], []
        isalnum = None
        for ch in raw:
            res = ch.isalnum()
            if res:
                if res != isalnum:
                    raw_forms.append(ch)
                else:
                    raw_forms[-1] += ch
            elif not ch.isspace():
                raw_punct.append(ch)
            isalnum = res

        raw_forms_ids = []
        idx = 0
        for token in raw_forms:
            idx_ = raw.index(token, idx)
            raw_forms_ids.append(idx_)
            idx = idx_ + len(token)
        idx = 0
        raw_punct_ids = []
        for i, ch in enumerate(raw_punct):
            idx_ = raw.index(ch, idx)
            raw_punct_ids.append(idx_)
            raw_punct[i] = norm_punct(ch)
            idx = idx_ + 1

    return list(zip(raw_forms, raw_forms_ids)), list(zip(raw_punct, raw_punct_ids))

if DOC_ID:
    raw_corpus, raw_puncts = get_raw(in_fn)

In [None]:
# adjust correct shifts to the wrong ones in Rucor
def adjust_raw_corpus(doc_id, raw_corpus):
    for i, (form, idx) in enumerate(raw_corpus):
        if doc_id == 115:
            if idx >= 1288:
                raw_corpus[i] = form, idx + 2
            elif idx >= 771:
                raw_corpus[i] = form, idx + 1
        elif doc_id == 116:
            if idx >= 1545:
                raw_corpus[i] = form, idx + 7
            elif idx >= 884:
                raw_corpus[i] = form, idx + 6
            elif idx >= 858:
                raw_corpus[i] = form, idx + 5
            elif idx >= 394:
                raw_corpus[i] = form, idx + 4
            elif idx >= 388:
                raw_corpus[i] = form, idx + 3
            elif idx >= 386:
                raw_corpus[i] = form, idx + 2
            elif idx >= 165:
                raw_corpus[i] = form, idx + 1

if DOC_ID and NEED_SHIFT_ADJUST:
    adjust_raw_corpus(DOC_ID, raw_corpus)
    adjust_raw_corpus(DOC_ID, raw_puncts)

In [None]:
if DOC_ID:
    tp.clear_corpus()
    tp.load_pars(in_fn, eop=r'\n')

In [None]:
if DOC_ID:
    tp.do_all(tag_date=False, norm_punct=True)

In [None]:
if DOC_ID:
    _ = tp.save(out_fn + ('$' if TAG_COREF_HEADS else ''))

In [None]:
if DOC_ID and TAG_COREF_HEADS:
    tagger_f.predict(tagger_u.predict(out_fn + '$'), save_to=out_fn)

In [None]:
if DOC_ID and TAG_COREF_HEADS:
    os.remove(out_fn + '$')

In [None]:
re_nonalnum = re.compile('(?:\W|_)+')
re_alnum = re.compile('(?:[^\W_])+')

if DOC_ID:
    corpus_orig = list(Conllu.load(out_fn))
    corpus = []

In [None]:
def make_corpus():
    tag_shortcut = tp.TAG_SHORTCUT[2:]
    masks = list(x[1:] for x in tp.TAG_MASKS.keys())
    for sent in corpus_orig:
        for tok in sent[0]:
            form, misc = tok['FORM'], tok['MISC']
            for misc_ in misc:
                if misc_ in masks:
                    form = misc[misc_]
                elif misc_ == tag_shortcut:
                    form = misc[misc_]
            corpus.append((re_nonalnum.sub('', form.lower()), misc, tok['UPOS'],
                           re_alnum.sub('', form)))

if DOC_ID:
    make_corpus()

In [None]:
def validate_corpus():
    corpus_, raw_corpus_ = re_nonalnum.sub('', ''.join(x[0] for x in corpus)), \
                           ''.join(x[0] for x in raw_corpus)
    if corpus_ != raw_corpus_:
        print('                CORPUS_:')
        print(corpus_)
        print('                RAW_CORPUS_:')
        print(raw_corpus_)
        raise ValueError('The corpus is not the same as the raw corpus!')

if DOC_ID:
    validate_corpus()

In [None]:
def process_corpus(corpus, raw_corpus):

    len_corpus, len_raw_corpus = len(corpus), len(raw_corpus)

    def find_next(i, j):
        form, misc = corpus[i][:2]
        if not form:
            return i + 1, j

        raw_form, raw_form_idx = raw_corpus[j]
        form_, raw_form_ = form, raw_form
        len_form, len_raw_form = len(form), len(raw_form)

        misc['Shift'] = str(raw_form_idx)

        i_, j_ = i + 1, j + 1
        if len_form < len_raw_form:
            while i_ < len_corpus and len(form_) < len_raw_form:
                form_ += corpus[i_][0]
                i_ += 1
            form = form_
        elif len_form > len_raw_form:
            while j_ < len_raw_corpus and len(raw_form_) < len_form:
                raw_form_ += raw_corpus[j_][0]
                j_ += 1
            raw_form = raw_form_

        if form != raw_form:
            raise ValueError('form [{}] is not equal to raw_form [{}]!'.format(form, raw_form))

        return i_, j_,

    mid_ = {'mid': 0}
    def get_mention_id():
        #mid = uuid.uuid4()
        mid_['mid'] += 1
        mid = mid_['mid']
        return str(mid)

    i = j = 0
    while i < len_corpus:
        i_, j_ = find_next(i, j)
        for ii in range(i, i_):  # TODO
            form, misc, upos = corpus[ii][:3]
            if TAG_COREF_HEADS and upos in ['NOUN', 'PRON', 'PROPN']:
                misc['Coref_' + get_mention_id()] = 'Head'
        i, j = i_, j_

if DOC_ID:
    process_corpus(corpus, raw_corpus)

In [None]:
def process_puncts(corpus, raw_puncts):

    sm = difflib.SequenceMatcher()
    len_corpus, len_raw_puncts = len(corpus), len(raw_puncts)

    def find_next(i, j):
        start = stop = 0
        for i in range(i, len_corpus):
            form, misc = corpus[i][:2]
            if not form:
                break
            shift = misc.get('Shift')
            if shift:
                start = int(shift)
        puncts, miscs = [], []
        for i_ in range(i, len_corpus):
            form, misc, _, punct = corpus[i_]
            if form:
                shift = misc.get('Shift')
                if shift:
                    stop = int(shift)
                    break
            else:
                puncts.append(norm_punct(punct))
                miscs.append(misc)
        if not stop:
            return len_corpus, j

        for j in range(j, len_raw_puncts):
            if raw_puncts[j][1] >= start:
                break
        for j_ in range(j, len_raw_puncts):
            if raw_puncts[j_][1] >= stop:
                break
        if j_ == j:
            i_, j_ = len_corpus, len_raw_puncts
        raws, shifts = zip(*raw_puncts[j:j_])
        for ir in range(len(raws) - 3):
            if raws[ir:ir + 3] == ('.',) * 3:
                raws = raws[:ir + 1] + ('', '') + raws[ir + 3:]

        sm.set_seqs(puncts, raws)
        matches = sm.get_matching_blocks()
        for a, b, size in matches:
            for k in range(size):
                miscs[a + k]['Shift'] = str(shifts[b + k])

        return i_, j_,

    i = j = 0
    while i < len_corpus:
        i, j = find_next(i, j)

if DOC_ID:
    process_puncts(corpus, raw_puncts)

In [None]:
if DOC_ID:
    for sent in corpus_orig:
        for tok in sent[0]:
            print('{:20s}{}'.format(tok['FORM'] or '', tok['MISC'].get('Shift', '')))

In [None]:
if DOC_ID:
    shifts = {}
    for sent in corpus_orig:
        for tok in sent[0]:
            form, misc = tok['FORM'] or '', tok['MISC']
            shift = misc.get('Shift', '')
            if shift:
                shifts[shift] = (norm_punct(form), misc)

In [None]:
if DOC_ID:
    raw_corpus

In [None]:
if DOC_ID:
    raw_puncts

In [None]:
if DOC_ID:
    for group_id, chain_id, link, tks, tk_shifts, attrs, hd_shift in \
        groups[groups['doc_id'] == DOC_ID].reset_index() \
              [['group_id', 'chain_id', 'link',
                'content', 'tk_shifts', 'attributes', 'hd_shifts']].values:
        tks = tks.split()
        tk_shifts = tk_shifts.split(',')
        attrs = [(x for x in x.split(':')) for x in attrs.split('|')] \
                    if not pd.isnull(attrs) else \
                []
        assert len(tks) == len(tk_shifts), \
            'len({}) != len({})'.format(tks, tk_shifts)
        for tk, tk_shift in zip(tks, tk_shifts):
            form, misc = shifts.get(tk_shift, (None, None))
            if form is None:
                print('token {} ({}) is not found'.format(tk_shift, tk))
            elif norm_punct(form) != norm_punct(tk):
                print('token {}: {} != {}'.format(tk_shift, form, tk))
            else:
                misc['RuCor_group_id'] = str(group_id)
                misc['RuCor_chain_id'] = str(chain_id)
                if len(tks) == 1 or hd_shift == tk_shift:
                    misc['RuCor_link_id'] = str(link)
                    for attr, val in attrs:
                        misc['RuCor_attrs_' + attr] = val

In [None]:
if DOC_ID:
    Conllu.save(corpus_orig, out_fn)

In [None]:
with open(log_fn, 'wt', encoding='utf-8') as f_log:

    def log(text=''):
        print(text)
        print(text, file=f_log)

    for doc_id in groups['doc_id'].unique():
        in_fn, out_fn = get_fns(doc_id, conllu_dir=conllu_dir)
        log('{}: {}'.format(doc_id, in_fn))

        raw_corpus, raw_puncts = get_raw(in_fn)
        if NEED_SHIFT_ADJUST:
            adjust_raw_corpus(doc_id, raw_corpus)
            adjust_raw_corpus(doc_id, raw_puncts)

        tp.clear_corpus()
        tp.load_pars(in_fn, eop=r'\n')
        tp.do_all(tag_date=False, norm_punct=True)
        tp.save(out_fn + ('$' if TAG_COREF_HEADS else ''))

        if TAG_COREF_HEADS:
            tagger_f.predict(tagger_u.predict(out_fn + '$'), save_to=out_fn)

        corpus_orig = list(Conllu.load(out_fn))
        corpus = []
        make_corpus()
        validate_corpus()
        process_corpus(corpus, raw_corpus)
        process_puncts(corpus, raw_puncts)

        shifts = {}
        for sent in corpus_orig:
            for tok in sent[0]:
                form, misc = tok['FORM'] or '', tok['MISC']
                shift = misc.get('Shift', '')
                if shift:
                    shifts[shift] = (norm_punct(form), misc)
        for group_id, chain_id, link, tks, tk_shifts, attrs, hd_shift in \
            groups[groups['doc_id'] == doc_id].reset_index() \
                  [['group_id', 'chain_id', 'link',
                    'content', 'tk_shifts', 'attributes', 'hd_shifts']].values:
            tks = tks.split()
            tk_shifts = tk_shifts.split(',')
            attrs = [(x for x in x.split(':')) for x in attrs.split('|')] \
                        if not pd.isnull(attrs) else \
                    []
            assert len(tks) == len(tk_shifts), \
                'len({}) != len({})'.format(tks, tk_shifts)
            for tk, tk_shift in zip(tks, tk_shifts):
                form, misc = shifts.get(tk_shift, (None, None))
                if form is None:
                    print('token {} ({}) is not found'.format(tk_shift, tk))
                    log('token {} ({}) is not found'.format(tk_shift, tk))
                elif norm_punct(form) != norm_punct(tk):
                    print('token {}: {} != {}'.format(tk_shift, form, tk))
                    log('token {}: {} != {}'.format(tk_shift, form, tk))
                else:
                    misc['RuCor_group_id'] = str(group_id)
                    misc['RuCor_chain_id'] = str(chain_id)
                    if len(tks) == 1 or hd_shift == tk_shift:
                        misc['RuCor_link_id'] = str(link)
                        for attr, val in attrs:
                            misc['RuCor_attrs_' + attr] = val

        log()
        Conllu.save(corpus_orig, out_fn)

    if TAG_COREF_HEADS:
        os.remove(out_fn + '$')