# UdepLog Neural-Logical Inference System

## 1. Visualization

In [1]:
import os
from nltk.tree import Tree
from nltk.draw import TreeWidget
from nltk.draw.util import CanvasFrame
from IPython.display import Image, display

def jupyter_draw_nltk_tree(tree):
    cf = CanvasFrame()
    tc = TreeWidget(cf.canvas(), tree)
    tc['node_font'] = 'arial 14 bold'
    tc['leaf_font'] = 'arial 14'
    tc['node_color'] = '#005990'
    tc['leaf_color'] = '#3F8F57'
    tc['line_color'] = '#175252'
    cf.add_widget(tc, 20, 20)
    os.system('rm -rf ../data/tree.png')
    os.system('rm -rf ../data/tree.ps')
    cf.print_to_file('../data/tree.ps')
    cf.destroy()
    os.system('convert ../data/tree.ps ../data/tree.png')
    display(Image(filename='../data/tree.png'))

## 2. BERT Model for Pharaphrase

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

roberta_MRPC = "textattack/roberta-base-MRPC"
bert_MRPC = "bert-base-cased-finetuned-mrpc"

paraphraseTokenizer = AutoTokenizer.from_pretrained(roberta_MRPC)  
paraphraseModel = AutoModelForSequenceClassification.from_pretrained(roberta_MRPC)

Some weights of the model checkpoint at textattack/roberta-base-MRPC were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## 3. UD Parser and RoBERTa Semantic Similarity

In [3]:
from copy import deepcopy
from Udep2Mono.util import btree2list
from Udep2Mono.dependency_parse import tokenizer
from Udep2Mono.dependency_parse import dependency_parse
from Udep2Mono.binarization import BinaryDependencyTree
from Udep2Mono.polarization import PolarizationPipeline

from sentence_transformers import SentenceTransformer, util
sentenceTransformer = SentenceTransformer("roberta-base-nli-stsb-mean-tokens")

def inference_sts(seq1s, seq2s):
    embeddings1 = sentenceTransformer.encode(seq1s, convert_to_tensor=True)
    embeddings2 = sentenceTransformer.encode(seq2s, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)
    return cosine_scores[0][0]

2021-03-12 23:01:45 INFO: Loading these models for language: en (English):
| Processor | Package                  |
----------------------------------------
| tokenize  | ../model/e...ize/gum.pt  |
| pos       | ../model/en/pos/ewt.pt   |
| lemma     | ../model/en/lemma/gum.pt |
| depparse  | ../model/e...rse/gum.pt  |

2021-03-12 23:01:45 INFO: Use device: gpu
2021-03-12 23:01:45 INFO: Loading: tokenize
2021-03-12 23:01:46 INFO: Loading: pos
2021-03-12 23:01:47 INFO: Loading: lemma
2021-03-12 23:01:47 INFO: Loading: depparse
2021-03-12 23:01:47 INFO: Done loading processors!
2021-03-12 23:01:47 INFO: Loading these models for language: en (English):
| Processor | Package                 |
---------------------------------------
| tokenize  | ../model/e...ize/gum.pt |

2021-03-12 23:01:47 INFO: Use device: cpu
2021-03-12 23:01:47 INFO: Loading: tokenize
2021-03-12 23:01:47 INFO: Done loading processors!


## 4. Phrasal Monotonicity Inference

In [23]:
from pattern.en import pluralize, singularize
from copy import copy
import re

