# It's a Machine and Natural Language Tagger

In [1]:
from src.IaMaN.base import LM
from src.utils.data import load_ud
from collections import defaultdict
from collections import Counter
from tqdm import tqdm
import numpy as np
import os, re

seed = 691

print("Loading pre-training data...")
pretrain_path = '/data/newstweet/week_2019-40_article_texts/'
total_pretrain = len([pretrain_file for pretrain_file in os.listdir(pretrain_path) if re.search("^\d+.txt$", pretrain_file)])
num_pretrain = 0

all_pretrain_files = [pretrain_file for pretrain_file in os.listdir(pretrain_path) if re.search("^\d+.txt$", pretrain_file)]
if num_pretrain:
    np.random.seed(seed)
    pretrain_files = np.random.choice(all_pretrain_files, size=num_pretrain, replace=False)
else:
    pretrain_files = np.array([])

ptdocs = [[[open(pretrain_path+pretrain_file).read()]] for pretrain_file in tqdm(pretrain_files)]

max_char = 200_000_000
m = 10; space = True; fine_tune = False; fine_tune_post_pretrain = False
positional = 'independent'; positionally_encode = 't'; do_ife = False; update_ife = False; update_bow = False; 
runners = 10; gpu = False

print("Loading gold-tagged UDs data...")
load_set = 'GUM'; tokenizer = 'hr-bpe'
all_docs = load_ud("English", num_articles = 0, seed = 691, load_set = load_set, rebuild = True, space = space)
test_docs = [doc for doc in all_docs if 'test' in doc['id'] and len(doc['text']) <= max_char]# [:1]
train_docs = [doc for doc in all_docs if 'test' not in doc['id'] and len(doc['text']) <= max_char]# [:4]
nsamp = len(test_docs)
print('Avail. pre-train, total pre-train, Avail. gold, total gold-train, total test-gold: ', 
      total_pretrain, len(ptdocs), len(all_docs), len(train_docs), len(test_docs))

Loading pre-training data...


0it [00:00, ?it/s]

Loading gold-tagged UDs data...





Avail. pre-train, total pre-train, Avail. gold, total gold-train, total test-gold:  14198 0 150 132 18


In [2]:
docs = [["".join([row[1] for row in s]) for s in d['conllu']] for d in train_docs]
tdocs = [["".join([row[1] for row in s]) for s in d['conllu']] for d in test_docs]
covering = [[[row[1] for row in s] for s in d['conllu']] for d in train_docs]
tcovering = [[[row[1] for row in s] for s in d['conllu']] for d in test_docs]
covering_vocab = set([t for d in covering for s in d for t in s])

all_layers = {d_i: {# 'lem': [[row[2] for row in s] for s in d['conllu']], 
                    # 'sty': [[d['s_type'][s_i] for row in s] for s_i, s in enumerate(d['conllu'])], 
                    'pos': [[row[3] for row in s] for s in d['conllu']], 
                    'sup': [[(str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]) for row in s] for s in d['conllu']], 
                    'dep': [[row[7] for row in s] for s in d['conllu']]}
              for d_i, d in enumerate(train_docs)}

model = LM(m = m, tokenizer = tokenizer, noise = 0.001, seed = seed, space = space, positional = positional,
           positionally_encode = positionally_encode, do_ife = do_ife, runners = runners, gpu = gpu)
model.fit(docs, f'{load_set}-{nsamp}', covering = covering, all_layers = all_layers, 
          fine_tune = fine_tune)
model.pre_train(ptdocs, update_ife = update_ife)
if fine_tune_post_pretrain:
    model.fine_tune(docs, covering = covering, all_layers = all_layers)

Training tokenizer...


Initializing: 100%|██████████| 6503/6503 [00:01<00:00, 4471.64it/s]
Fitting:  20%|██        | 20/100 [00:33<02:14,  1.68s/it]


Built a vocabulary of 10606 types
Tokenizing documents...


100%|██████████| 132/132 [00:15<00:00,  8.61it/s]


Counting documents and aggregating counts...


4758045it [05:29, 14435.65it/s] 


Collecting metadata...


100%|██████████| 132/132 [00:06<00:00, 19.97it/s]


Aggregating metadata...


100%|██████████| 132/132 [00:00<00:00, 160.16it/s]


Encoding parameters...


100%|██████████| 4758045/4758045 [00:43<00:00, 109826.91it/s]


Computing marginal statistics...


100%|██████████| 22/22 [00:06<00:00,  3.37it/s]


Building dense output heads...


  X[X==0] = self._noise; X /= X.sum(axis = 1)[:,None]; X = np.nan_to_num(-np.log10(X))
100%|██████████| 22/22 [00:14<00:00,  1.53it/s]


Counting for transition matrices...


100%|██████████| 132/132 [00:06<00:00, 19.25it/s]


Building transition matrices for Viterbi tag decoding...


100%|██████████| 9/9 [00:00<00:00, 49.84it/s]

Done.
Model params, types, encoding size, contexts, vec dim, max sent, and % capacity used: 507580 10607 10607 10607 10898 178 0.451





__Currently__: ordering for the current fine tuning process:
1. train tokenizer and fit model to GUM
2. process NewsTweet documents to integrate sparse post-training statistics (requires mr implementation and updates to the vocabularies/indices)
3. update the ife and dense model, i.e., produce new statistics and dimensionalities
4. fine tune output heads to GUM, and _combine_ them with the dense model from (3), i.e., don't just replace as is current.

__Preliminarily__: this does seem to present performance benefits, but as is usual will require 'big data' statistics to become competitive. In particular, the (tokenization, least of all), counting, sorting, and aggregation of co-occurrence counts must all be distributed for the statistical resolution required to approach performance gains aking to more-advanced systems. Currently, a spark-based MR system is implemented for these (all but tokenization).

In [3]:
interpret_docs = list([docs[3][0:1]])
print(interpret_docs)
model.interpret(interpret_docs, seed = 691)
for doc in model._documents:
    print('opening next doc:')
    for s in doc._sentences:
        print(f'opening next sent: {s._sty}')
        for t in s._tokens:
            print(f'opening next token: {t._form}, {t._sep}, {t._pos}, {t._sup}, {t._dep}')

[[' Emperor Norton ']]


100%|██████████| 1/1 [00:00<00:00,  3.73it/s]

opening next doc:
opening next sent: None
opening next token:  , False, SPACE, 0, root
opening next token: Emperor, False, PROPN, 1, space
opening next token:  , False, NOUN, 1, space
opening next token: Nort, False, PUNCT, 1, space
opening next token: o, False, SPACE, -4, space
opening next token: n, False, PROPN, -2, space
opening next token:  , True, SPACE, -2, space





In [4]:
interpret_docs = list([tdocs[0][1:2]])
print(interpret_docs)
model.interpret(interpret_docs, seed = 691)
for doc in model._documents:
    print('opening next doc:')
    for s in doc._sentences:
        print(f'opening next sent: {s._sty}')
        for t in s._tokens:
            print(f'opening next token: {t._form}, {t._sep}, {t._pos}, {t._sup}, {t._dep}')

[[' Results from a nationally representative sample of adults ']]


100%|██████████| 1/1 [00:00<00:00,  1.14it/s]

