In [None]:
# import sys
import os
import re

#Set path manually to incldue sources location
if 'src/' not in sys.path:
    sys.path.append('src/')


In [30]:
#Load Java classpath for stanford corenlp using gradle. this will also install it if missing
from subprocess import run,PIPE
if 'CLASSPATH' not in os.environ:
    if not (os.path.exists('build') and os.path.exists('build/classpath.txt')):
        print("Generating classpath")
        r=run(["./gradlew", "writeClasspath"],stdout=PIPE, stderr=PIPE, universal_newlines=True)
        print(r.stdout)
        print(r.stderr)
              
    print("Loading classpath")
    os.environ['CLASSPATH'] = open('build/classpath.txt','r').read()
    print("Done")

In [31]:
from classifier.LogisticRegressionClassifier import LogisticRegressionClassifier
from classifier.features.generate_features import FeatureGenerator, num
from distant_supervision.utterance_detection import f_threshold_match
from factchecking.question import Question
from tabular.filtering import load_collection

In [32]:
fg = FeatureGenerator()
Xs,ys = fg.generate_training()

Done: 0.0
Search for "Exxon Mobil" "Market Value"
Query already executed
Done: 6.25
Search for "Unaccompanied children" "claimed asylum"
Query already executed
Done: 12.5
Search for "Hamas" "Founded"
Query already executed
Done: 18.75
Search for "United States" "Average Temperature"
Query already executed
Done: 25.0
Search for "United States" "Life expectancy"
Query already executed
Done: 31.25
Search for "United States" "Number of abortions"
Query already executed
Done: 37.5
Search for "United States" "Abortion Rate per 1,000 births"
Query already executed
Done: 43.75
Search for "United States Teenagers" "Percentage Enrolled in education"
Query already executed
Done: 50.0
Search for "United States Teenagers" "Enrolled in education"
Query already executed
Done: 56.25
Search for "America" "bee colonies 201"
Query already executed
Done: 62.5
Search for "United States" "Financial Intermediary Funds 2016"
Query already executed
Done: 68.75
Search for "United States" "Homocides by firearm"


In [33]:
classifier = LogisticRegressionClassifier()
classifier.train(Xs,ys)


Training classifier
Trained


In [34]:
tables = load_collection("herox")

register table herox/1.csv
register table herox/2.csv
register table herox/3.csv
register table herox/4.csv
register table herox/5.csv
register table herox/8.csv
register table herox/9.csv
register table herox/10.csv
register table herox/11.csv
register table herox/12.csv
register table herox/13.csv
register table herox/14.csv


In [35]:
def fact_check(q):
    question = Question(text=q, type="NUM")
    tuples,q_features = fg.generate_test(tables,question)

    q_match = False

    if len(tuples)>0:
        q_predicted = classifier.predict(q_features)

        for i in range(len(tuples)):
            tuple = tuples[i]
            if len(tuple[1][2]) > 0:
                prediction = q_predicted[i]
                features = q_features[i]


                if prediction == 1:
                    print(str(tuple) + "\t\t" + ("Possible Match" if prediction else "No match"))
                    for number in question.numbers:
                        value = num(re.sub(r"[^0-9\.]+", "", tuple[1][2].replace(",", "")))

                        if value is None:
                            continue

                        if f_threshold_match(number, value, 0.05):
                            print(str(tuple) + "\t\t" + "Threshold Match to 5%")
                            q_match = True

                    for number in question.dates:
                        value = num(re.sub(r"[^0-9\.]+", "", tuple[1][2].replace(",", "")))
                        if number == value:
                            print(str(tuple) + "\t\t" + "Exact Match")
                            q_match = True
        print(question.text)
        print(q_match)

    else:
        print(question.text)
        print("No supporting information can be found in the knowledge base")
    print("\n\n")

In [57]:
fact_check("In the USA 2014, the number of homicides by firearm was almost 10,000")

('herox/11.csv', ('Firearm', 'USA', '11078'))		Possible Match
('herox/11.csv', ('Firearm', 'USA', '11078'))		Possible Match
In the USA 2014, the number of homicides by firearm was almost 10,000
False





In [67]:
import csv
import numpy as np
import re

def read_table(filename,base="data/WikiTableQuestions"):
    header = []
    rows = []

    header_read = False
    filename = filename.replace(".csv",".tsv")
    with open(base+"/"+filename,"r",encoding='utf-8') as table:
        has_header = csv.Sniffer().has_header(table.readline())
        table.seek(0)

        for line in csv.reader(table, delimiter="\t"):
            if has_header and not header_read:
                header = line
                header_read = True
            else:
                rows.append(line)
    return {"header": header, "rows":rows}

