# This notebook enables training and testing of Sherlock.
The procedure is:
- Load train, val, test datasets (should be preprocessed)
- Initialize model using the "pretrained" model or by training one from scratch.
- Evaluate and analyse the model predictions.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
model_id = 'retrained_sherlock'

In [3]:
from ast import literal_eval
from collections import Counter
from datetime import datetime

import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, classification_report

from sherlock.deploy.model import SherlockModel

## Load datasets for training, validation, testing

In [4]:
start = datetime.now()
print(f'Started at {start}')

X_train = pd.read_parquet('../data/data/processed/train.parquet')
y_train = pd.read_parquet('../data/data/raw/train_labels.parquet').values.flatten()

y_train = np.array([x.lower() for x in y_train])

print(f'Load data (train) process took {datetime.now() - start} seconds.')

Started at 2022-02-10 18:47:37.581529
Load data (train) process took 0:00:06.286789 seconds.


In [5]:
print('Distinct types for columns in the Dataframe (should be all float32):')
print(set(X_train.dtypes))

Distinct types for columns in the Dataframe (should be all float32):
{dtype('float32')}


In [6]:
start = datetime.now()
print(f'Started at {start}')

X_validation = pd.read_parquet('../data/data/processed/validation.parquet')
y_validation = pd.read_parquet('../data/data/raw/val_labels.parquet').values.flatten()

y_validation = np.array([x.lower() for x in y_validation])

print(f'Load data (validation) process took {datetime.now() - start} seconds.')

Started at 2022-02-10 18:47:44.189431
Load data (validation) process took 0:00:01.713314 seconds.


In [7]:
start = datetime.now()
print(f'Started at {start}')

X_test = pd.read_parquet('../data/data/processed/test.parquet')
y_test = pd.read_parquet('../data/data/raw/test_labels.parquet').values.flatten()

y_test = np.array([x.lower() for x in y_test])

print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')

Started at 2022-02-10 18:47:45.963358
Finished at 2022-02-10 18:47:48.373023, took 0:00:02.409678 seconds


## Initialize the model
Two options:
- Load Sherlock model with pretrained weights
- Fit Sherlock model from scratch

### Load Sherlock with pretrained weights

In [25]:
start = datetime.now()
print(f'Started at {start}')

model = SherlockModel()
model.initialize_model_from_json(with_weights=True)

print('Initialized model.')
print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')

Started at 2022-02-10 18:55:58.217413
Initialized model.
Finished at 2022-02-10 18:55:59.172945, took 0:00:00.955544 seconds


### Fit Sherlock from scratch (and save for later use)

In [41]:
start = datetime.now()
print(f'Started at {start}')

sherlock_model = SherlockModel()
model.fit(X_train, y_train, X_validation, y_validation, model_id=model_id)

print('Trained and saved new model.')
print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')

Started at 2022-02-10 20:48:03.023090
Train on 412059 samples, validate on 137353 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Trained and saved new model.
Finished at 2022-02-10 21:37:31.052443, took 0:49:28.030848 seconds


In [42]:
model.store_weights(model_id=model_id)

### Make prediction

In [50]:
predicted_labels = model.predict(X_test)
predicted_labels = np.array([x.lower() for x in predicted_labels])

In [51]:
print(f'prediction count {len(predicted_labels)}, type = {type(predicted_labels)}')

size=len(y_test)

# Should be fully deterministic too.
f1_score(y_test[:size], predicted_labels[:size], average="weighted")

prediction count 137353, type = <class 'numpy.ndarray'>


0.8937685721454983

In [52]:
# If using the original model, model_id should be replaced with "sherlock"
classes = np.load(f"../model_files/classes_{model_id}.npy", allow_pickle=True)

report = classification_report(y_test, predicted_labels, output_dict=True)

class_scores = list(filter(lambda x: isinstance(x, tuple) and isinstance(x[1], dict) and 'f1-score' in x[1] and x[0] in classes, list(report.items())))

class_scores = sorted(class_scores, key=lambda item: item[1]['f1-score'], reverse=True)

### Top 5 Types

In [53]:
print(f"\t\tf1-score\tprecision\trecall\t\tsupport")

for key, value in class_scores[0:5]:
    if len(key) >= 8:
        tabs = '\t' * 1
    else:
        tabs = '\t' * 2

    print(f"{key}{tabs}{value['f1-score']:.3f}\t\t{value['precision']:.3f}\t\t{value['recall']:.3f}\t\t{value['support']}")

		f1-score	precision	recall		support
grades		0.993		0.994		0.993		1765
isbn		0.990		0.992		0.987		1430
jockey		0.986		0.980		0.992		2819
currency	0.980		0.987		0.973		405
industry	0.980		0.975		0.985		2958


### Bottom 5 Types

In [54]:
print(f"\t\tf1-score\tprecision\trecall\t\tsupport")

for key, value in class_scores[len(class_scores)-5:len(class_scores)]:
    if len(key) >= 8:
        tabs = '\t' * 1
    else:
        tabs = '\t' * 2

    print(f"{key}{tabs}{value['f1-score']:.3f}\t\t{value['precision']:.3f}\t\t{value['recall']:.3f}\t\t{value['support']}")

		f1-score	precision	recall		support
rank		0.706		0.631		0.802		2983
person		0.665		0.721		0.617		579
director	0.581		0.673		0.511		225
sales		0.555		0.554		0.556		322
ranking		0.486		0.722		0.367		439


### All Scores

