# Entities as Experts

This notebook is a code implementation of the paper "Entities as Experts: Sparse Memory Access with Entity Supervision" by Févry, Baldini Soares, FitzGerald, Choi, Kwiatowski.

## Problem definition and high-level model description

We want to perform question answering on typical one-shot questions that require external knowledge or context. For example, in order to answer the question "Which country was Charles Darwin born in?" one needs some text providing answers on typical structured scenarios.

In this case, however, we want to rely on knowledge-graph extracted information. For example, in the question given here, we can prune out unrelated to the antropologist and evolution theorist Charles Darwins, e.g. Charles River, Darwin City etc. 

In the paper, the authors propose to augment BERT in the task of cloze-type question answering by leveraging an Entity Memory extracted from e.g. a Knoweldge Graph.

![Entity as Experts description](images/eae_highlevel.png)

The Entity Memory is a simple bunch of embeddings of entities extracted from a Knowledge Graph. Relationships are ignored (see the Facts as Experts paper and notebook to see how they could be used).

## Datasets

> We assume access to a corpus $D={(xi,mi)}$,where all entity mentions are detected but not necessarily  all  linked  to  entities.   We  use  English Wikipedia as our corpus, with a vocabulary of 1m entities. Entity links come from hyperlinks, leading to 32m 128 byte contexts containing 17m entity links.

In the appendix B, it is explained that:

> We build our training corpus of contexts paired with entity mention labels from the 2019-04-14 dump of English Wikipedia. We first divide each article into chunks of 500 bytes,resulting in a corpus of 32 million contexts withover 17 million entity mentions. We restrict our-selves  to  the  one  million  most  frequent  entities
(86% of the linked mentions).

Given that the dump 2019-04-14 is not available at the time of writing, we will adopt the revision 2020-11-01.

Entities are thus partially extracted by link annotations (e.g. they associate with each token a mention if that token belongs to a wikipedia url).

## Mention Detection

> In addition to the Wikipedia links, we annotaten each sentence with unlinked mention spans using the mention detector from Section 2.2

The mention detection head discussed in Section 2.2 is a simple BIO sequence: each token is annotated with a B (beginning), I (inside) or O (outside) if they are respectivelly beginning, inside or outside of a mention. The reason why we use both BIO and EL is to avoid inconsistencies.

There is a catch. In the paper, they explain they used Google NLP APIs to perform entity detection and linking on large-scale Wikipedia entries, that is, to have a properly annotated Wikipedia dataset. We are going to use simple Wikipedia hyperlinks insteal (TODO: consider adding spacy annotation?).

NOTE FOR MYSELF: We don't *actually* perform entity linking here - as in, we don't train a classifier from the first $l_0$ layers. Instead we'll build pseudo embeddings and try to find the entity that best matches a pseudo embedding (see later).

HOWEVER, we do need that when training. The whole idea is that we get supervised data when training, however it is not always the case that we have this data at hand (see: TriviaQA, WebQuestions, ...).

## Entity Memory

The idea is pretty simple: we have as input $X_1$ and the mention spans $m_i = (e_i, s_{m_i}, t_{m_i})$ . Those are given as input. We don't care about e_i for the embedding calculation, but we DO care for the loss definition.

(Glossing over the Entity Memory calculation...)

When the entity detection is supervised, our obtained entity should be close to the found pseudo entity embedding.

$$
ELLoss = \sum_{m_i} \alpha_i \cdot \mathbb{1}_{e_{m_i} \ne e_{\emptyset}} , \qquad \alpha = softmax(E \cdot h_{m_i})
$$

($E$ is our `EntEmbed`, so $E * h_{m_i}$ is a vector, so $\alpha$ is a vector too (of shape $N$), and $h_{m_i}$ is a "pseudo entity embedding" ).

## Chunking
- In theory we should split articles by chunks of 500 bytes (assuming unicode encoding), and contexts are only 128 tokens long. For simplicity by now we only limit ourselves to the first paragraph only.

## Tokenization:

- BERT Tokenizer (e.g. Wordpiece) using lowercase vocabulary, limited to 128 distinct word-piece tokens.

## Learning hyperparameters

For pretraining:

> We use ADAM with a learning rate of 1e-4.  We apply warmup for the first 5% of training, decaying the learning rate afterwards.  We also apply gradient clipping with a norm of 1.0

