In [1]:
import os
import conllu
from collections import Counter, OrderedDict
from estnltk.converters.conll_importer import conll_to_text

In [2]:
import sys
sys.path.append('../')
from src.syntax_sketch import clean_clause
from src.syntax_sketch import syntax_sketch

In [3]:
text = conll_to_text('./data/syntax-trees.conllu', 'ud_syntax').tag_layer('clauses')
assert len(text.sentences) == 3214, "Unexpected change of a test data"
display(text.sentences[:3])
display(text.clauses[0])

layer name,attributes,parent,enveloping,ambiguous,span count
sentences,,,words,False,3

text
"['Palju', 'olulisi', 'komponente', ',', 'nagu', 'liha', 'ja', 'kala', ',', 'hang ..., type: <class 'list'>, length: 13"
"['Loomulikult', 'kuuluvad', 'meie', 'kohalikku', 'ostusedelisse', 'ka', 'aedviljad', '.']"
"['Meie', 'peremehe', 'Gesualdo', 'Nava', 'arvates', 'saab', 'hea', 'roa', 'ka', ..., type: <class 'list'>, length: 20"


text,clause_type
"Palju olulisi komponente ,",regular


### I. Old implementation

In [4]:
def subtree_len(heads, ids, head_idx):
    result = 0
    
    if head_idx not in heads:
        return 1
    
    for i, dep_head in enumerate(heads):
        if dep_head == head_idx:
             result += subtree_len(heads, ids, ids[i])
                
    return result + 1 #heads.count(head_idx)

def sketch_sorted(roots, subtrees):
    sketch = '[{}]'.format(','.join(roots))
    for subtree in subtrees:
        sketch += subtree
    return sketch

def sketch_unsorted(roots, subtrees):
    sketch = '[{}]'.format(','.join(roots))
    for subtree in sorted(subtrees):
        sketch += subtree
    return sketch

In [5]:
sketch_counter = Counter()
sketches = list()

assert text.layers == {
    'clauses', 'compound_tokens', 'morph_analysis', 
    'sentences', 'tokens', 'ud_syntax','words'}, "Unexpected layers in the test data"
for idx, clause in enumerate(text.clauses):
    deprels = list(clause.ud_syntax.deprel)
    heads = list(clause.ud_syntax.head)
    ids = list(clause.ud_syntax.id)
    pos = list(clause.ud_syntax.xpostag)

    # punktuatsioon ja sidesõnad lause algusest-lõpust eemaldada
    while pos and ('J' in pos[0] or 'Z' in pos[0]):
        heads.pop(0)
        ids.pop(0)
        deprels.pop(0)
        pos.pop(0)

    if not pos:
        continue

    while 'J' in pos[-1] or 'Z' in pos[-1]:
        heads.pop()
        ids.pop()
        deprels.pop()
        pos.pop()


    # juurte indeksid leida
    root_ids = list()
    for i, head in enumerate(heads):
        if head not in ids:
            root_ids.append(i)

    if len(root_ids) > 1:
        continue

    roots = []
    first_level = list()
    for root_id in root_ids:
        root = ids[root_id]
        roots.append(pos[root_id])
        for i, head in enumerate(heads):
            if head == root:
                length = subtree_len(heads, ids, ids[i])
                if length < 3:
                    subtree_cat = 'L'
                elif length < 10:
                    subtree_cat = 'P'
                else:
                    subtree_cat = 'ÜP'

                subtree = deprels[i] + '({})'.format(subtree_cat)
                first_level.append(subtree)

    if roots[0] == 'V':
        sketch_root = 'V'
    elif roots[0] in ['S', 'P', 'A', 'Y', 'N']:
        sketch_root = 'S'
    else:
        sketch_root = 'X'

    clause_sketch = sketch_unsorted([sketch_root], first_level)
    sketches.append(clause_sketch)

assert len(sketches) == 6036, "Unexpected  number of extracted sketches"    

### II. Regression test

In [6]:
i = 0
invalid_clauses = 0
for clause in text.clauses:
    cleaned_clause = clean_clause(clause)
    if len(cleaned_clause['root_loc']) != 1:
        invalid_clauses += 1
        continue
        
    if syntax_sketch(cleaned_clause) != sketches[i]:
        print(clause.text)
        print(clean_clause(clause)['root_loc'])
        print(syntax_sketch(clean_clause(clause)))
        print(sketches[i])
        assert False, "Implementations differ"
        break    
    i += 1
print('Valid clauses:   {}'.format(i))
print('Invalid clauses: {}'.format(invalid_clauses))

Valid clauses:   6036
Invalid clauses: 126
