In [None]:
import pickle

import numpy
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from sklearn.svm import SVC
from lightgbm import LGBMClassifier
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams["figure.figsize"] = (15,10)

In [4]:
# to filter out rare meddra terms
min_occurence = 10

with open('pt_counts.pkl', 'rb') as file:
    tag_counts = pickle.load(file)
    
with open("spell_checker.pkl", "rb") as file:
    spell = pickle.load(file)

def get_features(case):
    """
    return effect_description, drugname, [sex, age, imc]
    
    Depends on your data format,
        - effect_description is the free text written by the patient, it must be tokenized
            (you can use the tokenize function above)
        - drugname is simply the drug name, we used a spell checker trained on specific text (python library pyspellchecker==0.5.0)
        - sex: can be encoded as an int (0-1)
        - age and imc are given as numerical values (int of float and float)
    """
    return None

## data processing

In [None]:
"""
'dataset.pkl' contains your dataset of features, it is a dict with unique key corresponding to each case.
"""
with open('dataset.pkl', 'rb') as file:
    data = pickle.load(file)
    
"""
'regex_match.pkl' contains meddra terms matched to each case using a regex engine.
    it is a dict with the same key as for 'dataset.pkl'.
"""
with open('regex_match.pkl', 'rb') as file:
    regex_match = pickle.load(file)

"""
'tags.pkl' contains the meddra tags that correspond to your dataset, it is a dict with the same key as for 'dataset.pkl'.
    We only keep the most common terms (i.e with number of occurences greater than the min_occurence parameter.)
""" 
with open('tags.pkl', 'rb') as file:
    tags = pickle.load(file)
    
X = []
Y = []
re_match = []
"""
    We build the X and Y arrays from our features and tags.
    X components are numeric vectors of features, it can be a mixture of text
    vectorisation (using TF-IDF or any text embedding algorithm), numerical
    features (age, weigh,...) and one hot encoding of categorical features (gender).
"""
"""
    /!\ If you use a non pre trained text vectorization model, you should compute it on the train
    sample after train-test split (next cell) to avoid introducing bias in your evaluation. Indeed,
    if you compute for instance TF-IDF on the whole dataset (ie before splitting) test data will be
    used for word frequency computation.
"""
for key, value in data.items():
    X.append(value)
    Y.append(tags[key])
    re_match.append(regex_match[key])

## Test - train split

In [None]:
X_train, X_test, Y_train, Y_test, regex_train, regex_test = train_test_split(X, Y, re_match, test_size=0.2, random_state=42)

## Tags binarization and TF-IDF

In [None]:
Y_train = tag_binarizer.fit_transform(Y_train)
Y_test = tag_binarizer.transform(Y_test)
regex_test_bin = tag_binarizer.transform(regex_test)
# we remove rare matched terms to be consistent with the ML approach
regex_test_filtered = [[s for s in l if s in tag_counts.keys() and tag_counts[s] >= min_occurence] for l in regex_test]

## Models training

In [None]:
"""
    Training a random forest classifier.
    The parameters were obtained through grid search tuning method (see a few cells below).
"""
clf = RandomForestClassifier(
    n_estimators=200,
    max_depth=4,
    n_jobs=8
)
clf.fit(X_train_vec, Y_train)

pred_test = clf.predict_proba(X_test_vec)
"""
    As we have several labels to predict, we flatten the prediction output to compute the evaluation metrics.
"""
pred_test_flat = numpy.vstack(pred_test)

In [None]:
logit = OneVsRestClassifier(
    LogisticRegression(
        multi_class='ovr'
    ),
    n_jobs=8
)
logit.fit(X_train_vec, Y_train)
pred_test_logit = logit.predict_proba(X_test_vec)
pred_test_logit_flat = numpy.hstack(pred_test_logit)

In [None]:
svc = OneVsRestClassifier(
    SVC(probability=True),
    n_jobs=8
)
svc.fit(X_train_vec, Y_train)
pred_test_svc = svc.predict_proba(X_test_vec)
pred_test_svc_flat = numpy.hstack(pred_test_svc)

In [None]:
lgbm = OneVsRestClassifier(
    LGBMClassifier(
        max_depth=4,
        n_estimators=200
    ),
    n_jobs=10
)
lgbm.fit(X_train_vec, Y_train)
pred_test_lgbm = lgbm.predict_proba(X_test_vec)
pred_test_lgbm_flat = numpy.hstack(pred_test_lgbm)