class PhrasalGenerator:
    def __init__(self):
        self.deptree = None
        self.annotated = None
        self.original = None
        self.kb = {}
        self.hypothesis = ""
        self.tree_log = []
        self.sent_log = []
        self.stop_critarion = False
        self.mod_at_left = [
            "advmod", "amod", "advmod:count", 
            "acl:relcl", "obl", 'obl:npmod', 
            "obl:tmod", "nmod", "nmod:npmod", 
            "nmod:poss", "nmod:tmod", "obl:npmod",
            "acl", "advcl", "xcomp", "ccomp", 
            "appos", 'compound:ptr']
        self.mod_at_right = []
        self.mod_symmetric = ["conj", "compound"]
        
        '''  
            "cop": self.generate_inherite, 
            "expl": self.generate_expl,
            "nummod": self.generate_nummod,
        '''

    def deptree_generate(self, tree, annotated, original):
        self.stop_critarion = False
        self.tree_log = []
        self.sent_log = []
        self.deptree = tree.copy()
        self.original = original  
        self.annotated = deepcopy(annotated)
        self.sentence = original
        self.generate(self.deptree)

    def generate(self, tree):
        if self.stop_critarion:
            return
        if not tree.is_tree:
            self.generate_default(tree)
        else:
            generation = self.get_generation_type(tree)
            generation(tree)

    def get_generation_type(self, tree):
        disjunction = False
        if tree.val == "conj":
            disjunction = self.search_dependency('or', tree.left)

        
        left_mod = tree.left.mark == "+"
        left_mod = left_mod or tree.left.mark == "=" or disjunction
        left_mod = left_mod and tree.val in self.mod_at_left

        right_mod = tree.right.mark == "+" or tree.right.mark == "="
        right_mod = right_mod or disjunction 

        sym_mod = tree.val in self.mod_symmetric and left_mod and right_mod

        if left_mod:
            return self.left_modifier_generate
        elif sym_mod:
            return self.symmetric_generate
        else:
            return self.generate_default

    def delete_cc(self, tree):
        if tree.val == "cc" and tree.left.val != "but":
            self.delete_modifier(tree, tree.right)

        if tree.is_tree:
            self.delete_cc(tree.left)
            self.delete_cc(tree.right)

    def delete_modifier(self, tree, modifier):
        tree.val = modifier.val
        tree.mark = modifier.mark
        tree.pos = modifier.pos
        tree.id = modifier.id
        
        tree.is_tree = modifier.is_tree
        tree.is_root = modifier.is_root

        tree.left = modifier.left
        tree.right = modifier.right

        self.delete_cc(tree)
        self.save_tree()

    def delete_left_modifier(self, tree):
        self.delete_modifier(tree, tree.right)

    def delete_right_modifier(self, tree):
        self.delete_modifier(tree, tree.left)

    def rollback(self, tree, backup):
        tree.val = backup.val
        tree.left = deepcopy(backup.left)
        tree.right = deepcopy(backup.right)
        tree.mark = backup.mark
        tree.pos = backup.pos
        tree.id = backup.id
        tree.is_tree = backup.is_tree
        tree.is_root = backup.is_root

    def symmetric_generate(self, tree):
        self.right_modifier_generate(tree)
        self.left_modifier_generate(tree)
        self.delete_cc(tree)

    def right_modifier_generate(self, tree):
        left = tree.left
        right = tree.right
        backup = deepcopy(tree)

        self.delete_right_modifier(tree)
        self.save_tree()
        self.rollback(tree, backup)    
        
        self.generate(tree.left)
        self.generate(tree.right)

    def left_modifier_generate(self, tree):
        left = tree.left
        right = tree.right
        backup = deepcopy(tree)

        self.delete_left_modifier(tree)
        self.save_tree()
        self.rollback(tree, backup)   

        self.generate(tree.left)
        self.generate(tree.right)
    
    def return_last_leaf(self, tree):
        max_id = 0
        max_id_l = 0
        max_id_r = 0

        if tree.id != None:
            max_id = int(tree.id)
    
        if tree.left.is_tree:
            max_id_l = self.return_last_leaf(tree.left)
        else:
            max_id_l = tree.left.id

        if tree.right.is_tree:
            max_id_r = self.return_last_leaf(tree.right)
        else:
            max_id_r = tree.right.id

        return max(max_id, max(max_id_l, max_id_r))

    def return_first_leaf(self, tree):
        min_id = 100
        min_id_l = 100
        min_id_r = 100

        if tree.id != None:
            min_id = int(tree.id)
    
        if tree.left.is_tree:
            min_id_l = self.return_last_leaf(tree.left)
        else:
            min_id_l = tree.left.id

        if tree.right.is_tree:
            min_id_r = self.return_last_leaf(tree.right)
        else:
            min_id_r = tree.right.id

        return min(min_id, min(min_id_l, min_id_r))

    def add_modifier_sent(self, tree, modifier, direct=0): 
        sentence = deepcopy(self.sentence)
        if direct == 0:
            last_leaf = self.return_first_leaf(tree)
            sentence.insert(last_leaf-1, modifier)
        elif direct == 1:
            last_leaf = self.return_last_leaf(tree)
            sentence.insert(last_leaf, modifier)        

        self.remove_adjcent_duplicate(sentence)
        sentence = ' '.join(sentence)
        sentence = sentence.replace("-", " ")
        sentence = sentence.replace(" 's", "'s")

        if abs(len(sentence) - len(self.hypothesis)) < 15:
            re.sub(r'((\b\w+\b.{1,2}\w+\b)+).+\1', r'\1', sentence, flags = re.I)
            sentence = sentence.strip() 
            
            if sentence.lower() == self.hypothesis.lower():
                self.stop_critarion = True
                self.sent_log.append((sentence, 1.0))
                return
                
            similarity = inference_sts([sentence], [self.hypothesis])
            if similarity > 0.90:
                self.sent_log.append((sentence, similarity))
            if similarity > 0.97:
                self.sent_log = []
                self.sent_log.append((sentence, similarity))
                self.stop_critarion = True

    def add_modifier_lexical(self, tree, modifier, head, word_id, direct=0):
        if direct == 0:
            generated = ' '. join([modifier, head])
        else:
            generated = ' '. join([head, modifier])
        
        sentence = deepcopy(self.sentence)
        diff = 0
        if word_id > len(sentence):
            diff = word_id - len(sentence)

        goal = word_id-1-diff
        sentence[goal] = "DEL"
        sentence[goal:goal] = generated.split(' ')

        if abs(len(sentence) - len(self.hypothesis.split(' '))) < 7:
            self.remove_adjcent_duplicate(sentence)
            sentence = ' '.join(sentence)
            sentence = sentence.replace("DEL ", "")
            sentence = sentence.replace("DEL", "")
            sentence = sentence.replace("-", " ")
            sentence = sentence.replace(" 's", "'s")
            re.sub(r'((\b\w+\b.{1,2}\w+\b)+).+\1', r'\1', sentence, flags = re.I)
            sentence = sentence.strip()

            if sentence.lower() == self.hypothesis.lower():
                self.stop_critarion = True
                self.sent_log.append((sentence, 1.0))
                return
            
            similarity = inference_sts([sentence], [self.hypothesis])
            if similarity > 0.9:
                self.sent_log.append((sentence, similarity))
            if similarity > 0.97:
                self.sent_log = []
                self.sent_log.append((sentence, similarity))
                self.stop_critarion = True

    def generate_default(self, tree):
        VP_rel = {
            "aux":1, 
            "obj":1, 
            "obl":1, 
            "xcomp":1, 
            "ccomp":1,
            "aux:pass":1, 
            "obl:tmod":1, 
            "obl:npmod":1
        }

        VP_mod = {
            "advcl":1, 
            "xcomp":1, 
            "ccomp":1,
            "obj":1, 
            "advmod":1, 
            "obl":1, 
            "obl:tmod":1,
            "obl:nmod":1, 
            "parataxis":1, 
            "conj":1
        }

        NP_rel = {
            "amod":1,
            "compound":1,
            "det":1,
            "mark":1,
            "nmod:poss":1,
            "flat":1,
            "acl:relcl":1,
            "acl":1,
            "nmod":1
        }

        NP_mod = {
            "amod":1,
            "compound":1,
            "det":1,
            "mark":1,
            "nmod:poss":1,
            "flat":1,
        }

        if tree.pos is not None:
            if "NN" in tree.pos and tree.mark == "-":
                for rel in ["amod", "compound", "det", "mark", "nmod:poss", "flat", "conj"]:
                    if rel in self.kb:
                        for phrase in self.kb[rel]:
                            if phrase['head'] == tree.val:
                                self.add_modifier_lexical(tree, phrase['mod'], tree.val, tree.id)
                for rel in ["amod", "acl:relcl", "compound", "acl", "nmod"]:
                    if rel in self.kb:
                        for phrase in self.kb[rel]:
                            if phrase['head'] == tree.val:
                                self.add_modifier_lexical(tree, phrase['mod'], tree.val, tree.id, 1)
                
            elif "VB" in tree.pos and tree.mark == "-":
                for rel in ["advmod"]:
                    if rel in self.kb:
                        for phrase in self.kb[rel]:
                            self.add_modifier_lexical(tree, phrase['mod'], tree.val, tree.id)
                            self.add_modifier_lexical(tree, phrase['mod'], tree.val, tree.id, 1)

        elif VP_rel.get(tree.val, 0) and tree.mark == "-":
            for rel in VP_mod:
                if rel in self.kb:
                    for phrase in self.kb[rel]:
                        self.add_modifier_sent(tree, phrase['mod'], direct=1)

        elif NP_rel.get(tree.val, 0) and tree.mark == "-":
            for rel in NP_mod:
                if rel in self.kb:
                    for phrase in self.kb[rel]:
                        self.add_modifier_sent(tree, phrase['mod'], direct=0)

        if tree.is_tree:
            self.generate(tree.left)
            self.generate(tree.right)  

    def save_tree(self):
        leaves = self.deptree.sorted_leaves().popkeys()
        sentence = ' '.join([x[0] for x in leaves])

        if sentence.lower() == self.hypothesis.lower():
            self.tree_log = []
            self.stop_critarion = True
            self.tree_log.append((self.deptree.copy(), sentence, 1.0))
            return
        
        similarity = inference_sts([sentence], [self.hypothesis])
        if similarity > 0.9:
            self.tree_log.append((self.deptree.copy(), sentence, similarity))
        if similarity > 0.97:
            self.tree_log = []
            self.tree_log.append((self.deptree.copy(), sentence, similarity))
            self.stop_critarion = True
    
    def remove_adjcent_duplicate(self, string):
        to_remove = -1
        for i in range(len(string)-1):
            if string[i] == string[i+1]:
                to_remove = i
        if to_remove > -1:
            del string[to_remove]

    def search_dependency(self, deprel, tree):
        if tree.val == deprel:
            return True
        else:
            right = tree.right
            left = tree.left

            left_found = False
            right_found = False

            if right is not None:
                right_found = self.search_dependency(deprel, right)

            if left is not None:
                left_found = self.search_dependency(deprel, left)

            return left_found or right_found
    
    def Diff(self, li1, li2):
        return (list(list(set(li1)-set(li2)) + list(set(li2)-set(li1))))    
    
    def preprocess(self, sentence):
        preprocessed = sentence.replace(".", "").replace("!", "").replace("?", "")
        preprocessed = preprocessed.replace("can't", "can not")
        preprocessed = preprocessed.replace("couldn't", "could not")
        preprocessed = preprocessed.replace("don't", "do not")
        preprocessed = preprocessed.replace("doesn't", "does not")
        preprocessed = preprocessed.replace("isn't", "is not")
        preprocessed = preprocessed.replace("won't", "will not")
        preprocessed = preprocessed.replace("wasn't", "was not")
        preprocessed = preprocessed.replace("weren't", "were not")
        preprocessed = preprocessed.replace("didn't", "did not")
        preprocessed = preprocessed.replace("aren't", "are not")
        preprocessed = preprocessed.replace("it's", "it is")
        preprocessed = preprocessed.replace("wouldn't", "would not")
        preprocessed = preprocessed.replace("There's", "There is")
        return preprocessed

