In [93]:
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
%matplotlib inline

In [94]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)
        
config = Config(
    model_type="bert-base-uncased",
    max_seq_len=128,
)

In [95]:
T = TypeVar('T')
def flatten(x: List[List[T]]) -> List[T]:
    return [item for sublist in x for item in sublist]

In [96]:
from allennlp.common.util import get_spacy_model
from spacy.attrs import ORTH
from spacy.tokenizer import Tokenizer

nlp = get_spacy_model("en_core_web_sm", pos_tags=False, parse=True, ner=False)
nlp.tokenizer.add_special_case("[MASK]", [{ORTH: "[MASK]"}])
def spacy_tok(s: str):
    return [w.text for w in nlp(s)]

In [97]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import PretrainedBertIndexer
from allennlp.data.tokenizers import Token

token_indexer = PretrainedBertIndexer(
    pretrained_model=config.model_type,
    max_pieces=config.max_seq_len,
    do_lowercase=True,
 )

# apparently we need to truncate the sequence here, which is a stupid design decision
def tokenize(x: str) -> List[Token]:
        return [Token(w) for w in flatten([
                token_indexer.wordpiece_tokenizer(w)
                for w in spacy_tok(x)]
        )[:config.max_seq_len]]

In [98]:
from pytorch_pretrained_bert import BertConfig, BertForMaskedLM
# model = BertForMaskedLM.from_pretrained(config.model_type)
model = BertForMaskedLM.from_pretrained("../../transformers/examples/language-modeling/output/")
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
   

In [99]:
from allennlp.data import Vocabulary

vocab = Vocabulary()
token_indexer._add_encoding_to_vocabulary(vocab)

In [100]:
def get_logits(input_sentence: str) -> torch.Tensor:
    input_toks = tokenize(input_sentence)
    batch = token_indexer.tokens_to_indices(input_toks, vocab, "tokens")
    token_ids = torch.LongTensor(batch["tokens"]).unsqueeze(0)
    with torch.no_grad():
        out_logits = model(token_ids).squeeze(0)
    return out_logits.detach().cpu().numpy()

In [101]:
full_vocab = {v:k for k, v in token_indexer.vocab.items()}

def indices_to_words(indices: Iterable[int]) -> List[str]:
    return [full_vocab[x] for x in indices]

In [102]:
indices_to_words(get_logits("he is very [MASK].").argmax(1))

['.', 'he', 'is', 'very', '[', '.', '.']

In [103]:
indices_to_words(get_logits("he is [MASK].").argmax(1))

['.', 'he', 'is', '[', '.', '.']

In [104]:
indices_to_words(get_logits("she is [MASK].").argmax(1))

['.', 'she', 'is', '[', '|', '.']

In [105]:
indices_to_words(get_logits("she is very [MASK].").argmax(1))

['.', 'she', 'is', 'a', '[', '.', '.']

The usual stuff

In [106]:
indices_to_words(get_logits("[MASK] is a programmer.").argmax(1))

['.', 'he', 'is', 'a', 'programmer', '.', '.']

In [107]:
indices_to_words(get_logits("[MASK] is a nurse.").argmax(1))

['.', 'she', 'is', 'a', 'nurse', '.', '.']

In [108]:
indices_to_words(get_logits("[MASK] is a programmer.").argmax(1))

['.', 'he', 'is', 'a', 'programmer', '.', '.']

Measuring difference

In [109]:
male_logits = get_logits("he is very [MASK].")[4, :]
female_logits = get_logits("she is very [MASK].")[4, :]

In [110]:
male_logits

array([-7.5189896, -7.565958 , -7.6195374, ..., -6.732218 , -6.303575 ,
       -2.1873262], dtype=float32)

In [111]:
def softmax(x, axis=0, eps=1e-9):
    e = np.exp(x)
    return e / (e.sum(axis, keepdims=True) + eps)

In [112]:
male_probs = softmax(male_logits)
female_probs = softmax(female_logits)

In [113]:
male_probs

array([4.2089968e-10, 4.0158774e-10, 3.8063719e-10, ..., 9.2441982e-10,
       1.4191438e-09, 8.7034323e-08], dtype=float32)