opening next doc:
opening next sent: None
opening next token:  Results from, False, SPACE, 1, space
opening next token:  , False, NOUN, 0, root
opening next token: a, False, NOUN, 1, space
opening next token:  , False, SPACE, 1, space
opening next token: n, False, NOUN, 1, space
opening next token: a, False, NOUN, -4, space
opening next token: t, False, NOUN, 1, space
opening next token: io, False, NOUN, 1, space
opening next token: n, False, NOUN, 1, space
opening next token: ally, False, SPACE, 1, space
opening next token:  , False, NOUN, 1, space
opening next token: representative, False, NOUN, -7, space
opening next token:  , False, SPACE, 1, space
opening next token: sample, False, NOUN, 1, space
opening next token:  , False, SPACE, 1, space
opening next token: of, False, NOUN, -7, space
opening next token:  , False, SPACE, 1, space
opening next token: adults, False, NOUN, -7, space
opening next token:  , True, SPACE, -4, space





In [5]:
interpret_docs = list([tdocs[0][1:2]])
print(interpret_docs)
interpret_covering =  [[[row[1] for row in s] for s in d['conllu']][1:2] for d in test_docs[0:1]]
print(interpret_covering)
model.interpret(interpret_docs, seed = 691, covering = interpret_covering)

accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[0:1][d_i]['s_type'][s_i+1]
            accuracy_sty[test_docs[0:1][d_i]['s_type'][s_i+1]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep) for doc in model._documents for s in doc._sentences for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs[:1] for s in d['conllu'][1:2] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]),
                  row[7]) for d in test_docs[:1] for s in d['conllu'][1:2] for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs[:1] for s in d['conllu'][1:2] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Tag-wise POS accuracy with/out space", {tag: sum(accuracy[tag])/len(accuracy[tag]) for tag in accuracy}, 
                                          {tag: sum(accuracy_nsp[tag])/len(accuracy_nsp[tag]) for tag in accuracy_nsp})
print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
print("Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                         for tag in accuracy_sty}).most_common()))

[[' Results from a nationally representative sample of adults ']]
[[[' ', 'Results', ' ', 'from', ' ', 'a', ' ', 'nationally', ' ', 'representative', ' ', 'sample', ' ', 'of', ' ', 'adults', ' ']]]


100%|██████████| 1/1 [00:00<00:00,  1.22it/s]

Tag-wise POS accuracy with/out space {'SPACE': 0.7777777777777778, 'NOUN': 1.0, 'ADP': 0.0, 'DET': 0.0, 'ADV': 0.0, 'ADJ': 0.0} {'NOUN': 1.0, 'ADP': 0.0, 'DET': 0.0, 'ADV': 0.0, 'ADJ': 0.0}
Overall POS accuracy with/out space 0.5882352941176471 0.375
Overall SUP:DEP accuracy with/out space 0.35294117647058826 0.375
Tag-wise accuracy []





In [6]:
interpret_docs = list([tdocs[0][:2]])
print(interpret_docs)
interpret_covering =  [[[row[1] for row in s] for s in d['conllu']][:2] for d in test_docs[0:1]]
print(interpret_covering)
model.interpret(interpret_docs, seed = 691, covering = interpret_covering)

accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[0:1][d_i]['s_type'][s_i]
            accuracy_sty[test_docs[0:1][d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, s_i, d_i) for d_i, doc in enumerate(model._documents) for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs[:1] for s in d['conllu'][:2] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], s_i, d_i) 
                 for d_i, d in enumerate(test_docs[:1]) for s_i, s in enumerate(d['conllu'][:2]) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs[:1] for s in d['conllu'][:2] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Tag-wise POS accuracy with/out space", {tag: sum(accuracy[tag])/len(accuracy[tag]) for tag in accuracy}, 
                                          {tag: sum(accuracy_nsp[tag])/len(accuracy_nsp[tag]) for tag in accuracy_nsp})
print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
print("Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                         for tag in accuracy_sty}).most_common()))

[[' The prevalence of discrimination across racial groups in contemporary America: ', ' Results from a nationally representative sample of adults ']]
[[[' ', 'The', ' ', 'prevalence', ' ', 'of', ' ', 'discrimination', ' ', 'across', ' ', 'racial', ' ', 'groups', ' ', 'in', ' ', 'contemporary', ' ', 'America', ':', ' '], [' ', 'Results', ' ', 'from', ' ', 'a', ' ', 'nationally', ' ', 'representative', ' ', 'sample', ' ', 'of', ' ', 'adults', ' ']]]


100%|██████████| 1/1 [00:01<00:00,  1.80s/it]

Tag-wise POS accuracy with/out space {'SPACE': 0.7, 'DET': 0.0, 'NOUN': 0.8333333333333334, 'ADP': 0.0, 'ADJ': 0.0, 'PROPN': 0.0, 'PUNCT': 0.0, 'ADV': 0.0} {'DET': 0.0, 'NOUN': 0.8333333333333334, 'ADP': 0.0, 'ADJ': 0.0, 'PROPN': 0.0, 'PUNCT': 0.0, 'ADV': 0.0}
Overall POS accuracy with/out space 0.48717948717948717 0.2631578947368421
Overall SUP:DEP accuracy with/out space 0.38461538461538464 0.47368421052631576
Tag-wise accuracy []





In [7]:
from tqdm import tqdm

confusion = Counter()
confusion_nsp = Counter()
accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

model.interpret(tdocs, seed = 691) 

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_spans = set([(sh-len(gt), sh) for sh, gt in zip(pred_spans, pred_toks)])

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_spans = set([(sh-len(gt), sh) for sh, gt in zip(gold_spans, gold_toks)])

for pred_span, pred_tok in zip(pred_spans, pred_toks):
    if pred_span in gold_spans:
        confusion['TP'] += 1
        if pred_tok != ' ':
            confusion_nsp['TP'] += 1
    else:
        confusion['FP'] += 1
        if pred_tok != ' ':
            confusion_nsp['FP'] += 1
confusion['FN'] = len(gold_spans) - confusion['TP']
confusion_nsp['FN'] = len([t for t in gold_toks if t != ' ']) - confusion_nsp['TP']

confusion['P'] = round(confusion['TP']/(confusion['TP'] + confusion['FP']), 3)
confusion['R'] = round(confusion['TP']/(confusion['TP'] + confusion['FN']), 3)
confusion['F'] = round(2*confusion['P']*confusion['R']/(confusion['P']+confusion["R"]), 3)
confusion_nsp['P'] = round(confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FP']), 3)
confusion_nsp['R'] = round(confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FN']), 3)
confusion_nsp['F'] = round(2*confusion_nsp['P']*confusion_nsp['R']/(confusion_nsp['P']+confusion_nsp["R"]), 3)

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[d_i]['s_type'][s_i]
            accuracy_sty[test_docs[d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, d_i, s_i) for d_i, doc in enumerate(model._documents) 
                 for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], d_i, s_i) 
                 for d_i, d in enumerate(test_docs) for s_i, s in enumerate(d['conllu']) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Token segmentation performance with/out space", confusion, confusion_nsp)
print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
"Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                   for tag in accuracy_sty}).most_common()), list(Counter({tag: (sum(accuracy[tag])/len(accuracy[tag]), len(accuracy[tag])) 
                                                                                           for tag in accuracy}).most_common())

100%|██████████| 18/18 [32:58<00:00, 109.91s/it]


Token segmentation performance with/out space Counter({'TP': 20687, 'FP': 12487, 'FN': 10339, 'R': 0.667, 'F': 0.645, 'P': 0.624}) Counter({'TP': 13418, 'FP': 8065, 'FN': 2798, 'R': 0.827, 'F': 0.712, 'P': 0.625})
Overall POS accuracy with/out space 0.1950944369238703 0.09176122348297977
Overall SUP:DEP accuracy with/out space 0.18776752878760475 0.1894521249359959


