Skip to content

Commit

Permalink
ENH stratified label deletion in semisup example + stricter tol in EM NB
Browse files Browse the repository at this point in the history
  • Loading branch information
larsmans committed Dec 20, 2011
1 parent e0d3808 commit 5fa51b9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
18 changes: 11 additions & 7 deletions examples/semisupervised_document_classification.py
Expand Up @@ -26,6 +26,7 @@
import sys
from time import time

from sklearn.cross_validation import StratifiedKFold
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import Vectorizer
from sklearn.naive_bayes import BernoulliNB, SemisupervisedNB, MultinomialNB
Expand All @@ -44,7 +45,7 @@
help="Print the confusion matrix.")
op.add_option("--labeled",
action="store", type="float", dest="labeled_fraction",
help="Fraction of labels to retain.")
help="Fraction of labels to retain (roughly).")
op.add_option("--report",
action="store_true", dest="print_report",
help="Print a detailed classification report.")
Expand All @@ -63,11 +64,14 @@
print


def split_indices(n, fraction):
"""Randomly split indices"""
k = int(fraction * n)
a = rng.permutation(np.arange(n))
return a[:k], a[k:]
def split_indices(y, fraction):
"""Random stratified split of indices into y
Returns (unlabeled, labeled)
"""
k = int(round(1 / fraction))
folds = list(StratifiedKFold(y, k))
return folds[rng.randint(k)]


def trim(s):
Expand Down Expand Up @@ -131,7 +135,7 @@ def trim(s):
print "n_samples: %d, n_features: %d" % X_test.shape
print

labeled, unlabeled = split_indices(len(y_train), fraction)
unlabeled, labeled = split_indices(y_train, fraction)
print "Removing labels of %d random training documents" % len(unlabeled)
print
X_labeled = X_train[labeled]
Expand Down
2 changes: 1 addition & 1 deletion sklearn/naive_bayes.py
Expand Up @@ -497,7 +497,7 @@ class SemisupervisedNB(BaseNB):
Whether to print progress information.
"""

def __init__(self, estimator, n_iter=10, relabel_all=True, tol=1e-3,
def __init__(self, estimator, n_iter=10, relabel_all=True, tol=1e-5,
verbose=False):
self.estimator = estimator
self.n_iter = n_iter
Expand Down

0 comments on commit 5fa51b9

Please sign in to comment.