Skip to content

Commit

Permalink
nltk.tag tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dimazest committed Jul 18, 2014
1 parent 1903316 commit c56787f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 479 deletions.
63 changes: 37 additions & 26 deletions nltk/tag/brill.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def nltkdemo18():
Template(Word([-1]), Word([1])),
]


def nltkdemo18plus():
"""
Return 18 templates, from the original nltk demo, and additionally a few
Expand All @@ -88,6 +89,7 @@ def nltkdemo18plus():
Template(Pos([-1]), Word([0]), Pos([1])),
]


def fntbl37():
"""
Return 37 templates taken from the postagging task of the
Expand Down Expand Up @@ -135,6 +137,7 @@ def fntbl37():
Template(Pos([1]), Pos([2]), Word([1]))
]


def brill24():
"""
Return 24 templates of the seminal TBL paper, Brill (1995)
Expand Down Expand Up @@ -171,7 +174,8 @@ def describe_template_sets():
"""
Print the available template sets in this demo, with a short description"
"""
import inspect, sys
import inspect
import sys

#a bit of magic to get all functions in this module
templatesets = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
Expand Down Expand Up @@ -201,7 +205,7 @@ class BrillTagger(TaggerI):
of the TaggerTrainers available.
"""

json_tag='nltk.tag.BrillTagger'
json_tag = 'nltk.tag.BrillTagger'

def __init__(self, initial_tagger, rules, training_stats=None):
"""
Expand Down Expand Up @@ -318,27 +322,35 @@ def det_tplsort(tpl_value):
return (tpl_value[1], repr(tpl_value[0]))

def print_train_stats():
print("TEMPLATE STATISTICS (TRAIN) {0} templates, {1} rules)".format(
len(template_counts),len(tids)))
print(
"TEMPLATE STATISTICS (TRAIN) {0} templates, {1} rules)".format(
len(template_counts),
len(tids)),
)
print("TRAIN ({tokencount:7d} tokens) initial {initialerrors:5d} {initialacc:.4f} "
"final: {finalerrors:5d} {finalacc:.4f} ".format(**train_stats))
head = "#ID | Score (train) | #Rules | Template"
print(head, "\n", "-" * len(head), sep="")
train_tplscores = sorted(weighted_traincounts.items(), key=det_tplsort, reverse=True)
for (tid, trainscore) in train_tplscores:
s = "{0:s} | {1:5d} {2:5.3f} |{3:4d} {4:.3f} | {5:s}".format(
tid,
trainscore,
trainscore/tottrainscores,
template_counts[tid],
template_counts[tid]/len(tids),
Template.ALLTEMPLATES[int(tid)])
s = "{0} | {1:5d} {2:5.3f} |{3:4d} {4:.3f} | {5}".format(
tid,
trainscore,
trainscore / tottrainscores,
template_counts[tid],
template_counts[tid] / len(tids),
Template.ALLTEMPLATES[int(tid)],
)
print(s)

def print_testtrain_stats():
testscores = test_stats['rulescores']
print("TEMPLATE STATISTICS (TEST AND TRAIN) ({0} templates, {1} rules)".format(
len(template_counts),len(tids)))
print(
"TEMPLATE STATISTICS (TEST AND TRAIN) ({0} templates, {1} rules)".format(
len(template_counts),
len(tids),
),
)
print("TEST ({tokencount:7d} tokens) initial {initialerrors:5d} {initialacc:.4f} "
"final: {finalerrors:5d} {finalacc:.4f} ".format(**test_stats))
print("TRAIN ({tokencount:7d} tokens) initial {initialerrors:5d} {initialacc:.4f} "
Expand All @@ -352,14 +364,15 @@ def print_testtrain_stats():
test_tplscores = sorted(weighted_testcounts.items(), key=det_tplsort, reverse=True)
for (tid, testscore) in test_tplscores:
s = "{0:s} |{1:5d} {2:6.3f} | {3:4d} {4:.3f} |{5:4d} {6:.3f} | {7:s}".format(
tid,
testscore,
testscore/tottestscores,
weighted_traincounts[tid],
weighted_traincounts[tid]/tottrainscores,
template_counts[tid],
template_counts[tid]/len(tids),
Template.ALLTEMPLATES[int(tid)])
tid,
testscore,
testscore / tottestscores,
weighted_traincounts[tid],
weighted_traincounts[tid] / tottrainscores,
template_counts[tid],
template_counts[tid] / len(tids),
Template.ALLTEMPLATES[int(tid)],
)
print(s)

def print_unused_templates():
Expand Down Expand Up @@ -395,15 +408,13 @@ def batch_tag_incremental(self, sequences, gold):
:returns: tuple of (tagged_sequences, ordered list of rule scores (one for each rule))
"""
def counterrors(xs):
return sum(t[1] != g[1]
for pair in zip(xs, gold)
for (t, g) in zip(*pair))
return sum(t[1] != g[1] for pair in zip(xs, gold) for (t, g) in zip(*pair))
testing_stats = {}
testing_stats['tokencount'] = sum(len(t) for t in sequences)
testing_stats['sequencecount'] = len(sequences)
tagged_tokenses = [self._initial_tagger.tag(tokens) for tokens in sequences]
testing_stats['initialerrors'] = counterrors(tagged_tokenses)
testing_stats['initialacc'] = 1- testing_stats['initialerrors']/testing_stats['tokencount']
testing_stats['initialacc'] = 1 - testing_stats['initialerrors'] / testing_stats['tokencount']
# Apply each rule to the entire corpus, in order
errors = [testing_stats['initialerrors']]
for rule in self._rules:
Expand All @@ -412,7 +423,7 @@ def counterrors(xs):
errors.append(counterrors(tagged_tokenses))
testing_stats['rulescores'] = [err0 - err1 for (err0, err1) in zip(errors, errors[1:])]
testing_stats['finalerrors'] = errors[-1]
testing_stats['finalacc'] = 1 - testing_stats['finalerrors']/testing_stats['tokencount']
testing_stats['finalacc'] = 1 - testing_stats['finalerrors'] / testing_stats['tokencount']
return (tagged_tokenses, testing_stats)


