In [1]:
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)
        
config = Config(
    model_type="bert-base-uncased",
    max_seq_len=128,
)

In [3]:
T = TypeVar('T')
def flatten(x: List[List[T]]) -> List[T]:
    return [item for sublist in x for item in sublist]

In [4]:
from allennlp.common.util import get_spacy_model
from spacy.attrs import ORTH
from spacy.tokenizer import Tokenizer

nlp = get_spacy_model("en_core_web_sm", pos_tags=False, parse=True, ner=False)
nlp.tokenizer.add_special_case("[MASK]", [{ORTH: "[MASK]"}])
def spacy_tok(s: str):
    return [w.text for w in nlp(s)]o

In [5]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import PretrainedBertIndexer
from allennlp.data.tokenizers import Token

token_indexer = PretrainedBertIndexer(
    pretrained_model=config.model_type,
    max_pieces=config.max_seq_len,
    do_lowercase=True,
 )

# apparently we need to truncate the sequence here, which is a stupid design decision
def tokenize(x: str) -> List[Token]:
        return [Token(w) for w in flatten([
                token_indexer.wordpiece_tokenizer(w)
                for w in spacy_tok(x)]
        )[:config.max_seq_len]]

In [42]:
from pytorch_pretrained_bert import BertConfig, BertForMaskedLM
model = BertForMaskedLM.from_pretrained(config.model_type)
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
   

In [7]:
from allennlp.data import Vocabulary

vocab = Vocabulary()
token_indexer._add_encoding_to_vocabulary(vocab)

In [8]:
def get_logits(input_sentence: str) -> torch.Tensor:
    input_toks = tokenize(input_sentence)
    batch = token_indexer.tokens_to_indices(input_toks, vocab, "tokens")
    token_ids = torch.LongTensor(batch["tokens"]).unsqueeze(0)
    with torch.no_grad():
        out_logits = model(token_ids).squeeze(0)
    return out_logits.detach().cpu().numpy()

In [9]:
full_vocab = {v:k for k, v in token_indexer.vocab.items()}

def indices_to_words(indices: Iterable[int]) -> List[str]:
    return [full_vocab[x] for x in indices]

In [10]:
indices_to_words(get_logits("he is very [MASK].").argmax(1))

['.', 'he', 'is', 'very', '[', '.', '.']

In [11]:
indices_to_words(get_logits("he is [MASK].").argmax(1))

['.', 'he', 'is', '[', '.', '.']

In [12]:
indices_to_words(get_logits("she is [MASK].").argmax(1))

['.', 'she', 'is', '[', '.', '.']

In [13]:
indices_to_words(get_logits("she is very [MASK].").argmax(1))

['.', 'she', 'is', 'very', '[', '.', '.']

The usual stuff

In [14]:
indices_to_words(get_logits("[MASK] is a programmer.").argmax(1))

['.', 'he', 'is', 'a', 'programmer', '.', '.']

In [15]:
indices_to_words(get_logits("[MASK] is a nurse.").argmax(1))

['.', 'she', 'is', 'a', 'nurse', '.', '.']

In [16]:
indices_to_words(get_logits("[MASK] is a programmer.").argmax(1))

['.', 'he', 'is', 'a', 'programmer', '.', '.']

Measuring difference

In [17]:
male_logits = get_logits("he is very [MASK].")[4, :]
female_logits = get_logits("she is very [MASK].")[4, :]

In [18]:
male_logits

array([-7.765704, -7.8001  , -8.097227, ..., -6.347876, -8.464294,
       -4.750696], dtype=float32)

In [19]:
def softmax(x, axis=0, eps=1e-9):
    e = np.exp(x)
    return e / (e.sum(axis, keepdims=True) + eps)

In [20]:
male_probs = softmax(male_logits)
female_probs = softmax(female_logits)

In [21]:
male_probs

array([1.13422584e-10, 1.09587665e-10, 8.14181014e-11, ...,
       4.68224848e-10, 5.64034479e-11, 2.31260189e-09], dtype=float32)