In [114]:
msk = ((male_probs >= 1e-6) & (female_probs >= 1e-6))
male_probs = male_probs[msk]
female_probs = female_probs[msk]

In [115]:
[(pos + 1, full_vocab[i]) for i, pos in enumerate((male_probs / female_probs).argsort()) if pos < 10]

[(2, '[unused44]'),
 (7, '[unused73]'),
 (6, '[unused86]'),
 (3, '[unused112]'),
 (5, '[unused153]'),
 (9, '[unused175]'),
 (10, '[unused229]'),
 (4, '[unused341]'),
 (1, '[unused433]'),
 (8, '[unused521]')]

In [116]:
[(pos + 1, full_vocab[i]) for i, pos in enumerate((female_probs / male_probs).argsort()) if pos < 10]

[(8, '[unused148]'),
 (1, '[unused236]'),
 (4, '[unused328]'),
 (10, '[unused440]'),
 (9, '[unused494]'),
 (5, '[unused516]'),
 (3, '[unused557]'),
 (6, '[unused587]'),
 (7, '[unused600]'),
 (2, '[unused629]')]

# Construct measure of bias

In [117]:
input_sentence = "[MASK] is intelligent"

In [118]:
def _get_mask_index(toks: Iterable[Token]) -> int:
    for i, t in enumerate(toks):
        if t.text == "[MASK]":
            return i + 1 # take the [CLS] token into account
    raise ValueError("No [MASK] token found")

In [119]:
def get_logits(input_sentence: str, n_calc: int=10) -> np.ndarray:
    """
    n_calc: Since the logits are non-deterministic, 
    computing the logits multiple times might be better
    """
    input_toks = tokenize(input_sentence)
    batch = token_indexer.tokens_to_indices(input_toks, vocab, "tokens")
    token_ids = torch.LongTensor(batch["tokens"]).unsqueeze(0)
    
    logits = None
    for _ in range(n_calc):
        with torch.no_grad():
            out_logits = model(token_ids).squeeze(0)
        if logits is None: logits = np.zeros(out_logits.shape)
        logits += out_logits.detach().cpu().numpy()
    return logits / n_calc

In [120]:
def get_logit_scores(input_sentence: str, words: int) -> Dict[str, float]:
    out_logits = get_logits(input_sentence)
    input_toks = tokenize(input_sentence)
    i = _get_mask_index(input_toks)
    return {w: out_logits[i, token_indexer.vocab[w]] for w in words}

def get_log_odds(input_sentence: str, word1: str, word2: str) -> float:
    scores = get_logit_scores(input_sentence, (word1, word2))
    return scores[word1] - scores[word2]

In [121]:
get_logit_scores("[MASK] is intelligent.", ["she", "he"])

{'she': 8.559272861480713, 'he': 9.23280210494995}

In [122]:
get_log_odds("[MASK] is intelligent.", "she", "he")

-0.6292662620544434

Surprisingly, marriage is more strongly associated with he than she

In [123]:
get_log_odds("[MASK] is married.", "she", "he")

-1.513838148117065

In [124]:
get_log_odds("[MASK] is alive.", "she", "he")

-0.41614518165588343

In [125]:
get_log_odds("[MASK] is a person.", "she", "he")

-0.04594726562499929

In [126]:
get_log_odds("[MASK] is a doctor.", "she", "he")

-1.199558591842651

In [127]:
get_log_odds("[MASK] is my mother.", "she", "he")

4.188584804534913

In [128]:
get_log_odds("[MASK] is my father.", "she", "he")

-4.232120871543884

This is strange...

In [129]:
get_log_odds("[MASK] is female.", "she", "he")

-1.4870113849639885

In [130]:
get_log_odds("[MASK] is ugly.", "she", "he")

-0.27642731666564924

This is strange too...

In [131]:
get_log_odds("[MASK] is male.", "she", "he")

0.3119628906249998

In [132]:
get_log_odds("[MASK] is a housewife", "she", "he")

1.6833960533142092

In [133]:
get_log_odds("[MASK] is a girl", "she", "he")

1.6631692409515377