In [1]:
from parser.TokenSpan import TokenSpan
from training import *
from dcs.dcs import *
from table.TableToGraph import *
from dateutil.parser import parse as dparse
import unicodedata
import itertools

Loading spacy
Done


In [12]:
atomic_types = ["DATE","TIME","PERCENT","MONEY","QUANTITY","ORDINAL","CARDINAL"]
entity_types = ["PERSON","NORP","FACILITY","ORG","GPE","LOC","PRODUCT","EVENT","LANGUAGE"]

class TrainingExample():
    def __init__(self,sentence, answer, table):
        doc = nlp(sentence)

        tokens = []
        espans = []
        nspans = []
        ne_start = -1
        for t in range(len(doc)):
            tokens.append(doc[t].text)
            if doc[t].ent_iob == 3:
                if ne_start >= 0:
                    if str(doc[ne_start].ent_type_) in entity_types:
                        espans.append((ne_start,t))
                    elif str(doc[ne_start].ent_type_) in atomic_types:
                        nspans.append((ne_start,t))
                else:
                    ne_start = t
            elif doc[t].ent_iob == 2:
                if ne_start >= 0:
                    if str(doc[ne_start].ent_type_) in entity_types:
                        espans.append((ne_start,t))
                    elif str(doc[ne_start].ent_type_) in atomic_types:
                        nspans.append((ne_start,t))
                    ne_start = -1

        token_spans = []
        for espan in espans:
            start = espan[0]
            end = espan[1]

            ts = TokenSpan(tokens,start,end)
            token_spans.append(ts)
        
        self.nes = token_spans
        token_spans = []
        for espan in nspans:
            start = espan[0]
            end = espan[1]

            ts = TokenSpan(tokens,start,end)
            token_spans.append(ts)
            
            
        self.words = tokens
        self.answer = answer
        self.table = table
        self.nums = token_spans

t = TrainingExample("Greece last held its Summer Olympics since 2004?",None,None)

def ground_entity(ts):
    s = " ".join(ts.tokens)
    return [Entity(s)]


def ground_atom(ts):
    s = " ".join(ts.tokens)
    return [Atom(s)]

def is_date(string):
    try: 
        dparse(string)
        return True
    except ValueError:
        return False
    
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        pass
 
    try:
        unicodedata.numeric(s)
        return True
    except (TypeError, ValueError):
        pass
 
    return False



def cross_product(list1,list2):
    return [(a,b) for a in list1 for b in list2]

def act_union(bits):
    return [Union(bits[0],bits[1])]

def act_intersection(bits):
    return [Intersection(bits[0],bits[1])]

def act_join(bits):
    return [Join(bits[0],bits[1])]

def act_agg(u):
    if isinstance(u,Atom):
        return [Min(u), Max(u), Count(u), Avg(u), Sum(u)]
    return [Count(u)]
        
def act_sup(bits):
    u = bits[0]
    b = bits[1]
    return [ArgMin(u,b), ArgMax(u,b)]

def act_reverse(b):
    if isinstance(b,Reverse):
        return [b.b]
    return [Reverse(b)]

def act_chain(bs):
    b1 = bs[0]
    b2 = bs[1]
    return [Chain(b1,b2)]
    
def act_cmp(ua):
    u = ua[0]
    a = ua[1]
    
    return [Join(u, GreaterThan(a)),
            Join(u, GreaterThanEq(a)),
            Join(u, LessThan(a)),
            Join(u, LessThanEq(a)),
            Join(u, Eq(a)),
            Join(u, NEq(a)),
           ]

table = Table("test.tsv")


properties = []
for prop in table.read()['header']:
    properties.append(Property(prop))


       
atoms = set()
entities = set()

row_id = 0
last_row = None

next_row = Property("$next")
index = Property("$index")

records = Unary("$records") 

properties.append(next_row)
properties.append(index)

for row in table.read()['rows']:
    r = Record("$"+str(row_id))
    
    index.add(r,Atom(row_id))
    
    col_id = 0 
    for col in row:

        if len(col.strip()) == 0:
            col_id +=1
            continue
            
        if is_number(col) or is_date(col):
            a = Atom(float(col))
            properties[col_id].add(r,a)
            atoms.add(a)
        else:
            e = Entity(col.strip())
            
            properties[col_id].add(r,e)
            entities.add(e)
        col_id += 1
 
    records.add(r)
    if last_row is not None:
        next_row.add(last_row,r)
    row_id += 1 
    last_row = r    


actions = [( ground_entity , t.nes ),
           ( ground_atom , t.nums ),
          ]

u = set()
b = set()

for action,data in actions:
    for datum in data:
        u.update(action(datum))

        
