In [85]:
import nltk
from nltk.corpus import dependency_treebank
import graphviz
import os
import random
random.seed(10)
os.environ["PATH"] += os.pathsep + r'C:\Program Files\Graphviz\bin'

In [86]:
import numpy as np

In [87]:
nltk.download('dependency_treebank')

[nltk_data] Downloading package dependency_treebank to
[nltk_data]     C:\Users\eliav\AppData\Roaming\nltk_data...
[nltk_data]   Package dependency_treebank is already up-to-date!


True

In [88]:
from collections import Counter, defaultdict, namedtuple
from networkx import DiGraph

In [89]:
sentences = dependency_treebank.parsed_sents()
k = len(sentences)//10
train, test = sentences[:-k], sentences[-k:]
random.shuffle(train)
print(f"train size = {len(train)}, test size = {len(test)}, total = {len(train) + len(test)}")

train size = 3523, test size = 391, total = 3914


In [90]:
len(train), len(test), len(test) + len(train), len(sentences)

(3523, 391, 3914, 3914)

In [91]:
vocab = Counter()
pos_tags = Counter()
for dg in sentences:
    for i, node in dg.nodes.items():
        w = str(node['word'])
        pos = node['tag']
        # print(w, pos)
        # for f in dic:
        #     if f(w):
        #         print(node['tag'])
        #         w = dic[f]
        # vocab.add(str(w).lower())
        vocab[w] += 1
        # if not w.isalpha() and '_' not in w:
        #     print(w)
        pos_tags[pos] += 1
print(len(vocab), len(pos_tags))

11968 46


In [92]:
from collections import Counter

def tree_to_counter(triples, lr=1):
    c = Counter()
    for x in triples:
        z = zip(x[0], x[-1])
        for i in range(lr):
            c.update(z)
    return c
tree_to_counter(sentences[0].triples())

Counter({('will', 'Vinken'): 1,
         ('MD', 'NNP'): 1,
         ('Vinken', 'Pierre'): 1,
         ('NNP', 'NNP'): 1,
         ('Vinken', ','): 2,
         ('NNP', ','): 2,
         ('Vinken', 'old'): 1,
         ('NNP', 'JJ'): 1,
         ('old', 'years'): 1,
         ('JJ', 'NNS'): 1,
         ('years', '61'): 1,
         ('NNS', 'CD'): 1,
         ('will', 'join'): 1,
         ('MD', 'VB'): 1,
         ('join', 'board'): 1,
         ('VB', 'NN'): 1,
         ('board', 'the'): 1,
         ('NN', 'DT'): 2,
         ('join', 'as'): 1,
         ('VB', 'IN'): 1,
         ('as', 'director'): 1,
         ('IN', 'NN'): 1,
         ('director', 'a'): 1,
         ('director', 'nonexecutive'): 1,
         ('NN', 'JJ'): 1,
         ('join', 'Nov.'): 1,
         ('VB', 'NNP'): 1,
         ('Nov.', '29'): 1,
         ('NNP', 'CD'): 1,
         ('will', '.'): 1,
         ('MD', '.'): 1})

In [93]:
def score(features, weights):
    score = 0
    for f in features:
        score += weights[f]
    return score

In [94]:
from networkx.drawing.nx_agraph import write_dot, read_dot

def nx_graph_to_dot(G):
    write_dot(G, r'C:\Users\eliav\PycharmProjects\NLP\ex4\blah.gv')
    return read_dot(r'C:\Users\eliav\PycharmProjects\NLP\ex4\blah.gv')

In [95]:
from Chu_Liu_Edmonds_algorithm import min_spanning_arborescence_nx
from networkx import minimum_spanning_arborescence
from nltk.parse import DependencyGraph
import pygraphviz
from networkx.drawing.nx_agraph import to_agraph
import time

