In [14]:
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
%matplotlib inline

In [15]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)
        
config = Config(
    model_type="bert-base-uncased",
    max_seq_len=128,
)

In [16]:
T = TypeVar('T')
def flatten(x: List[List[T]]) -> List[T]:
    return [item for sublist in x for item in sublist]

In [17]:
from allennlp.common.util import get_spacy_model
from spacy.attrs import ORTH
from spacy.tokenizer import Tokenizer

nlp = get_spacy_model("en_core_web_sm", pos_tags=False, parse=True, ner=False)
nlp.tokenizer.add_special_case("[MASK]", [{ORTH: "[MASK]"}])
def spacy_tok(s: str):
    return [w.text for w in nlp(s)]

In [18]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import PretrainedBertIndexer
from allennlp.data.tokenizers import Token

token_indexer = PretrainedBertIndexer(
    pretrained_model=config.model_type,
    max_pieces=config.max_seq_len,
    do_lowercase=True,
 )

# apparently we need to truncate the sequence here, which is a stupid design decision
def tokenize(x: str) -> List[Token]:
        return [Token(w) for w in flatten([
                token_indexer.wordpiece_tokenizer(w)
                for w in spacy_tok(x)]
        )[:config.max_seq_len]]

In [55]:
from pytorch_pretrained_bert import BertConfig, BertForMaskedLM
model = BertForMaskedLM.from_pretrained(config.model_type)
model.eval()# https://jamesmccaffrey.wordpress.com/2019/01/23/pytorch-train-vs-eval-mode/

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
   

In [56]:
from allennlp.data import Vocabulary

vocab = Vocabulary()
token_indexer._add_encoding_to_vocabulary(vocab)

In [57]:
def get_logits(input_sentence: str) -> torch.Tensor:
    input_toks = tokenize(input_sentence)
    batch = token_indexer.tokens_to_indices(input_toks, vocab, "tokens")
    token_ids = torch.LongTensor(batch["tokens"]).unsqueeze(0)
    with torch.no_grad():
        out_logits = model(token_ids).squeeze(0)
    return out_logits.detach().cpu().numpy()

In [58]:
full_vocab = {v:k for k, v in token_indexer.vocab.items()}

def indices_to_words(indices: Iterable[int]) -> List[str]:
    return [full_vocab[x] for x in indices]

In [59]:
indices_to_words(get_logits("he is very [MASK].").argmax(1))

['.', 'he', 'is', 'very', '[', '.', '.']

In [60]:
indices_to_words(get_logits("he is [MASK].").argmax(1))

['.', 'he', 'is', '[', '.', '.']

In [61]:
indices_to_words(get_logits("she is [MASK].").argmax(1))

['.', 'she', 'is', '[', '.', '.']

In [62]:
indices_to_words(get_logits("she is very [MASK].").argmax(1))

['.', 'she', 'is', 'very', '[', '.', '.']

The usual stuff

In [63]:
indices_to_words(get_logits("[MASK] is a programmer.").argmax(1))

['.', 'he', 'is', 'a', 'programmer', '.', '.']

In [64]:
indices_to_words(get_logits("[MASK] is a nurse.").argmax(1))

['.', 'she', 'is', 'a', 'nurse', '.', '.']

In [65]:
indices_to_words(get_logits("[MASK] is a programmer.").argmax(1))

['.', 'he', 'is', 'a', 'programmer', '.', '.']

Measuring difference

In [66]:
male_logits = get_logits("he is very [MASK].")[4, :]
female_logits = get_logits("she is very [MASK].")[4, :]

In [67]:
male_logits

array([-8.156933 , -8.080509 , -8.314136 , ..., -6.6777625, -8.160287 ,
       -3.6486151], dtype=float32)

In [68]:
def softmax(x, axis=0, eps=1e-9):
    e = np.exp(x)
    return e / (e.sum(axis, keepdims=True) + eps)

In [69]:
male_probs = softmax(male_logits)
female_probs = softmax(female_logits)

In [70]:
male_probs

array([2.7336199e-11, 2.9507230e-11, 2.3359623e-11, ..., 1.1998684e-10,
       2.7244661e-11, 2.4812792e-09], dtype=float32)

