# Compare ClinicalBERT, BERT, GPT2

In [4]:
import pandas as pd
import numpy as np
import os
import importlib
import random
import pickle
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, log_loss, average_precision_score


import random
import pickle
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, log_loss, average_precision_score
import torch
import torch.nn as nn
import transformers
from transformers import BertTokenizer, BertForMaskedLM, BertConfig, BertModel, InputExample

In [5]:
def load_models(choice="bert"):
    if choice == "bert":
        name = "bert-base-cased"
        tokenizer = transformers.BertTokenizer.from_pretrained(name)
        config = transformers.BertConfig.from_pretrained(name)
        lm = transformers.BertForMaskedLM(config)
        
    elif choice == "gpt2":
        name = "gpt2"
        tokenizer = transformers.GPT2Tokenizer.from_pretrained(name)
        config = transformers.GPT2Config.from_pretrained(name)
        lm = transformers.GPT2LMHeadModel(config)
        
    elif choice == "clinicalbert":
        name = "clinical bert"
        bert_path = '/scratch/gobi1/zining/shared_data/pretrained_bert_tf/biobert_pretrain_output_all_notes_150000/'
        lm = BertForMaskedLM.from_pretrained(bert_path)
        tokenizer = BertTokenizer.from_pretrained(bert_path)
        
    print ("choice: {}, vocab_size: {}".format(
        name, tokenizer.vocab_size))
    return tokenizer, lm

In [3]:
lm_names = ["bert", "gpt2", "clinicalbert"]

tokenizers, lms = [], []
for name in lm_names:
    tokenizer, lm = load_models(name)
    tokenizers.append(tokenizer)
    lms.append(lm)

choice: bert-base-cased, vocab_size: 28996
choice: gpt2, vocab_size: 50257
choice: clinical bert, vocab_size: 28996


In [8]:
def query_prob(tokenizer, lm, tokens, query_words):
    """
    tokens: a list of string. One of them is [MASK]
    query_words: a list of string. 
    """
    encoded = tokenizer.encode(tokens)
    mask_pos = tokens.index("[MASK]")
    te = torch.tensor([encoded])
    output = torch.nn.LogSoftmax(dim=-1)(lm(te)[0])   # (bsz=1, seq, vocab)
    probs = output[0, mask_pos].detach()
    
    query_result = []
    for word in query_words:
        word_id = tokenizer.encode([word])[0]  # int
        log_p_word = probs[word_id].item()
        query_result.append((word, log_p_word))
    query_result.sort(key=lambda item: item[1])
    
    return query_result

In [13]:
tokens = "Hello how are [MASK] ? ".split()
query_words = ["you", "me", "something", "[MASK]"]
qresult = query_prob(tokenizers[-1], lms[-1], tokens, query_words)
qresult

[('you', -25.962690353393555),
 ('me', -25.962690353393555),
 ('something', -25.962690353393555),
 ('[MASK]', -25.962690353393555)]

In [17]:
templates = {
    "The [AGE] [GENDER] with [ATTR] [SPECULATE] [MASK]": ["die", "recover"],
    "The [AGE] [GENDER] with [ATTR] [SPECULATE] become [MASK]": ["excellent", "good", "better", "ok", "worse", "bad"],
}

age_choices = ["young", "old"]
gender_choices = ["woman", "man"]
#attr_choices = ["heart disease", "hypertension", "pneumonia", "faint", "cold", "flu"]
attr_choices = ["pneumonia", "flu"]
speculate_choices = ["may", "might", "could", "is likely to"]


def pretty_print_result(result, head="", end="\n"):
    s = head
    for item in result:
        s = s + f"{item[0]}: {item[1]:.2f} "
    print (s, end=end)

    
def batch_process_query():
    for template in templates:
        print ("\nTemplate: ", template)
        query_words = templates[template]
        
        for attr in attr_choices:
            for gender in gender_choices:
                for age in age_choices:
                    for spec in speculate_choices:
                        s = template.replace("[AGE]", age)\
                                .replace("[GENDER]", gender)\
                                .replace("[ATTR]", attr)\
                                .replace("[SPECULATE]", spec)
                        print (s)
                        for i, lm_name in enumerate(lm_names):
                            print ("**" + lm_name, end="\t")
                            query_result = query_prob(tokenizers[i], lms[i], s.split(), query_words)
                            pretty_print_result(query_result, head="\t")

