In [1]:
from pqdict import pqdict

import torch
#from transformers import AutoTokenizer, AutoModelForSequenceClassification

#paraphraseTokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-MRPC")  
#paraphraseModel = AutoModelForSequenceClassification.from_pretrained("textattack/#roberta-base-MRPC")

from sentence_transformers import SentenceTransformer, util
model_name = "roberta-large-nli-stsb-mean-tokens"
sentenceBERT = SentenceTransformer(model_name)
classes = ["not paraphrase", "is paraphrase"]

def inference_mrpc(seq1s, seq2s):
    for i in range(len(seq1s)):
        paraphrase = paraphraseTokenizer.encode_plus(
            seq1s[i], seq2s[i], return_tensors="pt")
        logits = paraphraseModel(**paraphrase)[0]
        paraphrase_results = torch.softmax(logits, dim=1).tolist()[0]
        print(f"{classes[1]}: {round(paraphrase_results[1] * 100)}%")

In [2]:
from copy import deepcopy
from Udep2Mono.util import btree2list
from Udep2Mono.binarization import BinaryDependencyTree
from Udep2Mono.polarization import PolarizationPipeline
from allennlp.predictors.predictor import Predictor
import allennlp_models.structured_prediction

ie_extractor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/openie-model.2020.03.26.tar.gz")

INFO:stanza:Loading these models for language: en (English):
| Processor | Package   |
-------------------------
| tokenize  | gum       |
| pos       | gum       |
| lemma     | gum       |
| depparse  | gum       |
| sentiment | sstplus   |
| ner       | ontonotes |

INFO:stanza:Use device: gpu
INFO:stanza:Loading: tokenize
INFO:stanza:Loading: pos
INFO:stanza:Loading: lemma
INFO:stanza:Loading: depparse
INFO:stanza:Loading: sentiment
INFO:stanza:Loading: ner
INFO:stanza:Done loading processors!


In [3]:
from pattern.en import pluralize, singularize
from copy import copy

def fix_info(desc):
    out = desc.replace("ARG0: ", "")
    out = out.replace("ARG1: ", "")
    out = out.replace("V: ", "")
    out = out.replace("[", "")
    out = out.replace("]", ",")
    out = out.split(",")
    out = list(map(lambda x: x.strip(), out))
    return out