def MST_helper(lr=1, max_iter=2):
    c = Counter()
    sum_c = Counter()
    num_empty_G, st = 0, time.time()
    for r in range(max_iter):
        for i, dg in enumerate(train):
            Arc = namedtuple("Arc", "head tail weight")
            arcs = [Arc(u,v,score(zip(u, v), c)) for u, _, v in dg.triples()]
            G = DiGraph()
            for arc in arcs:
                G.add_edge(arc.head, arc.tail, weight=arc.weight)
            if not len(G):
                num_empty_G += 1
                continue
            T_prime = minimum_spanning_arborescence(G)
            c += tree_to_counter(dg.triples(), lr) - tree_to_counter(T_prime.edges, lr)
            sum_c += c
    et = time.time()
    print(f"runtime: {et - st} seconds\n"
          f"encountered {num_empty_G} empty graphs out of {max_iter*len(train)}")
    return sum_c

In [96]:
# mst_counter = MST_helper()
# mst_counter

In [97]:
def MST(lr=1, max_iter=2):
    mst = MST_helper(lr, max_iter)
    for key in mst:
        mst[key] /= max_iter*len(train)
    return mst
mst_counter = MST()

runtime: 28.21900177001953 seconds
encountered 2 empty graphs out of 7046


In [98]:
lst = [1,2,3]
random.shuffle(lst)
lst

[3, 1, 2]

In [99]:
len(mst_counter)

5292

In [100]:
dg = sentences[0]
dot = dg.to_dot()
print(dot)
src = graphviz.Source(dot)
# dot.render('sent0-dep-treebank.gv', view=True)
# doctest_mark_exe()
# src.render('sent0-dep-treebank.gv', view=True).replace('\\', '/')

digraph G{
edge [dir=forward]
node [shape=plaintext]

0 [label="0 (None)"]
0 -> 8 [label="ROOT"]
1 [label="1 (Pierre)"]
2 [label="2 (Vinken)"]
2 -> 1 [label=""]
2 -> 3 [label=""]
2 -> 6 [label=""]
2 -> 7 [label=""]
3 [label="3 (,)"]
4 [label="4 (61)"]
5 [label="5 (years)"]
5 -> 4 [label=""]
6 [label="6 (old)"]
6 -> 5 [label=""]
7 [label="7 (,)"]
8 [label="8 (will)"]
8 -> 2 [label=""]
8 -> 9 [label=""]
8 -> 18 [label=""]
9 [label="9 (join)"]
9 -> 11 [label=""]
9 -> 12 [label=""]
9 -> 16 [label=""]
10 [label="10 (the)"]
11 [label="11 (board)"]
11 -> 10 [label=""]
12 [label="12 (as)"]
12 -> 15 [label=""]
13 [label="13 (a)"]
14 [label="14 (nonexecutive)"]
15 [label="15 (director)"]
15 -> 13 [label=""]
15 -> 14 [label=""]
16 [label="16 (Nov.)"]
16 -> 17 [label=""]
17 [label="17 (29)"]
18 [label="18 (.)"]
}


In [101]:
for dg in sentences[:10]:
    for head, rel, dep in dg.triples():
        print('({h[0]}, {h[1]}), {r}, ({d[0]}, {d[1]})'
              .format(h=head, r=rel, d=dep))

