In [2]:
from sklearn.datasets import fetch_20newsgroups

# Data Selection (no prep)

In [9]:
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']

In [18]:
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)
test_b = fetch_20newsgroups(subset='test', categories=categories, shuffle=True, random_state=42)

In [22]:
print(f'size of training set: {len(train_b["data"])}')
print(f'size of testing set: {len(test_b["data"])}')
print(f'classes: {train_b.target_names}')

size of training set: 2257
size of testing set: 1502
classes: ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']


In [16]:
train_b['data'][0]

'From: sd345@city.ac.uk (Michael Collier)\nSubject: Converting images to HP LaserJet III?\nNntp-Posting-Host: hampton\nOrganization: The City University\nLines: 14\n\nDoes anyone know of a good way (standard PC application/PD utility) to\nconvert tif/img/tga files into LaserJet III format.  We would also like to\ndo the same, converting to HPGL (HP plotter) files.\n\nPlease email any response.\n\nIs this the correct group?\n\nThanks in advance.  Michael.\n-- \nMichael Collier (Programmer)                 The Computer Unit,\nEmail: M.P.Collier@uk.ac.city                The City University,\nTel: 071 477-8000 x3769                      London,\nFax: 071 477-8565                            EC1V 0HB.\n'

In [17]:
train_b['target'][0]

1

In [23]:
x_train = train_b.data
y_train = train_b.target
x_test = test_b.data
y_test = test_b.target

# Model

In [26]:
MODEL_NAME = 'distilbert-base-uncased'

### STEP 1: Create a transformer instance

In [24]:
import ktrain
from ktrain import text

In [None]:
from ktrain import 

In [27]:
t = text.Transformer(MODEL_NAME, maxlen=500, classes=train_b.target_names)

Downloading: 100%|██████████| 442/442 [00:00<00:00, 441kB/s]


### STEP 2: Preprocess the Datasets

In [28]:
trn = t.preprocess_train(x_train, y_train)
val = t.preprocess_test(x_test, y_test)

preprocessing train...
language: en
train sequence lengths:
	mean : 308
	95percentile : 837
	99percentile : 1938
Downloading: 100%|██████████| 232k/232k [00:00<00:00, 629kB/s] 
Downloading: 100%|██████████| 466k/466k [00:00<00:00, 1.15MB/s]


Is Multi-Label? False
preprocessing test...
language: en
test sequence lengths:
	mean : 343
	95percentile : 979
	99percentile : 2562


### STEP 3: Create a Model and Weap in Learner

In [29]:
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)

Downloading: 100%|██████████| 363M/363M [00:18<00:00, 20.0MB/s]


### STEP 4 (optional): Estimate the Learning Rate

In [30]:
learner.lr_find(show_plot=True, max_epochs=2)

simulating training for different learning rates... this may take a few moments...
Epoch 1/2
 15/376 [>.............................] - ETA: 42:08 - loss: 1.3949 - accuracy: 0.1341

### STEP 5: Train the Model

In [31]:
learner.fit_onecycle(lr=5e-5, epochs=4)



begin training using onecycle policy with max lr of 5e-05...
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tensorflow.python.keras.callbacks.History at 0x1ea98fcb910>

### STEP 6: Inspect the model

In [32]:
learner.view_top_losses(n=1, preproc=t)

----------
id:1330 | loss:7.13 | true:comp.graphics | pred:sci.med)



In [43]:
print(x_test[1330])

From: madler@cco.caltech.edu (Mark Adler)
Subject: gamma correction
Organization: California Institute of Technology, Pasadena
Lines: 5
NNTP-Posting-Host: sandman.caltech.edu


Can someone who knows what they're talking about add a FAQ entry
on gamma correction?  Thanks.

mark



### STEP 7: Make Predictions on New Data

In [34]:
predictor = ktrain.get_predictor(learner.model, preproc=t)

In [35]:
predictor.predict("Jesus Christ is the central figure of Christianity")

'soc.religion.christian'

### Explanation

In [44]:
!pip3 install git+https://github.com/amaiya/eli5@tfkeras_0_10_1

Collecting git+https://github.com/amaiya/eli5@tfkeras_0_10_1
  Cloning https://github.com/amaiya/eli5 (to revision tfkeras_0_10_1) to c:\users\gustavo\appdata\local\temp\pip-req-build-w7_e6q71
Collecting tabulate>=0.7.7
  Downloading tabulate-0.8.7-py3-none-any.whl (24 kB)
Collecting graphviz
  Downloading graphviz-0.16-py2.py3-none-any.whl (19 kB)
Building wheels for collected packages: eli5
  Building wheel for eli5 (setup.py): started
  Building wheel for eli5 (setup.py): finished with status 'done'
  Created wheel for eli5: filename=eli5-0.10.1-py2.py3-none-any.whl size=107645 sha256=9e9e50ca2d01e336a65a4f035b200a1b94ec65b167f5baf0dd59a7e61797bad7
  Stored in directory: C:\Users\Gustavo\AppData\Local\Temp\pip-ephem-wheel-cache-6i3sojn4\wheels\92\5c\70\2de39262143de9d4f8990bd79d5ce380697535833ceb70b595
Successfully built eli5
Installing collected packages: tabulate, graphviz, eli5
Successfully installed eli5-0.10.1 graphviz-0.16 tabulate-0.8.7


In [45]:
predictor.explain("Jesus Christ is the central figure of Christianity")

Contribution?,Feature
9.535,Highlighted in text (sum)
-0.057,<BIAS>


In [47]:
predictor.explain("This year's pandemic will bring lots of challenges regarding vaccination")

Contribution?,Feature
7.553,Highlighted in text (sum)
-0.433,<BIAS>


### STEP 8: Saving / Loading Model

In [38]:
predictor.save('/tmp/my_20newsgroup_predictor')

In [39]:
reloaded_predictor = ktrain.load_predictor('/tmp/my_20newsgroup_predictor')

In [40]:
reloaded_predictor.predict("Jesus Christ is the central figure of Christianity")

'soc.religion.christian'

In [41]:
reloaded_predictor.predict_proba("Jesus Christ is the central figure of Christianity")

array([1.8159728e-03, 4.5640234e-04, 5.0473597e-04, 9.9722290e-01],
      dtype=float32)

In [42]:
reloaded_predictor.get_classes()

['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']