In [None]:
import pandas as pd
import nltk,re
from nltk.tokenize import word_tokenize
from nltk import pos_tag, word_tokenize
from amr_logic_converter import types
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from word_forms.word_forms import get_word_forms
from sympy.logic.boolalg import to_cnf
from sympy.abc import A, B,C
import sympy
from sklearn.metrics import classification_report
from sympy import Symbol,simplify_logic
from pysat.formula import CNF
from pysat.solvers import Solver
from pattern.en import conjugate, lemma, lexeme, PRESENT, SG
import inflect
import numpy as np
from tqdm import tqdm
import spacy
nlp = spacy.load("en_core_web_sm")

In [None]:
from transition_amr_parser.parse import AMRParser
from amr_logic_converter import AmrLogicConverter
from transformers import pipeline
# Download and save a model named AMR3.0 to cache
parser = AMRParser.from_pretrained('AMR3-joint-ontowiki-seed43')
converter = AmrLogicConverter(existentially_quantify_instances=False,invert_relations=True)

In [None]:
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')


In [None]:
# generate PL
def generate_logic(data):
    tem  = []
    temm = []
    tem_token = []
    for sen in data:
        tokens, positions = parser.tokenize(sen)
        tem_token.append(tokens)
    
    annotations, machines = parser.parse_sentences(tem_token)
    tem = annotations
    temm = [i.get_amr().to_penman(jamr=False, isi=False) for i in machines]
    n = 0
    r1 = []
    r2 = []
    for sen in data:
        r1.append(converter.convert(tem[n]))
        r2.append(converter.convert(temm[n]))
        n+=1
    return tem,temm, r1,r2

In [None]:
# get substring for dyadic predicate
def get_substring(s,w1,w2):
    p = inflect.engine()
#     print(s,w1,w2)
    if w1.isnumeric():
        w1 = p.number_to_words(w1)
    if w2.isnumeric():
        w2 = p.number_to_words(w2) 
    w11 = []
    if w1=="be-located-at":
        w11 = ["on","at","in"]
    w22 = []
    if w2=="be-located-at":
        w22 = ["on","at","in"]
    
    w111 = []
    if w1 == "person":
        w111 = [str(tok) for tok in nlp(s) if (tok.dep_ == "nsubj")]
    w222 = []
    if w2== "person":
        w222 = [str(tok) for tok in nlp(s) if (tok.dep_ == "nsubj")]


    sub1 = [j for i in get_word_forms(w1,0.7) for j in get_word_forms(w1,0.7)[i]]+[w1]+w11+w111
    sub2 = [j for i in get_word_forms(w2,0.7) for j in get_word_forms(w2,0.7)[i]]+[w2]+w22+w222
    search1 = 0
    search2 = 999
    c1= 0
    c2= 0

    token = word_tokenize(s.lower())
#     print(token)
    for i in range(len(token)):
        if token[i] in sub1:
            c1 = 1
            search1 = i
        elif token[i] in sub2:
            c2 = 1
            search2 = i
#     print(search1,search2)
# print(search)
    if c1 == 0 or c2 == 0:
         return False

    if search1 > search2:
#         print("The extracted string : " + " ".join(token[search2:search1+1]))
        return " ".join(token[search2:search1+1])
    else:
#         print("The extracted string : " + " ".join(token[search1:search2+1]))
        return " ".join(token[search1:search2+1])
        
 


In [None]:
# convert to pysat formula
def combine(final,f = False):
    init = True
    for i in final:
        if type(i) == list:
            tem = True
            
            tem = tem&combine(i)
            if ~tem == -1:
                
                init = init & True
            elif ~tem == -2:
                if not f:
                    init = init & False
                else:
                    init = init & True
            else:
                init = init&~tem
        else:
            init = init&i
            
    return init
    

In [None]:
import copy
# convert to pysat formula
def transform(formula,Var,X):
    final = copy.deepcopy(formula)
    for i in range(len(final)):
        
        if type(final[i]) == list:
            
            if final[i][0] == "ARG":
                if " ".join([Var[final[i][1]],Var[final[i][2]],final[i][3]]) not in X:
                    continue
                else:
                    final[i] = X[" ".join([Var[final[i][1]],Var[final[i][2]],final[i][3]])]
#             
                
#                     init = init
            else:
                final[i] = transform(final[i],Var,X)
               