class PhrasalGenerator:
    def __init__(self):
        self.deptree = None
        self.annotated = None
        self.original = None
        self.kb = {}
        self.tree_log = []
        self.sent_log = []
        self.mod_at_left = [
            "advmod", "amod", "advmod:count", "acl:relcl", 
            "acl", "advcl", "xcomp", "ccomp", "appos"]
        self.mod_at_right = []
        
        '''
            "ccomp": self.generate_ccomp,
            "compound": self.generate_inherite,
            "compound:prt": self.generate_inherite,
            "cop": self.generate_inherite,
            
            "expl": self.generate_expl,
            "nmod": self.generate_nmod,
            "nmod:npmod": self.generate_nmod,
            "nmod:tmod": self.generate_nmod,
            "nmod:poss": self.generate_nmod_poss,
            
            "nummod": self.generate_nummod,
            "obl": self.generate_obj,
            "obl:npmod": self.generate_oblnpmod,
            "obl:tmod": self.generate_inherite,
        '''

    def deptree_generate(self, tree, annotated, original):
        self.tree_log = []
        self.sent_log = []
        self.deptree = tree
        self.original = original
        self.annotated = annotated
        self.ie_pred = {}
        verbs = ie_extractor.predict(original)['verbs']
        for verb in verbs:
            self.ie_pred[verb['verb']] = fix_info(verb['description'])        
        self.generate(self.deptree)

    def generate(self, tree):
        if tree.val in self.mod_at_left:
            self.left_modifier_generate(tree)
        elif tree.val == "conj" and tree.mark == "+":
            self.generate_conj(tree)
        elif tree.is_tree:
            self.generate_default(tree)

    def delete_left_modifier(self, tree):
        tree.val = tree.right.val
        tree.mark = tree.right.mark
        tree.pos = tree.right.pos
        tree.id = tree.right.id
        
        tree.is_tree = tree.right.is_tree
        tree.is_root = tree.right.is_root

        tree.left = tree.right.left
        tree.right = tree.right.right

    def rollback(self, tree, backup):
        tree.val = backup.val
        tree.left = deepcopy(backup.left)
        tree.right = deepcopy(backup.right)
        tree.mark = backup.mark
        tree.pos = backup.pos
        tree.id = backup.id
        tree.is_tree = backup.is_tree
        tree.is_root = backup.is_root

    def left_modifier_generate(self, tree):
        left = tree.left
        right = tree.right
        backup = deepcopy(tree)

        if right.mark == "+" or right.mark == "=":
            self.delete_left_modifier(tree)
            self.save_tree(isTree=True)
            self.rollback(tree, backup)   
        
        self.generate(tree.left)
        self.generate(tree.right)

    def rollback_annotation(self, generated, original):
        word_id = self.annotated[generated]
        del self.annotated[generated]
        self.annotated[original] = word_id

    def generate_conj(self, tree):
        backup = deepcopy(tree)
        self.left_modifier_generate(tree)

        tree.val = tree.left.right.val
        tree.mark = tree.left.right.mark
        tree.pos = tree.left.right.pos
        tree.id = tree.left.right.id
        tree.left = backup.left.right.left
        tree.right = backup.left.right.right
        self.save_tree(isTree=True)
        self.rollback(tree, backup)

    def add_modifier(self, tree, mod, head, direct=0):
        if direct == 0:
            generated = ' '. join([mod, head])
        else:
            generated = ' '. join([head, mod])

        orig_key = (head, tree.pos, tree.mark)
        gen_key = (generated, tree.pos, tree.mark)

        word_id = self.annotated[orig_key]
        del self.annotated[orig_key]
        self.annotated[gen_key] = word_id
        self.save_tree(isTree=False)
        self.rollback_annotation(gen_key, orig_key)

    def generate_default(self, tree):
        left = tree.left
        right = tree.right

        if right.pos is not None:
            if "NN" in right.pos and right.mark == "-":
                for adj in self.kb["ADJ"]:
                    self.add_modifier(right, adj, right.val)
                for rel in self.kb["RCL"]:
                    self.add_modifier(right, rel, right.val, 1)
            elif "VB" in right.pos and right.mark == "-":
                for adv in self.kb["ADV"]:
                    self.add_modifier(right, adv, right.val)
                    self.add_modifier(right, adv, right.val, 1)
                    description = self.ie_pred[right.val]
                    arg1 = description[2]
                    self.sent_log.append(
                        self.original.replace(
                            arg1, ' '.join([arg1, adv])))

        self.generate(tree.left)
        self.generate(tree.right)  

    def save_tree(self, isTree):
        if isTree:
            #leaves = self.deptree.sorted_leaves().popkeys()
            #sentence = ' '.join([x[0] for x in leaves])
            #print(sentence)
            self.tree_log.append(self.deptree.copy())
            #leaves = copy(self.deptree).sorted_leaves().popkeys()
            #sentence = ' '.join([x[0] for x in leaves])
            #print(sentence)
        else:
            annotated_cp = deepcopy(self.annotated)
            self.sent_log.append(
                ' '.join([word[0] for word in list(annotated_cp.popkeys())]))

In [4]:
from pymongo import MongoClient
from wordnet import *
from Udep2Mono.util import det_mark, det_type

client = MongoClient('localhost', 27017)
db = client.UdepLog
quantifier = db.quantifier

