Update: use unlabeled sample to train classifier by EM algorithm.

In [1]:
# Import packages and libraries
import numpy as np
import random as rnd
import nltk as nk

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.naive_bayes import MultinomialNB
from sklearn import metrics
from pprint import pprint

from Semi_EM_NB import Semi_EM_MultinomialNB

In [2]:
# Load train and test data set with class labels 
train_Xy = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
test_Xy = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))

In [3]:
# Convert all text data into tf-idf vectors 
vectorizer = TfidfVectorizer(stop_words='english', min_df=3, max_df=0.9)
# vectorizer = TfidfVectorizer()
train_vec = vectorizer.fit_transform(train_Xy.data)
test_vec = vectorizer.transform(test_Xy.data)
print train_vec.shape, test_vec.shape

(11314, 26747) (7532, 26747)


In [4]:
# Divide train data set into labeled and unlabeled data sets
n_train_data = train_vec.shape[0]
split_ratio = 0.5 # labeled vs unlabeled
X_l, X_u, y_l, y_u = train_test_split(train_vec, train_Xy.target, train_size=split_ratio, stratify=train_Xy.target)
print X_l.shape, X_u.shape

(5657, 26747) (5657, 26747)


In [5]:
# Train Naive Bayes classifier (imported) 
# using labeled data set only
nb_clf = MultinomialNB(alpha=1)
nb_clf.fit(X_l, y_l)

MultinomialNB(alpha=1, class_prior=None, fit_prior=True)

In [6]:
# Train Naive Bayes classifier (imported) 
# using both labeled and unlabeled data set
em_nb_clf = Semi_EM_MultinomialNB(alpha=1e-8, max_iter=30) # semi supervised EM based Naive Bayes classifier
# em_nb_clf.fit(X_l, y_l, X_u)
em_nb_clf.fit_x(X_l, y_l, X_u)
# em_nb_clf.fit_with_clustering(X_l, y_l, X_u)
# em_nb_clf.partial_fit(X_l, y_l, X_u)

Initial expected log likelihood = -3800070.806

EM iteration #1
	Expected log likelihood = -2935788.693
EM iteration #2
	Expected log likelihood = -2935604.028
EM iteration #3
	Expected log likelihood = -2935569.932
EM iteration #4
	Expected log likelihood = -2935569.932
EM iteration #5
	Expected log likelihood = -2935569.932
EM iteration #6
	Expected log likelihood = -2935569.932
EM iteration #7
	Expected log likelihood = -2935569.932
EM iteration #8
	Expected log likelihood = -2935569.932
EM iteration #9
	Expected log likelihood = -2935569.932
EM iteration #10
	Expected log likelihood = -2935569.932


<Semi_EM_NB.Semi_EM_MultinomialNB instance at 0x11026e998>

In [7]:
# Evaluate original NB classifier using test data set
pred = nb_clf.predict(test_vec)
print(metrics.classification_report(test_Xy.target, pred, target_names=test_Xy.target_names))
# pprint(metrics.confusion_matrix(test_Xy.target, pred))
print(metrics.accuracy_score(test_Xy.target, pred))

                          precision    recall  f1-score   support

             alt.atheism       0.76      0.21      0.33       319
           comp.graphics       0.61      0.68      0.64       389
 comp.os.ms-windows.misc       0.66      0.60      0.63       394
comp.sys.ibm.pc.hardware       0.57      0.74      0.64       392
   comp.sys.mac.hardware       0.73      0.65      0.69       385
          comp.windows.x       0.74      0.73      0.73       395
            misc.forsale       0.76      0.74      0.75       390
               rec.autos       0.77      0.70      0.73       396
         rec.motorcycles       0.84      0.67      0.75       398
      rec.sport.baseball       0.88      0.79      0.83       397
        rec.sport.hockey       0.57      0.91      0.70       399
               sci.crypt       0.64      0.77      0.70       396
         sci.electronics       0.70      0.47      0.56       393
                 sci.med       0.82      0.71      0.76       396
         

