<a href="https://colab.research.google.com/github/marcomoldovan/hierarchical-text-encoder/blob/master/hierarchical_transformer_based_document_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Hierarchical Transformer-based Tocument Encoder** 

We present a multi-purpose hierarchical tocument encoder (HATE) based on stacking sentence-level and document-level transformers on top of each other. By doing this we try to capture as many semantic facets of a given, arbitrarily long document. The model implicitly learns contextualized word and sentence representations as well a robust document representation. All encodings live in the same representation space and therefore similarity metrics such as cosine similarity can be applied on all levels of representations interchangeably. This has the purpose that only one model would have to be trained in order to encode both a query and a candidate document as well as all its contextualized sentences in one representation space where similarity measures could be applied in order to retrieve most relevant documents as well as scoring answer passages candidates by their likelihood of relevancy. Due to the robust sentence embeddings that are also produced, the model could be used for document segmentation as well: since intuitively there would be a semantic break between paragraphs this would also be reflected in the representations of the sentences. One could apply an algorithms that separates unstructured texts into paragraphs at points where the likelihood of a semantic break, and therefore the start of a separate paragraph is high.

# **1.** Introduction



# **2.** Background

## 2.1 Statistical foundations of machine learning

(1) https://towardsdatascience.com/the-statistical-foundations-of-machine-learning-973c356a95f

More on regression, classification, etc. in a probabilistic context. Show that these problems are essentially MAP.

## 2.2 Mathematics of optimization for deep learning

(1) https://towardsdatascience.com/the-mathematics-of-optimization-for-deep-learning-11af2b1fda30

- Optimization: (1) visualizing loss landscape https://arxiv.org/pdf/1712.09913.pdf
- Momentum based optimizers
- Dropout (1) https://arxiv.org/pdf/1207.0580.pdf
- Batch normalization (1) https://arxiv.org/abs/1502.03167 (2) https://papers.nips.cc/paper/7515-how-does-batch-normalization-help-optimization.pdf
- Weight initialization
- Reguralization (1) NLP specific: https://mlexplained.com/2018/03/02/regularization-techniques-for-natural-language-processing-with-code-examples/

## 2.3 Representation Learning

Touch briefly on the theory of representation learning, independent of language. Main focus on the Bengio paper: https://arxiv.org/pdf/1206.5538.pdf Also refer to representation learning slides from DL&AI course from last semester as an appropriate introduction.

## 2.4 Language Models

- Language modeling via auto-encoding or auto-regressive methods in general


- Embeddings in Language Models (1) https://jalammar.github.io/skipgram-recommender-talk/Text (2) https://dspace.mit.edu/handle/1721.1/118079
-  Word embeddings (1) https://ruder.io/word-embeddings-1/index.html (2) https://ruder.io/word-embeddings-softmax/index.html (3) https://ruder.io/secret-word2vec/index.html (4) https://ruder.io/word-embeddings-2017/index.html (5) https://jalammar.github.io/illustrated-word2vec/ (6) Glove https://mlexplained.com/2018/04/29/paper-dissected-glove-global-vectors-for-word-representation-explained/ (7) ELMo https://mlexplained.com/2018/06/15/paper-dissected-deep-contextualized-word-representations-explained/ (8) https://p.migdal.pl/2017/01/06/king-man-woman-queen-why.html (9) https://lilianweng.github.io/lil-log/2017/10/15/learning-word-embedding.html#loss-functions
- Sentence embeddings (1) https://supernlp.github.io/2018/11/26/sentreps/ (2) https://mlexplained.com/2017/12/28/an-overview-of-sentence-embedding-methods/ (3) https://medium.com/huggingface/universal-word-sentence-embeddings-ce48ddc8fc3a
- Document Embeddings (1) https://towardsdatascience.com/document-embedding-techniques-fed3e7a6a25d#1409 (2) https://graphaware.com/nlp/2018/09/03/advanced-document-representation.html

- General early language models that are based on DL: RNNs/LSTM (touch shortly) (1) https://arxiv.org/pdf/1312.6026.pdf (2) https://distill.pub/2019/memorization-in-rnns/
- On the diffictuly of training recurrent neural networks https://arxiv.org/pdf/1211.5063.pdf
- Transition to attention mechanism, at first in RNNs
- Why these large DL-based models are so important for transfer learning in NLP (1) https://ruder.io/transfer-learning/ (2) https://thegradient.pub/nlp-imagenet/ (3) very linguistic study of word embeddings for transfer tasks https://arxiv.org/pdf/1903.08855.pdf

# **3.** Related Work

## 3.1 Attention and why it's all you need

- Attention is all you need (1) https://mlexplained.com/2017/12/29/attention-is-all-you-need-explained/ (2) https://arxiv.org/pdf/1706.03762.pdf (3) https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec (4) https://blog.floydhub.com/attention-mechanism/ (5) https://jalammar.github.io/illustrated-transformer/ (6) https://distill.pub/2016/augmented-rnns/ (7) https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html
- BERT (1) https://mlexplained.com/2019/01/07/paper-dissected-bert-pre-training-of-deep-bidirectional-transformers-for-language-understanding-explained/ (2) https://arxiv.org/pdf/1810.04805.pdf (3) https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/ (4) https://jalammar.github.io/illustrated-bert/
- What Does BERT Look At? An Analysis of BERT’s Attention 
https://arxiv.org/pdf/1906.04341.pdf https://arxiv.org/abs/1909.10430v2
- BERT raw embeddings https://arxiv.org/abs/1909.00512v1 https://towardsdatascience.com/examining-berts-raw-embeddings-fd905cb22df7
- Sentence embeddings revisited: BERT methods -> transition to intuition for my document representation model
- Call-back to section 2.4 where I touch on classical embeddings and compare to deep, high-parameter, attention based models such as BERT
- Why these large pretrained models are so important for transfer learning in NLP (1) https://ruder.io/transfer-learning/ (2) https://thegradient.pub/nlp-imagenet/


## 3.2 Hierarchical Attention-Models

- SMITH (1) https://arxiv.org/pdf/2004.12297v1.pdf (2) https://github.com/dmolony3/SMITH
- HIBERT (1) https://arxiv.org/pdf/1905.06566.pdf (2) https://github.com/liangsi03/hibert_model

## 3.3 Information Retrieval

- TANDA: Transfer and Adapt Pre-Trained Transformer Modelsfor Answer Sentence Selection https://arxiv.org/pdf/1911.04118.pdf

# **4.** The Model

Introduce all necessary top-level imports and install the transformer module from huggingface

In [1]:
!pip install transformers #installs transformer module from huggingface
!pip install datasets #installs dataset module from huggingface
!pip install tokenizers #installs tokenizer module from huggingface

import re
import math

import nltk
from nltk import sent_tokenize
nltk.download('punkt')

import torch
import torch.nn as nn
import torch.nn.functional as f
import torchtext.datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import transformers
from transformers import BertConfig, BertModel, BertForMaskedLM
from transformers import configuration_utils
from transformers.activations import ACT2FN 
from transformers.file_utils import ModelOutput
from transformers.modeling_utils import PreTrainedModel, apply_chunking_to_forward
from transformers.configuration_utils import PretrainedConfig


from pprint import pprint
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## 4.1 Transformer

#### Internal Embedding Lookup

