# It's a Machine and Natural Language Tagger

In [1]:
from src.IaMaN.base import LM
from src.utils.data import load_ud
from collections import defaultdict
from collections import Counter
from tqdm import tqdm
import numpy as np
import os, re

seed = 691; max_char = 200_000_000
m = 10; space = True; fine_tune = False; fine_tune_post_pretrain = False
positional = 'independent'; positionally_encode = True; bits = 500; num_pretrain = 0
do_ife = False; update_ife = False; update_bow = False; 
runners = 25; gpu = True

print("Loading pre-training data...")
pretrain_path = '/local-data/newstweet/week_2019-40_article_texts/'
total_pretrain = len([pretrain_file for pretrain_file in os.listdir(pretrain_path) if re.search("^\d+.txt$", pretrain_file)])
all_pretrain_files = [pretrain_file for pretrain_file in os.listdir(pretrain_path) if re.search("^\d+.txt$", pretrain_file)]
if num_pretrain:
    np.random.seed(seed)
    pretrain_files = np.random.choice(all_pretrain_files, size=num_pretrain, replace=False)
else:
    pretrain_files = np.array([])
ptdocs = [[[open(pretrain_path+pretrain_file).read()]] for pretrain_file in tqdm(pretrain_files)]

print("Loading gold-tagged UDs data...")
load_set = 'GUM'; tokenizer = 'hr-bpe'
all_docs = load_ud("English", num_articles = 0, seed = 691, load_set = load_set, rebuild = True, space = space)
test_docs = [doc for doc in all_docs if 'test' in doc['id'] and len(doc['text']) <= max_char]# [:2]
train_docs = [doc for doc in all_docs if 'test' not in doc['id'] and len(doc['text']) <= max_char]# [:4]
nsamp = len(test_docs)
print('Avail. pre-train, total pre-train, Avail. gold, total gold-train, total test-gold: ', 
      total_pretrain, len(ptdocs), len(all_docs), len(train_docs), len(test_docs))

Loading pre-training data...


0it [00:00, ?it/s]


Loading gold-tagged UDs data...
Avail. pre-train, total pre-train, Avail. gold, total gold-train, total test-gold:  14198 0 150 132 18


In [2]:
docs = [["".join([row[1] for row in s]) for s in d['conllu']] for d in train_docs]
tdocs = [["".join([row[1] for row in s]) for s in d['conllu']] for d in test_docs]
covering = [[[row[1] for row in s] for s in d['conllu']] for d in train_docs]
tcovering = [[[row[1] for row in s] for s in d['conllu']] for d in test_docs]
covering_vocab = set([t for d in covering for s in d for t in s])

all_layers = {d_i: {# 'lem': [[row[2] for row in s] for s in d['conllu']], 
                    # 'sty': [[d['s_type'][s_i] for row in s] for s_i, s in enumerate(d['conllu'])], 
                    'pos': [[row[3] for row in s] for s in d['conllu']], 
                    'sup': [[(str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]) for row in s] for s in d['conllu']], 
                    'dep': [[row[7] for row in s] for s in d['conllu']]}
              for d_i, d in enumerate(train_docs)}

model = LM(m = m, tokenizer = tokenizer, noise = 0.001, seed = seed, space = space, positional = positional,
           positionally_encode = positionally_encode, do_ife = do_ife, runners = runners, gpu = gpu, bits = bits)
data_streams = model.fit(docs, f'{load_set}-{nsamp}', covering = covering, all_layers = all_layers, fine_tune = fine_tune)
model.pre_train(ptdocs, update_ife = update_ife, fine_tune = fine_tune)
if fine_tune_post_pretrain:
    model.fine_tune(docs, covering = covering, all_layers = all_layers)

Initializing:   0%|          | 0/6503 [00:00<?, ?it/s]

Training tokenizer...


