<a href="https://colab.research.google.com/github/marcomoldovan/hierarchical-text-encoder/blob/master/hierarchical_transformer_based_document_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Hierarchical Transformer-based Tocument Encoder** 

We present a multi-purpose hierarchical tocument encoder (HATE) based on stacking sentence-level and document-level transformers on top of each other. By doing this we try to capture as many semantic facets of a given, arbitrarily long document. The model implicitly learns contextualized word and sentence representations as well a robust document representation. All encodings live in the same representation space and therefore similarity metrics such as cosine similarity can be applied on all levels of representations interchangeably. This has the purpose that only one model would have to be trained in order to encode both a query and a candidate document as well as all its contextualized sentences in one representation space where similarity measures could be applied in order to retrieve most relevant documents as well as scoring answer passages candidates by their likelihood of relevancy. Due to the robust sentence embeddings that are also produced, the model could be used for document segmentation as well: since intuitively there would be a semantic break between paragraphs this would also be reflected in the representations of the sentences. One could apply an algorithms that separates unstructured texts into paragraphs at points where the likelihood of a semantic break, and therefore the start of a separate paragraph is high.

# **1.** Introduction



# **2.** Background

## 2.1 Statistical foundations of machine learning

(1) https://towardsdatascience.com/the-statistical-foundations-of-machine-learning-973c356a95f

More on regression, classification, etc. in a probabilistic context. Show that these problems are essentially MAP.

## 2.2 Mathematics of optimization for deep learning

(1) https://towardsdatascience.com/the-mathematics-of-optimization-for-deep-learning-11af2b1fda30

- Optimization: (1) visualizing loss landscape https://arxiv.org/pdf/1712.09913.pdf
- Momentum based optimizers
- Dropout (1) https://arxiv.org/pdf/1207.0580.pdf
- Batch normalization (1) https://arxiv.org/abs/1502.03167 (2) https://papers.nips.cc/paper/7515-how-does-batch-normalization-help-optimization.pdf
- Weight initialization
- Reguralization (1) NLP specific: https://mlexplained.com/2018/03/02/regularization-techniques-for-natural-language-processing-with-code-examples/

## 2.3 Representation Learning

Touch briefly on the theory of representation learning, independent of language. Main focus on the Bengio paper: https://arxiv.org/pdf/1206.5538.pdf Also refer to representation learning slides from DL&AI course from last semester as an appropriate introduction.

## 2.4 Language Models

- Language modeling via auto-encoding or auto-regressive methods in general


- Embeddings in Language Models (1) https://jalammar.github.io/skipgram-recommender-talk/Text (2) https://dspace.mit.edu/handle/1721.1/118079
-  Word embeddings (1) https://ruder.io/word-embeddings-1/index.html (2) https://ruder.io/word-embeddings-softmax/index.html (3) https://ruder.io/secret-word2vec/index.html (4) https://ruder.io/word-embeddings-2017/index.html (5) https://jalammar.github.io/illustrated-word2vec/ (6) Glove https://mlexplained.com/2018/04/29/paper-dissected-glove-global-vectors-for-word-representation-explained/ (7) ELMo https://mlexplained.com/2018/06/15/paper-dissected-deep-contextualized-word-representations-explained/ (8) https://p.migdal.pl/2017/01/06/king-man-woman-queen-why.html (9) https://lilianweng.github.io/lil-log/2017/10/15/learning-word-embedding.html#loss-functions
- Sentence embeddings (1) https://supernlp.github.io/2018/11/26/sentreps/ (2) https://mlexplained.com/2017/12/28/an-overview-of-sentence-embedding-methods/ (3) https://medium.com/huggingface/universal-word-sentence-embeddings-ce48ddc8fc3a
- Document Embeddings (1) https://towardsdatascience.com/document-embedding-techniques-fed3e7a6a25d#1409 (2) https://graphaware.com/nlp/2018/09/03/advanced-document-representation.html

- General early language models that are based on DL: RNNs/LSTM (touch shortly) (1) https://arxiv.org/pdf/1312.6026.pdf (2) https://distill.pub/2019/memorization-in-rnns/
- On the diffictuly of training recurrent neural networks https://arxiv.org/pdf/1211.5063.pdf
- Transition to attention mechanism, at first in RNNs
- Why these large DL-based models are so important for transfer learning in NLP (1) https://ruder.io/transfer-learning/ (2) https://thegradient.pub/nlp-imagenet/ (3) very linguistic study of word embeddings for transfer tasks https://arxiv.org/pdf/1903.08855.pdf

# **3.** Related Work

## 3.1 Attention and why it's all you need

- Attention is all you need (1) https://mlexplained.com/2017/12/29/attention-is-all-you-need-explained/ (2) https://arxiv.org/pdf/1706.03762.pdf (3) https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec (4) https://blog.floydhub.com/attention-mechanism/ (5) https://jalammar.github.io/illustrated-transformer/ (6) https://distill.pub/2016/augmented-rnns/ (7) https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html
- BERT (1) https://mlexplained.com/2019/01/07/paper-dissected-bert-pre-training-of-deep-bidirectional-transformers-for-language-understanding-explained/ (2) https://arxiv.org/pdf/1810.04805.pdf (3) https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/ (4) https://jalammar.github.io/illustrated-bert/
- What Does BERT Look At? An Analysis of BERT’s Attention 
https://arxiv.org/pdf/1906.04341.pdf https://arxiv.org/abs/1909.10430v2
- BERT raw embeddings https://arxiv.org/abs/1909.00512v1 https://towardsdatascience.com/examining-berts-raw-embeddings-fd905cb22df7
- Sentence embeddings revisited: BERT methods -> transition to intuition for my document representation model
- Call-back to section 2.4 where I touch on classical embeddings and compare to deep, high-parameter, attention based models such as BERT
- Why these large pretrained models are so important for transfer learning in NLP (1) https://ruder.io/transfer-learning/ (2) https://thegradient.pub/nlp-imagenet/


## 3.2 Hierarchical Attention-Models