In [None]:
"""
    Here we perform a basic hyperparameters tuning (grid search approach) to find a good set
    of parameters for the LGBM model. We can do the same for Random Forests.
    For a more sophisticated tuning method, we could use dedicated sklearn modules 
    (RandomizedSearchCV, GridSearchCV). 
"""
import itertools
from tqdm import tqdm

depth = [2, 4, 6, 10, 15]
n_est = [10, 50, 100, 150, 200]
best_auc = 0
best_params = None
best_lgbm = None
for x in tqdm(itertools.product(depth, n_est)):
    lgbm = OneVsRestClassifier(
        LGBMClassifier(
            max_depth=x[0],
            n_estimators=x[1]
        ),
        n_jobs=10
    )
    lgbm.fit(X_train_vec, Y_train)
    pred_test_lgbm = lgbm.predict_proba(X_test_vec)
    # lgbm + regex
    pred_test_lgbm_regex = pred_test_lgbm + regex_test_bin
    pred_test_lgbm_regex = numpy.minimum(pred_test_lgbm_regex, numpy.ones(pred_test_lgbm_regex.shape))
    pred_test_lgbm_regex_flat = numpy.hstack(pred_test_lgbm_regex)
    
    fpr_lgbm, tpr_lgbm, _ = roc_curve(Y_test.flatten('C'), pred_test_lgbm_regex_flat)
    roc_auc_lgbm = auc(fpr_lgbm, tpr_lgbm)
    if roc_auc_lgbm > best_auc:
        best_auc = roc_auc_lgbm
        best_params = x
        best_lgbm = lgbm

In [None]:
pred_test_lgbm = best_lgbm.predict_proba(X_test_vec)
pred_test_lgbm_flat = numpy.hstack(pred_test_lgbm)

# lgbm + regex
"""
    We use a very simple ensembling method for lgbm and regex. We add the prediction
    vectors them apply a threshold to be sure the result is lower than 1.
"""
pred_test_lgbm_regex = pred_test_lgbm + regex_test_bin
pred_test_lgbm_regex = numpy.minimum(pred_test_lgbm_regex, numpy.ones(pred_test_lgbm_regex.shape))
pred_test_lgbm_regex_flat = numpy.hstack(pred_test_lgbm_regex)

## ROC curves and AUC

In [None]:
"""
    Plotting the curves.
"""

fpr, tpr, _ = roc_curve(Y_test.flatten('F'), pred_test_flat)
roc_auc = auc(fpr, tpr)

fpr_logit, tpr_logit, _ = roc_curve(Y_test.flatten('C'), pred_test_logit_flat)
roc_auc_logit = auc(fpr_logit, tpr_logit)

fpr_svc, tpr_svc, _ = roc_curve(Y_test.flatten('C'), pred_test_svc_flat)
roc_auc_svc = auc(fpr_svc, tpr_svc)

fpr_lgbm, tpr_lgbm, _ = roc_curve(Y_test.flatten('C'), pred_test_lgbm_flat)
roc_auc_lgbm = auc(fpr_lgbm, tpr_lgbm)

fpr_lgbm_re, tpr_lgbm_re, _ = roc_curve(Y_test.flatten('C'), pred_test_lgbm_regex_flat)
roc_auc_lgbm_re = auc(fpr_lgbm_re, tpr_lgbm_re)

fpr_re, tpr_re, _ = roc_curve(Y_test.flatten('C'), numpy.hstack(regex_test_bin))
roc_auc_re = auc(fpr_re, tpr_re)

plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
         lw=lw, label='ROC curve random forest (area = %0.2f)' % roc_auc)
plt.plot(fpr_logit, tpr_logit, color='green',
         lw=lw, label='ROC curve logit (area = %0.2f)' % roc_auc_logit)
plt.plot(fpr_svc, tpr_svc, color='red',
         lw=lw, label='ROC curve SVM (area = %0.2f)' % roc_auc_svc)
plt.plot(fpr_lgbm, tpr_lgbm, color='black',
         lw=lw, label='ROC curve LGBM (area = %0.2f)' % roc_auc_lgbm)
