### Evidence Inference

This task involves consuming two inputs: (1) a natural language (full-text) article describing a clinical trial, and, (2) an "evidence prompt" that specifies an intervention, comparator and outcome of interest. For example, the trial may have been comparing over-the-counter headache medicine, and the "prompt" may read:

With respect to frequency of contractions, characterize the difference between antibiotics and placebo.

> With respect to *duration of headache* (outcome), characterize the difference between *aspirin* (intervention) and *placebo* (comparator)

The task is then to infer the reported finding for this triplet. In particular, this is framed as a classification task with three categories: *significantly decreased*, *no significant difference*. 



In [43]:
import random

In [44]:
import numpy as np
import pandas as pd
import torch

In [45]:
import preprocessor

<module 'preprocessor' from '/Users/byron/dev/evidence-inference/preprocessor.py'>

In [62]:
PROMPT_ID_COL_NAME = "RowID"

In [46]:
prompts = pd.read_csv("pilot_run_data/prompts.csv")
annotations = pd.read_csv("pilot_run_data/annotations.csv")

In [47]:
prompts.head()

Unnamed: 0,RowID,Outcome,Intervention,Comparator,Answer,Reasoning,XML,Answer_Vals
0,1,venous ulcer healing,two-layer bandage,compression stocking,no significant difference,"However, the observed differences were not sta...",3298351,0
1,2,venous ulcer healing,four-layer bandage,compression stocking,no significant difference,"However, the observed differences were not sta...",3298351,0
2,3,venous ulcer healing,compression therapy,no treatment,significantly increased,We compared the dynamics of ulcer healing of p...,3298351,1
3,4,venous ulcer healing,four-layer bandage,two-layer bandage,no significant difference,"What is more, no statistically significant adv...",3298351,0
4,5,CEAP scores,four-layer bandage,two-layer bandage,no significant difference,The statistical analysis did not reveal any st...,3298351,0


In [48]:
annotations.head()