In [5]:
modifier_relation = {
    "NN": ["amod", "nmod", "acl:relcl", "fixed", "compound", "det", "nmod:poss", "conj"],
    "VB": ["advmod", "acl", "obl", "xcomp", "advcl", "obl:tmod", "parataxis", "obj","ccomp"]
}

def down_right(tree):
    if(tree.right == None):
        return tree
    return down_right(tree.right)

def down_left(tree):
    if(tree.left == None):
        return tree
    return down_left(tree.left)

def collect_modifiers(tree, sent_set, mod_type="NN"):
    leaves = []
    if tree.is_tree:
        if tree.val in ["mark", "case", "compound", "flat"]:
            leaves.append(
                (list(tree.right.sorted_leaves().popkeys()),
                down_right(tree.left).val)
            )
        if tree.val in modifier_relation[mod_type]:
            leaves.append(
                (list(tree.left.sorted_leaves().popkeys()),
                down_right(tree.right).val)
            )

        for leave in leaves:
            if len(leave) > 0 and len(leave) < 10:
                head = leave[1]
                modifier = ' '.join([x[0] for x in leave[0]])
                if tree.val in sent_set:
                    sent_set[tree.val].append({'head': head,'mod': modifier})
                else:
                    sent_set[tree.val] = [{'head': head,'mod': modifier}]
        
        collect_modifiers(tree.left, sent_set, mod_type)
        collect_modifiers(tree.right, sent_set, mod_type)

