**References**
- [Evaluation metrics](https://www.kaggle.com/code/theoviel/evaluation-metric-folds-baseline#Metric)
- [Bert](https://www.kaggle.com/code/tomohiroh/nbme-bert-for-beginners)


In [1]:
import torch
import numpy as np
import pandas as pd
import plotly.express as px

from itertools import chain
from ast import literal_eval
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_fscore_support
from tqdm.notebook import tqdm, trange
from sklearn.model_selection import StratifiedKFold
from transformers import AutoModel, AutoTokenizer

In [2]:
# Seed everything
SEED = 13
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

In [3]:
# Load prepaired training csv
train_df = pd.read_csv('train_ready.csv',
                       converters={'annotation': lambda x: literal_eval(x),
                                   'location': lambda x: literal_eval(x)})
train_df

Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location,pn_history,feature_text
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724],hpi: 17yo m presents with palpitations. patien...,family history of mi; family history of myocar...
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693],hpi: 17yo m presents with palpitations. patien...,family history of thyroid disorder
2,00016_002,0,16,2,[chest pressure],[203 217],hpi: 17yo m presents with palpitations. patien...,chest pressure
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]",hpi: 17yo m presents with palpitations. patien...,intermittent symptoms
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258],hpi: 17yo m presents with palpitations. patien...,lightheaded
...,...,...,...,...,...,...,...,...
14295,95333_912,9,95333,912,[],[],stephanie madden is a 20 year old woman compla...,family history of migraines
14296,95333_913,9,95333,913,[],[],stephanie madden is a 20 year old woman compla...,female
14297,95333_914,9,95333,914,[photobia],[274 282],stephanie madden is a 20 year old woman compla...,photophobia
14298,95333_915,9,95333,915,[no sick contacts],[421 437],stephanie madden is a 20 year old woman compla...,no known illness contacts


# KFold

In [4]:
# Add kfold markup
skf = StratifiedKFold(n_splits=10)
train_df['stratify_on'] = train_df['case_num'].astype(str) + train_df['feature_num'].astype(str)
train_df['fold'] = -1

for fold, (_, valid_idx) in enumerate(skf.split(train_df['id'], y=train_df['stratify_on'])):
    train_df.loc[valid_idx, 'fold'] = fold
    
train_df.shape

(14300, 10)

# Tokenizer

In [5]:
# Init tokenizer
MODEL_NAME = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [6]:
def loc_list_to_ints(loc_list) -> list:
    """ 
    Convert type of 'location' column to int.
    
    Args:
        loc_list (list): list with locations.
        
    Returns:
        list: converted to int results.
    """
    
    to_return = []
    for loc_str in loc_list:
        loc_strs = loc_str.split(';')
        for loc in loc_strs:
            start, end = loc.split(' ')
            to_return.append((int(start), int(end)))
    return to_return

In [7]:
def tokenize_and_add_labels(tokenizer_, row) -> dict:
    """
    Tokenize and add labels.
    
    Args:
        tokenizer_ (transformers bert-base-uncased): tokenizer for the BERT model.
        row (dict): one row from the dataframe.
        
    Returns:
        dict: tokenized data.
    
    """
    
    # Tokenize input data, and get 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'
    tokenized_inputs = tokenizer_(row['feature_text'],
                                  row['pn_history'],
                                  truncation = 'only_second',
                                  max_length = 416,
                                  padding = 'max_length',
                                  return_offsets_mapping = True)
    
    # Create zero float list 
    labels = [0.0] * len(tokenized_inputs['input_ids']) 
    tokenized_inputs['location_int'] = loc_list_to_ints(row['location'])  # add converted to int locations
    tokenized_inputs['sequence_ids'] = tokenized_inputs.sequence_ids()  # add sequence ids
    
    # Fill labels
    for idx, (seq_id, offsets) in enumerate(zip(tokenized_inputs['sequence_ids'], tokenized_inputs['offset_mapping'])):
        # -100 is our fill value 
        if seq_id is None or seq_id == 0:
            labels[idx] = -100.0
            continue
            
        # Based on location, mark with 1 the righ position of the feature 
        exit = False
        token_start, token_end = offsets
        for feature_start, feature_end in tokenized_inputs['location_int']:
            if exit:
                break
            if token_start >= feature_start and token_end <= feature_end:
                labels[idx] = 1.0
                exit = True
    tokenized_inputs['labels'] = labels
    
    return tokenized_inputs