plt.plot(fpr_lgbm_re, tpr_lgbm_re, color='blue',
         lw=lw, label='ROC curve LGBM regex (area = %0.2f)' % roc_auc_lgbm_re)
plt.plot(fpr_re, tpr_re, color='gray',
         lw=lw, label='ROC curve pure regex (area = %0.2f)' % roc_auc_re)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel('False Positive Rate')

plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic ({})'.format(level))
plt.legend(loc="lower right")
plt.savefig('roc_auc', dpi=None, facecolor='w', edgecolor='w')
plt.show()

## Precision - recall per model

We plot precision, recall and F1 score for each model.

We also compute and print contingency tables.

In [None]:
p, r, t = precision_recall_curve(Y_test.flatten('C'), numpy.hstack(regex_test_bin))
plt.figure()
lw = 2
plt.plot([0]+list(t), p, color='darkorange',
         lw=lw, label='precision')
plt.plot([0]+list(t), r, color='green',
         lw=lw, label='recall')
plt.xlabel('Threshold')
plt.ylabel('True Positive Rate')
plt.title('Precision - recall')
plt.legend(loc="lower right")
plt.show()

In [None]:
F1 = 2 * (p * r) / (p + r)
F1 = [x if x==x else 0 for x in F1]
plt.figure()
lw = 2
plt.plot([0]+list(t), F1, color='blue',
         lw=lw, label='regex')
plt.xlabel('Threshold')
plt.ylabel('F1 score')
plt.title('F1 score')
plt.legend(loc="lower right")
plt.show()

In [None]:
th_regex = t[numpy.argmax(F1)]
tn, fp, fn, tp = confusion_matrix(Y_test.flatten('C'), [0 if x < th_regex else 1 for x in numpy.hstack(regex_test_bin)]).ravel()
max(F1), tn, fp, fn, tp, th_regex

### Random forests

In [None]:
p, r, t = precision_recall_curve(Y_test.flatten('F'), pred_test_flat)
plt.figure()
lw = 2
plt.plot([0]+list(t), p, color='darkorange',
         lw=lw, label='precision')
plt.plot([0]+list(t), r, color='green',
         lw=lw, label='recall')
plt.xlabel('Threshold')
plt.ylabel('True Positive Rate')
plt.title('Precision - recall')
plt.legend(loc="lower right")
plt.show()

In [None]:
F1 = 2 * (p * r) / (p + r)
F1 = [x if x==x else 0 for x in F1]
plt.figure()
lw = 2
plt.plot([0]+list(t), F1, color='blue',
         lw=lw, label='rf')
plt.xlabel('Threshold')
plt.ylabel('F1 score')
plt.title('F1 score')
plt.legend(loc="lower right")
plt.show()

In [None]:
numpy.argmax(F1), F1[numpy.argmax(F1)]

In [None]:
th_rf = t[numpy.argmax(F1)]
tn, fp, fn, tp = confusion_matrix(Y_test.flatten('C'), [0 if x < th_rf else 1 for x in pred_test_flat]).ravel()
max(F1), tn, fp, fn, tp, th_rf

### Logistic regression

In [None]:
p, r, t = precision_recall_curve(Y_test.flatten('C'), pred_test_logit_flat)
plt.figure()
lw = 2
plt.plot([0]+list(t), p, color='darkorange',
         lw=lw, label='precision')
plt.plot([0]+list(t), r, color='green',
         lw=lw, label='recall')
plt.xlabel('Threshold')
plt.ylabel('True Positive Rate')
plt.title('Precision - recall')
plt.legend(loc="lower right")
plt.show()

In [None]:
F1 = 2 * (p * r) / (p + r)
F1 = [x if x==x else 0 for x in F1]
plt.figure()
lw = 2
plt.plot([0]+list(t), F1, color='blue',
         lw=lw, label='logit')
plt.xlabel('Threshold')
plt.ylabel('F1 score')
plt.title('F1 score')
plt.legend(loc="lower right")
plt.show()

In [None]:
th_logit = t[numpy.argmax(F1)]
tn, fp, fn, tp = confusion_matrix(Y_test.flatten('C'), [0 if x < th_logit else 1 for x in pred_test_logit_flat]).ravel()
max(F1), tn, fp, fn, tp, th_logit