## 5. Lexical Monotonicity Inference

In [6]:
class LexicalGenerator:
    def __init__(self):
        self.postags = None
        self.hypothesis = ""
        self.sent_log = []
        self.replacement_log = []
        self.key_tokens = [
            'NN','NNS','NNP','NNPS','VBD',
            'VBG','VBN','VBZ','VB']

    def filter_words(self, word_set):
        filtered = []
        for word in word_set:
            if word in self.hypothesis:
                filtered.append(word)
        return filtered

    def get_word_knowledge(self, word):
        hyper = [] 
        hypo = [] 
        syn = [] 
        ant = []
        hyper, hypo, syn, ant = get_word_sets(
            singularize(word[0]), word[1].lower())
        hyper_fil = self.filter_words(hyper)
        hypo_fil = self.filter_words(hypo)
        syn_fil = self.filter_words(syn)
        ant_fil = self.filter_words(ant)
        return hyper_fil, hypo_fil, syn_fil, ant_fil

    def replace_token(self, orig, word_set):
        prev = orig
        for word in word_set:
            if prev[0] == word[0]:
                continue
            item = self.postags[prev]
            del self.postags[prev]
            self.postags[(word, prev[1], prev[2])] = item
            postags_cp = deepcopy(self.postags)
            self.sent_log.append(
                ' '.join([word[0] for word in list(postags_cp.popkeys())]))
            self.replacement_log.append(
                "{} => {}".format(prev[0], word))
            del self.postags[(word, prev[1], prev[2])]
            self.postags[prev] = item
            prev = (word, prev[1], prev[2])

    def generate(self, postags):  
        self.postags = deepcopy(postags)  
        for word in postags:    
            if word[1] in self.key_tokens:
                hyper, hypo, syn, ant = self.get_word_knowledge(word)
                self.replace_token(word, syn)
                self.replace_token(word, ant)

                if word[2] == "+":                
                    self.replace_token(word, hyper)
                if word[2] == "-":
                    self.replace_token(word, hypo)
        
            elif word[1] == "DET":    
                kb = quantifier.find({"word": tree.left.val.lower()})[0]
                self.replace_token(tree.left.val, kb["="])
                if word[2] == "+":
                    self.replace_token(tree.left.val, kb["<"])   
                if word[2] == "-":
                    self.replace_token(tree.left.val, kb[">"])

