# Imports

In [1]:
%load_ext autoreload
import sys

sys.path.append('..')

In [2]:
%autoreload

import itertools, os, pickle, pandas as pd
from collections import Counter

from chexpert_approximator.data_processor import *
from chexpert_approximator.run_classifier import *

from chexpert_approximator.reload_and_get_logits import *

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
# We can't display MIMIC-CXR Output:

DO_BLIND = True
def blind_display(df):
    if DO_BLIND:     
        df = df.copy()
        index_levels = df.index.names
        df.reset_index('rad_id', inplace=True)
        df['rad_id'] = [0 for _ in df['rad_id']]
        df.set_index('rad_id', append=True, inplace=True)
        df = df.reorder_levels(index_levels, axis=0)

        for c in df.columns:
            if pd.api.types.is_string_dtype(df[c]): df[c] = ['SAMPLE' for _ in df[c]]
            else: df[c] = np.NaN

    display(df.head())

# Load the Data

In [4]:
DATA_DIR = '/scratch/chexpert_approximator/processed_data/' # INSERT YOUR DATA DIR HERE!
# DATA MUST BE STORED IN A FILE `inputs.hdf` under key `with_folds`.
INPUT_PATH, INPUT_KEY = os.path.join(DATA_DIR, 'inputs.hdf'), 'with_folds'

# YOUR CLINICAL BERT MODEL GOES HERE
BERT_MODEL_PATH = (
    '/data/medg/misc/clinical_BERT/cliniBERT/models/pretrained_bert_tf/bert_pretrain_output_all_notes_300000/'
)

# THIS IS WHERE YOUR PRE-TRAINED CHEXPERT++ MODEL WILL BE WRITTEN
OUT_CXPPP_DIR = '../out/run_1'

# DON'T MODIFY THESE
FOLD = 'Fold'

KEY = {
    0: 'No Mention',
    1: 'Uncertain Mention',
    2: 'Negative Mention',
    3: 'Positive Mention',
}
INV_KEY = {v: k for k, v in KEY.items()}

In [5]:
inputs = pd.read_hdf(INPUT_PATH, INPUT_KEY)
label_cols = [col for col in inputs.index.names if col not in ('rad_id', FOLD)]

In [6]:
blind_display(inputs)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,Unnamed: 14_level_0,Unnamed: 15_level_0,sentence
rad_id,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Lesion,Airspace Opacity,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,Fold,Unnamed: 16_level_1
0,Positive Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,6,SAMPLE
0,Positive Mention,Negative Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,6,SAMPLE
0,Positive Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,6,SAMPLE
0,Positive Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,4,SAMPLE
0,No Mention,No Mention,No Mention,No Mention,Positive Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,No Mention,4,SAMPLE


# Model
## Data Processor

In [44]:
class CheXpertProcessor(DataProcessor):
    def __init__(self, tuning_fold, held_out_fold):
        super().__init__()
        self.tuning_fold, self.held_out_fold = tuning_fold, held_out_fold
    
    """Processor for the CheXpert approximator.
    Honestly this is kind of silly, as it never stores internal state."""
    def get_train_examples(self, df): return self._create_examples(
        df, set([f for f in range(K) if f not in (self.tuning_fold, self.held_out_fold)])
    )
    def get_dev_examples(self, df):   return self._create_examples(df, set([self.tuning_fold]))
    def get_examples(self, df, folds=[]): return self._create_examples(df, set(folds))
    
    def get_labels(self): return {label: list(range(4)) for label in label_cols}

    def _create_examples(self, df, folds):
        """Creates examples for the training and dev sets."""
        df = df[df.index.get_level_values(FOLD).isin(folds)]
        lmap = {l: i for i, l in enumerate(df.index.names)}
        
        examples = []
        for idx, r in df.iterrows():
            labels = {l: INV_KEY[idx[lmap[l]]] for l in label_cols}
            
            examples.append(InputExample(
                guid=str(idx[lmap['rad_id']]), text_a=r.sentence, text_b=None, label=labels
            ))
        return examples

## Running the Model

In [45]:
processor = CheXpertProcessor(8, 9)
# train_examples = processor.get_train_examples(inputs)
# dev_examples = processor.get_dev_examples(inputs)

In [13]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'

In [None]:
out = build_and_train(
    inputs,
    bert_model = BERT_MODEL_PATH,
    processor  = processor,
    task_dimensions = {l: 4 for l in label_cols},
    output_dir = OUT_CXPPP_DIR,
    gradient_accumulation_steps = 1,
    gpu                         = '0,1,2',
    do_train                    = True,
    do_eval                     = True,
    seed                        = 42,
    do_lower_case               = False,
    max_seq_length              = 128,
    train_batch_size            = 32,
    eval_batch_size             = 8,
    learning_rate               = 5e-5,
    num_train_epochs            = 5,
    warmup_proportion           = 0.1,
)

/data/medg/misc/clinical_BERT/cliniBERT/models/pretrained_bert_tf/bert_pretrain_output_all_notes_300000/
Max Sequence Length: 112


HBox(children=(IntProgress(value=0, description='Epoch', max=5, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…



HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…

In [16]:
out

{'eval_loss': 0.04080065189753456,
 'eval_accuracy': 0.9992550486952994,
 'global_step': 94200,
 'loss': 0.0014591591281715949}