# BERT representation storage and loading (with sharding)

In other examples, the data stored with `seqp` were discrete symbols from a closed set, like text tokens (with finite word-based vocabulary) and DNA nucleotide strings.

In this notebook, we are going to store floating point numbers. More specifically, we are going to use [BERT](https://github.com/huggingface/pytorch-pretrained-BERT) to encode pieces of text as sequences of contextual token representations.


## Interface to BERT

We will be using Hugging Face's port of BERT to Pytorch. Nevertheless, we will prepare a convenient wrapper to easily interface BERT. For this wrapper, we will use `seqp`'s `TextCodec` as base class.

In [1]:
import numpy as np
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
from seqp.encoding import TextCodec
import torch
from typing import Union, List, Optional


DEFAULT_BERT_WEIGHTS = 'bert-base-multilingual-cased'

class BertInterface(TextCodec):
    def __init__(self, use_gpu=False):
        self.tokenizer = BertTokenizer.from_pretrained(DEFAULT_BERT_WEIGHTS)
        self.model = BertForMaskedLM.from_pretrained(DEFAULT_BERT_WEIGHTS)
        self.model.eval()
        use_gpu = use_gpu and torch.cuda.is_available()
        self.device = torch.device("cuda" if use_gpu else "cpu")
        self.model.to(self.device)

    def decode(self, embedded: Union[np.ndarray, torch.Tensor]) -> List[str]:
        if isinstance(embedded, np.ndarray):
            if len(embedded.shape) == 2:  # seq_length x emb_dim
                embedded = np.expand_dims(embedded, 0)  # add batch dimension
            assert len(embedded.shape) == 3
            embedded = torch.from_numpy(embedded).to(self.device)
        predictions = self.model.cls(embedded)
        predicted_indexes = torch.argmax(predictions, dim=2).cpu().numpy()
        predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_indexes[0].tolist())
        return predicted_tokens

    def detokenize(self, tokens: List[str]) -> str:
        return " ".join(tokens).replace(" ##", "")

    def tokenize(self, sentence: str) -> List[str]:
        return self.tokenizer.tokenize(sentence)

    def encode(self, tokens: List[str]) -> Optional[np.ndarray]:
        tokenized_text = ['[CLS]'] + tokens
        if len(tokenized_text) > self.tokenizer.max_len:
            return None

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.LongTensor([indexed_tokens]).to(self.device)
        sequence_output, _ = self.model.bert(tokens_tensor, output_all_encoded_layers=False)
        return sequence_output.detach()[0].cpu().numpy()

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


There are many details about BERT and Pytorch that we are encapsulated in this wrapper. They are not important for this example, but if you want to know more and don't know how to start with:
- About BERT: please read [the original article](https://arxiv.org/abs/1810.04805) or one of the many tutorials about it, like [The Illustrated BERT](http://jalammar.github.io/illustrated-bert/), and please have a look at [Hugging Face's wonderful port](https://github.com/huggingface/pytorch-pretrained-BERT).
- About Pytorch: I recommend the book [Natural Language Processing with PyTorch
](http://shop.oreilly.com/product/0636920063445.do).


## Encode BERT representations and persist them to sharded HDF5 files

First we will download a text file (the Universal Declaration of Human Rights):

In [2]:
!wget -q http://research.ics.aalto.fi/cog/data/udhr/txt/eng.txt

Now we will iterate over the lines in the file, encoding the text as contextual vector representations of each token, and we will save the sequences as `seqp` records:

In [3]:
from seqp.hdf5 import Hdf5RecordWriter
from seqp.record import ShardedWriter

input_file = 'eng.txt'
output_file_template = "bert_example_{:02d}.hdf5"

bert = BertInterface()

with ShardedWriter(Hdf5RecordWriter,
                   output_file_template,
                   max_records_per_shard=10) as writer, open(input_file) as f:
    for line in f:
        line = line.strip()
        tokens = bert.tokenize(line)
        ctx_representations = bert.encode(tokens)
        writer.write(ctx_representations)

## Read representations back and decode them

Now, we will use a `Hdf5RecordReader` to read back the token vector representations and will decode them back into tokens by means of our BERT interface.

In [4]:
from glob import glob
from seqp.hdf5 import Hdf5RecordReader

with Hdf5RecordReader(glob('bert_example_*.hdf5')) as reader:
    indexes = list(reader.indexes())
    ctx_rep = reader.retrieve(indexes[9])
    print("Vector representation shape: {}".format(ctx_rep.shape))
    sentence = bert.detokenize(bert.decode(ctx_rep))
    print("Sentence: {}".format(sentence))
    

Vector representation shape: (338, 768)
Sentence: . .amble and recognition of the inherent dignity and of the equal and inalienable rights of all members of the human family is the foundation of freedom , justice and peace in the world , whereas disregard and contempt for human rights have resulted in barbarous acts which have outraged the conscience of mankind , and the advent of a world in which human beings shall enjoy freedom of speech and belief and freedom from fear and want has been proclaimed as the highest aspiration of the common people , whereas it is essential , if man is not to be compelled to have recourse , as a last resort , to rebellion against tyranny and oppression , that human rights should be protected by the rule of law , whereas it is essential to promote the development of friendly relations between nations , whereas the peoples of the united nations have in the charter reaffirmed their faith in fundamental human rights , in the dignity and worth of the human pe