In [39]:
## 6. Syntactic Variational Inference

In [7]:
from chunker import Chunker

class SyntacticVariator:
    
    def __init__(self):
        self.chunker = Chunker()
        self.paraphraseTokenizer = paraphraseTokenizer
        self.paraphraseModel = paraphraseModel

    def chunking(self, tree):
        return self.chunker.get_chunks_byDepTree(tree)

    def build_pairs(self, chunks1, chunks2):
        chunk_pairs = []
        for chunk1 in chunks1:
            for chunk2 in chunks2:
                if len(set(chunk1.split(' ')).intersection(chunk2.split(' '))) > 0:
                     chunk_pairs.append((chunk1, chunk2))

        return chunk_pairs

    def inference_mrpc(self, seq1, seq2):
        paraphrase = paraphraseTokenizer.encode_plus(
            seq1, seq2, return_tensors="pt")
        logits = paraphraseModel(**paraphrase)[0]
        paraphrase_results = torch.softmax(logits, dim=1).tolist()[0]
        return paraphrase_results[1]

    def phrase_alignment(self, chunk_pairs):
        alignments = []
        for pair in chunk_pairs:
            score = self.inference_mrpc(pair[0], pair[1])
            if score > 0.85:
                alignments.append(pair)

        return alignments

    def variate(self, sentence, p_tree, h_tree):
        p_chunks = self.chunking(p_tree)
        h_chunks = self.chunking(h_tree)

        chunk_pairs = self.build_pairs(p_chunks, h_chunks)
        alignments = self.phrase_alignment(chunk_pairs)

        var_sentence = copy(sentence)
        for align in alignments:
            var_sentence = var_sentence.replace(align[0], align[1])

        return var_sentence

## 7. A* Inference Search Engine

In [35]:
from pqdict import pqdict

