# Entities as Experts

This notebook is a code implementation of the paper "Entities as Experts: Sparse Memory Access with Entity Supervision" by Févry, Baldini Soares, FitzGerald, Choi, Kwiatowski.

## Problem definition and high-level model description

We want to perform question answering on typical one-shot questions that require external knowledge or context. For example, in order to answer the question "Which country was Charles Darwin born in?" one needs some text providing answers on typical structured scenarios.

In this case, however, we want to rely on knowledge-graph extracted information. For example, in the question given here, we can prune out unrelated to the antropologist and evolution theorist Charles Darwins, e.g. Charles River, Darwin City etc. 

In the paper, the authors propose to augment BERT in the task of cloze-type question answering by leveraging an Entity Memory extracted from e.g. a Knoweldge Graph.

![Entity as Experts description](images/eae_highlevel.png)

The Entity Memory is a simple bunch of embeddings of entities extracted from a Knowledge Graph. Relationships are ignored (see the Facts as Experts paper and notebook to see how they could be used).

## Datasets

> We assume access to a corpus $D={(xi,mi)}$,where all entity mentions are detected but not necessarily  all  linked  to  entities.   We  use  English Wikipedia as our corpus, with a vocabulary of 1m entities. Entity links come from hyperlinks, leading to 32m 128 byte contexts containing 17m entity links.

In the appendix B, it is explained that:

> We build our training corpus of contexts paired with entity mention labels from the 2019-04-14 dump of English Wikipedia. We first divide each article into chunks of 500 bytes,resulting in a corpus of 32 million contexts withover 17 million entity mentions. We restrict our-selves  to  the  one  million  most  frequent  entities
(86% of the linked mentions).

Given that the dump 2019-04-14 is not available at the time of writing, we will adopt the revision 2020-11-01.

Entities are thus partially extracted by link annotations (e.g. they associate with each token a mention if that token belongs to a wikipedia url).

## Mention Detection

> In addition to the Wikipedia links, we annotaten each sentence with unlinked mention spans using the mention detector from Section 2.2

The mention detection head discussed in Section 2.2 is a simple BIO sequence: each token is annotated with a B (beginning), I (inside) or O (outside) if they are respectivelly beginning, inside or outside of a mention. The reason why we use both BIO and EL is to avoid inconsistencies.

There is a catch. In the paper, they explain they used Google NLP APIs to perform entity detection and linking on large-scale Wikipedia entries, that is, to have a properly annotated Wikipedia dataset.

Since we cannot technically afford this, we will use spacy's entity detection and linking capabilities as a baseline. Data quality 

## Chunking

- In theory we should split articles by chunks of 500 bytes (assuming unicode encoding), and contexts are only 128 tokens long. For simplicity by now we only limit ourselves to the first paragraph only.

## Tokenization:

- BERT Tokenizer (e.g. Wordpiece) using lowercase vocabulary, limited to 128 distinct word-piece tokens.

## Learning hyperparameters

For pretraining:

> We use ADAM with a learning rate of 1e-4.  We apply warmup for the first 5% of training, decaying the learning rate afterwards.  We also apply gradient clipping with a norm of 1.0

Since the decaying rate is not provided, we test with 3e-5 which seems quite standard.

## Evaluation

To evaluate:

- TriviaQA
- MetaQA
- (Colla?)

#### Wikipedia

In [8]:
from tools.dataloaders import WikipediaCBOR, BIO

In [9]:
wikipedia_cbor = WikipediaCBOR("wikipedia/car-wiki2020-01-01/enwiki2020.cbor", "wikipedia/car-wiki2020-01-01/partitioned",
                                        num_partitions=100,
                                        max_entity_num=100_000,
                                        token_length=128
                                    )

Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_master


Loaded from cache


In [10]:
wikipedia_cbor[100]

