# Entities as Experts

This notebook is a code implementation of the paper "Entities as Experts: Sparse Memory Access with Entity Supervision" by Févry, Baldini Soares, FitzGerald, Choi, Kwiatowski.

## Problem definition and high-level model description

We want to perform question answering on typical one-shot questions that require external knowledge or context. For example, in order to answer the question "Which country was Charles Darwin born in?" one needs some text providing answers on typical structured scenarios.

In this case, however, we want to rely on knowledge-graph extracted information. For example, in the question given here, we can prune out unrelated to the antropologist and evolution theorist Charles Darwins, e.g. Charles River, Darwin City etc. 

In the paper, the authors propose to augment BERT in the task of cloze-type question answering by leveraging an Entity Memory extracted from e.g. a Knoweldge Graph.

![Entity as Experts description](images/eae_highlevel.png)

The Entity Memory is a simple bunch of embeddings of entities extracted from a Knowledge Graph. Relationships are ignored (see the Facts as Experts paper and notebook to see how they could be used).

## Datasets

> We assume access to a corpus $D={(xi,mi)}$,where all entity mentions are detected but not necessarily  all  linked  to  entities.   We  use  English Wikipedia as our corpus, with a vocabulary of 1m entities. Entity links come from hyperlinks, leading to 32m 128 byte contexts containing 17m entity links.

In the appendix B, it is explained that:

> We build our training corpus of contexts paired with entity mention labels from the 2019-04-14 dump of English Wikipedia. We first divide each article into chunks of 500 bytes,resulting in a corpus of 32 million contexts withover 17 million entity mentions. We restrict our-selves  to  the  one  million  most  frequent  entities
(86% of the linked mentions).

Given that the dump 2019-04-14 is not available at the time of writing, we will adopt the revision 2020-11-01.

Entities are thus partially extracted by link annotations (e.g. they associate with each token a mention if that token belongs to a wikipedia url).

## Mention Detection

> In addition to the Wikipedia links, we annotaten each sentence with unlinked mention spans using the mention detector from Section 2.2

The mention detection head discussed in Section 2.2 is a simple BIO sequence: each token is annotated with a B (beginning), I (inside) or O (outside) if they are respectivelly beginning, inside or outside of a mention. The reason why we use both BIO and EL is to avoid inconsistencies.

There is a catch. In the paper, they explain they used Google NLP APIs to perform entity detection and linking on large-scale Wikipedia entries, that is, to have a properly annotated Wikipedia dataset.

Since we cannot technically afford this, we will use spacy's entity detection and linking capabilities as a baseline. Data quality 

## Chunking

- Split articles by chunks of 500 bytes (assuming unicode encoding).
- We will elide sentences till the last period to make sure they reach such limit without giving weird effects.

## Tokenization:

- BERT Tokenizer (e.g. Wordpiece) using lowercase vocabulary, limited to 128 distinct word-piece tokens.

#### Wikipedia

In [2]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
import pandas as pd
import numpy as np

spark = SparkSession.builder \
                        .getOrCreate()

# Generate a Pandas DataFrame
pdf = pd.DataFrame(np.random.rand(100, 3))

# Create a Spark DataFrame from a Pandas DataFrame using Arrow
df = spark.createDataFrame(pdf)

# Convert the Spark DataFrame back to a Pandas DataFrame using Arrow
result_pdf = df.select("*").toPandas()

In [3]:
# from tools.providers import WikipediaProvider
# WikipediaProvider.dump_full_dataset(revision="20201101")

from trec_car import read_data
from tools.dumps import wrap_open
from collections import defaultdict
#import spacy
import torch
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')    # Download vocabulary from S3 and cache.

from trec_car.read_data import Page, Section, List, Para, ParaLink, ParaText, ParaBody

#nlp = spacy.load("en_core_web_md", pipeline=["tokenizer"])

Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_master


In [27]:
def handle_section(skel, toks, links, tokenize):
    for subskel in skel.children:
        visit_section(subskel, toks, links, tokenize)

def handle_list(skel, toks, links, tokenize):
    visit_section(skel.body, toks, links, tokenize)