In [2]:
class EmbeddingsLookup(nn.Module):
  def __init__(self, config):
    super().__init__()
    # Initialize the lookup matrix for input IDs, positional embeddings and token types
    self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
    self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
    self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
    
    # Adds to Layer Normalization and Dropout on inital word embeddings
    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # position_ids (1, len position emb) is contiguous in memory and exported when serialized
    self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
    self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

  def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
    if input_ids is not None:
      input_shape = input_ids.size()
    else:
      input_shape = inputs_embeds.size()[:-1]

    seq_length = input_shape[1]

    if position_ids is None:
      position_ids = self.position_ids[:, :seq_length]

    if token_type_ids is None:
      token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

    if inputs_embeds is None:
      inputs_embeds = self.word_embeddings(input_ids)
    token_type_embeddings = self.token_type_embeddings(token_type_ids)

    embeddings = inputs_embeds + token_type_embeddings
    if self.position_embedding_type == "absolute":
      position_embeddings = self.position_embeddings(position_ids)
      embeddings += position_embeddings
    embeddings = self.LayerNorm(embeddings)
    embeddings = self.dropout(embeddings)
    return embeddings

#### Encoder Stack

In [3]:
class BaseModelOutputWithCrossAttentions(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states and attentions.
    Args:
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.
            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None

In [4]:
class SelfAttention(nn.Module):
  def __init__(self, config):
    super().__init__()
    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
      raise ValueError(
        "The hidden size (%d) is not a multiple of the number of attention "
        "heads (%d)" % (config.hidden_size, config.num_attention_heads)
      )

    self.num_attention_heads = config.num_attention_heads
    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
    self.all_head_size = self.num_attention_heads * self.attention_head_size

    self.query = nn.Linear(config.hidden_size, self.all_head_size)
    self.key = nn.Linear(config.hidden_size, self.all_head_size)
    self.value = nn.Linear(config.hidden_size, self.all_head_size)

    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
      self.max_position_embeddings = config.max_position_embeddings
      self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

  def transpose_for_scores(self, x):
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

  def forward(self,
              hidden_states,
              attention_mask=None,
              head_mask=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=False):
    
    mixed_query_layer = self.query(hidden_states)

    # If this is instantiated as a cross-attention module, the keys
    # and values come from an encoder; the attention mask needs to be
    # such that the encoder's padding tokens are not attended to.
    if encoder_hidden_states is not None:
        mixed_key_layer = self.key(encoder_hidden_states)
        mixed_value_layer = self.value(encoder_hidden_states)
        attention_mask = encoder_attention_mask
    else:
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

    query_layer = self.transpose_for_scores(mixed_query_layer)    
    key_layer = self.transpose_for_scores(mixed_key_layer)
    value_layer = self.transpose_for_scores(mixed_value_layer)

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
      seq_length = hidden_states.size()[1]
      position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
      position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
      distance = position_ids_l - position_ids_r
      positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
      positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

      if self.position_embedding_type == "relative_key":
        relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
        attention_scores = attention_scores + relative_position_scores
      elif self.position_embedding_type == "relative_key_query":
        relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
        relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
        attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
    if attention_mask is not None:
      # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
      attention_scores = attention_scores + attention_mask

    # Normalize the attention scores to probabilities.
    attention_probs = nn.Softmax(dim=-1)(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs)

    # Mask heads if we want to
    if head_mask is not None:
      attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    return outputs
  

class SelfOutput(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states


class AttentionModule(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.self = SelfAttention(config)
    self.output = SelfOutput(config)
    self.pruned_heads = set()

  def prune_head(self, heads):
    if len(heads) == 0:
      return
    heads, index = find_pruneable_heads_and_indices(                            # don't foget this!!
      heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
    )
    # Prune linear layers
    self.self.query = prune_linear_layer(self.self.query, index)
    self.self.key = prune_linear_layer(self.self.key, index)
    self.self.value = prune_linear_layer(self.self.value, index)
    self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

    # Update hyper params and store pruned heads
    self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
    self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
    self.pruned_heads = self.pruned_heads.union(heads)

  def forward(self,
              hidden_states,
              attention_mask=None,
              head_mask=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=False):
    self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            output_attentions)
    attention_output = self.output(self_outputs[0], hidden_states)
    outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
    return outputs





class FeedForward(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
    if isinstance(config.hidden_act, str):
      self.intermediate_act_fn = ACT2FN[config.hidden_act]
    else:
      self.intermediate_act_fn = config.hidden_act

  def forward(self, hidden_states):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.intermediate_act_fn(hidden_states)
    return hidden_states


class Output(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states





class EncoderLayer (nn.Module):
  def __init__(self, config):
    super().__init__()
    self.chunk_size_feed_forward = config.chunk_size_feed_forward
    self.seq_len_dim = 1
    self.attention = AttentionModule(config)
    self.is_decoder = config.is_decoder
    self.add_cross_attention = config.add_cross_attention
    if self.add_cross_attention:
      assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
      self.crossattention = AttentionModule(config)
    self.intermediate = FeedForward(config)
    self.output = Output(config)

  def forward(self,
              hidden_states,
              attention_mask=None,
              head_mask=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=False):
    
    self_attention_outputs = self.attention(hidden_states,
                                            attention_mask,
                                            head_mask,
                                            output_attentions=output_attentions)
    attention_output = self_attention_outputs[0]
    outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
    if self.is_decoder and encoder_hidden_states is not None:
      assert hasattr(self, "crossattention"), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
      cross_attention_outputs = self.crossattention(attention_output,
                                                    attention_mask,
                                                    head_mask,
                                                    encoder_hidden_states,
                                                    encoder_attention_mask,
                                                    output_attentions)
      attention_output = cross_attention_outputs[0]
      outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights

    layer_output = apply_chunking_to_forward(self.feed_forward_chunk,           # don't forget this!!
                                             self.chunk_size_feed_forward, 
                                             self.seq_len_dim, attention_output)
      
    outputs = (layer_output,) + outputs
    return outputs

  def feed_forward_chunk(self, attention_output):
    intermediate_output = self.intermediate(attention_output)
    layer_output = self.output(intermediate_output, attention_output)
    return layer_output




class EncoderStack(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.layer = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])

  def forward(self,
              hidden_states,
              attention_mask=None,
              head_mask=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True):
    all_hidden_states = () if output_hidden_states else None
    all_self_attentions = () if output_attentions else None
    all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
    for i, layer_module in enumerate(self.layer):
      if output_hidden_states:
        all_hidden_states = all_hidden_states + (hidden_states,)
      
      layer_head_mask = head_mask[i] if head_mask is not None else None

      if getattr(self.config, "gradient_checkpointing", False):
        def create_custom_forward(module):
          def custom_forward(*inputs):
            return module(*inputs, output_attentions)
          return custom_forward
        layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module),
                                                          hidden_states,
                                                          attention_mask,
                                                          layer_head_mask,
                                                          encoder_hidden_states,
                                                          encoder_attention_mask,)
      else:
        layer_outputs = layer_module(hidden_states,
                                     attention_mask,
                                     layer_head_mask,
                                     encoder_hidden_states,
                                     encoder_attention_mask,
                                     output_attentions,)
      
      hidden_states = layer_outputs[0]
      if output_attentions:
        all_self_attentions = all_self_attentions + (layer_outputs[1],)
        if self.config.add_cross_attention:
          all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

    if output_hidden_states:
      all_hidden_states = all_hidden_states + (hidden_states,)

    if not return_dict:
      return tuple(v
                   for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
                   if v is not None)
    
    return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states,
                                              hidden_states=all_hidden_states,
                                              attentions=all_self_attentions,
                                              cross_attentions=all_cross_attentions,)