class AStarPlanner:
    def __init__(self):    
        self.closed = []                  
        self.entailments = set()
        self.contradictions = set()
        self.hypothesis = ""
        self.h_tree = None

        self.pipeline = PolarizationPipeline()
        self.phrasalGenerator = PhrasalGenerator()
        self.lexicalGenerator = LexicalGenerator()
        self.syntacticVariator = SyntacticVariator() 

    def hypothesis_kb(self):
        self.hypothesis = self.phrasalGenerator.preprocess(self.hypothesis)
        h_parsed, replaced = dependency_parse(self.hypothesis, parser="stanza")
        h_tree, _ = self.pipeline.run_binarization(h_parsed, self.hypothesis, {})
        self.pipeline.modify_replacement(h_tree, replaced)
        phrases = {} 
        collect_modifiers(h_tree, phrases, mod_type="NN")
        collect_modifiers(h_tree, phrases, mod_type="VB")
        self.phrasalGenerator.kb = phrases
        self.h_tree = h_tree

    def generate_premises(self, start):
        self.entailments.clear()
        self.contradictions.clear()

        start = self.phrasalGenerator.preprocess(start)
        annotation = self.pipeline.single_polarization(start)
        self.phrasalGenerator.hypothesis = self.hypothesis.replace(',', '')

        tokenized = tokenizer(start).sentences[0].words
        tokens = [tok.text for tok in tokenized]

        self.phrasalGenerator.deptree_generate(
        annotation['polarized_tree'], 
        annotation['annotated'], tokens)
        
        #print(start)
        #print(*self.phrasalGenerator.sent_log, sep='\n')
        #print("============================")
        
        #print(self.phrasalGenerator.sent_log)
        
        if self.phrasalGenerator.stop_critarion:
            return True

        self.entailments |= set(self.phrasalGenerator.sent_log)

        variate = self.syntacticVariator.variate(start, annotation['polarized_tree'], self.h_tree)
        similarity = inference_sts([variate], [self.hypothesis])
        if similarity > 0.98:
            return True
        
        self.entailments.add((variate, similarity))

        #self.lexicalGenerator.generate(annotation['annotated'])
        #for gen_sent in self.lexicalGenerator.sent_log:
        #    self.entailments.add(gen_sent)
        #toc = time.perf_counter()
        #print(f"Lexical Generate 1 Premise: {toc - tic:0.4f} seconds")

        return False

    def generate(self, start, opened):
        terminate = self.generate_premises(start)
        if terminate:
            return True

        for premise in self.entailments:
            if premise in self.closed:
                continue
            cost = premise[1]
            if premise[0] not in opened:
                opened[premise[0]] = cost
            if cost < opened[premise[0]]:
                opened[premise[0]] = cost
        return False

    def search(self, premises, hypothesis):
        self.closed = pqdict({})
        self.hypothesis = hypothesis

        self.hypothesis_kb()
        self.phrasalGenerator.hypothesis = self.hypothesis
        self.lexicalGenerator.hypothesis = self.hypothesis

        open_lists = pqdict({}, reverse=True)
        open_lists[premises] = inference_sts([premises], [hypothesis])

        hop = 0
        top_k = 2

        while open_lists:
            for _ in range(top_k):
                if len(open_lists) > 0:
                    optimal = open_lists.popitem()
                    goal_found = self.generate(optimal[0], open_lists)
                    self.closed[optimal] = len(self.closed) + 1
                    if goal_found:
                        self.closed[(self.hypothesis, 1.0)] = len(self.closed) + 1
                        return True
                else: break   
            hop += 1
            if hop > 5: break

In [36]:
planner = AStarPlanner()

In [37]:
entail = planner.search("A brown dog is attacking another animal in front of the tall man in pants", 
             "A dog is attacking another animal in front of the man in pants")
print(*planner.closed, sep=" =>\n")
print(entail)

('A brown dog is attacking another animal in front of the tall man in pants', tensor(0.8783)) =>
('A brown dog is attacking another animal in front of the man in pants', tensor(0.9273)) =>
('A dog is attacking another animal in front of the man in pants', 1.0)
True


In [30]:
entail = planner.search("A family is watching a little boy who is hitting a baseball", 
             "A family is watching a boy who is hitting a baseball")
print(*planner.closed, sep=" =>\n")
print(entail)

('A family is watching a little boy who is hitting a baseball', tensor(0.9915)) =>
('A family is watching a boy who is hitting a baseball', 1.0)
True


In [34]:
entail = planner.search("You can't park in front of my house on weekends.", 
             "You can't park in front of my large house on weekends.")

print(*planner.closed, sep=" =>\n")
print(entail)

("You can't park in front of my house on weekends.", tensor(0.9529)) =>
('You can not park in front of my large house on weekends', 1.0)
True