for p in properties:
    b.add(p)
    
u.add(records)
        


def transition(u,b,v_prods):
    atoms = [au for au in u if isinstance(au,Atom)]
    
    ab_actions = [ (act_cmp, cross_product(b,atoms)) ]
    
    u_actions = [( act_union , cross_product(u,u) ),
             ( act_intersection, cross_product(u,u) ),
             ( act_agg, u )
            ]

    bu_actions = [( act_join , cross_product(b,u) ),
                  ( act_sup, cross_product(u,b) )
                 ]

    b_actions = [ ( act_reverse, b),
                  ( act_chain, cross_product(b,b)) 
                ]


    print("Starting with " + str(len(u)) + " unaries")
    print("Starting with " + str(len(b)) + " binaries")
    
    u_next = set()
    b_next = set()
    
    for action,data in ab_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))
                print(a)

                
    for action,data in u_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))

    for action,data in bu_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))

    for action,data in b_actions:
        for datum in data:
            for a in action(datum):
                b_next.add((datum,a))

    print("Generated " + str(len(u_next)) + " unaries")
    print("Generated " + str(len(b_next)) + " binaries")    
    
    
    u_ret = set()
    b_ret = set()
    
    u_ret.update(u)
    b_ret.update(b)
    
    pruned = 0  
    
    for d,a in u_next:
        v = a.vals()
        

            
        if v is None:
            pruned += 1
            continue
        elif len(v) == 0:
            pruned += 1
            continue

        if Atom(2004.0) in v and len(v) == 1:
            v_prods.add(a)
            print("2004!")
            
        if hasattr(d,"__iter__"):
            if hasattr(v,"__iter__"):
                if set(d)==set(v):
                    pruned +=1
                    continue
        u_ret.add(a)
            
    print("Pruned "+ str(pruned) + " unaries")

    pruned = 0
    for d,a in b_next:
        v = a.vals()
            
        if v is None:
            pruned += 1
            continue
        elif len(v) == 0:
            pruned += 1
            continue
            
        if Atom(2004.0) in v and len(v) == 1:
            v_prods.add(a)
            print("2004!")
            
        if hasattr(d,"__iter__"):
            if hasattr(v,"__iter__"):
                if set(d)==set(v):
                    pruned +=1
                    continue

        b_ret.add(a)
    
    print("Pruned "+ str(pruned) + " binaries")
    return u_ret,b_ret,v_prods


u_,b_ = u,b
v_prods = set()
for i in range(2):
    print("Iteration " + str(i))
    u_, b_,v_prods = transition(u_,b_,v_prods)
    print("")
    
for d in v_prods:
    print(d)

