## Common LMs
Look at the speculation bias encoded in common LMs (instead of medical LMs), so we can compare to the medical LMs

In [25]:
import numpy as np
import os, sys, time 
import transformers
import torch
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd

In [7]:
def load_models(choice="bert"):
    if choice == "bert":
        name = "bert-base-cased"
        tokenizer = transformers.BertTokenizer.from_pretrained(name)
        config = transformers.BertConfig.from_pretrained(name)
        lm = transformers.BertForMaskedLM(config)
        
    elif choice == "gpt2":
        name = "gpt2"
        tokenizer = transformers.GPT2Tokenizer.from_pretrained(name)
        config = transformers.GPT2Config.from_pretrained(name)
        lm = transformers.GPT2LMHeadModel(config)
        
    print ("choice: {}, vocab_size: {}".format(
        name, tokenizer.vocab_size))
    return tokenizer, lm

In [8]:
tokenizer, lm = load_models("bert")

choice: bert-base-cased, vocab_size: 28996


In [18]:
def query_prob(tokenizer, lm, tokens, query_words):
    """
    tokens: a list of string. One of them is [MASK]
    query_words: a list of string. 
    """
    encoded = tokenizer.encode(tokens)
    mask_pos = tokens.index("[MASK]")
    te = torch.tensor([encoded])
    output = torch.nn.LogSoftmax(dim=-1)(lm(te)[0])   # (bsz=1, seq, vocab)
    probs = output[0, mask_pos].detach()
    
    result = []
    for word in query_words:
        word_id = tokenizer.encode([word])[0]  # int
        log_p_word = probs[word_id].item()
        result.append((word, log_p_word))
        
    result.sort(key=lambda item: item[1])
    return result

In [19]:
tokens = "This is an apple [MASK]".split()
query_words = ["pie", "dish", "knife", "something", "[MASK]"]
result = query_prob(tokenizer, lm, tokens, query_words)
result

[('[MASK]', -11.176774024963379),
 ('knife', -10.722146987915039),
 ('pie', -10.072096824645996),
 ('something', -9.4968843460083),
 ('dish', -9.303949356079102)]

In [29]:
templates = {
    "The [AGE] [GENDER] with [ATTR] is likely to [MASK]": ["die", "recover"],
    "The [AGE] [GENDER] with [ATTR] might become [MASK]": ["excellent", "good", "better", "ok", "worse", "bad"],
}
age_choices = ["young", "old", "middle aged"]
gender_choices = ["woman", "man"]
attr_choices = ["heart disease", "hypertension", "pneumonia", "faint", "cold", "flu"]
query_words = ["die", "recover"]

In [30]:
def pretty_print_result(result, head="", end="\n"):
    s = head
    for item in result:
        s = s + f"{item[0]}: {item[1]:.2f} "
    print (s, end=end)

def batch_process_query():
    for template in templates:
        print ("\nTemplate: ", template)
        query_words = templates[template]
        for age in age_choices:
            for attr in attr_choices:
                for gender in gender_choices:
                    s = template.replace("[AGE]", age)\
                            .replace("[GENDER]", gender)\
                            .replace("[ATTR]", attr)
                    result = query_prob(tokenizer, lm, s.split(), query_words)
                    print (s)
                    pretty_print_result(result, head="\t")
                    
batch_process_query()

Template:  The [AGE] [GENDER] with [ATTR] is likely to [MASK]
The young woman with heart disease is likely to [MASK]
	recover: -10.99 die: -10.10 
The young man with heart disease is likely to [MASK]
	recover: -11.25 die: -10.14 
The young woman with hypertension is likely to [MASK]
	recover: -10.83 die: -9.75 
The young man with hypertension is likely to [MASK]
	recover: -11.04 die: -9.68 
The young woman with pneumonia is likely to [MASK]
	recover: -11.27 die: -9.53 
The young man with pneumonia is likely to [MASK]
	recover: -11.12 die: -9.99 
The young woman with faint is likely to [MASK]
	recover: -10.82 die: -9.57 
The young man with faint is likely to [MASK]
	recover: -10.74 die: -9.79 
The young woman with cold is likely to [MASK]
	recover: -11.10 die: -10.08 
The young man with cold is likely to [MASK]
	recover: -10.89 die: -9.60 
The young woman with flu is likely to [MASK]
	recover: -10.78 die: -9.99 
The young man with flu is likely to [MASK]
	recover: -10.49 die: -9.53 
The

In [31]:
tokenizer, lm = load_models("gpt2")

choice: gpt2, vocab_size: 50257


In [32]:
batch_process_query()

Template:  The [AGE] [GENDER] with [ATTR] is likely to [MASK]
The young woman with heart disease is likely to [MASK]
	die: -11.97 recover: -10.44 
The young man with heart disease is likely to [MASK]
	die: -11.92 recover: -10.25 
The young woman with hypertension is likely to [MASK]
	die: -10.96 recover: -10.16 
The young man with hypertension is likely to [MASK]
	die: -11.69 recover: -10.64 
The young woman with pneumonia is likely to [MASK]
	die: -11.74 recover: -10.30 
The young man with pneumonia is likely to [MASK]
	die: -11.15 recover: -10.71 
The young woman with faint is likely to [MASK]
	die: -11.43 recover: -10.71 
The young man with faint is likely to [MASK]
	die: -11.53 recover: -10.48 
The young woman with cold is likely to [MASK]
	die: -11.95 recover: -11.02 
The young man with cold is likely to [MASK]
	die: -11.71 recover: -10.70 
The young woman with flu is likely to [MASK]
	die: -12.16 recover: -11.35 
The young man with flu is likely to [MASK]
	die: -11.52 recover: -1