In [41]:
from Udep2Mono.binarization import BinaryDependencyTree
from Udep2Mono import polarization
from Udep2Mono.util import btreeToList
from copy import deepcopy

In [48]:
class TreeFactory:
    def __init__(self):
        self.tree_builder = {
            "amod": self.build_amod
        }
        
    def buildTree(self, config):
        return self.tree_builder[config['rel']](config)

    def build_amod(self, config):
        wid = config['wid']
        half_wid = wid - (wid-(wid-1))/2
        left = BinaryDependencyTree(
            config['mod'], "N", "N", 1024, 
            wid=half_wid, npos="JJ")
        right = BinaryDependencyTree(
            config['head'], "N", "N", 1024,
            wid=wid, npos="NN")
        tree = BinaryDependencyTree("amod", left, right, 1025)
        left.mark = config['mark']
        right.mark = config['mark']
        tree.mark = config['mark']
        return tree

In [68]:
from database import *
from pattern.en import pluralize, singularize

# TODO: Verb Phrase Patterns
# 1.Intransitive: subject + VI
# 2.Linking: subject + VL + NP/AdjP
# 3.Transitive: subject + VT + NP
# 4.Ditransitive: subject + VD + NP(indirect) + NP(direct)
# 5.Complex Transitive: subject + VC + NP(direct) + NP/AdjP
# 6.Open clausal complement: VB + to/that VP
class PhrasalGenerator:
    def __init__(self):
        self.deptree = None
        self.length = 0
        self.kb = {}
        self.tree_factory = TreeFactory()
        self.treeLog = []
        self.polarLog = []
        self.lexical_generation = {
             "advmod": self.generate_advmod,
             "advmod:count": self.generate_advmod,
             "amod": self.generate_amod,
             "acl": self.generate_acl_relcl,
             "acl:relcl": self.generate_acl_relcl,
             "advcl": self.generate_acl_relcl,
             "cc:preconj": self.generate_default,
             "det": self.generate_default,
             "det:predet": self.generate_default,
             "nsubj": self.generate_default,
             "nsubj:pass": self.generate_default,
        }
        '''"advmod": self.generate_advmod,
            "advmod:count": self.generate_advmod,
            "amod": self.generate_amod,
            "appos": self.generate_inherite,
            "aux": self.generate_aux,
            "aux:pass": self.generate_aux,
            "case": self.generate_case,
            "cc": self.generate_cc,
           
            "ccomp": self.generate_ccomp,
            "compound": self.generate_inherite,
            "compound:prt": self.generate_inherite,
            "conj": self.generate_inherite,
            "cop": self.generate_inherite,
            "csubj": self.generate_nsubj,
            "csubj:pass": self.generate_nsubj,
            "dep": self.generate_dep,
            
            "discourse": self.generate_discourse,
            "expl": self.generate_expl,
            "fixed": self.generate_inherite,
            "flat": self.generate_inherite,
            "goeswith": self.generate_inherite,
            "iobj": self.generate_inherite,
            "mark": self.generate_inherite,
            "nmod": self.generate_nmod,
            "nmod:npmod": self.generate_nmod,
            "nmod:tmod": self.generate_nmod,
            "nmod:poss": self.generate_nmod_poss,
            
            "nummod": self.generate_nummod,
            "obj": self.generate_obj,
            "obl": self.generate_obj,
            "obl:npmod": self.generate_oblnpmod,
            "obl:tmod": self.generate_inherite,
            "parataxis": self.generate_inherite,
            "xcomp": self.generate_obj,'''

    def deptree_generate(self, length, tree):
        self.treeLog = []
        self.deptree = tree
        self.length = length
        self.generate(self.deptree)

    def generate(self, tree):
        if tree.val in self.lexical_generation.keys():
            
            self.lexical_generation[tree.val](tree)

    def save_tree(self, tree=None):
        #if tree is not None:
            #generated, _, _, _ = btreeToList(tree, self.length, {}, 0)
        #else:
        #    generated, _, _, _ = btreeToList(self.deptree, self.length, {}, 0)
        #generated = '[%s]' % ', '.join(map(str, generated)).replace("'", "")
        #generated = generated.replace(",", "")
        #print("New tree: ", generated)

        if tree is not None:
            return deepcopy(tree)
        else:
            return deepcopy(self.deptree)

    def delete_modifier(self, tree):
        tree.val = tree.right.val
        tree.mark = tree.right.mark
        tree.npos = tree.right.npos
        tree.id = tree.right.id
        tree.left = tree.right.left
        tree.right = tree.right.right

    def rollback(self, tree, backup):
        tree.val = backup.val
        tree.left = deepcopy(backup.left)
        tree.right = deepcopy(backup.right)
        tree.mark = backup.mark
        tree.npos = backup.npos
        tree.id = backup.id

    def generate_advmod(self, tree):
        "adv + VB | VB + adv => VB"
        left = tree.left
        right = tree.right
        backup = deepcopy(tree)

        if right.mark == "+":
            self.delete_modifier(tree)
            self.treeLog.append(self.save_tree())
            self.rollback(tree, backup)    

    def generate_amod(self, tree):  
        "amod + Noun => Noun"
        left = tree.left 
        right = tree.right
        backup = deepcopy(tree)

        if right.mark == "+":
            self.delete_modifier(tree)
            self.treeLog.append(self.save_tree())
            self.rollback(tree, backup)

    def generate_acl_relcl(self, tree):
        "Noun + relcl => Noun"
        left = tree.left
        right = tree.right
        backup = deepcopy(tree)

        if right.mark == "+":
            self.delete_modifier(tree)
            self.treeLog.append(self.save_tree())
            self.rollback(tree, backup)

    def generate_default(self, tree):
        left = tree.left
        right = tree.right 
        backup = deepcopy(tree)

        if right.npos is not None:
            if "NN" in right.npos and right.mark == "-":
                for adj in kb["ADJ"]:
                    amod_tree = self.tree_factory.buildTree(
                        {'rel': "amod",
                         'mod': adj,
                         'head': right.val,
                         'mark': "-",
                         'wid': right.id})
                    tree.right = amod_tree
                    self.treeLog.append(self.save_tree())
                    self.rollback(tree, backup) 

        self.generate(left)
        self.generate(right)  

In [69]:
import heapq

sentences = ["Some red flowers need light", "All flowers need light"]
annotations, _ = polarization.run_polarize_pipeline(
    sentences, verbose=2, parser="stanza")

kb = {"ADJ": ["beautiful", "red", "fragret"]}

phrasalGenerator = PhrasalGenerator()
for annotation in annotations:
    annotated, original, polarized, postags, polarized_tree = annotation
    print("\n" + annotated)
    phrasalGenerator.kb = kb
    phrasalGenerator.deptree_generate(len(original), polarized_tree)
    for gen_tree in phrasalGenerator.treeLog:
        generated, queue, _, _ = btreeToList(gen_tree, len(original), {}, 0)
        annotated = []
        while queue:
            next_item = heapq.heappop(queue)
            annotated.append(next_item[1])
        print(' '.join(annotated))

100%|██████████| 2/2 [00:00<00:00,  7.04it/s]


some↑ red↑ flowers↑ need↑ light↑
some↑ flowers↑ need↑ light↑

all↑ flowers↓ need↑ light↑
all↑ beautiful↓ flowers↓ need↑ light↑
all↑ red↓ flowers↓ need↑ light↑
all↑ fragret↓ flowers↓ need↑ light↑





(1, 'all↑')