(tensor([ 1999,  2375,  1010,  2019, 10061,  2003,  1037,  4766,  4861,  2008,
          3397,  1996,  2087,  2590,  2685,  1997,  1037,  2146,  3423,  6254,
          2030,  1997,  2195,  3141,  3423,  4981,  1012,  1996, 10061,  1997,
          2516,  1010,  2109,  1999,  2613,  3776, 11817,  1010,  2003,  1996,
          2062,  2691,  2433,  1997, 10061,  1012,  2019, 10061,  1997,  2516,
          7201,  2035,  1996,  5608,  1997,  1037,  3538,  1997,  2455,  1010,
          1037,  2160,  1010,  2030,  1037,  2311,  2077,  2009,  2234,  2046,
          6664,  1997,  1996,  2556,  3954,  1012,  1996, 10061,  2036,  2636,
          2035, 15046,  1055,  1010, 17234,  1010, 14344,  2015,  1010,  1998,
          2060,  5491,  2008,  7461,  6095,  1997,  1996,  3200,  1012,  2019,
         10061,  5577,  1037,  4677,  1997, 15210,  2013,  3954,  2000,  3954,
          1998,  2151, 10540,  2011,  2280,  5608,  2008,  2024,  8031,  2006,
          2101,  5608,  1012,  1037,  3154,  2516,  

In [11]:
import numpy as np

In [12]:
bio_dataset = BIO("ner.csv", 75)

Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_master


In [13]:
FULL_FINETUNING=False

In [14]:
# Frankly this code looks horrible - need to delve into pytorch's dataloader tools API
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, random_split

bs = 32

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

if not FULL_FINETUNING:
    # use only 0.1%  of Wikipedia's dataset
    wiki_use_size = int(0.001 * len(wikipedia_cbor))
    wikipedia_cbor_limited, _ = random_split(wikipedia_cbor,
                                             [wiki_use_size, len(wikipedia_cbor) - wiki_use_size],
                                             generator=torch.Generator().manual_seed(42))
    
    wiki_train_size = int(0.8*len(wikipedia_cbor_limited))
    wiki_validation_size = len(wikipedia_cbor_limited) - wiki_train_size
    
    wikipedia_cbor_train, wikipedia_cbor_validation = random_split(wikipedia_cbor_limited,
                                                                   [wiki_train_size, wiki_validation_size], generator=torch.Generator().manual_seed(42))

else:
    wiki_train_size = int(0.8*len(wikipedia_cbor))
    wiki_validation_size = len(wikipedia_cbor) - wiki_train_size

    wikipedia_cbor_train, wikipedia_cbor_validation = random_split(wikipedia_cbor,
                                                                   [wiki_train_size, wiki_validation_size], generator=torch.Generator().manual_seed(42))

wiki_train_sampler = RandomSampler(wikipedia_cbor_train)
wiki_train_dataloader = DataLoader(wikipedia_cbor_train, sampler=wiki_train_sampler, batch_size=bs)

wiki_validation_sampler = RandomSampler(wikipedia_cbor_validation)
wiki_validation_dataloader = DataLoader(wikipedia_cbor_validation, sampler=wiki_validation_sampler, batch_size=bs)

In [15]:
bio_data = bio_dataset.get_pytorch_dataset()

bio_train_size = int(0.7 * len(bio_data))
bio_validation_size = len(bio_data) - bio_train_size

bio_train, bio_validation = random_split(bio_data, [bio_train_size, bio_validation_size])

bio_train_sampler = RandomSampler(bio_train)
bio_train_dataloader = DataLoader(bio_train, sampler=bio_train_sampler, batch_size=bs)

validation_sampler = RandomSampler(bio_validation)
validation_dataloader = DataLoader(bio_validation, sampler=validation_sampler, batch_size=bs)

## Model

In the paper, the authors explain they used a modified BERT.

In [16]:
for x in bio_train_dataloader:
    print(x[0].dtype)
    print(x[1].dtype)
    print(x[2].dtype)
    
    break

torch.int64
torch.float32
torch.int64


In [17]:
for x in wiki_train_dataloader:
    print(x[0].dtype)
    print(x[1].dtype)
    print(x[2].dtype)
    
    break

torch.int64
torch.float32
torch.int64


### Load and finetune the model

In [18]:
from torch.nn import Module, Linear, Dropout
from transformers.modeling_bert import BertEncoder, BertModel, BertForTokenClassification
from copy import deepcopy

class TruncatedEncoder(Module):
    def __init__(self, encoder: BertEncoder, l0: int):
        super().__init__()
        __doc__ = encoder.__doc__
        self.encoder = deepcopy(encoder)
        self.encoder.layer = self.encoder.layer[:l0]
        
        
    def forward(self, *args, **kwargs):
        __doc__ = self.encoder.forward.__doc__
        return self.encoder(*args, **kwargs)

class TruncatedModel(Module):
    def __init__(self, model: BertModel, l0: int = 4):
        super().__init__()
        self.model = deepcopy(model)
        self.model.encoder = TruncatedEncoder(self.model.encoder, l0)
    
    def forward(self, *args, **kwargs):
        __doc__ = self.model.forward.__doc__
        return self.model(*args, **kwargs)

class BioClassifier(Module):
    """
    BIO classifier head
    """
    def __init__(self,  bertmodel: TruncatedModel):
        super().__init__()
        self.bert = bertmodel
        self.dropout = Dropout(p=0.1)
        self.classifier = Linear(in_features=768, out_features=4, bias=True)
        self.num_labels = 4
    
    def forward(self, *args, **kwargs):
        return BertForTokenClassification.forward(self, *args, **kwargs)
    
class LinkPredictorClassifier(Module):
    """
    Link predictor classifier head
    """
    def __init__(self,  bertmodel: TruncatedModel):
        super().__init__()
        self.bert = bertmodel
        self.dropout = Dropout(p=0.1)
        self.classifier = Linear(in_features=768, out_features=wikipedia_cbor.max_entity_num, bias=True)
        self.num_labels = wikipedia_cbor.max_entity_num
    
    def forward(self, *args, **kwargs):
        return BertForTokenClassification.forward(self, *args, **kwargs)


In [19]:
model = torch.hub.load('huggingface/pytorch-transformers',
                            'model', 'bert-base-uncased')
 
common_model = TruncatedModel(model)
bioclassifier = BioClassifier(common_model).to(device)
linkpredictorclassifier = LinkPredictorClassifier(common_model).to(device)

Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_master


In [22]:
#FULL_FINETUNING = True

def get_optimizer(finetuning_level, model):
    """
    Get an optimizer
    """
    param_optimizer = list(model.named_parameters())

    from transformers import AdamW

    if FULL_FINETUNING:
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}
        ]
    else:
        optimizer_grouped_parameters = [{'params': [p for n, p in param_optimizer]}]

    return AdamW(
        optimizer_grouped_parameters,
        lr=1e-4,
        eps=1e-8
    )

In [23]:
optimizer_bio = get_optimizer(True, bioclassifier)
optimizer_lp = get_optimizer(True, linkpredictorclassifier)

In [24]:
from transformers import get_linear_schedule_with_warmup

max_grad_norm = 1.0

def get_schedule(epochs, optimizer, train_dataloader):
    total_steps = len(train_dataloader) * epochs

    return get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.05*total_steps),
        num_training_steps=total_steps
    )