def handle_para(skel: Para, toks, links, tokenize):
    paragraph = skel.paragraph
    bodies = paragraph.bodies

    for body in bodies:
        visit_section(body, toks, links, tokenize)

def handle_paratext(body: ParaBody, toks, links, tokenize):
    if tokenize:
        lemmas = tokenizer.tokenize(body.get_text())
        toks.extend(lemmas)
        links.extend(["PAD"] * len(lemmas))

def handle_paralink(body: ParaLink, toks, links, tokenize):
    lemmas = tokenizer.tokenize(body.get_text())
    if tokenize:
        toks.extend(lemmas)
        links.extend([body.page] + ["PAD"] * (len(lemmas) - 1))
    else:
        links.append(body.page)
    pass

def nothing():
    return lambda body, toks, links, tokenize: None

handler = defaultdict(nothing, {Section: handle_section,
                     Para: handle_para,
                     List: handle_list,
                     ParaLink: handle_paralink,
                     ParaText: handle_paratext})


def visit_section(skel, toks, links, tokenize=True):
    # Recur on the sections
    handler[type(skel)](skel, toks, links, tokenize)

In [32]:
# That's a small example to see if it's working.
# It will likely take the first page available, Anarchism

from collections import Counter

with wrap_open("wikipedia/car-wiki2020-01-01/enwiki2020.cbor", "rb") as toc:
    for idx, page in enumerate(read_data.iter_annotations(toc)):
        links = []
        toks = []
        for skel in page.skeleton:
            visit_section(skel, toks, links, False)
            
        # print(Counter(links))
        break

In [6]:
from tools.dumps import get_filename_path

cbor_path = get_filename_path("wikipedia/car-wiki2020-01-01/enwiki2020.cbor")

from trec_car.read_data import AnnotationsFile, ParagraphsFile

cbor_toc_annotations = AnnotationsFile(cbor_path)
cbor_toc_paragraphs = ParagraphsFile(cbor_path)

In [7]:
import numpy as np

keys = list(cbor_toc_annotations.keys())

In [8]:
len(keys)

7893216

In [9]:
cbor_toc_annotations.get(keys[0])

<trec_car.read_data.Page at 0x7f9133788f98>

In [10]:
toks, links = [], []

page = cbor_toc_annotations.get(keys[0])

In [11]:
key_to_use = cbor_toc_paragraphs.toc.get(b'enwiki:Touch%20(TV%20series)')

In [12]:
#page = cbor_toc_paragraphs.get(b'enwiki:U.S.%20Route%20277')
from sklearn.preprocessing import LabelEncoder
import tqdm

page = cbor_toc_annotations.get(b'enwiki:U.S.%20Route%20277')

# cbor_toc_paragraphs.get(46102636493)
#print(page)

In [13]:
import cbor
import mmap



In [14]:
values = list(cbor_toc_annotations.toc.values())
values.sort()

In [15]:
key_title = list()

with mmap.mmap(cbor_toc_annotations.cbor.fileno(), 0, mmap.MAP_PRIVATE) as cbor_file:
    for offset in tqdm.tqdm(values):
        key_title.append(extract_from_key(offset))

100%|██████████| 7893216/7893216 [01:04<00:00, 121955.38it/s]


In [19]:
keys = np.array(keys)
key_title_set = set(key_title)

In [21]:
"Wrestling at the 1912 Summer Olympics – Men's Greco-Roman middle" in key_title_set

False

In [47]:
import itertools

key_encoder = dict(zip(key_title, itertools.count()))

In [79]:
max_entity_num = len(key_encoder)
print(max_entity_num)

from typing import List, Union, Iterable
import torch
from torch.utils.data import Dataset
import tqdm
from transformers import BertTokenizer, BertConfig
from tools.dumps import get_filename_path

from keras.preprocessing.sequence import pad_sequences

import numpy as np
import random

from trec_car.read_data import AnnotationsFile, ParagraphsFile, Page

import concurrent.futures as futures

from sklearn.preprocessing import LabelEncoder



