In [5]:
import sys
import math
from pqdict import pqdict

import torch
#from transformers import AutoTokenizer, AutoModelForSequenceClassification

#paraphraseTokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-MRPC")  
#paraphraseModel = AutoModelForSequenceClassification.from_pretrained("textattack/#roberta-base-MRPC")

from sentence_transformers import SentenceTransformer, util
model_name = "roberta-base-nli-stsb-mean-tokens"
sentenceBERT = SentenceTransformer(model_name)
classes = ["not paraphrase", "is paraphrase"]

def inference_mrpc(seq1s, seq2s):
    for i in range(len(seq1s)):
        paraphrase = paraphraseTokenizer.encode_plus(
            seq1s[i], seq2s[i], return_tensors="pt")
        logits = paraphraseModel(**paraphrase)[0]
        paraphrase_results = torch.softmax(logits, dim=1).tolist()[0]
        print(f"{classes[1]}: {round(paraphrase_results[1] * 100)}%")

In [8]:
from Udep2Mono.binarization import BinaryDependencyTree
from Udep2Mono.polarization import PolarizationPipeline
from Udep2Mono.util import btreeToList
from copy import deepcopy

from allennlp.predictors.predictor import Predictor
import allennlp_models.structured_prediction

ie_extractor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/openie-model.2020.03.26.tar.gz")

ERROR:allennlp.common.plugins:Plugin allennlp_models could not be loaded: No module named 'transformers.models.bart'


HBox(children=(HTML(value='downloading'), FloatProgress(value=0.0, max=54185577.0), HTML(value='')))


✔ Download and installation successful
You can now load the model via spacy.load('en_core_web_sm')
  0%|          | 933k/1.31G [49:37<1164:01:41, 313B/s]


In [38]:
from pattern.en import pluralize, singularize

def fix_info(desc):
    out = desc.replace("ARG0: ", "")
    out = out.replace("ARG1: ", "")
    out = out.replace("V: ", "")
    out = out.replace("[", "")
    out = out.replace("]", ",")
    out = out.split(",")
    out = list(map(lambda x: x.strip(), out))
    out = list(map(lambda x: x.split(" "), out))
    return out

class PhrasalGenerator:
    def __init__(self):
        self.deptree = None
        self.annotated = None
        self.kb = {}
        self.tree_log = []
        self.sent_log = []
        self.mod_at_left = [
            "advmod", "amod", "advmod:count", "acl:relcl", 
            "acl", "advcl", "xcomp", "ccomp", "appos"]
        self.mod_at_right = []
        
        '''
            "ccomp": self.generate_ccomp,
            "compound": self.generate_inherite,
            "compound:prt": self.generate_inherite,
            "cop": self.generate_inherite,
            
            "expl": self.generate_expl,
            "nmod": self.generate_nmod,
            "nmod:npmod": self.generate_nmod,
            "nmod:tmod": self.generate_nmod,
            "nmod:poss": self.generate_nmod_poss,
            
            "nummod": self.generate_nummod,
            "obl": self.generate_obj,
            "obl:npmod": self.generate_oblnpmod,
            "obl:tmod": self.generate_inherite,
        '''

    def deptree_generate(self, tree, annotated, original):
        self.tree_log = []
        self.sent_log = []
        self.deptree = tree
        self.annotated = annotated
        self.ie_pred = {}
        verbs = ie_extractor.predict(original)['verbs']
        for verb in verbs:
            self.ie_pred[verb['verb']] = fix_info(verb['description'])        
        self.generate(self.deptree)

    def generate(self, tree):
        if tree.val in self.mod_at_left:
            self.left_modifier_generate(tree)
        elif tree.val == "conj" and tree.mark == "+":
            self.generate_conj(tree)
        elif tree.left != "N":
            self.generate_default(tree)

    def delete_left_modifier(self, tree):
        tree.val = tree.right.val
        tree.mark = tree.right.mark
        tree.npos = tree.right.npos
        tree.id = tree.right.id
        tree.left = tree.right.left
        tree.right = tree.right.right

    def rollback(self, tree, backup):
        tree.val = backup.val
        tree.left = deepcopy(backup.left)
        tree.right = deepcopy(backup.right)
        tree.mark = backup.mark
        tree.npos = backup.npos
        tree.id = backup.id

    def left_modifier_generate(self, tree):
        # adv + VB | VB + adv => VB
        # amod + Noun => Noun
        # Noun + relcl => Noun
        left = tree.left
        right = tree.right
        backup = deepcopy(tree)

        if right.mark == "+":
            self.delete_left_modifier(tree)
            self.save_tree(isTree=True)
            self.rollback(tree, backup)   

        self.generate(tree.left)
        self.generate(tree.right) 

    def rollback_annotation(self, generated, original):
        word_id = self.annotated[generated]
        del self.annotated[generated]
        self.annotated[original] = word_id

    def generate_conj(self, tree):
        backup = deepcopy(tree)
        self.left_modifier_generate(tree)

        tree.val = tree.left.right.val
        tree.mark = tree.left.right.mark
        tree.npos = tree.left.right.npos
        tree.id = tree.left.right.id
        tree.left = backup.left.right.left
        tree.right = backup.left.right.right
        self.save_tree(isTree=True)
        self.rollback(tree, backup)

    def add_modifier(self, tree, mod, head, direct=0):
        if direct == 0:
            generated = ' '. join([mod, head])
        else:
            generated = ' '. join([head, mod])
        word_id = self.annotated[head]
        del self.annotated[head]
        self.annotated[generated] = word_id
        self.save_tree(isTree=False)
        self.rollback_annotation(generated, head)

    def generate_default(self, tree):
        left = tree.left
        right = tree.right

        if right.npos is not None:
            if "NN" in right.npos and right.mark == "-":
                for adj in self.kb["ADJ"]:
                    self.add_modifier(tree, adj, right.val)
                for rel in self.kb["RCL"]:
                    self.add_modifier(tree, rel, right.val, 1)
            elif "VB" in right.npos and right.mark == "-":
                for adv in self.kb["ADV"]:
                    self.add_modifier(tree, adv, right.val)
                    description = self.ie_pred[right.val]
                    arg1 = description[2][-1]
                    self.add_modifier(tree, adv, arg1, 1)

        self.generate(left)
        self.generate(right)  

    def save_tree(self, isTree):
        if isTree:
            self.tree_log.append(deepcopy(self.deptree))
        else:
            self.sent_log.append(deepcopy(self.annotated))