In [5]:
class Pooler(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.activation = nn.Tanh()

  def forward(self, hidden_states):
    # We "pool" the model by simply taking the hidden state corresponding
    # to the first token.
    first_token_tensor = hidden_states[:, 0]
    pooled_output = self.dense(first_token_tensor)
    pooled_output = self.activation(pooled_output)
    return pooled_output

#### Transformer

In [6]:
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
  """
  Base class for model's outputs that also contains a pooling of the last hidden states.
  Args:
    last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
      Sequence of hidden-states at the output of the last layer of the model.
    pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
      Last layer hidden-state of the first token of the sequence (classification token) further processed by a
      Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
      prediction (classification) objective during pretraining.
    hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
      Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
      of shape :obj:`(batch_size, sequence_length, hidden_size)`.
      Hidden-states of the model at the output of each layer plus the initial embedding outputs.
    attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
      Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
      sequence_length, sequence_length)`.
      Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
      heads.
    cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
      Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
      sequence_length, sequence_length)`.
      Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
      weighted average in the cross-attention heads.
  """
  last_hidden_state: torch.FloatTensor = None
  pooler_output: torch.FloatTensor = None
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  attentions: Optional[Tuple[torch.FloatTensor]] = None
  cross_attentions: Optional[Tuple[torch.FloatTensor]] = None

In [7]:
class TransformerPreTrainedModel(PreTrainedModel):
  """
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  models.
  """
  _keys_to_ignore_on_load_missing = [r"position_ids"]

  def _init_weights(self, module):
    """ Initialize the weights """
    if isinstance(module, (nn.Linear, nn.Embedding)):
      # Slightly different from the TF version which uses truncated_normal for initialization
      # cf https://github.com/pytorch/pytorch/pull/5617
      module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
      module.bias.data.zero_()
      module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
      module.bias.data.zero_()

In [8]:
class TransformerConfig(PretrainedConfig):
    """
    This is the configuration class to store the configuration of a TransformerModel. 
    It is used to instantiate a BERT model according to the specified arguments,
    defining the model architecture. Instantiating a configuration with the defaults 
    will yield a similar configuration to that of the BERT `bert-base-uncased 
    <https://huggingface.co/bert-base-uncased>`__ architecture. Configuration objects 
    inherit from :class:`~transformers.PretrainedConfig` and can be used to control 
    the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 
    for more information.
    Args:
        vocab_size (:obj:`int`, `optional`, defaults to 30522):
            Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
            :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
            :class:`~transformers.TFBertModel`.
        hidden_size (:obj:`int`, `optional`, defaults to 768):
            Dimensionality of the encoder layers and the pooler layer.
        num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (:obj:`int`, `optional`, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        intermediate_size (:obj:`int`, `optional`, defaults to 3072):
            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
        hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string,
            :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
        hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
        attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
            The dropout ratio for the attention probabilities.
        max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
            The maximum sequence length that this model might ever be used with. Typically set this to something large
            just in case (e.g., 512 or 1024 or 2048).
        type_vocab_size (:obj:`int`, `optional`, defaults to 2):
            The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
            :class:`~transformers.TFBertModel`.
        initializer_range (:obj:`float`, `optional`, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
            The epsilon used by the layer normalization layers.
        gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
            If True, use gradient checkpointing to save memory at the expense of slower backward pass.
        position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
            Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
            :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
            :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
            <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
            `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
            <https://arxiv.org/abs/2009.13658>`__.
    Examples::
        >>> from transformers import BertModel, BertConfig
        >>> # Initializing a BERT bert-base-uncased style configuration
        >>> configuration = BertConfig()
        >>> # Initializing a model from the bert-base-uncased style configuration
        >>> model = BertModel(configuration)
        >>> # Accessing the model configuration
        >>> configuration = model.config
    """
    model_type = "bert"

    def __init__(
        self,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        gradient_checkpointing=False,
        position_embedding_type="absolute",
        **kwargs
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.gradient_checkpointing = gradient_checkpointing
        self.position_embedding_type = position_embedding_type

In [9]:
class TransformerBase(TransformerPreTrainedModel):
  """
  The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  cross-attention is added between the self-attention layers, following the architecture described in `Attention is
  all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
  set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
  argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
  input to the forward pass.
  """
  def __init__(self, config, add_pooling_layer=True):
    super().__init__(config)
    self.config = config

    self.embeddings = EmbeddingsLookup(config)
    self.encoder = EncoderStack(config)

    self.pooler = Pooler(config) if add_pooling_layer else None

    #self.init_weights() # Don't forget this

  def get_input_embeddings(self):
    return self.embeddings.word_embeddings

  def set_input_embeddings(self, value):
    self.embeddings.word_embeddings = value

  def _prune_heads(self, heads_to_prune):
    """
    Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
    class PreTrainedModel
    """
    for layer, heads in heads_to_prune.items():
      self.encoder.layer[layer].attention.prune_heads(heads)

  def forward(self,
              input_ids=None,
              attention_mask=None,
              token_type_ids=None,
              position_ids=None,
              head_mask=None,
              inputs_embeds=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              output_attentions=None,
              output_hidden_states=None,
              return_dict=None,):
    """
    encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
      Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
      the model is configured as a decoder.
    encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
      Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
      the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
      - 1 for tokens that are **not masked**,
      - 0 for tokens that are **masked**.
    """
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    if input_ids is not None and inputs_embeds is not None:
      raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    elif input_ids is not None:
      input_shape = input_ids.size()
    elif inputs_embeds is not None:
      input_shape = inputs_embeds.size()[:-1]
    else:
      raise ValueError("You have to specify either input_ids or inputs_embeds")

    device = input_ids.device if input_ids is not None else inputs_embeds.device

    if attention_mask is None:
      attention_mask = torch.ones(input_shape, device=device)
    if token_type_ids is None:
      token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
    # ourselves in which case we just need to make it broadcastable to all heads.
    extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) # can we throw this away?

    # If a 2D or 3D attention mask is provided for the cross-attention
    # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
    if self.config.is_decoder and encoder_hidden_states is not None:
      encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
      encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
      if encoder_attention_mask is None:
        encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
      encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
    else:
      encoder_extended_attention_mask = None

    # Prepare head mask if needed
    # 1.0 in head_mask indicate we keep the head
    # attention_probs has shape bsz x n_heads x N x N
    # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
    # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

    embedding_output = self.embeddings(input_ids=input_ids,
                                       position_ids=position_ids,
                                       token_type_ids=token_type_ids,
                                       inputs_embeds=inputs_embeds)
    encoder_outputs = self.encoder(embedding_output,
                                   attention_mask=extended_attention_mask,
                                   head_mask=head_mask,
                                   encoder_hidden_states=encoder_hidden_states,
                                   encoder_attention_mask=encoder_extended_attention_mask,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states,
                                   return_dict=return_dict,)
    sequence_output = encoder_outputs[0]
    pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

    if not return_dict:
      return (sequence_output, pooled_output) + encoder_outputs[1:]

    return BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=sequence_output,
                                                        pooler_output=pooled_output,
                                                        hidden_states=encoder_outputs.hidden_states,
                                                        attentions=encoder_outputs.attentions,
                                                        cross_attentions=encoder_outputs.cross_attentions,)

## 4.2 Language Modeling

### 4.2.1 Input Masking Algorithms

#### Token Masking

In [10]:
# word_mask_probability = 0.15
# replace_with_mask_probability = 0.8
# replace_randomly_probability = 0.1
# keep_token_probability = 0.1

