# 3. Classifying Successful Prospects

In this notebook, we'll use the new text classification library `gobbli` to classify each prospect using the labels we previously generated.

For this task, we're going to use some fairly memory-intensive deep learning models. For best results, run this notebook on a GPU - Google Colab works well no other options are available.

In [1]:
import pandas as pd
import numpy as np

from gobbli.model import MTDNN
from gobbli.io import TrainInput, PredictInput

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, balanced_accuracy_score, accuracy_score

import logging
logging.basicConfig(level=logging.DEBUG)

DEBUG:matplotlib.pyplot:Loaded backend module://ipykernel.pylab.backend_inline version unknown.


In [2]:
df = pd.read_csv('labelled.csv')[['report', 'label']] # preprocessed.csv

X = df.report
y = df.label.astype('str')

print(df.shape)
df.head()

(5824, 2)


Unnamed: 0,report,label
0,Heimlich is a Level 1 sex offender and wouldn'...,0
1,Alexy made headlines for all the wrong reasons...,0
2,"The Nationals have acquired Cole twice, first ...",1
3,Signed for an above-slot $2 million as a Natio...,1
4,"It often takes time for those high-ceilinged, ...",0


We prepare the training, validation, and testing input with `sklearn.model_selection.train_test_split`.

In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.05, random_state=1)

print(X_train.shape, X_test.shape, X_val.shape)

train_input = TrainInput(
    # X_train: A list of strings to classify
    X_train=X_train.tolist(), #df.report.values.tolist(),
    # y_train: The true class for each string in X_train
    y_train=y_train.tolist(), # df.label.astype('str').values.tolist(),
    # And likewise for validation
    X_valid=X_val.tolist(), # df.report.values.tolist(),
    y_valid=y_val.tolist(), # df.label.astype('str').values.tolist(),
    # Number of documents to train on at once
    train_batch_size=16,
    # Number of documents to evaluate at once
    valid_batch_size=16,
    # Number of times to iterate over the training set
    num_train_epochs=10
)

(4978,) (583,) (263,)


Then, we build and train the classifier. We'll use the `MT-DNN` architecture, which is an extension of the popular `BERT` model. In my experience, this model trains a good deal faster and performs similarly well.

This cell will take a while to run.

In [4]:
clf = MTDNN(use_gpu=True)
clf.build()

train_output = clf.train(train_input)

DEBUG:docker.utils.config:Trying paths: ['/home/jacobgdt/.docker/config.json', '/home/jacobgdt/.dockercfg']
DEBUG:docker.utils.config:No config file found
DEBUG:docker.utils.config:Trying paths: ['/home/jacobgdt/.docker/config.json', '/home/jacobgdt/.dockercfg']
DEBUG:docker.utils.config:No config file found
INFO:gobbli.model.base:Starting build.
INFO:gobbli.model.base:Downloading pre-trained weights.
DEBUG:gobbli.util:Download for URL 'https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_base.pt' already exists at '/home/jacobgdt/.gobbli/download/mt_dnn_base.pt'
INFO:gobbli.model.base:Weights downloaded.
DEBUG:docker.api.build:Looking for auth config
DEBUG:docker.api.build:No auth config in memory - loading from filesystem
DEBUG:docker.utils.config:Trying paths: ['/home/jacobgdt/.docker/config.json', '/home/jacobgdt/.dockercfg']
DEBUG:docker.utils.config:No config file found
DEBUG:docker.api.build:Sending auth config ()
DEBUG:urllib3.connectionpool:http://localhost:None "POST /v1.35/

