Simply debias BERT by optimizing the log odds ratio

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import pandas as pd
import numpy as np
from pathlib import Path
from typing import *
import matplotlib.pyplot as plt
from overrides import overrides
%matplotlib inline

In [None]:
import sys
sys.path.append("../lib")

In [None]:
from bert_utils import Config, BertPreprocessor
config = Config(
    model_type="bert-base-uncased",
    max_seq_len=128,
    batch_size=64,
    testing=True,
)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
T = TypeVar("T")
TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]

In [None]:
processor = BertPreprocessor(config.model_type, config.max_seq_len)

In [None]:
DATA_ROOT = Path("../data")

# The Dataset

In [None]:
from allennlp.data.token_indexers import PretrainedBertIndexer

def flatten(x: List[List[T]]) -> List[T]:
        return [item for sublist in x for item in sublist]

token_indexer = PretrainedBertIndexer(
    pretrained_model=config.model_type,
    max_pieces=config.max_seq_len,
    do_lowercase=True,
 )

def tokenizer(s: str):
    return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]

### Dataset

In [None]:
import csv
from allennlp.data import DatasetReader, Instance, Token
from allennlp.data.fields import (TextField, SequenceLabelField, LabelField, 
                                  MetadataField, ArrayField)

class LongArrayField(ArrayField):
    @overrides
    def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:
        tensor = torch.from_numpy(self.array)
        return tensor

class DebiasingDatasetReader(DatasetReader):
    def __init__(self, tokenizer, token_indexers) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers
        self.vocab = token_indexers["tokens"].vocab

    def _proc(self, x):
        if x == "[MASK]": return x
        else: return x.lower()
        
    @overrides
    def text_to_instance(self, tokens: List[str], w1: str, w2: str) -> Instance:
        fields = {}
        fields["input"] = TextField([Token(self._proc(x)) for x in tokens],
                                   self.token_indexers)        
        # take [CLS] token into account
        fields["mask_positions"] = LongArrayField(np.array(tokens.index("[MASK]") + 1, dtype=np.int64),
                                             )
        fields["target_ids"] = LongArrayField(np.array([
            self.vocab[w1], self.vocab[w2],
        ], dtype=np.int64))
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path, "rt") as f:
            reader = csv.reader(f)
            for row in reader:
                sentence, w1, w2 = row
                yield self.text_to_instance(
                    self.tokenizer(sentence), w1, w2,
                )

In [None]:
reader = DebiasingDatasetReader(tokenizer=tokenizer, 
                                token_indexers={"tokens": token_indexer})
train_ds, val_ds = (reader.read(DATA_ROOT / fname) for fname in ["sample.csv", "sample.csv"])

### Data Iterator

In [None]:
from allennlp.data.iterators import BucketIterator
from allennlp.data.vocabulary import Vocabulary

iterator = BucketIterator(
        batch_size=config.batch_size, 
        biggest_batch_first=config.testing,
        sorting_keys=[("input", "num_tokens")],
    )
iterator.index_with(Vocabulary())

Sanity check

In [None]:
batch = next(iter(iterator(train_ds)))

In [None]:
batch

# Model and Loss

### The model

In [None]:
from pytorch_pretrained_bert import BertConfig, BertForMaskedLM
masked_lm = BertForMaskedLM.from_pretrained(config.model_type)

Sanity check

In [None]:
masked_lm(batch["input"]["tokens"]).shape

### The loss function

In [None]:
def mse_loss(x, y): return ((x - y) ** 2).mean()

class BiasLoss(nn.Module):
    """
    Returns the deviation of the log odds ratio from its desired value.
    Denoting the probs as p and q there are several options available:
        - MSE(log p, log q)
        - Max-margin loss
    TODO: Add option to set the optimal log odds ratio
    """
    def __init__(self, loss_func: Callable=mse_loss):
        super().__init__()
        self.loss_func = loss_func
        
    def forward(self, logits: torch.FloatTensor, # (batch, seq, V)
                mask_positions: torch.LongTensor, # (batch, )
                target_ids: torch.LongTensor, # (batch, seq, 2)
               ) -> torch.FloatTensor:
        """
        input_ids: Numericalized tokens
        mask_position: Positions of mask tokens
        target_ids: Ids of target tokens to compute log odds on
        """
        bs, seq = logits.size(0), logits.size(1)

        # Gather the logits for each target id
        # TODO: More efficient implementation?
        # Gather copies the data to create a new tensor which we would rather avoid        
        sel = target_ids.unsqueeze(1).expand(bs, seq, 2)
        target_logits = logits.gather(2, sel) # (batch, seq, 2)
        
        # Gather the logits for each masked position in the sequence
        sel = (mask_positions.unsqueeze(1)
                .unsqueeze(2).expand(bs, 1, 2)) # (batch, 1, 2)
        target_logits_at_masked_positions = target_logits.gather(1, sel).squeeze(1) # (batch, 2)
        
        return self.loss_func(
            target_logits_at_masked_positions[:, 0], # male logits
            target_logits_at_masked_positions[:, 1], # female logits
         )

### The allennlp model (for training)

In [None]:
from allennlp.models import Model

class BERT(Model):
    def __init__(self, vocab, bert_for_masked_lm):
        super().__init__(vocab)
        self.bert_for_masked_lm = bert_for_masked_lm
        self.loss = BiasLoss()
    
    def forward(self, 
                input: TensorDict,
                mask_positions: torch.LongTensor,
                target_ids: torch.LongTensor,
            ) -> TensorDict:
        logits = self.bert_for_masked_lm(input["tokens"])
        out_dict = {"loss": self.loss(logits, mask_positions, target_ids)}
        out_dict["logits"] = logits
        return out_dict

In [None]:
model = BERT(vocab, masked_lm)

# Training Loop

In [None]:
from allennlp.training import Trainer

optimizer = torch.optim.Adam(masked_lm.parameters(), lr=0.01)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=None,
    cuda_device=0 if torch.cuda.is_available() else -1,
    num_epochs=1,
)

In [None]:
trainer.train()

# Evaluate