### Pipeline example

In [8]:
# This is how the data will look like in pipeline
example_df = train_df.copy()

first = example_df.loc[100]
example = {'feature_text': first.feature_text,
           'pn_history': first.pn_history,
           'location': first.location,
           'annotation': first.annotation}

for key in example.keys():
    print(key)
    print(example[key])
    print('=' * 100)

feature_text
heart pounding; heart racing
pn_history
hpi: patient is a 17 yo m with a c/o of palpitations.  palpitations began a few months ago. states that palpitations are sudden, unpredictable and feel like his heart is pounding fast/jumping out of his chest. typically these episodes last 3-4 minutes and resolve on their own. his most recent episode was 2 days ago and lasted about 10 minutes. during this epsiode he felt lightheaded, short of breath and had chest pressure located in the middle of his chest. denies any sweating, changes in hair or bowel movements.
ros: negative except as stated above
pmh: none
meds: takes his roommates adderall to help study
allergies: nkda
pshx: none
fh: mother has a thyroid problem, father had a mi this past year at age 53
sh: denies to
location
['40 52', '55 67', '104 116', '161 178', '161 169;179 183']
annotation
['palpitations', 'Palpitations', 'palpitations', 'heart is pounding', 'heart is fast']


In [9]:
tokenized_inputs = tokenize_and_add_labels(tokenizer, example)
for key in tokenized_inputs.keys():
    print(key)
    print(tokenized_inputs[key])
    print("=" * 100)