In [55]:
print(classification_report(y_test, predicted_labels, digits=3))

                precision    recall  f1-score   support

       address      0.944     0.936     0.940      3003
     affiliate      0.922     0.814     0.865       204
   affiliation      0.978     0.951     0.964      1768
           age      0.847     0.960     0.900      3033
         album      0.863     0.904     0.883      3035
          area      0.911     0.802     0.853      1987
        artist      0.810     0.870     0.839      3043
    birth date      0.979     0.977     0.978       479
   birth place      0.970     0.916     0.942       418
         brand      0.779     0.707     0.742       574
      capacity      0.773     0.715     0.743       362
      category      0.915     0.886     0.900      3087
          city      0.856     0.889     0.872      2966
         class      0.920     0.911     0.915      2971
classification      0.924     0.848     0.885       587
          club      0.962     0.962     0.962      2977
          code      0.931     0.902     0.916  

## Review errors

In [56]:
size = len(y_test)
mismatches = list()

for idx, k1 in enumerate(y_test[:size]):
    k2 = predicted_labels[idx]

    if k1 != k2:
        mismatches.append(k1)
        
        # zoom in to specific errors. Use the index in the next step
        if k1 in ('address'):
            print(f'[{idx}] expected "{k1}" but predicted "{k2}"')
        
f1 = f1_score(y_test[:size], predicted_labels[:size], average="weighted")
print(f'Total mismatches: {len(mismatches)} (F1 score: {f1})')

data = Counter(mismatches)
data.most_common()   # Returns all unique items and their counts

[1116] expected "address" but predicted "name"
[1578] expected "address" but predicted "language"
[2420] expected "address" but predicted "club"
[2616] expected "address" but predicted "city"
[3398] expected "address" but predicted "city"
[4380] expected "address" but predicted "county"
[4422] expected "address" but predicted "city"
[5112] expected "address" but predicted "location"
[5546] expected "address" but predicted "name"
[5647] expected "address" but predicted "team name"
[7119] expected "address" but predicted "day"
[8797] expected "address" but predicted "location"
[9354] expected "address" but predicted "location"
[9574] expected "address" but predicted "location"
[9806] expected "address" but predicted "city"
[10035] expected "address" but predicted "creator"
[10067] expected "address" but predicted "education"
[11055] expected "address" but predicted "city"
[11902] expected "address" but predicted "location"
[12072] expected "address" but predicted "artist"
[12639] expecte

Total mismatches: 14623 (F1 score: 0.8937685721454983)


[('name', 761),
 ('rank', 592),
 ('position', 551),
 ('location', 489),
 ('region', 473),
 ('team', 428),
 ('description', 422),
 ('artist', 395),
 ('area', 394),
 ('notes', 364),
 ('product', 361),
 ('category', 353),
 ('type', 342),
 ('company', 335),
 ('city', 330),
 ('day', 326),
 ('album', 292),
 ('code', 290),
 ('team name', 282),
 ('ranking', 278),
 ('class', 264),
 ('order', 254),
 ('sex', 254),
 ('person', 222),
 ('gender', 219),
 ('status', 217),
 ('owner', 211),
 ('weight', 206),
 ('result', 194),
 ('year', 193),
 ('address', 193),
 ('duration', 191),
 ('country', 177),
 ('service', 176),
 ('manufacturer', 171),
 ('brand', 168),
 ('origin', 162),
 ('plays', 152),
 ('credit', 151),
 ('component', 149),
 ('sales', 143),
 ('range', 135),
 ('format', 133),
 ('age', 122),
 ('county', 118),
 ('state', 117),
 ('club', 113),
 ('director', 110),
 ('nationality', 107),
 ('publisher', 105),
 ('capacity', 103),
 ('classification', 89),
 ('affiliation', 87),
 ('command', 85),
 ('symbol',

In [57]:
test_samples = pd.read_parquet('../data/data/raw/test_values.parquet')

In [58]:
idx = 1001
original = test_samples.iloc[idx]
converted = original.apply(literal_eval).to_list()

print(f'Predicted "{predicted_labels[idx]}", actual label "{y_test[idx]}". Actual values:\n{converted}')

Predicted "address", actual label "address". Actual values:
[['Cabot House', 'Cabot House', '5 Hill Rd.', '5 Hill Rd.', '9 Cabot Rd.', '9 Cabot Rd.', 'Cabot House', '22 Bank Rd.', '22 Bank Rd.', 'Cabot House', '31 Bank Rd.', '31 Bank Rd.', 'Bairds Hotel', '11 Cabot Rd.', '11 Cabot Rd.', '10 Hill Rd.', '10 Hill Rd.', '10 Hill Rd.', '10 Hill Rd.', '7A Church Rd.', '1 Cabot Rd.', '1 Cabot Rd.', '1 Cabot Rd.', '1 Cabot Rd.', '2 Coronation St.', '2 Coronation St.', '7A Church Rd.', '12 Hill Rd.', '12 Hill Rd.', '12 Hill Rd.', 'Cabot House', '19 Bank Rd.', '19 Bank Rd.', '19 Bank Rd.', '19 Bank Rd.', '19 Bank Rd.', '7A Church Rd.', '18 Mill Rd.', '17 Hill Rd.', '17 Hill Rd.', 'Cabot House', 'Cabot House', '25 Bank Rd.', '10 Coronation St.', '6 Cabot Rd.', '6 Cabot Rd.', '8 Hill Rd.', '8 Hill Rd.', '4 Mill Rd.', '4 Mill Rd.', '12 Sulva Rd.', '4 Haig Rd.', '13 Botwood Rd.', '13 Botwood Rd.', '8 Botwood Rd.', '8 Botwood Rd.', '16 Botwood Rd.', '16 Botwood Rd.', '16 Botwood Rd.', '16 Botwood Rd.