In [71]:
msk = ((male_probs >= 1e-6) & (female_probs >= 1e-6))
male_probs = male_probs[msk]
female_probs = female_probs[msk]

In [72]:
[(pos + 1, full_vocab[i]) for i, pos in enumerate((male_probs / female_probs).argsort()) if pos < 10]

[(1, '[unused1]'),
 (2, '[unused7]'),
 (5, '[unused74]'),
 (9, '[unused77]'),
 (3, '[unused90]'),
 (10, '[unused91]'),
 (6, '[unused92]'),
 (7, '[unused94]'),
 (4, '[unused111]'),
 (8, '[unused115]')]

In [73]:
[(pos + 1, full_vocab[i]) for i, pos in enumerate((female_probs / male_probs).argsort()) if pos < 10]

[(8, '[unused0]'),
 (4, '[unused4]'),
 (7, '[unused25]'),
 (6, '[unused27]'),
 (10, '[unused28]'),
 (3, '[unused29]'),
 (9, '[unused42]'),
 (5, '[unused45]'),
 (2, '[unused108]'),
 (1, '[unused114]')]

# Construct measure of bias

In [74]:
input_sentence = "[MASK] is intelligent"

In [75]:
def _get_mask_index(toks: Iterable[Token]) -> int:
    for i, t in enumerate(toks):
        if t.text == "[MASK]":
            return i + 1 # take the [CLS] token into account
    raise ValueError("No [MASK] token found")

In [76]:
def get_logits(input_sentence: str, n_calc: int=10) -> np.ndarray:
    """
    n_calc: Since the logits are non-deterministic, 
    computing the logits multiple times might be better
    """
    input_toks = tokenize(input_sentence)
    batch = token_indexer.tokens_to_indices(input_toks, vocab, "tokens")
    token_ids = torch.LongTensor(batch["tokens"]).unsqueeze(0)
    
    logits = None
    for _ in range(n_calc):
        with torch.no_grad():
            out_logits = model(token_ids).squeeze(0)
        if logits is None: logits = np.zeros(out_logits.shape)
        logits += out_logits.detach().cpu().numpy()
    return logits / n_calc

In [77]:
def get_logit_scores(input_sentence: str, words: int) -> Dict[str, float]:
    out_logits = get_logits(input_sentence)
    input_toks = tokenize(input_sentence)
    i = _get_mask_index(input_toks)
    return {w: out_logits[i, token_indexer.vocab[w]] for w in words}

def get_log_odds(input_sentence: str, word1: str, word2: str) -> float:
    scores = get_logit_scores(input_sentence, (word1, word2))
    return scores[word1] - scores[word2]

In [78]:
get_logit_scores("[MASK] is intelligent.", ["she", "he"])

{'she': 9.330872535705566, 'he': 9.864638328552246}

In [79]:
get_log_odds("[MASK] is intelligent.", "she", "he")

-0.5337657928466797

Surprisingly, marriage is more strongly associated with he than she

In [80]:
get_log_odds("[MASK] is married.", "she", "he")

-2.4174156188964844

In [81]:
get_log_odds("[MASK] is alive.", "she", "he")

-0.39806365966796875

In [82]:
get_log_odds("[MASK] is a person.", "she", "he")

-0.058249473571777344

In [83]:
get_log_odds("[MASK] is a doctor.", "she", "he")

-0.9748430252075195

In [48]:
get_log_odds("[MASK] is my mother.", "she", "he")

3.96632866859436

In [49]:
get_log_odds("[MASK] is my father.", "she", "he")

-4.462116336822509

This is strange...

In [50]:
get_log_odds("[MASK] is female.", "she", "he")

-2.151384162902832

In [51]:
get_log_odds("[MASK] is ugly.", "she", "he")

-0.46579909324645996

This is strange too...

In [52]:
get_log_odds("[MASK] is male.", "she", "he")

0.3173854351043701

In [53]:
get_log_odds("[MASK] is a housewife", "she", "he")

2.246202230453491

In [54]:
get_log_odds("[MASK] is a girl", "she", "he")

1.3769546985626224