extracted_links = extract_links()
#v1 = extract_links_monothreaded(get_pages(np.arange(10)))
#v2 = extract_links_monothreaded(get_pages(np.arange(10, 20)))

7893216


In [80]:
extracted_links

tensor(indices=tensor([[      0,       0,       0,  ...,       9,       9,
                              9],
                       [     41,      91,     226,  ..., 7890336, 7890753,
                        7893063]]),
       values=tensor([68, 34,  1,  ..., 61, 40, 40]),
       size=(10, 7893216), nnz=20020, dtype=torch.int32, layout=torch.sparse_coo)

In [81]:
def b2i(x):
    return int.from_bytes(x, "big")

class WikipediaCBOR(Dataset):
    """
    This is a simple CBOR loader.
    """
    def __init__(self, cbor_path, max_entity_num=1_000_000):
        """
        :param cbor_path the path of the wikipedia cbor export
        """
        # Let trec deal with that in my place
        
        self.cbor_path = get_filename_path(cbor_path)
        self.cbor_toc_annotations = AnnotationsFile(self.cbor_path)
        self.cbor_toc_paragraphs = ParagraphsFile(self.cbor_path)
        
        self.keys = np.fromiter(cbor_toc_paragraphs.keys(), dtype='<U64')
        self.key_encoder = LabelEncoder().fit(self.keys)
        
        # preprocess and find the top k unique wikipedia links
        self.key_titles = self.extract_radable_key_titles()
        
        # page frequencies
        self.total_freqs = self.__extract_links()
        
    def extract_readable_key_titles(self):
        """
        Build a list of human-readable names of CBOR entries.
        Compared to self.keys, these keys are not binary encoded formats.
        """
        def extract_from_key(offset):
            cbor_file.seek(offset)

            # We refer to the RFC 8949 about the CBOR structure
            # See https://tools.ietf.org/html/rfc8949 for details
            len_first_field = b2i(cbor_file.read(1))
            field_type = (len_first_field & (0b11100000)) >> 5

            # array
            if field_type == 0b100:
                # ignore the next byte
                cbor_file.read(1)
                first_elem_header = b2i(cbor_file.read(1))
                first_elem_len = first_elem_header & 31
                # first_elem_tag = first_elem_header >> 5

                if first_elem_len > 23:
                    first_elem_len = b2i(cbor_file.read(first_elem_len - 23))

                return cbor_file.read(first_elem_len).decode('utf-8')

            else:
                raise Exception("Wrong header")
                
        # Sorted seeks should make the OS scheduler less confused, hopefully
        values = list(self.cbor_toc_annotations.toc.values())
        values.sort()
        
        key_titles = set()
        
        # If reloaded for a second time, this should be way faster.
        with mmap.mmap(self.cbor_toc_annotations.cbor.fileno(), 0, mmap.MAP_PRIVATE) as cbor_file:
            for offset in tqdm.tqdm(values):
                key_titles.add(extract_from_key(offset))
                
        return key_titles
     

    def self.__get_pages (self, idx: Union[int, Iterable[int]]) -> List[Page]:
        """
        Extract some Page's from the CBOR file.
        
        :arg idx the index of the pages to extract
        :returns a list of pages
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        elif type(idx) != list and type(idx) != np.ndarray:
            idx = [idx]

        return [self.cbor_toc_annotations.get(k) for k in self.keys[idx]]

    def __extract_links_monothreaded(self, pages: List[Page]) -> torch.sparse.LongTensor:
        """
        Calculate the frequency of each mention in wikipedia.
        
        :arg parges the list of Wikipedia pages
        :returns a sparse torch tensor
        """
        
        toks = []

        for page in pages:
            # remove spurious None elements
            if page is None:
                continue
            for skel in page.skeleton:
                visit_section(skel, [], links, False)

        freqs = Counter(links)

        # remove mentions that do not have an associated wikipedia page
        keys = list(freqs.keys())
        for key in keys:
            if key not in key_title_set:
                del freqs[key]

        keys = np.array([[0, key_encoder.get(k, -1)] for k in freqs.keys()]).T.reshape(2, -1)
        values = np.fromiter(freqs.values(), dtype=np.int32)

        return torch.sparse_coo_tensor(keys, values,
                                             size=(1, max_entity_num))

    def __extract_links(self, pages_per_worker=100, page_lim=1000):
        """
        Create some page batches and count mention occurrences for each batch.
        Summate results.

        This method is threaded (hopefully for the common good).
        """

        if page_lim is None:
            page_lim = len(keys)

        starting_tensor = torch.sparse.LongTensor(1, max_entity_num)
        tensors = []
        with futures.ThreadPoolExecutor() as executor:
            promises = []
            for i in range(0, page_lim, pages_per_worker):
                pages = self.__get_pages(np.arange(idx, min(idx+pages_per_worker, page_lim)))
                promises.append(executor.submit(self.__extract_links_monothreaded, self, pages))
            for promise in futures.as_completed(promises):
                tensors.append(promise.result())

        return torch.sparse.sum(torch.stack(tensors), [1])
    
    
    def __extract_links_monothreaded(self, pages: List[Page]):
        """
        Calculate the frequency of each mention in wikipedia.
        Return a sparse torch tensor
        """
        toks = []
        
        for skel in page.skeleton:
            visit_section(skel, [], links, False)
        
        keys = self.key_ecnoder.transform(np.fromiter(freqs.keys(), dtype='<U64'))
        
        return torch.sparse_coo_tensor(keys, freqs.values(), self.max_entity_num)
    
    def __get_pages(self, idx: Union[int, Iterable[int]]):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        elif type(idx) != list and type(idx) != np.array:
            idx = [idx]
        
        return [self.cbor_toc_annotations.get(k.encode('ascii')) for k in self.keys[idx]]
        
    def __extract_links(self, pages_per_worker=100, page_lim=1000):
        """
        Create some page batches and count mention occurrences for each batch.
        Summate results.
        
        This method is threaded (hopefully for the common good).
        """
        
        if page_lim is None:
            page_lim = len(self)
            
        starting_tensor = torch.sparse.LongTensor(1, page_lim)
        tensors = []
        with futures.ThreadPoolExecutor() as executor:
            promises = []
            for i in range(0, len(self), pages_per_worker):
                pages = self.__get_pages(np.arange(idx, min(idx+pages_per_worker, page_len)))
                promises.append(executor.submit(self.__extract_links_monothreaded, self, pages))
            for promise in promises.as_completed(futures):
                tensor.append(promise.result())
                
        return torch.tensor(tensors).sum(1)
                
        
    def __len__(self):
        return len(self.keys)
    
    def __tokenize(self, page: Page):
        toks = []
        links = []
        for skel in page.skeleton:
            visit_section(skel, toks, links)
        return toks, links
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        elif type(idx) != list:
            idx = [idx]
        
        pages = [self.cbor_toc_annotations.get(k.encode('ascii')) for k in self.keys[idx]]
        
        # can we parallelize this?
        result = [self.__tokenize(page) for page in pages]
        print(result)
        
        return torch.tensor(result)

SyntaxError: invalid syntax (<ipython-input-81-071f57c40757>, line 60)

In [22]:
cbor_dataloader = WikipediaCBOR("wikipedia/car-wiki2020-01-01/enwiki2020.cbor")

for idx, batch in enumerate(cbor_dataloader):
    print(batch)
    
    if idx >= 0:
        break

UnboundLocalError: local variable 'futures' referenced before assignment

In [93]:
with wrap_open("wikipedia/car-wiki2020-01-01/enwiki2020.cbor", "rb") as toc:
    #toc.seek(46102636493)
    
    for idx, page in enumerate(read_data.iter_pages(toc)):
        if idx >= 10:
            break
        print(page)
        page.

Page(Anarchism)
Page(Albedo)
Page(Autism)
Page(A)
Page(Achilles)
Page(Alabama)
Page(Abraham Lincoln)
Page(An American in Paris)
Page(Aristotle)
Page(Academy Award for Best Production Design)


## BIO

In [22]:
import torch

from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

In [23]:
model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')    # Download vocabulary from S3 and cache.

Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_master
Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_master


In [3]:
tokenizer.tokenize("this is a test, [we might even exclude that](nice link)")

['this',
 'is',
 'a',
 'test',
 ',',
 '[',
 'we',
 'might',
 'even',
 'exclude',
 'that',
 ']',
 '(',
 'nice',
 'link',
 ')']

In [6]:
from zipfile import ZipFile
from tools.dumps import wrap_open
import pandas as pd
from tqdm import tqdm, trange

!cd data && unzip -f ner.csv.zip

Archive:  ner.csv.zip


In [7]:
# The columns are a bit irregular.
names = []
with wrap_open("ner.csv", "r", encoding="latin1") as f:
    print(f.readline())
    f.seek(0)
    names = ["index"] + f.readline().strip().split(",")[1:]
    names = names + list(range(34 - len(names)))

print(names)

with wrap_open("ner.csv", "rb") as f:
    f.readline() # skip the first line
    data = pd.read_csv(f, encoding="latin1", names=names).fillna(method="ffill")
data.tail(10)

,lemma,next-lemma,next-next-lemma,next-next-pos,next-next-shape,next-next-word,next-pos,next-shape,next-word,pos,prev-iob,prev-lemma,prev-pos,prev-prev-iob,prev-prev-lemma,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag

['index', 'lemma', 'next-lemma', 'next-next-lemma', 'next-next-pos', 'next-next-shape', 'next-next-word', 'next-pos', 'next-shape', 'next-word', 'pos', 'prev-iob', 'prev-lemma', 'prev-pos', 'prev-prev-iob', 'prev-prev-lemma', 'prev-prev-pos', 'prev-prev-shape', 'prev-prev-word', 'prev-shape', 'prev-word', 'sentence_idx', 'shape', 'word', 'tag', 0, 1, 2, 3, 4, 5, 6, 7, 8]


  interactivity=interactivity, compiler=compiler, result=result)


Unnamed: 0,index,lemma,next-lemma,next-next-lemma,next-next-pos,next-next-shape,next-next-word,next-pos,next-shape,next-word,...,tag,0,1,2,3,4,5,6,7,8
1050786,1048565,impact,.,__end1__,__END1__,wildcard,__END1__,.,punct,.,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050787,1048566,.,__end1__,__end2__,__END2__,wildcard,__END2__,__END1__,wildcard,__END1__,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050788,1048567,indian,forc,said,VBD,lowercase,said,NNS,lowercase,forces,...,B-gpe,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050789,1048568,forc,said,they,PRP,lowercase,they,VBD,lowercase,said,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050790,1048569,said,they,respond,VBD,lowercase,responded,PRP,lowercase,they,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050791,1048570,they,respond,to,TO,lowercase,to,VBD,lowercase,responded,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050792,1048571,respond,to,the,DT,lowercase,the,TO,lowercase,to,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050793,1048572,to,the,attack,NN,lowercase,attack,DT,lowercase,the,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050794,1048573,the,attack,with,IN,lowercase,with,NN,lowercase,attack,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag
1050795,1048574,attack,with,machine-gun,JJ,contains-hyphen,machine-gun,IN,lowercase,with,...,O,prev-prev-pos,prev-prev-shape,prev-prev-word,prev-shape,prev-word,sentence_idx,shape,word,tag


In [8]:
def simplify_bio(column):
    return column[0]

data["bio"] = data["tag"].apply(simplify_bio)
data["bio"]

0          O
1          O
2          O
3          O
4          O
          ..
1050791    O
1050792    O
1050793    O
1050794    O
1050795    O
Name: bio, Length: 1050796, dtype: object

In [9]:
MAX_LEN = 75 ## can replace with 512 as per the original paper
bs = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

'GeForce RTX 2080 Ti'

In [10]:
def aggregate(s):
    return [(w, t) for w, t in zip(s["word"].values.tolist(), s["bio"].values.tolist())]

sentences = [s for s in data.groupby("sentence_idx").apply(aggregate)]

In [11]:
sentences[2]

[('They', 'O'),
 ('marched', 'O'),
 ('from', 'O'),
 ('the', 'O'),
 ('Houses', 'O'),
 ('of', 'O'),
 ('Parliament', 'O'),
 ('to', 'O'),
 ('a', 'O'),
 ('rally', 'O'),
 ('in', 'O'),
 ('Hyde', 'B'),
 ('Park', 'I'),
 ('.', 'O')]

In [12]:
utterances = [[w[0] for w in s] for s in sentences]
labels = [[w[1] for w in s] for s in sentences]

bio_values = list(set(data["bio"].values))
bio_values.append("PAD")
# Apparently one row is misclassified.
bio2idx = {t: i for i, t in enumerate(bio_values)}
bio2idx['p'] = bio2idx['O']

In [13]:
def tokenize_preserve_labels(sentence, text_labels):
    """
    Tokenize the given sentence. Extend the corresponding label
    for all the tokens the word is made of.
    
    Assumption: len(sentence) == len(text_labels)
    """
    
    tokenized_sentence = []
    labels = []
    
    for word, label in zip(sentence, text_labels):
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)
        
        tokenized_sentence.extend(tokenized_word)
        labels.extend([label] * n_subwords)
    
    return tokenized_sentence, labels

In [14]:
tokenized_texts_labels = [
    tokenize_preserve_labels(sent, labs) for sent, labs in zip(utterances, labels)
]

In [15]:
tokenized_texts = [token_label_pair[0] for token_label_pair in tokenized_texts_labels]
tokenized_labels = [token_label_pair[1] for token_label_pair in tokenized_texts_labels]

In [16]:
input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
                                maxlen=MAX_LEN, dtype="long", value=0.0, truncating="post", padding="post")
labels = pad_sequences([[bio2idx.get(l) for l in lab] for lab in tokenized_labels],
                               maxlen=MAX_LEN, value=bio2idx["PAD"], padding="post", dtype="long", truncating="post")

In [17]:
# In this classification task we want to classify every token, thus mask everything else
attention_masks = [[float(i != 0.0) for i in ii] for ii in input_ids]

In [18]:
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, labels, random_state=42, test_size=0.1)
train_masks, validation_masks, _, _ = train_test_split(attention_masks, input_ids, random_state=42, test_size=0.1)

train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)

In [19]:
# Frankly this code looks horrible - need to delve into pytorch's dataloader tools API
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_sampler = RandomSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=bs)

In [68]:
def transform_sentence(sentence: str):
    tokens = tokenizer.tokenize(sentence)
    print(tokens)
    padded = pad_sequences([tokenizer.convert_tokens_to_ids(tokens)], maxlen=MAX_LEN,
                  dtype="long", value=0.0, truncating="post", padding="post")
    
    attention_mask = [[float(tok != 0.0) for tok in padded_] for padded_ in padded]
    
    return padded, attention_mask

# bioclassifier.forward(tokens,)

padded, attention = transform_sentence("Hello world, this is Spongebob!")

bioclassifier.eval()
res = bioclassifier.forward(torch.tensor(padded).to(device), token_type_ids=None, attention_mask=torch.tensor(attention).to(device), labels=None)

['hello', 'world', ',', 'this', 'is', 'sponge', '##bo', '##b', '!']


In [74]:
predictions = np.array(bio_values)[np.argmax(res[0].detach().cpu().numpy(), axis=2)]

In [75]:
predictions

array([['B', 'O', 'O', 'O', 'O', 'B', 'B', 'B', 'O', 'O', 'O', 'O', 'O',
        'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',
        'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B', 'O', 'O', 'O',
        'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',
        'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B', 'B', 'B', 'O', 'O',
        'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']], dtype='<U3')

## Model

In the paper, the authors explain they used a modified BERT.

### Load and finetune the model

In [48]:
# I guess I need to use pytorch hooks to do what I
# want. Yeeee, yet one more thing to study

from torch.nn import Module, Linear, Dropout
from transformers.modeling_bert import BertEncoder, BertModel, BertForTokenClassification
from copy import deepcopy



class TruncatedEncoder(Module):
    def __init__(self, encoder: BertEncoder, l0: int):
        super().__init__()
        __doc__ = encoder.__doc__
        self.encoder = deepcopy(encoder)
        self.encoder.layer = self.encoder.layer[:l0]
        
        
    def forward(self, *args, **kwargs):
        __doc__ = self.encoder.forward.__doc__
        return self.encoder(*args, **kwargs)

class TruncatedModel(Module):
    def __init__(self, model: BertModel, l0: int = 4):
        super().__init__()
        self.model = deepcopy(model)
        self.model.encoder = TruncatedEncoder(self.model.encoder, l0)
    
    def forward(self, *args, **kwargs):
        __doc__ = self.model.forward.__doc__
        return self.model(*args, **kwargs)


class BioClassifier(Module):
    def __init__(self,  bertmodel: TruncatedEncoder):
        super().__init__()
        self.bert = bertmodel
        self.dropout = Dropout(p=0.1)
        self.classifier = Linear(in_features=768, out_features=4, bias=True)
        self.num_labels = 4
    
    def forward(self, *args, **kwargs):
        return BertForTokenClassification.forward(self, *args, **kwargs)

In [49]:
bioclassifier = BioClassifier(TruncatedModel(model)).cuda()

In [50]:
FULL_FINETUNING = True
param_optimizer = list(bioclassifier.named_parameters())

from transformers import AdamW

if FULL_FINETUNING:
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    optimizer_grouped_parameters = [{'params': [p for n, p in param_optimizer]}]
    
optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=3e-5,
    eps=1e-8
)

In [76]:
from transformers import get_linear_schedule_with_warmup

epochs = 10
max_grad_norm = 1.0

total_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

In [None]:
from seqeval.metrics import f1_score, accuracy_score

import numpy as np
from torch.nn.utils import clip_grad_norm_

loss_values, validation_loss_values = [], []

for epoch in range(epochs):
    bioclassifier.train()
    total_loss = 0
    
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        bioclassifier.zero_grad()
            
        
        outputs = bioclassifier(b_input_ids, token_type_ids=None,
                                attention_mask=b_input_mask, labels=b_labels)
        
        # Someone has to explain to me why someone put the loss function inside a module
        loss = outputs[0]
        loss.backward()
        total_loss += loss.item()
        clip_grad_norm_(parameters=bioclassifier.parameters(),
                        max_norm=max_grad_norm)
    
        optimizer.step()
        scheduler.step()
    
    avg_train_loss = total_loss / len(train_dataloader)
    print("Average train loss: {}".format(avg_train_loss))
    
    loss_values.append(avg_train_loss)
    
    model.eval()
    
    eval_loss, eval_accuracy = 0.0, 0.0
    number_eval_steps, number_eval_examples = 0, 0
    predictions, true_labels = [], []
    
    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        
        with torch.no_grad():
            outputs = bioclassifier(b_input_ids, token_type_ids=None,
                            attention_mask=b_input_mask, labels=b_labels)
        
        logits = outputs[1].detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        
        eval_loss += outputs[0].mean().item()
        
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.extend(label_ids)
        
    eval_loss = eval_loss / len(validation_dataloader)
    validation_loss_values.append(eval_loss)
    pred_tags = [[bio_values[p_i] for p_i, l_i in zip(p, l)]
                    for p, l in zip(predictions, true_labels)]
    
    true_tags = [[bio_values[l_i] for l_i in l] for l in true_labels]
    
    print(f"Validation Accuracy: {accuracy_score(pred_tags, true_tags)}")
    print(f"Validation F1-Score: {f1_score(pred_tags, true_tags)}")
    print(f"Validation loss: {eval_loss}")
    print()

100%|██████████| 1011/1011 [00:58<00:00, 17.37it/s]


Average train loss: 0.056947913965989645
Validation Accuracy: 0.4191223675665646


  0%|          | 2/1011 [00:00<00:58, 17.14it/s]

Validation F1-Score: 0.654957695598441
Validation loss: 0.084311828248005



100%|██████████| 1011/1011 [01:01<00:00, 16.33it/s]


Average train loss: 0.04006332797058199
Validation Accuracy: 0.41927451526115594


  0%|          | 2/1011 [00:00<00:59, 16.96it/s]

Validation F1-Score: 0.6568595041322314
Validation loss: 0.08999692602495177



100%|██████████| 1011/1011 [01:01<00:00, 16.35it/s]


Average train loss: 0.028179564128208878
Validation Accuracy: 0.41935244456814175


  0%|          | 2/1011 [00:00<00:59, 16.94it/s]

Validation F1-Score: 0.6896914182746453
Validation loss: 0.10549780948961203



100%|██████████| 1011/1011 [01:01<00:00, 16.37it/s]


Average train loss: 0.02010758651828561
Validation Accuracy: 0.4192374060673532


  0%|          | 2/1011 [00:00<01:01, 16.46it/s]

Validation F1-Score: 0.6787399252997838
Validation loss: 0.11534326749367524



100%|██████████| 1011/1011 [01:01<00:00, 16.46it/s]


Average train loss: 0.014377029335612669
Validation Accuracy: 0.4194118192782262


  0%|          | 2/1011 [00:00<00:58, 17.31it/s]

Validation F1-Score: 0.6663807072729006
Validation loss: 0.12534657229496315



100%|██████████| 1011/1011 [01:01<00:00, 16.50it/s]


Average train loss: 0.010751042502718534
Validation Accuracy: 0.4195157250208739


  0%|          | 2/1011 [00:00<01:00, 16.57it/s]

Validation F1-Score: 0.6649154647053205
Validation loss: 0.13663531210174604



100%|██████████| 1011/1011 [01:01<00:00, 16.50it/s]


Average train loss: 0.008416197471756813
Validation Accuracy: 0.419838575006958


  0%|          | 2/1011 [00:00<00:59, 16.91it/s]

Validation F1-Score: 0.6667144529663346
Validation loss: 0.141331840603225



100%|██████████| 1011/1011 [01:01<00:00, 16.52it/s]


Average train loss: 0.006647726145021859
Validation Accuracy: 0.419704981909268


  0%|          | 2/1011 [00:00<00:59, 17.02it/s]

Validation F1-Score: 0.6782409662023959
Validation loss: 0.15386315876931217



100%|██████████| 1011/1011 [01:01<00:00, 16.57it/s]


Average train loss: 0.005283215134538172
Validation Accuracy: 0.4195231468596345


  0%|          | 2/1011 [00:00<00:58, 17.23it/s]

Validation F1-Score: 0.6670636149284301
Validation loss: 0.15581438350862106



 32%|███▏      | 328/1011 [00:19<00:41, 16.41it/s]

In [12]:
from torch.nn import Module, Embedding, Dropout, ModuleList, Linear
import torch.nn as nn
import torch
import math

GELU = torch.nn.GELU
LayerNorm = torch.nn.LayerNorm

l0 = 4
l1 = 8

    
class EntityMemory(Module):
    """
    Entity Memory, as described in the paper
    """
    def __init__(self, embedding_size: int, entity_size: int,
                   entity_embedding_size: int):
        """
        :param embedding_size the size of an embedding. In the EaE paper it is called d_emb, previously as d_k
            (attention_heads * embedding_per_head)
        :param entity_size also known as N in the EaE paper, the maximum number of entities we store
        :param entity_embedding_size also known as d_ent in the EaE paper, the embedding of each entity
        
        """
        self.N = entity_size
        self.d_ent = entity_embedding_size
        self.w_f = Linear(d_ent, 2*embedding_size)
        
    def forward(self, x, entity_spans, num_entities, k=None):
        """
        :param x the (raw) output of the first transformer block. It has a shape:
                B x N x (embed_size)
        :param entity_spans entities and spans of such entities.
                Shape: B x C x 3. Each "row" contains a triple (e_k, s_mi, t_mi)
                where e_k is an (encoded) entity id, s_mi and t_mi are indices.
        :param num_entities the number of found entities for each batch.
        :param k the number of nearest entities to consider when softmax-ing.
                if k = None, all the entities are used.
                In the paper, one should set k for when running the inference
        """
        
        mentions = [entity_spans[:, :mentions_per_batch] for mentions_per_batch in num_entities]
        pass
        