Unnamed: 0,UserId,RowID,PMCID,Selection,Annotation,Outcome,Comparator,Intervention,Invalid Prompt,Prompt Reason,Answer_Val,In Abstract
0,0,82.0,2875419,Significantly increased,Weight gain was higher for glargine (differenc...,Weight Gain,detemir,glargine,0,,1,1
1,0,75.0,3281242,Significantly increased,Continuous abstinence was higher for varenicli...,continuous abstienence at 24 week follow up,placebo,varenicline,0,,1,1
2,0,64.0,1764008,No significant difference,Differences in the duration of the active-firs...,duration of labor stages,epidural analgesia,Meperidine analgesia,0,,0,1
3,0,79.0,2875419,No significant difference,The proportions of patients achieving A1C <7% ...,number of patients with A1C <7%,glargine,detemir,0,,0,0
4,0,13.0,2366143,No significant difference,No significant differences in the interval to ...,neonatal outcomes,antibiotic therapy and tocolysis,tocolysis,0,,0,1


Create a temporary validation set for experimentation. ***TODO*** once we collect a reasonable amount of data we will have official train/val/test splits.

In [34]:
unique_articles = annotations["PMCID"].unique()
train_set_size = int(0.8 * len(unique_articles))
train_doc_ids = np.random.choice(unique_articles, train_set_size, replace=False)
val_doc_ids = [doc_id for doc_id in unique_articles if not doc_id in train_doc_ids]

In [39]:
inference_vectorizer = preprocessor.get_inference_vectorizer(article_ids=train_doc_ids)

In [74]:
training_prompts = prompts[prompts['XML'].isin(train_doc_ids)]

# filter out prompts for which we do not have annotations for whatever reason
# this was actually just one case; not sure what was going on there.
def have_annotations_for_prompt(prompt_id):
    return len(annotations[annotations[PROMPT_ID_COL_NAME] == prompt_id]) > 0

training_prompts = [prompt for row_idx, prompt in training_prompts.iterrows() if 
                        have_annotations_for_prompt(prompt[PROMPT_ID_COL_NAME])]

training_prompts = pd.DataFrame(training_prompts)

In [75]:
type(training_prompts)

pandas.core.frame.DataFrame

In [76]:
train_Xy = []
for prompt_id in training_prompts[PROMPT_ID_COL_NAME].values:
    Xy_dict = inference_vectorizer.vectorize(training_prompts, prompt_id, include_lbls=True, annotations_df=annotations)
    train_Xy.append(Xy_dict)

In [77]:
train_Xy[0]

{'C': [240, 855],
 'I': [385, 481, 168],
 'O': [825, 805, 411],
 'article': [782,
  234,
  563,
  769,
  324,
  563,
  240,
  737,
  137,
  481,
  240,
  760,
  437,
  825,
  806,
  795,
  94,
  94,
  470,
  769,
  121,
  563,
  769,
  680,
  829,
  783,
  232,
  769,
  317,
  563,
  825,
  805,
  411,
  835,
  794,
  841,
  769,
  816,
  563,
  240,
  737,
  158,
  833,
  158,
  573,
  803,
  137,
  385,
  481,
  168,
  760,
  94,
  504,
  137,
  519,
  400,
  563,
  58,
  590,
  748,
  388,
  825,
  807,
  829,
  739,
  776,
  400,
  252,
  563,
  47,
  78,
  846,
  137,
  5,
  30,
  76,
  515,
  119,
  178,
  54,
  137,
  83,
  850,
  769,
  166,
  118,
  829,
  72,
  850,
  137,
  769,
  511,
  829,
  73,
  590,
  834,
  645,
  468,
  779,
  401,
  383,
  795,
  841,
  769,
  629,
  803,
  481,
  758,
  627,
  385,
  481,
  240,
  137,
  841,
  769,
  816,
  563,
  240,
  737,
  220,
  430,
  437,
  769,
  200,
  563,
  534,
  481,
  240,
  240,
  337,
  53,
  527,
  182,
  617,
  

In [83]:
val_prompts = prompts[~prompts['XML'].isin(train_doc_ids)]
val_Xy = []
for prompt_id in val_prompts[PROMPT_ID_COL_NAME].values:
    Xy_dict = inference_vectorizer.vectorize(val_prompts, prompt_id, include_lbls=True, annotations_df=annotations)
    val_Xy.append(Xy_dict)


In [86]:
val_Xy[0]

{'C': [855, 855, 563, 803, 855, 855],
 'I': [802, 855, 716, 855, 855],
 'O': [855],
 'article': [782,
  855,
  563,
  855,
  855,
  437,
  855,
  855,
  137,
  367,
  577,
  94,
  94,
  167,
  769,
  264,
  855,
  263,
  855,
  855,
  855,
  855,
  315,
  855,
  855,
  257,
  776,
  741,
  829,
  855,
  783,
  232,
  769,
  326,
  563,
  263,
  855,
  803,
  295,
  855,
  855,
  855,
  855,
  565,
  855,
  855,
  137,
  855,
  578,
  437,
  855,
  611,
  846,
  94,
  519,
  803,
  425,
  137,
  855,
  846,
  841,
  717,
  855,
  550,
  855,
  304,
  137,
  855,
  855,
  7,
  855,
  834,
  646,
  161,
  783,
  566,
  563,
  779,
  401,
  855,
  400,
  855,
  803,
  855,
  855,
  855,
  855,
  855,
  6,
  521,
  855,
  591,
  831,
  76,
  400,
  855,
  566,
  855,
  802,
  855,
  6,
  521,
  855,
  591,
  831,
  76,
  137,
  400,
  855,
  566,
  855,
  263,
  62,
  521,
  855,
  591,
  267,
  76,
  550,
  107,
  855,
  834,
  855,
  855,
  137,
  855,
  855,
  855,
  834,
  508,
  163,
 