#                     else:
                    
        else:
            if final[i] not in X:
                continue
        
            else:
                final[i]  = X[final[i]] 
  
    return final
        


In [None]:
# extract predicates from PL 
def extract(formula):
    and_list = []
    var = {}
    arg = []
    if type(formula) == types.Not:
        return [extract(formula.body)[0]],{**extract(formula.body)[1],**var},arg+extract(formula.body)[2]
    for i in formula.args:
        if type(i) == types.Not:
            and_list.append(extract(i.body)[0])
            var = {**extract(i.body)[1],**var}
            arg = arg+extract(i.body)[2]
           
        else:
            if i.predicate.symbol[0] ==":":
#             if i.predicate.symbol[0] =="÷:"
                and_list.append(["ARG"]+ [i.terms[j].value for j in range(0,len(i.terms))]+[i.predicate.symbol])
                arg.append([i.terms[j].value for j in range(0,len(i.terms))]+[i.predicate.symbol])
            else:
                and_list.append(re.sub(r'\-*[0-9]',"",i.predicate.symbol))
                var[i.terms[0].value] = re.sub(r'\-*[0-9]',"",i.predicate.symbol)
    return and_list,var,arg
    

In [None]:
# calculate sentences/words similarity
def score(s1,s2):
    sentences = [s1,s2]
    embedding_1= model.encode(sentences[0], convert_to_tensor=True,show_progress_bar=False)
    embedding_2 = model.encode(sentences[1], convert_to_tensor=True,show_progress_bar=False)
    return util.pytorch_cos_sim(embedding_1, embedding_2)[0][0]

In [None]:
# generate PYSAT formula
def pysat_formula(formula):
    tem_list = []
    for i in str(formula).split(" & "):
        if i[0] == "x":
            tem_list.append([int(i[1:])])
        else:
            tem_tem = []
            for j in i.replace("(","").replace(")","").split(" | "):
                if j[0] == "~":
#                 print(int(j[-1]))
                    tem_tem.append(int(j[2:])*-1)
                elif j[0] == "x":
                    tem_tem.append(int(j[1:]))
            tem_list.append(tem_tem)
    return tem_list

In [None]:
# subsitute variable according similarity
def subsitute(x,y,replaceX,replaceXX,maxx,i,j,thre):
    tems = score(x,y)
                                    
    if tems>=thre:
        if tems > maxx[i]:
            maxx[i] = tems
            replaceXX[i] = replaceX[j]
            return True
    return False
            

In [None]:
# relaxed proposition entailment/contradiction to classify RTE dataset
def prove(data,sent):
    checkArg0 = []
    checkVaribale0 = {}
    # extract predicate from premise
    for0,checkVaribale0,checkArg0 = extract(data[0])
    for i in checkArg0:
        for j in range(len(i)):
#             print(j)
            if i[j] in checkVaribale0:
                i[j] = checkVaribale0[i[j]]
            else:
                checkVaribale0[i[j]] = i[j]

    replaceX = {}
    n = 1
    for i in checkVaribale0:
        if checkVaribale0[i][0] == ":":
            continue
        if checkVaribale0[i] not in replaceX:
            replaceX[checkVaribale0[i]] = Symbol('x'+str(n))  
            n+=1
    for i in checkArg0:
#         print(i)
        if " ".join(i) not in replaceX:
            replaceX[" ".join(i)] = Symbol('x'+str(n))  
            n+=1
    explanation = False
    if len(data)>2:
        explanation = True
    if explanation:
        checkArg2 = []
        checkVaribale2 ={}
        # extract predicate from explanation
        for2,checkVaribale2,checkArg2 = extract(data[2])
        for i in checkArg2:
            for j in range(len(i)):
#             print(j)
                if i[j] in checkVaribale2:
                    i[j] = checkVaribale2[i[j]]
                else:
                    checkVaribale2[i[j]] = i[j]
    
#         print(checkVaribale2)
        for i in checkVaribale2:
#             print(i)
            if checkVaribale2[i][0] == ":":
                continue
            if checkVaribale2[i] not in replaceX:
                replaceX[checkVaribale2[i]] = Symbol('x'+str(n))  
                n+=1
        for i in checkArg2:
            if " ".join(i) not in replaceX:
                replaceX[" ".join(i)] = Symbol('x'+str(n))  
                n+=1
    checkArg11 = []
    checkVaribale11 ={}
    # extract predicate from claim
    for1,checkVaribale11,checkArg11 = extract(data[1])
    
    replaceXX = {}

    thre = 0.55
    for i in checkArg11:
        for j in range(len(i)):
            if i[j] in checkVaribale11:
                i[j] = checkVaribale11[i[j]]
            else:
#                 else:
                checkVaribale11[i[j]] = i[j]
    maxx = {}
    for i in checkArg11:
        maxx[" ".join(i)] = 0
    for i in checkVaribale11:
        maxx[checkVaribale11[i]] = 0
    temj = ""
    for i in checkVaribale11:
        if checkVaribale11[i][0] == ":":
            continue
        if checkVaribale11[i] in replaceX:
            replaceXX[checkVaribale11[i]] = replaceX[checkVaribale11[i]]

        else:
            ccccc = 0
            for j in replaceX:
                if len(j.split())>1:
                    continue

                subsitute(checkVaribale11[i],j,replaceX,replaceXX,maxx,checkVaribale11[i],j,thre)
                
            if maxx[checkVaribale11[i]] == 0:
                replaceXX[checkVaribale11[i]] =Symbol('x'+str(n))
#                 formula11 = &formula11
                
                n+=1
    for i in checkArg11:
        if " ".join(i) in replaceXX:
            continue
#         print(i)
        if " ".join(i) in replaceX:
#             print(i)
            replaceXX[" ".join(i)] = replaceX[" ".join(i)]

        else:
            temj = " "
            for j in replaceX:
                
#                 print(j)
                if len(j.split())<3:
#                     continue
                    if True:
#                         print(j,i,temj)
                        
                        if subsitute(j," ".join([i[0],i[-2]]),replaceX,replaceXX,maxx," ".join(i),j,thre):
                            temj = j
                               
                
                else:

                            tems3 = False
                            tems1 = get_substring(sent[1],i[0],i[-2])
                            tems2 = get_substring(sent[0],j.split()[0],j.split()[-2])
                            if explanation:
                                tems3 = get_substring(sent[2],j.split()[0],j.split()[-2])
#                            
                            if not tems1:
                                if tems2:
                                    if subsitute(" ".join([i[0],i[-2]]),tems2,replaceX,replaceXX,maxx," ".join(i),j,thre,):
                                        temj = j
    
                                if tems3:
                                    if subsitute(" ".join([i[0],i[-2]]),tems3,replaceX,replaceXX,maxx," ".join(i),j,thre):
                                        temj = j
                                   
                                if subsitute(" ".join([i[0],i[-2]])," ".join([j.split()[0],j.split()[1]]),replaceX,replaceXX,maxx," ".join(i),j,thre,):
                                    temj = j
                           
                                        
                            if tems2 and tems1:
                                if subsitute(tems1,tems2,replaceX,replaceXX,maxx," ".join(i),j,thre,):
                                    temj = j
                                
                        
                            if tems3 and tems1:
                                if subsitute(tems1,tems3,replaceX,replaceXX,maxx," ".join(i),j,thre,):
                                    temj = j
                   
            if maxx[" ".join(i)] == 0:
                replaceXX[" ".join(i)] = Symbol('x'+str(n))
                n+=1
            else:
               
                if i[0] in replaceXX:
                    replaceXX[i[0]] = True
                if i[-2] in replaceXX:
                    replaceXX[i[-2]]= True
    new_rex = {}
    for i in replaceXX:
        if i == "and":
            new_rex[i] = True
        if i.split()[0]=="and" and i.split()[-1][:3]==":op":
            new_rex[i] = new_rex[i.split()[1]]
        else:
            new_rex[i] = replaceXX[i]
    new_re = {}
#     print(for2)
    for i in replaceX:
#         print(replaceXX.values())
        tcc = 0
        for j in new_rex:
        
            if isinstance(new_rex[j], sympy.Not):
                if ~new_rex[j] == replaceX[i]:
#                     print(new_rex[j] ,~new_rex[j] )
                    new_re[i] = replaceX[i]
                    tcc = 1
            else:
                if new_rex[j] == replaceX[i]:
                    new_re[i] = replaceX[i]
                    tcc = 1
        if tcc == 0:
                new_re[i] = True

    formula0 = combine(transform(for0,checkVaribale0, replaceX))

    formula11 =  combine(transform(for1,checkVaribale11, new_rex))
    if formula11 == -1:
        formula11 = True
    elif formula11 == -2:
        formula11 =False
    elif formula11 == 0:
        formula11 =False
    elif formula11 == 1:
        formula11 =True
    if explanation:
        formula2 = combine(transform(for2,checkVaribale2, replaceX))
        final_formula = to_cnf( (formula0 & formula2 ) & ~(formula11))
    else:
        final_formula = to_cnf( formula0 & ~(formula11))
    cnf = CNF(from_clauses=pysat_formula(final_formula))
   
    formula00 = combine(transform(for0,checkVaribale0, new_re))
    
    if explanation:
        formula22 = combine(transform(for2,checkVaribale2, new_re))
        if formula22 == -1:
            formula22 = True
        elif formula22 == -2:
            formula22 = False
        elif formula22 == 1:
            formula22 = True
        elif formula22 == 0:
            formula22 = False
        if formula00 == -1:
            formula00 = True
        elif formula00 == -2:
            formula00 =False
        elif formula00 == 1:
            formula00 =True
        elif formula00 == 0:
            formula00 =False
        final_formula11 = to_cnf( formula00 & formula22 & formula11)
    else:
        if formula00 == -1:
            formula00 = True
        elif formula00 == -2:
            fomula00 = False

        final_formula11 = to_cnf( formula00 & formula11)

    cnf11 = CNF(from_clauses=pysat_formula(final_formula11))
#
    with Solver(name = "Minisat22",bootstrap_with=cnf) as solver:
        
        check_ent = solver.solve()

    with Solver(name = "Minisat22",bootstrap_with=cnf11) as solver:
        check_con1 = solver.solve()
#     print(check_con)
    # do classification
    if not check_ent and check_con1:
        return "ent"
    elif not check_con1 and check_ent:

        return "con"
    
    elif not check_con1 and not check_ent:
        return "both"

    else:
        return "neu"
    
   

In [None]:
# np.random.seed(66)

In [None]:
# e-SNLI dataset
df = pd.read_csv('esnli_train_1.csv')

In [None]:

sent_ent = []
sent_con = []
sent_neu = []
# set max length
length = 2999999
for i in range(0,259999):

        try:
            if len(word_tokenize(df.iloc[i,2]))>length or len(word_tokenize(df.iloc[i,3]))>length or len(word_tokenize(df.iloc[i,4]))>length:
                continue
        except:
            continue

        if df.iloc[i,1] == "entailment":
#             conti

            sent_ent.append([df.iloc[i,2],df.iloc[i,3],df.iloc[i,4],'ent'])

        elif df.iloc[i,1] == "contradiction":
            sent_con.append([df.iloc[i,2],df.iloc[i,3],df.iloc[i,4],'con'])   
        else:
#             elif df.iloc[i,1] == "contradicton":
            sent_neu.append([df.iloc[i,2],df.iloc[i,3],df.iloc[i,4],'neu']) 


In [None]:
esnli_sent = [sent_ent[i] for i in np.random.choice(len(sent_ent), 50
                                                    ,replace=False)]+[sent_neu[i] for i in np.random.choice(len(sent_neu), 50,replace=False)]+[sent_con[i] for i in np.random.choice(len(sent_con), 50,replace=False)]

In [None]:
len(esnli_sent)

In [None]:
pre_data = []
for i in tqdm(esnli_sent):
#     try:
    pre_data.append(generate_logic(i[:-1])[-2])

In [None]:
ll = []
gl = []
for i in tqdm(range(0,150)):
    try:
#     pre_data = (generate_logic(i[:-1])[-2])
        tem = prove(pre_data[i][:],esnli_sent[i])
#         print(tem)
        if tem == "both":
            print("both")
            continue
        ll.append(tem)
        gl.append(esnli_sent[i][-1])
    except:
        print("exception")
        continue

    

In [None]:
# get classification/evaluation report
testy = gl
yhat_classes = ll

accuracy = accuracy_score(testy, yhat_classes)
print('Accuracy: %f' % accuracy)


print(classification_report(testy, yhat_classes))

In [None]:
confusion_matrix(testy, yhat_classes,labels = ["ent","con","neu"])

In [None]:
# SICK dataset
from datasets import load_dataset

