In [1]:
from data.table.Table import Table

from training import TrainingExample

from dcs.aggregation.Avg import Avg
from dcs.aggregation.Count import Count
from dcs.aggregation.Max import Max
from dcs.aggregation.Min import Min
from dcs.aggregation.Sum import Sum
from dcs.arithmetic.Add import Add

from dcs.comparitor.GreaterThan import GreaterThan
from dcs.comparitor.GreaterThanEq import GreaterThanEq
from dcs.comparitor.LessThan import LessThan
from dcs.comparitor.LessThanEq import LessThanEq
from dcs.comparitor.NEq import NEq
from dcs.comparitor.Eq import Eq


from dcs.base.Atom import Atom
from dcs.base.Entity import Entity
from dcs.base.Property import Property
from dcs.base.Unary import Unary
from dcs.base.NormalisationAtom import NormalisationAtom
from dcs.base.Record import Record
from dcs.base.DateAtom import DateAtom
from dcs.base.ComparableAtom import ComparableAtom
from dcs.comparitor.GreaterThan import GreaterThan
from dcs.relation.Chain import Chain
from dcs.relation.Intersection import Intersection
from dcs.relation.Join import Join
from dcs.relation.Negate import Negate
from dcs.relation.Reverse import Reverse
from dcs.relation.Union import Union
from dcs.superlative.ArgMax import ArgMax
from dcs.superlative.ArgMin import ArgMin

from parser.TokenSpan import TokenSpan
from parser.GroundingStep import get_table_properties
from parser.GroundingStep import get_question_properties

from util.dateutils import *
from util.numberutils import *

Loading spacy
Done


In [4]:
def check_produces(p):
    v = p.vals()
    return v is not None and len(v) > 0

def check_not_identical(p1,p2):
    return p1 != p2

def check_not_same_productions(p1,p2):
    return p1.vals() != p2.vals()

def check_same_ret_type(p1,p2):
    return p1.get_types() == p2.get_types()

def check_reverse(b):
    return not isinstance(b,Reverse)

def check_k_type(b1,b2):
    return b1.chain(b2)

def check_cmp(b,a):
    return any([t in a.get_types() for t in [Atom, DateAtom]]) and set([type(v.v) for v in b.vals()]) == set([type(v) for v in a.vals()])
    

def check_single(u):
    return len(u.vals())==1
training_ex = TrainingExample("Greece last held its Summer Olympics since 2004?",None,None)
table = Table("test.tsv")


u1, b1 = get_table_properties(table)
u2, b2 = get_question_properties(training_ex)

u = set()
u.update(u1)
u.update(u2)

b = set()
b.update(b1)
b.update(b2)




def cross_product(list1,list2):
    return [(a,b) for a in list1 for b in list2]



def act_union(bits):   
    if not check_produces(bits[0]) or not check_produces(bits[1]):
        return []
    
    if not check_not_identical(bits[0],bits[1]):
        return []
    
    if not check_same_ret_type(bits[0],bits[1]):
        return []
    
    q = Union(bits[0],bits[1])
    
    if check_produces(q):
        return [q]
    else:
        return []
    
    
    

def act_intersection(bits):
    if not check_produces(bits[0]) or not check_produces(bits[1]):
        return []
    
    if not check_not_identical(bits[0],bits[1]):
        return []
    
    if not check_same_ret_type(bits[0],bits[1]):
        return []
        
    q = Intersection(bits[0],bits[1])
    if check_produces(q):
        return [q]
    else:
        return []




def act_join(bits):
    return [Join(bits[0],bits[1])]

def act_agg(u):
    if isinstance(u,Atom):
        return [Min(u), Max(u), Count(u), Avg(u), Sum(u)]
    return [Count(u)]
        
def act_sup(bits):
    u = bits[0]
    b = bits[1]
    
    if not check_produces(b):
        return []
    
    if not check_produces(u):
        return []
    
    if b.compatible(u):
        return [ArgMin(u,b), ArgMax(u,b)]
    return []

def act_reverse(b):
    if not check_produces(b):
        return []
    
    if not check_reverse(b):
        return []
    
    return [Reverse(b)]

def act_chain(bs):
    b1 = bs[0]
    b2 = bs[1]
    
    if not check_produces(b1):
        return []
    
    if not check_produces(b2):
        return []
    
    if not check_k_type(b1,b2):
        return []
    
    return [Chain(b1,b2)]
    
def act_cmp(ba):
    b = ba[0]
    a = ba[1]
    
    if not check_single(a):
        return []
    
    if not check_cmp(b,a):
        return []
    
    
    return [Join(b, GreaterThan(a)),
            Join(b, GreaterThanEq(a)),
            Join(b, LessThan(a)),
            Join(b, LessThanEq(a)),
            Join(b, Eq(a)),
            Join(b, NEq(a)),
           ]
       


