In [2]:
import sys
import math
from pqdict import pqdict

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

paraphraseTokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-MRPC")  
paraphraseModel = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-MRPC")

from sentence_transformers import SentenceTransformer, util

classes = ["not paraphrase", "is paraphrase"]

def inference_mrpc(seq1s, seq2s):
    for i in range(len(seq1s)):
        paraphrase = paraphraseTokenizer.encode_plus(
            seq1s[i], seq2s[i], return_tensors="pt")
        logits = paraphraseModel(**paraphrase)[0]
        paraphrase_results = torch.softmax(logits, dim=1).tolist()[0]
        print(f"{classes[1]}: {round(paraphrase_results[1] * 100)}%")

In [4]:
from Udep2Mono.binarization import BinaryDependencyTree
from Udep2Mono.polarization import PolarizationPipeline
from Udep2Mono.util import btreeToList
from copy import deepcopy

from allennlp.predictors.predictor import Predictor
import allennlp_models.structured_prediction

ie_extractor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/openie-model.2020.03.26.tar.gz")

2020-12-25 23:38:42 INFO: Loading these models for language: en (English):
| Processor | Package   |
-------------------------
| tokenize  | gum       |
| pos       | gum       |
| lemma     | gum       |
| depparse  | gum       |
| sentiment | sstplus   |
| ner       | ontonotes |

2020-12-25 23:38:42 INFO: Use device: cpu
2020-12-25 23:38:42 INFO: Loading: tokenize
2020-12-25 23:38:42 INFO: Loading: pos
2020-12-25 23:38:43 INFO: Loading: lemma
2020-12-25 23:38:43 INFO: Loading: depparse
2020-12-25 23:38:44 INFO: Loading: sentiment
2020-12-25 23:38:45 INFO: Loading: ner
2020-12-25 23:38:46 INFO: Done loading processors!


In [5]:
from pattern.en import pluralize, singularize

def fix_info(desc):
    out = desc.replace("ARG0: ", "")
    out = out.replace("ARG1: ", "")
    out = out.replace("V: ", "")
    out = out.replace("[", "")
    out = out.replace("]", ",")
    out = out.split(",")
    out = list(map(lambda x: x.strip(), out))
    out = list(map(lambda x: x.split(" "), out))
    return out

# TODO: Verb Phrase Patterns
# 1.Intransitive: subject + VI
# 2.Linking: subject + VL + NP/AdjP
# 3.Transitive: subject + VT + NP
# 4.Ditransitive: subject + VD + NP(indirect) + NP(direct)
# 5.Complex Transitive: subject + VC + NP(direct) + NP/AdjP
# 6.Open clausal complement: VB + to/that VP
class PhrasalGenerator:
    def __init__(self):
        self.deptree = None
        self.annotated = None
        self.kb = {}
        self.tree_log = []
        self.sent_log = []
        self.mod_at_left = [
            "advmod", "amod", "advmod:count", "acl:relcl", 
            "acl", "advcl", "xcomp", "ccomp", "appos"]
        self.mod_at_right = []
        
        '''
            "ccomp": self.generate_ccomp,
            "compound": self.generate_inherite,
            "compound:prt": self.generate_inherite,
            "cop": self.generate_inherite,
            
            "expl": self.generate_expl,
            "nmod": self.generate_nmod,
            "nmod:npmod": self.generate_nmod,
            "nmod:tmod": self.generate_nmod,
            "nmod:poss": self.generate_nmod_poss,
            
            "nummod": self.generate_nummod,
            "obl": self.generate_obj,
            "obl:npmod": self.generate_oblnpmod,
            "obl:tmod": self.generate_inherite,
        '''

    def deptree_generate(self, tree, annotated, original):
        self.tree_log = []
        self.sent_log = []
        self.deptree = tree
        self.annotated = annotated
        self.ie_pred = {}
        verbs = ie_extractor.predict(original)['verbs']
        for verb in verbs:
            self.ie_pred[verb['verb']] = fix_info(verb['description'])        
        self.generate(self.deptree)

    def generate(self, tree):
        if tree.val in self.mod_at_left:
            self.left_modifier_generate(tree)
        elif tree.val == "conj" and tree.mark == "+":
            self.generate_conj(tree)
        elif tree.left != "N":
            self.generate_default(tree)

    def delete_left_modifier(self, tree):
        tree.val = tree.right.val
        tree.mark = tree.right.mark
        tree.npos = tree.right.npos
        tree.id = tree.right.id
        tree.left = tree.right.left
        tree.right = tree.right.right

    def rollback(self, tree, backup):
        tree.val = backup.val
        tree.left = deepcopy(backup.left)
        tree.right = deepcopy(backup.right)
        tree.mark = backup.mark
        tree.npos = backup.npos
        tree.id = backup.id

    def left_modifier_generate(self, tree):
        # adv + VB | VB + adv => VB
        # amod + Noun => Noun
        # Noun + relcl => Noun
        left = tree.left
        right = tree.right
        backup = deepcopy(tree)

        if right.mark == "+":
            self.delete_left_modifier(tree)
            self.save_tree(isTree=True)
            self.rollback(tree, backup)   

        self.generate(tree.left)
        self.generate(tree.right) 

    def rollback_annotation(self, generated, original):
        word_id = self.annotated[generated]
        del self.annotated[generated]
        self.annotated[original] = word_id

    def generate_conj(self, tree):
        backup = deepcopy(tree)
        self.left_modifier_generate(tree)

        tree.val = tree.left.right.val
        tree.mark = tree.left.right.mark
        tree.npos = tree.left.right.npos
        tree.id = tree.left.right.id
        tree.left = backup.left.right.left
        tree.right = backup.left.right.right
        self.save_tree(isTree=True)
        self.rollback(tree, backup)

    def add_modifier(self, tree, mod, head, direct=0):
        if direct == 0:
            generated = ' '. join([mod, head])
        else:
            generated = ' '. join([head, mod])
        word_id = self.annotated[head]
        del self.annotated[head]
        self.annotated[generated] = word_id
        self.save_tree(isTree=False)
        self.rollback_annotation(generated, head)

    def generate_default(self, tree):
        left = tree.left
        right = tree.right

        if right.npos is not None:
            if "NN" in right.npos and right.mark == "-":
                for adj in kb["ADJ"]:
                    self.add_modifier(tree, adj, right.val)
                for rel in kb["RCL"]:
                    self.add_modifier(tree, rel, right.val, 1)
            elif "VB" in right.npos and right.mark == "-":
                for adv in kb["ADV"]:
                    self.add_modifier(tree, adv, right.val)
                    description = self.ie_pred[right.val]
                    arg1 = description[2][-1]
                    self.add_modifier(tree, adv, arg1, 1)

        self.generate(left)
        self.generate(right)  

    def save_tree(self, isTree):
        if isTree:
            self.tree_log.append(deepcopy(self.deptree))
        else:
            self.sent_log.append(deepcopy(self.annotated))