scheduler_bio = get_schedule(10, optimizer_bio, bio_train_dataloader)
scheduler_lp = get_schedule(10, optimizer_lp, wiki_train_dataloader)

In [21]:
from seqeval.metrics import f1_score, accuracy_score

import numpy as np
from torch.nn.utils import clip_grad_norm_

loss_values, validation_loss_values = [], []

for epoch in range(epochs):
    bioclassifier.train()
    total_loss = 0
    
    for step, batch in enumerate(tqdm(bio_train_dataloader)):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        bioclassifier.zero_grad()
            
        
        outputs = bioclassifier(b_input_ids, token_type_ids=None,
                                attention_mask=b_input_mask, labels=b_labels)
        
        # Someone has to explain to me why someone put the loss function inside a module
        loss = outputs[0]
        loss.backward()
        total_loss += loss.item()
        clip_grad_norm_(parameters=bioclassifier.parameters(),
                        max_norm=max_grad_norm)
    
        optimizer.step()
        scheduler.step()
    
    avg_train_loss = total_loss / len(train_dataloader)
    print("Average train loss: {}".format(avg_train_loss))
    
    loss_values.append(avg_train_loss)
    
    model.eval()
    
    eval_loss, eval_accuracy = 0.0, 0.0
    number_eval_steps, number_eval_examples = 0, 0
    predictions, true_labels = [], []
    
    for batch in bio_validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            outputs = bioclassifier(b_input_ids, token_type_ids=None,
                            attention_mask=b_input_mask, labels=b_labels)
        
        logits = outputs[1].detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        
        eval_loss += outputs[0].mean().item()
        
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.extend(label_ids)
        
    eval_loss = eval_loss / len(validation_dataloader)
    validation_loss_values.append(eval_loss)
    pred_tags = [[bio_values[p_i] for p_i, l_i in zip(p, l)]
                    for p, l in zip(predictions, true_labels)]
    
    true_tags = [[bio_values[l_i] for l_i in l] for l in true_labels]
    
    print(f"Validation Accuracy: {accuracy_score(pred_tags, true_tags)}")
    print(f"Validation F1-Score: {f1_score(pred_tags, true_tags)}")
    print(f"Validation loss: {eval_loss}")
    print()

NameError: name 'epochs' is not defined

In [35]:
for x in wiki_train_dataloader:
    print(x[2].shape)
    break

torch.Size([32, 128])


In [28]:
linkpredictorclassifier.classifier

Linear(in_features=768, out_features=100000, bias=True)

In [33]:
from seqeval.metrics import f1_score, accuracy_score

import numpy as np
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

loss_values, validation_loss_values = [], []