('Tag-wise accuracy',
 [],
 [('SPACE', (0.30823767724510465, 14810)),
  ('INTJ', (0.29545454545454547, 88)),
  ('NOUN', (0.18882531134298217, 2971)),
  ('PUNCT', (0.1878640776699029, 2060)),
  ('NUM', (0.13994169096209913, 343)),
  ('PROPN', (0.1335478680611424, 1243)),
  ('PRON', (0.07613741875580315, 1077)),
  ('AUX', (0.036061026352288486, 721)),
  ('ADV', (0.03588907014681892, 613)),
  ('VERB', (0.032432432432432434, 1665)),
  ('ADP', (0.029359953024075163, 1703)),
  ('DET', (0.02145922746781116, 1398)),
  ('CCONJ', (0.01870748299319728, 588)),
  ('ADJ', (0.015370705244122965, 1106)),
  ('PART', (0.014925373134328358, 335)),
  ('SCONJ', (0.012295081967213115, 244)),
  ('SYM', (0.0, 35)),
  ('X', (0.0, 26))])

```
--- start small training/test
--- ife, sparse, uniform-positional encoding
100%|██████████| 1/1 [00:18<00:00, 18.28s/it]

Token segmentation performance with/out space 
Counter({'TP': 1660, 'FP': 874, 'FN': 328, 'R': 0.835, 'F': 0.734, 'P': 0.655}) 
Counter({'TP': 1024, 'FP': 573, 'FN': 27, 'R': 0.974, 'F': 0.773, 'P': 0.641})
Overall POS accuracy with/out space 0.6403420523138833 0.3824928639391056
Overall SUP:DEP accuracy with/out space 0.19179163378058406 0.1959924859110833
Overall s_type accuracy:  0.7037037037037037

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0625, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9295624332977588, 937)),
  ('PUNCT', (0.68125, 160)),
  ('DET', (0.5876288659793815, 97)),
  ('ADP', (0.5765765765765766, 111)),
  ('CCONJ', (0.56, 25)),
  ('NOUN', (0.3884297520661157, 242)),
  ('NUM', (0.32558139534883723, 43)),
  ('AUX', (0.2682926829268293, 41)),
  ('VERB', (0.26548672566371684, 113)),
  ('PRON', (0.22727272727272727, 22)),
  ('PROPN', (0.09375, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('ADJ', (0.0, 83)),
  ('SCONJ', (0.0, 16)),
  ('PART', (0.0, 13)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, sparse, token-positional encoding
100%|██████████| 1/1 [00:20<00:00, 20.01s/it]

Token segmentation performance with/out space 
Counter({'TP': 1658, 'FP': 876, 'FN': 330, 'R': 0.834, 'F': 0.733, 'P': 0.654}) 
Counter({'TP': 1047, 'FP': 550, 'FN': 4, 'R': 0.996, 'F': 0.791, 'P': 0.656})
Overall POS accuracy with/out space 0.6740442655935613 0.3939105613701237
Overall SUP:DEP accuracy with/out space 0.1910023677979479 0.19724483406386975
Overall s_type accuracy:  0.7037037037037037

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0625, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9882604055496265, 937)),
  ('PUNCT', (0.70625, 160)),
  ('DET', (0.6288659793814433, 97)),
  ('ADP', (0.6126126126126126, 111)),
  ('CCONJ', (0.6, 25)),
  ('NOUN', (0.384297520661157, 242)),
  ('NUM', (0.32558139534883723, 43)),
  ('AUX', (0.2682926829268293, 41)),
  ('VERB', (0.26548672566371684, 113)),
  ('PRON', (0.22727272727272727, 22)),
  ('PROPN', (0.09375, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('ADJ', (0.0, 83)),
  ('SCONJ', (0.0, 16)),
  ('PART', (0.0, 13)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, fine tuned, uniform-positional encoding
100%|██████████| 1/1 [00:19<00:00, 19.83s/it]

Token segmentation performance with/out space 
Counter({'TP': 1611, 'FP': 1089, 'FN': 377, 'R': 0.81, 'F': 0.687, 'P': 0.597}) 
Counter({'TP': 1055, 'FP': 716, 'R': 1.004, 'F': 0.748, 'P': 0.596, 'FN': -4})
Overall POS accuracy with/out space 0.619215291750503 0.3368220742150333
Overall SUP:DEP accuracy with/out space 0.1925925925925926 0.19875776397515527
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.935965848452508, 937)),
  ('DET', (0.6082474226804123, 97)),
  ('PUNCT', (0.6, 160)),
  ('ADP', (0.5945945945945946, 111)),
  ('CCONJ', (0.52, 25)),
  ('AUX', (0.36585365853658536, 41)),
  ('NUM', (0.27906976744186046, 43)),
  ('NOUN', (0.256198347107438, 242)),
  ('PART', (0.23076923076923078, 13)),
  ('VERB', (0.168141592920354, 113)),
  ('PRON', (0.09090909090909091, 22)),
  ('PROPN', (0.0625, 32)),
  ('ADJ', (0.04819277108433735, 83)),
  ('ADV', (0.02564102564102564, 39)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, fine tuned, token-positional encoding
100%|██████████| 1/1 [00:22<00:00, 22.50s/it]

Token segmentation performance with/out space 
Counter({'TP': 1666, 'FP': 870, 'FN': 322, 'R': 0.838, 'F': 0.737, 'P': 0.657}) 
Counter({'TP': 1043, 'FP': 556, 'FN': 8, 'R': 0.992, 'F': 0.787, 'P': 0.652})
Overall POS accuracy with/out space 0.6745472837022133 0.384395813510942
Overall SUP:DEP accuracy with/out space 0.16876971608832808 0.17886178861788618
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('SPACE', (1.0, 937)),
  ('PUNCT', (0.74375, 160)),
  ('DET', (0.6597938144329897, 97)),
  ('CCONJ', (0.6, 25)),
  ('ADP', (0.5945945945945946, 111)),
  ('AUX', (0.34146341463414637, 41)),
  ('NOUN', (0.28512396694214875, 242)),
  ('VERB', (0.2831858407079646, 113)),
  ('NUM', (0.27906976744186046, 43)),
  ('PRON', (0.18181818181818182, 22)),
  ('PART', (0.15384615384615385, 13)),
  ('PROPN', (0.0625, 32)),
  ('ADJ', (0.04819277108433735, 83)),
  ('ADV', (0.02564102564102564, 39)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- form, sparse, uniform-positional encoding
Token segmentation performance with/out space: 
Counter({'TP': 1567, 'FP': 1035, 'FN': 367, 'R': 0.81, 'F': 0.691, 'P': 0.602}) 
Counter({'TP': 1013, 'FP': 706, 'FN': 38, 'R': 0.964, 'F': 0.731, 'P': 0.589})
Overall POS accuracy with/out space 0.655635987590486 0.36631779257849667
Overall SUP:DEP accuracy with/out space 0.1848578016910069 0.1878999418266434
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('PUNCT', (0.9674017257909875, 1043)),
  ('ADP', (0.6846846846846847, 111)),
  ('DET', (0.5463917525773195, 97)),
  ('NUM', (0.37209302325581395, 43)),
  ('NOUN', (0.28512396694214875, 242)),
  ('PART', (0.23076923076923078, 13)),
  ('VERB', (0.23008849557522124, 113)),
  ('CCONJ', (0.2, 25)),
  ('PRON', (0.13636363636363635, 22)),
  ('PROPN', (0.0625, 32)),
  ('ADV', (0.05128205128205128, 39)),
  ('AUX', (0.04878048780487805, 41)),
  ('ADJ', (0.024096385542168676, 83)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- end small training/test
--- start full training/test
--- ife, sparse, uniform-positional encoding
100%|██████████| 18/18 [43:41<00:00, 145.63s/it]

Token segmentation performance with/out space 
Counter({'TP': 28408, 'FP': 3461, 'FN': 1724, 'R': 0.943, 'F': 0.916, 'P': 0.891}) 
Counter({'TP': 15983, 'FP': 1974, 'FN': 233, 'R': 0.986, 'F': 0.936, 'P': 0.89})
Overall POS accuracy with/out space 0.7840501792114696 0.6013196842624569
Overall SUP:DEP accuracy with/out space 0.34704571840973986 0.3442111711310353
Overall s_type accuracy:  0.7237136465324385

('Tag-wise accuracy',
 [('decl', (0.966044142614601, 589)),
  ('intj', (0.5769230769230769, 26)),
  ('q', (0.5, 16)),
  ('inf', (0.4, 5)),
  ('wh', (0.38095238095238093, 21)),
  ('frag', (0.3225806451612903, 93)),
  ('imp', (0.20408163265306123, 49)),
  ('sub', (0.12195121951219512, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8))],
 [('SPACE', (0.9969818913480886, 13916)),
  ('PUNCT', (0.8689320388349514, 2060)),
  ('DET', (0.8190271816881259, 1398)),
  ('ADP', (0.7428068115091015, 1703)),
  ('NOUN', (0.7381353079771121, 2971)),
  ('CCONJ', (0.7380952380952381, 588)),
  ('PRON', (0.6666666666666666, 1077)),
  ('SCONJ', (0.6024590163934426, 244)),
  ('AUX', (0.5242718446601942, 721)),
  ('VERB', (0.48768768768768767, 1665)),
  ('INTJ', (0.375, 88)),
  ('PART', (0.3074626865671642, 335)),
  ('NUM', (0.22448979591836735, 343)),
  ('PROPN', (0.22123893805309736, 1243)),
  ('ADJ', (0.2206148282097649, 1106)),
  ('ADV', (0.2137030995106036, 613)),
  ('SYM', (0.17142857142857143, 35)),
  ('X', (0.0, 26))])
--- ife, sparse, token-positional encoding
100%|██████████| 18/18 [07:05<00:00, 23.61s/it]

Token segmentation performance with/out space 
Counter({'TP': 29317, 'FP': 3473, 'FN': 1709, 'R': 0.945, 'F': 0.919, 'P': 0.894}) 
Counter({'TP': 16083, 'FP': 1897, 'FN': 133, 'R': 0.992, 'F': 0.94, 'P': 0.894})
Overall POS accuracy with/out space 0.795687487913363 0.6090897878638382
Overall SUP:DEP accuracy with/out space 0.3339737724916133 0.33403781979977754
Overall s_type accuracy:  0.7136465324384788

('Tag-wise accuracy',
 [('decl', (0.9864176570458404, 589)),
  ('intj', (0.46153846153846156, 26)),
  ('q', (0.375, 16)),
  ('wh', (0.2857142857142857, 21)),
  ('frag', (0.26881720430107525, 93)),
  ('inf', (0.2, 5)),
  ('imp', (0.12244897959183673, 49)),
  ('sub', (0.024390243902439025, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.8941747572815534, 2060)),
  ('DET', (0.8183118741058655, 1398)),
  ('ADP', (0.7680563711098062, 1703)),
  ('CCONJ', (0.7517006802721088, 588)),
  ('NOUN', (0.7485695052170986, 2971)),
  ('PRON', (0.7056638811513464, 1077)),
  ('AUX', (0.5312066574202496, 721)),
  ('VERB', (0.4924924924924925, 1665)),
  ('SCONJ', (0.3770491803278688, 244)),
  ('INTJ', (0.3068181818181818, 88)),
  ('PART', (0.29850746268656714, 335)),
  ('PROPN', (0.2333065164923572, 1243)),
  ('ADJ', (0.22423146473779385, 1106)),
  ('NUM', (0.21865889212827988, 343)),
  ('ADV', (0.19086460032626426, 613)),
  ('SYM', (0.14285714285714285, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned, uniform-positional encoding
100%|██████████| 18/18 [2:34:46<00:00, 515.94s/it]

Token segmentation performance with/out space:
Counter({'TP': 27967, 'FP': 4668, 'FN': 2165, 'R': 0.928, 'F': 0.891, 'P': 0.857}) 
Counter({'TP': 16028, 'FP': 2718, 'FN': 188, 'R': 0.988, 'F': 0.917, 'P': 0.855})
Overall POS accuracy with/out space 0.7831541218637993 0.5993463246176616
Overall SUP:DEP accuracy with/out space 0.30378428068025126 0.30390483303104665
Overall s_type accuracy:  0.7125279642058165

('Tag-wise accuracy',
 [('decl', (0.9643463497453311, 589)),
  ('intj', (0.6153846153846154, 26)),
  ('q', (0.5, 16)),
  ('frag', (0.3010752688172043, 93)),
  ('wh', (0.23809523809523808, 21)),
  ('imp', (0.16326530612244897, 49)),
  ('sub', (0.07317073170731707, 41)),
  ('multiple', (0.03125, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (0.9973411899971256, 13916)),
  ('PUNCT', (0.9242718446601942, 2060)),
  ('DET', (0.9062947067238912, 1398)),
  ('CCONJ', (0.8690476190476191, 588)),
  ('ADP', (0.7862595419847328, 1703)),
  ('PRON', (0.7520891364902507, 1077)),
  ('AUX', (0.7128987517337032, 721)),
  ('PART', (0.6746268656716418, 335)),
  ('SCONJ', (0.5245901639344263, 244)),
  ('NOUN', (0.49814877145742176, 2971)),
  ('INTJ', (0.36363636363636365, 88)),
  ('VERB', (0.35315315315315315, 1665)),
  ('PROPN', (0.30973451327433627, 1243)),
  ('ADJ', (0.26763110307414106, 1106)),
  ('NUM', (0.2478134110787172, 343)),
  ('ADV', (0.2463295269168026, 613)),
  ('SYM', (0.08571428571428572, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned, token-positional encoding
100%|██████████| 18/18 [18:42<00:00, 62.39s/it]

Token segmentation performance with/out space 
Counter({'TP': 29101, 'FP': 3984, 'FN': 1925, 'R': 0.938, 'F': 0.908, 'P': 0.88}) 
Counter({'TP': 16028, 'FP': 2247, 'FN': 188, 'R': 0.988, 'F': 0.929, 'P': 0.877})
Overall POS accuracy with/out space 0.7975568877715464 0.6126665022200296
Overall SUP:DEP accuracy with/out space 0.28958742632612966 0.28891928864569083
Overall s_type accuracy:  0.6778523489932886

('Tag-wise accuracy',
 [('decl', (0.99830220713073, 589)),
  ('intj', (0.34615384615384615, 26)),
  ('frag', (0.08602150537634409, 93)),
  ('q', (0.0625, 16)),
  ('imp', (0.0, 49)),
  ('sub', (0.0, 41)),
  ('multiple', (0.0, 32)),
  ('wh', (0.0, 21)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.9529126213592233, 2060)),
  ('DET', (0.9213161659513591, 1398)),
  ('CCONJ', (0.8401360544217688, 588)),
  ('ADP', (0.806224310041104, 1703)),
  ('PRON', (0.7325905292479109, 1077)),
  ('PART', (0.6985074626865672, 335)),
  ('AUX', (0.6463245492371706, 721)),
  ('NOUN', (0.5412319084483339, 2971)),
  ('SCONJ', (0.45491803278688525, 244)),
  ('VERB', (0.4096096096096096, 1665)),
  ('PROPN', (0.31053901850362026, 1243)),
  ('INTJ', (0.3068181818181818, 88)),
  ('ADJ', (0.2585895117540687, 1106)),
  ('NUM', (0.24489795918367346, 343)),
  ('ADV', (0.22838499184339314, 613)),
  ('SYM', (0.11428571428571428, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned using all pre-training data, token-positional encoding
100%|██████████| 18/18 [12:42<00:00, 42.36s/it]

Token segmentation performance with/out space 
Counter({'TP': 28946, 'FP': 4430, 'FN': 2080, 'R': 0.933, 'F': 0.899, 'P': 0.867}) 
Counter({'TP': 16129, 'FP': 2437, 'FN': 87, 'R': 0.995, 'F': 0.928, 'P': 0.869})
Overall POS accuracy with/out space 0.7999742151743698 0.6172915638875185
Overall SUP:DEP accuracy with/out space 0.28577420901246403 0.28374447915544543
Overall s_type accuracy:  0.6789709172259508

('Tag-wise accuracy',
 [('decl', (0.99830220713073, 589)),
  ('intj', (0.19230769230769232, 26)),
  ('frag', (0.15053763440860216, 93)),
  ('imp', (0.0, 49)),
  ('sub', (0.0, 41)),
  ('multiple', (0.0, 32)),
  ('wh', (0.0, 21)),
  ('q', (0.0, 16)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.9601941747572815, 2060)),
  ('DET', (0.9277539341917024, 1398)),
  ('CCONJ', (0.8537414965986394, 588)),
  ('ADP', (0.8027011156782149, 1703)),
  ('PART', (0.7850746268656716, 335)),
  ('PRON', (0.7678737233054782, 1077)),
  ('AUX', (0.6907073509015257, 721)),
  ('SCONJ', (0.6229508196721312, 244)),
  ('NOUN', (0.47088522383036013, 2971)),
  ('VERB', (0.4114114114114114, 1665)),
  ('INTJ', (0.4090909090909091, 88)),
  ('NUM', (0.35276967930029157, 343)),
  ('PROPN', (0.3153660498793242, 1243)),
  ('ADV', (0.2969004893964111, 613)),
  ('ADJ', (0.2730560578661845, 1106)),
  ('SYM', (0.2571428571428571, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned using all pre-training data, token-positional encoding, no spaces
Token segmentation performance with/out space 
Counter({'TP': 12109, 'FP': 4632, 'FN': 4107, 'R': 0.747, 'F': 0.735, 'P': 0.723}) 
Counter({'TP': 12109, 'FP': 4632, 'FN': 4107, 'R': 0.747, 'F': 0.735, 'P': 0.723})
Overall POS accuracy with/out space 0.5549457326097681 0.5549457326097681
Overall SUP:DEP accuracy with/out space 0.1356549787945762 0.1356549787945762

('Tag-wise accuracy',
 [],
 [('PUNCT', (0.9097087378640777, 2060)),
  ('DET', (0.8497854077253219, 1398)),
  ('CCONJ', (0.8180272108843537, 588)),
  ('ADP', (0.7146212566059894, 1703)),
  ('INTJ', (0.5113636363636364, 88)),
  ('AUX', (0.5104022191400832, 721)),
  ('SCONJ', (0.5081967213114754, 244)),
  ('PRON', (0.5051067780872794, 1077)),
  ('NOUN', (0.4867048131942107, 2971)),
  ('ADJ', (0.4240506329113924, 1106)),
  ('VERB', (0.34654654654654654, 1665)),
  ('PART', (0.3283582089552239, 335)),
  ('ADV', (0.27569331158238175, 613)),
  ('NUM', (0.2478134110787172, 343)),
  ('PROPN', (0.24215607401448108, 1243)),
  ('SYM', (0.02857142857142857, 35)),
  ('X', (0.0, 26))])
--- form, sparse, uniform-positional encoding
100%|██████████| 18/18 [9:55:08<00:00, 1983.80s/it]

Token segmentation performance with/out space 
Counter({'TP': 29528, 'FP': 3009, 'FN': 1498, 'R': 0.952, 'F': 0.929, 'P': 0.908}) 
Counter({'TP': 16051, 'FP': 1676, 'FN': 165, 'R': 0.99, 'F': 0.946, 'P': 0.905})
Overall POS accuracy with/out space 0.8271127441500676 0.6697705969412926
Overall SUP:DEP accuracy with/out space 0.3659218735593325 0.36684154115191514
Overall s_type accuracy:  0.7304250559284117

('Tag-wise accuracy',
 [('decl', (0.9881154499151104, 589)),
  ('intj', (0.5384615384615384, 26)),
  ('q', (0.5, 16)),
  ('imp', (0.3673469387755102, 49)),
  ('frag', (0.27956989247311825, 93)),
  ('wh', (0.19047619047619047, 21)),
  ('sub', (0.024390243902439025, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (0.9993923024983119, 14810)),
  ('PUNCT', (0.9271844660194175, 2060)),
  ('NOUN', (0.8391114102995625, 2971)),
  ('DET', (0.7982832618025751, 1398)),
  ('ADP', (0.7639459776864357, 1703)),
  ('PRON', (0.754874651810585, 1077)),
  ('VERB', (0.6216216216216216, 1665)),
  ('CCONJ', (0.5952380952380952, 588)),
  ('AUX', (0.5409153952843273, 721)),
  ('SCONJ', (0.4344262295081967, 244)),
  ('ADJ', (0.4204339963833635, 1106)),
  ('ADV', (0.3866231647634584, 613)),
  ('NUM', (0.36443148688046645, 343)),
  ('PROPN', (0.3153660498793242, 1243)),
  ('INTJ', (0.3068181818181818, 88)),
  ('PART', (0.29253731343283584, 335)),
  ('SYM', (0.08571428571428572, 35)),
  ('X', (0.0, 26))])
```

