In [None]:
!pip install ktrain

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID";
os.environ['CUDA_VISIBLE_DEVICES']="0";

In [None]:
import ktrain
from ktrain import text
from sklearn.datasets import fetch_20newsgroups

In [None]:
categories = ['alt.atheism', 
              'soc.religion.christian', 
              'comp.graphics',
              'sci.med',
              'rec.sport.baseball']

In [None]:
train = fetch_20newsgroups(
    subset = 'train',
    categories = categories,
    shuffle = True,
    random_state =0
)

In [None]:
test = fetch_20newsgroups(
    subset = 'test',
    categories = categories,
    shuffle = True,
    random_state=0
)

In [None]:
print(test)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [None]:
test.keys()

dict_keys(['data', 'filenames', 'target_names', 'target', 'DESCR'])

In [None]:
test.target

array([0, 4, 2, ..., 2, 3, 0])

In [None]:
test.target_names

['alt.atheism',
 'comp.graphics',
 'rec.sport.baseball',
 'sci.med',
 'soc.religion.christian']

In [None]:
X_train = train.data
y_train = train.target

X_test = test.data
y_test = test.target

In [None]:
len(X_train), len(X_test)

(2854, 1899)

In [None]:
#X_test

# Build ML model with Transformer

In [None]:
model_name = "distilbert-base-uncased"

trans = text.Transformer(model_name, maxlen=512, class_names=categories)

In [None]:
train_data = trans.preprocess_train(X_train,y_train)
test_data = trans.preprocess_test(X_test, y_test)

preprocessing train...
language: en
train sequence lengths:
	mean : 291
	95percentile : 820
	99percentile : 1757


Is Multi-Label? False
preprocessing test...
language: en
test sequence lengths:
	mean : 323
	95percentile : 894
	99percentile : 2394


In [None]:
model = trans.get_classifier()

In [None]:
learner = ktrain.get_learner(model, train_data=train_data, 
                             val_data = test_data,
                             batch_size = 16)

# to find the best learning rate

In [None]:
learner.lr_find(show_plot=True, max_epochs=10)

In [None]:
# only fit one epoch

learner.fit_onecycle(1e-4,1)
#.fit_onecycle(1e-4, 1)
# 0.004
# 1 epoch



begin training using onecycle policy with max lr of 0.0001...


<keras.callbacks.History at 0x7f4187f48910>

In [None]:
# confusion matrix
learner.validate()

              precision    recall  f1-score   support

           0       0.85      0.93      0.89       319
           1       0.91      0.98      0.95       389
           2       1.00      0.96      0.98       397
           3       0.97      0.93      0.95       396
           4       0.96      0.90      0.93       398

    accuracy                           0.94      1899
   macro avg       0.94      0.94      0.94      1899
weighted avg       0.94      0.94      0.94      1899



array([[296,   8,   0,   4,  11],
       [  1, 382,   0,   5,   1],
       [  6,   9, 380,   1,   1],
       [  9,  18,   1, 367,   1],
       [ 36,   2,   0,   0, 360]])

In [None]:
learner.view_top_losses(n=5, preproc=trans)

----------
id:787 | loss:6.06 | true:comp.graphics | pred:soc.religion.christian)

----------
id:908 | loss:5.61 | true:comp.graphics | pred:soc.religion.christian)

----------
id:562 | loss:5.54 | true:rec.sport.baseball | pred:soc.religion.christian)

----------
id:238 | loss:5.54 | true:comp.graphics | pred:soc.religion.christian)

----------
id:170 | loss:5.52 | true:comp.graphics | pred:alt.atheism)



# predict on new data

In [None]:
predictor = ktrain.get_predictor(learner.model, preproc = trans)

In [None]:
x = 'Jesus Christ is the central figure of Christianity'

In [None]:
predictor.predict(x)

'rec.sport.baseball'

In [None]:
predictor.explain(x)



In [None]:
predictor.save('model')