def transition(u,b,v_prods):
    atoms = [au for au in u if isinstance(au,Atom)]
    
    ab_actions = [(act_cmp, cross_product(b,atoms)) ]
    
    u_actions = [( act_union , cross_product(u,u) ),
             ( act_intersection, cross_product(u,u) ),
             ( act_agg, u )
            ]

    bu_actions = [( act_join , cross_product(b,u) ),
                  ( act_sup, cross_product(u,b) )
                 ]

    b_actions = [ ( act_reverse, b),
                  ( act_chain, cross_product(b,b)) 
                ]


    print("Starting with " + str(len(u)) + " unaries")
    print("Starting with " + str(len(b)) + " binaries")
    
    u_next = set()
    b_next = set()

    for action,data in ab_actions:
        for datum in data:
            for a in action(datum):
                if check_produces(a):
                    u_next.add((datum,a))
    
    for action,data in u_actions:
        for datum in data:
            for a in action(datum):
                if check_produces(a):
                    u_next.add((datum,a))
    
    for action,data in bu_actions:
        for datum in data:
            for a in action(datum):
                if check_produces(a):
                    u_next.add((datum,a))

    

    for action,data in b_actions:
        for datum in data:
            for a in action(datum):
                if check_produces(a):
                    b_next.add((datum,a))
           
    u_ret = set()
    b_ret = set()
    
    u_ret.update(u)
    b_ret.update(b)
    
    for data,a in u_next:
         u_ret.add(a)
    for data,a in b_next:
         b_ret.add(a)
               
    for a in u_ret:
         print(a)
            
    for a in b_ret:
         print(a)
    
    return u_ret,b_ret
        

u,b = transition(u,b,None)
u,b = transition(u,b,None)

Starting with 5 unaries
Starting with 11 binaries
[JOIN: [PROPERTY: $number] x [ATOM:2004.0]]
[JOIN: [PROPERTY: $date$year] x [GTE [ATOM:2004.0]]]
[COUNT: [NORMALIZE 2004]]
[ARGMAX: [NORMALIZE 2004] [PROPERTY: $number]]
[UNARY: $records]
[JOIN: [PROPERTY: $date$year] x [NEQ [ATOM:2004.0]]]
[JOIN: [PROPERTY: $index] x [LTE [ATOM:2004.0]]]
[COUNT: [UNARY: $records]]
[ARGMAX: [UNARY: $records] [PROPERTY: $index]]
[JOIN: [PROPERTY: $index] x [LT [ATOM:2004.0]]]
[JOIN: [PROPERTY: $date$year] x [ATOM:2004.0]]
[JOIN: [PROPERTY: $date$year] x [EQ [ATOM:2004.0]]]
[MAX: [ATOM:2004.0]]
[JOIN: [PROPERTY: $number] x [LTE [ATOM:2004.0]]]
[ARGMIN: [DATE Y: 2004 M: None D:None]  [PROPERTY: $date$year]]
[JOIN: [PROPERTY: $number] x [EQ [ATOM:2004.0]]]
[JOIN: [PROPERTY: $next] x [UNARY: $records]]
[JOIN: [PROPERTY: Year] x [NORMALIZE 2004]]
[JOIN: [PROPERTY: $date$year] x [GT [ATOM:2004.0]]]
[JOIN: [PROPERTY: $number] x [GTE [ATOM:2004.0]]]
[JOIN: [PROPERTY: $index] x [NEQ [ATOM:2004.0]]]
[COUNT: [ATOM:

In [None]:



def cross_product(list1,list2):
    return [(a,b) for a in list1 for b in list2]

def act_union(bits):
    if bits[0]
    
    if bits[0] != bits[1]:
        return [Union(bits[0],bits[1])]

def act_intersection(bits):
    return [Intersection(bits[0],bits[1])]

def act_join(bits):
    return [Join(bits[0],bits[1])]

def act_agg(u):
    if isinstance(u,Atom):
        return [Min(u), Max(u), Count(u), Avg(u), Sum(u)]
    return [Count(u)]
        
def act_sup(bits):
    u = bits[0]
    b = bits[1]
    return [ArgMin(u,b), ArgMax(u,b)]

def act_reverse(b):
    if isinstance(b,Reverse):
        return [b.b]
    return [Reverse(b)]

def act_chain(bs):
    b1 = bs[0]
    b2 = bs[1]
    return [Chain(b1,b2)]
    
def act_cmp(ua):
    u = ua[0]
    a = ua[1]
    
    return [Join(u, GreaterThan(a)),
            Join(u, GreaterThanEq(a)),
            Join(u, LessThan(a)),
            Join(u, LessThanEq(a)),
            Join(u, Eq(a)),
            Join(u, NEq(a)),
           ]
       


def transition(u,b,v_prods):
    atoms = [au for au in u if isinstance(au,Atom)]
    
    ab_actions = []# (act_cmp, cross_product(b,atoms)) ]
    
    u_actions = [( act_union , cross_product(u,u) ),
             ( act_intersection, cross_product(u,u) ),
             ( act_agg, u )
            ]

    bu_actions = [( act_join , cross_product(b,u) ),
                  ( act_sup, cross_product(u,b) )
                 ]

    b_actions = [ ( act_reverse, b),
                  ( act_chain, cross_product(b,b)) 
                ]


    print("Starting with " + str(len(u)) + " unaries")
    print("Starting with " + str(len(b)) + " binaries")
    
    u_next = set()
    b_next = set()
    
    for action,data in ab_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))
                

                
    for action,data in u_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))

    for action,data in bu_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))

    for action,data in b_actions:
        for datum in data:
            for a in action(datum):
                b_next.add((datum,a))

    print("Generated " + str(len(u_next)) + " unaries")
    print("Generated " + str(len(b_next)) + " binaries")    
    
    
    u_ret = set()
    b_ret = set()
    
    u_ret.update(u)
    b_ret.update(b)
    
    pruned = 0  
    
    for d,a in u_next:
        v = a.vals()
        

            
        if v is None:
            pruned += 1
            continue
        elif len(v) == 0:
            pruned += 1
            continue

        if Atom(2004.0) in v and len(v) == 1:
            v_prods.add(a)
            print("2004!")
            
        if hasattr(d,"__iter__"):
            if hasattr(v,"__iter__"):
                if set(d)==set(v):
                    pruned +=1
                    continue
        u_ret.add(a)
            
    print("Pruned "+ str(pruned) + " unaries")

    pruned = 0
    for d,a in b_next:
        v = a.vals()
            
        if v is None:
            pruned += 1
            continue
        elif len(v) == 0:
            pruned += 1
            continue
            
        if Atom(2004.0) in v and len(v) == 1:
            v_prods.add(a)
            print("2004!")
            
        if hasattr(d,"__iter__"):
            if hasattr(v,"__iter__"):
                if set(d)==set(v):
                    pruned +=1
                    continue

        b_ret.add(a)
    
    print("Pruned "+ str(pruned) + " binaries")
    return u_ret,b_ret,v_prods


