In [1]:
from pprint import pprint

def print_sentence_tags(tagged_sentence, tag_filter):
    for i, (wd,tags) in enumerate(tagged_sentence):
        tags = tags.intersection(tag_filter)
        stags = ""
        if tags:
            stags = ",".join(tags)
        print(str(i).ljust(3),wd.ljust(30), stags)

In [2]:
class Stack(object):
    def __init__(self, verbose=False):    
        self.stack = []
        self.verbose = verbose
    
    def tos(self):
        if self.len() == 0:
            return None
        #assert self.len() > 0, "Can't peek when stack is empty"
        return self.stack[-1]
    
    def pop(self):
        assert self.len() > 0, "Can't pop when stack is empty"
        item = self.stack.pop()
        if self.verbose:
            print("POPPING: %s" % item)
            print("LEN:     %i" % len(self.stack))
        return item
    
    def push(self, item):
        self.stack.append(item)
        if self.verbose:
            print("PUSHING: %s" % item)
            print("LEN:     %i" % len(self.stack))
    
    def len(self):
        return len(self.stack)

    def contains(self, item):
        return item in self.stack
    
    def __repr__(self):
        return "|".join(map(str,self.stack))

In [30]:
ROOT = "root"

def norm_arc(arc):
    #return tuple(sorted(arc, key=lambda tpl: (tpl[0],tpl[1])))
    return tuple(sorted(arc))

def norm_arcs(arcs):
    return set(map(norm_arc, arcs))

class Parser(object):
    def __init__(self, stack):
        self.stack = stack
        self.arcs = []
        self.normed_arcs = set()
        # nodes with heads
        self.children = set()
        self.actions = []

    def get_dependencies(self):
        return [(l, r) for (l, r) in self.arcs if r != ROOT and l != ROOT]

    def left_arc(self, buffer):
        tos = self.stack.pop()
        # Pre-condition
        # assert self.has_head(tos) == False
        arc = (tos, buffer)
        n_arc = norm_arc(arc)
        assert n_arc not in self.normed_arcs, "Arc already processed %s" % str(n_arc)
        self.arcs.append(arc)
        self.normed_arcs.add(arc)
        self.children.add(tos)
        self.actions.append("L ARC   : " + str(tos) + "->" + str(buffer))

    def right_arc(self, buffer):
        tos = self.stack.tos()
        # normalize arc
        arc = (buffer, tos)
        n_arc = norm_arc(arc)
        assert n_arc not in self.normed_arcs, "Arc already processed %s" % str(n_arc)
        self.arcs.append(arc)
        self.normed_arcs.add(n_arc)
        self.actions.append("R ARC   : " + str(tos) + "<-" + str(buffer))
        self.children.add(buffer)
        self.stack.push(buffer)

    def reduce(self):
        tos = self.stack.pop()
        # assert self.has_head(tos) == True
        self.actions.append("REDUCE  : Pop  %s" % str(tos))

    def shift(self, buffer):
        self.stack.push(buffer)
        self.actions.append("SHIFT   : Push %s" % str(buffer))

    def skip(self, buffer):
        self.actions.append("SKIP    : item %s" % str(buffer))

    def has_head(self, item):
        return item in self.children

    def in_stack(self, item):
        return self.stack.contains(item)

    def clone(self):
        cloney = Parser(self.stack.clone())
        cloney.arcs = list(self.arcs)
        cloney.normed_arcs = set(self.normed_arcs)
        # nodes with heads
        cloney.children = set(self.children)
        cloney.actions = list(self.actions)
        return cloney

In [31]:
from collections import defaultdict

SHIFT = "Shift"
REDUCE = "Reduce"
LARC = "LArc"
RARC = "Rarc"
SKIP = "Skip"

class Oracle(object):
    def __init__(self, crels, parser):
        self.parser = parser
        self.raw_crels = crels
        self.crels = norm_arcs(crels)  # type: Set[Tuple[str,str]]
        self.mapping = self.build_mappings(crels)

    def build_mappings(self, pairs):
        mapping = defaultdict(set)
        for c, res in pairs:
            mapping[c].add(res)
            mapping[res].add(c)
        return mapping

    def should_continue(self, action):
        # continue parsing if REDUCE or LARC
        return action in (REDUCE, LARC)

    def remove_relation(self, a, b):
        # as we can force it to execute actions that are invalid, we have to see if this is a valid relation to remove
        if a in self.mapping and b in self.mapping[a]:
            self.mapping[a].remove(b)
            if len(self.mapping[a]) == 0:
                del self.mapping[a]
            self.mapping[b].remove(a)
            if len(self.mapping[b]) == 0:
                del self.mapping[b]

    def consult(self, tos, buffer):
        """
        Performs optimal decision for parser
        If true, continue processing, else Consume Buffer
        """
        parser = self.parser
        a, b = norm_arc((tos, buffer))
        if (a, b) in self.crels:
            # TOS has arcs remaining? If so, we need RARC, else LARC
            if len(self.mapping[tos]) == 1:
                return LARC
            else:
                return RARC
        else:
            if buffer not in self.mapping:
                return SKIP
            # If the buffer has relations further down in the stack, we need to POP the TOS
            for item in self.mapping[buffer]:
                if item == tos:
                    continue
                if parser.in_stack(item):
                    return REDUCE
            # end for
            # ELSE
            return SHIFT

    def execute(self, action, tos, buffer):
        """
        Performs optimal decision for parser
        If true, continue processing, else Consume Buffer
        """
        parser = self.parser
        if action == LARC:
            parser.left_arc(buffer)
            self.remove_relation(tos, buffer)
        elif action == RARC:
            parser.right_arc(buffer)
            self.remove_relation(tos, buffer)
        elif action == REDUCE:
            parser.reduce()
        elif action == SHIFT:
            parser.shift(buffer)
        elif action == SKIP:
            parser.skip(buffer)
        else:
            raise Exception("Unknown parsing action %s" % action)
        return self.should_continue(action)

    def tos(self):
        return self.parser.stack.tos()

    def is_stack_empty(self):
        return self.parser.stack.len() == 0

    def clone(self):
        cloney = Oracle(set(self.raw_crels), self.parser.clone())
        # Need to ensure a deep clone of the mappings dict
        cloney.mapping = defaultdict(set)
        for key, set_vals in self.mapping.items():
            cloney.mapping[key].update(set_vals)
        return cloney