DEBUG:gobbli.model.base:CONTAINER:             (output): BertSelfOutput(
DEBUG:gobbli.model.base:CONTAINER:               (dense): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (LayerNorm): BertLayerNorm()
DEBUG:gobbli.model.base:CONTAINER:               (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:             )
DEBUG:gobbli.model.base:CONTAINER:           )
DEBUG:gobbli.model.base:CONTAINER:           (intermediate): BertIntermediate(
DEBUG:gobbli.model.base:CONTAINER:             (dense): Linear(in_features=768, out_features=3072, bias=True)
DEBUG:gobbli.model.base:CONTAINER:           )
DEBUG:gobbli.model.base:CONTAINER:           (output): BertOutput(
DEBUG:gobbli.model.base:CONTAINER:             (dense): Linear(in_features=3072, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:             (LayerNorm): BertLayerNorm()
DEBUG:gobbli.model.base:CONTAINER:             (dropout): Dropout(p=0.1)
DEBUG

DEBUG:gobbli.model.base:CONTAINER:             (self): BertSelfAttention(
DEBUG:gobbli.model.base:CONTAINER:               (query): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (key): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (value): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:             )
DEBUG:gobbli.model.base:CONTAINER:             (output): BertSelfOutput(
DEBUG:gobbli.model.base:CONTAINER:               (dense): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (LayerNorm): BertLayerNorm()
DEBUG:gobbli.model.base:CONTAINER:               (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:             )
DEBUG:gobbli.model.base:CONTAINER:           )
DEBUG:gobbli.model.base:CONTAINER:    

DEBUG:gobbli.model.base:CONTAINER:             (dense): Linear(in_features=3072, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:             (LayerNorm): BertLayerNorm()
DEBUG:gobbli.model.base:CONTAINER:             (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:           )
DEBUG:gobbli.model.base:CONTAINER:         )
DEBUG:gobbli.model.base:CONTAINER:         (11): BertLayer(
DEBUG:gobbli.model.base:CONTAINER:           (attention): BertAttention(
DEBUG:gobbli.model.base:CONTAINER:             (self): BertSelfAttention(
DEBUG:gobbli.model.base:CONTAINER:               (query): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (key): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (value): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:  

DEBUG:urllib3.connectionpool:http://localhost:None "POST /v1.35/containers/c9b2060910704b9d14613ff4ac2b8ac8de9ca2d040156a1cef2b77b3a5b47398/wait HTTP/1.1" 200 None
DEBUG:urllib3.connectionpool:http://localhost:None "POST /v1.35/containers/c9b2060910704b9d14613ff4ac2b8ac8de9ca2d040156a1cef2b77b3a5b47398/stop HTTP/1.1" 304 0
DEBUG:urllib3.connectionpool:http://localhost:None "DELETE /v1.35/containers/c9b2060910704b9d14613ff4ac2b8ac8de9ca2d040156a1cef2b77b3a5b47398?v=False&link=False&force=False HTTP/1.1" 204 0
INFO:gobbli.model.base:Training finished in 2339.27 sec.
INFO:gobbli.model.base:RESULTS:
INFO:gobbli.model.base:  Validation loss: 0.5079846978187561
INFO:gobbli.model.base:  Validation accuracy: 0.8022813688212928
INFO:gobbli.model.base:  Training loss: 0.3200165331363678


After our model completes training, we can see the training metrics. To evaluate how well our model does on the test set, we can use `sklearn`'s metrics.

Note that these results aren't significantly better than using a simple `Tf-Idf + LogisticRegression` approach with `sklearn`. The problem we've defined is quite challenging.

In [5]:
predict_input = PredictInput(
    X=X_test.tolist(),
    labels=train_output.labels,
    checkpoint=train_output.checkpoint,
    predict_batch_size=64
)

predict_output = clf.predict(predict_input)

y_pred = np.array(predict_output.y_pred).astype('int')
y_test = y_test.astype('int')

print(classification_report(y_test, y_pred))
print(f'f1: {f1_score(y_test, y_pred):0.4f}')
print(f'balanced acc: {balanced_accuracy_score(y_test, y_pred):0.4f}')

INFO:gobbli.model.base:Starting prediction.
DEBUG:gobbli.model.base:Running container for image 'gobbli-mt-dnn-classifier' with command 'python gobbli_train.py --data_dir=data/mt_dnn --init_checkpoint=/gobbli/checkpoint.pt --batch_size=64 --output_dir=/gobbli/output --log_file=/gobbli/output/log.log --optimizer=adamax --grad_clipping=0 --global_grad_clipping=1 --lr=2e-5 --test_file=/gobbli/input/test.csv --label_file=/gobbli/input/labels.csv --max_seq_len=128'
DEBUG:gobbli.model.base:Container volumes: {'/home/jacobgdt/.gobbli/model/MTDNN/53f6f40d2ad34da793985f9ae987bfd2/predict/705033286b374ae5b5a7b220535598f4': {'bind': '/gobbli', 'mode': 'rw'}, '/home/jacobgdt/.gobbli/model/MTDNN/53f6f40d2ad34da793985f9ae987bfd2/weights': {'bind': '/model/weights', 'mode': 'rw'}, '/home/jacobgdt/.gobbli/model/MTDNN/53f6f40d2ad34da793985f9ae987bfd2/train/5ba51f33c5154580ab4d74cf70a7abde/output/model_9.pt': {'bind': '/gobbli/checkpoint.pt', 'mode': 'rw'}}
DEBUG:urllib3.connectionpool:http://localhost:

DEBUG:gobbli.model.base:CONTAINER:               (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:             )
DEBUG:gobbli.model.base:CONTAINER:             (output): BertSelfOutput(
DEBUG:gobbli.model.base:CONTAINER:               (dense): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (LayerNorm): BertLayerNorm()
DEBUG:gobbli.model.base:CONTAINER:               (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:             )
DEBUG:gobbli.model.base:CONTAINER:           )
DEBUG:gobbli.model.base:CONTAINER:           (intermediate): BertIntermediate(
DEBUG:gobbli.model.base:CONTAINER:             (dense): Linear(in_features=768, out_features=3072, bias=True)
DEBUG:gobbli.model.base:CONTAINER:           )
DEBUG:gobbli.model.base:CONTAINER:           (output): BertOutput(
DEBUG:gobbli.model.base:CONTAINER:             (dense): Linear(in_features=3072, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAI

DEBUG:gobbli.model.base:CONTAINER:         (7): BertLayer(
DEBUG:gobbli.model.base:CONTAINER:           (attention): BertAttention(
DEBUG:gobbli.model.base:CONTAINER:             (self): BertSelfAttention(
DEBUG:gobbli.model.base:CONTAINER:               (query): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (key): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (value): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:             )
DEBUG:gobbli.model.base:CONTAINER:             (output): BertSelfOutput(
DEBUG:gobbli.model.base:CONTAINER:               (dense): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:               (LayerNorm): BertLayerNorm()
DEBUG:gobbli.model.base:CONTAINER:               (dropout): Dropout(p=0.1)
DE

DEBUG:gobbli.model.base:CONTAINER:           )
DEBUG:gobbli.model.base:CONTAINER:           (output): BertOutput(
DEBUG:gobbli.model.base:CONTAINER:             (dense): Linear(in_features=3072, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:             (LayerNorm): BertLayerNorm()
DEBUG:gobbli.model.base:CONTAINER:             (dropout): Dropout(p=0.1)
DEBUG:gobbli.model.base:CONTAINER:           )
DEBUG:gobbli.model.base:CONTAINER:         )
DEBUG:gobbli.model.base:CONTAINER:       )
DEBUG:gobbli.model.base:CONTAINER:     )
DEBUG:gobbli.model.base:CONTAINER:     (pooler): BertPooler(
DEBUG:gobbli.model.base:CONTAINER:       (dense): Linear(in_features=768, out_features=768, bias=True)
DEBUG:gobbli.model.base:CONTAINER:       (activation): Tanh()
DEBUG:gobbli.model.base:CONTAINER:     )
DEBUG:gobbli.model.base:CONTAINER:   )
DEBUG:gobbli.model.base:CONTAINER:   (scoring_list): ModuleList(
DEBUG:gobbli.model.base:CONTAINER:     (0): Linear(in_features=768, out_features

              precision    recall  f1-score   support

           0       0.86      0.85      0.86       470
           1       0.41      0.42      0.41       113

    accuracy                           0.77       583
   macro avg       0.63      0.63      0.63       583
weighted avg       0.77      0.77      0.77       583

f1: 0.4105
balanced acc: 0.6346
