In [3]:
import os
import conllu
from random import Random, seed, choice, shuffle
from collections import Counter, OrderedDict
from estnltk.converters.conll_importer import conll_to_text

from pandas import read_csv 


from src.syntax_sketch import clean_clause
from src.syntax_sketch import syntax_sketch
from src.clause_export import export_cleaned_clause

### I. Load syntax analysis data and top 50 sketches

In [4]:
top50 = read_csv('./experiments/knock_out_list.csv', index_col=0)
display(top50.head())

Unnamed: 0,sketch,support
0,[S],960
1,[V]nsubj(L),792
2,[V],416
3,[V]nsubj(L)obl(L),372
4,[S]cop(L)nsubj:cop(L),348


### II. Create a special 50 x 50 test file

We need a test file to estimate the tagger performance on top 50 sketches. 
As the support of each sketch is rather low we use only 50 random clauses for each sketch.
If the test set does not contain enough instances we steal them form the training set.
The is enough clauses in the training set so that this does not effect the overall performance. 

In [None]:
# Write into the file test_panels/combined_panel.conllu
# intividual panels as well
# it can be done in memory in clean way
# we do it differently for ??
# compute sketch for each 
def save_sketches(text, remove_sketch, amount=0):

    removed_count = 0
    conllus = list()

    for clause in text.clauses:
        cleaned_clause = clean_clause
        
        # ??
        sketch = syntax_sketch(cleaned_clause)
        
        if sketch == remove_sketch and removed_count < amount:
            removed_count += 1
        
            conllu = export_cleaned_clause(cleaned_clause) + '\n\n'
            conllus.append(conllu)

    print('Saved {} instances of sketch {}'. format(removed_count, remove_sketch))
            
    return conllus, removed_count



In [None]:
# create test set
train_to_remove = list()
test_data = list()

for test_sketch, _ in top50: 
    
    data, saved = save_sketches(whole_data[0], test_sketch, amount = 10000)
    test_data.extend(data)
    
    if saved < 50:
        data_train, saved_train = save_sketches(whole_data[1], test_sketch, amount = 50-saved)
        test_data.extend(data_train)
        train_to_remove.extend(data_train)

print(len(train_to_remove))

In [None]:
# Koosta 50x50 test fail
"""with open('test_exp2.conllu', 'w', encoding='utf-8') as fout:
    fout.write(''.join(test_data))"""  --> combined_panel





In [None]:
whole_train, _ = remove_sketches(whole_data[1], '', amount = 0)

# Kogu treeningset, kus pole alles testi liikunud lauseid
pure_train = list()
for conllu_clause in whole_train:
    if conllu_clause in train_to_remove:
        continue
    pure_train.append(conllu_clause)
    
"""with open('pure_train_exp2.conllu', 'w', encoding='utf-8') as fout:
    fout.write(''.join(pure_train))"""

In [8]:
def save_sketches(text, remove_sketch, amount=0):

    removed_count = 0
    conllus = list()

    for clause in text.clauses:
        wordforms = list(clause.ud_syntax.text)
        lemmas = list(clause.ud_syntax.lemma)
        feats = list(clause.ud_syntax.feats)
        deprels = list(clause.ud_syntax.deprel)
        heads = list(clause.ud_syntax.head)
        ids = list(clause.ud_syntax.id)
        pos = list(clause.ud_syntax.xpostag)

        # punktuatsioon ja sidesõnad lause algusest-lõpust eemaldada
        while pos and ('J' in pos[0] or 'Z' in pos[0]):
            wordforms.pop(0)
            lemmas.pop(0)
            feats.pop(0)
            heads.pop(0)
            ids.pop(0)
            deprels.pop(0)
            pos.pop(0)

        if not pos:
            continue

        while 'J' in pos[-1] or 'Z' in pos[-1]:
            wordforms.pop()
            lemmas.pop()
            feats.pop()
            heads.pop()
            ids.pop()
            deprels.pop()
            pos.pop()
        
        assert len(wordforms) == len(lemmas) == len(feats) == len(ids) == len(deprels) == len(heads) == len(pos)
        
        sketch = create_sketch(ids, heads, deprels, pos)
        if not sketch:
            continue
        
        if sketch == remove_sketch and removed_count < amount:
            removed_count += 1
        
            conllu = ''

            id_map = dict()
            for i in range(1, len(ids)+1):
                id_map[ids[i-1]] = i


            for old_id, form, lemma, postag, feat, deprel, head in zip(ids, wordforms, lemmas, pos, feats, deprels, heads):
                new_id = id_map[old_id]
                if head not in ids:
                    head = 0
                    deprel = 'root'
                else:
                    head = id_map[head]

                if feat:
                    f = '|'.join([f + '=' + f for f in feat])
                else:
                    f = '_'


                conllu += '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(new_id,
                                                                           form,
                                                                           lemma,
                                                                           postag,
                                                                            postag,
                                                                           f,
                                                                            head,
                                                                            deprel,
                                                                            '_',
                                                                            '_')
            conllu += '\n'
            conllus.append(conllu)

    print('Saved {} instances of sketch {}'. format(removed_count, remove_sketch))
            
    return conllus, removed_count


