Simply debias BERT by optimizing the log odds ratio

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
from overrides import overrides
%matplotlib inline

In [3]:
import sys
sys.path.append("../lib")

In [4]:
from bert_utils import Config, BertPreprocessor
config = Config(
    model_type="bert-base-uncased",
    max_seq_len=128,
    batch_size=64,
    consistency_weight=1.,
    prior_precomputed=True,
    testing=True,
)

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
T = TypeVar("T")
TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]

In [7]:
processor = BertPreprocessor(config.model_type, config.max_seq_len)

In [8]:
DATA_ROOT = Path("../data")
MODEL_SAVE_DIR = Path("../weights")

Read the model in here

In [9]:
from pytorch_pretrained_bert import BertConfig, BertForMaskedLM
masked_lm = BertForMaskedLM.from_pretrained(config.model_type)
masked_lm.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
       

# The Dataset

In [10]:
from allennlp.data.token_indexers import PretrainedBertIndexer

def flatten(x: List[List[T]]) -> List[T]:
        return [item for sublist in x for item in sublist]

token_indexer = PretrainedBertIndexer(
    pretrained_model=config.model_type,
    max_pieces=config.max_seq_len,
    do_lowercase=True,
 )

def tokenizer(s: str):
    return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]

In [11]:
from allennlp.data.vocabulary import Vocabulary
global_vocab = Vocabulary()

In [12]:
# # record the prior
# with torch.no_grad():
#     bert_input = (self.token_indexers["tokens"]
#                   .tokens_to_indices(input_toks, global_vocab, "tokens"))
#     token_ids = torch.LongTensor(bert_input["tokens"]).unsqueeze(0)

### Dataset

In [13]:
import csv
from allennlp.data import DatasetReader, Instance, Token
from allennlp.data.fields import (TextField, SequenceLabelField, LabelField, 
                                  MetadataField, ArrayField)

class LongArrayField(ArrayField):
    @overrides
    def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
        tensor = torch.from_numpy(self.array)
        return tensor
    
class FloatArrayField(ArrayField):
    @overrides
    def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.FloatTensor:
        tensor = torch.FloatTensor(self.array)
        return tensor

class DebiasingDatasetReader(DatasetReader):
    def __init__(self, tokenizer, token_indexers, 
                 prior_precomputed: bool=False) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers
        self.vocab = token_indexers["tokens"].vocab
        self._prior_precomputed = prior_precomputed

    def _proc(self, x):
        if x == "[MASK]": return x
        else: return x.lower()
        
    @overrides
    def text_to_instance(self, tokens: List[str], w1: str, w2: str, 
                         p1: Optional[float], p2: Optional[float]) -> Instance:
        fields = {}
        input_toks = [Token(self._proc(x)) for x in tokens]
        fields["input"] = TextField(input_toks, self.token_indexers)        
        # take [CLS] token into account
        mask_position = tokens.index("[MASK]") + 1
        fields["mask_positions"] = LongArrayField(
            np.array(mask_position, dtype=np.int64),
         )
        fields["target_ids"] = LongArrayField(np.array([
            self.vocab[w1], self.vocab[w2],
        ], dtype=np.int64))
                
        if self._prior_precomputed:
            fields["prior_prob_sum"] = FloatArrayField(np.array(p1 + p2, dtype=np.float32))
        else:
            with torch.no_grad():
                bert_input = (self.token_indexers["tokens"]
                              .tokens_to_indices(input_toks, global_vocab, "tokens"))
                token_ids = torch.LongTensor(bert_input["tokens"]).unsqueeze(0)
                probs = masked_lm(token_ids)[0, mask_position, :].detach().numpy()
                probs = (probs - probs.max())
                probs = probs.exp() / probs.exp().sum()
                fields["prior_prob_sum"] = \
                    FloatArrayField(np.array(probs[self.vocab[w1]] + probs[self.vocab[w2]],
                               dtype=np.float32))
            
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        p1, p2 = 0., 0.
        with open(file_path, "rt") as f:
            reader = csv.reader(f)
            for row in reader:
                if self._prior_precomputed: sentence, w1, w2, p1, p2 = row
                else: sentence, w1, w2 = row
                yield self.text_to_instance(
                    self.tokenizer(sentence), 
                    w1, w2, # words
                    float(p1), float(p2), # prior probs
                )