In [22]:
msk = ((male_probs >= 1e-6) & (female_probs >= 1e-6))
male_probs = male_probs[msk]
female_probs = female_probs[msk]

In [23]:
[(pos + 1, full_vocab[i]) for i, pos in enumerate((male_probs / female_probs).argsort()) if pos < 10]

[(3, '[unused5]'),
 (6, '[unused6]'),
 (7, '[unused9]'),
 (4, '[unused18]'),
 (1, '[unused32]'),
 (8, '[unused38]'),
 (9, '[unused47]'),
 (5, '[unused51]'),
 (2, '[unused54]'),
 (10, '[unused69]')]

In [24]:
[(pos + 1, full_vocab[i]) for i, pos in enumerate((female_probs / male_probs).argsort()) if pos < 10]

[(10, '[unused7]'),
 (2, '[unused22]'),
 (5, '[unused25]'),
 (9, '[unused29]'),
 (8, '[unused38]'),
 (1, '[unused44]'),
 (4, '[unused58]'),
 (7, '[unused67]'),
 (6, '[unused70]'),
 (3, '[unused71]')]

# Construct measure of bias

In [25]:
input_sentence = "[MASK] is intelligent"

In [26]:
def _get_mask_index(toks: Iterable[Token]) -> int:
    for i, t in enumerate(toks):
        if t.text == "[MASK]":
            return i + 1 # take the [CLS] token into account
    raise ValueError("No [MASK] token found")

In [27]:
def get_logits(input_sentence: str, n_calc: int=10) -> np.ndarray:
    """
    n_calc: Since the logits are non-deterministic, 
    computing the logits multiple times might be better
    """
    input_toks = tokenize(input_sentence)
    batch = token_indexer.tokens_to_indices(input_toks, vocab, "tokens")
    token_ids = torch.LongTensor(batch["tokens"]).unsqueeze(0)
    
    logits = None
    for _ in range(n_calc):
        with torch.no_grad():
            out_logits = model(token_ids).squeeze(0)
        if logits is None: logits = np.zeros(out_logits.shape)
        logits += out_logits.detach().cpu().numpy()
    return logits / n_calc

In [28]:
def get_logit_scores(input_sentence: str, words: int) -> Dict[str, float]:
    out_logits = get_logits(input_sentence)
    input_toks = tokenize(input_sentence)
    i = _get_mask_index(input_toks)
    return {w: out_logits[i, token_indexer.vocab[w]] for w in words}

def get_log_odds(input_sentence: str, word1: str, word2: str) -> float:
    scores = get_logit_scores(input_sentence, (word1, word2))
    return scores[word1] - scores[word2]

In [29]:
get_logit_scores("[MASK] is intelligent.", ["she", "he"])

{'she': 9.103129959106445, 'he': 9.641182708740235}

In [30]:
get_log_odds("[MASK] is intelligent.", "she", "he")

-0.5490983009338368

Surprisingly, marriage is more strongly associated with he than she

In [31]:
get_log_odds("[MASK] is married.", "she", "he")

-1.9561370372772213

In [32]:
get_log_odds("[MASK] is alive.", "she", "he")

-0.3699408054351796

In [33]:
get_log_odds("[MASK] is a person.", "she", "he")

-0.17259473800659286

In [34]:
get_log_odds("[MASK] is a doctor.", "she", "he")

-0.8532533645629883

In [35]:
get_log_odds("[MASK] is my mother.", "she", "he")

4.602996683120728

In [36]:
get_log_odds("[MASK] is my father.", "she", "he")

-3.791827392578125

This is strange...

In [37]:
get_log_odds("[MASK] is female.", "she", "he")

-1.9437612771987904

In [38]:
get_log_odds("[MASK] is ugly.", "she", "he")

-0.5187945365905762

This is strange too...

In [39]:
get_log_odds("[MASK] is male.", "she", "he")

0.39930248260498047

In [40]:
get_log_odds("[MASK] is a housewife", "she", "he")

1.9183058738708505

In [41]:
get_log_odds("[MASK] is a girl", "she", "he")

1.2518535137176512