# Step 1: preparing the data structure.

## Loading the data
The example we use in this project come from this repository: https://github.com/kovvalsky/LangPro
They have been saved to a txt file after slight modification (replacing "\" with "\_"), which makes it easier to load them into python. 

The "examples" list contains separate examples from the prolog file ""

In [2]:
examples = []
with open("LOLA_SICK_corrected_derivations.txt", "r") as f:
    lines = [line for line in f]
    curr_example = ""
    for line in lines:
        if line == '\n':
            examples.append(curr_example)
            curr_example = ""
        else:
            curr_example += line

## Building the tree structure

In [1]:
from TreeNode import TreeNode
from build_tree_utils import get_line_dict, build_the_tree

The function "reload tree" creates trees from the list of examples; it can also be done for debugging purposes, as it creates a pure tree, without any parameters sepcified (only parents and children relationships are created) 
Parameter "n" specifies the number of examples that are being loaded. 

In [3]:
# input: ccg derivation in string format
# output: tree
examples_roots = {}
def reload_tree():
    examples_roots = {}
    for i in range(0, 15):
        line_dict = get_line_dict(examples[i]) # ccg derivation -> lines
        roots, leaf = build_the_tree(line_dict) # lines -> tree
        examples_roots[i] = (examples[i], roots, leaf)
  
    return examples_roots

examples_roots = reload_tree()

At this point, the representation of each example is saved in examples_roots dictionary:
* the keys are derivations ids;
* first value is the prolog derivation in string format
* the second one - the roots of the derivation, and 
* the third one - the leaf (which should, in the end, represent the final meaning of the derivation.

# Step 2: Getting First Order Logic Meaning Representation

We defined lexical semantic in a dictionary, which uses POS tags, types and some exceptions as keys. See the file lexical_semantic_rules.py to see it in details.
The function "get_compositional_semantics" traverse the tree and sets compositional semantics to each node.

In [7]:
from lexical_semantic_rules import lexical_semantic_rules 
from get_semantics import assign_compositional_semantics, get_lexical_semantics 

from nltk.sem.logic import *
read_expr = Expression.fromstring

The function "derivation_to_fol" takes id of the derivation as an argument, and return First Order Logic Formula representing the meaning of this derivation

In [8]:
# input: id of derivation
# output: the FOL representation of the derivation
def derivation_to_FOL(id):
    line_dict = get_line_dict(examples[id - 1]) # ccg derivation -> lines
    roots, leaf = build_the_tree(line_dict) # lines -> tree
    examples_roots = (examples[id - 1], roots, leaf)
    
    # get lexical semantics for every word
    for i in range(0,len(examples_roots[1])):
        get_lexical_semantics(examples_roots[1][i])
    
    # get lambda representation
    assign_compositional_semantics(examples_roots[2])

    # return FOL
    lr = examples_roots[2].get_lexical_semantics()
    fol = read_expr(f"({lr})(True)").simplify()
    return str(fol)

derivation_to_FOL(15)

'exists x.(exists z1.(two(z1) & argu1(x,z1) & dog(x)) & exists z3.(play(z3) & argu1(z3,x) & exists z2.(tree(z2) & by(z3,z2) & True(z3))))'

# Step 3: Get relations from WordNet and model them as FOL formulas


In [9]:
import nltk
from nltk.corpus import wordnet as wn
from nltk.corpus.reader.wordnet import Synset
from typing import List
nltk.download('wordnet')
nltk.download('omw-1.4')
import re

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\weron\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\weron\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Define is_a relationship based on hypernym relationships from wordnet:

In [10]:
def is_a(w1: str, w2: str):
    ''' 
    return true if w1 is more specific than w2
    return false otherwise
    '''
    w1list = wn.synsets(w1)
    w2list = wn.synsets(w2)
    for syn in w1list:
        queue = []
        queue.append(syn)
        seen = set()
        seen.add(syn)
        while len(queue) > 0:
            start = queue.pop(0)
            end = start.hypernyms()
            for e in end:
              if e not in seen:
                  if e in w2list:
                    seen.add(e)
                    return True
              if e not in w2list:
                  queue.append(e)
                  seen.add(e)
    return False

**!!!!!!!!!!!!!!1someone please describe what this function does!!!!!!!!!!!!!!**

In [11]:
# extract words from FOL
def fol_to_word(s):
  # remove numbers
  s = re.sub(r'[0-9]+', '', s)
  # remove 'exists' 'all'
  s = s.replace("exists", "")
  s = s.replace("all", "")
  # "&" to "("
  s = s.replace("&", "(")
  # split "("
  l = s.split("(")
  ll = []
  for i in range(0,len(l)):
    if '.' in l[i]:
      ll.append(l[i])
    elif ')' in l[i]:
      ll.append(l[i])
    elif 'argu' in l[i]:
      ll.append(l[i])
    elif 'True' in l[i]:
      ll.append(l[i])
  s = list(set(l) - set(ll))
  return s

In [12]:
def get_knowledge(hypothesis: str, premises: List[str]):
  add_premises = []
  add_fol = []
  for premise in premises:
    p_word = premise.split(' ')
    h_word = hypothesis.split(' ')
    same_word = set(p_word).intersection(set(h_word))
    p_word = set(p_word) - same_word
    h_word = set(h_word) - same_word
    # print(p_word,h_word)
    for w1 in p_word:
      for w2 in h_word:
        if is_a(w1, w2):
          add_premises.append(f'{w1} is a {w2}')
          add_fol.append(f'all x.({w1}(x) -> {w2}(x))')
        # check knowlegde base for is_a relations for w1 and w1
  return add_premises, add_fol

# Step 4: Use Tableau prover to detect inference label

We decided to use prover9, which needs to be externally downloaded.

In [13]:
!curl -O "https://www.cs.unm.edu/~mccune/prover9/gui/p9m4-v05.tar.gz"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  4  868k    4 39765    0     0  45478      0  0:00:19 --:--:--  0:00:19 45497
100  868k  100  868k    0     0   578k      0  0:00:01  0:00:01 --:--:--  579k


In [14]:
!tar -xvzf "p9m4-v05.tar.gz"

x p9m4-v05/
x p9m4-v05/bin/
x p9m4-v05/bin/prover9
x p9m4-v05/bin/mace4
x p9m4-v05/bin/interpformat
x p9m4-v05/bin/prooftrans
x p9m4-v05/bin/isofilter
x p9m4-v05/bin/isofilter2
x p9m4-v05/bin-win32/
x p9m4-v05/control.py
x p9m4-v05/Mac.build
x p9m4-v05/err
x p9m4-v05/files.py
x p9m4-v05/Mac.README
x p9m4-v05/Images/
x p9m4-v05/Images/mace4-90t.gif
x p9m4-v05/Images/p9.ico
x p9m4-v05/Images/prover9-5a-128t.gif
x p9m4-v05/Images/prover9-splash.gif
x p9m4-v05/Mac.setup
x p9m4-v05/Mac-setup.py
x p9m4-v05/options.pyc
x p9m4-v05/my_setup.py
x p9m4-v05/Win32.README
x p9m4-v05/options.py
x p9m4-v05/p9-48.ico
x p9m4-v05/p9.icns
x p9m4-v05/partition_input.py
x p9m4-v05/Win32-setup.py
x p9m4-v05/platforms.py
x p9m4-v05/prover9-mace4.py
x p9m4-v05/Samples/
x p9m4-v05/Samples/Equality/
x p9m4-v05/Samples/Equality/Mace4/
x p9m4-v05/Samples/Equality/Mace4/BA-Sheffer-counterexample.in
x p9m4-v05/Samples/Equality/Mace4/CL-QL.in
x p9m4-v05/Samples/Equality/Mace4/Megill-68.in
x p9m4-v05/Samples/Equality/

In [15]:
def prover9_prove(conclusion: str, premises: List[str] = [], path=r"/content/prover9/bin") -> bool:
    """ 
    Give a conclusion and a list of premises, builds a tableau and
    detects whether the premises entail the conclusion.
    Returns a boolean value and optionally prints the tableau structure
    """
    str2exp = nltk.sem.Expression.fromstring
    c = str2exp(conclusion)
    ps = [ str2exp(p) for p in premises ] 
    prover9 = nltk.Prover9()
    if path: prover9.config_prover9(path) 
    return prover9.prove(c, ps)

## A short test:

In [16]:
premise = derivation_to_FOL(15)
hypothesis = derivation_to_FOL(16)

In [17]:
def get_inference_label(hypothesis: str, premise: List[str]):
  try:
    entailment = prover9_prove('-(' + hypothesis + ')', premise)
    contradiction = prover9_prove(hypothesis, premise)
    # print(f'Contradiction is {contradiction} and entailment is {entailment}')
    if entailment and not contradiction:
      return 'no'
    if contradiction and not entailment:
      return 'yes'
    else: 
      return 'unknown'
  except:
    print("Could not find a proof: return \'neutral\'")
    return 'unknown'

In [18]:
get_inference_label(hypothesis,[premise,"all x.(tree(x) -> plant(x))"])

Could not find a proof: return 'neutral'


'unknown'

# Step 5: evaluation

In [19]:
import pandas as pd
import traceback

In [22]:
# run on SICK trial set. catch the error when running.
# input: sick dataset file name + ccg derivation file name
# output: results and error
# data(dataframe): problem_id(int), p_sen(str), h_sen(str), p_dev(int), h_dev(int), label(str), p_fol(str), h_fol(str), add_p(list of str), add_fol(list of fol), result(str)
data = pd.DataFrame(columns=['problem_id','p_sen','h_sen', 'p_dev', 'h_dev', 'label', 'add_p', 'add_fol', 'p_fol', 'h_fol', 'result'])
data.set_index('problem_id')
error_cnt = {'LogicalExpressionException':0, "IndexError":0, "UnboundLocalError":0}

# ccg_derivation txt file ---> examples
# example[x] ---> the xth derivation in LOLA_SICK_corrected_derivations.txt
examples = []
with open("LOLA_SICK_corrected_derivations.txt", "r") as f:
    lines = [line for line in f]
    curr_example = ""
    for line in lines:
        if line == '\n':
            examples.append(curr_example)
            curr_example = ""
        else:
            curr_example += line

with open("SICK_trial_sen.txt.pl", "r") as file:
    line = [l for l in file]
    for i in range(0,len(line),3):
        problem_id = int(line[i].split('=')[1].strip())
        p_sen = line[i+1].split('(')[1].split(',')[4][2:-4]
        h_sen = line[i+2].split('(')[1].split(',')[4][2:-4]
        p_dev = int(line[i+1].split('(')[1].split(',')[0])
        h_dev = int(line[i+2].split('(')[1].split(',')[0])
        label = line[i+1].split('(')[1].split(',')[3][2:-1]

        # add_p, add_fol = get_knowledge(h_sen, [p_sen])
        try: 
            p_fol = derivation_to_FOL(p_dev)
            h_fol = derivation_to_FOL(h_dev)
        except LogicalExpressionException as e:
            print(f"{problem_id}: can't get a FOL formula based on the given lexical semantics")
            error_cnt["LogicalExpressionException"] += 1
        except(IndexError):
            # print(f"{problem_id}: can't handle token ','")
            error_cnt["IndexError"] += 1
        except(UnboundLocalError):
            print(f"{problem_id}: missing lexical semantics")
            error_cnt["UnboundLocalError"] += 1
        else:
            add_p, add_fol = get_knowledge(' '.join(fol_to_word(h_fol)), [' '.join(fol_to_word(p_fol))])
            if 'patient' in p_fol and 'by' in p_fol:
                add_fol.append('all x all y (patient(x,y) <-> argu2(x,y))')
                add_fol.append('all x all y (by(x,y) <-> argu1(x,y))')
            elif 'patient' in h_fol and 'by' in h_fol:
                add_fol.append('all x all y (patient(x,y) <-> argu2(x,y))')
                add_fol.append('all x all y (by(x,y) <-> argu1(x,y))')
            result = get_inference_label(h_fol,[p_fol] + add_fol)
            data.loc[problem_id] = [problem_id,p_sen,h_sen,p_dev,h_dev,label,add_p,add_fol,p_fol,h_fol,result]


4: can't get a FOL formula based on the given lexical semantics
Could not find a proof: return 'neutral'
105: can't get a FOL formula based on the given lexical semantics
Could not find a proof: return 'neutral'
Could not find a proof: return 'neutral'
Could not find a proof: return 'neutral'
197: can't get a FOL formula based on the given lexical semantics
Could not find a proof: return 'neutral'
218: can't get a FOL formula based on the given lexical semantics
219: can't get a FOL formula based on the given lexical semantics
Could not find a proof: return 'neutral'
Could not find a proof: return 'neutral'
236: can't get a FOL formula based on the given lexical semantics
253: can't get a FOL formula based on the given lexical semantics
Could not find a proof: return 'neutral'
285: can't get a FOL formula based on the given lexical semantics
317: can't get a FOL formula based on the given lexical semantics
Could not find a proof: return 'neutral'
384: can't get a FOL formula based on t

In [23]:
# result analysis
print("correct problems: ",len(data[data['label'] == data['result']]))
print("total problems: ",len(data))
print("accuracy: ",len(data[data['label'] == data['result']])/len(data))
print("wrong problems: \n")
wrong_data = data[data['label'] != data['result']]
for i in range(0,len(wrong_data)):
    print(wrong_data.iloc[i]['problem_id'], "label: ", wrong_data.iloc[i]['label'], "   result: ", wrong_data.iloc[i]['result'])
    print("premise: ", wrong_data.iloc[i]['p_sen'])
    print("add_premise: ", wrong_data.iloc[i]['add_fol'])
    print("hypothesis: ", wrong_data.iloc[i]['h_sen'])
    print("\n")

correct problems:  149
total problems:  263
accuracy:  0.5665399239543726
wrong problems: 

211 label:  yes    result:  unknown
premise:  Two dogs are playing by a tree
add_premise:  ['all x.(tree(x) -> plant(x))']
hypothesis:  Two dogs are playing by a plant


717 label:  no    result:  unknown
premise:  A few men in a competition are running outside
add_premise:  []
hypothesis:  A few men in a competition are running indoors


1190 label:  yes    result:  unknown
premise:  A man and a woman are hiking through a wooded area
add_premise:  ['all x.(hike(x) -> walk(x))']
hypothesis:  A man and a woman are walking together through the woods


1192 label:  yes    result:  unknown
premise:  A man and a woman are walking together through the woods
add_premise:  []
hypothesis:  A man and a woman are walking through a wooded area


1266 label:  yes    result:  unknown
premise:  A band is playing on a stage
add_premise:  []
hypothesis:  A band is playing onstage


1333 label:  yes    result:  u