In [8]:
from tqdm import tqdm

accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

model.interpret(tdocs, seed = 691, covering = [[[row[1] for row in s] for s in d['conllu']] for d in test_docs])

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[d_i]['s_type'][s_i]
            accuracy_sty[test_docs[d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, d_i, s_i) for d_i, doc in enumerate(model._documents) 
                 for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], d_i, s_i) 
                 for d_i, d in enumerate(test_docs) for s_i, s in enumerate(d['conllu']) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
"Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                   for tag in accuracy_sty}).most_common()), list(Counter({tag: (sum(accuracy[tag])/len(accuracy[tag]), len(accuracy[tag])) 
                                                                                           for tag in accuracy}).most_common())

100%|██████████| 18/18 [33:13<00:00, 110.74s/it]


Overall POS accuracy with/out space 0.2919164571649584 0.20072767636901825
Overall SUP:DEP accuracy with/out space 0.3028582489192851 0.30219915987150975


('Tag-wise accuracy',
 [],
 [('NOUN', (0.568158869067654, 2971)),
  ('SPACE', (0.39176232275489536, 14810)),
  ('INTJ', (0.36363636363636365, 88)),
  ('PROPN', (0.25422365245374096, 1243)),
  ('PRON', (0.2414113277623027, 1077)),
  ('NUM', (0.2303206997084548, 343)),
  ('PUNCT', (0.19223300970873786, 2060)),
  ('VERB', (0.13753753753753753, 1665)),
  ('ADV', (0.09135399673735727, 613)),
  ('PART', (0.05373134328358209, 335)),
  ('AUX', (0.052704576976421634, 721)),
  ('ADP', (0.040516735173223725, 1703)),
  ('ADJ', (0.023508137432188065, 1106)),
  ('DET', (0.022889842632331903, 1398)),
  ('SCONJ', (0.020491803278688523, 244)),
  ('CCONJ', (0.01870748299319728, 588)),
  ('SYM', (0.0, 35)),
  ('X', (0.0, 26))])

