From 5fa51b9bd9ae5e152efdc4744356c3d62c5a9042 Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Tue, 20 Dec 2011 14:12:40 +0100 Subject: [PATCH] ENH stratified label deletion in semisup example + stricter tol in EM NB --- .../semisupervised_document_classification.py | 18 +++++++++++------- sklearn/naive_bayes.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/semisupervised_document_classification.py b/examples/semisupervised_document_classification.py index c00a81f59a904..1f0b752e0ccca 100644 --- a/examples/semisupervised_document_classification.py +++ b/examples/semisupervised_document_classification.py @@ -26,6 +26,7 @@ import sys from time import time +from sklearn.cross_validation import StratifiedKFold from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import Vectorizer from sklearn.naive_bayes import BernoulliNB, SemisupervisedNB, MultinomialNB @@ -44,7 +45,7 @@ help="Print the confusion matrix.") op.add_option("--labeled", action="store", type="float", dest="labeled_fraction", - help="Fraction of labels to retain.") + help="Fraction of labels to retain (roughly).") op.add_option("--report", action="store_true", dest="print_report", help="Print a detailed classification report.") @@ -63,11 +64,14 @@ print -def split_indices(n, fraction): - """Randomly split indices""" - k = int(fraction * n) - a = rng.permutation(np.arange(n)) - return a[:k], a[k:] +def split_indices(y, fraction): + """Random stratified split of indices into y + + Returns (unlabeled, labeled) + """ + k = int(round(1 / fraction)) + folds = list(StratifiedKFold(y, k)) + return folds[rng.randint(k)] def trim(s): @@ -131,7 +135,7 @@ def trim(s): print "n_samples: %d, n_features: %d" % X_test.shape print -labeled, unlabeled = split_indices(len(y_train), fraction) +unlabeled, labeled = split_indices(y_train, fraction) print "Removing labels of %d random training documents" % len(unlabeled) print X_labeled = X_train[labeled] diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index d2fef10c1cccd..240bd90a7d475 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -497,7 +497,7 @@ class SemisupervisedNB(BaseNB): Whether to print progress information. """ - def __init__(self, estimator, n_iter=10, relabel_all=True, tol=1e-3, + def __init__(self, estimator, n_iter=10, relabel_all=True, tol=1e-5, verbose=False): self.estimator = estimator self.n_iter = n_iter