(will, MD), , (Vinken, NNP)
(Vinken, NNP), , (Pierre, NNP)
(Vinken, NNP), , (,, ,)
(Vinken, NNP), , (old, JJ)
(old, JJ), , (years, NNS)
(years, NNS), , (61, CD)
(Vinken, NNP), , (,, ,)
(will, MD), , (join, VB)
(join, VB), , (board, NN)
(board, NN), , (the, DT)
(join, VB), , (as, IN)
(as, IN), , (director, NN)
(director, NN), , (a, DT)
(director, NN), , (nonexecutive, JJ)
(join, VB), , (Nov., NNP)
(Nov., NNP), , (29, CD)
(will, MD), , (., .)
(is, VBZ), , (Vinken, NNP)
(Vinken, NNP), , (Mr., NNP)
(is, VBZ), , (chairman, NN)
(chairman, NN), , (of, IN)
(of, IN), , (group, NN)
(group, NN), , (N.V., NNP)
(N.V., NNP), , (Elsevier, NNP)
(group, NN), , (,, ,)
(group, NN), , (the, DT)
(group, NN), , (Dutch, NNP)
(group, NN), , (publishing, VBG)
(is, VBZ), , (., .)
(was, VBD), , (Agnew, NNP)
(Agnew, NNP), , (Rudolph, NNP)
(Agnew, NNP), , (,, ,)
(Agnew, NNP), , (old, JJ)
(old, JJ), , (years, NNS)
(years, NNS), , (55, CD)
(old, JJ), , (and, CC)
(old, JJ), , (chairman, NN)
(chairman, NN), , (former,

In [102]:
def view_tree(dg):
    src = graphviz.Source(dg.to_dot)
    src.render('sent-dep-treebank.gv', view=True).replace('\\', '/')

In [103]:
list(dg.triples())

[(('is', 'VBZ'), '', ('There', 'EX')),
 (('is', 'VBZ'), '', ('asbestos', 'NN')),
 (('asbestos', 'NN'), '', ('no', 'DT')),
 (('is', 'VBZ'), '', ('in', 'IN')),
 (('in', 'IN'), '', ('products', 'NNS')),
 (('products', 'NNS'), '', ('our', 'PRP$')),
 (('is', 'VBZ'), '', ('now', 'RB')),
 (('is', 'VBZ'), '', ('.', '.')),
 (('is', 'VBZ'), '', ("''", "''"))]

In [104]:
len(mst_counter)

5292

In [105]:
print(sentences[0].tree().pprint())

(will
  (Vinken Pierre , (old (years 61)) ,)
  (join (board the) (as (director a nonexecutive)) (Nov. 29))
  .)
None


In [106]:
dg.triples()

<generator object DependencyGraph.triples at 0x000001EB8AD9A270>

In [107]:
for dg in sentences[:10]:
    for head, rel, dep in dg.triples():
        print('({h[0]}, {h[1]}), {r}, ({d[0]}, {d[1]})'
              .format(h=head, r=rel, d=dep))

(will, MD), , (Vinken, NNP)
(Vinken, NNP), , (Pierre, NNP)
(Vinken, NNP), , (,, ,)
(Vinken, NNP), , (old, JJ)
(old, JJ), , (years, NNS)
(years, NNS), , (61, CD)
(Vinken, NNP), , (,, ,)
(will, MD), , (join, VB)
(join, VB), , (board, NN)
(board, NN), , (the, DT)
(join, VB), , (as, IN)
(as, IN), , (director, NN)
(director, NN), , (a, DT)
(director, NN), , (nonexecutive, JJ)
(join, VB), , (Nov., NNP)
(Nov., NNP), , (29, CD)
(will, MD), , (., .)
(is, VBZ), , (Vinken, NNP)
(Vinken, NNP), , (Mr., NNP)
(is, VBZ), , (chairman, NN)
(chairman, NN), , (of, IN)
(of, IN), , (group, NN)
(group, NN), , (N.V., NNP)
(N.V., NNP), , (Elsevier, NNP)
(group, NN), , (,, ,)
(group, NN), , (the, DT)
(group, NN), , (Dutch, NNP)
(group, NN), , (publishing, VBG)
(is, VBZ), , (., .)
(was, VBD), , (Agnew, NNP)
(Agnew, NNP), , (Rudolph, NNP)
(Agnew, NNP), , (,, ,)
(Agnew, NNP), , (old, JJ)
(old, JJ), , (years, NNS)
(years, NNS), , (55, CD)
(old, JJ), , (and, CC)
(old, JJ), , (chairman, NN)
(chairman, NN), , (former,

In [108]:
two_digit_num = lambda w: w.isnumeric() and len(w) == 2

four_digit_num = lambda w: w.isnumeric() and len(w) == 4

def decorated_contains(char):
  def contains_digit_and_char(w):
    ret = char in w
    for c in w:
      if c.isnumeric():
        return ret
    return False
  return contains_digit_and_char

contains_digit_and_dash = decorated_contains('-')
contains_digit_and_comma = decorated_contains(',')
contains_digit_and_period = decorated_contains('.')
contains_digit_and_slash = decorated_contains('/')

all_caps = lambda w: 'all_caps' if w.isupper() else False
othernum = lambda w: 'othernum' if w.isnumeric() else False
not_anum = lambda w: 'not_anum' if not w.isalnum() else False
cap_period = lambda w:"cap_period" if  len(w) == 2 and w.isupper() and w[1] == '.' else False
init_cap = lambda w: 'init_cap' if w[0].isupper() and all(a.isalpha() and a.islower() for a in w[1:]) else False
lowercase = lambda w: 'lowercase' if w.isalpha() and w.islower() else False
other = lambda w: True

In [109]:
dic = {two_digit_num: 'two_digit_num',
     four_digit_num:'four_digit_num',
     contains_digit_and_dash:'contains_digit_and_dash',
     contains_digit_and_comma:'contains_digit_and_comma',
     contains_digit_and_period:'contains_digit_and_period',
     contains_digit_and_slash:'contains_digit_and_slash',
     all_caps:'all_caps',
     othernum:'othernum',
     not_anum:"not_anum",
     cap_period:"cap_period",
     init_cap:"init_cap"}



In [110]:
vocab = set()
pos_tags = set()
for dg in sentences:
    for i, node in dg.nodes.items():
        w = str(node['word'])
        pos = node['tag']
        # print(w, pos)
        # for f in dic:
        #     if f(w):
        #         print(node['tag'])
        #         w = dic[f]
        # vocab.add(str(w).lower())
        vocab.add(str(w))
        # if not w.isalpha() and '_' not in w:
        #     print(w)
        pos_tags.add(pos)
# vocab
# pos_tags
# vocab

In [111]:
len(vocab)

11968

In [112]:
d = len(vocab)*len(vocab) + len(pos_tags)**2

In [113]:
vec = np.zeros(d)
vec

array([0., 0., 0., ..., 0., 0., 0.])

In [114]:
pairs_dict = {}
i = 0
for w1 in vocab:
    for w2 in vocab:
        i += 1
        pairs_dict[(w1, w2)] = i


pairs_dict

{('gallium', 'gallium'): 1,
 ('gallium', 'Mac'): 2,
 ('gallium', 'complicate'): 3,
 ('gallium', 'driver'): 4,
 ('gallium', 'Mehta'): 5,
 ('gallium', 'Deane'): 6,
 ('gallium', 'citizen'): 7,
 ('gallium', '2.65'): 8,
 ('gallium', 'varied'): 9,
 ('gallium', 'dean'): 10,
 ('gallium', 'Nissho-Iwai'): 11,
 ('gallium', 'Del'): 12,
 ('gallium', 'authorized'): 13,
 ('gallium', 'Grgich'): 14,
 ('gallium', '8.04'): 15,
 ('gallium', '234.4'): 16,
 ('gallium', 'diversify'): 17,
 ('gallium', 'celebrate'): 18,
 ('gallium', 'Sierra'): 19,
 ('gallium', 'Pencil'): 20,
 ('gallium', 'Though'): 21,
 ('gallium', 'Frankfurt'): 22,
 ('gallium', 'ratified'): 23,
 ('gallium', 'inverted'): 24,
 ('gallium', 'Improvement'): 25,
 ('gallium', 'casts'): 26,
 ('gallium', 'hair'): 27,
 ('gallium', 'frozen'): 28,
 ('gallium', '6.4'): 29,
 ('gallium', '139'): 30,
 ('gallium', 'offense'): 31,
 ('gallium', 'Filter'): 32,
 ('gallium', '90'): 33,
 ('gallium', 'P.'): 34,
 ('gallium', '1,880'): 35,
 ('gallium', 'old'): 36,
 ('

In [115]:
n = int(len(sentences)*0.1)
train, test = sentences[:-n], sentences[-n:]