# Trying out a Regular Expression as a Baseline for High Recall

Kaggle Model 2 uses the Schwartz-Heasrt (SW) algorithm for extracting candidate 
entities and then classifies them using a binary classifier (see 
`explore_schwartz_heart_baseline.ipynb` for more info). The SW algorithm will
miss any entities that don't match the pattern LONG FORM (ACRONYM). In the 
evalution of the Kaggle private data set, this produced a recall of 0.65. So,
at best models leveraging the SW algorithm will only produce a recall of 0.65.

This notebook tries using a Regular Expression based extraction method to get 
candidates which is more flexible than the SW algorithm.


In [1]:
from itertools import chain
import json

import pandas as pd
from thefuzz import fuzz, process

import src.models.regex_model as rm
from src.data.kaggle_repository import KaggleRepository
from src.evaluate.model import evaluate_model, evaluate_kaggle_private

In [4]:
repo = KaggleRepository()

In [24]:
# the `scorer` and `processor` arguments are explained in the notebook
# `defining_a_match_1.ipynb`

evaluation = evaluate_kaggle_private(
    rm.RegexModel(dict()),
    dict(),  # this model doesn't have any configuration params
    scorer=fuzz.partial_ratio,  # use fuzzy string matching
    processor=lambda s: s.lower(),  # convert to lowercase
)

In [25]:
evaluation = evaluate_model(
    repo, 
    rm.RegexModel(dict()), 
    dict(),
)

evaluation


        Model Evaluation:

        - Run time: 52.208805322647095 seconds, avg: 0.00647510918053418 seconds per sample
        - True Postive Count: 24139, avg: 2.993798834180826 per sample
        - Precision: 0.07442429773418881
        - Recall: 0.6834758480095136
        

In [26]:
stats = evaluation.output_statistics
all_labels = list(chain(*list(map(lambda x: x["labels"], stats["statistics"].values))))
global_stats = list(chain(*list(map(lambda x: x["stats"], stats["statistics"].values))))

In [27]:
stats_df = pd.DataFrame({"labels": all_labels, "stats": global_stats})
stats_df.loc[stats_df["stats"] == "FN", :].groupby("labels").count().sort_values(
    "stats", ascending=False
)

Unnamed: 0_level_0,stats
labels,Unnamed: 1_level_1
dbgap,3710
database of genotypes and phenotypes,129
gtex,112
1000 genomes project,110
database of genotypes and phenotypes dbgap,83
...,...
genemania,1
genenetwork,1
genenetwork org,1
generation scotland,1


Let's try adding some explicit keywords that we want to include that don't alway match the rules. 

In [31]:
keywords = [
    "database of genotypes and phenotypes",
    "dbgap",
    "DART buoy",
    "pisa"
]

evaluation_with_keywords = evaluate_model(
    repo, 
    rm.RegexModel(dict(keywords=keywords)), 
    dict(),
)

evaluation_with_keywords


        Model Evaluation:

        - Run time: 100.09639191627502 seconds, avg: 0.012414286483476996 seconds per sample
        - True Postive Count: 28030, avg: 3.4763735582289472 per sample
        - Precision: 0.0853339990379756
        - Recall: 0.7985299982906957
        

In [29]:
stats = evaluation_with_keywords.output_statistics
all_labels = list(chain(*list(map(lambda x: x["labels"], stats["statistics"].values))))
global_stats = list(chain(*list(map(lambda x: x["stats"], stats["statistics"].values))))

In [30]:
stats_df = pd.DataFrame({"labels": all_labels, "stats": global_stats})
stats_df.loc[stats_df["stats"] == "FN", :].groupby("labels").count().sort_values(
    "stats", ascending=False
)

Unnamed: 0_level_0,stats
labels,Unnamed: 1_level_1
gtex,112
1000 genomes project,110
business r d and innovation survey,79
dbsnp,75
dart buoy,73
...,...
generation scotland geneva,1
genereviews,1
generif,1
genes and genomes database,1


The regex doesn't include acronyms so dbgap was expected to not be found, let's look at `women s interagency hiv study`

In [16]:
def model_missed_label(label, row):
    if label in row["labels"]:
        lbl_idx = row["labels"].index(label)
        return row["stats"][lbl_idx] == "FN"
    else:
        return False


missed_mask = stats["statistics"].apply(lambda row: model_missed_label("wihs", row))
missed = stats.loc[missed_mask, :]
missed

Unnamed: 0,id,label,statistics
38,407584a01,hers|hiv epidemiology research study|wihs|wome...,"{'labels': ['hiv epidemiology research study',..."
53,5b470b7ef,wihs|women s interagency hiv study|women s int...,"{'labels': ['women s interagency hiv study', '..."
106,9dec63cff,wihs|women s interagency hiv study|women s int...,"{'labels': ['women s interagency hiv study', '..."
114,3a3e7e66c,genbank database|hiv blast|rega hiv 1 automate...,"{'labels': ['women s interagency hiv study', '..."
230,90426633c,wihs|women s interagency hiv study|women s int...,"{'labels': ['women s interagency hiv study', '..."
...,...,...,...
7783,14e5b1a69,wihs|women s interagency hiv study|women s int...,"{'labels': ['women s interagency hiv study', '..."
7815,605a4cb71,wihs|women s interagency hiv study|women s int...,"{'labels': ['women s interagency hiv study', '..."
7820,6295e8bb6,aric|atherosclerosis risk in communities|d a d...,{'labels': ['atherosclerosis risk in communiti...
7829,9755b1f7c,macs|multicenter aids cohort study|multicenter...,"{'labels': ['multicenter aids cohort study', '..."


The missed sample from `b043c048c.json` looks like this:

*The fi ve included datasets were from the USA National Institute on Aging, UK, Germany, France, and the USA **database of genotypes and phenotypes**.*

This is all lower case and is an expected miss.