In [2]:
import stanza

depparse_gum_config = {
    'lang': "en",
    'processors': "tokenize,pos,lemma,depparse",
    'tokenize_model_path': './model/en/tokenize/gum.pt',
    'pos_model_path': './model/en/pos/gum.pt',
    'depparse_model_path': './model/en/depparse/gum.pt',
    'lemma_model_path': './model/en/lemma/gum.pt',
    'tokenize_no_ssplit': True,
    'use_gpu': True,
    'pos_batch_size': 2000
}

token_config = {
    'lang': "en",
    'processors': "tokenize",
    'tokenize_model_path': './model/en/tokenize/gum.pt',
    'tokenize_no_ssplit': True,
    'use_gpu': False,
    'pos_batch_size': 3000
}

gum_depparse = stanza.Pipeline(**depparse_gum_config)
tokenizer = stanza.Pipeline(**token_config)

2021-06-10 01:52:21 INFO: Loading these models for language: en (English):
| Processor | Package                 |
---------------------------------------
| tokenize  | ./model/en...ize/gum.pt |
| pos       | ./model/en/pos/gum.pt   |
| lemma     | ./model/en/lemma/gum.pt |
| depparse  | ./model/en...rse/gum.pt |

2021-06-10 01:52:21 INFO: Use device: gpu
2021-06-10 01:52:21 INFO: Loading: tokenize
2021-06-10 01:52:21 INFO: Loading: pos
2021-06-10 01:52:22 INFO: Loading: lemma
2021-06-10 01:52:22 INFO: Loading: depparse
2021-06-10 01:52:23 INFO: Done loading processors!
2021-06-10 01:52:23 INFO: Loading these models for language: en (English):
| Processor | Package                 |
---------------------------------------
| tokenize  | ./model/en...ize/gum.pt |

2021-06-10 01:52:23 INFO: Use device: cpu
2021-06-10 01:52:23 INFO: Loading: tokenize
2021-06-10 01:52:23 INFO: Done loading processors!


In [3]:
def dependency_parse(sentence, parser="gum"):
    processed, replaced = preprocess(sentence)
    return stanza_parse(processed, parser=parser), replaced


def stanza_parse(sentence, parser="gum"):
    postags = {}
    words = {}
    parse_tree = []
    head_log = {}
    depdent_log = {}

    parsed = gum_depparse(sentence + "\n")
    """if parser == "ewt":
        parsed = ewt_depparse(sentence)"""

    for sent in parsed.sentences:
        for word in sent.words:
            tree_node = post_process(sent, word, postags, words)

            if len(tree_node) == 0:
                continue

            if tree_node[2] in head_log:
                head_log[tree_node[2]].append(tree_node[0])
            else:
                head_log[tree_node[2]] = [tree_node[0]]

            if tree_node[1] in depdent_log:
                depdent_log[tree_node[1]].append(tree_node[0])
            else:
                depdent_log[tree_node[1]] = [tree_node[0]]

            parse_tree.append(tree_node)

        enhance_parse(parse_tree, head_log, depdent_log, words)
    return parse_tree, postags, words


def enhance_parse(tree, heads, deps, words):
    for node in tree:
        if node[0] == "conj":
            if "nsubj" in heads[node[1]] and "nsubj" in heads[node[2]]:
                node[0] = "conj-sent"
            elif words[node[1]][1] == "JJ" and words[node[2]][1] == "JJ":
                node[0] = "conj-adj"
            elif "NN" in words[node[1]][1] and "NN" in words[node[2]][1]:
                node[0] = "conj-n"
                vp_rel = set(["amod", "compound", "compound",  "compound:prt", "det",
                              "nummod", "appos", "advmod", "nmod", "nmod:poss"])
                vp_left = set(heads[node[1]]) & vp_rel
                vp_right = set(heads[node[2]]) & vp_rel
                if len(vp_left) and len(vp_right):
                    node[0] = "conj-np"
            elif "VB" in words[node[1]][1] and "VB" in words[node[2]][1]:
                node[0] = "conj-vb"
                vp_rel = set(["obj", "xcomp", "obl"])
                vp_left = set(heads[node[1]]) & vp_rel
                vp_right = set(heads[node[2]]) & vp_rel

                if len(vp_left):
                    if len(vp_right):
                        node[0] = "conj-vp"
                    # else:

        if node[0] == "advcl":
            if words[1][0] == "if":
                node[0] = "advcl-sent"
        if node[0] == "advmod":
            if words[node[1]][0] == "not" and node[1] == 1:
                node[0] = "advmod-sent"
        if node[0] == "case" and node[1] - node[2] > 0:
            node[0] = "case-after"
        if words[node[1]][0] in ["at-most", "at-least", "more-than", "less-than"]:
            node[0] = "det"


