Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

better corpus loading

  • Loading branch information...
commit 928d370b1478f44e260d96690a105c4d281fdba8 1 parent 5215d33
@japerk authored
View
17 nltk_trainer/__init__.py
@@ -38,15 +38,20 @@ def load_corpus_reader(corpus, reader=None, fileids=None, **kwargs):
raise ValueError('you must specify a corpus reader')
if not fileids:
- raise ValueError('you must specify the corpus fileids')
+ fileids = '.*'
- if os.path.isdir(corpus):
- root = corpus
- else:
+ root = os.path.expanduser(corpus)
+
+ if not os.path.isdir(root):
+ if not corpus.startswith('corpora/'):
+ path = 'corpora/%s' % corpus
+ else:
+ path = corpus
+
try:
- root = nltk.data.find(corpus)
+ root = nltk.data.find(path)
except LookupError:
- raise ValueError('cannot find corpus path %s' % corpus)
+ raise ValueError('cannot find corpus path for %s' % corpus)
reader_cls = import_attr(reader)
real_corpus = reader_cls(root, fileids, **kwargs)
View
10 train_chunker.py
@@ -69,14 +69,10 @@
## corpus reader ##
###################
-chunked_corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids)
-
-if not chunked_corpus:
- raise ValueError('%s is an unknown corpus')
-
if args.trace:
- print 'loading nltk.corpus.%s' % args.corpus
-# trigger loading so it has its true class
+ print 'loading %s' % args.corpus
+
+chunked_corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids)
chunked_corpus.fileids()
fileids = args.fileids
kwargs = {}
View
25 train_classifier.py
@@ -10,7 +10,7 @@
from nltk.metrics import BigramAssocMeasures, f_measure, masi_distance, precision, recall
from nltk.probability import FreqDist, ConditionalFreqDist
from nltk.util import ngrams
-from nltk_trainer import dump_object, import_attr
+from nltk_trainer import dump_object, import_attr, load_corpus_reader
from nltk_trainer.classification import corpus, scoring
from nltk_trainer.classification.featx import bag_of_words, bag_of_words_in_set, train_test_feats
from nltk_trainer.classification.multi import MultiBinaryClassifier
@@ -37,9 +37,9 @@
help='number of most informative features to show, works for all algorithms except DecisionTree')
corpus_group = parser.add_argument_group('Training Corpus')
-corpus_group.add_argument('--reader', choices=('plaintext', 'tagged'),
- default='plaintext',
- help='specify categorized plaintext or part-of-speech tagged corpus')
+corpus_group.add_argument('--reader',
+ default='nltk.corpus.reader.CategorizedPlaintextCorpusReader',
+ help='Full module path to a corpus reader class, such as %(default)s')
corpus_group.add_argument('--cat_pattern', default='(.+)/.+',
help='''A regular expression pattern to identify categories based on file paths.
If cat_file is also given, this pattern is used to identify corpus file ids.
@@ -129,11 +129,6 @@
## corpus reader ##
###################
-reader_class = {
- 'plaintext': CategorizedPlaintextCorpusReader,
- 'tagged': CategorizedTaggedCorpusReader
-}
-
reader_args = []
reader_kwargs = {}
@@ -160,8 +155,15 @@
if args.para_block_reader:
reader_kwargs['para_block_reader'] = import_attr(args.para_block_reader)
-categorized_corpus = LazyCorpusLoader(args.corpus, reader_class[args.reader],
+if args.trace:
+ print 'loading %s' % args.corpus
+
+categorized_corpus = load_corpus_reader(args.corpus, args.reader,
*reader_args, **reader_kwargs)
+
+if not hasattr(categorized_corpus, 'categories'):
+ raise ValueError('%s is does not have categories for classification')
+
labels = categorized_corpus.categories()
nlabels = len(labels)
@@ -291,9 +293,6 @@ def norm_words(words):
scoring.cross_fold(train_feats, trainf, accuracy, folds=args.cross_fold,
trace=args.trace, metrics=not args.no_eval, informative=args.show_most_informative)
else:
- if args.trace:
- print 'training %s classifier' % args.classifier
-
classifier = trainf(train_feats)
################
View
14 train_tagger.py
@@ -109,16 +109,10 @@
## corpus reader ##
###################
-tagged_corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids)
-
-if not tagged_corpus:
- raise ValueError('%s is an unknown corpus')
-
if args.trace:
- print 'loading nltk.corpus.%s' % args.corpus
-# trigger loading so it has its true class
-tagged_corpus.fileids()
-# fileid is used for corpus naming, if it exists
+ print 'loading %s' % args.corpus
+
+tagged_corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids)
fileids = args.fileids
kwargs = {}
@@ -130,7 +124,7 @@
raise ValueError('%s does not support simplify_tags' % args.corpus)
if isinstance(tagged_corpus, SwitchboardCorpusReader):
- if args.fileids:
+ if fileids:
raise ValueError('fileids cannot be used with switchboard')
tagged_sents = list(itertools.chain(*[[list(s) for s in d if s] for d in tagged_corpus.tagged_discourses(**kwargs)]))
Please sign in to comment.
Something went wrong with that request. Please try again.