for epoch in range(10):
    """
    linkpredictorclassifier.train()
    total_loss = 0
    
    for batch in tqdm(wiki_train_dataloader):
        #batch = tuple(t.to(device) for t in batch)
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        linkpredictorclassifier.zero_grad()
            
        
        outputs = linkpredictorclassifier(b_input_ids, token_type_ids=None,
                                attention_mask=b_input_mask, labels=b_labels)
        
        # Someone has to explain to me why someone put the loss function inside a module
        loss = outputs[0]
        loss.backward()
        total_loss += loss.item()
        clip_grad_norm_(parameters=linkpredictorclassifier.parameters(),
                        max_norm=max_grad_norm)
    
        optimizer_lp.step()
        scheduler_lp.step()
    
    #avg_train_loss = total_loss / len(wiki_train_dataloader)
    avg_train_loss = total_loss / len(wiki_train_dataloader)
    print("Average train loss: {}".format(avg_train_loss))
    
    loss_values.append(avg_train_loss)
    
    linkpredictorclassifier.eval()
    
    eval_loss, eval_accuracy = 0.0, 0.0
    number_eval_steps, number_eval_examples = 0, 0
    predictions, true_labels = [], []
    """
    
    linkpredictorclassifier.eval()
    
    for step, batch in enumerate(wiki_validation_dataloader):
        print(step)
        
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        print(b_labels)
        print(b_labels.shape)
        
        with torch.no_grad():
            outputs = bioclassifier(b_input_ids, token_type_ids=None,
                            attention_mask=b_input_mask, labels=b_labels)
        
        logits = outputs[1].detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        
        eval_loss += outputs[0].mean().item()
        
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.extend(label_ids)
        
        # This is still quite slow, we need to get it to work fast
        if step >= 100:
            break
        
    eval_loss = eval_loss / len(wiki_validation_dataloader)
    validation_loss_values.append(eval_loss)
    pred_tags = [[bio_values[p_i] for p_i, l_i in zip(p, l)]
                    for p, l in zip(predictions, true_labels)]
    
    true_tags = [[bio_values[l_i] for l_i in l] for l in true_labels]
    
    tqdm.tqdm.write(f"Validation Accuracy: {accuracy_score(pred_tags, true_tags)}")
    tqdm.tqdm.write(f"Validation F1-Score: {f1_score(pred_tags, true_tags)}")
    tqdm.tqdm.write(f"Validation loss: {eval_loss}")
    tqdm.tqdm.write()

0
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
torch.Size([32, 128])


IndexError: Target 4920 is out of bounds.

In [23]:
def transform_sentence(sentence: str):
    tokens = tokenizer.tokenize(sentence)
    print(tokens)
    padded = pad_sequences([tokenizer.convert_tokens_to_ids(tokens)], maxlen=MAX_LEN,
                  dtype="long", value=0.0, truncating="post", padding="post")
    
    attention_mask = [[float(tok != 0.0) for tok in padded_] for padded_ in padded]
    
    return padded, attention_mask

# bioclassifier.forward(tokens,)

padded, attention = transform_sentence("Hello world, this is Spongebob!")

bioclassifier.eval()
res = bioclassifier.forward(torch.tensor(padded).to(device), token_type_ids=None, attention_mask=torch.tensor(attention).to(device), labels=None)

['hello', 'world', ',', 'this', 'is', 'sponge', '##bo', '##b', '!']


NameError: name 'bioclassifier' is not defined

In [12]:
from torch.nn import Module, Embedding, Dropout, ModuleList, Linear
import torch.nn as nn
import torch
import math

GELU = torch.nn.GELU
LayerNorm = torch.nn.LayerNorm

l0 = 4
l1 = 8

    
class EntityMemory(Module):
    """
    Entity Memory, as described in the paper
    """
    def __init__(self, embedding_size: int, entity_size: int,
                   entity_embedding_size: int):
        """
        :param embedding_size the size of an embedding. In the EaE paper it is called d_emb, previously as d_k
            (attention_heads * embedding_per_head)
        :param entity_size also known as N in the EaE paper, the maximum number of entities we store
        :param entity_embedding_size also known as d_ent in the EaE paper, the embedding of each entity
        
        """
        self.N = entity_size
        self.d_ent = entity_embedding_size
        self.w_f = Linear(d_ent, 2*embedding_size)
        
    def forward(self, x, entity_spans, num_entities, k=None):
        """
        :param x the (raw) output of the first transformer block. It has a shape:
                B x N x (embed_size)
        :param entity_spans entities and spans of such entities.
                Shape: B x C x 3. Each "row" contains a triple (e_k, s_mi, t_mi)
                where e_k is an (encoded) entity id, s_mi and t_mi are indices.
        :param num_entities the number of found entities for each batch.
        :param k the number of nearest entities to consider when softmax-ing.
                if k = None, all the entities are used.
                In the paper, one should set k for when running the inference
        """
        
        mentions = [entity_spans[:, :mentions_per_batch] for mentions_per_batch in num_entities]
        pass
        