In [6]:
class AStarSearch:
    def __init__(self):    
        self.closed_forward = set()                        
        self.closed_backward = set()                     
        self.entailments = set()
        self.contradictions = set()
        self.hypothesis = ""
        self.phrasalGenerator = PhrasalGenerator()

        model_name = "roberta-large-nli-stsb-mean-tokens"
        self.sbert = SentenceTransformer(model_name)

    def generate_premises(self, start):
        pipeline = PolarizationPipeline([start], verbose=0, parser="stanza")
        pipeline.run_polarize_pipeline()
        print("\nPolarization Complete")

        for annotation in pipeline.annotations:
            print("\n====================================")
            print("\nInit Premise: " + annotation['annotated'])
            #polarized = annotation['polarized']
            #btreeViz = Tree.fromstring(polarized.replace('[', '(').replace(']', ')'))
            #jupyter_draw_nltk_tree(btreeViz) 
            self.phrasalGenerator.deptree_generate(
                annotation['polarized_tree'], annotation['word_dict'], annotation['original'])
            for gen_tree in phrasalGenerator.tree_log:
                generated, queue, _, _ = btreeToList(gen_tree, len(annotation['original']), {}, 0)
                annotated = list(queue.popkeys())
                print("\nNext Premise: " + ' '.join(annotated))
            for gen_sent in phrasalGenerator.sent_log:
                print("\nNext Premise: " + ' '.join(gen_sent.popkeys()))

    def word_similarity(self, s1, s2):
        num_sim = 0
        seq1 = s1.split(" ")
        for w in seq1:
            if w in s2:
                num_sim += 1
        return num_sim / len(seq1)

    def inference_sts(self, seqs1, seqs2):
        embeddings1 = self.sbert.encode(seqs1, convert_to_tensor=True)
        embeddings2 = self.sbert.encode(seqs2, convert_to_tensor=True)
        cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)
        for i in range(len(seqs1)):
            cost1 = cosine_scores[i][i]
            cost2 = self.word_similarity(seqs1[i], seqs2[i])
            cost = cost1 * cost2
            return cost

    def clear(self):
        self.closedF.clear()
        self.closedR.clear()

    def generate_motion(self, open_set, side):
        closed = self.closed_forward if side == 0 else self.closed_backward
        opened = open_set[side]
        self.generate_premises()
        for premise in self.entailments:
            if premise in closed:
                continue
            cost = self.inference_sts([premise], [self.hypothesis])
            if premise not in opened:
                opened[premise] = cost
            elif cost > opened[premise]:
                opened[premise] = cost

    def query(self, premises, hypothesis):
        self.clear()
        self.hypothesis = hypothesis
        kb = {"ADJ": ["beautiful", "red", "fragret"], 
              "ADV": ["ergently", "clearly", "neccesaraly"],
              "RCL": ["which is beautiful", "which opens at night"]}
        self.phrasalGenerator.kb = kb
        open_lists = [pqdict({}), pqdict({})]
        open_lists[0][premises] = self.inference_sts([premise], [hypothesis])
        open_lists[0][hypothesis] = self.inference_sts([hypothesis], [hypothesis])

        while open_lists[0] or open_lists[1]:
            while open_lists[0]:
                optimal = open_lists[0].pop()
                break
            self.generate_motion(open_list, 0)
            if optimal in self.closed_backward:
                break
            self.closed_forward.add(optimal)

            while open[1]:
                optimal = open_lists[1].pop()
                break
            self.generate_motion(open_list, 0)
            if optimal in self.closed_forward:
                break
            self.closed_backward.add(optimal)

        self.closed_forward = self.closed_forward | self.closed_backward

In [7]:
search = AStarSearch()

NameError: name 'SentenceTransformer' is not defined

In [None]:
search.query("All flowers need light", "All red flowers need light")