def post_process(sent, word, postag, words):
    word_id = int(word.id)
    if word_id not in words:
        postag[word.text] = (word_id, word.xpos)
        words[word_id] = (word.text, word.xpos)
    if word.deprel != "punct":
        tree_node = [word.deprel, word_id,
                     word.head if word.head > 0 else "root"]
        return tree_node
    return []


def printTree(tree, tag, word):
    if tree[0] != "root":
        print(
            f"word: {word[tree[1]][0]}\thead: {word[tree[2]][0]}\tdeprel: {tree[0]}", sep="\n")

In [4]:
from pqdict import pqdict

negate_mark = {
    "+": "-",
    "-": "+",
    "=": "="
}

class BinaryDependencyTree:
    def __init__(self, val, left, right, key, counter, id=None, pos=None):
        self.val = val
        self.parent = None
        self.left = left
        self.right = right
        self.mark = "0"
        self.id = id
        self.pos = pos
        self.key = key
        self.is_root = False
        self.is_tree = True
        self.length = 0
        self.leaves = pqdict({})
        self.counter = counter
        self.replaced = {}

    def sorted_leaves(self):
        self.traverse(self)
        return self.leaves

    def traverse(self, tree, multi_word=False):
        if not tree.is_tree:
            replacement = False
            if str((tree.val, tree.id)) in self.replaced:
                tree.val = self.replaced[str((tree.val, tree.id))]
                replacement = True
            if "-" in tree.val and replacement and multi_word:
                words = tree.val.split('-')
                words.reverse()
                for i in range(len(words)):
                    word_id = tree.id - i * 0.1
                    key = (words[i], tree.pos, tree.mark, word_id)
                    if words[i].lower() == "not" and len(words) == 2:
                        key = (words[i], tree.pos,
                               negate_mark[tree.mark], word_id)
                    self.leaves[key] = (word_id)
            else:
                item = (tree.id)
                key = (tree.val, tree.pos, tree.mark, tree.id)
                self.leaves[key] = item
        else:
            self.traverse(tree.left)
            self.traverse(tree.right)

    def copy(self):
        left = None
        if self.left is not None:
            left = self.left.copy()
        right = None
        if self.right is not None:
            right = self.right.copy()
        new_tree = BinaryDependencyTree(
            self.val, left, right, self.key, self.counter, self.id, self.pos)
        new_tree.mark = self.mark
        new_tree.parent = self.parent
        new_tree.is_tree = self.is_tree
        new_tree.is_root = self.is_root
        new_tree.leaves = pqdict({})
        return new_tree

    def set_length(self, lth):
        self.length = lth

    def set_root(self):
        self.is_root = True

    def set_not_tree(self):
        self.is_tree = False


