diff --git a/nltk_trainer/__init__.py b/nltk_trainer/__init__.py index 448bda2..d66e209 100644 --- a/nltk_trainer/__init__.py +++ b/nltk_trainer/__init__.py @@ -38,15 +38,20 @@ def load_corpus_reader(corpus, reader=None, fileids=None, **kwargs): raise ValueError('you must specify a corpus reader') if not fileids: - raise ValueError('you must specify the corpus fileids') + fileids = '.*' - if os.path.isdir(corpus): - root = corpus - else: + root = os.path.expanduser(corpus) + + if not os.path.isdir(root): + if not corpus.startswith('corpora/'): + path = 'corpora/%s' % corpus + else: + path = corpus + try: - root = nltk.data.find(corpus) + root = nltk.data.find(path) except LookupError: - raise ValueError('cannot find corpus path %s' % corpus) + raise ValueError('cannot find corpus path for %s' % corpus) reader_cls = import_attr(reader) real_corpus = reader_cls(root, fileids, **kwargs) diff --git a/train_chunker.py b/train_chunker.py index 778c202..ed0718e 100755 --- a/train_chunker.py +++ b/train_chunker.py @@ -69,14 +69,10 @@ ## corpus reader ## ################### -chunked_corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids) - -if not chunked_corpus: - raise ValueError('%s is an unknown corpus') - if args.trace: - print 'loading nltk.corpus.%s' % args.corpus -# trigger loading so it has its true class + print 'loading %s' % args.corpus + +chunked_corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids) chunked_corpus.fileids() fileids = args.fileids kwargs = {} diff --git a/train_classifier.py b/train_classifier.py index 5711c42..697aae2 100755 --- a/train_classifier.py +++ b/train_classifier.py @@ -10,7 +10,7 @@ from nltk.metrics import BigramAssocMeasures, f_measure, masi_distance, precision, recall from nltk.probability import FreqDist, ConditionalFreqDist from nltk.util import ngrams -from nltk_trainer import dump_object, import_attr +from nltk_trainer import dump_object, import_attr, load_corpus_reader from nltk_trainer.classification import corpus, scoring from nltk_trainer.classification.featx import bag_of_words, bag_of_words_in_set, train_test_feats from nltk_trainer.classification.multi import MultiBinaryClassifier @@ -37,9 +37,9 @@ help='number of most informative features to show, works for all algorithms except DecisionTree') corpus_group = parser.add_argument_group('Training Corpus') -corpus_group.add_argument('--reader', choices=('plaintext', 'tagged'), - default='plaintext', - help='specify categorized plaintext or part-of-speech tagged corpus') +corpus_group.add_argument('--reader', + default='nltk.corpus.reader.CategorizedPlaintextCorpusReader', + help='Full module path to a corpus reader class, such as %(default)s') corpus_group.add_argument('--cat_pattern', default='(.+)/.+', help='''A regular expression pattern to identify categories based on file paths. If cat_file is also given, this pattern is used to identify corpus file ids. @@ -129,11 +129,6 @@ ## corpus reader ## ################### -reader_class = { - 'plaintext': CategorizedPlaintextCorpusReader, - 'tagged': CategorizedTaggedCorpusReader -} - reader_args = [] reader_kwargs = {} @@ -160,8 +155,15 @@ if args.para_block_reader: reader_kwargs['para_block_reader'] = import_attr(args.para_block_reader) -categorized_corpus = LazyCorpusLoader(args.corpus, reader_class[args.reader], +if args.trace: + print 'loading %s' % args.corpus + +categorized_corpus = load_corpus_reader(args.corpus, args.reader, *reader_args, **reader_kwargs) + +if not hasattr(categorized_corpus, 'categories'): + raise ValueError('%s is does not have categories for classification') + labels = categorized_corpus.categories() nlabels = len(labels) @@ -291,9 +293,6 @@ def norm_words(words): scoring.cross_fold(train_feats, trainf, accuracy, folds=args.cross_fold, trace=args.trace, metrics=not args.no_eval, informative=args.show_most_informative) else: - if args.trace: - print 'training %s classifier' % args.classifier - classifier = trainf(train_feats) ################ diff --git a/train_tagger.py b/train_tagger.py index 3666a78..f7e74c7 100755 --- a/train_tagger.py +++ b/train_tagger.py @@ -109,16 +109,10 @@ ## corpus reader ## ################### -tagged_corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids) - -if not tagged_corpus: - raise ValueError('%s is an unknown corpus') - if args.trace: - print 'loading nltk.corpus.%s' % args.corpus -# trigger loading so it has its true class -tagged_corpus.fileids() -# fileid is used for corpus naming, if it exists + print 'loading %s' % args.corpus + +tagged_corpus = load_corpus_reader(args.corpus, reader=args.reader, fileids=args.fileids) fileids = args.fileids kwargs = {} @@ -130,7 +124,7 @@ raise ValueError('%s does not support simplify_tags' % args.corpus) if isinstance(tagged_corpus, SwitchboardCorpusReader): - if args.fileids: + if fileids: raise ValueError('fileids cannot be used with switchboard') tagged_sents = list(itertools.chain(*[[list(s) for s in d if s] for d in tagged_corpus.tagged_discourses(**kwargs)]))