def remove_sketches(text, remove_sketch, amount=0):

    removed_count = 0
    conllus = list()

    for clause in text.clauses:
        wordforms = list(clause.ud_syntax.text)
        lemmas = list(clause.ud_syntax.lemma)
        feats = list(clause.ud_syntax.feats)
        deprels = list(clause.ud_syntax.deprel)
        heads = list(clause.ud_syntax.head)
        ids = list(clause.ud_syntax.id)
        pos = list(clause.ud_syntax.xpostag)

        # punktuatsioon ja sidesõnad lause algusest-lõpust eemaldada
        while pos and ('J' in pos[0] or 'Z' in pos[0]):
            wordforms.pop(0)
            lemmas.pop(0)
            feats.pop(0)
            heads.pop(0)
            ids.pop(0)
            deprels.pop(0)
            pos.pop(0)

        if not pos:
            continue

        while 'J' in pos[-1] or 'Z' in pos[-1]:
            wordforms.pop()
            lemmas.pop()
            feats.pop()
            heads.pop()
            ids.pop()
            deprels.pop()
            pos.pop()
        
        assert len(wordforms) == len(lemmas) == len(feats) == len(ids) == len(deprels) == len(heads) == len(pos)
        
        sketch = create_sketch(ids, heads, deprels, pos)
        if not sketch:
            continue
        
        if sketch == remove_sketch and removed_count < amount:
            removed_count += 1
            continue
        
        conllu = ''

        id_map = dict()
        for i in range(1, len(ids)+1):
            id_map[ids[i-1]] = i


        for old_id, form, lemma, postag, feat, deprel, head in zip(ids, wordforms, lemmas, pos, feats, deprels, heads):
            new_id = id_map[old_id]
            if head not in ids:
                head = 0
                deprel = 'root'
            else:
                head = id_map[head]

            if feat:
                f = '|'.join([f + '=' + f for f in feat])
            else:
                f = '_'


            conllu += '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(new_id,
                                                                       form,
                                                                       lemma,
                                                                       postag,
                                                                        postag,
                                                                       f,
                                                                        head,
                                                                        deprel,
                                                                        '_',
                                                                        '_')
        conllu += '\n'
        conllus.append(conllu)
    
    print('Removed {} instances of sketch {}'. format(removed_count, remove_sketch))
    print('No of clauses:', len(conllus))
    return conllus, removed_count

## experiment 2

In [9]:
# create test set
train_to_remove = list()
test_data = list()

for test_sketch, _ in top50: 
    
    data, saved = save_sketches(whole_data[0], test_sketch, amount = 10000)
    test_data.extend(data)
    
    if saved < 50:
        data_train, saved_train = save_sketches(whole_data[1], test_sketch, amount = 50-saved)
        test_data.extend(data_train)
        train_to_remove.extend(data_train)

print(len(train_to_remove))