hierarchy = {
    "conj-sent": 0,
    "advcl-sent": 1,
    "advmod-sent": 2,
    "case": 10,
    "case-after": 75,
    "mark": 10,
    "expl": 10,
    "discourse": 10,
    "nsubj": 20,
    "csubj": 20,
    "nsubj:pass": 20,
    "conj-vp": 25,
    "ccomp": 30,
    "advcl": 30,
    "advmod": 30,
    "nmod": 30,
    "nmod:tmod": 30,
    "nmod:npmod": 30,
    "nmod:poss": 30,
    "xcomp": 40,
    "aux": 40,
    "aux:pass": 40,
    "obj": 60,
    "iobj": 60,
    "obl": 50,
    "obl:tmod": 50,
    "obl:npmod": 50,
    "cop": 50,
    "acl": 60,
    "acl:relcl": 60,
    "appos": 60,
    "conj": 60,
    "conj-np": 60,
    "conj-adj": 60,
    "det": 55,
    "det:predet": 55,
    "cc": 70,
    "cc:preconj": 70,
    "nummod": 75,
    "fixed": 80,
    "compound": 80,
    "compound:prt": 80,
    "fixed": 80,
    "amod": 75,
    "conj-n": 90,
    "conj-vb": 90,
    "dep": 100,
    "flat": 100,
    "goeswith": 100,
    "parataxis": 100
}


class UnifiedCounter:
    def __init__(self, initial_val=0):
        self.addi_negates = initial_val
        self.unifies = initial_val
        self.nsubjLeft = False
        self.expl = False
        self.willing_verb = False

    def add_negates(self):
        self.addi_negates += 1

    def add_unifies(self):
        self.unifies += 1

    def is_unified_clause_subj(self):
        return self.unifies % 2 == 1 and self.nsubjLeft


class Binarizer:
    def __init__(self, parse_table=None, postag=None, words=None):
        self.postag = postag
        self.parse_table = parse_table
        self.words = words
        self.id = 0
        self.counter = UnifiedCounter(0)
        self.replaced = {}

    def process_not(self, children):
        if len(children) > 1:
            if children[0][0] == "advmod":
                if self.words[children[1][1]][0] == "not":
                    return [children[1]]
        return children

    def compose(self, head):
        children = list(filter(lambda x: x[2] == head, self.parse_table))
        children.sort(key=(lambda x: hierarchy[x[0]]))
        children = self.process_not(children)

        if len(children) == 0:
            word = self.words[head][0]
            tag = self.words[head][1]
            binary_tree = BinaryDependencyTree(
                word, None, None, self.id, self.counter, head, tag)
            binary_tree.replaced = self.replaced
            self.id += 1
            binary_tree.set_not_tree()
            return binary_tree, [binary_tree.key]
        else:
            top_dep = children[0]
        self.parse_table.remove(top_dep)

        left, left_rel = self.compose(top_dep[1])
        right, right_rel = self.compose(top_dep[2])
        if "conj" in top_dep[0]:
            dep_rel = "conj"
        elif "case" in top_dep[0]:
            dep_rel = "case"
        elif "advcl" in top_dep[0]:
            dep_rel = "advcl"
        elif "advmod" in top_dep[0]:
            dep_rel = "advmod"
        else:
            dep_rel = top_dep[0]

        binary_tree = BinaryDependencyTree(
            dep_rel, left, right, self.id, self.counter)
        binary_tree.left.parent = binary_tree
        binary_tree.right.parent = binary_tree
        binary_tree.replaced = self.replaced

        left_rel.append(binary_tree.key)
        self.id += 1
        return binary_tree, left_rel + right_rel

    def binarization(self):
        self.id = 0
        self.relation = []
        root = list(filter(lambda x: x[0] == "root", self.parse_table))[0][1]
        self.counter = UnifiedCounter(0)
        binary_tree, relation = self.compose(root)
        binary_tree.set_root()
        binary_tree.length = len(self.words)
        return binary_tree, relation

In [6]:
from pattern.en import conjugate
from nltk.tree import Tree
from nltk.draw import TreeWidget
from nltk.draw.util import CanvasFrame
from IPython.display import Image, display

arrows = {
    "+": "\u2191",
    "-": "\u2193",
    "=": "=",
    "0": ""
}

arrow2int = {
    "\u2191": 1,
    "\u2193": -1,
    "=": 0
}

