# Training an intents classification model

Define constants

In [1]:
from os import getcwd, path
import sys
import matplotlib.pyplot as plt
import numpy as np

BASE_PATH = path.dirname(getcwd())
sys.path.append(BASE_PATH)

DATA_UTILS = path.join(BASE_PATH, 'common/data_utils.py')
TRAIN_PATH = path.join(BASE_PATH, 'kc_data.json')
CLASSES_FILE = path.join(BASE_PATH, 'classes.json')

In [2]:
exec(open(DATA_UTILS).read())

Use functions from the utils to extract and preprocess the training data
Refer to `kc_data.json` for the sample data format
`get_data_pairs` is then used to parse data into a tuple of `([list_of_sentences], [list_of_labels])`

In [3]:
X_train, y_train = get_data_pairs(data_from_json(TRAIN_PATH))

In [4]:
import torch
torch.__version__

'0.5.0a0+a24163a'

Start training the classification model and save

In [5]:
import torch.optim as optim
from text_classification.sif_starspace.model import StarspaceClassifierWrapper
from text_classification.sif_starspace.train import StarspaceClassifierLearner
from common.callbacks import PrintLoggerCallback, EarlyStoppingCallback

model = StarspaceClassifierWrapper()
# learner = StarspaceClassifierLearner(model, 
#     optimizer_fn=optim.SGD, 
#     optimizer_kwargs={'lr': 0.01, 'momentum': 0.9}
# )
learner = StarspaceClassifierLearner(model)

In [6]:
learner.fit(
    training_data=(X_train, y_train),
    batch_size=64,
    epochs=300,
    callbacks=[
        PrintLoggerCallback(log_every=1),
        EarlyStoppingCallback(tolerance=0)
    ]
)

1m 53s (- 563m 22s) (1 0%) - loss: 914.5382 - accuracy: 0.0923
3m 46s (- 562m 19s) (2 0%) - loss: 734.0002 - accuracy: 0.3298
5m 28s (- 542m 3s) (3 1%) - loss: 649.9254 - accuracy: 0.4224
7m 19s (- 542m 32s) (4 1%) - loss: 596.3769 - accuracy: 0.4661
9m 5s (- 535m 55s) (5 1%) - loss: 555.8648 - accuracy: 0.4962
11m 33s (- 565m 59s) (6 2%) - loss: 521.5937 - accuracy: 0.5206
13m 18s (- 557m 4s) (7 2%) - loss: 491.3354 - accuracy: 0.5337
15m 6s (- 551m 39s) (8 2%) - loss: 462.9839 - accuracy: 0.5519
16m 47s (- 543m 11s) (9 3%) - loss: 437.1207 - accuracy: 0.5641
18m 28s (- 535m 51s) (10 3%) - loss: 412.1628 - accuracy: 0.5780
20m 7s (- 528m 51s) (11 3%) - loss: 388.2298 - accuracy: 0.5929
21m 56s (- 526m 34s) (12 4%) - loss: 366.4929 - accuracy: 0.5995
23m 36s (- 521m 18s) (13 4%) - loss: 345.0561 - accuracy: 0.6071
25m 23s (- 518m 38s) (14 4%) - loss: 325.2444 - accuracy: 0.6194
27m 6s (- 514m 59s) (15 5%) - loss: 307.4852 - accuracy: 0.6314
28m 44s (- 510m 15s) (16 5%) - loss: 290.1102

In [7]:
model(['I\'m having diahrea'])

[[{'confidence': 0.7043687105178833,
   'intent': 'General - Depression - Generic'},
  {'confidence': 0.6343507170677185, 'intent': 'BabyGender - Unhappy'},
  {'confidence': 0.619273841381073,
   'intent': 'General - I Hate Being Pregnant'},
  {'confidence': 0.5730706453323364, 'intent': 'SmallTalk - UserIsAnnoyed'},
  {'confidence': 0.5696707367897034,
   'intent': 'Medical - PostPartum Depression - What - Lvl3'}]]

In [8]:
model(['the earth is flat'])

[[{'confidence': 0.4830198585987091,
   'intent': 'Trimester - Third - How big IT'},
  {'confidence': 0.4711604416370392, 'intent': 'Trimester - How Big'},
  {'confidence': 0.43384605646133423,
   'intent': 'Birth - Labour - Epidural - What is it IC'},
  {'confidence': 0.4299301505088806, 'intent': 'Trimester - First - How big'},
  {'confidence': 0.4291103482246399,
   'intent': 'Trimester - Second - how big'}]]