```
--- start small training/test
--- ife, sparse, uniform-positional encoding
100%|██████████| 1/1 [00:13<00:00, 13.63s/it]

Overall POS accuracy with/out space 0.7208249496981891 0.5242626070409134
Overall SUP:DEP accuracy with/out space 0.499496475327291 0.4918970448045758
Overall s_type accuracy:  0.7037037037037037

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0625, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9413020277481323, 937)),
  ('NOUN', (0.871900826446281, 242)),
  ('PUNCT', (0.70625, 160)),
  ('DET', (0.5876288659793815, 97)),
  ('ADP', (0.5855855855855856, 111)),
  ('CCONJ', (0.56, 25)),
  ('VERB', (0.46017699115044247, 113)),
  ('PRON', (0.36363636363636365, 22)),
  ('NUM', (0.32558139534883723, 43)),
  ('AUX', (0.2682926829268293, 41)),
  ('PROPN', (0.15625, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('ADJ', (0.0, 83)),
  ('SCONJ', (0.0, 16)),
  ('PART', (0.0, 13)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, sparse, token-positional encoding
100%|██████████| 1/1 [00:12<00:00, 12.83s/it]

Overall POS accuracy with/out space 0.7540241448692153 0.5394862036156042
Overall SUP:DEP accuracy with/out space 0.5161127895266868 0.5147759771210677
Overall s_type accuracy:  0.7037037037037037

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0625, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9946638207043756, 937)),
  ('NOUN', (0.8801652892561983, 242)),
  ('PUNCT', (0.7375, 160)),
  ('ADP', (0.6306306306306306, 111)),
  ('DET', (0.6288659793814433, 97)),
  ('CCONJ', (0.6, 25)),
  ('VERB', (0.4690265486725664, 113)),
  ('NUM', (0.32558139534883723, 43)),
  ('PRON', (0.2727272727272727, 22)),
  ('AUX', (0.2682926829268293, 41)),
  ('PROPN', (0.15625, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('ADJ', (0.0, 83)),
  ('SCONJ', (0.0, 16)),
  ('PART', (0.0, 13)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, fine tuned, uniform-positional encoding 
100%|██████████| 1/1 [00:12<00:00, 12.06s/it]

Overall POS accuracy with/out space 0.6896378269617707 0.4652711703139867
Overall SUP:DEP accuracy with/out space 0.5115810674723061 0.5081029551954243
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9413020277481323, 937)),
  ('NOUN', (0.6570247933884298, 242)),
  ('ADP', (0.6306306306306306, 111)),
  ('DET', (0.6185567010309279, 97)),
  ('PUNCT', (0.6125, 160)),
  ('CCONJ', (0.52, 25)),
  ('AUX', (0.3902439024390244, 41)),
  ('VERB', (0.3893805309734513, 113)),
  ('NUM', (0.27906976744186046, 43)),
  ('PART', (0.23076923076923078, 13)),
  ('PRON', (0.18181818181818182, 22)),
  ('ADJ', (0.08433734939759036, 83)),
  ('PROPN', (0.0625, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, fine tuned, token-positional encoding
100%|██████████| 1/1 [00:15<00:00, 15.76s/it]

Overall POS accuracy with/out space 0.7364185110663984 0.5014272121788773
Overall SUP:DEP accuracy with/out space 0.5518630412890232 0.5643469971401335
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('SPACE', (1.0, 937)),
  ('PUNCT', (0.75625, 160)),
  ('NOUN', (0.6694214876033058, 242)),
  ('DET', (0.6597938144329897, 97)),
  ('ADP', (0.6396396396396397, 111)),
  ('CCONJ', (0.6, 25)),
  ('VERB', (0.45132743362831856, 113)),
  ('AUX', (0.34146341463414637, 41)),
  ('NUM', (0.27906976744186046, 43)),
  ('PRON', (0.18181818181818182, 22)),
  ('PROPN', (0.15625, 32)),
  ('PART', (0.15384615384615385, 13)),
  ('ADJ', (0.060240963855421686, 83)),
  ('ADV', (0.02564102564102564, 39)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- form, sparse, uniform-positional encoding
Overall POS accuracy with/out space 0.7471561530506722 0.5347288296860133
Overall SUP:DEP accuracy with/out space 0.5377846790890269 0.5243088655862727
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('PUNCT', (0.9731543624161074, 1043)),
  ('NOUN', (0.8842975206611571, 242)),
  ('ADP', (0.7027027027027027, 111)),
  ('DET', (0.5567010309278351, 97)),
  ('NUM', (0.3953488372093023, 43)),
  ('VERB', (0.3893805309734513, 113)),
  ('PART', (0.23076923076923078, 13)),
  ('PRON', (0.22727272727272727, 22)),
  ('CCONJ', (0.2, 25)),
  ('PROPN', (0.09375, 32)),
  ('ADV', (0.07692307692307693, 39)),
  ('AUX', (0.04878048780487805, 41)),
  ('ADJ', (0.024096385542168676, 83)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- end small training/test
--- start full training/test
--- ife, sparse, uniform-positional encoding
100%|██████████| 18/18 [39:12<00:00, 130.72s/it]

Overall POS accuracy with/out space 0.8118279569892473 0.6528737049827331
Overall SUP:DEP accuracy with/out space 0.6086566569226681 0.6032863849765259
Overall s_type accuracy:  0.7248322147651006

('Tag-wise accuracy',
 [('decl', (0.9643463497453311, 589)),
  ('intj', (0.5769230769230769, 26)),
  ('q', (0.5625, 16)),
  ('inf', (0.4, 5)),
  ('wh', (0.38095238095238093, 21)),
  ('frag', (0.3225806451612903, 93)),
  ('imp', (0.20408163265306123, 49)),
  ('sub', (0.12195121951219512, 41)),
  ('other', (0.07142857142857142, 14)),
  ('multiple', (0.0, 32)),
  ('ger', (0.0, 8))],
 [('SPACE', (0.997053751077896, 13916)),
  ('PUNCT', (0.8820388349514563, 2060)),
  ('NOUN', (0.8801750252440256, 2971)),
  ('DET', (0.8204577968526466, 1398)),
  ('ADP', (0.7563123899001761, 1703)),
  ('CCONJ', (0.7380952380952381, 588)),
  ('PRON', (0.7214484679665738, 1077)),
  ('SCONJ', (0.6065573770491803, 244)),
  ('AUX', (0.5866851595006934, 721)),
  ('VERB', (0.5321321321321322, 1665)),
  ('INTJ', (0.38636363636363635, 88)),
  ('PART', (0.382089552238806, 335)),
  ('NUM', (0.3119533527696793, 343)),
  ('PROPN', (0.3049074818986323, 1243)),
  ('ADV', (0.24796084828711257, 613)),
  ('ADJ', (0.22151898734177214, 1106)),
  ('SYM', (0.17142857142857143, 35)),
  ('X', (0.038461538461538464, 26))])
--- ife, sparse, token-positional encoding
100%|██████████| 18/18 [06:18<00:00, 21.02s/it]

Overall POS accuracy with/out space 0.8213756204473667 0.6582387765170202
Overall SUP:DEP accuracy with/out space 0.5965546164268662 0.5992092908327156
Overall s_type accuracy:  0.7125279642058165

('Tag-wise accuracy',
 [('decl', (0.9847198641765704, 589)),
  ('intj', (0.46153846153846156, 26)),
  ('q', (0.375, 16)),
  ('wh', (0.2857142857142857, 21)),
  ('frag', (0.26881720430107525, 93)),
  ('inf', (0.2, 5)),
  ('imp', (0.12244897959183673, 49)),
  ('sub', (0.024390243902439025, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.9053398058252428, 2060)),
  ('NOUN', (0.8872433524065971, 2971)),
  ('DET', (0.8190271816881259, 1398)),
  ('ADP', (0.7821491485613623, 1703)),
  ('PRON', (0.7595171773444754, 1077)),
  ('CCONJ', (0.7517006802721088, 588)),
  ('AUX', (0.5936199722607489, 721)),
  ('VERB', (0.5315315315315315, 1665)),
  ('SCONJ', (0.38114754098360654, 244)),
  ('PART', (0.3611940298507463, 335)),
  ('PROPN', (0.31938857602574416, 1243)),
  ('INTJ', (0.3181818181818182, 88)),
  ('NUM', (0.26239067055393583, 343)),
  ('ADJ', (0.22694394213381555, 1106)),
  ('ADV', (0.22512234910277323, 613)),
  ('SYM', (0.14285714285714285, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned, uniform-positional encoding
100%|██████████| 18/18 [2:29:47<00:00, 499.28s/it]

Overall POS accuracy with/out space 0.816938802601885 0.6610754810064134
Overall SUP:DEP accuracy with/out space 0.5741097528567632 0.573943661971831
Overall s_type accuracy:  0.70917225950783

('Tag-wise accuracy',
 [('decl', (0.9609507640067911, 589)),
  ('intj', (0.6153846153846154, 26)),
  ('q', (0.5, 16)),
  ('frag', (0.3010752688172043, 93)),
  ('wh', (0.23809523809523808, 21)),
  ('imp', (0.14285714285714285, 49)),
  ('sub', (0.07317073170731707, 41)),
  ('multiple', (0.03125, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (0.9985628054038517, 13916)),
  ('PUNCT', (0.9266990291262136, 2060)),
  ('DET', (0.9084406294706724, 1398)),
  ('CCONJ', (0.8707482993197279, 588)),
  ('PRON', (0.8142989786443825, 1077)),
  ('ADP', (0.8038755137991779, 1703)),
  ('AUX', (0.7572815533980582, 721)),
  ('PART', (0.7522388059701492, 335)),
  ('NOUN', (0.6893301918545944, 2971)),
  ('SCONJ', (0.5245901639344263, 244)),
  ('VERB', (0.42162162162162165, 1665)),
  ('PROPN', (0.4006436041834272, 1243)),
  ('INTJ', (0.375, 88)),
  ('NUM', (0.2915451895043732, 343)),
  ('ADV', (0.27569331158238175, 613)),
  ('ADJ', (0.2739602169981917, 1106)),
  ('SYM', (0.08571428571428572, 35)),
  ('X', (0.038461538461538464, 26))])
--- ife, fine tuned, token-positional encoding
100%|██████████| 18/18 [10:14<00:00, 34.12s/it]

Overall POS accuracy with/out space 0.8244697995229807 0.6641588554514061
Overall SUP:DEP accuracy with/out space 0.5768114071875605 0.5762910798122066
Overall s_type accuracy:  0.6778523489932886

('Tag-wise accuracy',
 [('decl', (0.99830220713073, 589)),
  ('intj', (0.34615384615384615, 26)),
  ('frag', (0.08602150537634409, 93)),
  ('q', (0.0625, 16)),
  ('imp', (0.0, 49)),
  ('sub', (0.0, 41)),
  ('multiple', (0.0, 32)),
  ('wh', (0.0, 21)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.954368932038835, 2060)),
  ('DET', (0.9227467811158798, 1398)),
  ('CCONJ', (0.8401360544217688, 588)),
  ('ADP', (0.8244274809160306, 1703)),
  ('PRON', (0.7883008356545961, 1077)),
  ('PART', (0.7820895522388059, 335)),
  ('NOUN', (0.6886570178391114, 2971)),
  ('AUX', (0.6851595006934813, 721)),
  ('VERB', (0.48228228228228226, 1665)),
  ('SCONJ', (0.45491803278688525, 244)),
  ('PROPN', (0.38777152051488334, 1243)),
  ('INTJ', (0.32954545454545453, 88)),
  ('NUM', (0.29737609329446063, 343)),
  ('ADJ', (0.2603978300180832, 1106)),
  ('ADV', (0.23817292006525284, 613)),
  ('SYM', (0.11428571428571428, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned using all pre-training data, token-positional encoding
100%|██████████| 18/18 [12:25<00:00, 41.44s/it]

Overall POS accuracy with/out space 0.8316895506994134 0.6779723729649728
Overall SUP:DEP accuracy with/out space 0.5690367120459384 0.5663454410674573
Overall s_type accuracy:  0.6789709172259508

('Tag-wise accuracy',
 [('decl', (0.99830220713073, 589)),
  ('intj', (0.19230769230769232, 26)),
  ('frag', (0.15053763440860216, 93)),
  ('imp', (0.0, 49)),
  ('sub', (0.0, 41)),
  ('multiple', (0.0, 32)),
  ('wh', (0.0, 21)),
  ('q', (0.0, 16)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.9616504854368932, 2060)),
  ('DET', (0.9284692417739628, 1398)),
  ('PART', (0.8776119402985074, 335)),
  ('CCONJ', (0.8537414965986394, 588)),
  ('ADP', (0.815619495008808, 1703)),
  ('PRON', (0.8133704735376045, 1077)),
  ('AUX', (0.723994452149792, 721)),
  ('NOUN', (0.6560080780881858, 2971)),
  ('SCONJ', (0.6229508196721312, 244)),
  ('VERB', (0.4864864864864865, 1665)),
  ('NUM', (0.4518950437317784, 343)),
  ('INTJ', (0.42045454545454547, 88)),
  ('PROPN', (0.415124698310539, 1243)),
  ('ADV', (0.32300163132137033, 613)),
  ('ADJ', (0.2766726943942134, 1106)),
  ('SYM', (0.2571428571428571, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned using all pre-training data, token-positional encoding, no spaces
100%|██████████| 18/18 [02:36<00:00,  8.67s/it]

Overall POS accuracy with/out space 0.712074494326591 0.712074494326591
Overall SUP:DEP accuracy with/out space 0.2855201383741043 0.2855201383741043

('Tag-wise accuracy',
 [],
 [('PUNCT', (0.9169902912621359, 2060)),
  ('DET', (0.9120171673819742, 1398)),
  ('ADP', (0.8461538461538461, 1703)),
  ('CCONJ', (0.8384353741496599, 588)),
  ('AUX', (0.8377253814147018, 721)),
  ('PRON', (0.8217270194986073, 1077)),
  ('PART', (0.7283582089552239, 335)),
  ('NOUN', (0.6974082800403905, 2971)),
  ('SCONJ', (0.6311475409836066, 244)),
  ('VERB', (0.5897897897897898, 1665)),
  ('INTJ', (0.5681818181818182, 88)),
  ('ADJ', (0.5289330922242315, 1106)),
  ('ADV', (0.41435562805872755, 613)),
  ('PROPN', (0.40788415124698313, 1243)),
  ('NUM', (0.3119533527696793, 343)),
  ('X', (0.15384615384615385, 26)),
  ('SYM', (0.02857142857142857, 35))])
--- form, sparse, uniform-positional encoding
100%|██████████| 18/18 [10:36:22<00:00, 2121.24s/it]

Overall POS accuracy with/out space 0.8492232321278927 0.7121361618154909
Overall SUP:DEP accuracy with/out space 0.6333957029485773 0.6320731405979738
Overall s_type accuracy:  0.727069351230425

('Tag-wise accuracy',
 [('decl', (0.9830220713073005, 589)),
  ('intj', (0.5384615384615384, 26)),
  ('q', (0.5, 16)),
  ('imp', (0.3673469387755102, 49)),
  ('frag', (0.27956989247311825, 93)),
  ('wh', (0.19047619047619047, 21)),
  ('sub', (0.024390243902439025, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (0.9993247805536799, 14810)),
  ('NOUN', (0.9447997307303938, 2971)),
  ('PUNCT', (0.9373786407766991, 2060)),
  ('PRON', (0.8050139275766016, 1077)),
  ('DET', (0.7982832618025751, 1398)),
  ('ADP', (0.7780387551379918, 1703)),
  ('VERB', (0.6582582582582582, 1665)),
  ('CCONJ', (0.5952380952380952, 588)),
  ('AUX', (0.5880721220527045, 721)),
  ('SCONJ', (0.4385245901639344, 244)),
  ('NUM', (0.4314868804664723, 343)),
  ('ADJ', (0.42495479204339964, 1106)),
  ('ADV', (0.42251223491027734, 613)),
  ('PROPN', (0.4022526146419952, 1243)),
  ('PART', (0.3492537313432836, 335)),
  ('INTJ', (0.3181818181818182, 88)),
  ('SYM', (0.08571428571428572, 35)),
  ('X', (0.0, 26))])
```