Iteration 0
Starting with 3 unaries
Starting with 6 binaries
[JOIN: [PROPERTY: $index] x [GT [ATOM:2004]]]
[JOIN: [PROPERTY: $index] x [GTE [ATOM:2004]]]
[JOIN: [PROPERTY: $index] x [LT [ATOM:2004]]]
[JOIN: [PROPERTY: $index] x [LTE [ATOM:2004]]]
[JOIN: [PROPERTY: $index] x [EQ [ATOM:2004]]]
[JOIN: [PROPERTY: $index] x [NEQ [ATOM:2004]]]
[JOIN: [PROPERTY: Country] x [GT [ATOM:2004]]]
[JOIN: [PROPERTY: Country] x [GTE [ATOM:2004]]]
[JOIN: [PROPERTY: Country] x [LT [ATOM:2004]]]
[JOIN: [PROPERTY: Country] x [LTE [ATOM:2004]]]
[JOIN: [PROPERTY: Country] x [EQ [ATOM:2004]]]
[JOIN: [PROPERTY: Country] x [NEQ [ATOM:2004]]]
[JOIN: [PROPERTY: $next] x [GT [ATOM:2004]]]
[JOIN: [PROPERTY: $next] x [GTE [ATOM:2004]]]
[JOIN: [PROPERTY: $next] x [LT [ATOM:2004]]]
[JOIN: [PROPERTY: $next] x [LTE [ATOM:2004]]]
[JOIN: [PROPERTY: $next] x [EQ [ATOM:2004]]]
[JOIN: [PROPERTY: $next] x [NEQ [ATOM:2004]]]
[JOIN: [PROPERTY: Nations] x [GT [ATOM:2004]]]
[JOIN: [PROPERTY: Nations] x [GTE [ATOM:2004]]]
[JOIN: 

TypeError: unorderable types: int() >= str()

In [3]:
print([str(a) for a in Join(Reverse(properties[0]), ArgMax(Join(properties[2],Entity("Greece")),index)).vals()])
print(Join(Reverse(properties[0]), ArgMax(Join(properties[2],Entity("Greece")),index)))

['[ATOM:2004.0]']
[JOIN: R[[PROPERTY: Year]] x [ARGMAX: [JOIN: [PROPERTY: Country] x [ENTITY:Greece]] [PROPERTY: $index]]]


In [6]:
table = Table("towns.tsv")


properties = []
for prop in table.read()['header']:
    properties.append(Property(prop))


       
atoms = set()
entities = set()

row_id = 0
last_row = None

next_row = Property("$next")
index = Property("$index")

records = Unary("$records") 

properties.append(next_row)
properties.append(index)

for row in table.read()['rows']:
    r = Record("$"+str(row_id))
    
    index.add(r,Atom(row_id))
    
    col_id = 0 
    for col in row:

        if len(col.strip()) == 0:
            col_id +=1
            continue
            
        if is_number(col) or is_date(col):
            a = Atom(float(col))
            properties[col_id].add(r,a)
            atoms.add(a)
        else:
            e = Entity(col.strip())
            
            properties[col_id].add(r,e)
            entities.add(e)
        col_id += 1
 
    records.add(r)
    if last_row is not None:
        next_row.add(last_row,r)
    row_id += 1 
    last_row = r    


actions = [( ground_entity , t.nes ),
           ( ground_atom , t.nums ),
          ]

u = set()
b = set()

for action,data in actions:
    for datum in data:
        u.update(action(datum))

        
for p in properties:
    b.add(p)
    
u.add(records)
        


def transition(u,b,v_prods):
    
    u_actions = [( act_union , cross_product(u,u) ),
             ( act_intersection, cross_product(u,u) ),
             ( act_agg, u )
            ]

    bu_actions = [( act_join , cross_product(b,u) ),
                  ( act_sup, cross_product(u,b) )
                 ]

    b_actions = [ ( act_reverse, b),
                  ( act_chain, cross_product(b,b)) 
                ]


    print("Starting with " + str(len(u)) + " unaries")
    print("Starting with " + str(len(b)) + " binaries")
    
    u_next = set()
    b_next = set()
    
    for action,data in u_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))

    for action,data in bu_actions:
        for datum in data:
            for a in action(datum):
                u_next.add((datum,a))

    for action,data in b_actions:
        for datum in data:
            for a in action(datum):
                b_next.add((datum,a))

    print("Generated " + str(len(u_next)) + " unaries")
    print("Generated " + str(len(b_next)) + " binaries")    
    
    
    u_ret = set()
    b_ret = set()
    
    u_ret.update(u)
    b_ret.update(b)
    
    pruned = 0  
    
    for d,a in u_next:
        v = a.vals()
        

            
        if v is None:
            pruned += 1
            continue
        elif len(v) == 0:
            pruned += 1
            continue

        if Atom(2004.0) in v and len(v) == 1:
            v_prods.add(a)
            print("2004!")
            
        if hasattr(d,"__iter__"):
            if hasattr(v,"__iter__"):
                if set(d)==set(v):
                    pruned +=1
                    continue
        u_ret.add(a)
            
    print("Pruned "+ str(pruned) + " unaries")

    pruned = 0
    for d,a in b_next:
        v = a.vals()
            
        if v is None:
            pruned += 1
            continue
        elif len(v) == 0:
            pruned += 1
            continue
            
        if Atom(2004.0) in v and len(v) == 1:
            v_prods.add(a)
            print("2004!")
            
        if hasattr(d,"__iter__"):
            if hasattr(v,"__iter__"):
                if set(d)==set(v):
                    pruned +=1
                    continue

        b_ret.add(a)
    
    print("Pruned "+ str(pruned) + " binaries")
    return u_ret,b_ret,v_prods



    

print(GreaterThan(Atom(100)).compile()(Atom(10)))
print(GreaterThan(Atom(100)).compile()(Atom(110)))


#print([str(a) for a in Join(properties[1],GreaterThan(Atom(1000))).vals()])

print([str(a) for a in Join(Reverse(properties[0]), Join(properties[1],GreaterThan(Atom(1000)))).vals()])

#print(Join(Reverse(properties[1]),list(records.vals())[0]).vals())
#print([str(a) for a in Join22(Reverse(properties[3]),records).vals()])

False
True
['[ENTITY:Tórshavn]', '[ENTITY:Tvøroyri]', '[ENTITY:Miðvágur]', '[ENTITY:Vestmanna]', '[ENTITY:Fuglafjørður]', '[ENTITY:Argir]', '[ENTITY:Klaksvík]', '[ENTITY:Hoyvík]', '[ENTITY:Vágur]']