def mask_input_ids(inputs: torch.tensor,
                   tokenizer: transformers.BertTokenizerFast,
                   special_tokens_mask: Optional[torch.Tensor] = None,
                   word_mask_probability = 0.15,
                   replace_with_mask_probability = 0.8,
                   replace_randomly_probability = 0.1,
                   keep_token_probability = 0.1
                   ) -> Tuple[torch.Tensor, torch.Tensor]:
  """
  We specifiy the probability with which to mask token for the language modeling
  task. Generally 15% of tokens are considered for masking. If we just mask 
  naively then a problem arises: some masked token will never actually have been 
  seen at fine-tuning. The solution is to not replace the token with [MASK] 100%
  of the time, instead:
  - 80% of the time, replace the token with [MASK]
    went to the store -> went to the [MASK]
  - 10% of the time, replace random token
    went to the store -> went to the running
  - 10% of the time, keep same
    went to the store -> went to the store
  The same principle is also appilicable with masked sentence prediction, only
  that we have to establish a sentence vocabulary beforehand

  Args:
    inputs: tensor, containing all the token IDs
    special_tokens_mask: tensor, denotes whether a token is a word [0] or a 
      special token [1], [CLS] tokens and padding tokens are all counted as 
      special tokens. This will be used to create a mask so that only actual
      words are considered for random masking

  Returns:
    masked_inputs:
    labels:
  """
  labels = inputs.clone()
  # Tensor that hold the probability values for the Bernoulli function
  probability_matrix = torch.full(inputs.shape, word_mask_probability)

  # Get special token indices in order to exclude special tokens from masking
  if special_tokens_mask is None:
    special_tokens_mask = [
      tokenizer.get_special_tokens_mask(entry, already_has_special_tokens=True) for entry in labels.tolist()
    ]
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
  else:
    special_tokens_mask = special_tokens_mask.bool()

  # Fill the probability matrix with 0.0 values where there are special tokens
  probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
  # Draws from a bernoulli distribution where probability_matrix holds the 
  # probablitites for drawing the binary random number. The probablity matrix
  # was previously filled with 0.0 values where special tokens are present so
  # that only tokens containing words/sentences are considered
  masked_indices = torch.bernoulli(probability_matrix).bool()
  # In order to compute the loss only on the masked indices all the unmasked
  # tokens in the label tensor are set to -100
  labels[~masked_indices] = -100

  # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  indices_replaced = torch.bernoulli(torch.full(labels.shape, replace_with_mask_probability)).bool() & masked_indices
  # Since we're dealing with tensors with numerical values we convert the [MASK]
  # token right back to its token_id representation
  inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

  # 10% of the time, we replace masked input tokens with random word
  indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
  inputs[indices_random] = random_words[indices_random]

  return (inputs, labels)

#### Embedding Masking

In [11]:
def mask_input_embeddings(input_embeddings: torch.tensor,
                          special_embeddings_mask: torch.tensor,
                          sentence_mask_probability = 0.15):
  """
  Randomly masks sentences with a probability of 15%. The masked sentence
  embeddings are replaced with a random tensor and the original embedding will
  be stored in a labels tensor that has the same size as the input tensor. The
  ground truth embedding will sit at the same position as is did in the input
  tensor to make it easier to identify the correct ground truth for loss
  computing.

  Args:
    input_embeddings: A torch.tensor containing all sentence embeddings computed
      by the Sentence Model for a given batch. The size of the tensor is
      [batch_size, max_doc_length, embedding_size]. Note that the documents are
      already padded to the length of the longest document in the batch.
    special_embeddings_mask: A torch.tensor of the same size as input_embeddings
      [batch_size, max_doc_length] which hold 0s where there is a real sentence 
      present and 1s where there is a special token embedding, that includes 
      CLS, SEP and PAD tokens.
  Returns:
    masked_input_embeddings: Same shape as input embeddings, only that it holds
      a random tensor wherever a sentence embedding was masked.
    label_embeddings: Same shape as the masked_input_embeddings but all entries 
      are filled with 0s except where there is a masked sentence embedding. That
      entry will be filled with the original input embedding.
    label_mask: torch.BoolTensor
  """
  masked_input_embeddings = input_embeddings.clone()
  label_embeddings = torch.zeros_like(input_embeddings)
  label_mask = torch.zeros_like(special_embeddings_mask)

  probability_matrix = torch.full(special_embeddings_mask.shape, sentence_mask_probability)

  probability_matrix.masked_fill_(special_embeddings_mask, value=0.0)

  masked_indices = torch.bernoulli(probability_matrix).bool()

  document_counter = 0
  sentence_counter = 0

  for document in input_embeddings:
    sentence_counter = 0
    for sentence in document:
      if masked_indices[document_counter][sentence_counter]:
        label_embeddings[document_counter][sentence_counter] = input_embeddings[document_counter][sentence_counter]
        label_mask[document_counter][sentence_counter] = 1.0
        masked_input_embeddings[document_counter][sentence_counter] = torch.randn_like(input_embeddings[document_counter][sentence_counter])
      sentence_counter += 1
    document_counter += 1

  label_embeddings[~masked_indices] = 0
  label_mask.bool()

  return (input_embeddings, masked_input_embeddings, label_embeddings, label_mask)

- Algorithm for masking sentences
- Extracting the masked token embedding from the DocumentEncoder
- Compare vs ground truth and define loss function on it (use huggingface optimizer??)
from transformers import AdamW #example
optimizer = AdamW(...)
https://huggingface.co/transformers/training.html
- SMITH adds loss from sentence encoder + document encoder, make both trainable simultaneously???

### 4.2.2 Language Modeling Head

Define the LM Head(s) and its loss function

#### Word Level Language Modeling Head