u_,b_ = u,b
v_prods = set()
for i in range(2):
    print("Iteration " + str(i))
    u_, b_,v_prods = transition(u_,b_,v_prods)
    print("")
    
for d in v_prods:
    print(d)

In [None]:
print([str(a) for a in Join(Reverse(properties[0]), ArgMax(Join(properties[2],Entity("Greece")),index)).vals()])
print(Join(Reverse(properties[0]), ArgMax(Join(properties[2],Entity("Greece")),index)))

In [None]:
table = Table("towns.tsv")


properties = []
for prop in table.read()['header']:
    properties.append(Property(prop))


       
atoms = set()
entities = set()

row_id = 0
last_row = None

next_row = Property("$next")
index = Property("$index")

records = Unary("$records") 

properties.append(next_row)
properties.append(index)

for row in table.read()['rows']:
    r = Record("$"+str(row_id))
    
    index.add(r,Atom(row_id))
    
    col_id = 0 
    for col in row:

        if len(col.strip()) == 0:
            col_id +=1
            continue
            
        if is_number(col) or is_date(col):
            a = Atom(float(col))
            properties[col_id].add(r,a)
            atoms.add(a)
        else:
            e = Entity(col.strip())
            
            properties[col_id].add(r,e)
            entities.add(e)
        col_id += 1
 
    records.add(r)
    if last_row is not None:
        next_row.add(last_row,r)
    row_id += 1 
    last_row = r    


actions = [( ground_entity , t.nes ),
           ( ground_atom , t.nums ),
          ]

u = set()
b = set()

for action,data in actions:
    for datum in data:
        u.update(action(datum))

        
for p in properties:
    b.add(p)
    
u.add(records)
        