def table_nes(table):
    header = table['header']
    rows = table['rows']

    ret_tokens = []
    for col in transpose(rows):
        text = ". ".join(col)
        doc = Annotation(text)
        SharedNERPipeline().getInstance().annotate(doc)


        num_ne_cell = 0
        tokens = []
        for cell in range(doc.get(CoreAnnotations.SentencesAnnotation).size()):
            col = doc.get(CoreAnnotations.SentencesAnnotation).get(cell)

            words = []
            col_ne_tags = []
            has_ne = False
            for i in range(col.get(CoreAnnotations.TokensAnnotation).size()):
                corelabel = col.get(CoreAnnotations.TokensAnnotation).get(i)
                ne =corelabel.get(CoreAnnotations.NamedEntityTagAnnotation)

                words.append(corelabel.get(CoreAnnotations.TextAnnotation))
                if ne not in ['O','NUMBER','NUMERIC']:
                    has_ne = True

            if len(words) > 1:
                tokens.append(" ".join(words[:-1]))

            if has_ne:
                num_ne_cell += 1

        if num_ne_cell >= len(tokens)/2 and len(tokens) > 0:
            ret_tokens.extend(tokens)

    return ret_tokens


def number_tuples(table):
    header = table['header']
    rows = table['rows']

    ret_tokens = []
    entity_col = []
    date_col = []
    number_col = []
    
    col_id = 0
    
    table_trans = transpose(rows)
    for col in table_trans:
        
        text = ". ".join(col)
        doc = Annotation(text)
        SharedNERPipeline().getInstance().annotate(doc)

        num_ne_cell = 0
        num_date_cell = 0
        num_number_cell = 0
        
        tokens = []
        for cell in range(doc.get(CoreAnnotations.SentencesAnnotation).size()):
            col = doc.get(CoreAnnotations.SentencesAnnotation).get(cell)

            words = []
            col_ne_tags = []
            has_ne = False
            has_date = False
            has_number = False
            for i in range(col.get(CoreAnnotations.TokensAnnotation).size()):
                corelabel = col.get(CoreAnnotations.TokensAnnotation).get(i)
                ne =corelabel.get(CoreAnnotations.NamedEntityTagAnnotation)

                words.append(corelabel.get(CoreAnnotations.TextAnnotation))
                if ne not in ['O','NUMBER','NUMERIC','DATE','YEAR']:
                    has_ne = True

                if ne in ['YEAR','DATE']:    
                    has_date = True
                    
                if ne in['NUMBER','NUMERIC','PERCENTAGE','ORDINAL']:
                    has_number = True
                    
                    
            if len(words) > 1:
                tokens.append(" ".join(words[:-1]))

            if has_ne:
                num_ne_cell += 1

            if has_date:
                num_date_cell += 1
                
            if has_number:
                num_number_cell += 1
        
        if num_ne_cell >= len(tokens)/2 and len(tokens) > 0:
            entity_col.append(col_id)

        if num_date_cell >= len(tokens)/2 and len(tokens) > 0:
            number_col.append(col_id)

        if num_number_cell >= len(tokens)/2 and len(tokens) > 0:
            number_col.append(col_id)

        col_id +=1

    
    tuples = []
    for col in entity_col:            
        for col1 in number_col:

            tuples.extend(list(zip([header[col1]] * len(rows),table_trans[col],table_trans[col1])))
    
    
    
    
    return tuples



def number_entity_tuples(table):
    header = table['header']
    rows = table['rows']

    text = ". ".join(" ".join(cell for cell in row) for row in transpose(rows))

    doc = Annotation(text)
    SharedNERPipeline().getInstance().annotate(doc)

    ne_columns = []
    number_columns = []
    for column in range(doc.get(CoreAnnotations.SentencesAnnotation).size()):
        col = doc.get(CoreAnnotations.SentencesAnnotation).get(column)


        tokens = []
        col_ne_tags = []
        for i in range(col.get(CoreAnnotations.TokensAnnotation).size()):
            corelabel = col.get(CoreAnnotations.TokensAnnotation).get(i)
            tokens.append(corelabel.get(CoreAnnotations.TextAnnotation))
            col_ne_tags.append(corelabel.get(CoreAnnotations.NamedEntityTagAnnotation))

        tags = col_ne_tags


        for tag in tags:
            if len(set(col_ne_tags).intersection(set(number_ne_types))) == 0 and tag not in ['NUMBER','NUMERIC','YEAR','DATE','DURATION','TIME','NUMBER','ORDINAL'] and tag != "O":
                ne_columns.append(column)
                break


        count_in = 0
        if column not in ne_columns:
            for tag in tags:
                if tag in number_ne_types:
                    count_in += 1

            if count_in >= len(tags)/2:
                number_columns.append(column)

    numbers = []


    tuples = []
    transposed = transpose(rows)
    for column in range(len(transposed)):
        if column in ne_columns:
            for ncolumn in range(len(transposed)):
                if ncolumn in number_columns:
                    tuples.extend(list(zip([header[ncolumn]] * len(rows),transposed[column],transposed[ncolumn])))


    return tuples