In [9]:
model.fine_tune(docs, covering = covering, all_layers = all_layers)

Fine-tuning dense output heads...


100%|██████████| 132/132 [6:31:15<00:00, 177.84s/it]


In [10]:
from tqdm import tqdm

confusion = Counter()
confusion_nsp = Counter()
accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

model.interpret(tdocs, seed = 691) 

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_spans = set([(sh-len(gt), sh) for sh, gt in zip(pred_spans, pred_toks)])

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_spans = set([(sh-len(gt), sh) for sh, gt in zip(gold_spans, gold_toks)])

for pred_span, pred_tok in zip(pred_spans, pred_toks):
    if pred_span in gold_spans:
        confusion['TP'] += 1
        if pred_tok != ' ':
            confusion_nsp['TP'] += 1
    else:
        confusion['FP'] += 1
        if pred_tok != ' ':
            confusion_nsp['FP'] += 1
confusion['FN'] = len(gold_spans) - confusion['TP']
confusion_nsp['FN'] = len([t for t in gold_toks if t != ' ']) - confusion_nsp['TP']

confusion['P'] = round(confusion['TP']/(confusion['TP'] + confusion['FP']), 3)
confusion['R'] = round(confusion['TP']/(confusion['TP'] + confusion['FN']), 3)
confusion['F'] = round(2*confusion['P']*confusion['R']/(confusion['P']+confusion["R"]), 3)
confusion_nsp['P'] = round(confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FP']), 3)
confusion_nsp['R'] = round(confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FN']), 3)
confusion_nsp['F'] = round(2*confusion_nsp['P']*confusion_nsp['R']/(confusion_nsp['P']+confusion_nsp["R"]), 3)

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[d_i]['s_type'][s_i]
            accuracy_sty[test_docs[d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, d_i, s_i) for d_i, doc in enumerate(model._documents) 
                 for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], d_i, s_i) 
                 for d_i, d in enumerate(test_docs) for s_i, s in enumerate(d['conllu']) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Token segmentation performance with/out space", confusion, confusion_nsp)
print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
"Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                   for tag in accuracy_sty}).most_common()), list(Counter({tag: (sum(accuracy[tag])/len(accuracy[tag]), len(accuracy[tag])) 
                                                                                           for tag in accuracy}).most_common())