In [12]:
class PredictionHeadTransform(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    if isinstance(config.hidden_act, str):
      self.transform_act_fn = ACT2FN[config.hidden_act]
    else:
      self.transform_act_fn = config.hidden_act
    self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

  def forward(self, hidden_states):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.transform_act_fn(hidden_states)
    hidden_states = self.LayerNorm(hidden_states)
    return hidden_states



class LMPredictionHead(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.transform = PredictionHeadTransform(config)

    # The output weights are the same as the input embeddings, but there is
    # an output-only bias for each token.
    self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    self.bias = nn.Parameter(torch.zeros(config.vocab_size))

    # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
    self.decoder.bias = self.bias

  def forward(self, hidden_states):
    hidden_states = self.transform(hidden_states)
    hidden_states = self.decoder(hidden_states)
    return hidden_states



class OnlyLMHead(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.predictions = LMPredictionHead(config)

  def forward(self, sequence_output):
    prediction_scores = self.predictions(sequence_output)
    return prediction_scores

#### Sentence Level Language Modeling Head

In [14]:
class SentencePredictionHead(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)
    self.LayerNorm = nn.LayerNorm() # Compare to the word-level LM Head

  def forward(self, masked_sentence_prediction, label_embeddings, label_mask):
    """
    In order to compute the sentence-level prediction loss we apply a similar
    loss function as during the word-level masked word prediction tast. Since
    we don't have a fixed size vocabulary over the training sentences we have
    to build a dynamic sentence vocabulary.
    Args:
      masked_sentence_prediction [batch_size, max_doc_length, hidden_size]:
      label_embeddings [batch_size, max_doc_length, hidden_size]:
      label_mask [batch_size, max_doc_length, hidden_size]:
    Returns:
      per_batch_sentence_loss:
      per_example_sentence_loss:
    """
    # Zero out all sentence embeddings that aren't at a masked position
    masked_sentence_prediction[~label_mask] = 0.0
    label_embeddings[~label_mask] = 0.0
    
    # Tensors will have size [batch_size * padded_doc_length, hidden_size]
    masked_sentence_prediction = torch.reshape(masked_sentence_prediction, (config.batch_size * config.doc_length, -1))
    label_embeddings = torch.reshape(label_embeddings, (config.batch_size * config.doc_length, -1))
    label_mask = torch.reshape(label_mask, (config.batch_size * config.doc_length, -1))

    output_embedding_list = []
    label_embedding_list = []
    label_mask_index =  0

    for mask_index in label_mask:
      if mask_index.item():
        output_embedding_list.append(masked_sentence_prediction[mask_index_int])
        label_embedding_list.append(label_embeddings[mask_index_int])
      mask_index_int += 1

    output_embeddings = torch.stack(output_embedding_list, dim=0)
    label_embeddings = torch.stack(label_embedding_list, dim=0)

    output_embeddings = dense(output_embeddings)
    output_embeddings = LayerNorm(output_embeddings)

    # TODO add bias like in SMITH?

    logits = torch.matmul(output_embeddings, torch.transpose(input=label_embeddings, dim0=0, dim1=1))
    log_probabilities = nn.functional.log_softmax(logits, dim=1)
    labels_one_hot = torch.diag(torch.Tensor([1] * log_probs.size()[0]))


    return (logits, per_batch_sentence_loss, per_example_sentence_loss)

### 4.2.3 Alternative: Learning via Decoding

Write a Decoder like in Hibert to have an alternative to the LM Head

## 4.3 Sentence Model

In [13]:
class SentenceModelingOutput(ModelOutput): #inherits from the huggingface class
  """
    Return object for Sentence Model.

    Args:
      loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
        Masked language modeling (MLM) loss.
      logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
      hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
        Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
        of shape :obj:`(batch_size, sequence_length, hidden_size)`.
        Hidden-states of the model at the output of each layer plus the initial embedding outputs.
      last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
        Sequence of hidden-states at the output of the last layer of the model.
      attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
        Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
        sequence_length, sequence_length)`.
        Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
        heads.
  """
  loss: Optional[torch.FloatTensor] = None
  logits: torch.FloatTensor = None
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  last_hidden_state: torch.FloatTensor = None
  attentions: Optional[Tuple[torch.FloatTensor]] = None

In [14]:
class HATESentenceModel(TransformerPreTrainedModel):
  
  _keys_to_ignore_on_load_unexpected = [r"pooler"]
  _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

  
  def __init__(self, config):
    super().__init__(config)

    if config.is_decoder:
      logger.warning(
        "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
        "bi-directional self-attention."
        )

    self.transformer = TransformerBase(config, add_pooling_layer=False)
    self.lmhead = OnlyLMHead(config)

  def get_output_embeddings(self):
    return self.lmhead.predictions.decoder

  def set_output_embeddings(self, new_embeddings):
    self.lmhead.predictions.decoder = new_embeddings

  def forward(self,
              input_ids=None,
              attention_mask=None,
              token_type_ids=None,
              position_ids=None,
              head_mask=None,
              inputs_embeds=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              labels=None,
              output_attentions=None,
              output_hidden_states=None,
              return_dict=None,):
    # TODO replace batch_size with document_length in here in the docfile
    """
    Args:
      inputs_ids (torch.LongTensor of shape (batch_size, sequence_length)):
        Indices of input sequence tokens in the vocabulary.
      attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional):
        Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:
        - 1 for tokens that are not masked,
        - 0 for tokens that are masked.
      token_type_ids  (torch.LongTensor of shape (batch_size, sequence_length), optional):
        Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]:
        - 0 corresponds to a sentence A token,
        - 1 corresponds to a sentence B token.
      position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional):
        Indices of positions of each input sequence tokens in the position embeddings. 
        Selected in the range [0, config.max_position_embeddings - 1].
      head_mask (torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads), optional):
        Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]:
        - 1 indicates the head is not masked,
        - 0 indicates the head is masked.
      inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional):
        Optionally, instead of passing input_ids you can choose to directly pass
         an embedded representation. This is useful if you want more control over 
         how to convert input_ids indices into associated vectors than the model’s 
         internal embedding lookup matrix.
      encoder_hidden_states (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional):
        Sequence of hidden-states at the output of the last layer of the encoder. 
        Used in the cross-attention if the model is configured as a decoder.
      encoder_attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional):
        Mask to avoid performing attention on the padding token indices of the encoder 
        input. This mask is used in the cross-attention if the model is configured 
        as a decoder. Mask values selected in [0, 1]:
        - 1 for tokens that are not masked,
        - 0 for tokens that are masked.
      labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
        Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
        config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
        (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
      output_attentions (bool, optional): 
        Whether or not to return the attentions tensors of all attention layers. 
        See attentions under returned tensors for more detail.
      output_hidden_states (bool, optional):
        Whether or not to return the hidden states of all layers. See hidden_states 
        under returned tensors for more detail.
      return_dict (bool, optional):
        Whether or not to return a ModelOutput instead of a plain tuple.
    Returns:
      SentenceModelingOutput:
    """
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.transformer(input_ids,
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids,
                               position_ids=position_ids,
                               head_mask=head_mask,
                               inputs_embeds=inputs_embeds,
                               encoder_hidden_states=encoder_hidden_states,
                               encoder_attention_mask=encoder_attention_mask,
                               output_attentions=output_attentions,
                               output_hidden_states=output_hidden_states,
                               return_dict=return_dict)
    
    sequence_output = outputs[0]
    prediction_scores = self.lmhead(sequence_output)

    masked_lm_loss = None
    if labels is not None:
      loss_fct = nn.CrossEntropyLoss()  # -100 index = padding token
      masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

    if not return_dict:
      output = (prediction_scores,) + outputs[2:]
      return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

    return SentenceModelingOutput(loss=masked_lm_loss,
                                  logits=prediction_scores,
                                  hidden_states=outputs.hidden_states,
                                  last_hidden_state=sequence_output,
                                  attentions=outputs.attentions,)

## 4.4 Document Model

In [21]:
class DocumentModelingOutput(ModelOutput):
  """
    Return object for Document Model.

    Args:
      loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
        Masked language modeling (MLM) loss.
      logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
      hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
        Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
        of shape :obj:`(batch_size, sequence_length, hidden_size)`.
        Hidden-states of the model at the output of each layer plus the initial embedding outputs.
      last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
        Sequence of hidden-states at the output of the last layer of the model.
      attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
        Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
        sequence_length, sequence_length)`.
        Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
        heads.
  """
  loss: Optional[torch.FloatTensor] = None
  logits: torch.FloatTensor = None
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  last_hidden_state: torch.FloatTensor = None
  attentions: Optional[Tuple[torch.FloatTensor]] = None

In [16]:
class HATEDocumentModel(TransformerPreTrainedModel):
  def __init__():
    super().__init__(config)
    self.transformer = TransformerBase(config, add_pooling_layer=False)
    self.lmhead = SentencePredictionHead(config)

  def forward(self,
              input_ids=None,
              attention_mask=None,
              token_type_ids=None,
              position_ids=None,
              head_mask=None,
              inputs_embeds=None,
              encoder_hidden_states=None,
              encoder_attention_mask=None,
              labels_embeddings=None,
              labels_mask=None,
              output_attentions=None,
              output_hidden_states=None,
              return_dict=None,):
    
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.transformer(input_ids,
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids,
                               position_ids=position_ids,
                               head_mask=head_mask,
                               inputs_embeds=inputs_embeds,
                               encoder_hidden_states=encoder_hidden_states,
                               encoder_attention_mask=encoder_attention_mask,
                               output_attentions=output_attentions,
                               output_hidden_states=output_hidden_states,
                               return_dict=return_dict,)
    
    sequence_output = outputs[0]
    sentence_prediction_output = self.lmhead(sequence_output, labels_embeddings, labels_mask)

    return DocumentModelingOutput(loss=sentence_prediction_output[1],
                                  logits=sentence_prediction_output[0],
                                  hidden_states=outputs.hidden_states,
                                  last_hidden_state=sequence_output,
                                  attentions=outputs.attentions,)

## 4.5 Hierachical Attention-Based Document Encoder (HATE)

In [None]:
class HATEConfig ():
  def __init__(self,
               sentence_model_config,
               document_model_config,
               is_pretraining=False):
    
    """
    Constructs ModelConfig.
    Args:
      Stuff
    Returns:
      Stuff
    """
    self.sentence_model_config = sentence_model_config
    self.document_model_config = document_model_config
    self.is_pretraining = is_pretraining
    
  @classmethod
  def from_dict(cls, json_object):
    """Constructs a `BertConfig` from a Python dictionary of parameters."""
    config = BertConfig(vocab_size=None)
    for (key, value) in six.iteritems(json_object):
      config.__dict__[key] = value
    return config

  @classmethod
  def from_json_file(cls, json_file):
    """Constructs a `BertConfig` from a json file of parameters."""
    with tf.gfile.GFile(json_file, "r") as reader:
      text = reader.read()
    return cls.from_dict(json.loads(text))

  def to_dict(self):
    """Serializes this instance to a Python dictionary."""
    output = copy.deepcopy(self.__dict__)
    return output

  def to_json_string(self):
    """Serializes this instance to a JSON string."""
    return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

In [None]:
class HATEOutput():
  """
  Class for the whole model output
  """

In [None]:
class HATEModel (torch.nn.Module):
  def __init__(self, hate_config):
    super().__init__()

    self.sentence_model = HATESentenceModel(hate_config.sentence_model_config)
    self.document_model = HATEDocumentModel(hate_config.document_model_config)

  # TODO write a config that implements the functionality that the forward
  # function can take either an entire batch or a single document as input, so 
  # it has a mode for pretrain or inference
  # Or is it all the same and a single document is just a batch with length one?
  def forward(batch_token_ids: torch.Tensor,
              batch_attention_mask: torch.Tensor,
              batch_labels: torch.Tensor
              max_doc_length: int,
              pretraining=hate_config.is_pretraining: Bool): # Make it so that when the model is in training mode it accepts a batch by default, else it accepts a document or a sentence for inference (hint: a sentence is just a document with length=1)
    """
    Args:
      batch_token_ids:
      batch_attention_mask:
      batch_labels:
      max_doc_length (int): The number of sentences in the longest document of
        the batch in order to pad the intermediary embedding tensor accordingly.
      pretraining:
    Returns:
    """
    if pretraining:

      sentence_model_embeddings = []

      for input_ids, attention_mask, labels in zip(batch_token_ids, batch_attention_mask, batch_labels):
        sentence_model_outputs = self.sentence_model(input_ids, attention_mask, labels)
        sentence_model_embeddings.appen(sentence_model_outputs.hidden_states)

      document_model_input =  prepare_for_document_model(sentence_model_embeddings)

      document_model_output = self.document_model()

      return HATEOutput()

    else:
      # inference routine, include checking for correct tensor sizes, etc.

    
    for sentence in document:
      sentence_model(input_ids=sentence[0], attention_mask=sentence[1], )

    for sent in doc:
      # Apply the word-level BERT model for masked token prediction
      sentence_output = self.sentence_model(input_ids, labels, out)
      # The ouput object can already return the loss
      sentence_loss = sentence_output.loss
      # Returns the CLS token, has been shown to produce solid sentence representations
      sentence_representation = sentence_output.hidden_states[0]
      # Dense layer as in SMITH
      sentence_representation = self.sentence_embedding(sentence_representation)
      # Normalization as in SMITH
      sentence_representation = f.normalize(sentence_representation)
      # Append the features to list of sentence features for the according doc
      sentence_representation_list.append((sentence_output, sentence_loss, sentence_representation))
    
    
    masked_sentence_embeddings, sentence_labels, mask_indices = mask_sentences()

    
    # TODO after looping thru all sentence: create sentence level attention mask, special token mask, concatenate sentence model outputs, pad them to longest doc in batch

  
    # prune layers? transformers.modeling_utils.find_pruneable_heads_and_indices

    return HATEOutput()

# **5.** Pretraining and Finetuning

## 5.1 Pretraining

### 5.1.1 Evidence for the need for Pretraining







- HIBERT https://arxiv.org/pdf/1905.06566.pdf
- Language Model Pre-training for Hierarchical Document Representations https://arxiv.org/pdf/1901.09128.pdf
- Pre-training Tasks for Embedding-based Large-scale Retrieval https://arxiv.org/pdf/2002.03932.pdf

### 5.1.2 Pretraining Data

Describe Dataset used for pretraining, where to get it and how to load it
https://dumps.wikimedia.org/enwiki/20200920/


In [None]:
from datasets import list_datasets, load_dataset
datasets_list = list_datasets()
print("This is how many datasets are available: ", len(datasets_list))
# Load the Wikipedia dump into an OrderedDict
wiki_data_raw = load_dataset('wikipedia', '20200501.en', split='train[:1%]')

This is how many datasets are available:  183


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4417.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=6866.0, style=ProgressStyle(description…


Downloading and preparing dataset wikipedia/20200501.en (download: 16.99 GiB, generated: 17.07 GiB, post-processed: Unknown size, total: 34.06 GiB) to /root/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/f92599dfccab29832c442b82870fa8f6983e5b4ebbf5e6e2dcbe894e325339cd...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=14554.0, style=ProgressStyle(descriptio…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=18307873280.0, style=ProgressStyle(desc…


Dataset wikipedia downloaded and prepared to /root/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/f92599dfccab29832c442b82870fa8f6983e5b4ebbf5e6e2dcbe894e325339cd. Subsequent calls will reuse this data.


In [None]:
# Print out some basic info about the dataset to better understand the structure
print("size: ", wiki_data_raw.dataset_size)
print("column names: ", wiki_data_raw.column_names)
print("shape: ", wiki_data_raw.shape)
print("format: ", wiki_data_raw.format)
print("description: ", wiki_data_raw.description)

size:  18330235071
column names:  ['title', 'text']
shape:  (60784, 2)
format:  {'type': None, 'format_kwargs': {}, 'columns': ['title', 'text'], 'output_all_columns': False}
description:  Wikipedia dataset containing cleaned articles of all languages.
The datasets are built from the Wikipedia dump
(https://dumps.wikimedia.org/) with one split per language. Each example
contains the content of one full Wikipedia article with cleaning to strip
markdown and unwanted sections (references, etc.).



In [None]:
# Create short version dataset containing 10 documents for quick testing
wiki_data_test = wiki_data_raw[:10]
print("first 10 lines of OrderedDict: ", wiki_data_test)
print("first 10 titles: ", wiki_data_test['title'])
print("first 10 articles: ", wiki_data_test['text'])
print("first article: ", wiki_data_test['text'][0])

first 10 lines of OrderedDict:  OrderedDict([('title', ['Yangliuqing', 'Orana Australia Ltd', "St. Mary's Church, Sønderborg", 'Kalitta', 'Where Is Freedom?', 'Latin liturgical rites', 'Fernaldia pandurata', 'Chester Earl Merrow', 'Hightech Information System', 'AD 47']), ('text', ['Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People\'s Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangliuqing nianhua. For more than 400 years, Yangliuqing has in effect specialised in the creation of these woodcuts for the New Year.  wood block prints using vivid colourschemes to portray traditional scenes of children\'s games often interwoven with auspiciouse objects.\n\n, it had 27 residential communities () and 25 villages under its administration.\n\nShi Family Grand Courtyard\n\nShi Family Grand Courtyard (

### 5.1.3 Process Pretraining Data

Some theory on preprocessing: https://mlexplained.com/2019/11/06/a-deep-dive-into-the-wonderful-world-of-preprocessing-in-nlp/

In [None]:
# A likely source of weak model performance might be the tokenization step in 
# split_sentences_by_words() because it ignores punctuation and creates seperate
# tokens for all punctuation symbols. This might not be anything the model can't
# handle as the punctiuation symbols each have their own token ID so they are
# part of the vocab and might even be expected to appear during model training.
# It might be of use to experiment with a batch where all punctiuation is
# removed, especially if loss doesn't decrease as desired during training.
# Maybe use NLTK for punctuation removal!

class PretrainingData(Dataset):
  
  def __init__(self, tokenizer):
    # Initialize pretrained tokenizer
    self.tokenizer = tokenizer #transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
    # Download and process the English Wikipedia dump
    self.wiki_dump = wiki_data_test # dataset #load_dataset('wikipedia', '20200501.en', split='train')
    self.wiki_dump_sentences_split = self.split_documents_by_sentences(self.wiki_dump)
    self.wiki_dump_words_split = self.split_sentences_by_words(self.wiki_dump_sentences_split)
    self.max_sentence_length = self.get_max_sent_length(self.wiki_dump_words_split)


  ##############################################################################
  # Preprocessing steps
  ##############################################################################

  def split_documents_by_sentences(self, data_ordered_dict):
    """
    Splits every document into sentences

    Args:
      data_ordered_dict: A Python orderedDict containing all training documents
        where the raw text dumps are indexed by 'text'

    Returns:
      doc_split_by_sentences: A list of lists where list[_] contains a full
        document and list[_][_] contains a sentence from a given document
    """
    docs_split_by_sentences = []
    for doc in data_ordered_dict['text']:
      sentences = sent_tokenize(doc)
      docs_split_by_sentences.append(sentences)
    return docs_split_by_sentences


  def split_sentences_by_words(self, docsSplitBySentences):
    """
    Reads the output of split_documents_by_sentences() and uses splits every
      sentence (list[_][_]) on individual words. The pre-tokenization makes it
      easier to determine the max sentence length in the dataset to which the
      other sentences should be padded

    Args:
      docsSplitBySentences: List of lists
        docsSplitBySentences[_] returns a document
        docsSplitBySentences[_][_] returns a sentence from a document

    Returns:
      tokenized_wiki_data: List of list of lists
        tokenized_wiki_data[_] returns a document
        tokenized_wiki_data[_][_] returns a sentence from a document
        tokenized_wiki_data[_][_][_] returns a word from a sentence from a doc
    """
    # Constructs a list of lists where the length of the first dimension is the
    # number of docuemnts in the batch for use in split_sentences_by_words()
    tokenized_wiki_data = [[] for _ in range(len(docsSplitBySentences))]
    # The new empty list gets zipped with our documents batch, nothing gets cut 
    # off as the new list was made to have the same length as the batch.
    # We iterate over the zipped list and for every entry (representing a document)
    # we iterate again over all sentences in that document in order to tokenize
    # each one and append the sentences (now lists of individual tokens) to the
    # respective entry (the according document).
    for doc, entry in zip(docsSplitBySentences, tokenized_wiki_data):
      for sent in doc:
        entry.append(self.tokenizer.tokenize(sent))
    return tokenized_wiki_data


    # TODO can we replace this with batch_encode_plus/encode_plus/prepare_for_model and move it into __getitem__() for more memory efficiecy?
    # Also add return_tensors='pt'
    # https://huggingface.co/transformers/internal/tokenization_utils.html
    def preprocess_dataset(batch, max_sequence_length):
      preprocessed_batch = []
      for doc in batch:
        preprocessed_batch.append(tokenizer(text=doc, 
                                            padding='max_length', 
                                            max_length=max_sequence_length,
                                            return_special_tokens_mask=True))
      return preprocessed_dataset


  ##############################################################################
  # Helper functions
  ##############################################################################
  
  def get_batch_size(self):
    """
    Returns the number of docuemnts in a batch
    """
    return len(self.wiki_dump_words_split)


  # TODO get_max_doc_length() and get_max_sent_length() can be made into one function

  def get_max_doc_length(self, tokenized_batch):
    """
    Returns the number of senteces the longest document in the batch has

    Args:
      tokenized_batch:

    Returns:
      longest_doc_len: The number of sentences of the longest document, when 
      measured by how many sentences it consits of.
    """
    longest_doc_len = 0
    for doc in tokenized_batch:
      if len(doc) >= longest_doc_len:
        longest_doc_len = len(doc)
      else:
        continue
    return longest_doc_len

  # Don't compute this for the whole dataset, we just need it for a batch so we can get a tensor of uniform size
  def get_max_sent_length(self, tokenized_batch):
    longest_sentence_len = 0
    for doc in tokenized_batch:
      # Tokenized documents have a 1 added as the last entry instead of a sentence
      for sent in doc[:-1]:
        if len(sent) >= longest_sentence_len:
          longest_sentence_len = len(sent)
        else:
          continue
    return longest_sentence_len

  ##############################################################################
  # Interface functions
  ##############################################################################

  def __len__(self):
    # References a different function with a better name for more readability
    get_batch_size()

  def __getitem__(self):
    # References a different function with a better name for more readability
    get_doc_from_batch()

### 5.1.4 Training Loop

In [21]:
doc1sent1 = "this is supposed to be the very first sentence in the dummy dataset for fast experimentation"
doc1sent2 = "we're using this as a small test case to see how the DataLoader and Dataset class work in detail"
doc1sent3 = "I'm going to add a few more sentences at the end of every document to inspect the batching a bit more"
doc1sent4 = "let's also make the docuemnts of unequal size to see if that's a problem"
doc2sent1 = "also we have two documents here to see which encoding function from huggingface works best"
doc2sent2 = "let's write a very rudimentary and fast tokenization pipeline and encode these texts into tensor outputs"
doc2sent3 = "hopefully the fact that the documents have an unequal number of sentences is not a problem"
doc3sent1 = "and what happens if we even add a third docuemnt"
doc3sent2 = "wow insanity I wonder if our DataLoader can handle such vast amounts of data"
doc3sent3 = "ok let us add another sentence here for ultimate clarity"

dummy_data = [[doc1sent1, doc1sent2, doc1sent3], [doc2sent1, doc2sent2, doc2sent3], [doc3sent1, doc3sent2, doc3sent3]]

dummy_data_pretokenized = [[] for _ in range(len(dummy_data))]

for doc_id, doc in enumerate(dummy_data):
  for sent in doc:
    dummy_data_pretokenized[doc_id].append(sent.split())


class DummyData(Dataset):
  def __init__(self, tokenizer):
    self.tokenizer = tokenizer
    self.dataset = dummy_data_pretokenized # has the exact same format as our wiki dataset so everything that works here should work there too
    self.max_length = 30
    self.padding_strategy = transformers.tokenization_utils_base.PaddingStrategy('max_length')
  
  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, id): 
    return self.tokenizer.batch_encode_plus(self.dataset[id], # batch-encodes a whole document at [id] at once
                                            padding=self.padding_strategy, 
                                            is_split_into_words=True, 
                                            return_tensors='pt',
                                            return_special_tokens_mask=True)

tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')

# The following will happen inside train()
dummy_data_processed = DummyData(tokenizer)
dataloader = DataLoader(dummy_data_processed, batch_size=2, shuffle=False)

document_inputs_labels = []
document_attention_masks = []

sentence_config = BertConfig()
sentence_model = HATESentenceModel(sentence_config)
document_config = BertConfig()
#document_model = HATEDocumentModel(document_config)
sentence_model_bert = transformers.BertForMaskedLM(sentence_config)



# What if we tokenize only on the fly for every batch? I think we don't need universal sentence lengths anymore...
for batch in dataloader:
  # Get number of sentences in longest doc in the batch
  max_doc_length = get_max_doc_length(batch)

  # Everything between the two hashtag lines should happen inside HATEModel
  ##############################################################################
  intermediary_embeddings = []
  intermediary_attention_mask = []
  intermediary_special_tokens_mask = []
  # Iterate over documents and apply the Sentence Model per doc
  for input_ids, token_type_ids, attention_mask, special_tokens_mask in zip(batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['special_tokens_mask']):
    # Mask input IDs for one document
    sentence_embeddings_per_doc = [torch.randn(768)]
    attention_mask_per_doc = [0]
    special_token_mask_per_doc= [1]

    sentence_counter = 0
    masking_output = mask_input_ids(input_ids, tokenizer, special_tokens_mask)
    sentence_model_output = sentence_model(input_ids=masking_output[0], 
                                         attention_mask=attention_mask, 
                                         token_type_ids=token_type_ids,
                                         inputs_embeds=None,
                                         labels=masking_output[1],
                                         output_attentions=True,
                                         output_hidden_states=True)
    # Iterate over sentence embeddings returned by the model for one document
    for hidden_states in sentence_model_output['last_hidden_state']:
      # CLS embedding for a sentence at sentence_counter position in the document
      # CLS is at position 0 out of 512 of last_hidden_states[sentence_counter]
      sentence_embeddings_per_doc.append(sentence_model_output['last_hidden_states'][sentence_counter][0])
      attention_mask_per_doc.append()
      special_tokens_mask_per_doc.append()
      sentence_counter += 1


    # TODO need a function to tell us longest doc in batch in order to pad inputs for doc model
    intermediary_embeddings.append(torch.stack(sentence_embeddings_per_doc))
    intermediary_attention_mask.append(torch.stack(attention_mask_per_doc))
    intermediary_special_tokens_mask.append(torch.stack(special_tokens_mask_per_doc))


  intermediary_embeddings = torch.stack()
  intermediary_attention_mask = torch.stack()
  intermediary_special_tokens_mask = torch.stack()

  input_embeddings, masked_input_embeddings, label_embeddings, label_mask = mask_input_embeddings(intermediary_embeddings, intermediary_special_tokens_mask)
  document_model_output = document_model()
  ##############################################################################

{'input_ids': tensor([[[  101,  2023,  2003,  ...,     0,     0,     0],
         [  101,  2057,  1005,  ...,     0,     0,     0],
         [  101,  1045,  1005,  ...,     0,     0,     0]],

        [[  101,  2036,  2057,  ...,     0,     0,     0],
         [  101,  2292,  1005,  ...,     0,     0,     0],
         [  101, 11504,  1996,  ...,     0,     0,     0]]]), 'token_type_ids': tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]]), 'attention_mask': tensor([[[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]],

        [[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]]), 'special_tokens_mask': tensor([[[1, 0, 0,  ..., 1, 1, 1],
         [1, 0, 0,  ..., 1, 1, 1],
         [1, 0, 0,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 1, 1, 

In [None]:
def train():

  file_path = "C:\Users\Marco Moldovan\Documents\Studium\Bachelorarbeit\datasets\enwiki-20200920-pages-articles-multistream1.xml-p1p41242"

  tokenizer = transformers.BertTokenizerFast()

  dataset = PretrainingData(file_path, tokenizer)

  dataloader = DataLoader(dataset)

  model = HateModel(sentence_configuration, document_configuration) # what if we use a pretrained BertModel as the Sentence mode and fix the weights and only train the Document Mode?

  model.train()

  criterion = nn.CrossEntropyLoss()

  optimizer = torch.optim.SGD(model.parameters(), lr=lr)

  outputs = model(data)

  loss = criterion(outputs, ground_truth)

  loss.backwards()
  
  # possibly add gradient clipping here: torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
  # https://huggingface.co/transformers/training.html
  
  optimizer.step()

  """
  Training logic works as follows:
  - the dataset consists of documents which consist of sentences which consist
    of words
  - for every sentence some words are masked and for every document some sentences
    are masked
  - for every datapoint (which is a document) the model predicts some masked words
    and some masked sentences
  - these predictions are compared against the ground truth by the criterion 
    and result in some losses
  - the losses for both the masked words and masked sentences are added together
  - the combined loss gets backpropagated through the model

  Important things to figure out:
  - how to build a proper vocabulary for sentences that we can predict over --> take elements from __getitem__().
  - how to store ground truth for sentences --> need masked_sentence_loss() function
  """

### 5.1.5 Perform Pretraining

Possibly train here: https://www.rz.ifi.lmu.de/infos/slurm_de.html

In [None]:
# TensorBoard of loss and all other sorts of important info

### 5.1.6 Save Pretrained Model

Imagine not saving your model and losing all training progress #uff

### 5.1.7 Inference

Write all the trained modules in such a compact way that we can just hand it a document or sentence and it will infer its representation

In [None]:
class HATEModel(nn.Module):

## 5.2 Finetuning on MS MARCO Document Ranking

### 5.2.1 Details on the MS MARCO Dataset

- https://microsoft.github.io/msmarco (has list of other document reranking models)

- MS MARCO: A Human Generated MAchine Reading COmprehension Dataset https://arxiv.org/pdf/1611.09268.pdf

- https://microsoft.github.io/TREC-2020-Deep-Learning/



also relevant for finetuning:

- RepBERT: Contextualized Text Embeddings for First-Stage Retrieval https://arxiv.org/pdf/2006.15498.pdf
- TwinBERT: Distilling Knowledge to Twin-Structured BERT Models for Efficient Retrieval (1) https://arxiv.org/pdf/2002.06275.pdf (2) https://github.com/deepampatel/TwinBert/blob/master/TwinBert.ipynb

Alternative Tasks:
- https://research.google/tools/datasets/ Wikipedia and arXiv similarity triplets
- https://github.com/LiqunW/Long-document-dataset with this paper: Long Document Classification From Local Word Glimpses via Recurrent Attention Learning https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8675939
- https://datasets.quantumstat.com/

### 5.2.2 Load and Tokenize Finetuning Dataset

### 5.2.3 Document Ranking Task

- https://github.com/microsoft/MSMARCO-Document-Ranking

- Baseline: Longformer: The Long-Document Transformer (1) https://arxiv.org/pdf/2004.05150.pdf (2) https://github.com/isekulic/longformer-marco/blob/master/src/TransformerMarco.py
Baseline: Conformer-Kernel with Query Term Independence for Document Retrieval (1) https://arxiv.org/pdf/2007.10434.pdf 

### 5.2.4 *TODO* 
- loss function for ranking task -> cosine contrastive loss? check this TwinBERT implementation https://github.com/deepampatel/TwinBert/blob/master/TwinBert.ipynb
- what training objective is actually used in the finetuning here? is it already a ranking task? could the loss function be the same as in pretraining? 
- ranking algorithm for the document representations

# **6.** Experiments

## 6.1 Document Ranking

## 6.2 Answer Passage Highlighting

# **7.** Conclusion

# **8.** Outlook

- Deeper look at attention: transformers are graph neural networks https://thegradient.pub/transformers-are-graph-neural-networks/
- Generalization to Hopfield Nets: (1) https://ml-jku.github.io/hopfield-layers/ (2) http://franksworld.com/2020/08/10/explaining-the-paper-hopfield-networks-is-all-you-need/ (3) https://analyticsindiamag.com/modern-hopfield-network-transformers-attention/ (4) https://towardsdatascience.com/hopfield-networks-are-useless-heres-why-you-should-learn-them-f0930ebeadcd
- Implications of these findings for language representation?
- Check http://nlp.seas.harvard.edu/code/ for more ideas