In [8]:
# Evaluate semi-supervised EM NB classifier using test data set
pred = em_nb_clf.predict(test_vec)
print(metrics.classification_report(test_Xy.target, pred, target_names=test_Xy.target_names))
# pprint(metrics.confusion_matrix(test_Xy.target, pred))
print(metrics.accuracy_score(test_Xy.target, pred))

                          precision    recall  f1-score   support

             alt.atheism       0.62      0.19      0.29       319
           comp.graphics       0.52      0.60      0.56       389
 comp.os.ms-windows.misc       0.60      0.29      0.40       394
comp.sys.ibm.pc.hardware       0.48      0.67      0.56       392
   comp.sys.mac.hardware       0.73      0.41      0.52       385
          comp.windows.x       0.63      0.71      0.67       395
            misc.forsale       0.76      0.51      0.61       390
               rec.autos       0.71      0.57      0.63       396
         rec.motorcycles       0.72      0.49      0.59       398
      rec.sport.baseball       0.95      0.69      0.80       397
        rec.sport.hockey       0.54      0.89      0.67       399
               sci.crypt       0.55      0.71      0.62       396
         sci.electronics       0.62      0.39      0.48       393
                 sci.med       0.74      0.69      0.72       396
         

In [9]:
# find the most informative features 
import numpy as np
def show_topK(classifier, vectorizer, categories, K=10):
    feature_names = np.asarray(vectorizer.get_feature_names())
    for i, category in enumerate(categories):
        topK = np.argsort(classifier.coef_[i])[-K:]
        print("%s: %s" % (category, " ".join(feature_names[topK])))

In [10]:
show_topK(nb_clf, vectorizer, train_Xy.target_names, K=10) # keywords for each class by original NB classifier

alt.atheism: does atheism islam atheists religion say think don people god
comp.graphics: code does know program looking file thanks image files graphics
comp.os.ms-windows.misc: using cica problem thanks driver use files dos file windows
comp.sys.ibm.pc.hardware: does pc disk monitor ide controller bus scsi card drive
comp.sys.mac.hardware: use lc quadra monitor thanks know does drive apple mac
comp.windows.x: program windows thanks application x11r5 xterm widget motif server window
misc.forsale: asking interested email sell new condition offer shipping 00 sale
rec.autos: good know new dealer engine ford like just cars car
rec.motorcycles: dog bmw riding motorcycle just like bikes ride dod bike
rec.sport.baseball: pitching braves players hit games runs game team baseball year
rec.sport.hockey: league play teams season games players nhl game team hockey
sci.crypt: escrow crypto use nsa government keys chip clipper encryption key
sci.electronics: phone current want don know circuit use 

In [11]:
show_topK(em_nb_clf, vectorizer, train_Xy.target_names, K=10) # keywords for each class by semisupervised EM NB classifier

alt.atheism: atheists morality religion just say atheism think don people god
comp.graphics: windows does format know looking file files image thanks graphics
comp.os.ms-windows.misc: ax program drivers use thanks driver files dos file windows
comp.sys.ibm.pc.hardware: monitor pc thanks disk ide controller bus scsi card drive
comp.sys.mac.hardware: use does simms know problem quadra thanks drive apple mac
comp.windows.x: using application program use widget thanks windows motif server window
misc.forsale: price asking sell new email condition offer 00 shipping sale
rec.autos: don good ford new dealer like engine just cars car
rec.motorcycles: riding don helmet motorcycle just like ride bikes dod bike
rec.sport.baseball: players braves hit pitching runs games game baseball team year
rec.sport.hockey: teams year nhl season players games play hockey team game
sci.crypt: people escrow use nsa keys government chip clipper encryption key
sci.electronics: good ground used does know voltage us

In [12]:
print nb_clf.class_log_prior_, em_nb_clf.clf.class_log_prior_

[-3.16001007 -2.96389519 -2.95028954 -2.95367364 -2.97422231 -2.94691686
 -2.96389519 -2.94691686 -2.94020542 -2.94355551 -2.93686652 -2.94355551
 -2.95367364 -2.94691686 -2.94691686 -2.93686652 -3.0311772  -2.99874192
 -3.19391162 -3.40420703] [-3.35238196 -2.98815982 -3.06659563 -2.88174722 -3.06659563 -2.87859761
 -3.05340034 -3.02024813 -3.02934737 -3.05715269 -2.67450226 -2.82205085
 -3.08382093 -2.93520124 -2.85528653 -2.73528715 -3.06469989 -2.82651846
 -3.29831474 -3.67780437]