In [209]:
class AStarSearch:
    def __init__(self):    
        self.closed_forward = set()                   
        self.entailments = set()
        self.contradictions = set()
        self.hypothesis = ""
        self.phrasalGenerator = PhrasalGenerator()
        self.sbert = sentenceBERT

    def generate_premises(self, start):
        self.entailments.clear()
        self.contradictions.clear()
        pipeline = PolarizationPipeline(
            [start], verbose=0, parser="stanza")
        pipeline.run_polarize_pipeline()
        #print("\nPolarization Complete")

        for annotation in pipeline.annotations:
            #print("\n====================================")
            #print("\nInit Premise: " + annotation['annotated'])
            #polarized = annotation['polarized']
            #btreeViz = Tree.fromstring(polarized.replace('[', '(').replace(']', ')'))
            #jupyter_draw_nltk_tree(btreeViz) 
            self.phrasalGenerator.deptree_generate(
                annotation['polarized_tree'], 
                annotation['word_dict'], 
                annotation['original'])
            for gen_tree in self.phrasalGenerator.tree_log:
                generated, queue, _, _ = btreeToList(
                    gen_tree, len(annotation['original']), {}, 0)
                annotated = list(queue.popkeys())
                #print("\nNext Premise: " + ' '.join(annotated))
            for gen_sent in self.phrasalGenerator.sent_log:
                self.entailments.add(' '.join(gen_sent.popkeys()))
                #print("\nNext Premise: " + ' '.join(gen_sent.popkeys()))

    def word_similarity(self, s1, s2):
        num_sim = 0
        seq1 = s1.split(" ")
        for w in seq1:
            if w in s2:
                num_sim += 1
        return num_sim / len(seq1)

    def inference_sts(self, seqs1, seqs2):
        embeddings1 = self.sbert.encode(seqs1, convert_to_tensor=True)
        embeddings2 = self.sbert.encode(seqs2, convert_to_tensor=True)
        cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)
        for i in range(len(seqs1)):
            cost1 = cosine_scores[i][i].data.cpu().numpy()
            cost2 = self.word_similarity(seqs1[i], seqs2[i])
            cost = cost1 * cost2
            return cost

    def generate_motion(self, start, opened):
        self.generate_premises(start)
        for premise in self.entailments:
            if premise in self.closed_forward:
                continue
            cost = self.inference_sts([premise], [self.hypothesis])
            if premise not in opened:
                opened[premise] = 1 - cost
            if (1-cost) < opened[premise]:
                opened[premise] = 1-cost

    def query(self, premises, hypothesis):
        self.closed_forward.clear()
        self.hypothesis = hypothesis.lower()

        kb = {"ADJ": ["beautiful", "red", "fragret"], 
              "ADV": ["ergently", "clearly", "neccesaraly"],
              "RCL": ["which is beautiful", "which opens at night"]}
        self.phrasalGenerator.kb = kb

        open_lists = pqdict({})
        open_lists[premises] = 1 - self.inference_sts([premises], [hypothesis])

        while open_lists:
            optimal = open_lists.popitem()
            print("\nOptimal: ", optimal)
            if optimal[1] < 9.0e-07:
                self.closed_forward.add((self.hypothesis, 0.0000))
                return True
            self.generate_motion(optimal[0], open_lists)
            self.closed_forward.add(optimal)

In [210]:
search = AStarSearch()

In [211]:
search.query("no flowers need light", "no red flowers need light ergently")

100%|██████████| 1/1 [00:00<00:00, 10.64it/s]
Optimal:  ('no flowers need light', 0.19805026054382324)

100%|██████████| 1/1 [00:00<00:00, 12.20it/s]
Optimal:  ('no flowers need light ergently', 0.07331925630569458)


Optimal:  ('no red flowers need light ergently', 2.384185791015625e-07)


True

In [212]:
search.closed_forward

{('no flowers need light', 0.19805026054382324),
 ('no flowers need light ergently', 0.07331925630569458),
 ('no red flowers need light ergently', 0.0)}