In [5]:
class LexicalGenerator:
    def __init__(self):
        self.postags = None
        self.hypothesis = ""
        self.sent_log = []
        self.replacement_log = []
        self.key_tokens = [
            'NN','NNS','NNP','NNPS','VBD',
            'VBG','VBN','VBZ','VB']

    def filter_words(self, word_set):
        filtered = []
        for word in word_set:
            if word in self.hypothesis:
                filtered.append(word)
        return filtered

    def get_word_knowledge(self, word):
        hyper = [] 
        hypo = [] 
        syn = [] 
        ant = []
        hyper, hypo, syn, ant = get_word_sets(
            singularize(word[0]), word[1].lower())
        hyper_fil = self.filter_words(hyper)
        hypo_fil = self.filter_words(hypo)
        syn_fil = self.filter_words(syn)
        ant_fil = self.filter_words(ant)
        return hyper_fil, hypo_fil, syn_fil, ant_fil

    def replace_token(self, orig, word_set):
        prev = orig
        for word in word_set:
            if prev[0] == word[0]:
                continue
            item = self.postags[prev]
            del self.postags[prev]
            self.postags[(word, prev[1], prev[2])] = item
            postags_cp = deepcopy(self.postags)
            self.sent_log.append(
                ' '.join([word[0] for word in list(postags_cp.popkeys())]))
            self.replacement_log.append(
                "{} => {}".format(prev[0], word))
            del self.postags[(word, prev[1], prev[2])]
            self.postags[prev] = item
            prev = (word, prev[1], prev[2])

    def generate(self, postags):  
        self.postags = deepcopy(postags)  
        for word in postags:    
            if word[1] in self.key_tokens:
                hyper, hypo, syn, ant = self.get_word_knowledge(word)
                self.replace_token(word, syn)
                self.replace_token(word, ant)

                if word[2] == "+":                
                    self.replace_token(word, hyper)
                if word[2] == "-":
                    self.replace_token(word, hypo)
        
            elif word[1] == "DET":    
                kb = quantifier.find({"word": tree.left.val.lower()})[0]
                self.replace_token(tree.left.val, kb["="])
                if word[2] == "+":
                    self.replace_token(tree.left.val, kb["<"])   
                if word[2] == "-":
                    self.replace_token(tree.left.val, kb[">"])

In [6]:
from pqdict import nlargest
import time