In [73]:
def test_oracle(codes, crels, orcl_fact, verbose=False):
    
    crels = set(crels)
    if verbose:
        prn_fun = lambda s="": print(s)
    else:
        prn_fun = lambda s="": None
    
    stack = Stack(False)
    stack.push(ROOT)
    parser = Parser(stack)
    oracle = orcl_fact(crels, parser)

    prn_fun("DEPS")
    for crel in sorted(crels):
        prn_fun("\t" + str(crel))
    prn_fun()

    PAD = 20
    STACK_PAD = 2*len(codes) + len(ROOT)
    LINE = PAD + 12 + STACK_PAD  + 13 + len(codes)

    for ix,buffer in enumerate(codes):
        prn_fun("-" * LINE)
        prn_fun(buffer)
        prn_fun("-" * LINE)

        while True:
            tos = stack.tos()
            action = oracle.consult(tos, buffer)
            # these actions don't advance the buffer
            if action in (LARC,REDUCE):
                REMAINING_BUFFER = codes[ix:]
            else:
                REMAINING_BUFFER = codes[ix+1:]
            if not oracle.execute(action, tos, buffer):
                prn_fun(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack).ljust(STACK_PAD) 
                        + " || BUFFER : " + REMAINING_BUFFER)
                break

            prn_fun(parser.actions[-1].ljust(PAD) + " || STACK : " + str(stack).ljust(STACK_PAD) 
                        + " || BUFFER : " + REMAINING_BUFFER)
            if stack.len() == 0:
                prn_fun("Empty stack, stopping")
                break

    prn_fun()
    prn_fun("*" * LINE)
    prn_fun("Stack")
    prn_fun("\t" + str(stack))
    deps = parser.get_dependencies()
    prn_fun("DEPS Actual")
    for crel in sorted(crels):
        prn_fun("\t" + str(crel))
    prn_fun("DEPS Pred")
    for dep in sorted(deps):
        prn_fun("\t" + str(dep))
    prn_fun("Actions")
    for a in parser.actions:
        prn_fun("\t" + a)
    prn_fun()
    prn_fun("Ordered Match?    " + str(set(deps) == crels))

    ndeps = norm_arcs(deps)
    ncrels = norm_arcs(crels)
    diff = (ndeps - ncrels).union(ncrels - ndeps)
    success = (len(diff) == 0)
    prn_fun("Un Ordered Match? " + str(success))
    if not success:
        prn_fun("Matches")
        prn_fun(ndeps.intersection(ncrels))
        prn_fun("Differences")
        prn_fun(diff)
    return success

In [71]:
test_pairs = []

test_pairs.append([
    ("A","B"),
])
test_pairs.append([
    ("A","B"),
    ("B","C"),
])
#C->B->A
test_pairs.append([
    ("C","B"),
    ("B","A"),
])
test_pairs.append([
    ("A","C"),
    ("B","C"),
])
test_pairs.append([
    ("A","B"),
    ("C","B"),
])
test_pairs.append([
    ("B","A"),
    ("B","C"),
])
test_pairs.append([
    ("A","C"),
    ("C","B"),
])

# Hard - has to flip relation
test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
])
test_pairs.append([
    ("D","A"),
    ("D","B"),
    ("B","C"),
])
test_pairs.append([
    ("D","A"),
    ("B","D"),
    ("B","C"),
])

test_pairs.append([
    ("A","E"),
    ("E","B"),
    ("B","D"),
    ("D","C"),
])
test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
    ("A", "F"),
    ("A", "E"),
])

test_pairs.append([
    ("A","D"),
    ("D","B"),
    ("B","C"),
    ("A", "F"),
    ("E", "F"),
])

oracle_fact = Oracle
for pairs in test_pairs:
    try:
        success = test_oracle("ABCDEF", pairs, oracle_fact, verbose=False)
    except:
        success = False
        
    if not success:
        print("Error for relations:")
        pprint(pairs)
        print()
        success = test_oracle("ABCDEF", pairs, oracle_fact, verbose=True)

## Visualize Parse for Tricker Graphs

### <span style="color:red">Doesn't Handle Cycles</span>
- So we remove the condition about only having a single parent

In [74]:
#[('1', '3'), ('1', '50'), ('3', '50')]
#['50', '1', '3']
pairs =[
    ("B","A"),
    ("B","C"),
    ("C","A"),
]
test_oracle("ABCDEF", pairs, Oracle, verbose=True)

DEPS
	('B', 'A')
	('B', 'C')
	('C', 'A')

-------------------------------------------------------------------
A
-------------------------------------------------------------------
SHIFT   : Push A     || STACK : root|A           || BUFFER : BCDEF
-------------------------------------------------------------------
B
-------------------------------------------------------------------
R ARC   : A<-B       || STACK : root|A|B         || BUFFER : CDEF
-------------------------------------------------------------------
C
-------------------------------------------------------------------
L ARC   : B->C       || STACK : root|A           || BUFFER : CDEF
L ARC   : A->C       || STACK : root             || BUFFER : CDEF
SKIP    : item C     || STACK : root             || BUFFER : DEF
-------------------------------------------------------------------
D
-------------------------------------------------------------------
SKIP    : item D     || STACK : root             || BUFFER : EF
------------

True

In [64]:
#[('1', '3'), ('1', '50'), ('3', '50')]
#['50', '1', '3']
pairs =[
    ("B","A"),
    ("B","C"),
    ("D","B"),
    ("E","B"),
]
test_oracle("ABCDEF", pairs, Oracle, verbose=True)

DEPS
	('B', 'A')
	('B', 'C')
	('D', 'B')
	('E', 'B')

-------------------------------------------------------------------
A
-------------------------------------------------------------------
SHIFT   : Push A     || STACK : root|A           || BUFFER : ABCDEF
-------------------------------------------------------------------
B
-------------------------------------------------------------------
L ARC   : A->B       || STACK : root             || BUFFER : BCDEF
SHIFT   : Push B     || STACK : root|B           || BUFFER : BCDEF
-------------------------------------------------------------------
C
-------------------------------------------------------------------
R ARC   : B<-C       || STACK : root|B|C         || BUFFER : CDEF
-------------------------------------------------------------------
D
-------------------------------------------------------------------
REDUCE  : Pop  C     || STACK : root|B           || BUFFER : DEF
R ARC   : B<-D       || STACK : root|B|D         || BUFFER : 