def btree2list(binaryDepdency, verbose=0):
    def to_list(tree):
        treelist = []
        if tree.is_tree:
            word = tree.val + arrows[tree.mark]
            if verbose == 2:
                word += str(tree.key)
            treelist.append(word)
        else:
            treelist.append(tree.pos)
            word = tree.val.replace('-', ' ') + arrows[tree.mark]
            if verbose == 2:
                word += str(tree.key)
            treelist.append(word)

        if tree.left is not None:
            treelist.append(to_list(tree.left))

        if tree.right is not None:
            treelist.append(to_list(tree.right))

        return treelist
    return to_list(binaryDepdency)

def jupyter_draw_nltk_tree(tree):
    cf = CanvasFrame()
    tc = TreeWidget(cf.canvas(), tree)
    tc['node_font'] = 'arial 14 bold'
    tc['leaf_font'] = 'arial 14'
    tc['node_color'] = '#005990'
    tc['leaf_color'] = '#3F8F57'
    tc['line_color'] = '#175252'
    cf.add_widget(tc, 20, 20)
    cf.print_to_file('../data/tree.ps')
    cf.destroy()
    os.system('convert ../data/tree.ps ../data/tree.png')
    display(Image(filename='../data/tree.png'))
    
def jupyter_draw_rsyntax_tree(tree):
    font_size = '8'
    command = 'rsyntaxtree -s {} "{}"'.format(font_size, tree)
    os.system(command)
    display(Image(filename='./syntree.png'))

In [None]:
class GraphFactoryPipeline:
    def __init__(self, verbose=0, parser="gum"):
        self.parser = parser
        self.binarizer = Binarizer()
        self.exceptioned = []
        self.verbose = verbose

    def run_binarization(self, parsed, replaced, sentence):
        self.binarizer.parse_table = parsed[0]
        self.binarizer.postag = parsed[1]
        self.binarizer.words = parsed[2]

        if self.verbose == 2:
            print()
            print(parsed[0])
            print()
            print(parsed[1])
            print()
            print(replaced)

        self.binarizer.replaced = replaced
        binary_dep, relation = self.binarizer.binarization()
        if self.verbose == 2:
            self.postprocess(binary_dep)
        return binary_dep, relation

    def postprocess(self, tree, svg=False):
        sexpression = btree2list(tree, 0)
        if not svg:
            sexpression = '[%s]' % ', '.join(
                map(str, sexpression)).replace(",", " ").replace("'", "")
        # print(sexpression)
        return sexpression

    def run_polarization(self, binary_dep, relation, replaced, sentence):
        self.polarizer.dependtree = binary_dep
        self.polarizer.relation = relation
        self.polarizer.replaced = replaced

        self.polarizer.polarize_deptree()
        if self.verbose == 2:
            self.postprocess(binary_dep)
        elif self.verbose == 1:
            polarized = self.postprocess(binary_dep)
            svgling.draw_tree(polarized)
            # jupyter_draw_rsyntax_tree(polarized)
            #btreeViz = Tree.fromstring(polarized.replace('[', '(').replace(']', ')'))
            # jupyter_draw_nltk_tree(btreeViz)

    def modify_replacement(self, tree, replace):
        if str((tree.val, tree.id)) in replace:
            tree.val = replace[str((tree.val, tree.id))]

        if tree.is_tree:
            self.modify_replacement(tree.left, replace)
            self.modify_replacement(tree.right, replace)

    def single_polarization(self, sentence):
        parsed, replaced = dependency_parse(sentence, self.parser)
        # print(parsed)
        binary_dep, relation = self.run_binarization(
            parsed, replaced, sentence)
        # print(parsed)
        self.run_polarization(binary_dep, relation, replaced, sentence)
        annotated = self.polarizer.dependtree.sorted_leaves()

        if self.verbose == 2:
            annotated_sent = ' '.join([word[0] for word in annotated.keys()])
            self.annotated_sentences.append(annotated_sent)

        self.modify_replacement(self.polarizer.dependtree, replaced)

        return {
            'original': sentence,
            'annotated': annotated,
            'polarized_tree': self.polarizer.dependtree,
        }
