In [1]:
import pandas as pd
import numpy as np
import sklearn

In [4]:
data = pd.read_csv('~/Desktop/expansion_inclusion_dataset.csv')
print(data.shape)

(300, 10)


In [11]:
# Filter for data that's been labeled already
labeled_data = data[data['is_accurate'].notnull()].copy()
l1 = 'is_accurate'
l2 = 'is_clinical_relevant'
labeled_data[l1] = labeled_data[l1].apply(int)
labeled_data[l2] = labeled_data[l2].apply(int)
print(labeled_data.shape)

(74, 10)


In [14]:
l1 = 'is_accurate'
l2 = 'is_clinical_relevant'
labeled_data['y'] = labeled_data[l1].combine(labeled_data[l2], lambda a, b: 1 if a == 1 and b == 1 else 0)

In [16]:
neg_examples = labeled_data[labeled_data['y'] == 0].shape[0]
pos_examples = labeled_data[labeled_data['y'] == 1].shape[0]

print('Negative Examples={}, Positive Examples={}'.format(neg_examples, pos_examples))

Negative Examples=62, Positive Examples=12
     sf                                                cui  \
0   HNE  C0064833|C0205177|C0443467|C1414369|C3853793|C...   
1   LDR       C0013621|C0023185|C0680238|C0681814|C1706386   
2   ADP  C0032400|C0071360|C0332265|C0335296|C0747726|C...   
3   BPL                         C0023047|C0325581|C0998565   
4    GE                                           C0200044   
5   TBD                                  C0075804|C1511331   
6   PLT  C0030705|C0042444|C0243148|C0332293|C0348005|C...   
7  PARS                C0023689|C0035549|C0137710|C4522002   
8    DE                                  C0441471|C4019010   
9   DSW                         C0205148|C0678544|C1421479   

                                                lf  source  \
0        high levels of active neutrophil elastase  pubmed   
1  learner disagreement from experiment resampling  pubmed   
2                  activation demonstrated by poly  pubmed   
3                   boettc

In [98]:
# Generate features
X_str = []
y = []
for _, row in labeled_data.iterrows():
    row = row.to_dict()
    x = {}
    sem_groups = row['semgroups'].split('|')
    for sem_group in sem_groups:
        x[sem_group] = 1.0
    lfs = row['lf'].split('|')
    x['support'] = len(lfs)
    y.append(row['y'])
    
    sem_types = row['semtypes'].split('|')
    for sem_type in sem_types:
        x[sem_type] = 1.0
    X_str.append(x)

In [103]:
from sklearn.feature_extraction import DictVectorizer
dv = DictVectorizer(sparse=False)
X = dv.fit_transform(X_str, y)

In [104]:
from sklearn.linear_model import LogisticRegression

from sklearn.ensemble import RandomForestClassifier

USE_RF = True
if USE_RF:
    estimator = RandomForestClassifier(n_estimators=10)
else:
    estimator = LogisticRegression(solver='liblinear')
estimator.fit(X, y)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
                       max_depth=None, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=10,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False)

In [107]:
y_proba = estimator.predict_proba(X)[:, 1]
y_pred = (y_proba >= 0.5).astype(int)

In [108]:
from sklearn.metrics import classification_report, roc_auc_score

roc_auc = roc_auc_score(y, y_proba)
print('ROC AUC={}'.format(roc_auc))

cr = classification_report(y, y_pred)
print(cr)

ROC AUC=1.0
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        62
           1       1.00      1.00      1.00        12

    accuracy                           1.00        74
   macro avg       1.00      1.00      1.00        74
weighted avg       1.00      1.00      1.00        74



array([0. , 0. , 0.2, 0.2, 0.7, 0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0.8,
       0.1, 0. , 0. , 0.7, 0.1, 0.7, 1. , 0.9, 0.1, 0. , 0.1, 0. , 0. ,
       0.1, 0. , 0. , 0. , 0.2, 0. , 0. , 0.1, 0. , 0. , 0.3, 0.6, 0.7,
       0.1, 0. , 0.5, 0.5, 0. , 0.1, 0.1, 0.1, 0.9, 0. , 0. , 0.1, 0.2,
       0. , 0. , 0.6, 0.2, 0.1, 0. , 0. , 0.1, 0. , 0. , 0.1, 0. , 0.2,
       0.1, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.1])