In [1]:
import os
import conllu
from collections import Counter, OrderedDict
from estnltk.converters.conll_importer import conll_to_text

In [2]:
import sys
sys.path.append('../')
from src.syntax_sketch import clean_clause
from src.syntax_sketch import syntax_sketch
from src.clause_export import export_cleaned_clause

In [None]:
text = conll_to_text('./data/syntax-trees.conllu', 'ud_syntax').tag_layer('clauses')
assert len(text.sentences) == 3214, "Unexpected change of a test data"
display(text.sentences[:3])
display(text.clauses[0])

### I. Old implementation

In [None]:
def subtree_len(heads, ids, head_idx):
    result = 0
    
    if head_idx not in heads:
        return 1
    
    for i, dep_head in enumerate(heads):
        if dep_head == head_idx:
             result += subtree_len(heads, ids, ids[i])
                
    return result + 1 #heads.count(head_idx)

def sketch_unsorted(roots, subtrees):
    sketch = '[{}]'.format(','.join(roots))
    for subtree in sorted(subtrees):
        sketch += subtree
    return sketch

def create_sketch(ids, heads, deprels, pos):
    root_ids = list()
    for i, head in enumerate(heads):
        if head not in ids:
            root_ids.append(i)

    if len(root_ids) > 1:
        return None
        
    first_level = list()
    root = ids[root_ids[0]]
    root_pos = pos[root_ids[0]]
    
    for i, head in enumerate(heads):
        if head == root:
            length = subtree_len(heads, ids, ids[i])
            if length < 3:
                subtree_cat = 'L'
            elif length < 10:
                subtree_cat = 'P'
            else:
                subtree_cat = 'ÜP'

            subtree = deprels[i] + '({})'.format(subtree_cat)
            first_level.append(subtree)

    if root_pos == 'V':
        sketch_root = root_pos
    elif root_pos in ['S', 'P', 'A', 'Y', 'N']:
        sketch_root = 'S'
    else:
        sketch_root = 'X'
    return sketch_unsorted([sketch_root], first_level)
    

In [None]:
def export_clause(clause):

    wordforms = list(clause.ud_syntax.text)
    lemmas = list(clause.ud_syntax.lemma)
    feats = list(clause.ud_syntax.feats)
    deprels = list(clause.ud_syntax.deprel)
    heads = list(clause.ud_syntax.head)
    ids = list(clause.ud_syntax.id)
    pos = list(clause.ud_syntax.xpostag)

    # punktuatsioon ja sidesõnad lause algusest-lõpust eemaldada
    while pos and ('J' in pos[0] or 'Z' in pos[0]):
        wordforms.pop(0)
        lemmas.pop(0)
        feats.pop(0)
        heads.pop(0)
        ids.pop(0)
        deprels.pop(0)
        pos.pop(0)

    if not pos:
        return None

    while 'J' in pos[-1] or 'Z' in pos[-1]:
        wordforms.pop()
        lemmas.pop()
        feats.pop()
        heads.pop()
        ids.pop()
        deprels.pop()
        pos.pop()

    assert len(wordforms) == len(lemmas) == len(feats) == len(ids) == len(deprels) == len(heads) == len(pos)


    conllu = ''

    id_map = dict()
    for i in range(1, len(ids)+1):
        id_map[ids[i-1]] = i


    for old_id, form, lemma, postag, feat, deprel, head in zip(ids, wordforms, lemmas, pos, feats, deprels, heads):
        new_id = id_map[old_id]
        if head not in ids:
            head = 0
            deprel = 'root'
        else:
            head = id_map[head]

        if feat:
            f = '|'.join([f + '=' + f for f in feat])
        else:
            f = '_'


        conllu += '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(new_id,
                                                                   form,
                                                                   lemma,
                                                                   postag,
                                                                    postag,
                                                                   f,
                                                                    head,
                                                                    deprel,
                                                                    '_',
                                                                    '_')

    return conllu

### II. Regression test

In [None]:
i = 0
invalid_clauses = 0
for clause in text.clauses:
    cleaned_clause = clean_clause(clause)
    if len(cleaned_clause['root_loc']) != 1:
        invalid_clauses += 1
        continue
        
    if export_cleaned_clause(cleaned_clause) + '\n' != export_clause(clause):
        print(clause.text)
        print(export_cleaned_clause(cleaned_clause) + '\n')
        print('---------------------------')
        print(export_clause(clause))
        print('---------------------------')
        assert False, "Implementations differ"
        break    
    i += 1
print('Valid clauses:   {}'.format(i))
print('Invalid clauses: {}'.format(invalid_clauses))