In [0]:
# install ktrain on Google Colab
!pip3 install ktrain



In [0]:
# import ktrain and the ktrain.text modules
import ktrain
from ktrain import text

In [0]:
ktrain.__version__

'0.5.2'

# Multiclass Text Classification Using BERT and Keras
In this example, we will use ***ktrain*** ([a lightweight wrapper around Keras](https://github.com/amaiya/ktrain)) to build a model using the dataset employed in the **scikit-learn** tutorial: [Working with Text Data](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html).  As in the tutorial, we will sample 4 newsgroups to create a relatively small multiclass text classification dataset.  The objective is to accurately classify each document into one of these four newsgroup topic categories.  This will provide us an opportunity to see **BERT** in action on a relatively smaller training set.  Let's fetch the [20newsgroups dataset ](http://qwone.com/~jason/20Newsgroups/) using scikit-learn.

In [0]:
# fetch the dataset using scikit-learn
categories = ['pos', 'neg']
pos = open('positive.parsed').read().split('\n')[:-1]
neg = open('negative.parsed').read().split('\n')[:-1]
print(pos)
x_test = pos+neg
y_test = [1]*1000 + [0]*1000

pos = open('g_positive.parsed').read().split('\n')[:-1]
neg = open('g_negative.parsed').read().split('\n')[:-1]
print(pos)
x_train = pos+neg
y_train = [1]*1000 + [0]*1000

['muuuito bommm prós o produto é de boa qualide é muito barato é ideal para quem tem um psp ou camera digital e é original contras podem comprar por que vocês não vão se arrepender conclusão os concorrentes morram de inveja', 'ótimo prós ótimo produto recomendo a todos contras não tenho conclusão adorei o produto estou usando em meu celular ficou ótimo cabe muitas músicas videos e ainda tem muito espaço adorei', 'nokia e NUMBER prós tudo nele é bom bateria dura muito teclas de facil acesso contras sistema operacional é fraco a navegaçao na net é um pouco lenta e algumas pag nao carregam conclusão é um smartphone barato pelo que ele oferece', 'não posso opinar sobre a x- NUMBER pois o vendedor não cumpriu a entrega até NUMBER h do desta data prós não posso opinar pois até esta data não a recebi contras não posso opinar pois até esta data não a recebi conclusão infelizmente n posso opinar sobre o produto pois o vendedor zeca parati@bol com br de nome paulo sérgio dos santos não cumpriu a

## STEP 1:  Load and Preprocess the Data
Preprocess the data using the `texts_from_array function` (since the data resides in an array).
If your documents are stored in folders or a CSV file you can use the `texts_from_folder` or `texts_from_csv` functions, respectively.

In [0]:
(x_train,  y_train), (x_test, y_test), preproc = text.texts_from_array(x_train=x_train, y_train=y_train,
                                                                       x_test=x_test, y_test=y_test,
                                                                       class_names=categories,
                                                                       preprocess_mode='bert',
                                                                       maxlen=350, 
                                                                       max_features=35000)

preprocessing train...
language: pt


preprocessing test...
language: pt


## STEP 2:  Load the BERT Model and Instantiate a Learner object

In [0]:
# you can disregard the deprecation warnings arising from using Keras 2.2.4 with TensorFlow 1.14.
model = text.text_classifier('bert', train_data=(x_train, y_train), preproc=preproc)
learner = ktrain.get_learner(model, train_data=(x_train, y_train), batch_size=6)

Is Multi-Label? False
maxlen is 350
done.


## STEP 3: Train the Model

We train using one of the three learning rates recommended in the BERT paper: *5e-5*, *3e-5*, or *2e-5*.
Alternatively, the ktrain Learning Rate Finder can be used to find a good learning rate by invoking `learner.lr_find()` and `learner.lr_plot()`, prior to training.
The `learner.fit_onecycle` method employs a [1cycle learning rate policy](https://arxiv.org/pdf/1803.09820.pdf).



In [0]:
learner.fit_onecycle(2e-5, 4)



begin training using onecycle policy with max lr of 2e-05...
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x7fc667ce2a58>

We can use the `learner.validate` method to test our model against the validation set.
As we can see, BERT achieves a **96%** accuracy, which is quite a bit higher than the 91% accuracy achieved by SVM in the [scikit-learn tutorial](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html).

In [0]:
learner.validate(val_data=(x_test, y_test), class_names=categories)

              precision    recall  f1-score   support

         pos       0.88      0.89      0.89      1000
         neg       0.89      0.88      0.88      1000

    accuracy                           0.88      2000
   macro avg       0.88      0.88      0.88      2000
weighted avg       0.88      0.88      0.88      2000



array([[894, 106],
       [125, 875]])

## How to Use Our Trained BERT Model

We can call the `learner.get_predictor` method to obtain a Predictor object capable of making predictions on new raw data.

In [0]:
import ktrain
predictor = ktrain.get_predictor(learner.model, preproc)

In [0]:
predictor.get_classes()

['pos', 'neg']

In [0]:
predictor.predict('')

'pos'

In [0]:
predictor.save('/tmp/g_e_predictor')

In [0]:
# we predicted the correct label


The `predictor.save` and `ktrain.load_predictor` methods can be used to save the Predictor object to disk and reload it at a later time to make predictions on new data.