Skip to content

Commit

Permalink
* Fix missing root labels bug identified in Issue #57
Browse files Browse the repository at this point in the history
  • Loading branch information
syllog1sm committed Apr 28, 2015
1 parent 693c5a1 commit b3fd48c
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 11 deletions.
8 changes: 7 additions & 1 deletion spacy/syntax/arc_eager.pyx
Expand Up @@ -88,9 +88,15 @@ cdef class ArcEager(TransitionSystem):
t.get_cost = get_cost_funcs[move]
return t

cdef int first_state(self, State* state) except -1:
cdef int initialize_state(self, State* state) except -1:
push_stack(state)

cdef int finalize_state(self, State* state) except -1:
cdef int root_label = self.strings['ROOT']
for i in range(state.sent_len):
if state.sent[i].head == 0 and state.sent[i].dep == 0:
state.sent[i].dep = root_label

cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = _can_shift(s)
Expand Down
3 changes: 0 additions & 3 deletions spacy/syntax/ner.pyx
Expand Up @@ -124,9 +124,6 @@ cdef class BiluoPushDown(TransitionSystem):
t.get_cost = _get_cost
return t

cdef int first_state(self, State* state) except -1:
pass

cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef int best = -1
cdef weight_t score = -90000
Expand Down
8 changes: 4 additions & 4 deletions spacy/syntax/parser.pyx
Expand Up @@ -83,23 +83,22 @@ cdef class GreedyParser:
cdef int n_feats
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.first_state(state)
self.moves.initialize_state(state)
cdef Transition guess
while not is_final(state):
fill_context(context, state)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, state)
#print self.moves.move_name(guess.move, guess.label),
#print print_state(state, [w.orth_ for w in tokens])
guess.do(&guess, state)
self.moves.finalize_state(state)
tokens.set_parse(state.sent)
return 0

def train(self, Tokens tokens, GoldParse gold):
self.moves.preprocess_gold(gold)
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.first_state(state)
self.moves.initialize_state(state)

cdef int cost
cdef const Feature* feats
Expand All @@ -117,3 +116,4 @@ cdef class GreedyParser:
self.model.update(context, guess.clas, best.clas, cost)

guess.do(&guess, state)
self.moves.finalize_state(state)
3 changes: 2 additions & 1 deletion spacy/syntax/transition_system.pxd
Expand Up @@ -30,7 +30,8 @@ cdef class TransitionSystem:
cdef const Transition* c
cdef readonly int n_moves

cdef int first_state(self, State* state) except -1
cdef int initialize_state(self, State* state) except -1
cdef int finalize_state(self, State* state) except -1

cdef int preprocess_gold(self, GoldParse gold) except -1

Expand Down
7 changes: 5 additions & 2 deletions spacy/syntax/transition_system.pyx
Expand Up @@ -26,8 +26,11 @@ cdef class TransitionSystem:
i += 1
self.c = moves

cdef int first_state(self, State* state) except -1:
raise NotImplementedError
cdef int initialize_state(self, State* state) except -1:
pass

cdef int finalize_state(self, State* state) except -1:
pass

cdef int preprocess_gold(self, GoldParse gold) except -1:
raise NotImplementedError
Expand Down

0 comments on commit b3fd48c

Please sign in to comment.