100%|██████████| 18/18 [1:28:36<00:00, 295.37s/it]

Token segmentation performance with/out space Counter({'FN': 17674, 'TP': 13352, 'FP': 6633, 'P': 0.668, 'F': 0.523, 'R': 0.43}) Counter({'TP': 8521, 'FN': 7695, 'FP': 4203, 'P': 0.67, 'F': 0.589, 'R': 0.525})
Overall POS accuracy with/out space 0.08824856571907433 0.05346571287617168
Overall SUP:DEP accuracy with/out space 0.17062797097823368 0.16842187991197738





('Tag-wise accuracy',
 [],
 [('INTJ', (0.3068181818181818, 88)),
  ('SPACE', (0.1263335584064821, 14810)),
  ('PUNCT', (0.12524271844660195, 2060)),
  ('NUM', (0.11078717201166181, 343)),
  ('DET', (0.0765379113018598, 1398)),
  ('PRON', (0.06963788300835655, 1077)),
  ('PROPN', (0.061946902654867256, 1243)),
  ('NOUN', (0.05856613934702121, 2971)),
  ('CCONJ', (0.02891156462585034, 588)),
  ('SYM', (0.02857142857142857, 35)),
  ('AUX', (0.027739251040221916, 721)),
  ('ADP', (0.023487962419260128, 1703)),
  ('ADV', (0.022838499184339316, 613)),
  ('SCONJ', (0.00819672131147541, 244)),
  ('VERB', (0.007207207207207207, 1665)),
  ('ADJ', (0.0045207956600361665, 1106)),
  ('PART', (0.0, 335)),
  ('X', (0.0, 26))])