Initializing: 100%|██████████| 6503/6503 [00:01<00:00, 4546.06it/s]
Fitting:  20%|██        | 20/100 [00:32<02:16,  1.70s/it]
  0%|          | 0/132 [00:00<?, ?it/s][A
  1%|          | 1/132 [00:00<00:17,  7.63it/s][A

Built a vocabulary of 10606 types
Tokenizing documents...



  2%|▏         | 2/132 [00:00<00:17,  7.33it/s][A
  2%|▏         | 3/132 [00:00<00:16,  7.78it/s][A
  3%|▎         | 4/132 [00:00<00:16,  7.64it/s][A
  5%|▍         | 6/132 [00:00<00:14,  8.42it/s][A
  6%|▌         | 8/132 [00:00<00:13,  8.91it/s][A
  8%|▊         | 10/132 [00:01<00:12,  9.77it/s][A
  8%|▊         | 11/132 [00:01<00:14,  8.31it/s][A
  9%|▉         | 12/132 [00:01<00:15,  7.80it/s][A
 10%|▉         | 13/132 [00:01<00:17,  6.96it/s][A
 11%|█▏        | 15/132 [00:01<00:16,  7.21it/s][A
 13%|█▎        | 17/132 [00:01<00:14,  8.01it/s][A
 14%|█▍        | 19/132 [00:02<00:13,  8.59it/s][A
 15%|█▌        | 20/132 [00:02<00:15,  7.16it/s][A
 16%|█▌        | 21/132 [00:02<00:16,  6.86it/s][A
 17%|█▋        | 22/132 [00:02<00:17,  6.33it/s][A
 17%|█▋        | 23/132 [00:02<00:17,  6.24it/s][A
 18%|█▊        | 24/132 [00:03<00:17,  6.31it/s][A
 19%|█▉        | 25/132 [00:03<00:16,  6.47it/s][A
 20%|█▉        | 26/132 [00:03<00:15,  6.95it/s][A
 20%|██        |

Counting documents and aggregating counts...



0it [00:00, ?it/s][A
1it [00:56, 56.27s/it][A
35840it [00:56, 39.39s/it][A
73728it [00:56, 27.57s/it][A
112640it [00:56, 19.30s/it][A
153600it [00:56, 13.51s/it][A
192512it [00:56,  9.46s/it][A
234496it [00:56,  6.62s/it][A
276480it [00:56,  4.63s/it][A
318464it [00:57,  3.24s/it][A
356635it [01:01,  2.27s/it][A
396669it [01:01,  1.59s/it][A
437151it [01:02,  1.11s/it][A
476541it [01:02,  1.28it/s][A
517501it [01:02,  1.83it/s][A
557437it [01:02,  2.62it/s][A
596349it [01:02,  3.74it/s][A
638333it [01:02,  5.35it/s][A
678269it [01:02,  7.64it/s][A
717316it [01:07, 10.91it/s][A
756936it [01:07, 15.58it/s][A
797896it [01:07, 22.26it/s][A
839880it [01:07, 31.80it/s][A
881864it [01:07, 45.43it/s][A
923019it [01:07, 64.90it/s][A
961743it [01:07, 92.71it/s][A
1000648it [01:07, 132.43it/s][A
1039197it [01:12, 187.82it/s][A
1081085it [01:12, 268.26it/s][A
1121021it [01:13, 383.11it/s][A
1160957it [01:13, 547.07it/s][A
1198757it [01:13, 781.05it/s][A
1238781it [

Collecting pre-processed data...



1it [00:00,  1.08it/s][A
2it [00:01,  1.22it/s][A
3it [00:01,  1.38it/s][A
4it [00:02,  1.39it/s][A
5it [00:03,  1.30it/s][A
6it [00:04,  1.32it/s][A
7it [00:04,  1.45it/s][A
8it [00:05,  1.49it/s][A
9it [00:05,  1.62it/s][A
10it [00:06,  1.75it/s][A
11it [00:07,  1.58it/s][A
12it [00:08,  1.47it/s][A
13it [00:08,  1.38it/s][A
14it [00:09,  1.43it/s][A
15it [00:10,  1.40it/s][A
16it [00:10,  1.61it/s][A
17it [00:11,  1.52it/s][A
18it [00:11,  1.63it/s][A
19it [00:12,  1.67it/s][A
20it [00:13,  1.55it/s][A
21it [00:13,  1.57it/s][A
22it [00:15,  1.20it/s][A
23it [00:15,  1.27it/s][A
24it [00:16,  1.30it/s][A
25it [00:17,  1.46it/s][A

0it [00:00, ?it/s][A
1it [00:00,  1.87it/s][A
2it [00:01,  1.70it/s][A
3it [00:01,  1.62it/s][A
4it [00:02,  1.36it/s][A
5it [00:03,  1.40it/s][A
6it [00:04,  1.42it/s][A
7it [00:04,  1.62it/s][A
8it [00:05,  1.78it/s][A
9it [00:05,  1.60it/s][A
10it [00:06,  1.43it/s][A
11it [00:07,  1.60it/s][A
12it [00:07,  1.56it/s

Aggregating metadata...



 59%|█████▉    | 78/132 [00:00<00:00, 397.57it/s][A
 79%|███████▉  | 104/132 [00:00<00:00, 342.58it/s][A
100%|██████████| 132/132 [00:00<00:00, 300.21it/s][A

  0%|          | 0/4758045 [00:00<?, ?it/s][A

Building cipher...  done.
Encoding parameters...



  0%|          | 6196/4758045 [00:00<01:16, 61959.76it/s][A
  0%|          | 12348/4758045 [00:00<01:16, 61816.53it/s][A
  0%|          | 18399/4758045 [00:00<01:17, 61417.01it/s][A
  1%|          | 24902/4758045 [00:00<01:15, 62456.08it/s][A
  1%|          | 31154/4758045 [00:00<01:15, 62472.54it/s][A
  1%|          | 37721/4758045 [00:00<01:14, 63396.43it/s][A
  1%|          | 43958/4758045 [00:00<01:14, 63084.94it/s][A
  1%|          | 50041/4758045 [00:00<01:15, 62390.70it/s][A
  1%|          | 56445/4758045 [00:00<01:14, 62874.62it/s][A
  1%|▏         | 62504/4758045 [00:01<01:15, 62169.13it/s][A
  1%|▏         | 68551/4758045 [00:01<01:17, 60462.40it/s][A
  2%|▏         | 74535/4758045 [00:01<01:17, 60273.65it/s][A
  2%|▏         | 80557/4758045 [00:01<01:17, 60256.35it/s][A
  2%|▏         | 86785/4758045 [00:01<01:16, 60848.20it/s][A
  2%|▏         | 92948/4758045 [00:01<01:16, 61078.76it/s][A
  2%|▏         | 99033/4758045 [00:01<01:16, 61008.96it/s][A
  2%|▏  

Building target vocabularies...
Pre-computing BOW probabilities...


100%|██████████| 2/2 [00:00<00:00, 3207.88it/s]

100%|██████████| 10/10 [00:00<00:00, 47934.90it/s]

 done.
Building context vocabularies...
Pre-computing wave amplitudes... done.
Stacking output vocabularies for decoders...
Encoding data streams for torch processing...




  0%|          | 0/22 [00:00<?, ?it/s][A

 done.
Computing marginal statistics...



  5%|▍         | 1/22 [00:01<00:28,  1.34s/it][A
  9%|▉         | 2/22 [00:01<00:21,  1.08s/it][A
 14%|█▎        | 3/22 [00:02<00:18,  1.02it/s][A
100%|██████████| 22/22 [00:02<00:00,  7.73it/s][A

  0%|          | 0/22 [00:00<?, ?it/s][A

Building dense output heads...



  5%|▍         | 1/22 [00:01<00:30,  1.45s/it][A
  9%|▉         | 2/22 [00:02<00:24,  1.22s/it][A
 14%|█▎        | 3/22 [00:03<00:21,  1.14s/it][A
100%|██████████| 22/22 [00:03<00:00,  6.52it/s][A

  0%|          | 0/132 [00:00<?, ?it/s][A

Counting for transition matrices...



  1%|          | 1/132 [00:00<00:26,  4.90it/s][A
  2%|▏         | 2/132 [00:00<00:25,  5.06it/s][A
  2%|▏         | 3/132 [00:00<00:24,  5.33it/s][A
  3%|▎         | 4/132 [00:00<00:25,  4.95it/s][A
  4%|▍         | 5/132 [00:01<00:31,  4.01it/s][A
  5%|▍         | 6/132 [00:01<00:32,  3.93it/s][A
  5%|▌         | 7/132 [00:01<00:28,  4.34it/s][A
  6%|▌         | 8/132 [00:01<00:27,  4.43it/s][A
  7%|▋         | 9/132 [00:01<00:24,  4.93it/s][A
  8%|▊         | 10/132 [00:02<00:22,  5.42it/s][A
  8%|▊         | 11/132 [00:02<00:25,  4.73it/s][A
  9%|▉         | 12/132 [00:02<00:28,  4.19it/s][A
 10%|▉         | 13/132 [00:02<00:31,  3.78it/s][A
 11%|█         | 14/132 [00:03<00:30,  3.93it/s][A
 11%|█▏        | 15/132 [00:03<00:29,  3.94it/s][A
 12%|█▏        | 16/132 [00:03<00:24,  4.71it/s][A
 13%|█▎        | 17/132 [00:03<00:26,  4.38it/s][A
 14%|█▎        | 18/132 [00:03<00:23,  4.95it/s][A
 14%|█▍        | 19/132 [00:04<00:45,  2.50it/s][A
 15%|█▌        | 20/

Building transition matrices for Viterbi tag decoding...
Done.
Model params, types, encoding size, contexts, vec dim, max sent, and % capacity used: 176130 10607 500 500 791 178 3.321





In [3]:
interpret_docs = list([docs[3][0:1]])
print(interpret_docs)
model.interpret(interpret_docs, seed = 691)
for doc in model._documents:
    print('opening next doc:')
    for s in doc._sentences:
        print(f'opening next sent: {s._sty}')
        for t in s._tokens:
            print(f'opening next token: {t._form}, {t._sep}, {t._pos}, {t._sup}, {t._dep}')

[[' Emperor Norton ']]



0it [00:00, ?it/s][A
1it [00:02,  2.90s/it][A

100%|██████████| 1/1 [00:00<00:00, 14.03it/s]

Interpreting documents...
opening next doc:
opening next sent: None
opening next token:  , False, SPACE, 1, space
opening next token: Emperor, False, NOUN, 0, root
opening next token:  , False, SPACE, 1, space
opening next token: Nort, False, NOUN, 1, space
opening next token: o, False, SPACE, 1, space
opening next token: n, False, NOUN, -4, space
opening next token:  , True, SPACE, -4, space





In [4]:
interpret_docs = list([tdocs[0][1:2]])
print(interpret_docs)
model.interpret(interpret_docs, seed = 691)
for doc in model._documents:
    print('opening next doc:')
    for s in doc._sentences:
        print(f'opening next sent: {s._sty}')
        for t in s._tokens:
            print(f'opening next token: {t._form}, {t._sep}, {t._pos}, {t._sup}, {t._dep}')

[[' Results from a nationally representative sample of adults ']]



0it [00:00, ?it/s][A
1it [00:02,  2.87s/it][A

100%|██████████| 1/1 [00:00<00:00, 15.75it/s]

Interpreting documents...
opening next doc:
opening next sent: None
opening next token:  Results from a nat, False, SPACE, 0, root
opening next token: io, False, NOUN, 1, nmod
opening next token: nally, False, NOUN, 1, nmod
opening next token:  representative , False, NOUN, 1, nmod
opening next token: sample, False, NOUN, -4, space
opening next token:  , False, NOUN, 1, space
opening next token: of, False, NOUN, -4, space
opening next token:  , False, SPACE, -4, space
opening next token: adults, False, NOUN, -4, space
opening next token:  , True, SPACE, -4, space





In [5]:
interpret_docs = list([tdocs[0][1:2]])
print(interpret_docs)
interpret_covering =  [[[row[1] for row in s] for s in d['conllu']][1:2] for d in test_docs[0:1]]
print(interpret_covering)
model.interpret(interpret_docs, seed = 691, covering = interpret_covering)

accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[0:1][d_i]['s_type'][s_i+1]
            accuracy_sty[test_docs[0:1][d_i]['s_type'][s_i+1]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep) for doc in model._documents for s in doc._sentences for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs[:1] for s in d['conllu'][1:2] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]),
                  row[7]) for d in test_docs[:1] for s in d['conllu'][1:2] for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs[:1] for s in d['conllu'][1:2] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Tag-wise POS accuracy with/out space", {tag: sum(accuracy[tag])/len(accuracy[tag]) for tag in accuracy}, 
                                          {tag: sum(accuracy_nsp[tag])/len(accuracy_nsp[tag]) for tag in accuracy_nsp})
print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
print("Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                         for tag in accuracy_sty}).most_common()))

[[' Results from a nationally representative sample of adults ']]
[[[' ', 'Results', ' ', 'from', ' ', 'a', ' ', 'nationally', ' ', 'representative', ' ', 'sample', ' ', 'of', ' ', 'adults', ' ']]]



0it [00:00, ?it/s][A
1it [00:02,  2.94s/it][A

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  7.32it/s][A

Interpreting documents...
Tag-wise POS accuracy with/out space {'SPACE': 0.3333333333333333, 'NOUN': 1.0, 'ADP': 0.0, 'DET': 0.0, 'ADV': 0.0, 'ADJ': 0.0} {'NOUN': 1.0, 'ADP': 0.0, 'DET': 0.0, 'ADV': 0.0, 'ADJ': 0.0}
Overall POS accuracy with/out space 0.35294117647058826 0.375
Overall SUP:DEP accuracy with/out space 0.23529411764705882 0.25
Tag-wise accuracy []





In [6]:
interpret_docs = list([tdocs[0][:2]])
print(interpret_docs)
interpret_covering =  [[[row[1] for row in s] for s in d['conllu']][:2] for d in test_docs[0:1]]
print(interpret_covering)
model.interpret(interpret_docs, seed = 691, covering = interpret_covering)

accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[0:1][d_i]['s_type'][s_i]
            accuracy_sty[test_docs[0:1][d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, s_i, d_i) for d_i, doc in enumerate(model._documents) for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs[:1] for s in d['conllu'][:2] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], s_i, d_i) 
                 for d_i, d in enumerate(test_docs[:1]) for s_i, s in enumerate(d['conllu'][:2]) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs[:1] for s in d['conllu'][:2] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Tag-wise POS accuracy with/out space", {tag: sum(accuracy[tag])/len(accuracy[tag]) for tag in accuracy}, 
                                          {tag: sum(accuracy_nsp[tag])/len(accuracy_nsp[tag]) for tag in accuracy_nsp})
print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
print("Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                         for tag in accuracy_sty}).most_common()))

[[' The prevalence of discrimination across racial groups in contemporary America: ', ' Results from a nationally representative sample of adults ']]
[[[' ', 'The', ' ', 'prevalence', ' ', 'of', ' ', 'discrimination', ' ', 'across', ' ', 'racial', ' ', 'groups', ' ', 'in', ' ', 'contemporary', ' ', 'America', ':', ' '], [' ', 'Results', ' ', 'from', ' ', 'a', ' ', 'nationally', ' ', 'representative', ' ', 'sample', ' ', 'of', ' ', 'adults', ' ']]]



0it [00:00, ?it/s][A
1it [00:02,  2.88s/it][A

  0%|          | 0/1 [00:00<?, ?it/s][A

Interpreting documents...



100%|██████████| 1/1 [00:00<00:00,  4.52it/s][A

Tag-wise POS accuracy with/out space {'SPACE': 0.7, 'DET': 0.0, 'NOUN': 1.0, 'ADP': 0.0, 'ADJ': 0.0, 'PROPN': 0.0, 'PUNCT': 0.0, 'ADV': 0.0} {'DET': 0.0, 'NOUN': 1.0, 'ADP': 0.0, 'ADJ': 0.0, 'PROPN': 0.0, 'PUNCT': 0.0, 'ADV': 0.0}
Overall POS accuracy with/out space 0.5128205128205128 0.3157894736842105
Overall SUP:DEP accuracy with/out space 0.38461538461538464 0.42105263157894735
Tag-wise accuracy []





In [7]:
# model.interpret(tdocs, seed = 691, predict_tags = False)

In [8]:
from tqdm import tqdm

confusion = Counter()
confusion_nsp = Counter()
accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

model.interpret(tdocs, seed = 691) 

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_spans = set([(sh-len(gt), sh) for sh, gt in zip(pred_spans, pred_toks)])

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
# gold_spans = set([(sh-len(gt), sh) for sh, gt in zip(gold_spans, gold_toks)])
gold_spans_nsp = {(sh-len(gt), sh): 1 for sh, gt in zip(gold_spans, gold_toks)}
gold_spans = {(sh-len(gt), sh): 1 for sh, gt in zip(gold_spans, gold_toks)}

confusion = {"TP": 0, "FP": 0, "FN": 0, "P": 0, "R": 0, "F": 0}
confusion_nsp = {"TP": 0, "FP": 0, "FN": 0, "P": 0, "R": 0, "F": 0}
for pred_span, pred_tok in zip(pred_spans, pred_toks):
    if pred_span not in gold_spans:
        confusion['FP'] += 1
        if pred_tok != ' ':
            confusion_nsp['FP'] += 1
    else:
        confusion['TP'] += 1
        if pred_tok != ' ':
            confusion_nsp['TP'] += 1
for gold_span, gold_tok in zip(gold_spans, gold_toks):
    if gold_span not in pred_spans:
        confusion['FN'] += 1
        if gold_tok != ' ':
            confusion_nsp['FN'] += 1
confusion['P'] = confusion['TP']/(confusion['TP'] + confusion['FP']) if (confusion['TP'] + confusion['FP']) else 0
confusion['R'] = confusion['TP']/(confusion['TP'] + confusion['FN']) if (confusion['TP'] + confusion['FN']) else 0
confusion['F'] = 2*confusion['P']*confusion['R']/(confusion['P']+confusion['R']) if (confusion['P']+confusion['R']) else 0
confusion_nsp['P'] = confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FP']) if (confusion_nsp['TP'] + confusion_nsp['FP']) else 0
confusion_nsp['R'] = confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FN']) if (confusion_nsp['TP'] + confusion_nsp['FN']) else 0
confusion_nsp['F'] = 2*confusion_nsp['P']*confusion_nsp['R']/(confusion_nsp['P']+confusion_nsp['R']) if (confusion_nsp['P']+confusion_nsp['R']) else 0

# for pred_span, pred_tok in zip(pred_spans, pred_toks):
#     if pred_span in gold_spans:
#         confusion['TP'] += 1
#         if pred_tok != ' ':
#             confusion_nsp['TP'] += 1
#             gold_spans_nsp[pred_span] = 0
#         else:
#             del(gold_spans_nsp[pred_span])
#         gold_spans[pred_span] = 0
#     else:
#         confusion['FP'] += 1
#         if pred_tok != ' ':
#             confusion_nsp['FP'] += 1
# confusion['FN'] = len(gold_spans) - confusion['TP']
# confusion_nsp['FN'] = len([t for t in gold_toks if t != ' ']) - confusion_nsp['TP']
# confusion['FN'] = sum(gold_spans.values())
# confusion_nsp['FN'] = sum(gold_spans_nsp.values())

# confusion['P'] = round(confusion['TP']/(confusion['TP'] + confusion['FP']), 4)
# confusion['R'] = round(confusion['TP']/(confusion['TP'] + confusion['FN']), 4)
# confusion['F'] = round(2*confusion['P']*confusion['R']/(confusion['P']+confusion["R"]), 4)
# confusion_nsp['P'] = round(confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FP']), 4)
# confusion_nsp['R'] = round(confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FN']), 4)
# confusion_nsp['F'] = round(2*confusion_nsp['P']*confusion_nsp['R']/(confusion_nsp['P']+confusion_nsp["R"]), 4)

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[d_i]['s_type'][s_i]
            accuracy_sty[test_docs[d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, d_i, s_i) for d_i, doc in enumerate(model._documents) 
                 for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], d_i, s_i) 
                 for d_i, d in enumerate(test_docs) for s_i, s in enumerate(d['conllu']) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Token segmentation performance with/out space", confusion, confusion_nsp)
print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
"Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                   for tag in accuracy_sty}).most_common()), list(Counter({tag: (sum(accuracy[tag])/len(accuracy[tag]), len(accuracy[tag])) 
                                                                                           for tag in accuracy}).most_common())


0it [00:00, ?it/s][A
1it [00:01,  1.32s/it][A
2it [00:01,  1.12s/it][A
3it [00:02,  1.05it/s][A
4it [00:03,  1.14it/s][A
5it [00:03,  1.31it/s][A
6it [00:04,  1.31it/s][A
7it [00:05,  1.32it/s][A
8it [00:05,  1.41it/s][A
9it [00:06,  1.49it/s][A
10it [00:07,  1.37it/s][A
11it [00:08,  1.13it/s][A
12it [00:09,  1.22it/s][A
13it [00:09,  1.26it/s][A
14it [00:10,  1.41it/s][A
15it [00:11,  1.42it/s][A
16it [00:11,  1.57it/s][A
17it [00:11,  1.78it/s][A
18it [00:12,  1.42it/s][A

  0%|          | 0/18 [00:00<?, ?it/s][A

Interpreting documents...



  6%|▌         | 1/18 [00:18<05:06, 18.05s/it][A
 11%|█         | 2/18 [00:40<05:07, 19.24s/it][A
 17%|█▋        | 3/18 [00:54<04:26, 17.75s/it][A
 22%|██▏       | 4/18 [01:12<04:08, 17.77s/it][A
 28%|██▊       | 5/18 [01:22<03:20, 15.40s/it][A
 33%|███▎      | 6/18 [01:36<03:01, 15.09s/it][A
 39%|███▉      | 7/18 [01:51<02:46, 15.11s/it][A
 44%|████▍     | 8/18 [02:02<02:19, 13.98s/it][A
 50%|█████     | 9/18 [02:21<02:17, 15.29s/it][A
 56%|█████▌    | 10/18 [02:46<02:25, 18.19s/it][A
 61%|██████    | 11/18 [02:56<01:49, 15.69s/it][A
 67%|██████▋   | 12/18 [03:11<01:34, 15.75s/it][A
 72%|███████▏  | 13/18 [03:25<01:15, 15.18s/it][A
 78%|███████▊  | 14/18 [03:40<00:59, 14.90s/it][A
 83%|████████▎ | 15/18 [04:00<00:49, 16.55s/it][A
 89%|████████▉ | 16/18 [04:09<00:28, 14.24s/it][A
 94%|█████████▍| 17/18 [04:18<00:12, 12.71s/it][A
100%|██████████| 18/18 [04:30<00:00, 15.02s/it][A


Token segmentation performance with/out space {'TP': 19091, 'FP': 9791, 'FN': 11935, 'P': 0.6609999307527179, 'R': 0.6153226326306969, 'F': 0.6373439273552781} {'TP': 12097, 'FP': 6217, 'FN': 7693, 'P': 0.660532925630665, 'R': 0.6112683173319858, 'F': 0.6349464623136678}
Overall POS accuracy with/out space 0.18839038226003996 0.08744449925999014
Overall SUP:DEP accuracy with/out space 0.1912609930060245 0.19302173200829967


('Tag-wise accuracy',
 [],
 [('SPACE', (0.2989196488858879, 14810)),
  ('INTJ', (0.25, 88)),
  ('PUNCT', (0.23592233009708738, 2060)),
  ('NOUN', (0.21676203298552676, 2971)),
  ('NUM', (0.10204081632653061, 343)),
  ('PRON', (0.09099350046425256, 1077)),
  ('ADV', (0.03262642740619902, 613)),
  ('PROPN', (0.02815768302493966, 1243)),
  ('AUX', (0.020804438280166437, 721)),
  ('CCONJ', (0.01870748299319728, 588)),
  ('VERB', (0.014414414414414415, 1665)),
  ('PART', (0.008955223880597015, 335)),
  ('DET', (0.008583690987124463, 1398)),
  ('ADP', (0.0064591896652965355, 1703)),
  ('ADJ', (0.0018083182640144665, 1106)),
  ('SCONJ', (0.0, 244)),
  ('SYM', (0.0, 35)),
  ('X', (0.0, 26))])

```
Interpreting documents...

100%|██████████| 2/2 [00:29<00:00, 14.76s/it]

Token segmentation performance with/out space 
Counter({'TP': 3235, 'FP': 787, 'FN': 461, 'R': 0.875, 'F': 0.838, 'P': 0.804}) 
Counter({'TP': 1864, 'FP': 436, 'FN': 86, 'R': 0.956, 'F': 0.877, 'P': 0.81})
Overall POS accuracy with/out space 0.6734307359307359 0.4328205128205128
Overall SUP:DEP accuracy with/out space 0.24540029835902535 0.24565217391304348

('Tag-wise accuracy',
 [],
 [('SPACE', (0.9421534936998854, 1746)),
  ('ADP', (0.7459677419354839, 248)),
  ('CCONJ', (0.7142857142857143, 49)),
  ('DET', (0.6847826086956522, 184)),
  ('NOUN', (0.5588972431077694, 399)),
  ('PUNCT', (0.5018181818181818, 275)),
  ('VERB', (0.36363636363636365, 176)),
  ('AUX', (0.3157894736842105, 76)),
  ('PRON', (0.2641509433962264, 53)),
  ('NUM', (0.2222222222222222, 63)),
  ('SCONJ', (0.16, 25)),
  ('PART', (0.15, 20)),
  ('PROPN', (0.05202312138728324, 173)),
  ('ADV', (0.03508771929824561, 57)),
  ('ADJ', (0.023622047244094488, 127)),
  ('SYM', (0.0, 20)),
  ('X', (0.0, 5))])
--- start small training/test
--- ife, sparse, uniform-positional encoding
100%|██████████| 1/1 [00:18<00:00, 18.28s/it]

Token segmentation performance with/out space 
Counter({'TP': 1660, 'FP': 874, 'FN': 328, 'R': 0.835, 'F': 0.734, 'P': 0.655}) 
Counter({'TP': 1024, 'FP': 573, 'FN': 27, 'R': 0.974, 'F': 0.773, 'P': 0.641})
Overall POS accuracy with/out space 0.6403420523138833 0.3824928639391056
Overall SUP:DEP accuracy with/out space 0.19179163378058406 0.1959924859110833
Overall s_type accuracy:  0.7037037037037037

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0625, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9295624332977588, 937)),
  ('PUNCT', (0.68125, 160)),
  ('DET', (0.5876288659793815, 97)),
  ('ADP', (0.5765765765765766, 111)),
  ('CCONJ', (0.56, 25)),
  ('NOUN', (0.3884297520661157, 242)),
  ('NUM', (0.32558139534883723, 43)),
  ('AUX', (0.2682926829268293, 41)),
  ('VERB', (0.26548672566371684, 113)),
  ('PRON', (0.22727272727272727, 22)),
  ('PROPN', (0.09375, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('ADJ', (0.0, 83)),
  ('SCONJ', (0.0, 16)),
  ('PART', (0.0, 13)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, sparse, token-positional encoding
100%|██████████| 1/1 [00:20<00:00, 20.01s/it]

Token segmentation performance with/out space 
Counter({'TP': 1658, 'FP': 876, 'FN': 330, 'R': 0.834, 'F': 0.733, 'P': 0.654}) 
Counter({'TP': 1047, 'FP': 550, 'FN': 4, 'R': 0.996, 'F': 0.791, 'P': 0.656})
Overall POS accuracy with/out space 0.6740442655935613 0.3939105613701237
Overall SUP:DEP accuracy with/out space 0.1910023677979479 0.19724483406386975
Overall s_type accuracy:  0.7037037037037037

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0625, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9882604055496265, 937)),
  ('PUNCT', (0.70625, 160)),
  ('DET', (0.6288659793814433, 97)),
  ('ADP', (0.6126126126126126, 111)),
  ('CCONJ', (0.6, 25)),
  ('NOUN', (0.384297520661157, 242)),
  ('NUM', (0.32558139534883723, 43)),
  ('AUX', (0.2682926829268293, 41)),
  ('VERB', (0.26548672566371684, 113)),
  ('PRON', (0.22727272727272727, 22)),
  ('PROPN', (0.09375, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('ADJ', (0.0, 83)),
  ('SCONJ', (0.0, 16)),
  ('PART', (0.0, 13)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, fine tuned, uniform-positional encoding
100%|██████████| 1/1 [00:19<00:00, 19.83s/it]

Token segmentation performance with/out space 
Counter({'TP': 1611, 'FP': 1089, 'FN': 377, 'R': 0.81, 'F': 0.687, 'P': 0.597}) 
Counter({'TP': 1055, 'FP': 716, 'R': 1.004, 'F': 0.748, 'P': 0.596, 'FN': -4})
Overall POS accuracy with/out space 0.619215291750503 0.3368220742150333
Overall SUP:DEP accuracy with/out space 0.1925925925925926 0.19875776397515527
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.935965848452508, 937)),
  ('DET', (0.6082474226804123, 97)),
  ('PUNCT', (0.6, 160)),
  ('ADP', (0.5945945945945946, 111)),
  ('CCONJ', (0.52, 25)),
  ('AUX', (0.36585365853658536, 41)),
  ('NUM', (0.27906976744186046, 43)),
  ('NOUN', (0.256198347107438, 242)),
  ('PART', (0.23076923076923078, 13)),
  ('VERB', (0.168141592920354, 113)),
  ('PRON', (0.09090909090909091, 22)),
  ('PROPN', (0.0625, 32)),
  ('ADJ', (0.04819277108433735, 83)),
  ('ADV', (0.02564102564102564, 39)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, fine tuned, token-positional encoding
100%|██████████| 1/1 [00:22<00:00, 22.50s/it]

Token segmentation performance with/out space 
Counter({'TP': 1666, 'FP': 870, 'FN': 322, 'R': 0.838, 'F': 0.737, 'P': 0.657}) 
Counter({'TP': 1043, 'FP': 556, 'FN': 8, 'R': 0.992, 'F': 0.787, 'P': 0.652})
Overall POS accuracy with/out space 0.6745472837022133 0.384395813510942
Overall SUP:DEP accuracy with/out space 0.16876971608832808 0.17886178861788618
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('SPACE', (1.0, 937)),
  ('PUNCT', (0.74375, 160)),
  ('DET', (0.6597938144329897, 97)),
  ('CCONJ', (0.6, 25)),
  ('ADP', (0.5945945945945946, 111)),
  ('AUX', (0.34146341463414637, 41)),
  ('NOUN', (0.28512396694214875, 242)),
  ('VERB', (0.2831858407079646, 113)),
  ('NUM', (0.27906976744186046, 43)),
  ('PRON', (0.18181818181818182, 22)),
  ('PART', (0.15384615384615385, 13)),
  ('PROPN', (0.0625, 32)),
  ('ADJ', (0.04819277108433735, 83)),
  ('ADV', (0.02564102564102564, 39)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- form, sparse, uniform-positional encoding
Token segmentation performance with/out space: 
Counter({'TP': 1567, 'FP': 1035, 'FN': 367, 'R': 0.81, 'F': 0.691, 'P': 0.602}) 
Counter({'TP': 1013, 'FP': 706, 'FN': 38, 'R': 0.964, 'F': 0.731, 'P': 0.589})
Overall POS accuracy with/out space 0.655635987590486 0.36631779257849667
Overall SUP:DEP accuracy with/out space 0.1848578016910069 0.1878999418266434
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('PUNCT', (0.9674017257909875, 1043)),
  ('ADP', (0.6846846846846847, 111)),
  ('DET', (0.5463917525773195, 97)),
  ('NUM', (0.37209302325581395, 43)),
  ('NOUN', (0.28512396694214875, 242)),
  ('PART', (0.23076923076923078, 13)),
  ('VERB', (0.23008849557522124, 113)),
  ('CCONJ', (0.2, 25)),
  ('PRON', (0.13636363636363635, 22)),
  ('PROPN', (0.0625, 32)),
  ('ADV', (0.05128205128205128, 39)),
  ('AUX', (0.04878048780487805, 41)),
  ('ADJ', (0.024096385542168676, 83)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- end small training/test
--- start full training/test
--- ife, sparse, uniform-positional encoding
100%|██████████| 18/18 [43:41<00:00, 145.63s/it]

Token segmentation performance with/out space 
Counter({'TP': 28408, 'FP': 3461, 'FN': 1724, 'R': 0.943, 'F': 0.916, 'P': 0.891}) 
Counter({'TP': 15983, 'FP': 1974, 'FN': 233, 'R': 0.986, 'F': 0.936, 'P': 0.89})
Overall POS accuracy with/out space 0.7840501792114696 0.6013196842624569
Overall SUP:DEP accuracy with/out space 0.34704571840973986 0.3442111711310353
Overall s_type accuracy:  0.7237136465324385

('Tag-wise accuracy',
 [('decl', (0.966044142614601, 589)),
  ('intj', (0.5769230769230769, 26)),
  ('q', (0.5, 16)),
  ('inf', (0.4, 5)),
  ('wh', (0.38095238095238093, 21)),
  ('frag', (0.3225806451612903, 93)),
  ('imp', (0.20408163265306123, 49)),
  ('sub', (0.12195121951219512, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8))],
 [('SPACE', (0.9969818913480886, 13916)),
  ('PUNCT', (0.8689320388349514, 2060)),
  ('DET', (0.8190271816881259, 1398)),
  ('ADP', (0.7428068115091015, 1703)),
  ('NOUN', (0.7381353079771121, 2971)),
  ('CCONJ', (0.7380952380952381, 588)),
  ('PRON', (0.6666666666666666, 1077)),
  ('SCONJ', (0.6024590163934426, 244)),
  ('AUX', (0.5242718446601942, 721)),
  ('VERB', (0.48768768768768767, 1665)),
  ('INTJ', (0.375, 88)),
  ('PART', (0.3074626865671642, 335)),
  ('NUM', (0.22448979591836735, 343)),
  ('PROPN', (0.22123893805309736, 1243)),
  ('ADJ', (0.2206148282097649, 1106)),
  ('ADV', (0.2137030995106036, 613)),
  ('SYM', (0.17142857142857143, 35)),
  ('X', (0.0, 26))])
--- ife, sparse, token-positional encoding
100%|██████████| 18/18 [07:05<00:00, 23.61s/it]

Token segmentation performance with/out space 
Counter({'TP': 29317, 'FP': 3473, 'FN': 1709, 'R': 0.945, 'F': 0.919, 'P': 0.894}) 
Counter({'TP': 16083, 'FP': 1897, 'FN': 133, 'R': 0.992, 'F': 0.94, 'P': 0.894})
Overall POS accuracy with/out space 0.795687487913363 0.6090897878638382
Overall SUP:DEP accuracy with/out space 0.3339737724916133 0.33403781979977754
Overall s_type accuracy:  0.7136465324384788

('Tag-wise accuracy',
 [('decl', (0.9864176570458404, 589)),
  ('intj', (0.46153846153846156, 26)),
  ('q', (0.375, 16)),
  ('wh', (0.2857142857142857, 21)),
  ('frag', (0.26881720430107525, 93)),
  ('inf', (0.2, 5)),
  ('imp', (0.12244897959183673, 49)),
  ('sub', (0.024390243902439025, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.8941747572815534, 2060)),
  ('DET', (0.8183118741058655, 1398)),
  ('ADP', (0.7680563711098062, 1703)),
  ('CCONJ', (0.7517006802721088, 588)),
  ('NOUN', (0.7485695052170986, 2971)),
  ('PRON', (0.7056638811513464, 1077)),
  ('AUX', (0.5312066574202496, 721)),
  ('VERB', (0.4924924924924925, 1665)),
  ('SCONJ', (0.3770491803278688, 244)),
  ('INTJ', (0.3068181818181818, 88)),
  ('PART', (0.29850746268656714, 335)),
  ('PROPN', (0.2333065164923572, 1243)),
  ('ADJ', (0.22423146473779385, 1106)),
  ('NUM', (0.21865889212827988, 343)),
  ('ADV', (0.19086460032626426, 613)),
  ('SYM', (0.14285714285714285, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned, uniform-positional encoding
100%|██████████| 18/18 [2:34:46<00:00, 515.94s/it]

Token segmentation performance with/out space:
Counter({'TP': 27967, 'FP': 4668, 'FN': 2165, 'R': 0.928, 'F': 0.891, 'P': 0.857}) 
Counter({'TP': 16028, 'FP': 2718, 'FN': 188, 'R': 0.988, 'F': 0.917, 'P': 0.855})
Overall POS accuracy with/out space 0.7831541218637993 0.5993463246176616
Overall SUP:DEP accuracy with/out space 0.30378428068025126 0.30390483303104665
Overall s_type accuracy:  0.7125279642058165

('Tag-wise accuracy',
 [('decl', (0.9643463497453311, 589)),
  ('intj', (0.6153846153846154, 26)),
  ('q', (0.5, 16)),
  ('frag', (0.3010752688172043, 93)),
  ('wh', (0.23809523809523808, 21)),
  ('imp', (0.16326530612244897, 49)),
  ('sub', (0.07317073170731707, 41)),
  ('multiple', (0.03125, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (0.9973411899971256, 13916)),
  ('PUNCT', (0.9242718446601942, 2060)),
  ('DET', (0.9062947067238912, 1398)),
  ('CCONJ', (0.8690476190476191, 588)),
  ('ADP', (0.7862595419847328, 1703)),
  ('PRON', (0.7520891364902507, 1077)),
  ('AUX', (0.7128987517337032, 721)),
  ('PART', (0.6746268656716418, 335)),
  ('SCONJ', (0.5245901639344263, 244)),
  ('NOUN', (0.49814877145742176, 2971)),
  ('INTJ', (0.36363636363636365, 88)),
  ('VERB', (0.35315315315315315, 1665)),
  ('PROPN', (0.30973451327433627, 1243)),
  ('ADJ', (0.26763110307414106, 1106)),
  ('NUM', (0.2478134110787172, 343)),
  ('ADV', (0.2463295269168026, 613)),
  ('SYM', (0.08571428571428572, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned, token-positional encoding
100%|██████████| 18/18 [18:42<00:00, 62.39s/it]

Token segmentation performance with/out space 
Counter({'TP': 29101, 'FP': 3984, 'FN': 1925, 'R': 0.938, 'F': 0.908, 'P': 0.88}) 
Counter({'TP': 16028, 'FP': 2247, 'FN': 188, 'R': 0.988, 'F': 0.929, 'P': 0.877})
Overall POS accuracy with/out space 0.7975568877715464 0.6126665022200296
Overall SUP:DEP accuracy with/out space 0.28958742632612966 0.28891928864569083
Overall s_type accuracy:  0.6778523489932886

('Tag-wise accuracy',
 [('decl', (0.99830220713073, 589)),
  ('intj', (0.34615384615384615, 26)),
  ('frag', (0.08602150537634409, 93)),
  ('q', (0.0625, 16)),
  ('imp', (0.0, 49)),
  ('sub', (0.0, 41)),
  ('multiple', (0.0, 32)),
  ('wh', (0.0, 21)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.9529126213592233, 2060)),
  ('DET', (0.9213161659513591, 1398)),
  ('CCONJ', (0.8401360544217688, 588)),
  ('ADP', (0.806224310041104, 1703)),
  ('PRON', (0.7325905292479109, 1077)),
  ('PART', (0.6985074626865672, 335)),
  ('AUX', (0.6463245492371706, 721)),
  ('NOUN', (0.5412319084483339, 2971)),
  ('SCONJ', (0.45491803278688525, 244)),
  ('VERB', (0.4096096096096096, 1665)),
  ('PROPN', (0.31053901850362026, 1243)),
  ('INTJ', (0.3068181818181818, 88)),
  ('ADJ', (0.2585895117540687, 1106)),
  ('NUM', (0.24489795918367346, 343)),
  ('ADV', (0.22838499184339314, 613)),
  ('SYM', (0.11428571428571428, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned using all pre-training data, token-positional encoding
100%|██████████| 18/18 [12:42<00:00, 42.36s/it]

Token segmentation performance with/out space 
Counter({'TP': 28946, 'FP': 4430, 'FN': 2080, 'R': 0.933, 'F': 0.899, 'P': 0.867}) 
Counter({'TP': 16129, 'FP': 2437, 'FN': 87, 'R': 0.995, 'F': 0.928, 'P': 0.869})
Overall POS accuracy with/out space 0.7999742151743698 0.6172915638875185
Overall SUP:DEP accuracy with/out space 0.28577420901246403 0.28374447915544543
Overall s_type accuracy:  0.6789709172259508

('Tag-wise accuracy',
 [('decl', (0.99830220713073, 589)),
  ('intj', (0.19230769230769232, 26)),
  ('frag', (0.15053763440860216, 93)),
  ('imp', (0.0, 49)),
  ('sub', (0.0, 41)),
  ('multiple', (0.0, 32)),
  ('wh', (0.0, 21)),
  ('q', (0.0, 16)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.9601941747572815, 2060)),
  ('DET', (0.9277539341917024, 1398)),
  ('CCONJ', (0.8537414965986394, 588)),
  ('ADP', (0.8027011156782149, 1703)),
  ('PART', (0.7850746268656716, 335)),
  ('PRON', (0.7678737233054782, 1077)),
  ('AUX', (0.6907073509015257, 721)),
  ('SCONJ', (0.6229508196721312, 244)),
  ('NOUN', (0.47088522383036013, 2971)),
  ('VERB', (0.4114114114114114, 1665)),
  ('INTJ', (0.4090909090909091, 88)),
  ('NUM', (0.35276967930029157, 343)),
  ('PROPN', (0.3153660498793242, 1243)),
  ('ADV', (0.2969004893964111, 613)),
  ('ADJ', (0.2730560578661845, 1106)),
  ('SYM', (0.2571428571428571, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned using all pre-training data, token-positional encoding, no spaces
Token segmentation performance with/out space 
Counter({'TP': 12109, 'FP': 4632, 'FN': 4107, 'R': 0.747, 'F': 0.735, 'P': 0.723}) 
Counter({'TP': 12109, 'FP': 4632, 'FN': 4107, 'R': 0.747, 'F': 0.735, 'P': 0.723})
Overall POS accuracy with/out space 0.5549457326097681 0.5549457326097681
Overall SUP:DEP accuracy with/out space 0.1356549787945762 0.1356549787945762

('Tag-wise accuracy',
 [],
 [('PUNCT', (0.9097087378640777, 2060)),
  ('DET', (0.8497854077253219, 1398)),
  ('CCONJ', (0.8180272108843537, 588)),
  ('ADP', (0.7146212566059894, 1703)),
  ('INTJ', (0.5113636363636364, 88)),
  ('AUX', (0.5104022191400832, 721)),
  ('SCONJ', (0.5081967213114754, 244)),
  ('PRON', (0.5051067780872794, 1077)),
  ('NOUN', (0.4867048131942107, 2971)),
  ('ADJ', (0.4240506329113924, 1106)),
  ('VERB', (0.34654654654654654, 1665)),
  ('PART', (0.3283582089552239, 335)),
  ('ADV', (0.27569331158238175, 613)),
  ('NUM', (0.2478134110787172, 343)),
  ('PROPN', (0.24215607401448108, 1243)),
  ('SYM', (0.02857142857142857, 35)),
  ('X', (0.0, 26))])
--- form, sparse, uniform-positional encoding
100%|██████████| 18/18 [9:55:08<00:00, 1983.80s/it]

Token segmentation performance with/out space 
Counter({'TP': 29528, 'FP': 3009, 'FN': 1498, 'R': 0.952, 'F': 0.929, 'P': 0.908}) 
Counter({'TP': 16051, 'FP': 1676, 'FN': 165, 'R': 0.99, 'F': 0.946, 'P': 0.905})
Overall POS accuracy with/out space 0.8271127441500676 0.6697705969412926
Overall SUP:DEP accuracy with/out space 0.3659218735593325 0.36684154115191514
Overall s_type accuracy:  0.7304250559284117

('Tag-wise accuracy',
 [('decl', (0.9881154499151104, 589)),
  ('intj', (0.5384615384615384, 26)),
  ('q', (0.5, 16)),
  ('imp', (0.3673469387755102, 49)),
  ('frag', (0.27956989247311825, 93)),
  ('wh', (0.19047619047619047, 21)),
  ('sub', (0.024390243902439025, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (0.9993923024983119, 14810)),
  ('PUNCT', (0.9271844660194175, 2060)),
  ('NOUN', (0.8391114102995625, 2971)),
  ('DET', (0.7982832618025751, 1398)),
  ('ADP', (0.7639459776864357, 1703)),
  ('PRON', (0.754874651810585, 1077)),
  ('VERB', (0.6216216216216216, 1665)),
  ('CCONJ', (0.5952380952380952, 588)),
  ('AUX', (0.5409153952843273, 721)),
  ('SCONJ', (0.4344262295081967, 244)),
  ('ADJ', (0.4204339963833635, 1106)),
  ('ADV', (0.3866231647634584, 613)),
  ('NUM', (0.36443148688046645, 343)),
  ('PROPN', (0.3153660498793242, 1243)),
  ('INTJ', (0.3068181818181818, 88)),
  ('PART', (0.29253731343283584, 335)),
  ('SYM', (0.08571428571428572, 35)),
  ('X', (0.0, 26))])
```

In [9]:
from tqdm import tqdm

accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

model.interpret(tdocs, seed = 691, covering = [[[row[1] for row in s] for s in d['conllu']] for d in test_docs])

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[d_i]['s_type'][s_i]
            accuracy_sty[test_docs[d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, d_i, s_i) for d_i, doc in enumerate(model._documents) 
                 for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], d_i, s_i) 
                 for d_i, d in enumerate(test_docs) for s_i, s in enumerate(d['conllu']) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
"Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                   for tag in accuracy_sty}).most_common()), list(Counter({tag: (sum(accuracy[tag])/len(accuracy[tag]), len(accuracy[tag])) 
                                                                                           for tag in accuracy}).most_common())


0it [00:00, ?it/s][A
1it [00:01,  1.41s/it][A
2it [00:02,  1.19s/it][A
3it [00:02,  1.01it/s][A
4it [00:03,  1.10it/s][A
5it [00:03,  1.24it/s][A
6it [00:04,  1.27it/s][A
7it [00:05,  1.30it/s][A
8it [00:06,  1.18it/s][A
9it [00:06,  1.31it/s][A
10it [00:07,  1.24it/s][A
11it [00:08,  1.43it/s][A
12it [00:08,  1.47it/s][A
13it [00:09,  1.44it/s][A
14it [00:10,  1.53it/s][A
15it [00:10,  1.46it/s][A
16it [00:11,  1.58it/s][A
17it [00:11,  1.70it/s][A
18it [00:12,  1.41it/s][A

  0%|          | 0/18 [00:00<?, ?it/s][A

Interpreting documents...



  6%|▌         | 1/18 [00:16<04:45, 16.78s/it][A
 11%|█         | 2/18 [00:32<04:25, 16.57s/it][A
 17%|█▋        | 3/18 [00:44<03:44, 14.99s/it][A
 22%|██▏       | 4/18 [01:01<03:39, 15.67s/it][A
 28%|██▊       | 5/18 [01:10<02:58, 13.76s/it][A
 33%|███▎      | 6/18 [01:24<02:45, 13.77s/it][A
 39%|███▉      | 7/18 [01:41<02:42, 14.76s/it][A
 44%|████▍     | 8/18 [01:52<02:16, 13.61s/it][A
 50%|█████     | 9/18 [02:06<02:03, 13.69s/it][A
 56%|█████▌    | 10/18 [02:28<02:09, 16.14s/it][A
 61%|██████    | 11/18 [02:38<01:39, 14.25s/it][A
 67%|██████▋   | 12/18 [02:53<01:27, 14.54s/it][A
 72%|███████▏  | 13/18 [03:08<01:13, 14.65s/it][A
 78%|███████▊  | 14/18 [03:21<00:57, 14.26s/it][A
 83%|████████▎ | 15/18 [03:43<00:49, 16.42s/it][A
 89%|████████▉ | 16/18 [03:52<00:28, 14.32s/it][A
 94%|█████████▍| 17/18 [04:01<00:12, 12.79s/it][A
100%|██████████| 18/18 [04:14<00:00, 14.14s/it][A

Overall POS accuracy with/out space 0.2796042029265777 0.2038727183029107
Overall SUP:DEP accuracy with/out space 0.23217626943673786 0.23085001235483074





('Tag-wise accuracy',
 [],
 [('NOUN', (0.6442275328172332, 2971)),
  ('SPACE', (0.362525320729237, 14810)),
  ('INTJ', (0.3068181818181818, 88)),
  ('PRON', (0.30269266480965645, 1077)),
  ('PUNCT', (0.2383495145631068, 2060)),
  ('NUM', (0.1661807580174927, 343)),
  ('VERB', (0.15555555555555556, 1665)),
  ('ADV', (0.11745513866231648, 613)),
  ('PART', (0.07462686567164178, 335)),
  ('PROPN', (0.05148833467417538, 1243)),
  ('AUX', (0.02912621359223301, 721)),
  ('CCONJ', (0.02040816326530612, 588)),
  ('ADP', (0.012918379330593071, 1703)),
  ('DET', (0.009298998569384835, 1398)),
  ('ADJ', (0.0027124773960217, 1106)),
  ('SCONJ', (0.0, 244)),
  ('SYM', (0.0, 35)),
  ('X', (0.0, 26))])

```
--- start small training/test
--- ife, sparse, uniform-positional encoding
100%|██████████| 1/1 [00:13<00:00, 13.63s/it]

Overall POS accuracy with/out space 0.7208249496981891 0.5242626070409134
Overall SUP:DEP accuracy with/out space 0.499496475327291 0.4918970448045758
Overall s_type accuracy:  0.7037037037037037

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0625, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9413020277481323, 937)),
  ('NOUN', (0.871900826446281, 242)),
  ('PUNCT', (0.70625, 160)),
  ('DET', (0.5876288659793815, 97)),
  ('ADP', (0.5855855855855856, 111)),
  ('CCONJ', (0.56, 25)),
  ('VERB', (0.46017699115044247, 113)),
  ('PRON', (0.36363636363636365, 22)),
  ('NUM', (0.32558139534883723, 43)),
  ('AUX', (0.2682926829268293, 41)),
  ('PROPN', (0.15625, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('ADJ', (0.0, 83)),
  ('SCONJ', (0.0, 16)),
  ('PART', (0.0, 13)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, sparse, token-positional encoding
100%|██████████| 1/1 [00:12<00:00, 12.83s/it]

Overall POS accuracy with/out space 0.7540241448692153 0.5394862036156042
Overall SUP:DEP accuracy with/out space 0.5161127895266868 0.5147759771210677
Overall s_type accuracy:  0.7037037037037037

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0625, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9946638207043756, 937)),
  ('NOUN', (0.8801652892561983, 242)),
  ('PUNCT', (0.7375, 160)),
  ('ADP', (0.6306306306306306, 111)),
  ('DET', (0.6288659793814433, 97)),
  ('CCONJ', (0.6, 25)),
  ('VERB', (0.4690265486725664, 113)),
  ('NUM', (0.32558139534883723, 43)),
  ('PRON', (0.2727272727272727, 22)),
  ('AUX', (0.2682926829268293, 41)),
  ('PROPN', (0.15625, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('ADJ', (0.0, 83)),
  ('SCONJ', (0.0, 16)),
  ('PART', (0.0, 13)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, fine tuned, uniform-positional encoding 
100%|██████████| 1/1 [00:12<00:00, 12.06s/it]

Overall POS accuracy with/out space 0.6896378269617707 0.4652711703139867
Overall SUP:DEP accuracy with/out space 0.5115810674723061 0.5081029551954243
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('SPACE', (0.9413020277481323, 937)),
  ('NOUN', (0.6570247933884298, 242)),
  ('ADP', (0.6306306306306306, 111)),
  ('DET', (0.6185567010309279, 97)),
  ('PUNCT', (0.6125, 160)),
  ('CCONJ', (0.52, 25)),
  ('AUX', (0.3902439024390244, 41)),
  ('VERB', (0.3893805309734513, 113)),
  ('NUM', (0.27906976744186046, 43)),
  ('PART', (0.23076923076923078, 13)),
  ('PRON', (0.18181818181818182, 22)),
  ('ADJ', (0.08433734939759036, 83)),
  ('PROPN', (0.0625, 32)),
  ('ADV', (0.02564102564102564, 39)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- ife, fine tuned, token-positional encoding
100%|██████████| 1/1 [00:15<00:00, 15.76s/it]

Overall POS accuracy with/out space 0.7364185110663984 0.5014272121788773
Overall SUP:DEP accuracy with/out space 0.5518630412890232 0.5643469971401335
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('SPACE', (1.0, 937)),
  ('PUNCT', (0.75625, 160)),
  ('NOUN', (0.6694214876033058, 242)),
  ('DET', (0.6597938144329897, 97)),
  ('ADP', (0.6396396396396397, 111)),
  ('CCONJ', (0.6, 25)),
  ('VERB', (0.45132743362831856, 113)),
  ('AUX', (0.34146341463414637, 41)),
  ('NUM', (0.27906976744186046, 43)),
  ('PRON', (0.18181818181818182, 22)),
  ('PROPN', (0.15625, 32)),
  ('PART', (0.15384615384615385, 13)),
  ('ADJ', (0.060240963855421686, 83)),
  ('ADV', (0.02564102564102564, 39)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- form, sparse, uniform-positional encoding
Overall POS accuracy with/out space 0.7471561530506722 0.5347288296860133
Overall SUP:DEP accuracy with/out space 0.5377846790890269 0.5243088655862727
Overall s_type accuracy:  0.6851851851851852

('Tag-wise accuracy',
 [('decl', (1.0, 37)), ('frag', (0.0, 16)), ('wh', (0.0, 1))],
 [('PUNCT', (0.9731543624161074, 1043)),
  ('NOUN', (0.8842975206611571, 242)),
  ('ADP', (0.7027027027027027, 111)),
  ('DET', (0.5567010309278351, 97)),
  ('NUM', (0.3953488372093023, 43)),
  ('VERB', (0.3893805309734513, 113)),
  ('PART', (0.23076923076923078, 13)),
  ('PRON', (0.22727272727272727, 22)),
  ('CCONJ', (0.2, 25)),
  ('PROPN', (0.09375, 32)),
  ('ADV', (0.07692307692307693, 39)),
  ('AUX', (0.04878048780487805, 41)),
  ('ADJ', (0.024096385542168676, 83)),
  ('SCONJ', (0.0, 16)),
  ('SYM', (0.0, 12)),
  ('X', (0.0, 2))])
--- end small training/test
--- start full training/test
--- ife, sparse, uniform-positional encoding
100%|██████████| 18/18 [39:12<00:00, 130.72s/it]

Overall POS accuracy with/out space 0.8118279569892473 0.6528737049827331
Overall SUP:DEP accuracy with/out space 0.6086566569226681 0.6032863849765259
Overall s_type accuracy:  0.7248322147651006

('Tag-wise accuracy',
 [('decl', (0.9643463497453311, 589)),
  ('intj', (0.5769230769230769, 26)),
  ('q', (0.5625, 16)),
  ('inf', (0.4, 5)),
  ('wh', (0.38095238095238093, 21)),
  ('frag', (0.3225806451612903, 93)),
  ('imp', (0.20408163265306123, 49)),
  ('sub', (0.12195121951219512, 41)),
  ('other', (0.07142857142857142, 14)),
  ('multiple', (0.0, 32)),
  ('ger', (0.0, 8))],
 [('SPACE', (0.997053751077896, 13916)),
  ('PUNCT', (0.8820388349514563, 2060)),
  ('NOUN', (0.8801750252440256, 2971)),
  ('DET', (0.8204577968526466, 1398)),
  ('ADP', (0.7563123899001761, 1703)),
  ('CCONJ', (0.7380952380952381, 588)),
  ('PRON', (0.7214484679665738, 1077)),
  ('SCONJ', (0.6065573770491803, 244)),
  ('AUX', (0.5866851595006934, 721)),
  ('VERB', (0.5321321321321322, 1665)),
  ('INTJ', (0.38636363636363635, 88)),
  ('PART', (0.382089552238806, 335)),
  ('NUM', (0.3119533527696793, 343)),
  ('PROPN', (0.3049074818986323, 1243)),
  ('ADV', (0.24796084828711257, 613)),
  ('ADJ', (0.22151898734177214, 1106)),
  ('SYM', (0.17142857142857143, 35)),
  ('X', (0.038461538461538464, 26))])
--- ife, sparse, token-positional encoding
100%|██████████| 18/18 [06:18<00:00, 21.02s/it]

Overall POS accuracy with/out space 0.8213756204473667 0.6582387765170202
Overall SUP:DEP accuracy with/out space 0.5965546164268662 0.5992092908327156
Overall s_type accuracy:  0.7125279642058165

('Tag-wise accuracy',
 [('decl', (0.9847198641765704, 589)),
  ('intj', (0.46153846153846156, 26)),
  ('q', (0.375, 16)),
  ('wh', (0.2857142857142857, 21)),
  ('frag', (0.26881720430107525, 93)),
  ('inf', (0.2, 5)),
  ('imp', (0.12244897959183673, 49)),
  ('sub', (0.024390243902439025, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.9053398058252428, 2060)),
  ('NOUN', (0.8872433524065971, 2971)),
  ('DET', (0.8190271816881259, 1398)),
  ('ADP', (0.7821491485613623, 1703)),
  ('PRON', (0.7595171773444754, 1077)),
  ('CCONJ', (0.7517006802721088, 588)),
  ('AUX', (0.5936199722607489, 721)),
  ('VERB', (0.5315315315315315, 1665)),
  ('SCONJ', (0.38114754098360654, 244)),
  ('PART', (0.3611940298507463, 335)),
  ('PROPN', (0.31938857602574416, 1243)),
  ('INTJ', (0.3181818181818182, 88)),
  ('NUM', (0.26239067055393583, 343)),
  ('ADJ', (0.22694394213381555, 1106)),
  ('ADV', (0.22512234910277323, 613)),
  ('SYM', (0.14285714285714285, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned, uniform-positional encoding
100%|██████████| 18/18 [2:29:47<00:00, 499.28s/it]

Overall POS accuracy with/out space 0.816938802601885 0.6610754810064134
Overall SUP:DEP accuracy with/out space 0.5741097528567632 0.573943661971831
Overall s_type accuracy:  0.70917225950783

('Tag-wise accuracy',
 [('decl', (0.9609507640067911, 589)),
  ('intj', (0.6153846153846154, 26)),
  ('q', (0.5, 16)),
  ('frag', (0.3010752688172043, 93)),
  ('wh', (0.23809523809523808, 21)),
  ('imp', (0.14285714285714285, 49)),
  ('sub', (0.07317073170731707, 41)),
  ('multiple', (0.03125, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (0.9985628054038517, 13916)),
  ('PUNCT', (0.9266990291262136, 2060)),
  ('DET', (0.9084406294706724, 1398)),
  ('CCONJ', (0.8707482993197279, 588)),
  ('PRON', (0.8142989786443825, 1077)),
  ('ADP', (0.8038755137991779, 1703)),
  ('AUX', (0.7572815533980582, 721)),
  ('PART', (0.7522388059701492, 335)),
  ('NOUN', (0.6893301918545944, 2971)),
  ('SCONJ', (0.5245901639344263, 244)),
  ('VERB', (0.42162162162162165, 1665)),
  ('PROPN', (0.4006436041834272, 1243)),
  ('INTJ', (0.375, 88)),
  ('NUM', (0.2915451895043732, 343)),
  ('ADV', (0.27569331158238175, 613)),
  ('ADJ', (0.2739602169981917, 1106)),
  ('SYM', (0.08571428571428572, 35)),
  ('X', (0.038461538461538464, 26))])
--- ife, fine tuned, token-positional encoding
100%|██████████| 18/18 [10:14<00:00, 34.12s/it]

Overall POS accuracy with/out space 0.8244697995229807 0.6641588554514061
Overall SUP:DEP accuracy with/out space 0.5768114071875605 0.5762910798122066
Overall s_type accuracy:  0.6778523489932886

('Tag-wise accuracy',
 [('decl', (0.99830220713073, 589)),
  ('intj', (0.34615384615384615, 26)),
  ('frag', (0.08602150537634409, 93)),
  ('q', (0.0625, 16)),
  ('imp', (0.0, 49)),
  ('sub', (0.0, 41)),
  ('multiple', (0.0, 32)),
  ('wh', (0.0, 21)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.954368932038835, 2060)),
  ('DET', (0.9227467811158798, 1398)),
  ('CCONJ', (0.8401360544217688, 588)),
  ('ADP', (0.8244274809160306, 1703)),
  ('PRON', (0.7883008356545961, 1077)),
  ('PART', (0.7820895522388059, 335)),
  ('NOUN', (0.6886570178391114, 2971)),
  ('AUX', (0.6851595006934813, 721)),
  ('VERB', (0.48228228228228226, 1665)),
  ('SCONJ', (0.45491803278688525, 244)),
  ('PROPN', (0.38777152051488334, 1243)),
  ('INTJ', (0.32954545454545453, 88)),
  ('NUM', (0.29737609329446063, 343)),
  ('ADJ', (0.2603978300180832, 1106)),
  ('ADV', (0.23817292006525284, 613)),
  ('SYM', (0.11428571428571428, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned using all pre-training data, token-positional encoding
100%|██████████| 18/18 [12:25<00:00, 41.44s/it]

Overall POS accuracy with/out space 0.8316895506994134 0.6779723729649728
Overall SUP:DEP accuracy with/out space 0.5690367120459384 0.5663454410674573
Overall s_type accuracy:  0.6789709172259508

('Tag-wise accuracy',
 [('decl', (0.99830220713073, 589)),
  ('intj', (0.19230769230769232, 26)),
  ('frag', (0.15053763440860216, 93)),
  ('imp', (0.0, 49)),
  ('sub', (0.0, 41)),
  ('multiple', (0.0, 32)),
  ('wh', (0.0, 21)),
  ('q', (0.0, 16)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (1.0, 14810)),
  ('PUNCT', (0.9616504854368932, 2060)),
  ('DET', (0.9284692417739628, 1398)),
  ('PART', (0.8776119402985074, 335)),
  ('CCONJ', (0.8537414965986394, 588)),
  ('ADP', (0.815619495008808, 1703)),
  ('PRON', (0.8133704735376045, 1077)),
  ('AUX', (0.723994452149792, 721)),
  ('NOUN', (0.6560080780881858, 2971)),
  ('SCONJ', (0.6229508196721312, 244)),
  ('VERB', (0.4864864864864865, 1665)),
  ('NUM', (0.4518950437317784, 343)),
  ('INTJ', (0.42045454545454547, 88)),
  ('PROPN', (0.415124698310539, 1243)),
  ('ADV', (0.32300163132137033, 613)),
  ('ADJ', (0.2766726943942134, 1106)),
  ('SYM', (0.2571428571428571, 35)),
  ('X', (0.0, 26))])
--- ife, fine tuned using all pre-training data, token-positional encoding, no spaces
100%|██████████| 18/18 [02:36<00:00,  8.67s/it]

Overall POS accuracy with/out space 0.712074494326591 0.712074494326591
Overall SUP:DEP accuracy with/out space 0.2855201383741043 0.2855201383741043

('Tag-wise accuracy',
 [],
 [('PUNCT', (0.9169902912621359, 2060)),
  ('DET', (0.9120171673819742, 1398)),
  ('ADP', (0.8461538461538461, 1703)),
  ('CCONJ', (0.8384353741496599, 588)),
  ('AUX', (0.8377253814147018, 721)),
  ('PRON', (0.8217270194986073, 1077)),
  ('PART', (0.7283582089552239, 335)),
  ('NOUN', (0.6974082800403905, 2971)),
  ('SCONJ', (0.6311475409836066, 244)),
  ('VERB', (0.5897897897897898, 1665)),
  ('INTJ', (0.5681818181818182, 88)),
  ('ADJ', (0.5289330922242315, 1106)),
  ('ADV', (0.41435562805872755, 613)),
  ('PROPN', (0.40788415124698313, 1243)),
  ('NUM', (0.3119533527696793, 343)),
  ('X', (0.15384615384615385, 26)),
  ('SYM', (0.02857142857142857, 35))])
--- form, sparse, uniform-positional encoding
100%|██████████| 18/18 [10:36:22<00:00, 2121.24s/it]

Overall POS accuracy with/out space 0.8492232321278927 0.7121361618154909
Overall SUP:DEP accuracy with/out space 0.6333957029485773 0.6320731405979738
Overall s_type accuracy:  0.727069351230425

('Tag-wise accuracy',
 [('decl', (0.9830220713073005, 589)),
  ('intj', (0.5384615384615384, 26)),
  ('q', (0.5, 16)),
  ('imp', (0.3673469387755102, 49)),
  ('frag', (0.27956989247311825, 93)),
  ('wh', (0.19047619047619047, 21)),
  ('sub', (0.024390243902439025, 41)),
  ('multiple', (0.0, 32)),
  ('other', (0.0, 14)),
  ('ger', (0.0, 8)),
  ('inf', (0.0, 5))],
 [('SPACE', (0.9993247805536799, 14810)),
  ('NOUN', (0.9447997307303938, 2971)),
  ('PUNCT', (0.9373786407766991, 2060)),
  ('PRON', (0.8050139275766016, 1077)),
  ('DET', (0.7982832618025751, 1398)),
  ('ADP', (0.7780387551379918, 1703)),
  ('VERB', (0.6582582582582582, 1665)),
  ('CCONJ', (0.5952380952380952, 588)),
  ('AUX', (0.5880721220527045, 721)),
  ('SCONJ', (0.4385245901639344, 244)),
  ('NUM', (0.4314868804664723, 343)),
  ('ADJ', (0.42495479204339964, 1106)),
  ('ADV', (0.42251223491027734, 613)),
  ('PROPN', (0.4022526146419952, 1243)),
  ('PART', (0.3492537313432836, 335)),
  ('INTJ', (0.3181818181818182, 88)),
  ('SYM', (0.08571428571428572, 35)),
  ('X', (0.0, 26))])
```

In [10]:
model.fine_tune(docs, covering = covering, all_layers = all_layers, streams = data_streams)


  0%|          | 0/132 [00:00<?, ?it/s][A

Fine-tuning dense output heads...



  1%|          | 1/132 [00:04<10:07,  4.64s/it][A
  2%|▏         | 2/132 [00:08<09:47,  4.52s/it][A
  2%|▏         | 3/132 [00:12<09:19,  4.34s/it][A
  3%|▎         | 4/132 [00:17<09:37,  4.51s/it][A
  4%|▍         | 5/132 [00:24<10:44,  5.08s/it][A
  5%|▍         | 6/132 [00:29<10:53,  5.19s/it][A
  5%|▌         | 7/132 [00:33<10:05,  4.85s/it][A
  6%|▌         | 8/132 [00:38<09:58,  4.82s/it][A
  7%|▋         | 9/132 [00:42<09:14,  4.51s/it][A
  8%|▊         | 10/132 [00:45<08:36,  4.23s/it][A
  8%|▊         | 11/132 [00:51<09:20,  4.63s/it][A
  9%|▉         | 12/132 [00:57<10:07,  5.06s/it][A
 10%|▉         | 13/132 [01:03<10:39,  5.38s/it][A
 11%|█         | 14/132 [01:08<10:24,  5.30s/it][A
 11%|█▏        | 15/132 [01:13<10:21,  5.31s/it][A
 12%|█▏        | 16/132 [01:16<08:56,  4.62s/it][A
 13%|█▎        | 17/132 [01:22<09:26,  4.92s/it][A
 14%|█▎        | 18/132 [01:26<08:37,  4.54s/it][A
 14%|█▍        | 19/132 [01:30<08:22,  4.45s/it][A
 15%|█▌        | 20/

In [11]:
from tqdm import tqdm

confusion = Counter()
confusion_nsp = Counter()
accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

model.interpret(tdocs, seed = 691)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_spans = set([(sh-len(gt), sh) for sh, gt in zip(pred_spans, pred_toks)])

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
# gold_spans = set([(sh-len(gt), sh) for sh, gt in zip(gold_spans, gold_toks)])
gold_spans_nsp = {(sh-len(gt), sh): 1 for sh, gt in zip(gold_spans, gold_toks)}
gold_spans = {(sh-len(gt), sh): 1 for sh, gt in zip(gold_spans, gold_toks)}

confusion = {"TP": 0, "FP": 0, "FN": 0, "P": 0, "R": 0, "F": 0}
confusion_nsp = {"TP": 0, "FP": 0, "FN": 0, "P": 0, "R": 0, "F": 0}
for pred_span, pred_tok in zip(pred_spans, pred_toks):
    if pred_span not in gold_spans:
        confusion['FP'] += 1
        if pred_tok != ' ':
            confusion_nsp['FP'] += 1
    else:
        confusion['TP'] += 1
        if pred_tok != ' ':
            confusion_nsp['TP'] += 1
for gold_span, gold_tok in zip(gold_spans, gold_toks):
    if gold_span not in pred_spans:
        confusion['FN'] += 1
        if gold_tok != ' ':
            confusion_nsp['FN'] += 1
confusion['P'] = confusion['TP']/(confusion['TP'] + confusion['FP']) if (confusion['TP'] + confusion['FP']) else 0
confusion['R'] = confusion['TP']/(confusion['TP'] + confusion['FN']) if (confusion['TP'] + confusion['FN']) else 0
confusion['F'] = 2*confusion['P']*confusion['R']/(confusion['P']+confusion['R']) if (confusion['P']+confusion['R']) else 0
confusion_nsp['P'] = confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FP']) if (confusion_nsp['TP'] + confusion_nsp['FP']) else 0
confusion_nsp['R'] = confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FN']) if (confusion_nsp['TP'] + confusion_nsp['FN']) else 0
confusion_nsp['F'] = 2*confusion_nsp['P']*confusion_nsp['R']/(confusion_nsp['P']+confusion_nsp['R']) if (confusion_nsp['P']+confusion_nsp['R']) else 0

# for pred_span, pred_tok in zip(pred_spans, pred_toks):
#     if pred_span in gold_spans:
#         confusion['TP'] += 1
#         if pred_tok != ' ':
#             confusion_nsp['TP'] += 1
#             gold_spans_nsp[pred_span] = 0
#         else:
#             del(gold_spans_nsp[pred_span])
#         gold_spans[pred_span] = 0
#     else:
#         confusion['FP'] += 1
#         if pred_tok != ' ':
#             confusion_nsp['FP'] += 1
# confusion['FN'] = len(gold_spans) - confusion['TP']
# confusion_nsp['FN'] = len([t for t in gold_toks if t != ' ']) - confusion_nsp['TP']
# confusion['FN'] = sum(gold_spans.values())
# confusion_nsp['FN'] = sum(gold_spans_nsp.values())

# confusion['P'] = round(confusion['TP']/(confusion['TP'] + confusion['FP']), 4)
# confusion['R'] = round(confusion['TP']/(confusion['TP'] + confusion['FN']), 4)
# confusion['F'] = round(2*confusion['P']*confusion['R']/(confusion['P']+confusion["R"]), 4)
# confusion_nsp['P'] = round(confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FP']), 4)
# confusion_nsp['R'] = round(confusion_nsp['TP']/(confusion_nsp['TP'] + confusion_nsp['FN']), 4)
# confusion_nsp['F'] = round(2*confusion_nsp['P']*confusion_nsp['R']/(confusion_nsp['P']+confusion_nsp["R"]), 4)

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[d_i]['s_type'][s_i]
            accuracy_sty[test_docs[d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, d_i, s_i) for d_i, doc in enumerate(model._documents) 
                 for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], d_i, s_i) 
                 for d_i, d in enumerate(test_docs) for s_i, s in enumerate(d['conllu']) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Token segmentation performance with/out space", confusion, confusion_nsp)
print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
"Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                   for tag in accuracy_sty}).most_common()), list(Counter({tag: (sum(accuracy[tag])/len(accuracy[tag]), len(accuracy[tag])) 
                                                                                           for tag in accuracy}).most_common())


0it [00:00, ?it/s][A
1it [00:01,  1.32s/it][A
2it [00:01,  1.12s/it][A
3it [00:02,  1.06it/s][A
4it [00:03,  1.17it/s][A
5it [00:03,  1.33it/s][A
6it [00:04,  1.32it/s][A
7it [00:05,  1.35it/s][A
8it [00:06,  1.27it/s][A
9it [00:06,  1.38it/s][A
10it [00:07,  1.29it/s][A
11it [00:07,  1.46it/s][A
12it [00:08,  1.47it/s][A
13it [00:09,  1.46it/s][A
14it [00:09,  1.58it/s][A
15it [00:10,  1.53it/s][A
16it [00:11,  1.63it/s][A
17it [00:11,  1.78it/s][A
18it [00:12,  1.47it/s][A

  0%|          | 0/18 [00:00<?, ?it/s][A

Interpreting documents...



  6%|▌         | 1/18 [00:29<08:24, 29.67s/it][A
 11%|█         | 2/18 [01:02<08:11, 30.72s/it][A
 17%|█▋        | 3/18 [01:24<06:59, 27.95s/it][A
 22%|██▏       | 4/18 [01:53<06:36, 28.34s/it][A
 28%|██▊       | 5/18 [02:12<05:31, 25.51s/it][A
 33%|███▎      | 6/18 [02:37<05:04, 25.34s/it][A
 39%|███▉      | 7/18 [03:03<04:41, 25.62s/it][A
 44%|████▍     | 8/18 [03:23<03:58, 23.88s/it][A
 50%|█████     | 9/18 [03:49<03:40, 24.49s/it][A
 56%|█████▌    | 10/18 [04:27<03:49, 28.67s/it][A
 61%|██████    | 11/18 [04:44<02:55, 25.11s/it][A
 67%|██████▋   | 12/18 [05:09<02:30, 25.10s/it][A
 72%|███████▏  | 13/18 [05:33<02:02, 24.55s/it][A
 78%|███████▊  | 14/18 [05:56<01:36, 24.10s/it][A
 83%|████████▎ | 15/18 [06:28<01:19, 26.64s/it][A
 89%|████████▉ | 16/18 [06:44<00:46, 23.36s/it][A
 94%|█████████▍| 17/18 [06:59<00:21, 21.02s/it][A
100%|██████████| 18/18 [07:22<00:00, 24.57s/it][A


Token segmentation performance with/out space {'TP': 19105, 'FP': 9854, 'FN': 11921, 'P': 0.6597258192617148, 'R': 0.6157738670792239, 'F': 0.636992581478703} {'TP': 12149, 'FP': 6191, 'FN': 7730, 'P': 0.6624318429661941, 'R': 0.6111474420242466, 'F': 0.6357570841727935}
Overall POS accuracy with/out space 0.16205762908528332 0.07461766156882092
Overall SUP:DEP accuracy with/out space 0.1793570219966159 0.17895310796074154


('Tag-wise accuracy',
 [],
 [('INTJ', (0.3181818181818182, 88)),
  ('SPACE', (0.2577987846049966, 14810)),
  ('NOUN', (0.18646920228879166, 2971)),
  ('PUNCT', (0.16310679611650486, 2060)),
  ('NUM', (0.15451895043731778, 343)),
  ('PRON', (0.07613741875580315, 1077)),
  ('PROPN', (0.04424778761061947, 1243)),
  ('ADV', (0.03099510603588907, 613)),
  ('SYM', (0.02857142857142857, 35)),
  ('AUX', (0.020804438280166437, 721)),
  ('CCONJ', (0.02040816326530612, 588)),
  ('DET', (0.012160228898426323, 1398)),
  ('VERB', (0.012012012012012012, 1665)),
  ('PART', (0.008955223880597015, 335)),
  ('ADP', (0.007633587786259542, 1703)),
  ('ADJ', (0.0018083182640144665, 1106)),
  ('SCONJ', (0.0, 244)),
  ('X', (0.0, 26))])

In [12]:
from tqdm import tqdm

accuracy = defaultdict(list)
accuracy_nsp = defaultdict(list)
accuracy_all, accuracy_all_nsp, = [], []
sup_accuracy, sup_accuracy_nsp, = 0, 0
accuracy_sty = defaultdict(list)
accuracy_all_sty = []

model.interpret(tdocs, seed = 691,
                covering = [[[row[1] for row in s] for s in d['conllu']] for d in test_docs])

for d_i, doc in enumerate(model._documents):
    for s_i, s in enumerate(doc._sentences):
        if s._sty is not None:
            result = s._sty == test_docs[d_i]['s_type'][s_i]
            accuracy_sty[test_docs[d_i]['s_type'][s_i]].append(result)
            accuracy_all_sty.append(result)

pred_toks = [t._form for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_arcs = set([(ix, str(t._sup), t._dep, d_i, s_i) for d_i, doc in enumerate(model._documents) 
                 for s_i, s in enumerate(doc._sentences) for ix, t in enumerate(s._tokens)])
pred_spans = list(np.cumsum([len(t) for t in pred_toks]))
pred_stream = [t._pos for doc in model._documents for s in doc._sentences for t in s._tokens]
pred_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(pred_spans, pred_toks, pred_stream)}

gold_toks = [row[1] for d in test_docs for s in d['conllu'] for row in s]
gold_arcs = set([(ix, (str(int(row[6]) - int(row[0])) if int(row[6]) else row[6]), row[7], d_i, s_i) 
                 for d_i, d in enumerate(test_docs) for s_i, s in enumerate(d['conllu']) for ix, row in enumerate(s)])
gold_spans = list(np.cumsum([len(t) for t in gold_toks]))
gold_stream = [row[3] for d in test_docs for s in d['conllu'] for row in s]
gold_spans = {(sh-len(gt), sh): (gl, gt)
              for sh, gt, gl in zip(gold_spans, gold_toks, gold_stream)}

for gold_span in gold_spans:
    if gold_span in pred_spans:
        result = gold_spans[gold_span] == pred_spans[gold_span]
    else:
        result = False
    accuracy[gold_spans[gold_span][0]].append(result)
    accuracy_all.append(result)
    if gold_spans[gold_span][1] != ' ':
        accuracy_nsp[gold_spans[gold_span][0]].append(result)
        accuracy_all_nsp.append(result)
        
for ptok, parc in zip(pred_toks, pred_arcs):
    if parc in gold_arcs:
        sup_accuracy += 1
        if ptok != ' ':
            sup_accuracy_nsp += 1
sup_accuracy /= len(pred_toks)
sup_accuracy_nsp /= len([x for x in pred_toks if x != ' '])

print("Overall POS accuracy with/out space", sum(accuracy_all)/len(accuracy_all), sum(accuracy_all_nsp)/len(accuracy_all_nsp))
print("Overall SUP:DEP accuracy with/out space", sup_accuracy, sup_accuracy_nsp)
if len(accuracy_all_sty):
    print("Overall s_type accuracy: ", sum(accuracy_all_sty)/len(accuracy_all_sty))
"Tag-wise accuracy", list(Counter({tag: (sum(accuracy_sty[tag])/len(accuracy_sty[tag]), len(accuracy_sty[tag])) 
                                   for tag in accuracy_sty}).most_common()), list(Counter({tag: (sum(accuracy[tag])/len(accuracy[tag]), len(accuracy[tag])) 
                                                                                           for tag in accuracy}).most_common())


0it [00:00, ?it/s][A
1it [00:01,  1.36s/it][A
2it [00:02,  1.16s/it][A
3it [00:02,  1.03it/s][A
4it [00:03,  1.13it/s][A
5it [00:03,  1.28it/s][A
6it [00:04,  1.28it/s][A
7it [00:05,  1.31it/s][A
8it [00:06,  1.19it/s][A
9it [00:06,  1.31it/s][A
10it [00:07,  1.25it/s][A
11it [00:08,  1.43it/s][A
12it [00:08,  1.47it/s][A
13it [00:09,  1.44it/s][A
14it [00:10,  1.55it/s][A
15it [00:10,  1.49it/s][A
16it [00:11,  1.59it/s][A
17it [00:11,  1.74it/s][A
18it [00:12,  1.43it/s][A

  0%|          | 0/18 [00:00<?, ?it/s][A

Interpreting documents...



  6%|▌         | 1/18 [00:26<07:38, 27.00s/it][A
 11%|█         | 2/18 [00:52<07:05, 26.61s/it][A
 17%|█▋        | 3/18 [01:11<06:02, 24.13s/it][A
 22%|██▏       | 4/18 [01:38<05:50, 25.03s/it][A
 28%|██▊       | 5/18 [01:55<04:54, 22.68s/it][A
 33%|███▎      | 6/18 [02:19<04:36, 23.01s/it][A
 39%|███▉      | 7/18 [02:47<04:29, 24.53s/it][A
 44%|████▍     | 8/18 [03:05<03:46, 22.69s/it][A
 50%|█████     | 9/18 [03:28<03:24, 22.73s/it][A
 56%|█████▌    | 10/18 [04:03<03:32, 26.51s/it][A
 61%|██████    | 11/18 [04:20<02:44, 23.52s/it][A
 67%|██████▋   | 12/18 [04:45<02:23, 23.98s/it][A
 72%|███████▏  | 13/18 [05:11<02:02, 24.48s/it][A
 78%|███████▊  | 14/18 [05:32<01:34, 23.62s/it][A
 83%|████████▎ | 15/18 [06:05<01:19, 26.41s/it][A
 89%|████████▉ | 16/18 [06:21<00:46, 23.37s/it][A
 94%|█████████▍| 17/18 [06:37<00:21, 21.16s/it][A
100%|██████████| 18/18 [07:02<00:00, 23.45s/it][A


Overall POS accuracy with/out space 0.2470508605685554 0.18561914158855453
Overall SUP:DEP accuracy with/out space 0.2172398219239951 0.22090437361008153


('Tag-wise accuracy',
 [],
 [('NOUN', (0.6085493099966341, 2971)),
  ('INTJ', (0.38636363636363635, 88)),
  ('SPACE', (0.31431465226198513, 14810)),
  ('PRON', (0.2906220984215413, 1077)),
  ('NUM', (0.23323615160349853, 343)),
  ('PUNCT', (0.15485436893203883, 2060)),
  ('VERB', (0.12072072072072072, 1665)),
  ('ADV', (0.10440456769983687, 613)),
  ('PROPN', (0.06999195494770716, 1243)),
  ('PART', (0.06865671641791045, 335)),
  ('X', (0.038461538461538464, 26)),
  ('SYM', (0.02857142857142857, 35)),
  ('AUX', (0.022191400832177532, 721)),
  ('CCONJ', (0.02040816326530612, 588)),
  ('ADP', (0.015854374633000587, 1703)),
  ('DET', (0.011444921316165951, 1398)),
  ('SCONJ', (0.00819672131147541, 244)),
  ('ADJ', (0.0054249547920434, 1106))])