input_ids
[101, 2540, 9836, 1025, 2540, 3868, 102, 6522, 2072, 1024, 5776, 2003, 1037, 2459, 10930, 1049, 2007, 1037, 1039, 1013, 1051, 1997, 14412, 23270, 10708, 1012, 14412, 23270, 10708, 2211, 1037, 2261, 2706, 3283, 1012, 2163, 2008, 14412, 23270, 10708, 2024, 5573, 1010, 21446, 1998, 2514, 2066, 2010, 2540, 2003, 9836, 3435, 1013, 8660, 2041, 1997, 2010, 3108, 1012, 4050, 2122, 4178, 2197, 1017, 1011, 1018, 2781, 1998, 10663, 2006, 2037, 2219, 1012, 2010, 2087, 3522, 2792, 2001, 1016, 2420, 3283, 1998, 6354, 2055, 2184, 2781, 1012, 2076, 2023, 20383, 3695, 3207, 2002, 2371, 2422, 4974, 2098, 1010, 2460, 1997, 3052, 1998, 2018, 3108, 3778, 2284, 1999, 1996, 2690, 1997, 2010, 3108, 1012, 23439, 2151, 18972, 1010, 3431, 1999, 2606, 2030, 6812, 2884, 5750, 1012, 20996, 2015, 1024, 4997, 3272, 2004, 3090, 2682, 7610, 2232, 1024, 3904, 19960, 2015, 1024, 3138, 2010, 18328, 2015, 5587, 21673, 2140, 2000, 2393, 2817, 2035, 2121, 17252, 1024, 25930, 2850, 8827, 2232, 2595, 1024, 3904, 1042

# Dataset

In [10]:
class NBMEData(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data.loc[idx]
        tokenized = tokenize_and_add_labels(self.tokenizer, sample)
        
        input_ids = np.array(tokenized['input_ids'])  # for input BERT
        attention_mask = np.array(tokenized['attention_mask'])  # for input BERT
        labels = np.array(tokenized['labels'])
        
        offset_mapping = np.array(tokenized['offset_mapping'])  
        sequence_ids = np.array(tokenized['sequence_ids']).astype('float16') 
        
        return input_ids, attention_mask, labels, offset_mapping, sequence_ids

In [11]:
class NBMETestData(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data.loc[idx]
        tokenized = self.tokenizer(sample['feature_text'],
                                   sample['pn_history'],
                                   truncation = 'only_second',
                                   max_length = 416,
                                   padding = 'max_length',
                                   return_offsets_mapping=True)
        tokenized['sequence_ids'] = tokenized.sequence_ids()

        input_ids = np.array(tokenized['input_ids'])
        attention_mask = np.array(tokenized['attention_mask'])
        offset_mapping = np.array(tokenized['offset_mapping'])
        sequence_ids = np.array(tokenized['sequence_ids']).astype('float16')

        return input_ids, attention_mask, offset_mapping, sequence_ids

In [12]:
class NBMEModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(MODEL_NAME) # BERT model
        self.dropout = torch.nn.Dropout(p=0.2)
        self.classifier = torch.nn.Linear(768, 1) 
    
    def forward(self, input_ids, attention_mask):
        last_hidden_state = self.backbone(input_ids=input_ids, 
                                          attention_mask=attention_mask)[0] 
        logits = self.classifier(self.dropout(last_hidden_state)).squeeze(-1)
        return logits

# Train

In [13]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n = 1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [14]:
BATCH_SIZE = 16
EPOCHS = 3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = NBMEModel().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

test_folds = (0,)
train_folds = (1, 2, 3, 4, 5, 6, 7)
valid_folds = (8, 9)

# Split training data frame into test, train and valid data frames
test = train_df.loc[train_df['fold'] == test_folds[0]]
for fold in test_folds[1:]:
    test = test.append(train_df.loc[train_df['fold'] == fold]) 
test.reset_index(inplace=True, drop=True)    
        
train = train_df.loc[train_df['fold'] == train_folds[0]]
for fold in train_folds[1:]:
    train = train.append(train_df.loc[train_df['fold'] == fold])
train.reset_index(inplace=True, drop=True)  

valid = train_df.loc[train_df['fold'] == valid_folds[0]]
for fold in valid_folds[1:]:
    valid = valid.append(train_df.loc[train_df['fold'] == fold])
valid.reset_index(inplace=True, drop=True)  

# Init datasets    
test_ds = NBMETestData(test, tokenizer)    
train_ds = NBMEData(train, tokenizer)
valid_ds = NBMEData(valid, tokenizer)

# Init data loaders
test_dl = torch.utils.data.DataLoader(test_ds, 
                                      batch_size=BATCH_SIZE * 2,
                                      pin_memory=True, 
                                      shuffle=False, 
                                      drop_last=False)
train_dl = torch.utils.data.DataLoader(train_ds, 
                                       batch_size=BATCH_SIZE,
                                       pin_memory=True, 
                                       shuffle=True, 
                                       drop_last=True)
valid_dl = torch.utils.data.DataLoader(valid_ds,
                                       batch_size=BATCH_SIZE * 2, 
                                       pin_memory=True, 
                                       shuffle=False, 
                                       drop_last=False)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
def train_fn(t_dataloader, model_, optimizer_) -> float:
    """
    Args:
        t_dataloader (torch.utils.data.DataLoader): training data loader.
        model_ (bert-base-uncased): in this case it is BERT model.
        optimizer_ (torch.optim.{}): torch.optim.AdamW for example.
        
    Returns:
        float: average training loss.
    """
    
    model_.train()
    t_loss = AverageMeter()
    
    progress_bar = tqdm(t_dataloader)
    for batch in progress_bar:
        optimizer.zero_grad()
        
        input_ids = batch[0].to(DEVICE)
        attention_mask = batch[1].to(DEVICE)
        labels = batch[2].to(DEVICE)
        
        logits = model_(input_ids, attention_mask)
        loss_fct = torch.nn.BCEWithLogitsLoss(reduction='none')
        loss = loss_fct(logits, labels)
        loss = torch.masked_select(loss, labels > -1).mean() 
        loss.backward()
        
        optimizer_.step()
        t_loss.update(val=loss.item(), n=len(input_ids))
        progress_bar.set_postfix(Loss=t_loss.avg)
    
    return t_loss.avg

In [16]:
def valid_fn(v_dataloader, model_) -> float:
    """
    Args:
        v_dataloader (torch.utils.data.DataLoader): validation data loader.
        model_ (bert-base-uncased): in this case it is BERT model.
        
    Returns:
        float: average validation loss.
    """
    
    model_.eval()
    v_loss = AverageMeter()
    
    with torch.no_grad():
        process_bar = tqdm(v_dataloader)
        for batch in process_bar:
            input_ids = batch[0].to(DEVICE)
            attention_mask = batch[1].to(DEVICE)
            labels = batch[2].to(DEVICE)
            
            logits = model_(input_ids, attention_mask)
            loss_fct = torch.nn.BCEWithLogitsLoss(reduction='none')
            loss = loss_fct(logits, labels)
            loss = torch.masked_select(loss, labels > -1).mean()
    
            v_loss.update(val=loss.item(), n=len(input_ids))
            process_bar.set_postfix(Loss=v_loss.avg)
            
    return v_loss.avg

In [17]:
TRAIN = False

# Training process
history = {'train': [], 'valid': []}
if TRAIN:
    best_loss = np.inf

    for epoch in range(EPOCHS):
        print("Epoch: {}/{}".format(epoch + 1, EPOCHS))
        
        train_loss = train_fn(train_dl, model, optimizer)
        history['train'].append(train_loss)

        valid_loss = valid_fn(valid_dl, model)
        history['valid'].append(valid_loss)

        # Save model
        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), 'nbme.pth')

# Evaluation


In [18]:
def micro_f1(preds, truths) -> float:
    """
    Micro f1 on binary arrays.

    Args:
        preds (list of lists of ints): Predictions.
        truths (list of lists of ints): Ground truths.

    Returns:
        float: f1 score.
        
    Example:
        preds = [[0, 0, 1], [0, 0, 0]]
        truths = [[0, 0, 1], [1, 0, 0]]
        >>> 0.6666666666
    """
    
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return f1_score(truths, preds)


def spans_to_binary(spans, length=None) -> np.array:
    """
    Converts spans to a binary array indicating whether each character is in the span.

    Args:
        spans (list of lists of two ints): Spans.

    Returns:
        np array [length]: Binarized spans.
        
    Example:
        spans = [(0, 5), (10, 15)]
        >>> array([1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.])
    """
    
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1
    return binary


def span_micro_f1(preds, truths) -> float:
    """
    Micro f1 on spans.

    Args:
        preds (list of lists of two ints): Prediction spans.
        truths (list of lists of two ints): Ground truth spans.

    Returns:
        float: f1 score.
        
    Example:
        preds = [[(0, 1)], [(2, 3), (4, 5)], [(6, 7)], [(10, 11), (12, 13), (14, 15)]]
        truths = [[(0, 1)], [(2, 3), (4, 5)], [(6, 7)], [(8, 11), (12, 13), (14, 15)]]
        >>> 0.8750000000000001
    """
    
    bin_preds = []
    bin_truths = []
    for pred, truth in zip(preds, truths):
        if not len(pred) and not len(truth):
            continue
        length = max(np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0)
        bin_preds.append(spans_to_binary(pred, length))
        bin_truths.append(spans_to_binary(truth, length))
    return micro_f1(bin_preds, bin_truths)


def create_spans_for_scoring(locations) -> list:
    """
    Args:
        locations (pd.Series): raw locations 
    
    Example:
        locations = [['0 1'], ['2 3', '4 5'], ['6 7'], [], ['10 11;12 13', '14 15']]
        >>> [[(0, 1)], [(2, 3), (4, 5)], [(6, 7)], [], [(10, 11), (12, 13), (14, 15)]]
    """
    
    spans = []
    for row in locations:
        span = []
        for loc_c in row:
            for loc in loc_c.split(';'):
                s, e = loc.split(' ')
                span.append((int(s), int(e)))
        spans.append(span)
    
    return spans

In [19]:
def sigmoid(z):
    """
    Standard sigmoid function.
    """
    
    return 1 / (1 + np.exp(-z))


def get_location_predictions(predictions, offset_mapping, sequence_ids, for_submission=False) -> list:
    """
    Convert raw predictions to understandable location points.
    
    Args:
        predictions (list): model output.
        offset_mapping (list): tokenizer product.
        sequence_ids (list): tokenizer product.
        for_submission (bool): set True if converting predictions fro submission.
        
    Returns:
        list: converted locations.
    """
    
    all_predictions = []
    for preds, offsets, seq_ids in zip(predictions, offset_mapping, sequence_ids):
        preds = sigmoid(preds)
        start_idx = None
        
        current_preds = []
        for p, o, s_id in zip(preds, offsets, seq_ids):
            if not s_id:
                continue
            if p > 0.5:
                if start_idx is None:
                    start_idx = o[0]
                end_idx = o[1]
            elif start_idx is not None:
                if for_submission:
                    current_preds.append(f'{start_idx} {end_idx}')
                else:
                    current_preds.append((start_idx, end_idx))
                start_idx = None
        if for_submission:
            all_predictions.append('; '.join(current_preds))
        else:
            all_predictions.append(current_preds)
    return all_predictions

In [20]:
model.load_state_dict(torch.load('nbme.pth', map_location=DEVICE))

def get_predictions(model_, dataloader, for_submission=False) -> list:
    """
    Input test data into the model, convert outputs and return ready results.
    
    Args:
        model_ (bert-base-uncased): in this case it is BERT model.
        dataloader (torch.utils.data.DataLoader): data loader.
        for_submission (bool): flag for get_location_predictions() function.
        
    Returns:
        float: converted locations.
    """
    
    model_.eval()
    preds = []
    offsets = []
    seq_ids = []

    with torch.no_grad():
        for batch in tqdm(test_dl):
            input_ids = batch[0].to(DEVICE)
            attention_mask = batch[1].to(DEVICE)
            offset_mapping = batch[2]
            sequence_ids = batch[3]

            logits = model(input_ids, attention_mask)
            preds.append(logits.cpu().numpy())
            offsets.append(offset_mapping.numpy())
            seq_ids.append(sequence_ids.numpy())

    preds = np.concatenate(preds, axis=0)
    offsets = np.concatenate(offsets, axis=0)
    seq_ids = np.concatenate(seq_ids, axis=0)
    return get_location_predictions(preds, offsets, seq_ids, for_submission=for_submission) # convert predicted locations to spans

pred_spans = get_predictions(model, test_dl)
truth_df = test['location']  # get true locations from the test dataset 
truth_spans = create_spans_for_scoring(truth_df)  # convert locations to spans

  0%|          | 0/45 [00:00<?, ?it/s]

In [21]:
# Calculate model score
span_micro_f1(pred_spans, truth_spans)

0.8219291014014839

# Submission

In [22]:
def create_test_df() -> pd.DataFrame:
    """
    Creates test dataframe for submission.
    
    Returns:
        pd.DataFrame: merged and preprocessed dataframe.
    """
    
    feats = pd.read_csv('./data/features.csv')
    notes = pd.read_csv('./data/patient_notes.csv')
    test = pd.read_csv('./data/test.csv')

    merged = test.merge(notes, how='left')
    merged = merged.merge(feats, how='left')

    def process_feature_text(text):
        # Add here new prepocessing functions
        return text.replace('-OR-', ';-').replace('-', ' ')
    
    merged['feature_text'] = [process_feature_text(x) for x in merged['feature_text']]
    
    print(merged.shape)
    return merged

In [25]:
test_df = create_test_df()
test_ds = NBMETestData(test_df, tokenizer)
test_dl = torch.utils.data.DataLoader(test_ds, 
                                      batch_size=BATCH_SIZE * 2, 
                                      pin_memory=True, 
                                      shuffle=False, 
                                      drop_last=False)

location_preds = get_predictions(model, test_dl, for_submission=True)
test_df['location'] = location_preds
test_df[['id', 'location']].to_csv('bert_submission.csv', index=False)
pd.read_csv('bert_submission.csv').head()

(5, 6)


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,id,location
0,00016_000,696 699
1,00016_001,668 693
2,00016_002,203 217
3,00016_003,70 91
4,00016_004,