- SMITH (1) https://arxiv.org/pdf/2004.12297v1.pdf (2) https://github.com/dmolony3/SMITH
- HIBERT (1) https://arxiv.org/pdf/1905.06566.pdf (2) https://github.com/liangsi03/hibert_model

## 3.3 Information Retrieval

- TANDA: Transfer and Adapt Pre-Trained Transformer Modelsfor Answer Sentence Selection https://arxiv.org/pdf/1911.04118.pdf

# **4.** The Model

Introduce all necessary top-level imports and install the transformer module from huggingface

In [2]:
!pip install transformers #installs transformer module from huggingface
!pip install datasets #installs dataset module from huggingface
!pip install tokenizers #installs tokenizer module from huggingface

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/99/84/7bc03215279f603125d844bf81c3fb3f2d50fe8e511546eb4897e4be2067/transformers-4.0.0-py3-none-any.whl (1.4MB)
[K     |▎                               | 10kB 16.0MB/s eta 0:00:01[K     |▌                               | 20kB 21.0MB/s eta 0:00:01[K     |▊                               | 30kB 25.5MB/s eta 0:00:01[K     |█                               | 40kB 24.2MB/s eta 0:00:01[K     |█▏                              | 51kB 18.4MB/s eta 0:00:01[K     |█▌                              | 61kB 18.1MB/s eta 0:00:01[K     |█▊                              | 71kB 13.6MB/s eta 0:00:01[K     |██                              | 81kB 13.8MB/s eta 0:00:01[K     |██▏                             | 92kB 13.4MB/s eta 0:00:01[K     |██▍                             | 102kB 12.5MB/s eta 0:00:01[K     |██▋                             | 112kB 12.5MB/s eta 0:00:01[K     |███                             | 

In [6]:
import re
import math
import torch
import nltk
from nltk import sent_tokenize
nltk.download('punkt')
import torch.nn as nn
import torch.nn.functional as f
import torchtext.datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import transformers
from transformers import BertConfig, BertModel, BertForMaskedLM
from pprint import pprint
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## 4.1 Transformer

#### Internal Embedding Lookup

In [8]:
class EmbeddingsLookup(nn.Module):
  def __init__(self, config):
    super().__init__()
    # Initialize the lookup matrix for input IDs, positional embeddings and token types
    self.word_embeddings = nn.Embedding(num_embeddings=config.vocab_size,
                                        embedding_dim=config_hidden_size, 
                                        padding_idx=config.pad_token_id)
    self.position_embeddings = nn.Embedding(num_embeddings=config.max_position_embeddings,
                                            embedding_dim=config.hidden_size)
    self.token_type_embeddings = nn.Embedding(num_embeddings=config.type_vocab_size,
                                              embedding_dim=config.hidden_size)
    
    # Adds to Layer Normalization and Dropout on inital word embeddings
    self.LayerNorm = nn.LayerNorm(normalized_shape=config.hidden_size, 
                                  eps=config.layer_norm_eps)
    self.dropout = nn.Dropout(p=config.hidden_dropout_prob)

    # position_ids (1, len position emb) is contiguous in memory and exported when serialized
    self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
    self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

  def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
    if input_ids is not None:
      input_shape = input_ids.size()
    else:
      input_shape = inputs_embeds.size()[:-1]

    seq_length = input_shape[1]

    if position_ids is None:
      position_ids = self.position_ids[:, :seq_length]

    if token_type_ids is None:
      token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

    if inputs_embeds is None:
      inputs_embeds = self.word_embeddings(input_ids)
      token_type_embeddings = self.token_type_embeddings(token_type_ids)

    embeddings = inputs_embeds + token_type_embeddings
    
    if self.position_embedding_type == "absolute":
      position_embeddings = self.position_embeddings(position_ids)
      embeddings += position_embeddings

    embeddings = self.LayerNorm(embeddings)
    embeddings = self.dropout(embeddings)
    return embeddings

#### Encoder Stack

In [12]:
class SelfAttention(nn.Module):
  def __init__(self, config):
    super().__init__()
    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
      raise ValueError(
        "The hidden size (%d) is not a multiple of the number of attention "
        "heads (%d)" % (config.hidden_size, config.num_attention_heads)
      )

    self.num_attention_heads = config.num_attention_heads
    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
    self.all_head_size = self.num_attention_heads * self.attention_head_size

    self.query = nn.Linear(config.hidden_size, self.all_head_size)
    self.key = nn.Linear(config.hidden_size, self.all_head_size)
    self.value = nn.Linear(config.hidden_size, self.all_head_size)

    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
      self.max_position_embeddings = config.max_position_embeddings
      self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

  def transpose_for_scores(self, x):
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

  def forward(self,
              hidden_states,
              attention_mask=None,
              head_mask=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=False):
    
    mixed_query_layer = self.query(hidden_states)

    # If this is instantiated as a cross-attention module, the keys
    # and values come from an encoder; the attention mask needs to be
    # such that the encoder's padding tokens are not attended to.
    if encoder_hidden_states is not None:
        mixed_key_layer = self.key(encoder_hidden_states)
        mixed_value_layer = self.value(encoder_hidden_states)
        attention_mask = encoder_attention_mask
    else:
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
      seq_length = hidden_states.size()[1]
      position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
      position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
      distance = position_ids_l - position_ids_r
      positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
      positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

      if self.position_embedding_type == "relative_key":
        relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
        attention_scores = attention_scores + relative_position_scores
      elif self.position_embedding_type == "relative_key_query":
        relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
        relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
        attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
    if attention_mask is not None:
      # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
      attention_scores = attention_scores + attention_mask

    # Normalize the attention scores to probabilities.
    attention_probs = nn.Softmax(dim=-1)(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs)

    # Mask heads if we want to
    if head_mask is not None:
      attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    return outputs
  

class SelfOutput(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states


class AttentionModule(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.self = SelfAttention(config)
    self.output = SelfOutput(config)
    self.pruned_heads = set()

  def prune_head(self, heads):
    if len(heads) == 0:
      return
    heads, index = find_pruneable_heads_and_indices(                            # don't foget this!!
      heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
    )
    # Prune linear layers
    self.self.query = prune_linear_layer(self.self.query, index)
    self.self.key = prune_linear_layer(self.self.key, index)
    self.self.value = prune_linear_layer(self.self.value, index)
    self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

    # Update hyper params and store pruned heads
    self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
    self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
    self.pruned_heads = self.pruned_heads.union(heads)

  def forward(self,
              hidden_states,
              attention_mask=None,
              head_mask=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=False):
    self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            output_attentions)
    attention_output = self.output(self_outputs[0], hidden_states)
    outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
    return outputs





class FeedForward(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
    if isinstance(config.hidden_act, str):
      self.intermediate_act_fn = ACT2FN[config.hidden_act]
    else:
      self.intermediate_act_fn = config.hidden_act

  def forward(self, hidden_states):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.intermediate_act_fn(hidden_states)
    return hidden_states


class Output(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states





class EncoderLayer (nn.Module):
  def __init__(self, config):
    super().__init__()
    self.chunk_size_feed_forward = config.chunk_size_feed_forward
    self.seq_len_dim = 1
    self.attention = AttentionModule(config)
    self.is_decoder = config.is_decoder
    self.add_cross_attention = config.add_cross_attention
    if self.add_cross_attention:
      assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
      self.crossattention = AttentionModule(config)
    self.intermediate = FeedForward(config)
    self.output = Output(config)

  def forward(self,
              hidden_states,
              attention_mask=None,
              head_mask=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=False):
    
    self_attention_outputs = self.attention(hidden_states,
                                            attention_mask,
                                            head_mask,
                                            output_attentions=output_attentions)
    attention_output = self_attention_outputs[0]
    outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
    if self.is_decoder and encoder_hidden_states is not None:
      assert hasattr(self, "crossattention"), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
      cross_attention_outputs = self.crossattention(attention_output,
                                                    attention_mask,
                                                    head_mask,
                                                    encoder_hidden_states,
                                                    encoder_attention_mask,
                                                    output_attentions)
      attention_output = cross_attention_outputs[0]
      outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights

    layer_output = apply_chunking_to_forward(self.feed_forward_chunk,           # don't forget this!!
                                             self.chunk_size_feed_forward, 
                                             self.seq_len_dim, attention_output)
      
    outputs = (layer_output,) + outputs
    return outputs

  def feed_forward_chunk(self, attention_output):
    intermediate_output = self.intermediate(attention_output)
    layer_output = self.output(intermediate_output, attention_output)
    return layer_output




class EncoderStack(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.layer = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])

  def forward(self,
              hidden_states,
              attention_mask=None,
              head_mask=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True):
    all_hidden_states = () if output_hidden_states else None
    all_self_attentions = () if output_attentions else None
    all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
    for i, layer_module in enumerate(self.layer):
      if output_hidden_states:
        all_hidden_states = all_hidden_states + (hidden_states,)
      
      layer_head_mask = head_mask[i] if head_mask is not None else None

      if getattr(self.config, "gradient_checkpointing", False):
        def create_custom_forward(module):
          def custom_forward(*inputs):
            return module(*inputs, output_attentions)
          return custom_forward
        layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module),
                                                          hidden_states,
                                                          attention_mask,
                                                          layer_head_mask,
                                                          encoder_hidden_states,
                                                          encoder_attention_mask,)
      else:
        layer_outputs = layer_module(hidden_states,
                                     attention_mask,
                                     layer_head_mask,
                                     encoder_hidden_states,
                                     encoder_attention_mask,
                                     output_attentions,)
      
      hidden_states = layer_outputs[0]
      if output_attentions:
        all_self_attentions = all_self_attentions + (layer_outputs[1],)
        if self.config.add_cross_attention:
          all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

    if output_hidden_states:
      all_hidden_states = all_hidden_states + (hidden_states,)

    if not return_dict:
      return tuple(v
                   for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
                   if v is not None)
    
    return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states,
                                              hidden_states=all_hidden_states,
                                              attentions=all_self_attentions,
                                              cross_attentions=all_cross_attentions,)

In [11]:
class Pooler(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.activation = nn.Tanh()

  def forward(self, hidden_states):
    # We "pool" the model by simply taking the hidden state corresponding
    # to the first token.
    first_token_tensor = hidden_states[:, 0]
    pooled_output = self.dense(first_token_tensor)
    pooled_output = self.activation(pooled_output)
    return pooled_output

#### Transformer

In [None]:
class TransformerBase(nn.Module):
  """
  The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  cross-attention is added between the self-attention layers, following the architecture described in `Attention is
  all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
  set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
  argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
  input to the forward pass.
  """
  def __init__(self, config, add_pooling_layer=True):
    super().__init__(config)
    self.config = config

    self.embeddings = EmbeddingsLookup(config)
    self.encoder = EncoderStack(config)

    self.pooler = Pooler(config) if add_pooling_layer else None

    self.init_weights() # Don't forget this

  def get_input_embeddings(self):
    return self.embeddings.word_embeddings

  def set_input_embeddings(self, value):
    self.embeddings.word_embeddings = value

  def _prune_heads(self, heads_to_prune):
    """
    Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
    class PreTrainedModel
    """
    for layer, heads in heads_to_prune.items():
      self.encoder.layer[layer].attention.prune_heads(heads)

  def forward(self,
              input_ids=None,
              attention_mask=None,
              token_type_ids=None,
              position_ids=None,
              head_mask=None,
              inputs_embeds=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=None,
              output_hidden_states=None,
              return_dict=None,):
    """
    encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
      Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
      the model is configured as a decoder.
    encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
      Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
      the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
      - 1 for tokens that are **not masked**,
      - 0 for tokens that are **masked**.
    """
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    if input_ids is not None and inputs_embeds is not None:
      raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    elif input_ids is not None:
      input_shape = input_ids.size()
    elif inputs_embeds is not None:
      input_shape = inputs_embeds.size()[:-1]
    else:
      raise ValueError("You have to specify either input_ids or inputs_embeds")

    device = input_ids.device if input_ids is not None else inputs_embeds.device

    if attention_mask is None:
      attention_mask = torch.ones(input_shape, device=device)
    if token_type_ids is None:
      token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
    # ourselves in which case we just need to make it broadcastable to all heads.
    extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) # can we throw this away?

    # If a 2D or 3D attention mask is provided for the cross-attention
    # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
    if self.config.is_decoder and encoder_hidden_states is not None:
      encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
      encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
      if encoder_attention_mask is None:
        encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
      encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
    else:
      encoder_extended_attention_mask = None

    # Prepare head mask if needed
    # 1.0 in head_mask indicate we keep the head
    # attention_probs has shape bsz x n_heads x N x N
    # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
    # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

    embedding_output = self.embeddings(input_ids=input_ids,
                                       position_ids=position_ids,
                                       token_type_ids=token_type_ids,
                                       inputs_embeds=inputs_embeds)
    encoder_outputs = self.encoder(embedding_output,
                                   attention_mask=extended_attention_mask,
                                   head_mask=head_mask,
                                   encoder_hidden_states=encoder_hidden_states,
                                   encoder_attention_mask=encoder_extended_attention_mask,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states,
                                   return_dict=return_dict,)
    sequence_output = encoder_outputs[0]
    pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

    if not return_dict:
      return (sequence_output, pooled_output) + encoder_outputs[1:]

    return BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=sequence_output,
                                                        pooler_output=pooled_output,
                                                        hidden_states=encoder_outputs.hidden_states,
                                                        attentions=encoder_outputs.attentions,
                                                        cross_attentions=encoder_outputs.cross_attentions,)

## 4.2 Language Modeling

### 4.2.1 Token Masking

In [None]:
# TODO how do we make this so that it stays usable for both sentences and words
# technically we just need to adjust the vocab over which to predicht the 
# sentences over, it's the collection of sent reps from the doc/batch calculated
# by the word level model
# Where exactly does this function get applied? In the Dataset class? (Well yes,
# so we can pass proper labels and inputs to the model, right?)
# But it also get applied again inside the model??
mask_probability = 0.15
replace_with_mask_probability = 0.8
replace_randomly_probability = 0.1
keep_token_probability = 0.1

def mask_input_ids(inputs: torch.tensor,
                   tokenizer: transformers.BertTokenizerFast,
                   special_tokens_mask: Optional[torch.Tensor] = None
                   ) -> Tuple[torch.Tensor, torch.Tensor]:
  """
  We specifiy the probability with which to mask token for the language modeling
  task. Generally 15% of tokens are considered for masking. If we just mask 
  naively then a problem arises: some masked token will never actually have been 
  seen at fine-tuning. The solution is to not replace the token with [MASK] 100%
  of the time, instead:
  - 80% of the time, replace the token with [MASK]
    went to the store -> went to the [MASK]
  - 10% of the time, replace random token
    went to the store -> went to the running
  - 10% of the time, keep same
    went to the store -> went to the store
  The same principle is also appilicable with masked sentence prediction, only
  that we have to establish a sentence vocabulary beforehand

  Args:
    inputs: tensor, containing all the token IDs
    special_tokens_mask: tensor, denotes whether a token is a word [0] or a 
      special token [1], [CLS] tokens and padding tokens are all counted as 
      special tokens. This will be used to create a mask so that only actual
      words are considered for random masking

  Returns:
    (masked_inputs, labels, sentence_candidates):
  """
  labels = inputs.clone()
  # Tensor that hold the probability values for the Bernoulli function
  probability_matrix = torch.full(inputs.shape, mask_probability)

  # Get special token indices in order to exclude special tokens from masking
  if special_tokens_mask is None:
    special_tokens_mask = [
      tokenizer.get_special_tokens_mask(entry, already_has_special_tokens=True) for entry in labels.tolist()
    ]
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
  else:
    special_tokens_mask = special_tokens_mask.bool()

  # Fill the probability matrix with 0.0 values where there are special tokens
  probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
  # Draws from a bernoulli distribution where probability_matrix holds the 
  # probablitites for drawing the binary random number. The probablity matrix
  # was previously filled with 0.0 values where special tokens are present so
  # that only tokens containing words/sentences are considered
  masked_indices = torch.bernoulli(probability_matrix).bool()
  # In order to compute the loss only on the masked indices all the unmasked
  # tokens in the label tensor are set to -100
  labels[~masked_indices] = -100

  # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  indices_replaced = torch.bernoulli(torch.full(labels.shape, replace_with_mask_probability)).bool() & masked_indices
  # Since we're dealing with tensors with numerical values we convert the [MASK]
  # token right back to its token_id representation
  inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

  # 10% of the time, we replace masked input tokens with random word
  indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
  inputs[indices_random] = random_words[indices_random]

  return (inputs, labels)

In [None]:
def mask_sentence_representations(num_masked_sentences, document_batch, num_sentences_per_doc, sentence_mask_vector):
"""
Args:
  num_masked_sentences: number of masked sentences in each document -- automate this??? why hardcode it???
  document_batch: tensor, represents the whole batch of documents
  num_sentences_per_doc: list, each entry represents the number of sentences per document
  sentence_mask_vector: tensor, randomly initialized vector to denote which sentence to mask

  where does attention mask come into play??

Returns:
  masked_document_batch: tensor, 

- Combine strenghts of both Hibert and SMITH to also make it possible to seamlessly exchange the decoding scheme
- implement the encoder and sentence masks in such a way learn via either masked sentence prediction or decoder
- combine dynamic sentence block matching (SMITH) with the random masking process of Hibert
- 
A test whether GitHub refreshes this
"""

- Algorithm for masking sentences
- Extracting the masked token embedding from the DocumentEncoder
- Compare vs ground truth and define loss function on it (use huggingface optimizer??)
from transformers import AdamW #example
optimizer = AdamW(...)
https://huggingface.co/transformers/training.html
- SMITH adds loss from sentence encoder + document encoder, make both trainable simultaneously???

### 4.2.2 Language Modeling Head

Define the LM Head(s) and its loss function

#### Word Level Language Modeling Head

In [None]:
class PredictionHeadTransform(nn.Module):
  def __init__(self, config):

  def forward(self, hidden_states):



class LMPredictionHead(nn.Module):
  def __init__(self, config):

  def forward(self, hidden_states):



class OnlyLMHead(nn.Module):
  def __init__(self, config):

  def forward(self, sequence_output):
    

#### Sentence Level Language Modeling Head

In [None]:
class SentenceSimilarityMatrix(nn.Module):
  def __init__(self, config):

  def forward(self):

class SentencePredictionHead(nn.Module):
  def __init__(self, config):

  def forward(self):

### 4.2.3 Alternative: Learning via Decoding

Write a Decoder like in Hibert to have an alternative to the LM Head

## 4.3 Sentence Model

In [None]:
class SentenceModel (torch.nn.Module):
  def __init__():
    # same as document encoder: could this just be a one-liner if EncoderLayer is general enough??
  def forward():

## 4.4 Document Model

In [None]:
class DocumentModel (torch.nn.Module):
  def __init__():
    # isn't this redundant? aren't both the sentence as well as the document encoder the same?
  def forward():

IndentationError: ignored

## 4.5 Hierachical Attention-Based Document Encoder (HATE)

Quick side note: call it "hierarchical attention-based text encoder" (HATE)

In [None]:
class ModelConfig (object):
  def __init__(self,
               vocab_size,
               hidden_size=768,
               num_hidden_layers=12,
               num_attention_heads=12,
               intermediate_size=3072,
               hidden_act="gelu",
               hidden_dropout_prob=0.1,
               attention_probs_dropout_prob=0.1,
               max_position_embeddings=512,
               type_vocab_size=16,
               initializer_range=0.02):
    
    """
    Constructs ModelConfig. Can be instantiated with different values for the
    sentece-level as well as the document-level BERT.
    Taken from https://github.com/google-research/google-research/blob/master/smith/bert/modeling.py
      Args:
      vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
      hidden_size: Size of the encoder layers and the pooler layer.
      num_hidden_layers: Number of hidden layers in the Transformer encoder.
      num_attention_heads: Number of attention heads for each attention layer in
        the Transformer encoder.
      intermediate_size: The size of the "intermediate" (i.e., feed-forward)
        layer in the Transformer encoder.
      hidden_act: The non-linear activation function (function or string) in the
        encoder and pooler.
      hidden_dropout_prob: The dropout probability for all fully connected
        layers in the embeddings, encoder, and pooler.
      attention_probs_dropout_prob: The dropout ratio for the attention
        probabilities.
      max_position_embeddings: The maximum sequence length that this model might
        ever be used with. Typically set this to something large just in case
        (e.g., 512 or 1024 or 2048).
      type_vocab_size: The vocabulary size of the `token_type_ids` passed into
        `BertModel`.
      initializer_range: The stdev of the truncated_normal_initializer for
        initializing all weight matrices.
    Returns:
      Stuff
  """
    self.vocab_size = vocab_size
    self.hidden_size = hidden_size
    self.num_hidden_layers = num_hidden_layers
    self.num_attention_heads = num_attention_heads
    self.hidden_act = hidden_act
    self.intermediate_size = intermediate_size
    self.hidden_dropout_prob = hidden_dropout_prob
    self.attention_probs_dropout_prob = attention_probs_dropout_prob
    self.max_position_embeddings = max_position_embeddings
    self.type_vocab_size = type_vocab_size
    self.initializer_range = initializer_range

  @classmethod
  def from_dict(cls, json_object):
    """Constructs a `BertConfig` from a Python dictionary of parameters."""
    config = BertConfig(vocab_size=None)
    for (key, value) in six.iteritems(json_object):
      config.__dict__[key] = value
    return config

  @classmethod
  def from_json_file(cls, json_file):
    """Constructs a `BertConfig` from a json file of parameters."""
    with tf.gfile.GFile(json_file, "r") as reader:
      text = reader.read()
    return cls.from_dict(json.loads(text))

  def to_dict(self):
    """Serializes this instance to a Python dictionary."""
    output = copy.deepcopy(self.__dict__)
    return output

  def to_json_string(self):
    """Serializes this instance to a JSON string."""
    return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

In [None]:
class HATEModelPretrain (torch.nn.Module):
  def __init__(self, sent_model_config, doc_model_config):
    super().__init__()
    # keep the following two, apparently will be referenced in forward function
    self.sentence_hidden_size = sent_model_config.hidden_size 
    self.document_hidden_size = doc_model_config.hidden_size

    # initiate sentence and document encoder with huggingface modules
    self.sentence_model = BertForMaskedLM(sentence_config)
    self.document_model = BertForMaskedLM(doc_model_config)

    self.sentence_embedding = torch.nn.Linear(sent_model_config.hidden_size, sent_model_config.hidden_size)
    self.document_embedding = torch.nn.Linear(doc_model_config.hidden_size, doc_model_config.hidden_size)

  def forward(docuement: torch.Tensor
              pretraining=False): # Make it so that when the model is in training mode it accepts a batch by default, else it accepts a document or a sentence for inference (hint: a sentence is just a document with length=1)
    sentence_representations = [] # Use fucking tensor!?!?!?! -> huggingface tokenizer has a function to return tensor

    
    for sentence in document:
      sentence_model(input_ids=sentence[0], attention_mask=sentence[1], )

    for sent in doc:
      # Apply the word-level BERT model for masked token prediction
      sentence_output = self.sentence_model(input_ids, labels, out)
      # The ouput object can already return the loss
      sentence_loss = sentence_output.loss
      # Returns the CLS token, has been shown to produce solid sentence representations
      sentence_representation = sentence_output.hidden_states[0]
      """
        The following could be replaced, possibly by an S-BERT approach
        check out transformers.modeling_utils.SequenceSummary
      """
      # Dense layer as in SMITH
      sentence_representation = self.sentence_embedding(sentence_representation)
      # Normalization as in SMITH
      sentence_representation = f.normalize(sentence_representation)
      # Append the features to list of sentence features for the according doc
      sentence_representation_list.append((sentence_output, sentence_loss, sentence_representation))
    
    # TODO: after the sentence model has looped through all sentences in the batch and stored the intermediate representations
    #       in a list we use them as a "vocab" over which to predict masked sentences in the document model.
    #       A tokenizer might not be necessary. We predict over sentence representations right away
    # Hand the sentence embeddings to the custom sentence masking function and return the masked sentences, their respective indices and labels
    masked_sentence_embeddings, sentence_labels, mask_indices = mask_sentences()

    
    # Apply the sentence-level BERT model
    # This one is a classical BERT model with no masked LM head because we use a custom sentence masking function
    document_output = self.document_model(inputs_embeds=masked_sentence_embeddings)
    document_loss = masked_sentence_loss()

    # IMPORTANT TODO: don't have the model output the loss, why would we want that?
    # model should output the embeddings and predictions (???)
    # calculate loss via standard torch criterion for prediction against ground truth
    # once model is trained properly we can access the embeddings (which will hopefully make sense by then)
    outputs = (sentence_loss, document_loss)

    # prune layers? transformers.modeling_utils.find_pruneable_heads_and_indices

    return outputs

# **5.** Pretraining and Finetuning

## 5.1 Pretraining

### 5.1.1 Evidence for the need for Pretraining







- HIBERT https://arxiv.org/pdf/1905.06566.pdf
- Language Model Pre-training for Hierarchical Document Representations https://arxiv.org/pdf/1901.09128.pdf
- Pre-training Tasks for Embedding-based Large-scale Retrieval https://arxiv.org/pdf/2002.03932.pdf

### 5.1.2 Pretraining Data

Describe Dataset used for pretraining, where to get it and how to load it
https://dumps.wikimedia.org/enwiki/20200920/


In [None]:
from datasets import list_datasets, load_dataset
datasets_list = list_datasets()
print("This is how many datasets are available: ", len(datasets_list))
# Load the Wikipedia dump into an OrderedDict
wiki_data_raw = load_dataset('wikipedia', '20200501.en', split='train[:1%]')

This is how many datasets are available:  183


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4417.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=6866.0, style=ProgressStyle(description…


Downloading and preparing dataset wikipedia/20200501.en (download: 16.99 GiB, generated: 17.07 GiB, post-processed: Unknown size, total: 34.06 GiB) to /root/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/f92599dfccab29832c442b82870fa8f6983e5b4ebbf5e6e2dcbe894e325339cd...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=14554.0, style=ProgressStyle(descriptio…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=18307873280.0, style=ProgressStyle(desc…


Dataset wikipedia downloaded and prepared to /root/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/f92599dfccab29832c442b82870fa8f6983e5b4ebbf5e6e2dcbe894e325339cd. Subsequent calls will reuse this data.


In [None]:
# Print out some basic info about the dataset to better understand the structure
print("size: ", wiki_data_raw.dataset_size)
print("column names: ", wiki_data_raw.column_names)
print("shape: ", wiki_data_raw.shape)
print("format: ", wiki_data_raw.format)
print("description: ", wiki_data_raw.description)

size:  18330235071
column names:  ['title', 'text']
shape:  (60784, 2)
format:  {'type': None, 'format_kwargs': {}, 'columns': ['title', 'text'], 'output_all_columns': False}
description:  Wikipedia dataset containing cleaned articles of all languages.
The datasets are built from the Wikipedia dump
(https://dumps.wikimedia.org/) with one split per language. Each example
contains the content of one full Wikipedia article with cleaning to strip
markdown and unwanted sections (references, etc.).



In [None]:
# Create short version dataset containing 10 documents for quick testing
wiki_data_test = wiki_data_raw[:10]
print("first 10 lines of OrderedDict: ", wiki_data_test)
print("first 10 titles: ", wiki_data_test['title'])
print("first 10 articles: ", wiki_data_test['text'])
print("first article: ", wiki_data_test['text'][0])

first 10 lines of OrderedDict:  OrderedDict([('title', ['Yangliuqing', 'Orana Australia Ltd', "St. Mary's Church, Sønderborg", 'Kalitta', 'Where Is Freedom?', 'Latin liturgical rites', 'Fernaldia pandurata', 'Chester Earl Merrow', 'Hightech Information System', 'AD 47']), ('text', ['Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People\'s Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangliuqing nianhua. For more than 400 years, Yangliuqing has in effect specialised in the creation of these woodcuts for the New Year.  wood block prints using vivid colourschemes to portray traditional scenes of children\'s games often interwoven with auspiciouse objects.\n\n, it had 27 residential communities () and 25 villages under its administration.\n\nShi Family Grand Courtyard\n\nShi Family Grand Courtyard (

### 5.1.3 Process Pretraining Data

Some theory on preprocessing: https://mlexplained.com/2019/11/06/a-deep-dive-into-the-wonderful-world-of-preprocessing-in-nlp/

In [None]:
# A likely source of weak model performance might be the tokenization step in 
# split_sentences_by_words() because it ignores punctuation and creates seperate
# tokens for all punctuation symbols. This might not be anything the model can't
# handle as the punctiuation symbols each have their own token ID so they are
# part of the vocab and might even be expected to appear during model training.
# It might be of use to experiment with a batch where all punctiuation is
# removed, especially if loss doesn't decrease as desired during training.
# Maybe use NLTK for punctuation removal!

class PretrainingData(Dataset):
  
  def __init__(self, tokenizer):
    # Initialize pretrained tokenizer
    self.tokenizer = tokenizer #transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
    # Download and process the English Wikipedia dump
    self.wiki_dump = wiki_data_test # dataset #load_dataset('wikipedia', '20200501.en', split='train')
    self.wiki_dump_sentences_split = self.split_documents_by_sentences(self.wiki_dump)
    self.wiki_dump_words_split = self.split_sentences_by_words(self.wiki_dump_sentences_split)
    self.max_sentence_length = self.get_max_sent_length(self.wiki_dump_words_split)


  ##############################################################################
  # Preprocessing steps
  ##############################################################################

  def split_documents_by_sentences(self, data_ordered_dict):
    """
    Splits every document into sentences

    Args:
      data_ordered_dict: A Python orderedDict containing all training documents
        where the raw text dumps are indexed by 'text'

    Returns:
      doc_split_by_sentences: A list of lists where list[_] contains a full
        document and list[_][_] contains a sentence from a given document
    """
    docs_split_by_sentences = []
    for doc in data_ordered_dict['text']:
      sentences = sent_tokenize(doc)
      docs_split_by_sentences.append(sentences)
    return docs_split_by_sentences


  def split_documents_by_sentences_regex(self, data_ordered_dict):
    """
    +++ Remains unused, just kept as a neat backup solution +++

    Splits every document into sentences but using RegEx based on the following
    rules for punctuation marks. 
      1. Preceding character is not a number
      2. Preceding character is not a capital number
      3. Following character is a space

    Args:
      data_ordered_dict: A Python orderedDict containing all training documents
        where the raw text dumps are indexed by 'text'

    Returns:
      doc_split_by_sentences: A list of lists where list[_] contains a full
        document and list[_][_] contains a sentence from a given document
    """
    wiki_data_split = [re.split('(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', doc) for doc in wiki_data_test['text']]
    return docs_split_by_sentences


  def split_sentences_by_words(self, docsSplitBySentences):
    """
    Reads the output of split_documents_by_sentences() and uses splits every
      sentence (list[_][_]) on individual words. The pre-tokenization makes it
      easier to determine the max sentence length in the dataset to which the
      other sentences should be padded

    Args:
      docsSplitBySentences: List of lists
        docsSplitBySentences[_] returns a document
        docsSplitBySentences[_][_] returns a sentence from a document

    Returns:
      tokenized_wiki_data: List of list of lists
        tokenized_wiki_data[_] returns a document
        tokenized_wiki_data[_][_] returns a sentence from a document
        tokenized_wiki_data[_][_][_] returns a word from a sentence from a doc
    """
    # Constructs a list of lists where the length of the first dimension is the
    # number of docuemnts in the batch for use in split_sentences_by_words()
    tokenized_wiki_data = [[] for _ in range(len(docsSplitBySentences))]
    # The new empty list gets zipped with our documents batch, nothing gets cut 
    # off as the new list was made to have the same length as the batch.
    # We iterate over the zipped list and for every entry (representing a document)
    # we iterate again over all sentences in that document in order to tokenize
    # each one and append the sentences (now lists of individual tokens) to the
    # respective entry (the according document).
    for doc, entry in zip(docsSplitBySentences, tokenized_wiki_data):
      for sent in doc:
        entry.append(self.tokenizer.tokenize(sent))
    return tokenized_wiki_data


    # TODO can we replace this with batch_encode_plus/encode_plus/prepare_for_model and move it into __getitem__() for more memory efficiecy?
    # Also add return_tensors='pt'
    # https://huggingface.co/transformers/internal/tokenization_utils.html
    def preprocess_dataset(batch, max_sequence_length):
      preprocessed_batch = []
      for doc in batch:
        preprocessed_batch.append(tokenizer(text=doc, 
                                            padding='max_length', 
                                            max_length=max_sequence_length,
                                            return_special_tokens_mask=True))
      return preprocessed_dataset


  ##############################################################################
  # Helper functions
  ##############################################################################
  
  def get_batch_size(self):
    """
    Returns the number of docuemnts in a batch
    """
    return len(self.wiki_dump_words_split)

  def get_doc_from_batch(self):
    """
    Returns a document
    """
    return 0


  # TODO get_max_doc_length() and get_max_sent_length() can be made into one function

  def get_max_doc_length(self, tokenized_batch):
    """
    Returns the number of senteces the longest document in the batch has

    Args:
      tokenized_batch:

    Returns:
      longest_doc_len: The number of sentences of the longest document, when 
      measured by how many sentences it consits of.
    """
    longest_doc_len = 0
    for doc in tokenized_batch:
      if len(doc) >= longest_doc_len:
        longest_doc_len = len(doc)
      else:
        continue
    return longest_doc_len


  def get_max_sent_length(self, tokenized_batch):
    longest_sentence_len = 0
    for doc in tokenized_batch:
      # Tokenized documents have a 1 added as the last entry instead of a sentence
      for sent in doc[:-1]:
        if len(sent) >= longest_sentence_len:
          longest_sentence_len = len(sent)
        else:
          continue
    return longest_sentence_len

  ##############################################################################
  # Interface functions
  ##############################################################################

  def __len__(self):
    # References a different function with a better name for more readability
    get_batch_size()

  def __getitem__(self):
    # References a different function with a better name for more readability
    get_doc_from_batch()

### 5.1.4 Training Loop

In [None]:
doc1sent1 = "this is supposed to be the very first sentence in the dummy dataset for fast experimentation"
doc1sent2 = "we're using this as a small test case to see how the DataLoader and Dataset class work in detail"
doc1sent3 = "I'm going to add a few more sentences at the end of every document to inspect the batching a bit more"
doc1sent4 = "let's also make the docuemnts of unequal size to see if that's a problem"
doc2sent1 = "also we have two documents here to see which encoding function from huggingface works best"
doc2sent2 = "let's write a very rudimentary and fast tokenization pipeline and encode these texts into tensor outputs"
doc2sent3 = "hopefully the fact that the documents have an unequal number of sentences is not a problem"
doc3sent1 = "and what happens if we even add a third docuemnt"
doc3sent2 = "wow insanity I wonder if our DataLoader can handle such vast amounts of data"

dummy_data = [[doc1sent1, doc1sent2], [doc2sent1, doc2sent2], [doc3sent1, doc3sent2]]

dummy_data_pretokenized = [[] for _ in range(len(dummy_data))]

for doc_id, doc in enumerate(dummy_data):
  for sent in doc:
    dummy_data_pretokenized[doc_id].append(sent.split())


class DummyData(Dataset):
  def __init__(self, tokenizer):
    self.tokenizer = tokenizer #transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
    self.dataset = dummy_data_pretokenized # has the exact same format as our wiki dataset so everything that works here should work there too
    self.max_length = 30
    self.padding_strategy = transformers.tokenization_utils_base.PaddingStrategy('max_length')
  
  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, id): 
    return self.tokenizer.batch_encode_plus(self.dataset[id], # batch-encodes a whole document at [id] at once
                                            padding=self.padding_strategy, 
                                            is_split_into_words=True, 
                                            return_tensors='pt')

tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')

# The following will happen inside train()
dummy_data_processed = DummyData(tokenizer)
dataloader = DataLoader(dummy_data_processed, batch_size=2, shuffle=False)

document_inputs_labels = []
document_attention_masks = []

for batch in dataloader:
  # A batch is a Python Dict with the keys 'input_ids', 'token_type_ids', 
  # 'attention_mask' where each key holds as its value a tensor with the 
  # respective entries for each document
  for doc_input_ids, doc_attention_mask in zip(batch['input_ids'], batch['attention_mask']):
    (inputs, labels) = mask_tokens(doc_input_ids, tokenizer) # inputs and labels for all sentences in a doc
    document_inputs_labels.append((inputs, labels))
    document_attention_masks.append(doc_attention_mask)
    model(torch.Tensor(zip(document_inputs_labels, document_attention_masks))) # seems like a hack, can we do better?

print(temp[0][0]) # from the first document in a batch [0] it prints all its sentences' input itd [0]
#print(temp[0][1][0]) # from the first document in a batch [0] it prints the labels [1] for the first sentence [0]

SyntaxError: ignored

In [None]:
def train():

  file_path = "C:\Users\Marco Moldovan\Documents\Studium\Bachelorarbeit\datasets\enwiki-20200920-pages-articles-multistream1.xml-p1p41242"

  tokenizer = transformers.BertTokenizer()

  dataset = PretrainingData(file_path, tokenizer)

  dataloader = DataLoader(dataset)

  model = HateModel(sentence_configuration, document_configuration)

  model.train()

  criterion = nn.CrossEntropyLoss()

  optimizer = torch.optim.SGD(model.parameters(), lr=lr)

  outputs = model(data)

  loss = criterion(outputs, ground_truth)

  loss.backwards()
  
  # possibly add gradient clipping here: torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
  # https://huggingface.co/transformers/training.html
  
  optimizer.step()

  """
  Training logic works as follows:
  - the dataset consists of documents which consist of sentences which consist
    of words
  - for every sentence some words are masked and for every document some sentences
    are masked
  - for every datapoint (which is a document) the model predicts some masked words
    and some masked sentences
  - these predictions are compared against the ground truth by the criterion 
    and result in some losses
  - the losses for both the masked words and masked sentences are added together
  - the combined loss gets backpropagated through the model

  Important things to figure out:
  - how to build a proper vocabulary for sentences that we can predict over --> take elements from __getitem__().
  - how to store ground truth for sentences --> need masked_sentence_loss() function
  """

SyntaxError: ignored

### 5.1.5 Perform Pretraining

Possibly train here: https://www.rz.ifi.lmu.de/infos/slurm_de.html

In [None]:
# TensorBoard of loss and all other sorts of important info

### 5.1.6 Save Pretrained Model

Imagine not saving your model and losing all training progress #uff

### 5.1.7 Inference

Write all the trained modules in such a compact way that we can just hand it a document or sentence and it will infer its representation

In [None]:
class HATEModel(nn.Module):

## 5.2 Finetuning on MS MARCO Document Ranking

### 5.2.1 Details on the MS MARCO Dataset

- https://microsoft.github.io/msmarco (has list of other document reranking models)

- MS MARCO: A Human Generated MAchine Reading COmprehension Dataset https://arxiv.org/pdf/1611.09268.pdf

- https://microsoft.github.io/TREC-2020-Deep-Learning/



also relevant for finetuning:

- RepBERT: Contextualized Text Embeddings for First-Stage Retrieval https://arxiv.org/pdf/2006.15498.pdf
- TwinBERT: Distilling Knowledge to Twin-Structured BERT Models for Efficient Retrieval (1) https://arxiv.org/pdf/2002.06275.pdf (2) https://github.com/deepampatel/TwinBert/blob/master/TwinBert.ipynb

Alternative Tasks:
- https://research.google/tools/datasets/ Wikipedia and arXiv similarity triplets
- https://github.com/LiqunW/Long-document-dataset with this paper: Long Document Classification From Local Word Glimpses via Recurrent Attention Learning https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8675939
- https://datasets.quantumstat.com/

### 5.2.2 Load and Tokenize Finetuning Dataset

### 5.2.3 Document Ranking Task

- https://github.com/microsoft/MSMARCO-Document-Ranking

- Baseline: Longformer: The Long-Document Transformer (1) https://arxiv.org/pdf/2004.05150.pdf (2) https://github.com/isekulic/longformer-marco/blob/master/src/TransformerMarco.py
Baseline: Conformer-Kernel with Query Term Independence for Document Retrieval (1) https://arxiv.org/pdf/2007.10434.pdf 

### 5.2.4 *TODO* 
- loss function for ranking task -> cosine contrastive loss? check this TwinBERT implementation https://github.com/deepampatel/TwinBert/blob/master/TwinBert.ipynb
- what training objective is actually used in the finetuning here? is it already a ranking task? could the loss function be the same as in pretraining? 
- ranking algorithm for the document representations

# **6.** Experiments

## 6.1 Document Ranking

## 6.2 Answer Passage Highlighting

# **7.** Conclusion

# **8.** Outlook

- Deeper look at attention: transformers are graph neural networks https://thegradient.pub/transformers-are-graph-neural-networks/
- Generalization to Hopfield Nets: (1) https://ml-jku.github.io/hopfield-layers/ (2) http://franksworld.com/2020/08/10/explaining-the-paper-hopfield-networks-is-all-you-need/ (3) https://analyticsindiamag.com/modern-hopfield-network-transformers-attention/ (4) https://towardsdatascience.com/hopfield-networks-are-useless-heres-why-you-should-learn-them-f0930ebeadcd
- Implications of these findings for language representation?
- Check http://nlp.seas.harvard.edu/code/ for more ideas