<a href="https://colab.research.google.com/github/cindyyj/NLP_examples/blob/main/20newsgroups_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# install ktrain on Google Colab
!pip3 install ktrain

Collecting ktrain
  Downloading ktrain-0.29.2.tar.gz (25.3 MB)
[K     |████████████████████████████████| 25.3 MB 9.7 MB/s 
[?25hCollecting scikit-learn==0.24.2
  Downloading scikit_learn-0.24.2-cp37-cp37m-manylinux2010_x86_64.whl (22.3 MB)
[K     |████████████████████████████████| 22.3 MB 1.5 MB/s 
Collecting langdetect
  Downloading langdetect-1.0.9.tar.gz (981 kB)
[K     |████████████████████████████████| 981 kB 43.0 MB/s 
Collecting cchardet
  Downloading cchardet-2.1.7-cp37-cp37m-manylinux2010_x86_64.whl (263 kB)
[K     |████████████████████████████████| 263 kB 41.0 MB/s 
Collecting syntok==1.3.3
  Downloading syntok-1.3.3-py3-none-any.whl (22 kB)
Collecting seqeval==0.0.19
  Downloading seqeval-0.0.19.tar.gz (30 kB)
Collecting transformers==4.10.3
  Downloading transformers-4.10.3-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 37.0 MB/s 
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux201

In [2]:
# import ktrain and the ktrain.text modules
import ktrain
from ktrain import text

In [3]:
ktrain.__version__

'0.29.2'

# Multiclass Text Classification Using BERT and Keras
In this example, we will use ***ktrain*** ([a lightweight wrapper around Keras](https://github.com/amaiya/ktrain)) to build a model using the dataset employed in the **scikit-learn** tutorial: [Working with Text Data](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html).  As in the tutorial, we will sample 4 newsgroups to create a relatively small multiclass text classification dataset.  The objective is to accurately classify each document into one of these four newsgroup topic categories.  This will provide us an opportunity to see **BERT** in action on a relatively smaller training set.  Let's fetch the [20newsgroups dataset ](http://qwone.com/~jason/20Newsgroups/) using scikit-learn.

In [4]:
# fetch the dataset using scikit-learn
categories = ['alt.atheism', 'soc.religion.christian',
             'comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train',
   categories=categories, shuffle=True, random_state=42)
test_b = fetch_20newsgroups(subset='test',
   categories=categories, shuffle=True, random_state=42)

print('size of training set: %s' % (len(train_b['data'])))
print('size of validation set: %s' % (len(test_b['data'])))
print('classes: %s' % (train_b.target_names))

x_train = train_b.data
y_train = train_b.target
x_test = test_b.data
y_test = test_b.target

size of training set: 2257
size of validation set: 1502
classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']


## STEP 1:  Load and Preprocess the Data
Preprocess the data using the `texts_from_array function` (since the data resides in an array).
If your documents are stored in folders or a CSV file you can use the `texts_from_folder` or `texts_from_csv` functions, respectively.

In [5]:
(x_train,  y_train), (x_test, y_test), preproc = text.texts_from_array(x_train=x_train, y_train=y_train,
                                                                       x_test=x_test, y_test=y_test,
                                                                       class_names=train_b.target_names,
                                                                       preprocess_mode='bert',
                                                                       maxlen=350, 
                                                                       max_features=35000)

downloading pretrained BERT model (uncased_L-12_H-768_A-12.zip)...
[██████████████████████████████████████████████████]
extracting pretrained BERT model...
done.

cleanup downloaded zip...
done.

preprocessing train...
language: en


Is Multi-Label? False
preprocessing test...
language: en


task: text classification


## STEP 2:  Load the BERT Model and Instantiate a Learner object

In [6]:
# you can disregard the deprecation warnings arising from using Keras 2.2.4 with TensorFlow 1.14.
model = text.text_classifier('bert', train_data=(x_train, y_train), preproc=preproc)
learner = ktrain.get_learner(model, train_data=(x_train, y_train), batch_size=6)

Is Multi-Label? False
maxlen is 350
done.


## STEP 3: Train the Model

We train using one of the three learning rates recommended in the BERT paper: *5e-5*, *3e-5*, or *2e-5*.
Alternatively, the ktrain Learning Rate Finder can be used to find a good learning rate by invoking `learner.lr_find()` and `learner.lr_plot()`, prior to training.
The `learner.fit_onecycle` method employs a [1cycle learning rate policy](https://arxiv.org/pdf/1803.09820.pdf).



In [None]:
learner.fit_onecycle(2e-5, 4)



begin training using onecycle policy with max lr of 2e-05...
Epoch 1/4
Epoch 2/4

We can use the `learner.validate` method to test our model against the validation set.
As we can see, BERT achieves a **96%** accuracy, which is quite a bit higher than the 91% accuracy achieved by SVM in the [scikit-learn tutorial](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html).

In [None]:
learner.validate(val_data=(x_test, y_test), class_names=train_b.target_names)

                        precision    recall  f1-score   support

           alt.atheism       0.94      0.91      0.92       319
         comp.graphics       0.96      0.96      0.96       389
               sci.med       0.98      0.96      0.97       396
soc.religion.christian       0.94      0.99      0.96       398

              accuracy                           0.96      1502
             macro avg       0.96      0.95      0.95      1502
          weighted avg       0.96      0.96      0.96      1502



array([[289,   6,   3,  21],
       [ 10, 374,   4,   1],
       [  5,   8, 379,   4],
       [  3,   1,   1, 393]])

## How to Use Our Trained BERT Model

We can call the `learner.get_predictor` method to obtain a Predictor object capable of making predictions on new raw data.

In [None]:
predictor = ktrain.get_predictor(learner.model, preproc)

In [None]:
predictor.get_classes()

['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']

In [None]:
predictor.predict(test_b.data[0:1])

['sci.med']

In [None]:
# we can visually verify that our prediction of 'sci.med' for this document is correct
print(test_b.data[0])

From: brian@ucsd.edu (Brian Kantor)
Subject: Re: HELP for Kidney Stones ..............
Organization: The Avant-Garde of the Now, Ltd.
Lines: 12
NNTP-Posting-Host: ucsd.edu

As I recall from my bout with kidney stones, there isn't any
medication that can do anything about them except relieve the pain.

Either they pass, or they have to be broken up with sound, or they have
to be extracted surgically.

When I was in, the X-ray tech happened to mention that she'd had kidney
stones and children, and the childbirth hurt less.

Demerol worked, although I nearly got arrested on my way home when I barfed
all over the police car parked just outside the ER.
	- Brian



In [None]:
# we predicted the correct label
print(test_b.target_names[test_b.target[0]])

sci.med


The `predictor.save` and `ktrain.load_predictor` methods can be used to save the Predictor object to disk and reload it at a later time to make predictions on new data.

In [None]:
# let's save the predictor for later use
predictor.save('/tmp/my_predictor')

In [None]:
# reload the predictor
reloaded_predictor = ktrain.load_predictor('/tmp/my_predictor')

In [None]:
# make a prediction on the same document to verify it still works
reloaded_predictor.predict(test_b.data[0:1])

['sci.med']