Since the decaying rate is not provided, we test with 3e-5 which seems quite standard.

## Evaluation

To evaluate:

- [X] SQuAD
- [ ] TriviaQA
- [ ] MetaQA
- [ ] WebQuestions
- [ ] Colla?

#### Wikipedia

In [1]:
import wandb

wandb.init(project="EntitiesAsExperts")
wandb.config.device = "cuda"

from tools.dataloaders import WikipediaCBOR
from models import EntitiesAsExperts, EaEForQuestionAnswering

[34m[1mwandb[0m: Currently logged in as: [33merolm_a[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.17 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [2]:
wikipedia_cbor = WikipediaCBOR("wikipedia/car-wiki2020-01-01/enwiki2020.cbor", "wikipedia/car-wiki2020-01-01/partitions",
                                       # top 2% most frequent items,  roughly at least 100 occurrences, with a total of  ~ 20000 entities
                                       #cutoff_frequency=0.02, recount=True 
                                    # TODO: is this representative enough?
)

Loaded from cache


In [3]:
NUM_WORKERS=16

In [4]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler

# https://discuss.pytorch.org/t/random-subset-from-dataloader-unique-random-numbers/43430/6
import numpy as np
np.random.seed(42)


wiki_dev_size = int(0.1*len(wikipedia_cbor))
    
wiki_dev_indices = np.random.choice(len(wikipedia_cbor), size=wiki_dev_size)

# 80/20% split
wiki_train_size = int(0.8*wiki_dev_size)
wiki_validation_size = wiki_dev_size - wiki_train_size

wiki_train_indices, wiki_validation_indices = wiki_dev_indices[:wiki_train_size], wiki_dev_indices[wiki_train_size:]
wiki_train_sampler = SubsetRandomSampler(wiki_train_indices,
                                         generator=torch.Generator().manual_seed(42))
wiki_validation_sampler = SubsetRandomSampler(wiki_validation_indices,
                                              generator=torch.Generator().manual_seed(42))

wiki_train_dataloader = DataLoader(wikipedia_cbor, sampler=wiki_train_sampler,
                                   batch_size=wandb.config.eae_batch_size,
                                   num_workers=NUM_WORKERS)
wiki_validation_dataloader = DataLoader(wikipedia_cbor,
                                        sampler=wiki_validation_sampler,
                                        batch_size=wandb.config.eae_batch_size,
                                        num_workers=NUM_WORKERS)
# wikipedia_cbor_train, wikipedia_cbor_validation = random_split(wikipedia_


## Model

### Load and finetune the model

In [5]:
from transformers import BertForMaskedLM, BertForTokenClassification
from models import EntitiesAsExperts, EaEForQuestionAnswering
from models.training import load_model, save_models,train_model, get_optimizer, get_schedule
model_masked_lm = BertForMaskedLM.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
from models.device import get_available_device

pretraining_model = EntitiesAsExperts(model_masked_lm,
                                      wandb.config.eae_l0,
                                      wandb.config.eae_l1, 
                                      wikipedia_cbor.max_entity_num,
                                     wandb.config.eae_entity_embedding_size).to(get_available_device())

wandb.watch(pretraining_model)

[<wandb.wandb_torch.TorchGraph at 0x7fad625440f0>]

In [7]:
pretraining_optimizer = get_optimizer(pretraining_model)
pretraining_schedule = get_schedule(wandb.config.eae_pretraining_epochs,
                                    pretraining_optimizer, wiki_train_dataloader)

# TODO: automatically send the model to the device AND provide on/off switches for it

def wiki_load_batch(batch):
    # Nothing interesting...
    return batch, tuple()

train_model(pretraining_model, wiki_train_dataloader,
                wiki_validation_dataloader, wiki_load_batch, pretraining_optimizer,
                pretraining_schedule, wandb.config.eae_pretraining_epochs, None, 8)

100%|██████████| 2888/2888 [4:09:58<00:00,  5.19s/it]  
  0%|          | 0/722 [00:05<?, ?it/s]


TypeError: can only concatenate list (not "NoneType") to list

## Save the model!

In [8]:
from models.training import save_models

save_models(pretraining_eae_one_epoch=pretraining_model)

  This is separate from the ipykernel package so we can avoid doing imports until
