# Compare ClinicalBERT, BERT, GPT2

In [9]:
import pandas as pd
import numpy as np
import os
import importlib
import random
import pickle
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, log_loss, average_precision_score


import random
import pickle
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, log_loss, average_precision_score
import torch
import torch.nn as nn
import transformers
from transformers import BertTokenizer, BertForMaskedLM, BertConfig, BertModel, InputExample

In [10]:
def load_models(choice="bert"):
    if choice == "bert":
        name = "bert-base-cased"
        tokenizer = transformers.BertTokenizer.from_pretrained(name)
        config = transformers.BertConfig.from_pretrained(name)
        lm = transformers.BertForMaskedLM(config)
        
    elif choice == "gpt2":
        name = "gpt2"
        tokenizer = transformers.GPT2Tokenizer.from_pretrained(name)
        config = transformers.GPT2Config.from_pretrained(name)
        lm = transformers.GPT2LMHeadModel(config)
        
    elif choice == "clinicalbert":
        name = "clinical bert"
        bert_path = '/scratch/gobi1/zining/shared_data/pretrained_bert_tf/biobert_pretrain_output_all_notes_150000/'
        lm = BertForMaskedLM.from_pretrained(bert_path)
        tokenizer = BertTokenizer.from_pretrained(bert_path)
        
    print ("choice: {}, vocab_size: {}".format(
        name, tokenizer.vocab_size))
    return tokenizer, lm

In [17]:
lm_names = ["bert", "gpt2", "clinicalbert"]

tokenizers, lms = [], []
for name in lm_names:
    tokenizer, lm = load_models(name)
    tokenizers.append(tokenizer)
    lms.append(lm)

choice: bert-base-cased, vocab_size: 28996
choice: gpt2, vocab_size: 50257
choice: clinical bert, vocab_size: 28996


In [6]:
def query_prob(tokenizer, lm, tokens, query_words):
    """
    tokens: a list of string. One of them is [MASK]
    query_words: a list of string. 
    """
    encoded = tokenizer.encode(tokens)
    mask_pos = tokens.index("[MASK]")
    te = torch.tensor([encoded])
    output = torch.nn.LogSoftmax(dim=-1)(lm(te)[0])   # (bsz=1, seq, vocab)
    probs = output[0, mask_pos].detach()
    
    query_result = []
    for word in query_words:
        word_id = tokenizer.encode([word])[0]  # int
        log_p_word = probs[word_id].item()
        query_result.append((word, log_p_word))
    query_result.sort(key=lambda item: item[1])
    
    return query_result

In [13]:
tokens = "Hello how are [MASK] ? ".split()
query_words = ["you", "me", "something", "[MASK]"]
qresult = query_prob(tokenizers[-1], lms[-1], tokens, query_words)
qresult

[('you', -25.962690353393555),
 ('me', -25.962690353393555),
 ('something', -25.962690353393555),
 ('[MASK]', -25.962690353393555)]

In [22]:
templates = {
    "The [AGE] [GENDER] with [ATTR] [SPECULATE] [MASK]": ["die", "recover"],
    "The [AGE] [GENDER] with [ATTR] [SPECULATE] become [MASK]": ["excellent", "good", "better", "ok", "worse", "bad"],
}

age_choices = ["young", "old"]
gender_choices = ["woman", "man"]
#attr_choices = ["heart disease", "hypertension", "pneumonia", "faint", "cold", "flu"]
attr_choices = ["pneumonia", "flu"]
speculate_choices = ["might", "could", "is likely to"]


def pretty_print_result(result, head="", end="\n"):
    s = head
    for item in result:
        s = s + f"{item[0]}: {item[1]:.2f} "
    print (s, end=end)

    
def batch_process_query():
    for template in templates:
        print ("\nTemplate: ", template)
        query_words = templates[template]
        
        for spec in speculate_choices:
            for attr in attr_choices:
                for gender in gender_choices:
                    for age in age_choices:
                        s = template.replace("[AGE]", age)\
                                .replace("[GENDER]", gender)\
                                .replace("[ATTR]", attr)\
                                .replace("[SPECULATE]", spec)
                        print (s)
                        for i, lm_name in enumerate(lm_names):
                            print ("**" + lm_name, end="\t")
                            query_result = query_prob(tokenizers[i], lms[i], s.split(), query_words)
                            pretty_print_result(query_result, head="\t")

In [23]:
batch_process_query()


Template:  The [AGE] [GENDER] with [ATTR] [SPECULATE] [MASK]
The young woman with pneumonia might [MASK]
**bert		die: -10.77 recover: -10.77 
**gpt2		die: -11.20 recover: -10.55 
**clinicalbert		die: -25.88 recover: -25.88 
The old woman with pneumonia might [MASK]
**bert		die: -10.88 recover: -10.88 
**gpt2		recover: -10.68 die: -10.56 
**clinicalbert		die: -25.42 recover: -25.42 
The young man with pneumonia might [MASK]
**bert		die: -10.60 recover: -10.60 
**gpt2		die: -11.26 recover: -10.05 
**clinicalbert		die: -25.15 recover: -25.15 
The old man with pneumonia might [MASK]
**bert		die: -10.79 recover: -10.79 
**gpt2		recover: -10.56 die: -10.25 
**clinicalbert		die: -26.16 recover: -26.16 
The young woman with flu might [MASK]
**bert		die: -10.25 recover: -10.25 
**gpt2		die: -10.84 recover: -10.73 
**clinicalbert		die: -30.31 recover: -30.31 
The old woman with flu might [MASK]
**bert		die: -10.93 recover: -10.93 
**gpt2		die: -10.93 recover: -10.82 
**clinicalbert		die: -29.07