Expand Down
27 changes: 9 additions & 18 deletions nltk/tag/brill_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __init__(self, initial_tagger, templates, trace=0,
if the rule applies. This records the next position we
need to check to see if the rule messed anything up."""


#////////////////////////////////////////////////////////////
# Training
#////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -163,41 +162,33 @@ def train(self, train_sents, max_rules=200, min_score=2, min_acc=None):
22 27 5 24 | NN->-NONE- if Pos:VBD@[-1]
17 17 0 0 | NN->CC if Pos:NN@[-1] & Word:and@[0]
>>> tagger1.rules()[1:3]
(Rule('001', 'NN', ',', [(Pos([-1]),'NN'), (Word([0]),',')]), Rule('001', 'NN', '.', [(Pos([-1]),'NN'), (Word([0]),'.')]))
>>> train_stats = tagger1.train_stats()
>>> [train_stats[stat] for stat in ['initialerrors', 'finalerrors', 'rulescores']]
[1775, 1269, [132, 85, 69, 51, 47, 33, 26, 24, 22, 17]]
##FIXME: the following test fails -- why?
#
#>>> tagger1.print_template_statistics(printunused=False)
#TEMPLATE STATISTICS (TRAIN) 2 templates, 10 rules)
#TRAIN ( 3163 tokens) initial 2358 0.2545 final: 1719 0.4565
##ID | Score (train) | #Rules | Template
#--------------------------------------------
#001 | 404 0.632 | 7 0.700 | Template(Pos([-1]),Word([0]))
#000 | 235 0.368 | 3 0.300 | Template(Pos([-1]))
#<BLANKLINE>
#<BLANKLINE>
>>> tagger1.print_template_statistics(printunused=False)
TEMPLATE STATISTICS (TRAIN) 2 templates, 10 rules)
TRAIN ( 2417 tokens) initial 1775 0.2656 final: 1269 0.4750
#ID | Score (train) | #Rules | Template
--------------------------------------------
001 | 305 0.603 | 7 0.700 | Template(Pos([-1]),Word([0]))
000 | 201 0.397 | 3 0.300 | Template(Pos([-1]))
<BLANKLINE>
<BLANKLINE>
>>> tagger1.evaluate(gold_data) # doctest: +ELLIPSIS
0.43996...
>>> (tagged, test_stats) = tagger1.batch_tag_incremental(testing_data, gold_data)
>>> tagged[33][12:] == [('foreign', 'IN'), ('debt', 'NN'), ('of', 'IN'), ('$', 'NN'), ('64', 'CD'),
... ('billion', 'NN'), ('*U*', 'NN'), ('--', 'NN'), ('the', 'DT'), ('third-highest', 'NN'), ('in', 'NN'),
... ('the', 'DT'), ('developing', 'VBG'), ('world', 'NN'), ('.', '.')]
True
>>> [test_stats[stat] for stat in ['initialerrors', 'finalerrors', 'rulescores']]
[1855, 1376, [100, 85, 67, 58, 27, 36, 27, 16, 31, 32]]
Expand Down

0 comments on commit c56787f

Please sign in to comment.