Saved 198 instances of sketch [V]nsubj(L)
Saved 240 instances of sketch [S]
Saved 93 instances of sketch [V]nsubj(L)obl(L)
Saved 104 instances of sketch [V]
Saved 87 instances of sketch [S]cop(L)nsubj:cop(L)
Saved 65 instances of sketch [V]nsubj(L)obj(L)
Saved 70 instances of sketch [V]advmod(L)nsubj(L)
Saved 66 instances of sketch [V]obj(L)
Saved 54 instances of sketch [V]obl(L)
Saved 45 instances of sketch [V]nsubj(L)obj(L)obl(L)
Saved 5 instances of sketch [V]nsubj(L)obj(L)obl(L)
Saved 35 instances of sketch [S]advmod(L)cop(L)nsubj:cop(L)
Saved 15 instances of sketch [S]advmod(L)cop(L)nsubj:cop(L)
Saved 53 instances of sketch [V]obj(L)obl(L)
Saved 37 instances of sketch [V]nsubj(L)obl(P)
Saved 13 instances of sketch [V]nsubj(L)obl(P)
Saved 52 instances of sketch [V]advmod(L)nsubj(L)obl(L)
Saved 51 instances of sketch [V]nsubj(P)
Saved 37 instances of sketch [S]nmod(L)
Saved 13 instances of sketch [S]nmod(L)
Saved 43 instances of sketch [V]nsubj(P)obl(L)
Saved 7 instances of sketch [

In [13]:
# Koosta 50x50 test fail
"""with open('test_exp2.conllu', 'w', encoding='utf-8') as fout:
    fout.write(''.join(test_data))"""


whole_train, _ = remove_sketches(whole_data[1], '', amount = 0)

# Kogu treeningset, kus pole alles testi liikunud lauseid
pure_train = list()
for conllu_clause in whole_train:
    if conllu_clause in train_to_remove:
        continue
    pure_train.append(conllu_clause)
    
"""with open('pure_train_exp2.conllu', 'w', encoding='utf-8') as fout:
    fout.write(''.join(pure_train))"""

Removed 0 instances of sketch 
No of clauses: 43980


"with open('pure_train_exp2.conllu', 'w', encoding='utf-8') as fout:\n    fout.write(''.join(pure_train))"

In [None]:
len(pure_train)

In [None]:
# create test set
train_to_remove = list()

for test_sketch, _ in top50:
    test_data = list()
    
    data, saved = save_sketches(whole_data[0], test_sketch, amount = 10000)
    test_data.extend(data)
    
    if saved < 50:
        data_train, saved_train = save_sketches(whole_data[1], test_sketch, amount = 50-saved)
        test_data.extend(data_train)
        train_to_remove.extend(data_train)
        if saved + saved_train < 50:
            print('!!!')
    
    fname = test_sketch.replace(':', '_').replace(')', '').replace('(', '').replace('[', '').replace(']', '')
    
    #with open('test_50/test_{}.conllu'.format(fname), 'w', encoding='utf-8') as fout:
     #   fout.write(''.join(test_data))

In [None]:
top50 = sketch_counter.most_common(50)
rand = Random()
rand.seed(3)
rand.shuffle(top50)
chosen_sketches = top50[:10]
print(chosen_sketches)
removed_lengths = list() # (sketch, no of clauses removed from train, no of clauses removed from dev)

# Removing each sketch (and clauses that were added to test set) from train and dev
for sketch, _ in chosen_sketches:
    trainset = list()
    removed = 0
    
    sketches_to_remove, _ = save_sketches(whole_data[1], sketch, amount = 10000)
    
    for conllu_clause in pure_train:
        if conllu_clause in sketches_to_remove:
            removed += 1
            continue
        trainset.append(conllu_clause)
        
    devset, removed_dev_count = remove_sketches(whole_data[2], sketch, amount=10000)
    print(sketch, '- removed', removed_dev_count, 'clauses from dev,', removed, 'from train')
    
    removed_lengths.append((sketch, removed, removed_dev_count))
        
    
"""filename_sketch = sketch.replace(':', '_').replace(')', '').replace('(', '').replace('[', '').replace(']', '')
    
    with open('experiment_2/splits/train_{}.conllu'.format(filename_sketch), 'w', encoding='utf-8') as fout:
        fout.write(''.join(trainset))
    
    with open('experiment_2/splits/dev_{}.conllu'.format(filename_sketch), 'w', encoding='utf-8') as fout:
        fout.write(''.join(devset))"""


In [None]:
removed_lengths

In [None]:
rand.seed(5)

for sketch, removed_train_count, removed_dev_count in removed_lengths:        
    train_sample = rand.sample(range(0, len(pure_train)+1), removed_train_count) #(range(0, 43981), removed_train_count)
    dev_sample = rand.sample(range(0, 5709), removed_dev_count)

    ablation_train = list()
    ablation_dev = list()

    for i, conllu in enumerate(pure_train): #enumerate(remove_sketches(whole_data[1], '', amount=0)[0]):
        if i in train_sample:
            continue
        ablation_train.append(conllu)

    for i, conllu in enumerate(remove_sketches(whole_data[2], '', amount=0)[0]):
        if i in dev_sample:
            continue
        ablation_dev.append(conllu)
        
    filename_sketch = sketch.replace(':', '_').replace(')', '').replace('(', '').replace('[', '').replace(']', '')
        
    with open('experiment_2/random_splits2/train_ablation_{}.conllu'.format(filename_sketch), 'w', encoding='utf-8') as fout:
        fout.write(''.join(ablation_train))
    
    with open('experiment_2/random_splits2/dev_ablation_{}.conllu'.format(filename_sketch), 'w', encoding='utf-8') as fout:
        fout.write(''.join(ablation_dev))

In [None]:
with open('experiment_2/random_splits2/train_ablation_total.conllu', 'w', encoding='utf-8') as fout:
    fout.write(''.join(pure_train))

full_dev, _ = remove_sketches(whole_data[2], '', amount=0)

with open('experiment_2/random_splits2/dev_ablation_total.conllu', 'w', encoding='utf-8') as fout:
    fout.write(''.join(full_dev))

### experiment:
1.  remove all 10 sketches from train and dev
2.remove same amount of random sents from train and dev

In [37]:
# 1)
top50 = sketch_counter.most_common(50)
rand = Random()
rand.seed(3)
rand.shuffle(top50)
chosen_sketches = top50[:10]
print(chosen_sketches)

sketches_to_remove_train = list()
sketches_to_remove_dev = list()
removed_train = 0
removed_dev = 0
minimal_train = list()
minimal_dev = list()


for sketch, _ in chosen_sketches:    
    sketches_to_remove_train.extend(save_sketches(whole_data[1], sketch, amount = 10000)[0])
    sketches_to_remove_dev.extend(save_sketches(whole_data[2], sketch, amount=10000)[0])
    
for conllu_clause in pure_train:
    if conllu_clause in sketches_to_remove_train:
        removed_train += 1
        continue
    minimal_train.append(conllu_clause)
    
for conllu_clause in remove_sketches(whole_data[2], sketch, amount=0)[0]:
    if conllu_clause in sketches_to_remove_dev:
        removed_dev += 1
        continue
    minimal_dev.append(conllu_clause)
        
    
print('- removed', removed_dev, 'clauses from dev,', removed_train, 'from train')
        
        
with open('experiment_2/no_sketch_splits/train.conllu'.format(filename_sketch), 'w', encoding='utf-8') as fout:
    fout.write(''.join(minimal_train))

with open('experiment_2/no_sketch_splits/dev.conllu'.format(filename_sketch), 'w', encoding='utf-8') as fout:
    fout.write(''.join(minimal_dev))

[('[V]obj(P)', 293), ('[V]obj(L)obl(P)', 188), ('[V]obj(L)obl(L)', 505), ('[V]nsubj(L)obj(P)', 401), ('[V]', 913), ('[V]aux(L)nsubj(L)', 335), ('[V]advmod(L)', 372), ('[S]cop(L)nmod(L)nsubj:cop(L)', 191), ('[S]advmod(L)cop(L)nsubj:cop(L)', 506), ('[V]advmod(L)nsubj(L)', 698)]
Saved 229 instances of sketch [V]obj(P)
Saved 36 instances of sketch [V]obj(P)
Saved 158 instances of sketch [V]obj(L)obl(P)
Saved 13 instances of sketch [V]obj(L)obl(P)
Saved 408 instances of sketch [V]obj(L)obl(L)
Saved 44 instances of sketch [V]obj(L)obl(L)
Saved 303 instances of sketch [V]nsubj(L)obj(P)
Saved 54 instances of sketch [V]nsubj(L)obj(P)
Saved 699 instances of sketch [V]
Saved 110 instances of sketch [V]
Saved 259 instances of sketch [V]aux(L)nsubj(L)
Saved 46 instances of sketch [V]aux(L)nsubj(L)
Saved 290 instances of sketch [V]advmod(L)
Saved 40 instances of sketch [V]advmod(L)
Saved 142 instances of sketch [S]cop(L)nmod(L)nsubj:cop(L)
Saved 28 instances of sketch [S]cop(L)nmod(L)nsubj:cop(L)
Sa

In [44]:
# 2)
rand.seed(5)

train_sample = rand.sample(range(0, len(pure_train)), removed_train)
dev_sample = rand.sample(range(0, 5708), removed_dev)

ablation_train = list()
ablation_dev = list()

for i, conllu in enumerate(pure_train): #enumerate(remove_sketches(whole_data[1], '', amount=0)[0]):
    if i in train_sample:
        continue
    ablation_train.append(conllu)

for i, conllu in enumerate(remove_sketches(whole_data[2], '', amount=0)[0]):
    if i in dev_sample:
        continue
    ablation_dev.append(conllu)

with open('experiment_2/no_sketch_splits/train_random.conllu', 'w', encoding='utf-8') as fout:
    fout.write(''.join(ablation_train))

with open('experiment_2/no_sketch_splits/dev_random.conllu', 'w', encoding='utf-8') as fout:
    fout.write(''.join(ablation_dev))

Removed 0 instances of sketch 
No of clauses: 5708


### check

In [None]:
## sanity check
import conllu

for sketch, removed_train_count, removed_dev_count in removed_lengths:        
        
    filename_sketch = sketch.replace(':', '_').replace(')', '').replace('(', '').replace('[', '').replace(']', '')
        
    with open('experiment_2/random_splits/train_ablation_{}.conllu'.format(filename_sketch), 'r', encoding='utf-8') as fin:
        c = conllu.parse(fin.read())
        print(sketch, len(c))
    
    with open('experiment_2/random_splits/dev_ablation_{}.conllu'.format(filename_sketch), 'r', encoding='utf-8') as fin:
        c = conllu.parse(fin.read())
        print(sketch, len(c))


In [None]:
## sanity check
import conllu

for sketch, removed_train_count, removed_dev_count in removed_lengths:        
        
    filename_sketch = sketch.replace(':', '_').replace(')', '').replace('(', '').replace('[', '').replace(']', '')
        
    with open('experiment_2/random_splits2/train_ablation_{}.conllu'.format(filename_sketch), 'r', encoding='utf-8') as fin:
        c1 = conllu.parse(fin.read())
    
    with open('experiment_2/splits/train_{}.conllu'.format(filename_sketch), 'r', encoding='utf-8') as fin:
        c2 = conllu.parse(fin.read())
    
    assert len(c1) == len(c2)
        
    print()
    
    with open('experiment_2/random_splits2/dev_ablation_{}.conllu'.format(filename_sketch), 'r', encoding='utf-8') as fin:
        c1 = conllu.parse(fin.read())
        
    with open('experiment_2/splits/dev_{}.conllu'.format(filename_sketch), 'r', encoding='utf-8') as fin:
        c2 = conllu.parse(fin.read())
        
    assert len(c1) == len(c2)


In [None]:
with open('experiment_2/random_splits2/train_ablation_total.conllu'.format(filename_sketch), 'r', encoding='utf-8') as fin:
        c1 = conllu.parse(fin.read())
print(len(c1))
print()

for sketch, removed_train_count, removed_dev_count in removed_lengths:        
        
    filename_sketch = sketch.replace(':', '_').replace(')', '').replace('(', '').replace('[', '').replace(']', '')

    with open('experiment_2/splits/train_{}.conllu'.format(filename_sketch), 'r', encoding='utf-8') as fin:
        c2 = conllu.parse(fin.read())
        print(sketch, len(c2))

    

In [45]:
## sanity check
import conllu

        
with open('experiment_2/no_sketch_splits/train.conllu', 'r', encoding='utf-8') as fin:
    c1 = conllu.parse(fin.read())

with open('experiment_2/no_sketch_splits/train_random.conllu', 'r', encoding='utf-8') as fin:
    c2 = conllu.parse(fin.read())

print(len(c1), len(c2))
assert len(c1) == len(c2)

print()

with open('experiment_2/no_sketch_splits/dev.conllu', 'r', encoding='utf-8') as fin:
    c1 = conllu.parse(fin.read())

with open('experiment_2/no_sketch_splits/dev_random.conllu', 'r', encoding='utf-8') as fin:
    c2 = conllu.parse(fin.read())

print(len(c1), len(c2))
assert len(c1) == len(c2)


39813 39813

5223 5223