def transition(u,b,v_prods):
    u_actions = [( act_union , cross_product(u,u) ),
             ( act_intersection, cross_product(u,u) ),
             ( act_agg, u )
            ]

    bu_actions = [( act_join , cross_product(b,u) ),
                  ( act_sup, cross_product(u,b) )
                 ]

    b_actions = [ ( act_reverse, b),
                  ( act_chain, cross_product(b,b)) 
                ]


    print("Starting with " + str(len(u)) + " unaries")
    print("Starting with " + str(len(b)) + " binaries")
    
    u_next = set()
    b_next = set()
    
    for action,data in u_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))

    for action,data in bu_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))

    for action,data in b_actions:
        for datum in data:
            for a in action(datum):
                b_next.add((datum,a))

    print("Generated " + str(len(u_next)) + " unaries")
    print("Generated " + str(len(b_next)) + " binaries")    
    
    
    u_ret = set()
    b_ret = set()
    
    u_ret.update(u)
    b_ret.update(b)
    
    pruned = 0  
    
    for d,a in u_next:
        v = a.vals()
        

            
        if v is None:
            pruned += 1
            continue
        elif len(v) == 0:
            pruned += 1
            continue

        if Atom(2004.0) in v and len(v) == 1:
            v_prods.add(a)
            print("2004!")
            
        if hasattr(d,"__iter__"):
            if hasattr(v,"__iter__"):
                if set(d)==set(v):
                    pruned +=1
                    continue
        u_ret.add(a)
            
    print("Pruned "+ str(pruned) + " unaries")

    pruned = 0
    for d,a in b_next:
        v = a.vals()
            
        if v is None:
            pruned += 1
            continue
        elif len(v) == 0:
            pruned += 1
            continue
            
        if Atom(2004.0) in v and len(v) == 1:
            v_prods.add(a)
            print("2004!")
            
        if hasattr(d,"__iter__"):
            if hasattr(v,"__iter__"):
                if set(d)==set(v):
                    pruned +=1
                    continue

        b_ret.add(a)
    
    print("Pruned "+ str(pruned) + " binaries")
    return u_ret,b_ret,v_prods



    

print(GreaterThan(Atom(100)).compile()(Atom(10)))
print(GreaterThan(Atom(100)).compile()(Atom(110)))


#print([str(a) for a in Join(properties[1],GreaterThan(Atom(1000))).vals()])

print([str(a) for a in Join(Reverse(properties[0]), Join(properties[1],GreaterThan(Atom(1000)))).vals()])

#print(Join(Reverse(properties[1]),list(records.vals())[0]).vals())
#print([str(a) for a in Join22(Reverse(properties[3]),records).vals()])

In [None]:
t = TrainingExample("Greece last held its Summer Olympics since 2004?",None,None)

examples = ["2004", "September","93", "September 2004", "5th September 2005","3503432.3"]

from dateutil.parser import parser, parse



    
class NormalisationAtom():
    dateprops = Property("$date")
    numberprops = Property("$number")
    
    def __init__(self,value):
        self.value = value
        if is_date(value):
            self.dateprops.add(self,DateAtom(value))
        elif is_number(value):
            self.numberprops.add(self,Atom(float(value)))
            
    def __str__(self):
        return "[NORMALIZE "+str(self.value)+"]"
            
    def allprops():
        return {NormalisationAtom.dateprops,NormalisationAtom.numberprops}
    allprops = staticmethod(allprops)

class DateAtom():
    yearprops = Property("$date$year")
    monthprops = Property("$date$month")
    dayprops = Property("$date$day")
    
    def __init__(self,value):
        self.value = value
        
        res, _ = parser()._parse(value)
        if hasattr(res,"day") and res.day is not None:
            self.day = res.day
            self.dayprops.add(self,Atom(res.day))
        if hasattr(res,"month") and res.month is not None:
            self.month = res.month
            self.monthprops.add(self,Atom(res.month))
        if hasattr(res,"year") and res.year is not None:
            self.year = res.year
            self.yearprops.add(self,Atom(res.year))
            
            
    def __str__(self):
        return ("[DATE Y: " + (str(self.year) if hasattr(self,"year") else "None") 
                + " M: "+ (str(self.month) if hasattr(self,"month") else "None")
                + " D:" + (str(self.day) if hasattr(self,"day") else "None") +"] ")
    
            
    def allprops():
        return {DateAtom.yearprops,DateAtom.monthprops,DateAtom.dayprops}
    
    allprops = staticmethod(allprops)
    
    
    
    
table = Table("test.tsv")


properties = []
for prop in table.read()['header']:
    properties.append(Property(prop))


atoms = set()
entities = set()

row_id = 0
last_row = None

next_row = Property("$next")
index = Property("$index")

records = Unary("$records") 


for row in table.read()['rows']:
    r = Record("$"+str(row_id))
    
    index.add(r,Atom(row_id))
    
    col_id = 0 
    for col in row:

        if len(col.strip()) == 0:
            col_id +=1
            continue
            
        if is_number(col) or is_date(col):
            a = NormalisationAtom(col)
            properties[col_id].add(r,a)
            atoms.add(a)
        else:
            e = Entity(col.strip())
            
            properties[col_id].add(r,e)
            entities.add(e)
        col_id += 1
 
    records.add(r)
    if last_row is not None:
        next_row.add(last_row,r)
    row_id += 1 
    last_row = r    

    
properties.append(next_row)
properties.append(index)

properties.extend(NormalisationAtom.allprops())
properties.extend(DateAtom.allprops())


for num in t.nums: 
    print(NormalisationAtom(" ".join(num.tokens)))