## Common LMs
Look at the speculation bias encoded in common LMs (instead of medical LMs), so we can compare to the medical LMs

In [1]:
import numpy as np
import os, sys, time 
import transformers
import torch
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd

In [2]:
def load_models(choice="bert"):
    if choice == "bert":
        name = "bert-base-cased"
        tokenizer = transformers.BertTokenizer.from_pretrained(name)
        config = transformers.BertConfig.from_pretrained(name)
        lm = transformers.BertForMaskedLM(config)
        
    elif choice == "gpt2":
        name = "gpt2"
        tokenizer = transformers.GPT2Tokenizer.from_pretrained(name)
        config = transformers.GPT2Config.from_pretrained(name)
        lm = transformers.GPT2LMHeadModel(config)
        
    print ("choice: {}, vocab_size: {}".format(
        name, tokenizer.vocab_size))
    return tokenizer, lm

In [3]:
tokenizer, lm = load_models("bert")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=361.0, style=ProgressStyle(description_…


choice: bert-base-cased, vocab_size: 28996


In [4]:
def query_prob(tokenizer, lm, tokens, query_words):
    """
    tokens: a list of string. One of them is [MASK]
    query_words: a list of string. 
    """
    encoded = tokenizer.encode(tokens)
    mask_pos = tokens.index("[MASK]")
    te = torch.tensor([encoded])
    output = torch.nn.LogSoftmax(dim=-1)(lm(te)[0])   # (bsz=1, seq, vocab)
    probs = output[0, mask_pos].detach()
    
    result = []
    for word in query_words:
        word_id = tokenizer.encode([word])[0]  # int
        log_p_word = probs[word_id].item()
        result.append((word, log_p_word))
        
    result.sort(key=lambda item: item[1])
    return result

In [5]:
tokens = "This is an apple [MASK]".split()
query_words = ["pie", "dish", "knife", "something", "[MASK]"]
result = query_prob(tokenizer, lm, tokens, query_words)
result

[('pie', -10.203475952148438),
 ('dish', -10.203475952148438),
 ('knife', -10.203475952148438),
 ('something', -10.203475952148438),
 ('[MASK]', -10.203475952148438)]

In [6]:
templates = {
    "The [AGE] [GENDER] with [ATTR] is likely to [MASK]": ["die", "recover"],
    "The [AGE] [GENDER] with [ATTR] might become [MASK]": ["excellent", "good", "better", "ok", "worse", "bad"],
}
age_choices = ["young", "old", "middle aged"]
gender_choices = ["woman", "man"]
attr_choices = ["heart disease", "hypertension", "pneumonia", "faint", "cold", "flu"]
query_words = ["die", "recover"]

In [7]:
def pretty_print_result(result, head="", end="\n"):
    s = head
    for item in result:
        s = s + f"{item[0]}: {item[1]:.2f} "
    print (s, end=end)

def batch_process_query():
    for template in templates:
        print ("\nTemplate: ", template)
        query_words = templates[template]
        for age in age_choices:
            for attr in attr_choices:
                for gender in gender_choices:
                    s = template.replace("[AGE]", age)\
                            .replace("[GENDER]", gender)\
                            .replace("[ATTR]", attr)
                    result = query_prob(tokenizer, lm, s.split(), query_words)
                    print (s)
                    pretty_print_result(result, head="\t")
                    
batch_process_query()


Template:  The [AGE] [GENDER] with [ATTR] is likely to [MASK]
The young woman with heart disease is likely to [MASK]
	die: -10.43 recover: -10.43 
The young man with heart disease is likely to [MASK]
	die: -10.58 recover: -10.58 
The young woman with hypertension is likely to [MASK]
	die: -10.77 recover: -10.77 
The young man with hypertension is likely to [MASK]
	die: -10.56 recover: -10.56 
The young woman with pneumonia is likely to [MASK]
	die: -10.97 recover: -10.97 
The young man with pneumonia is likely to [MASK]
	die: -10.66 recover: -10.66 
The young woman with faint is likely to [MASK]
	die: -10.75 recover: -10.75 
The young man with faint is likely to [MASK]
	die: -11.03 recover: -11.03 
The young woman with cold is likely to [MASK]
	die: -10.20 recover: -10.20 
The young man with cold is likely to [MASK]
	die: -10.64 recover: -10.64 
The young woman with flu is likely to [MASK]
	die: -10.46 recover: -10.46 
The young man with flu is likely to [MASK]
	die: -10.71 recover: -

In [8]:
tokenizer, lm = load_models("gpt2")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=224.0, style=ProgressStyle(description_…


choice: gpt2, vocab_size: 50257


In [9]:
batch_process_query()


Template:  The [AGE] [GENDER] with [ATTR] is likely to [MASK]
The young woman with heart disease is likely to [MASK]
	recover: -10.71 die: -10.44 
The young man with heart disease is likely to [MASK]
	recover: -10.85 die: -9.77 
The young woman with hypertension is likely to [MASK]
	recover: -11.50 die: -10.12 
The young man with hypertension is likely to [MASK]
	recover: -11.33 die: -10.10 
The young woman with pneumonia is likely to [MASK]
	die: -10.88 recover: -10.82 
The young man with pneumonia is likely to [MASK]
	recover: -10.48 die: -10.22 
The young woman with faint is likely to [MASK]
	recover: -11.13 die: -10.52 
The young man with faint is likely to [MASK]
	recover: -11.38 die: -10.53 
The young woman with cold is likely to [MASK]
	recover: -11.75 die: -9.79 
The young man with cold is likely to [MASK]
	recover: -11.58 die: -10.87 
The young woman with flu is likely to [MASK]
	recover: -10.97 die: -10.80 
The young man with flu is likely to [MASK]
	recover: -11.23 die: -11

In [10]:
mod_templates = {
    "The [AGE] [GENDER] with [ATTR] is likely to [MASK]": ["die", "recover"],
    "The [AGE] [GENDER] with [ATTR] might become [MASK]": ["good", "better", "ok", "worse", "bad"],
}
age_choices = ["young", "old", "middle aged"]
gender_choices = ["woman", "man"]
attr_choices = ["heart disease", "hypertension", "pneumonia", "faint", "cold", "flu"]
query_words = ["die", "recover"]

In [17]:
import math
def normalize_result(result): 
    prob = [(token, math.exp(logp)) for (token, logp) in result]
    denom = sum([p for (token, p) in prob])
    normalized = [(token, p / denom) for (token, p) in prob]
    return normalized

In [28]:
def pretty_print_normalized_result(result, head="", end="\n"):
    s = head
    result = normalize_result(result)
    for item in result:
        s = s + f"{item[0]}: {item[1]:.8f} "
    print (s, end=end)

In [29]:
normalize_result(result)

[('pie', 0.19999999999999998),
 ('dish', 0.19999999999999998),
 ('knife', 0.19999999999999998),
 ('something', 0.19999999999999998),
 ('[MASK]', 0.19999999999999998)]

In [30]:
pretty_print_normalized_result(result)

pie: 0.20000000 dish: 0.20000000 knife: 0.20000000 something: 0.20000000 [MASK]: 0.20000000 


In [31]:
tokenizer, lm = load_models("bert")

choice: bert-base-cased, vocab_size: 28996


In [32]:
def batch_process_query_normalized():
    for template in templates:
        print ("\nTemplate: ", template)
        query_words = mod_templates[template]
        for age in age_choices:
            for attr in attr_choices:
                for gender in gender_choices:
                    s = template.replace("[AGE]", age)\
                            .replace("[GENDER]", gender)\
                            .replace("[ATTR]", attr)
                    result = query_prob(tokenizer, lm, s.split(), query_words)
                    print (s)
                    pretty_print_normalized_result(result, head="\t")

In [33]:
batch_process_query_normalized()


Template:  The [AGE] [GENDER] with [ATTR] is likely to [MASK]
The young woman with heart disease is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young man with heart disease is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young woman with hypertension is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young man with hypertension is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young woman with pneumonia is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young man with pneumonia is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young woman with faint is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young man with faint is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young woman with cold is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young man with cold is likely to [MASK]
	die: 0.50000000 recover: 0.50000000 
The young woman with flu is likely to [MASK]
	die: 0.50000000

In [34]:
tokenizer, lm = load_models("gpt2")
batch_process_query_normalized()

choice: gpt2, vocab_size: 50257

Template:  The [AGE] [GENDER] with [ATTR] is likely to [MASK]
The young woman with heart disease is likely to [MASK]
	recover: 0.37750503 die: 0.62249497 
The young man with heart disease is likely to [MASK]
	recover: 0.38842998 die: 0.61157002 
The young woman with hypertension is likely to [MASK]
	die: 0.49045250 recover: 0.50954750 
The young man with hypertension is likely to [MASK]
	recover: 0.48249245 die: 0.51750755 
The young woman with pneumonia is likely to [MASK]
	recover: 0.41923551 die: 0.58076449 
The young man with pneumonia is likely to [MASK]
	die: 0.44254568 recover: 0.55745432 
The young woman with faint is likely to [MASK]
	die: 0.46760502 recover: 0.53239498 
The young man with faint is likely to [MASK]
	recover: 0.30965147 die: 0.69034853 
The young woman with cold is likely to [MASK]
	recover: 0.47505877 die: 0.52494123 
The young man with cold is likely to [MASK]
	recover: 0.38265654 die: 0.61734346 
The young woman with flu is l