In [14]:
reader = DebiasingDatasetReader(tokenizer=tokenizer, 
                                token_indexers={"tokens": token_indexer},
                                prior_precomputed=config.prior_precomputed)
train_ds, val_ds = (reader.read(DATA_ROOT / fname) for fname in ["sample_w_probs.csv", "sample_w_probs.csv"])

10000it [00:00, 16097.23it/s]
10000it [00:00, 18546.61it/s]


### Data Iterator

In [15]:
from allennlp.data.iterators import BucketIterator

iterator = BucketIterator(
        batch_size=config.batch_size, 
        biggest_batch_first=config.testing,
        sorting_keys=[("input", "num_tokens")],
    )
iterator.index_with(global_vocab)

Sanity check

In [16]:
batch = next(iter(iterator(train_ds)))

In [17]:
batch

{'input': {'tokens': tensor([[  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102],
          [  101,   103,  2003,  1037,  2160, 28478,   102]]),
  '

# Model and Loss

### The loss function

In [18]:
def mse_loss(x, y): return ((x - y) ** 2).mean()
def mae_loss(x, y): return (x - y).abs().mean()
class HingeLoss(nn.Module):
    def __init__(self, margin: float=0.1):
        super().__init__()
        self.margin = margin
    def forward(self, x, y):
        return torch.relu((x - y).abs().mean() - self.margin)

In [19]:
def _log_likelihood(logits, target_logits) -> torch.FloatTensor:
    max_logits = logits.max(1, keepdim=True)[0]
    log_exp_sum_logits = (logits - max_logits).exp().sum(1).log()
    log_exp_sum_correct_logits = (target_logits - max_logits).exp().sum(1).log()
    return log_exp_sum_logits - log_exp_sum_correct_logits

def likelihood(logits, # (batch, V)
               target_logits, # (batch, 2)
               prior_prob_sum, # (batch, )
     ):
    """log likelihood of either of the target ids being chosen"""
    return _log_likelihood(logits, target_logits).mean()

class Consistency(nn.Module):
    def __init__(self, distance: Callable):
        super().__init__()
        self._distance = distance
    
    def forward(self, logits, # (batch, V)
                target_logits, # (batch, 2)
                prior_prob_sum, # (batch, )
               ):
        """
        Constrains prob sum put on two words to be roughly equal
        TODO: Provide some probabilistic/statistical interpretation
        """
        l = _log_likelihood(logits, target_logits)
        return self._distance(l, prior_prob_sum.log())

In [20]:
class BiasLoss(nn.Module):
    """
    Returns the deviation of the log odds ratio from its desired value.
    Denoting the probs as p and q there are several options available:
        - MSE(log p, log q)
        - Max-margin loss
    TODO: Add option to set the optimal log odds ratio
    TODO: Ensure the logits do not change significantly
    """
    def __init__(self, loss_func: Callable=mae_loss,
                 consistency_loss_func: Callable=likelihood,
                 consistency_weight: float=1.):
        super().__init__()
        self.loss_func = loss_func
        self._consistency_loss = consistency_loss_func
        self.consistency_weight = consistency_weight
        
    def forward(self, logits: torch.FloatTensor, # (batch, seq, V)
                mask_positions: torch.LongTensor, # (batch, )
                target_ids: torch.LongTensor, # (batch, 2)
                prior_prob_sum: torch.FloatTensor, # (batch, )
               ) -> torch.FloatTensor:
        """
        input_ids: Numericalized tokens
        mask_position: Positions of mask tokens
        target_ids: Ids of target tokens to compute log odds on
        """
        bs, seq = logits.size(0), logits.size(1)

        # Gather the logits for at the masked positions
        # TODO: More efficient implementation?
        # Gather copies the data to create a new tensor which we would rather avoid
        sel = (mask_positions.unsqueeze(1)
                .unsqueeze(2).expand(bs, 1, logits.size(2))) # (batch, 1, V)
        logits_at_masked_positions = logits.gather(1, sel).squeeze(1) # (batch, V)
        
        # Gather the logits for the target ids
        sel = target_ids
        target_logits_at_masked_positions = logits_at_masked_positions.gather(1, sel).squeeze(1) # (batch, 2)
        
        bias_loss = self.loss_func(
            target_logits_at_masked_positions[:, 0], # male logits
            target_logits_at_masked_positions[:, 1], # female logits
         )
        consistency_loss = self._consistency_loss(
            logits_at_masked_positions, 
            target_logits_at_masked_positions, # pass target logits
            prior_prob_sum,
         )
        return bias_loss + consistency_loss * self.consistency_weight

### The allennlp model (for training)

In [21]:
from allennlp.models import Model

class BERT(Model):
    def __init__(self, vocab, bert_for_masked_lm, loss: nn.Module=BiasLoss()):
        super().__init__(vocab)
        self.bert_for_masked_lm = bert_for_masked_lm
        self.loss = loss
    
    def forward(self, 
                input: TensorDict,
                mask_positions: torch.LongTensor,
                target_ids: torch.LongTensor,
                prior_prob_sum: torch.FloatTensor,
            ) -> TensorDict:
        logits = self.bert_for_masked_lm(input["tokens"])
        out_dict = {"loss": self.loss(logits, mask_positions, 
                                      target_ids, prior_prob_sum)}
        out_dict["logits"] = logits
        return out_dict

In [22]:
logit_distance = mae_loss

loss = BiasLoss(
    loss_func=logit_distance,
    consistency_loss_func=Consistency(logit_distance),
    consistency_weight=config.consistency_weight,
)
model = BERT(global_vocab, masked_lm, loss=loss)

In [23]:
init_dict = dict(model.state_dict())

In [24]:
model.load_state_dict(init_dict)

### Bias scores before

In [25]:
rev_vocab = {v: k for k, v in token_indexer.vocab.items()}

def ttoi(t: str): return token_indexer.vocab[t]
def itot(i: int): return rev_vocab[i]

In [26]:
masked_lm.eval()
logits = masked_lm(processor.to_bert_model_input("[MASK] is a housemaid"))[0, 1]

In [27]:
logits[ttoi("he")]

tensor(8.1598, grad_fn=<SelectBackward>)

In [28]:
logits[ttoi("she")]

tensor(8.8144, grad_fn=<SelectBackward>)

# Training Loop

In [29]:
from allennlp.training import Trainer

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=None,
    cuda_device=0 if torch.cuda.is_available() else -1,
    num_epochs=1,
)

You provided a validation dataset but patience was set to None, meaning that early stopping is disabled


In [30]:
trainer.train()

loss: 1.1806 ||: 100%|██████████| 157/157 [11:31<00:00,  4.60s/it]
loss: 1.0594 ||: 100%|██████████| 157/157 [02:57<00:00,  1.25s/it]


{'best_epoch': 0,
 'peak_cpu_memory_MB': 1209.479168,
 'training_duration': '00:14:29',
 'training_start_epoch': 0,
 'training_epochs': 0,
 'epoch': 0,
 'training_loss': 1.180585659233628,
 'training_cpu_memory_MB': 1209.479168,
 'validation_loss': 1.059408180083439,
 'best_validation_loss': 1.059408180083439}

# Evaluate

Simple prediction

In [31]:
def to_np(t): return t.detach().cpu().numpy()

In [32]:
def to_words(arr):
    if len(arr.shape) > 1:
        return [to_words(a) for a in arr]
    else:
        arr = to_np(arr)
        return " ".join([itot(i) for i in arr])

In [33]:
def get_preds(model, batch: TensorDict):
    return model(**batch)["logits"].argmax(2)

In [34]:
to_words(batch["input"]["tokens"])

['[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]',
 '[CLS] [MASK] is a house ##maid [SEP]']

In [35]:
to_words(get_preds(model, batch))

['. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .',
 '. he she she he she .']

### Logits and bias

In [36]:
masked_lm.eval()
logits = masked_lm(processor.to_bert_model_input("[MASK] is a housemaid"))[0, 1]

In [37]:
logits[ttoi("he")]

tensor(12.8591, grad_fn=<SelectBackward>)

In [38]:
logits[ttoi("she")]

tensor(12.8306, grad_fn=<SelectBackward>)

# Export Weights

As PyTorch state dict

In [39]:
torch.save(masked_lm.state_dict(), MODEL_SAVE_DIR / "state_dict.pth")

TODO: Export as tensorflow checkpoint?