In [1]:
import pandas as pd
import numpy as np
from Reuters import *

In [2]:
# the downloaded dataset
!ls -la reuters/*.sgm

-rw-r--r-- 1 eriza eriza 1324350 Dec  4  1996 reuters/reut2-000.sgm
-rw-r--r-- 1 eriza eriza 1254440 Dec  4  1996 reuters/reut2-001.sgm
-rw-r--r-- 1 eriza eriza 1217495 Dec  4  1996 reuters/reut2-002.sgm
-rw-r--r-- 1 eriza eriza 1298721 Dec  4  1996 reuters/reut2-003.sgm
-rw-r--r-- 1 eriza eriza 1321623 Dec  4  1996 reuters/reut2-004.sgm
-rw-r--r-- 1 eriza eriza 1388644 Dec  4  1996 reuters/reut2-005.sgm
-rw-r--r-- 1 eriza eriza 1254765 Dec  4  1996 reuters/reut2-006.sgm
-rw-r--r-- 1 eriza eriza 1256772 Dec  4  1996 reuters/reut2-007.sgm
-rw-r--r-- 1 eriza eriza 1410117 Dec  4  1996 reuters/reut2-008.sgm
-rw-r--r-- 1 eriza eriza 1338903 Dec  4  1996 reuters/reut2-009.sgm
-rw-r--r-- 1 eriza eriza 1371071 Dec  4  1996 reuters/reut2-010.sgm
-rw-r--r-- 1 eriza eriza 1304117 Dec  4  1996 reuters/reut2-011.sgm
-rw-r--r-- 1 eriza eriza 1323584 Dec  4  1996 reuters/reut2-012.sgm
-rw-r--r-- 1 eriza eriza 1129687 Dec  4  1996 reuters/reut2-013.sgm
-rw-r--r-- 1 eriza eriza 1128671 Dec  4  1996 re

In [3]:
!grep \<TOPICS\>\<D\> reuters/*.sgm | wc -l

11367


In [4]:
# read and parse the data
# this will download the data if it's not yet available locally
data_streamer = ReutersStreamReader('reuters').iterdocs()
data = get_minibatch(data_streamer, 50000)
data

In [5]:
# 'text' is combined document title and body
data.head()

Unnamed: 0,text,tags
0,SANDOZ PLANS WEEDKILLER JOINT VENTURE IN USSR\...,"[usa, ussr]"
1,TAIWAN REJECTS TEXTILE MAKERS EXCHANGE RATE PL...,"[usa, taiwan]"
2,NATIONAL FSI INC <NFSI> 4TH QTR LOSS\n\nShr lo...,"[earn, usa]"
3,OCCIDENTAL <OXY> OFFICIAL RESIGNS\n\nMidCon Co...,[usa]
4,ITALY'S BNL TO ISSUE 120 MLN DLR CONVERTIBLE B...,[italy]


In [6]:
from sklearn.preprocessing import LabelBinarizer

# binary encode the tags
lb = LabelBinarizer()
Y = lb.fit_transform(data.tags)
Y

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ..., 
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [7]:
from sklearn.feature_extraction.text import TfidfVectorizer

# get the TF-IDF of the text
vec = TfidfVectorizer(min_df=2, sublinear_tf=True, decode_error='ignore')
X = vec.fit_transform(data.text)
X

<19716x25497 sparse matrix of type '<type 'numpy.float64'>'
	with 1509007 stored elements in Compressed Sparse Row format>

In [8]:
# split into train and test set
N = int(.8 * X.shape[0])
Xtr, ytr = X[:N,:], Y[:N,:]
Xte, yte = X[N:,:], Y[N:,:]

In [19]:
# there are warnings of ill-defined recall/precision etc.
# just ignore them for now
import warnings
warnings.filterwarnings("ignore")

In [27]:
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.grid_search import GridSearchCV

# logistic regression parameter to optimise
#params = {"estimator__C": np.logspace(1, 1.5, num=5)}
# use OneVsRestClassifier for multiclass learning
model = OneVsRestClassifier(LogisticRegression())
# do a grid search on the params, with 5-fold cross-validation
# optimise for F1-score
clf = GridSearchCV(model, param_grid=params, scoring='f1', n_jobs=-1, cv=5)
clf.fit(Xtr, ytr)
clf.best_score_, clf.best_params_

(0.83655035338102735, {'estimator__C': 31.622776601683793})

In [28]:
from sklearn.metrics import f1_score

# compute predictions on test set
pred = clf.predict(Xte)
# compute F1-score on test set
f1_score(yte, pred)

0.77208685519257036

In [29]:
# a quick look into some example predictions
# compare with tags in test data
tags = []
for n in xrange(20):
    tags.append((lb.classes_[yte[n]==1], lb.classes_[pred[n]==1]))
    
pd.DataFrame(tags, columns=['actual tags', 'predicted tags'])

Unnamed: 0,actual tags,predicted tags
0,"[earn, usa]","[earn, usa]"
1,"[gnp, west-germany]",[west-germany]
2,"[earn, usa]","[earn, usa]"
3,"[acq, usa]","[acq, usa]"
4,"[earn, uk, usa]",[]
5,[usa],[usa]
6,"[brazil, gnp, imf, trade]","[brazil, imf]"
7,[usa],[usa]
8,"[cpu, usa]",[]
9,"[earn, usa]","[earn, usa]"