class AStarSearch:
    def __init__(self):    
        self.closed_forward = set()                   
        self.entailments = set()
        self.contradictions = set()
        self.hypothesis = ""
        self.sbert = sentenceBERT
        self.pipeline = PolarizationPipeline()
        self.phrasalGenerator = PhrasalGenerator()
        self.lexicalGenerator = LexicalGenerator()

    def generate_premises(self, start):
        self.entailments.clear()
        self.contradictions.clear()

        #tic = time.perf_counter()
        annotation = self.pipeline.single_polarization(start)
        #toc = time.perf_counter()
        #print(f"Polarize 1 Premise: {toc - tic:0.4f} seconds")
        #print("\nPolarization Complete")
        
        #print("\n====================================")
        #print("\nInit Premise: " + annotation['annotated'])
        #polarized = annotation['polarized']
        #btreeViz = Tree.fromstring(polarized.replace('[', '(').replace(']', ')'))
        #jupyter_draw_nltk_tree(btreeViz) 
        #print(annotation['word_dict'])

        #self.lexicalGenerator.generate(annotation['annotated'])
        #for gen_sent in self.lexicalGenerator.sent_log:
        #    self.entailments.add(gen_sent)
        #toc = time.perf_counter()
        #print(f"Lexical Generate 1 Premise: {toc - tic:0.4f} seconds")
        
        self.phrasalGenerator.deptree_generate(
            annotation['polarized_tree'], 
            annotation['annotated'], 
            annotation['original'])
        for gen_tree in self.phrasalGenerator.tree_log:
            leaves = gen_tree.sorted_leaves().popkeys()
            sentence = ' '.join([x[0] for x in leaves])
            print(sentence)
            self.entailments.add(sentence)
            #print("\nNext Premise: " + ' '.join(annotated))
        for gen_sent in self.phrasalGenerator.sent_log:
            #print(gen_sent)
            self.entailments.add(gen_sent)
            #print("\nNext Premise: " + ' '.join(gen_sent.popkeys()))
        #toc = time.perf_counter()
        #print(f"Phrasla Generate 1 Premise: {toc - tic:0.4f} seconds")

    def word_similarity(self, s1, s2):
        num_sim = 0
        seq1 = s1.split(" ")
        for w in seq1:
            if w in s2:
                num_sim += 1
        return num_sim / len(seq1)

    def inference_sts(self, seqs1, seqs2):
        embeddings1 = self.sbert.encode(seqs1, convert_to_tensor=True)
        embeddings2 = self.sbert.encode(seqs2, convert_to_tensor=True)
        cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)
        for i in range(len(seqs1)):
            cost1 = cosine_scores[i][i].data.cpu().numpy()
            cost2 = self.word_similarity(seqs1[i], seqs2[i])
            cost = cost1 * cost2
            return cost

    def generate_motion(self, starts, opened):
        for start in starts:
            #tic = time.perf_counter()
            self.generate_premises(start)
            #toc = time.perf_counter()
            #print(f"Generate Premises: {toc - tic:0.4f} seconds")

            for premise in self.entailments:
                #print(premise)
                if premise in self.closed_forward:
                    continue
                cost = self.inference_sts([premise], [self.hypothesis])
                if premise not in opened:
                    opened[premise] = 1 - cost
                if (1-cost) < opened[premise]:
                    opened[premise] = 1-cost

            #toc = time.perf_counter()
            #print(f"Handled Premises: {toc - tic:0.4f} seconds")

    def query(self, premises, hypothesis):
        self.closed_forward.clear()
        self.hypothesis = hypothesis.lower()

        kb = {"ADJ": ["beautiful", "red", "fragret"], 
              "ADV": ["ergently", "clearly", "neccesaraly"],
              "RCL": ["which is beautiful", "which opens at night"]}
        self.phrasalGenerator.kb = kb
        self.lexicalGenerator.hypothesis = self.hypothesis

        open_lists = pqdict({})
        open_lists[premises] = 1 - self.inference_sts([premises], [hypothesis])

        while open_lists:
            top_k = 5
            optimals = []
            for _ in range(5):
                if len(open_lists) > 0:
                    optimals.append(open_lists.popitem())
            print("\nOptimals: ", optimals)
            for optimal in optimals:
                if optimal[1] < 9.0e-07:
                    self.closed_forward.add((self.hypothesis, 0.0000))
                    return True
            #tic = time.perf_counter()
            self.generate_motion([x[0] for x in optimals], open_lists)
            #toc = time.perf_counter()
            #print(f"Handle New Premises: {toc - tic:0.4f} seconds")
            self.closed_forward.add(optimal)

In [7]:
search = AStarSearch()

In [8]:
search.query("A brown dog is attacking another animal in front of the tall man in pants", "A dog is attacking another animal in front of the man in pants")


Optimals:  [('A brown dog is attacking another animal in front of the tall man in pants', 0.2218152602513631)]
a dog is attacking another animal in front of the tall man in pants
a brown dog is attacking another animal in front of the man in pants

Optimals:  [('a dog is attacking another animal in front of the tall man in pants', 0.12686695371355328), ('a brown dog is attacking another animal in front of the man in pants', 0.1375848778656551)]
a dog is attacking another animal in front of the man in pants
a dog is attacking another animal in front of the man in pants

Optimals:  [('a dog is attacking another animal in front of the man in pants', 2.384185791015625e-07)]


True

In [9]:
search.closed_forward

{('A brown dog is attacking another animal in front of the tall man in pants',
  0.2218152602513631),
 ('a brown dog is attacking another animal in front of the man in pants',
  0.1375848778656551),
 ('a dog is attacking another animal in front of the man in pants', 0.0)}