### SVM

In [None]:
p, r, t = precision_recall_curve(Y_test.flatten('C'), pred_test_svc_flat)
plt.figure()
lw = 2
plt.plot([0]+list(t), p, color='darkorange',
         lw=lw, label='precision')
plt.plot([0]+list(t), r, color='green',
         lw=lw, label='recall')
plt.xlabel('Threshold')
plt.ylabel('True Positive Rate')
plt.title('Precision - recall')
plt.legend(loc="lower right")
plt.show()

In [None]:
F1 = 2 * (p * r) / (p + r)
F1 = [x if x==x else 0 for x in F1]
plt.figure()
lw = 2
plt.plot([0]+list(t), F1, color='blue',
         lw=lw, label='svc')
plt.xlabel('Threshold')
plt.ylabel('F1 score')
plt.title('F1 score')
plt.legend(loc="lower right")
plt.show()

In [None]:
th_svc = t[numpy.argmax(F1)]
tn, fp, fn, tp = confusion_matrix(Y_test.flatten('C'), [0 if x < th_svc else 1 for x in pred_test_svc_flat]).ravel()
max(F1), tn, fp, fn, tp, th_svc

### LGBM

In [None]:
p, r, t = precision_recall_curve(Y_test.flatten('C'), pred_test_lgbm_flat)
p_re, r_re, t_re = precision_recall_curve(Y_test.flatten('C'), pred_test_lgbm_regex_flat)
plt.figure()
lw = 2
plt.plot([0]+list(t), p, color='darkorange',
         lw=lw, label='precision', linestyle='--')
plt.plot([0]+list(t), r, color='green',
         lw=lw, label='recall', linestyle='--')
plt.plot([0]+list(t_re), p_re, color='darkorange',
         lw=lw, label='precision')
plt.plot([0]+list(t_re), r_re, color='green',
         lw=lw, label='recall')
plt.plot([0]+list(t_rem), p_rem, color='darkorange',
         lw=lw, label='precision', linestyle='dotted')
plt.plot([0]+list(t_rem), r_rem, color='green',
         lw=lw, label='recall', linestyle='dotted')
plt.xlabel('Threshold')
#plt.ylabel('True Positive Rate')
plt.title('Precision - recall')
plt.legend(loc="lower right")
plt.show()

In [None]:
p_re, r_re, t_re = precision_recall_curve(Y_test.flatten('C'), pred_test_lgbm_regex_flat)
plt.figure()
lw = 2
plt.plot([0]+list(t_re)[:-1], p_re[:-1], color='darkorange',
         lw=lw, label='precision')
plt.plot([0]+list(t_re)[:-1], r_re[:-1], color='green',
         lw=lw, label='recall')
plt.xlabel('Threshold')
#plt.ylabel('True Positive Rate')
plt.title('Precision - recall')
plt.legend(loc="lower right")
plt.savefig('precision_recall', dpi=None, facecolor='w', edgecolor='w')
plt.show()

In [None]:
F1 = 2 * (p * r) / (p + r)
F1 = [x if x==x else 0 for x in F1]
F1_re = 2 * (p_re * r_re) / (p_re + r_re)
F1_re = [x if x==x else 0 for x in F1_re]
plt.figure()
lw = 2
plt.plot([0]+list(t), F1, color='blue',
         lw=lw, label='lgbm', linestyle='--')
plt.plot([0]+list(t_re)[:-1], F1_re[:-1], color='blue',
         lw=lw, label='lgbm + regex')
plt.xlabel('Threshold')
plt.ylabel('F1 score')
plt.title('F1 score')
plt.legend(loc="lower right")
plt.show()

In [None]:
th_lgbm = t[numpy.argmax(F1)]
tn, fp, fn, tp = confusion_matrix(Y_test.flatten('C'), [0 if x < th_lgbm else 1 for x in pred_test_lgbm_flat]).ravel()
max(F1), tp, fn, fp, tn, th_lgbm

In [None]:
th_re = t_re[numpy.argmax(F1_re)]
tn, fp, fn, tp = confusion_matrix(Y_test.flatten('C'), [0 if x < th_re else 1 for x in pred_test_lgbm_regex_flat]).ravel()
max(F1_re), tp, fn, fp, tn, th_re