# Train and save text classification model

### Import modules

In [8]:
import src.data.make_dataset as make_dataset
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)

### Load and process raw data files

In [2]:
train_data, valid_data, classes = make_dataset.load_data("../data/raw/")

INFO:root:Number of training samples:    3543746
INFO:root:Number of test     samples:     151860
INFO:root:Completed data shuffling.
INFO:root:Completed categories list.


### Load preprocessed data files

In [11]:
train_data, valid_data, classes = make_dataset.load_processed("../data/processed/")

Print a data sample with its labels.

In [10]:
train_data[0][15431],train_data[1][15431]

('Hypocalcimia after thydroidectomy - Want new management?',
 {'Society & Culture': False,
  'Science & Mathematics': False,
  'Health': True,
  'Education & Reference': False,
  'Computers & Internet': False,
  'Sports': False,
  'Business & Finance': False,
  'Entertainment & Music': False,
  'Family & Relationships': False,
  'Politics & Government': False})

### Model training

In [11]:
import src.models.train_model as train_model
train_model.train(train_data, valid_data, classes)

Loading Yahoo! Q&A data...
Loaded model 'en_core_web_sm'
Using 2000 examples (2000 training, 5000 evaluation)
Training the model...
LOSS 	  P  	  R  	  F  
20.005	0.545	0.020	0.039
5.555	0.538	0.044	0.082
1.986	0.459	0.108	0.175
0.736	0.420	0.172	0.245
0.306	0.408	0.220	0.286
0.143	0.391	0.247	0.303
0.071	0.381	0.265	0.313
0.045	0.370	0.273	0.314
0.021	0.364	0.281	0.317
0.011	0.359	0.285	0.318
0.007	0.355	0.289	0.319
0.005	0.353	0.295	0.321
0.003	0.351	0.296	0.321
0.005	0.349	0.299	0.322
0.002	0.345	0.297	0.319
0.002	0.345	0.301	0.321
0.002	0.347	0.305	0.324
0.002	0.349	0.308	0.327
0.001	0.348	0.309	0.327
0.001	0.343	0.305	0.323
This movie sucked {'Society & Culture': 3.4277999105825074e-08, 'Science & Mathematics': 1.4359703476209823e-13, 'Health': 7.139125401955937e-11, 'Education & Reference': 4.144410979324716e-16, 'Computers & Internet': 1.8301327120440192e-10, 'Sports': 5.0861096584364773e-11, 'Business & Finance': 4.719486696558306e-06, 'Entertainment & Music': 0.999992966651916

### Test model loading and prediction

In [5]:
import pandas as pd
from src.models.predict_model import load,predict
nlp=load(model_dir="../models/")
query,cats = predict_model.predict("Where to go to school.",nlp=nlp)
print("\n>"+query)
pd.Series(cats,name="Probability").sort_values(ascending=False).to_frame()*100

Loading from ../models/

>Where to go to school.


Unnamed: 0,Probability
Education & Reference,83.057827
Sports,16.575634
Business & Finance,0.297126
Entertainment & Music,0.062902
Computers & Internet,0.002717
Health,0.001683
Society & Culture,0.00114
Science & Mathematics,0.000558
Politics & Government,0.000292
Family & Relationships,0.000131


### Evaluate model performance

In [21]:
from src.models.train_model import evaluate
dev_texts, dev_cats = valid_data
n=100
scores = evaluate(nlp.tokenizer, nlp.get_pipe("textcat"), dev_texts[:n], dev_cats[:n])
scores

{'textcat_p': 0.36585365849196905,
 'textcat_r': 0.29999999997,
 'textcat_f': 0.32967032963410214}

Build the confusion matrix and adapt performance metrics for multiclass classification.

In [70]:
import numpy as np
docs = (nlp.tokenizer(text) for text in dev_texts[:n])

confusion_matrix=np.zeros([len(classes),len(classes)])
for i, doc in enumerate(nlp.get_pipe("textcat").pipe(docs)):
    gold = dev_cats[:n][i]
    g=pd.Series(gold,name="gold").argmax()
    a=pd.Series(doc.cats,name="assigned").argmax()
    confusion_matrix[g,a]+=1
print(confusion_matrix)
precisions = confusion_matrix.diagonal()/confusion_matrix.sum(axis=0)
recalls = confusion_matrix.diagonal()/confusion_matrix.sum(axis=1)
precisions.mean(),recalls.mean()

[[4. 1. 0. 0. 0. 1. 1. 0. 0. 0.]
 [4. 2. 3. 1. 0. 0. 1. 0. 2. 0.]
 [2. 2. 2. 2. 2. 0. 0. 1. 3. 0.]
 [2. 1. 1. 2. 0. 2. 0. 0. 0. 1.]
 [0. 2. 0. 0. 6. 0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0. 4. 1. 0. 2. 3.]
 [1. 0. 0. 1. 0. 0. 3. 1. 0. 2.]
 [3. 0. 1. 0. 0. 2. 0. 0. 2. 0.]
 [1. 0. 1. 0. 1. 1. 1. 0. 6. 0.]
 [1. 1. 0. 1. 1. 0. 0. 0. 1. 5.]]


(0.3156926406926407, 0.3541111666111666)