True

In [65]:
pairs =[
    ("A","D"),
    ("D","B"),
    ("B","C"),
]
test_oracle("ABCDEF", pairs, Oracle, verbose=True)

DEPS
	('A', 'D')
	('B', 'C')
	('D', 'B')

-------------------------------------------------------------------
A
-------------------------------------------------------------------
SHIFT   : Push A     || STACK : root|A           || BUFFER : ABCDEF
-------------------------------------------------------------------
B
-------------------------------------------------------------------
SHIFT   : Push B     || STACK : root|A|B         || BUFFER : BCDEF
-------------------------------------------------------------------
C
-------------------------------------------------------------------
R ARC   : B<-C       || STACK : root|A|B|C       || BUFFER : CDEF
-------------------------------------------------------------------
D
-------------------------------------------------------------------
REDUCE  : Pop  C     || STACK : root|A|B         || BUFFER : DEF
L ARC   : B->D       || STACK : root|A           || BUFFER : DEF
L ARC   : A->D       || STACK : root             || BUFFER : DEF
SKIP    : 

True

## Non Projective Parse Should Fail Test

In [75]:
pairs =[
    ("A","C"),
    ("B","E"),
]
try:
    success = test_oracle("ABCDEF", pairs, Oracle, verbose=True)
except Exception as e:
    success = False
    raise e
assert success == False

DEPS
	('A', 'C')
	('B', 'E')

-------------------------------------------------------------------
A
-------------------------------------------------------------------
SHIFT   : Push A     || STACK : root|A           || BUFFER : BCDEF
-------------------------------------------------------------------
B
-------------------------------------------------------------------
SHIFT   : Push B     || STACK : root|A|B         || BUFFER : CDEF
-------------------------------------------------------------------
C
-------------------------------------------------------------------
REDUCE  : Pop  B     || STACK : root|A           || BUFFER : CDEF
L ARC   : A->C       || STACK : root             || BUFFER : CDEF
SKIP    : item C     || STACK : root             || BUFFER : DEF
-------------------------------------------------------------------
D
-------------------------------------------------------------------
SKIP    : item D     || STACK : root             || BUFFER : EF
------------------------

In [76]:
pairs =[
    ("A","C"),
    ("B","D"),
]
try:
    success = test_oracle("ABCD", pairs, Oracle, verbose=True)
except Exception as e:
    success = False
    raise e
assert success == False

DEPS
	('A', 'C')
	('B', 'D')

-------------------------------------------------------------
A
-------------------------------------------------------------
SHIFT   : Push A     || STACK : root|A       || BUFFER : BCD
-------------------------------------------------------------
B
-------------------------------------------------------------
SHIFT   : Push B     || STACK : root|A|B     || BUFFER : CD
-------------------------------------------------------------
C
-------------------------------------------------------------
REDUCE  : Pop  B     || STACK : root|A       || BUFFER : CD
L ARC   : A->C       || STACK : root         || BUFFER : CD
SKIP    : item C     || STACK : root         || BUFFER : D
-------------------------------------------------------------
D
-------------------------------------------------------------
SHIFT   : Push D     || STACK : root|D       || BUFFER : 

*************************************************************
Stack
	root|D