In [11]:
from tqdm import tqdm

accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

model.interpret(tdocs, seed = 691, covering = [[[row[1] for row in s] for s in d['conllu']] for d in test_docs])

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[d_i]['s_type'][s_i]
            accuracy_sty[test_docs[d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, d_i, s_i) for d_i, doc in enumerate(model._documents) 
                 for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], d_i, s_i) 
                 for d_i, d in enumerate(test_docs) for s_i, s in enumerate(d['conllu']) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
"Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                   for tag in accuracy_sty}).most_common()), list(Counter({tag: (sum(accuracy[tag])/len(accuracy[tag]), len(accuracy[tag])) 
                                                                                           for tag in accuracy}).most_common())

100%|██████████| 18/18 [1:29:42<00:00, 299.00s/it]

Overall POS accuracy with/out space 0.22706762070521497 0.17834237789837198
Overall SUP:DEP accuracy with/out space 0.28269565778437317 0.28181368915245864





('Tag-wise accuracy',
 [],
 [('NOUN', (0.46987546280713566, 2971)),
  ('INTJ', (0.3977272727272727, 88)),
  ('PRON', (0.3045496750232126, 1077)),
  ('SPACE', (0.2804186360567184, 14810)),
  ('NUM', (0.1749271137026239, 343)),
  ('VERB', (0.15915915915915915, 1665)),
  ('PUNCT', (0.13349514563106796, 2060)),
  ('PROPN', (0.12308930008045052, 1243)),
  ('ADV', (0.10440456769983687, 613)),
  ('AUX', (0.08876560332871013, 721)),
  ('DET', (0.07939914163090128, 1398)),
  ('PART', (0.07164179104477612, 335)),
  ('SCONJ', (0.04918032786885246, 244)),
  ('CCONJ', (0.03571428571428571, 588)),
  ('ADP', (0.03523194362889019, 1703)),
  ('SYM', (0.02857142857142857, 35)),
  ('ADJ', (0.020795660036166366, 1106)),
  ('X', (0.0, 26))])