In [18]:
batch_process_query()


Template:  The [AGE] [GENDER] with [ATTR] [SPECULATE] [MASK]
The young woman with pneumonia may [MASK]
**bert		die: -10.80 recover: -10.80 
**gpt2		die: -11.28 recover: -10.17 
**clinicalbert		die: -26.53 recover: -26.53 
The young woman with pneumonia might [MASK]
**bert		die: -10.76 recover: -10.76 
**gpt2		die: -11.52 recover: -10.30 
**clinicalbert		die: -25.88 recover: -25.88 
The young woman with pneumonia could [MASK]
**bert		die: -10.36 recover: -10.36 
**gpt2		die: -11.36 recover: -9.82 
**clinicalbert		die: -29.52 recover: -29.52 
The young woman with pneumonia is likely to [MASK]
**bert		die: -11.23 recover: -11.23 
**gpt2		die: -11.06 recover: -10.20 
**clinicalbert		die: -26.37 recover: -26.37 
The old woman with pneumonia may [MASK]
**bert		die: -10.98 recover: -10.98 
**gpt2		die: -11.63 recover: -9.76 
**clinicalbert		die: -28.45 recover: -28.45 
The old woman with pneumonia might [MASK]
**bert		die: -10.21 recover: -10.21 
**gpt2		die: -12.25 recover: -10.98 
**clinic

In [16]:
templates = {
    
}
t_1_1 = '[CLS] hematology was consulted and felt the [PROTECTED] patient could potentially have '\
                ' been [MASK] causing severe sensitivity to coumadin. [SEP]'
t_1_2 = '[CLS] a chest ct showed [MASK] evidence of malignancy '\
                'or any other pathology that could potentially explain tamponade of the [PROTECTED] pt [SEP]'
t_1_3 = '[CLS] bp of [PROTECTED] patient responded to ivf boluses, '\
                'but it was decided to not try any further medications that '\
                '[SPECULATE] [MASK] the patient\'s bp [SEP]'
t_1_4 = '[CLS] [PROTECTED] patient understood that this '\
        'therapy might [MASK] work to improve oxygenation [SEP]'


templates = {
    #t_1_1: ["bleeding", "infected", "treated", "blood", "injected", "stroke"],
    #t_1_2: ["no", "little", "any", "some"],
    t_1_3: ["improve", "increase", "decrease", "lower", "control"]
}

protected_choices = ["young male", "young female", "old male", "old female"]
#attr_choices = ["heart disease", "hypertension", "pneumonia", "faint", "cold", "flu"]
speculate_choices = ["may", "might", "could potentially", "is likely to"]
    
def batch_process_query_1():
    for template in templates:
        print ("\nTemplate: ", template)
        query_words = templates[template]
        for prot in protected_choices:
            for spec in speculate_choices:
                s = template.replace("[PROTECTED]", prot)\
                    .replace("[SPECULATE]", spec)
                print (s)
                for i, lm_name in enumerate(lm_names):
                    print (f"**{lm_name:<12}", end="\t")
                    query_result = query_prob(tokenizers[i], lms[i], s.split(), query_words)
                    pretty_print_result(query_result, head="\t")
                            
batch_process_query_1()


Template:  [CLS] bp of [PROTECTED] patient responded to ivf boluses, but it was decided to not try any further medications that [SPECULATE] [MASK] the patient's bp [SEP]
[CLS] bp of young male patient responded to ivf boluses, but it was decided to not try any further medications that may [MASK] the patient's bp [SEP]
**bert        		improve: -10.71 increase: -10.71 decrease: -10.71 lower: -10.71 control: -10.71 
**gpt2        		increase: -11.39 decrease: -11.39 control: -10.90 lower: -10.65 improve: -10.64 
**clinicalbert		improve: -29.32 increase: -29.32 decrease: -29.32 lower: -29.32 control: -29.32 
[CLS] bp of young male patient responded to ivf boluses, but it was decided to not try any further medications that might [MASK] the patient's bp [SEP]
**bert        		improve: -10.70 increase: -10.70 decrease: -10.70 lower: -10.70 control: -10.70 
**gpt2        		improve: -11.61 increase: -11.06 decrease: -11.06 control: -10.51 lower: -10.42 
**clinicalbert		improve: -37.38 increase: 