dataset = [load_dataset("sick",split="train"),load_dataset("sick",split="test"),load_dataset("sick",split="validation")]

In [None]:
sent_ent = []
sent_con = []
sent_neu = []
length = 10
# with open('sentences.txt', 'w') as f:
for j in dataset:
    for i in j:
#         print(i)÷
#         if df.iloc[i,1] != "entailment":
#             continue
        if len(word_tokenize(i["sentence_A"]))>length or len(word_tokenize(i["sentence_A"]))>length:
            continue
#         print(df.iloc[i,3])
        
#         print(generated_sentence)
        if i["label"] == 0:
            if i["entailment_AB"] == ' A_entails_B':
#             conti
                sent_ent.append([i["sentence_A"],i["sentence_B"],'ent'])
            else:
                sent_ent.append([i["sentence_B"],i["sentence_A"],'ent'])
        elif i["label"] == 1:
            if i["entailment_AB"] == ' A_neutral_B':
#             conti
                sent_neu.append([i["sentence_A"],i["sentence_B"],'neu'])
            else:
                sent_neu.append([i["sentence_B"],i["sentence_A"],'neu'])

        elif i["label"] == 2:
            if i["entailment_AB"] == ' A_contradicts_B':
#             conti
                sent_con.append([i["sentence_A"],i["sentence_B"],'con'])
            else:
                sent_con.append([i["sentence_B"],i["sentence_A"],'con'])


In [None]:

sick_sent = [sent_ent[i] for i in np.random.choice(len(sent_ent)
                                                   , 10,replace=False)]+[sent_neu[i] for i in np
    .random.choice(len(sent_neu), 10,replace=False)]+[sent_con[i] for i in np.random.choice(len(sent_con),10,replace = False)]




In [None]:
ll_sick = []
gl_sick = []
for i in tqdm(sick_sent):
    try:
        pre_data = (generate_logic(i[:-1])[-2])
        tem = prove(pre_data,i)
#         print(tem)
        if tem == "both":
            print("both")
            continue
        ll_sick.append(tem)
        gl_sick.append(i[-1])
    except:
        print("exception")
        continue

In [None]:
testy = gl_sick
yhat_classes = ll_sick

accuracy = accuracy_score(testy, yhat_classes)
print('Accuracy: %f' % accuracy)

from sklearn.metrics import classification_report

print(classification_report(testy, yhat_classes))

In [None]:
# multiNLI datasets
import json

data_mnli = []
with open('multinli_1.0_train.jsonl') as f:
    for line in f:
        data_mnli.append(json.loads(line))

In [None]:
sent_ent = []
sent_con = []
sent_neu = []
length = 10
# with open('sentences.txt', 'w') as f:
for i in data_mnli:
#         if df.iloc[i,1] != "entailment":
#             continue
        if len(word_tokenize(i["sentence1"]))>length or len(word_tokenize(i["sentence2"]))>length:
            continue

        if i["gold_label"] == "entailment":
#             conti
            sent_ent.append([i["sentence1"],i["sentence2"],'ent'])

        elif i["gold_label"] == "contradiction":
            sent_con.append([i["sentence1"],i["sentence2"],'con'])   
        else:
#             elif df.iloc[i,1] == "contradicton":
            sent_neu.append([i["sentence1"],i["sentence2"],'neu']) 


In [None]:
mnli_sent = [sent_ent[i] for i in np.random.choice(len(sent_ent)
                                ,10,replace=False)]+[sent_neu[i] for i in np.random.choice(len(sent_neu), 10
                                                    ,replace=False)]+[sent_con[i] for i in np.random.choice(len(sent_con), 10
                                                    ,replace=False)]

In [None]:
ll_mn = []
gl_mn = []
for i in tqdm(mnli_sent):
    try:
        pre_data = (generate_logic(i[:-1])[-2])
        tem = prove(pre_data,i)
#         print(tem)
        if tem == "both":
            print("both")
            continue
        ll_mn.append(tem)
        gl_mn.append(i[-1])
    except:
        print("exception")
        continue

In [None]:
testy = gl_mn
yhat_classes = ll_mn

accuracy = accuracy_score(testy, yhat_classes)
print('Accuracy: %f' % accuracy)

from sklearn.metrics import classification_report

print(classification_report(testy, yhat_classes))