DEPS Actual
	('A', 'C')
	('B', 

## Test on Real Causal Relations (Limit to 2 or More Relations in a Sentence)

In [10]:
def normalize(code):
    return code.replace("Causer:","").replace("Result:","")

def normalize_cr(cr):
    return tuple(normalize(cr).split("->"))

In [11]:
normalize("Causer:14"),normalize("Result:50")

('14', '50')

In [12]:
normalize_cr('Causer:14->Result:50')

('14', '50')

In [13]:
import pickle 

training_pickled = "/Users/simon.hughes/Google Drive/Phd/Data/CoralBleaching/Thesis_Dataset/training.pl"
with open(training_pickled, "rb+") as f:
    tagged_essays = pickle.load(f)
len(tagged_essays)

902

In [14]:
from collections import defaultdict

tag_freq = defaultdict(int)
unique_words = set()
for essay in tagged_essays:
    for sentence in essay.sentences:
        for word, tags in sentence:
            unique_words.add(word)
            for tag in tags:
                tag_freq[tag] += 1

EMPTY_TAG = "Empty"
#TODO - don't ignore Anaphor, other and rhetoricals here
cr_tags  = list((t for t in tag_freq.keys() if ( "->" in t) and not "Anaphor" in t and not "other" in t and not "rhetorical" in t))
reg_tags = set((t for t in tag_freq.keys() if ( "->" not in t) and (t == "explicit" or t[0].isdigit())))

## Parse Causal Relations Using Position Information to Differentiate Codes and Crels (So can have multiple of the same type)

In [286]:
import string

def get_tags_relations_for(tagged_sentence, tag_freq, reg_tags, cr_tags):
    sent_reg_predicted_tags = set()
    sent_act_cr_tags = set()
    
    tag_seq  = [None] # seed with None
    crel_set_seq = [set()]

    pos_tag_seq = []
    latest_tag_posns = {}
    crel_child_tags = defaultdict(set)
    for i, (wd,tags) in enumerate(tagged_sentence):
        if wd in string.punctuation:
            continue
        # Get tag seq
        active_tag = None
        rtags = set([normalize(t) for t in tags])
        rtags = rtags.intersection(reg_tags)
        if rtags:
            # only use explicit tag if it's the only tag (prefer concept code tags if both present)
            if len(rtags) > 1 and "explicit" in rtags:
                rtags.remove("explicit")
            active_tag = max(rtags, key = lambda t: tag_freq[t])
            sent_reg_predicted_tags.add(active_tag)
            # if no prev tag and the current matches -2 (a gap of one), skip over
            if active_tag != tag_seq[-1] and \
                    not(tag_seq[-1] is None and (len(tag_seq) > 2) and active_tag == tag_seq[-2]):
                latest_tag_posns[active_tag] = (active_tag,i)
                pos_tag_seq.append((active_tag,i))
        tag_seq.append(active_tag)

        active_crels = tags.intersection(cr_tags)
        for cr in sorted(active_crels):
            sent_act_cr_tags.add(cr)
            if cr not in crel_set_seq[-1] \
                    and not(cr not in crel_set_seq[-1] and (len(crel_set_seq) > 2) and cr in crel_set_seq[-2]):
                latest_tag_posns[cr] = (cr, i)
        crel_set_seq.append(active_crels)

        # to have child tags, need a tag sequence and a current valid regular tag
        if not active_tag or len(active_crels) == 0:
            continue

        for crel in active_crels:
            l,r = normalize_cr(crel)
            if active_tag in (l,r):
                crel_child_tags[latest_tag_posns[crel]].add(latest_tag_posns[active_tag])

    pos_crels = []
    for (crelation,crix), tag_pairs in crel_child_tags.items():
        l,r = normalize_cr(crelation)
        #unsupported relation
        if l not in sent_reg_predicted_tags or r not in sent_reg_predicted_tags:
            continue
        tag2pair = defaultdict(list)
        for taga,ixa in tag_pairs:
            tag2pair[taga].append((taga,ixa))
        # un-supported relation
        if l not in tag2pair or r not in tag2pair:
            continue

        l_pairs = tag2pair[l]
        r_pairs = tag2pair[r]
        for pairsa in l_pairs:
            for pairsb in r_pairs:
                if pairsa != pairsb:            
                    pos_crels.append((pairsa,pairsb))
    return pos_tag_seq, pos_crels, crel_child_tags, sent_reg_predicted_tags, sent_act_cr_tags

In [290]:
from pprint import pprint

relations = []
skipped_sent = 0
skipped_crels = 0
num_sents = 0
num_csl = 0
num_supported = 0

diffs = []

too_many_kids = []

for essay_ix, essay in enumerate(tagged_essays):
    for sent_ix, tagged_sentence in enumerate(essay.sentences):
        
        tag_seq, _, crel_child_tags,_,_ = get_tags_relations_for(tagged_sentence, tag_freq, reg_tags, cr_tags)
        
        num_sents += 1
        un_csl  = set()
        for i, (wd,tags) in enumerate(tagged_sentence):
            csl = tags.intersection(cr_tags)
            un_csl.update(csl)
        
        num_csl += len(un_csl)
        
        # Don't count sentences without any relations as skipped
        if un_csl:        
            supported_causal = set()
            for (crel,ix), posn_tags in crel_child_tags.items():
                l,r = normalize_cr(crel)
                unique_child_tags = set()
                for tag, ix in posn_tags:
                    unique_child_tags.add(tag)                    
                if len(unique_child_tags) < 2:                    
                    # if l == r then we want to keep these
                    if len(unique_child_tags) == 0 or (len(unique_child_tags) == 1 and l != r):
                        skipped_crels += 1
                        continue
                if len(posn_tags) > 2:
                    too_many_kids.append((essay_ix, sent_ix, crel, list(tag_seq), dict(crel_child_tags)))
                assert (l in unique_child_tags and r in unique_child_tags), "Error - child tags are not supported"
                supported_causal.add(crel)

            if not supported_causal:
                skipped_sent += 1
                continue
                
            num_supported += len(supported_causal)
            # filter out any tags that were only part of unsupported causal relations
            #tag_seq = [tag for tag in tag_seq if tag in supported_codes]
            relations.append((essay_ix,sent_ix,supported_causal,tag_seq))
            
            if len(supported_causal) != len(un_csl):
                diffs.append((essay_ix, sent_ix, un_csl, supported_causal))
        #else:
        #    if un_csl:
        #        diffs.append((essay_ix, sent_ix, un_csl, set()))
        
num_sents, num_supported, num_csl #skipped_sent, skipped_crels, 
#(8292, 2217, 3006) # if the counts differ from these numbers, we've broken something

(8292, 2868, 3006)

In [291]:
# expected recall
2868/3006

0.9540918163672655

In [261]:
tag_seq, _, crel_child_tags,_,_ = get_tags_relations_for(tagged_sentence, tag_freq, reg_tags, cr_tags)

In [263]:
import string

sent_reg_predicted_tags = set()
sent_act_cr_tags = set()

tag_seq  = [None] # seed with None
crel_set_seq = [set()]

pos_tag_seq = []
latest_tag_posns = {}

crel_child_tags = defaultdict(set)

for i, (wd,tags) in enumerate(tagged_sentence):
    import string
    if wd in string.punctuation:
        continue
    # Get tag seq
    active_tag = None
    rtags = set([normalize(t) for t in tags])
    rtags = rtags.intersection(reg_tags)
    if rtags:
        # only use explicit tag if it's the only tag (prefer concept code tags if both present)
        if len(rtags) > 1 and "explicit" in rtags:
            rtags.remove("explicit")
        active_tag = max(rtags, key = lambda t: tag_freq[t])
        sent_reg_predicted_tags.add(active_tag)
        # if no prev tag and the current matches -2 (a gap of one), skip over
        if active_tag != tag_seq[-1] and \
                not(tag_seq[-1] is None and (len(tag_seq) > 2) and active_tag == tag_seq[-2]):
            latest_tag_posns[active_tag] = (active_tag,i)
            pos_tag_seq.append((active_tag,i))
    tag_seq.append(active_tag)

    active_crels = tags.intersection(cr_tags)
    for cr in sorted(active_crels):
        sent_act_cr_tags.add(cr)
        if cr not in crel_set_seq[-1] \
                and not(cr not in crel_set_seq[-1] and (len(crel_set_seq) > 2) and cr in crel_set_seq[-2]):
            latest_tag_posns[cr] = (cr, i)
    crel_set_seq.append(active_crels)

    # to have child tags, need a tag sequence and a current valid regular tag
    if not active_tag or len(active_crels) == 0:
        continue

    for crel in active_crels:
        l,r = normalize_cr(crel)
        if active_tag in (l,r):
            crel_child_tags[latest_tag_posns[crel]].add(latest_tag_posns[active_tag])

pos_crels = []
for (crelation,crix), tag_pairs in crel_child_tags.items():
    l,r = normalize_cr(crelation)
    #unsupported relation
    if l not in sent_reg_predicted_tags or r not in sent_reg_predicted_tags:
        continue
    tag2pair = defaultdict(list)
    for taga,ixa in tag_pairs:
        tag2pair[taga].append((taga,ixa))
    if l not in tag2pair or r not in tag2pair:
        #raise Exception("Missing children %s" % crelation)
        continue

    l_pairs = tag2pair[l]
    r_pairs = tag2pair[r]
    for pairsa in l_pairs:
        a_tag,a_ix = pairsa
        for pairsb in r_pairs:
            # Could be the same tag
            if pairsa != pairsb:
                b_tag, b_ix = pairsb            
                pos_crels.append((pairsa,pairsb))

In [264]:
[(c,t) for c,t in list(zip(crel_set_seq, tag_seq)) if c or t]

[({'Causer:7->Result:50'}, '50'),
 ({'Causer:7->Result:50'}, '50'),
 ({'Causer:7->Result:50'}, '50'),
 ({'Causer:7->Result:50'}, '50'),
 ({'Causer:7->Result:50'}, 'explicit'),
 ({'Causer:7->Result:50'}, 'explicit'),
 ({'Causer:7->Result:50'}, None),
 ({'Causer:7->Result:50'}, '7'),
 ({'Causer:7->Result:50'}, '7'),
 ({'Causer:7->Result:50'}, '7'),
 ({'Causer:7->Result:50'}, '7'),
 ({'Causer:7->Result:50'}, '7'),
 ({'Causer:7->Result:50'}, '7'),
 ({'Causer:7->Result:50'}, '7')]

In [265]:
pprint(crel_child_tags), pos_crels

defaultdict(<class 'set'>, {('Causer:7->Result:50', 1): {('7', 9), ('50', 1)}})


(None, [(('7', 9), ('50', 1))])

In [266]:
print_sentence_tags(tagged_sentence, reg_tags.union(cr_tags))

0   during                         
1   bleaching                      Causer:7->Result:50,50
2   ,                              
3   corals                         Causer:7->Result:50,50
4   turn                           Causer:7->Result:50,50
5   white                          Causer:7->Result:50,50
6   due                            Causer:7->Result:50,explicit
7   to                             Causer:7->Result:50,explicit
8   the                            Causer:7->Result:50
9   ejection                       Causer:7->Result:50,7
10  or                             Causer:7->Result:50,7
11  death                          Causer:7->Result:50,7
12  ,                              
13  of                             Causer:7->Result:50,7
14  the                            Causer:7->Result:50,7
15  zooxanthellae                  Causer:7->Result:50,7
16  algae                          Causer:7->Result:50,7
17  .                              


### How Many of the Relations Have More than 2 Child Tags?

In [17]:
len(too_many_kids)

6

## Do Any of these Have More than 3 Children?

In [18]:
for (essay_ix, sent_ix, crel, tag_seq, crel_child_tags) in too_many_kids:
    #for (crel,ix), posn_tags in crel_child_tags.items():
    pprint(crel_child_tags)

{('Causer:3->Result:50', 3): {('50', 3), ('3', 7), ('3', 15)}}
{('Causer:1->Result:5', 6): {('1', 6), ('5', 9), ('5', 20)}}
{('Causer:7->Result:50', 0): {('50', 0), ('50', 16), ('7', 11)}}
{('Causer:1->Result:3', 0): {('3', 12), ('1', 0), ('3', 5)}}
{('Causer:1->Result:50', 6): {('50', 15), ('1', 6), ('50', 10)}}
{('Causer:3->Result:50', 1): {('3', 1), ('50', 19), ('50', 12)}}


**>>> Create separate causal relations for each combo of children where l != r (unless a self to self relation)**

In [19]:
e_ix, s_ix, crel, tseq, cr_kids = too_many_kids[2]
crel, tseq, cr_kids

('Causer:7->Result:50',
 [('50', 0), ('explicit', 8), ('7', 11), ('50', 16)],
 {('Causer:7->Result:50', 0): {('50', 0), ('50', 16), ('7', 11)}})

In [20]:
get_tags_relations_for(sentence, tag_freq, reg_tags, cr_tags)

([('50', 1), ('explicit', 6), ('7', 9)],
 defaultdict(set, {('Causer:7->Result:50', 1): {('50', 1), ('7', 9)}}))

In [21]:
sentence = tagged_essays[e_ix].sentences[s_ix]
print_sentence_tags(sentence, [k for k in tag_freq.keys() if (":" in k and not "rhetorical" in k and not "->" in k) or k[0].isdigit()])
print()
print_sentence_tags(sentence, cr_tags)

0   coral                          50,Result:50
1   bleaching                      50,Result:50
2   is                             50,Result:50
3   when                           50,Result:50
4   coral                          50,Result:50
5   loses                          50,Result:50
6   it                             50,Result:50
7   color                          50,Result:50
8   due                            
9   too                            
10  the                            
11  algae                          Causer:7
12  that                           Causer:7
13  lives                          Causer:7
14  on                             Causer:7
15  the                            Causer:7
16  coral                          50,Causer:7
17  INFREQUENT                     50
18  .                              

0   coral                          Causer:7->Result:50
1   bleaching                      Causer:7->Result:50
2   is                             Causer:7->Result:50
3

In [22]:
2217/3006

0.7375249500998003

In [23]:
#normalize_cr("Causer:5->Result:10")

e_ix = essay_ix
s_ix = sent_ix
sentence = tagged_essays[e_ix].sentences[s_ix]
tag_seq, crel_children = get_tags_relations_for(sentence, tag_freq, reg_tags, set(cr_tags))
for pair in tag_seq:
    print(pair)
print("*" * 30)
for crel, kids in crel_children.items():
    print(crel)
    for k in kids:
        print(str(k))
    print()

('3', 1)
('11', 4)
('13', 7)
('50', 12)
('explicit', 14)
('50', 19)
******************************
('Causer:3->Result:50', 1)
('3', 1)
('50', 19)
('50', 12)



## Test New Parsing Logic

In [24]:
get_tags_relations_for(sentence, tag_freq, reg_tags, cr_tags)

([('3', 1), ('11', 4), ('13', 7), ('50', 12), ('explicit', 14), ('50', 19)],
 defaultdict(set,
             {('Causer:3->Result:50', 1): {('3', 1), ('50', 12), ('50', 19)}}))

In [267]:
errors = 0
successes= 0
exs = []
for e_ix, essay in enumerate(tagged_essays):
    for s_ix, sentence in enumerate(essay.sentences):
        
        tag_seq, crels,_,all_rtags,all_crel_tags = get_tags_relations_for(sentence, tag_freq, reg_tags, cr_tags)
        if not tag_seq or not crels:
            continue
        try:
            success = test_oracle(tag_seq, crels, Oracle, verbose=False)
            successes += 1
        except Exception as e:
            exs.append(e)
            success = False

        if not success:
            errors += 1
            print("Error for relations:", e_ix, ",", s_ix)
            pprint(crels)
            pprint(tag_seq)
            #print()
            #success = test_oracle(tag_seq, crels, Oracle, verbose=True)
            #break

Error for relations: 427 , 2
[(('1', 19), ('4', 27)),
 (('1', 19), ('50', 33)),
 (('3', 23), ('4', 27)),
 (('3', 23), ('50', 33))]
[('explicit', 17), ('1', 19), ('3', 23), ('4', 27), ('50', 33)]
Error for relations: 597 , 3
[(('3', 6), ('6', 1)),
 (('13', 11), ('6', 1)),
 (('3', 6), ('7', 18)),
 (('13', 11), ('7', 18)),
 (('7', 18), ('50', 31))]
[('explicit', 0),
 ('6', 1),
 ('3', 6),
 ('13', 11),
 ('explicit', 16),
 ('7', 18),
 ('explicit', 29),
 ('50', 31)]
Error for relations: 813 , 0
[(('3', 1), ('50', 19)),
 (('3', 1), ('50', 12)),
 (('11', 4), ('50', 19)),
 (('11', 4), ('50', 12)),
 (('13', 7), ('50', 19)),
 (('13', 7), ('50', 12))]
[('3', 1), ('11', 4), ('13', 7), ('50', 12), ('explicit', 14), ('50', 19)]
Error for relations: 890 , 2
[(('12', 1), ('11', 13)), (('13', 9), ('14', 20)), (('13', 9), ('11', 13))]
[('12', 1),
 ('13', 9),
 ('explicit', 11),
 ('11', 13),
 ('explicit', 18),
 ('14', 20)]


In [268]:
# These are all non-projective
errors, successes

(4, 2217)

In [188]:
print_sentence_tags(tagged_sentence, cr_tags)

0   the                            
1   temperature                    Causer:3->Result:50
2   change                         Causer:3->Result:50
3   ,                              
4   extreme                        Causer:11->Result:50,Causer:3->Result:50
5   storms                         Causer:11->Result:50,Causer:3->Result:50
6   ,                              
7   salinity                       Causer:11->Result:50,Causer:3->Result:50,Causer:13->Result:50
8   ,                              
9   zooxnathellae                  Causer:11->Result:50,Causer:3->Result:50,Causer:13->Result:50
10  ,                              
11  and                            Causer:11->Result:50,Causer:3->Result:50,Causer:13->Result:50
12  coral                          Causer:11->Result:50,Causer:3->Result:50,Causer:13->Result:50
13  bleaching                      Causer:11->Result:50,Causer:3->Result:50,Causer:13->Result:50
14  causes                         Causer:11->Result:50,Causer:3->Result:

In [191]:
pprint(crel_child_tags)

defaultdict(<class 'set'>,
            {('Causer:11->Result:50', 4): {('11', 4), ('50', 19), ('50', 12)},
             ('Causer:13->Result:50', 7): {('13', 7), ('50', 19), ('50', 12)},
             ('Causer:3->Result:50', 1): {('3', 1), ('50', 19), ('50', 12)}})


In [285]:
#e_ix = 813
s_ix = 9
#tagged_essay = tagged_essays[e_ix]
tagged_essay = [e for e in tagged_essays if e.name == 'EBA1415_BLRW_3_CB_ES-05177.ann'][0]
tagged_sentence = tagged_essay.sentences[s_ix]

tag_seq, crel_seq, crel_child_tags,_,crel_tags = get_tags_relations_for(tagged_sentence, tag_freq, reg_tags, cr_tags)

pprint(tag_seq)
pprint(crel_seq)
pprint(crel_tags)
pprint(crel_child_tags)
print()
test_oracle(tag_seq, crel_seq, Oracle, verbose=True)

[('3', 10), ('explicit', 22), ('5', 24), ('3', 30)]
[(('3', 10), ('5', 24))]
{'Causer:3->Result:5'}
defaultdict(<class 'set'>, {('Causer:3->Result:5', 10): {('5', 24), ('3', 10)}})

DEPS
	(('3', 10), ('5', 24))

---------------------------------
('3', 10)
---------------------------------
SHIFT   : Push ('3', 10) || STACK : root|('3', 10)
---------------------------------
('explicit', 22)
---------------------------------
SKIP    : item ('explicit', 22) || STACK : root|('3', 10)
---------------------------------
('5', 24)
---------------------------------
L ARC   : ('3', 10)->('5', 24) || STACK : root
SKIP    : item ('5', 24) || STACK : root
---------------------------------
('3', 30)
---------------------------------
SKIP    : item ('3', 30) || STACK : root

*********************************
Stack
	root
DEPS Actual
	(('3', 10), ('5', 24))
DEPS Pred
	(('3', 10), ('5', 24))
Actions
	SHIFT   : Push ('3', 10)
	SKIP    : item ('explicit', 22)
	L ARC   : ('3', 10)->('5', 24)
	SKIP    : item

True

In [157]:
crel_child_tags

defaultdict(set,
            {('Causer:1->Result:4', 19): {('1', 19), ('4', 27)},
             ('Causer:1->Result:50', 19): {('1', 19), ('50', 33)},
             ('Causer:3->Result:4', 23): {('3', 23), ('4', 27)},
             ('Causer:3->Result:50', 23): {('3', 23), ('50', 33)}})

In [160]:
tag2pairs

defaultdict(set, {'3': {('3', 23)}, '50': {('50', 33)}})

In [164]:
pos_crels=[]
for _, tag_pairs in crel_child_tags.items():
    for pairsa in tag_pairs:
        a_tag,a_ix = pairsa
        for pairsb in tag_pairs:
            b_tag, b_ix = pairsb
            if pairsa != pairsb and a_ix < b_ix:
                pos_crels.append((pairsa,pairsb))
pos_crels

[(('1', 19), ('4', 27)),
 (('1', 19), ('50', 33)),
 (('3', 23), ('4', 27)),
 (('3', 23), ('50', 33))]

In [156]:
errors, successes

(4, 2208)

In [67]:
print(tag_seq)
print(crels)

["('50', 1)", "('explicit', 6)", "('7', 9)"]
[("('7', 9)", "('50', 1)"), ("('50', 1)", "('7', 9)")]


In [69]:
test_oracle(tag_seq, crels, Oracle, verbose=True)

DEPS
	("('50', 1)", "('7', 9)")
	("('7', 9)", "('50', 1)")

-------------------------------
('50', 1)
-------------------------------
SHIFT   : Push ('50', 1) || STACK : root|('50', 1)
-------------------------------
('explicit', 6)
-------------------------------
SKIP    : item ('explicit', 6) || STACK : root|('50', 1)
-------------------------------
('7', 9)
-------------------------------
L ARC   : ('50', 1)->('7', 9) || STACK : root
SKIP    : item ('7', 9) || STACK : root

*******************************
Stack
	root
DEPS Actual
	("('50', 1)", "('7', 9)")
	("('7', 9)", "('50', 1)")
DEPS Pred
	("('50', 1)", "('7', 9)")
Actions
	SHIFT   : Push ('50', 1)
	SKIP    : item ('explicit', 6)
	L ARC   : ('50', 1)->('7', 9)
	SKIP    : item ('7', 9)

Ordered Match?    False
Un Ordered Match? True


True

### <span style="color:red">TODO - call success if we recover the original crel, as we are generating a lot of crels</span>
### <span style="color:red">TODO - Also check against fully supported as we want to get that to 100%</span>

In [26]:
errors

2302

## For the Unsupported Relations, are the Missing Tags in the Previous or Subsequent Sentence?

In [277]:
from pprint import pprint
def print_sentence(sentence):
    for wd, tags in sentence:
        print(wd.ljust(20), str([t for t in tags if t[0].isdigit()]).ljust(30), [t for t in tags if "->" in t])

for essay_ix, sent_ix, un_csl, supported_causal in diffs[0:10]:
    sentence = tagged_essays[essay_ix].sentences[sent_ix]
    pprint(un_csl)
    pprint(supported_causal)
    print("Missing")
    pprint(un_csl - supported_causal)
    if sent_ix > 0:
        print("--Previous--")    
        print_sentence(tagged_essays[essay_ix].sentences[sent_ix-1])
    print("--Sentence--")
    print_sentence(sentence)
    if sent_ix < len(tagged_essays[essay_ix].sentences)-1:
        print("--Next--")    
        print_sentence(tagged_essays[essay_ix].sentences[sent_ix+1])

    print() 

{'Causer:3->Result:4', 'Causer:4->Result:14'}
{'Causer:4->Result:14'}
Missing
{'Causer:3->Result:4'}
--Previous--
as                   []                             []
the                  []                             []
temperature          ['3']                          ['Causer:3->Result:4']
of                   ['3']                          ['Causer:3->Result:4']
water                ['3']                          ['Causer:3->Result:4']
increases            ['3']                          ['Causer:3->Result:4']
.                    []                             []
--Sentence--
the                  []                             ['Causer:3->Result:4']
amount               ['4']                          ['Causer:3->Result:4', 'Causer:4->Result:14']
of                   ['4']                          ['Causer:3->Result:4', 'Causer:4->Result:14']
co2                  ['4']                          ['Causer:3->Result:4', 'Causer:4->Result:14']
can                  []                

## TODO 
- Re-train tagging model, adding tags where reg tag is missing but is included in a causer or result tag. 
- Also include explicit in the predicted tags.
- Need to handle relations where same code -> same code

## 4 Errors Below Look Are from Non-Projective Parses
**NOTES**
With only 4 errors as 4 missed relations, hardly worth worrying about. 
One solution would be to train a forward and a backward parser, parse the sentence in both directions and merge the deps. In each case that would pick up all deps.

In [106]:
for i, (essay_ix, sent_ix, supported_causal, tag_seq) in enumerate(relations[:]):
    supported_causal = sorted(supported_causal)
    crels = [normalize_cr(crel) for crel in supported_causal]
    for l,r in crels:
        if l == r:
            print(i, l,r)
            print(relations[i])
            print()

37 50 50
(24, 13, {'Causer:1->Result:50', 'Causer:50->Result:50'}, ['explicit', '1', '50', 'explicit', '50'])

493 50 50
(189, 4, {'Causer:13->Result:50', 'Causer:50->Result:50'}, ['explicit', '13', 'explicit', '50', 'explicit', '50'])

527 50 50
(197, 10, {'Causer:5b->Result:50', 'Causer:50->Result:50'}, ['5b', 'explicit', '50', 'explicit', '50'])

766 11 11
(276, 12, {'Causer:11->Result:11', 'Causer:3->Result:4'}, ['4', '3', 'explicit', '11'])



In [107]:
#Why is the last one missing the 11->11 relation?
tagged_essays[276].sentences[12]
#looks to be an unsupported relation

[('balance', set()),
 ('between', set()),
 ('co2', {'4', 'Causer:3->Result:4', 'Result', 'Result:4'}),
 ('and', {'Causer:3->Result:4'}),
 ('water', {'3', 'Causer', 'Causer:3', 'Causer:3->Result:4'}),
 ('temperature', {'3', 'Causer', 'Causer:3', 'Causer:3->Result:4'}),
 ('is', set()),
 ('also', set()),
 ('threaten', {'explicit'}),
 ('by', {'explicit'}),
 ('extreme',
  {'11',
   'Causer',
   'Causer:11',
   'Causer:11->Result:11',
   'Result',
   'Result:11'}),
 ('storms',
  {'11',
   'Causer',
   'Causer:11',
   'Causer:11->Result:11',
   'Result',
   'Result:11'}),
 ('.', set())]

In [203]:
errors = 0
exs = []
for e_ix,s_ix, supported_causal, tag_seq in relations[:]:
    # remove indexes
    tag_seq = list(zip(*tag_seq))[0]
    
    supported_causal = sorted(supported_causal)
    crels = [normalize_cr(crel) for crel in supported_causal]
    
    try:
        success = test_oracle(tag_seq, crels, Oracle, verbose=False)
    except Exception as e:
        exs.append(e)
        success = False
        
    if not success:
        errors += 1
        print("Error for relations:", e_ix, ",", s_ix)
        pprint(crels)
        pprint(tag_seq)
        #print()
        #success = test_oracle(tag_seq, crels, Oracle, verbose=True)
        #break

Error for relations: 8 , 0
[('7', '50')]
('50', '50', 'explicit', '7', '50')
Error for relations: 23 , 3
[('11', '12'), ('12', '13'), ('13', '50')]
('11', '11', 'explicit', '12', 'explicit', '13', 'explicit', '50')
Error for relations: 24 , 13
[('1', '50'), ('50', '50')]
('explicit', '1', '50', 'explicit', '50')
Error for relations: 33 , 3
[('1', '50'), ('1', '7'), ('3', '50'), ('3', '7')]
('1', '3', 'explicit', '50', 'explicit', '1', '3', '7')
Error for relations: 50 , 7
[('11', '13')]
('11', '11', 'explicit', '13')
Error for relations: 51 , 6
[('3', '50')]
('3', '3', 'explicit', '50')
Error for relations: 61 , 0
[('5', '7'), ('7', '50')]
('50', '50', 'explicit', '7', 'explicit', '5')
Error for relations: 68 , 1
[('6', '50'), ('6', '7')]
('6', 'explicit', '50', 'explicit', '6', 'explicit', '7')
Error for relations: 75 , 2
[('11', '12'), ('11', '13'), ('12', '13'), ('13', '14')]
('13', 'explicit', '11', 'explicit', '12', 'explicit', '13', 'explicit', '14')
Error for relations: 126 , 1


In [204]:
errors

89

In [279]:
crels     = [('11', '13')]
tag_seq   = ('11', '11', 'explicit', '13')
test_oracle(tag_seq, crels, Oracle, verbose=True)

DEPS
	('11', '13')

---------------------------------
11
---------------------------------
SHIFT   : Push 11    || STACK : root|11
---------------------------------
11
---------------------------------
SHIFT   : Push 11    || STACK : root|11|11
---------------------------------
explicit
---------------------------------
SKIP    : item explicit || STACK : root|11|11
---------------------------------
13
---------------------------------
L ARC   : 11->13     || STACK : root|11


AssertionError: Arc already processed ('11', '13')

In [119]:
crels = [('1', '3'), ('1', '50')]
tag_seq = ['1', 'explicit', '50', '1', 'explicit', '3']
test_oracle(tag_seq, crels, Oracle, verbose=True)

DEPS
	('1', '3')
	('1', '50')

-------------------------------------
1
-------------------------------------
SHIFT   : Push 1     || STACK : root|1
-------------------------------------
explicit
-------------------------------------
SKIP    : item explicit || STACK : root|1
-------------------------------------
50
-------------------------------------
R ARC   : 1<-50      || STACK : root|1|50
-------------------------------------
1
-------------------------------------


AssertionError: Arc already processed ('1', '50')

## <span style="color:red">NEED to determine if all errors are non-projective<span>

In [189]:
test_oracle(['5', '50'], [('5', '50')], Oracle2, verbose=True)

DEPS
	('5', '50')

-----------------------------
5
-----------------------------
SHIFT   : Push 5     || STACK : root|5
-----------------------------
50
-----------------------------
L ARC   : 5->50      || STACK : root
L ARC   : 5->50      || STACK : root

*****************************
Stack
	root
DEPS Actual
	('5', '50')
DEPS Pred
	('5', '50')
Actions
	SHIFT   : Push 5
	L ARC   : 5->50

Ordered Match?    True
Un Ordered Match? True


True

In [199]:
errors = 0
exs = []
for e_ix, s_ix, supported_causal, tag_seq in relations[:]:
    
    # GET INITIAL TAGS (ignore indexes)
    tag_seq = list(zip(*tag_seq))[0]
    
    supported_causal = sorted(supported_causal)
    crels = [normalize_cr(crel) for crel in supported_causal]

    try:
        success = test_oracle(tag_seq, crels, Oracle2, verbose=False)
    except Exception as e:
        exs.append(e)
        success = False
        
    if not success:
        errors += 1
        print("Error for relations:")
        pprint(crels)
        pprint(tag_seq)
        #print()
        #success = test_oracle(tag_seq, crels, Oracle, verbose=True)
        #break
errors

Error for relations:
[('3', '4')]
('explicit', '3', 'explicit', '4')
Error for relations:
[('1', '50'), ('11', '50'), ('13', '50')]
('explicit', '50', '13', '1', '11')
Error for relations:
[('7', '50')]
('50', '50', 'explicit', '7', '50')
Error for relations:
[('3', '4')]
('explicit', '3', 'explicit', '4')
Error for relations:
[('1', '50'), ('3', '1')]
('explicit', '50', '1', 'explicit', '3', 'explicit', '1')
Error for relations:
[('1', '3')]
('explicit', '1', '3')
Error for relations:
[('3', '4')]
('explicit', '3', 'explicit', '4')
Error for relations:
[('3', '50')]
('explicit', '50', '3')
Error for relations:
[('3', '50')]
('explicit', '3', 'explicit', '50')
Error for relations:
[('13', '50')]
('explicit', '13', 'explicit', '50')
Error for relations:
[('3', '7'), ('7', '50')]
('explicit', '3', 'explicit', '7', 'explicit', '50')
Error for relations:
[('11', '12'), ('12', '13'), ('13', '50')]
('11', '11', 'explicit', '12', 'explicit', '13', 'explicit', '50')
Error for relations:
[('3',

660

In [192]:
relations[0]

(1, 1, {'Causer:5->Result:50'}, [('5', 2), ('explicit', 9), ('50', 11)])