def number_entity_date_tuples(table):
    header = table['header']
    rows = table['rows']

    ret_tokens = []
    entity_col = []
    date_col = []
    number_col = []
    
    col_id = 0
    
    table_trans = transpose(rows)
    
   
    hnums = set()
    hidx = 0
    for h in header:
        doc = Annotation(h)
        SharedNERPipeline().getInstance().annotate(doc)
        
        
        for s in range(doc.get(CoreAnnotations.SentencesAnnotation).size()):
            c = doc.get(CoreAnnotations.SentencesAnnotation).get(s)
            for i in range(c.get(CoreAnnotations.TokensAnnotation).size()):
                corelabel = c.get(CoreAnnotations.TokensAnnotation).get(i)
                ne =corelabel.get(CoreAnnotations.NamedEntityTagAnnotation)
                
                if ne in ['YEAR','DATE']:
                    hnums.add(hidx)
        hidx+=1


    hseries = False
    if(len(hnums)>=len(header)/2):
        hseries = True
        
        
    for col in table_trans:
        
        text = ". ".join(col)
        doc = Annotation(text)
        SharedNERPipeline().getInstance().annotate(doc)


        num_ne_cell = 0
        num_date_cell = 0
        num_number_cell = 0
        
        tokens = []
        for cell in range(doc.get(CoreAnnotations.SentencesAnnotation).size()):
            col = doc.get(CoreAnnotations.SentencesAnnotation).get(cell)

            words = []
            col_ne_tags = []
            has_ne = False
            has_date = False
            has_number = False
            for i in range(col.get(CoreAnnotations.TokensAnnotation).size()):
                corelabel = col.get(CoreAnnotations.TokensAnnotation).get(i)
                ne =corelabel.get(CoreAnnotations.NamedEntityTagAnnotation)

                words.append(corelabel.get(CoreAnnotations.TextAnnotation))
                if ne not in ['O','NUMBER','NUMERIC','DATE','YEAR']:
                    has_ne = True

                if ne in ['YEAR','DATE']:    
                    has_date = True
                    
                if ne in['NUMBER','NUMERIC','PERCENTAGE','ORDINAL']:
                    has_number = True
                    
                    
            if len(words) > 1:
                tokens.append(" ".join(words[:-1]))

            if has_ne:
                num_ne_cell += 1

            if has_date:
                num_date_cell += 1
                
            if has_number:
                num_number_cell += 1
        
        if not hseries:
            if num_ne_cell >= len(tokens)/2 and len(tokens) > 0:
                entity_col.append(col_id)

            if num_date_cell >= len(tokens)/2 and len(tokens) > 0:
                date_col.append(col_id)

            if num_number_cell >= len(tokens)/2 and len(tokens) > 0:
                number_col.append(col_id)
                
        

        col_id +=1

    
    tuples = []
    
    if not hseries:
        for col in entity_col:            
            for col1 in number_col:
                extra = []
                if len(date_col)>0:
                    for dc in date_col:
                        extra.append(table_trans[dc])
                    extra = transpose(extra)
                    tuples.extend(list(zip([header[col1]] * len(rows),table_trans[col],table_trans[col1],extra)))
                else:
                    tuples.extend(list(zip([header[col1]] * len(rows),table_trans[col],table_trans[col1])))

    else:
        nh = (set(range(len(header))).difference(hnums))
        tr = []
        for col in nh:
            tr.extend(table_trans[col])
            for col in hnums:
                tr.extend(table_trans[col])
        
        tuples.extend(tr)
        
    
    
    return tuples


def transpose(l):
    return list(map(list, zip(*l)))

t = read_table("herox/2.csv")
print(number_entity_date_tuples(t))

['EU (28 countries)', 'Belgium', 'Bulgaria', 'Czech Republic', 'Denmark', 'Germany', 'Estonia', 'Ireland', 'Greece', 'Spain', 'France', 'Croatia', 'Italy', 'Cyprus', 'Latvia', 'Lithuania', 'Luxembourg', 'Hungary', 'Malta', 'Netherlands', 'Austria', 'Poland', 'Portugal', 'Romania', 'Slovenia', 'Slovakia', 'Finland', 'Sweden', 'United Kingdom', 'Iceland', 'Liechtenstein', 'Norway', 'Switzerland', 'Montenegro', 'Former Yugoslav Republic of Macedonia, the', 'Albania', 'Serbia', 'Turkey', '11695', '470', '15', '35', '300', '765', '0', '100', '295', '10', '410', ':', '575', '70', '5', '0', '0', '175', '20', '725', '695', '375', '5', '55', '20', '70', '705', '1510', '4285', '0', '0', '1045', '595', ':', ':', ':', ':', ':', '12190', '705', '10', '10', '520', '1305', '0', '55', '40', '20', '445', ':', '415', '20', '0', '5', '10', '270', '45', '1040', '1040', '360', '0', '40', '25', '30', '535', '2250', '2990', '0', '15', '1820', '415', ':', ':', ':', ':', ':', '10610', '860', '20', '5', '410', 