<a href="https://colab.research.google.com/github/faraway1nspace/AnathemTransformer/blob/main/dev/notebooks/dev_anathem_transformer_base_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Development Notebook: build and test base layers for Anathem Transformer (aka Silo'd Transformer)

### Notes
- the google-minature models have the same vocab size and heads as bert-large-ucased
- the minature-google papers discusses the classification and distallation tasks & corpus's including:  
    - NLI
    - sentiment analysis
- the MTEB leader best model is e5-large (24 layers) which uses the CLS token. It is also "instruction fine-tuned", requiring query and passage prefixes.
- distillation example: https://github.com/philschmid/knowledge-distillation-transformers-pytorch-sagemaker/blob/master/knowledge-distillation.ipynb  
    - they set temperature to 2: which results in a flatter probability distribution. I could make this dynamic -> start 0.5 progress to 1
    - they set alpha to 0.5, which balances label-loss vs distil-loss



#### Playing Around with novel architectures

In [1]:
%pip install torch transformers datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.30.1-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m62.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m42.7 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m11

In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
from typing import List, Optional
from torch import nn
import torch.nn.functional as F
from torch.cuda import is_available
if is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

from transformers.models.bert.modeling_bert import BertEncoder
from transformers.activations import ACT2FN
import copy

model_string = 'google/bert_uncased_L-12_H-512_A-8' # 'distilroberta-base
tokenizer = AutoTokenizer.from_pretrained(model_string)
basemod = AutoModel.from_pretrained(model_string)
basemod.to(device)

Downloading (…)lve/main/config.json:   0%|          | 0.00/384 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/217M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/bert_uncased_L-12_H-512_A-8 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 512, padding_idx=0)
    (position_embeddings): Embedding(512, 512)
    (token_type_embeddings): Embedding(2, 512)
    (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=512, out_features=512, bias=True)
            (key): Linear(in_features=512, out_features=512, bias=True)
            (value): Linear(in_features=512, out_features=512, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=512, out_features=512, bias=True)
            (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [None]:
text = [
    "A standard indemnity clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with issues.",
    "It usually consists of two elements: a trigger event or circumstance and a payment obligation2. The trigger event or circumstance is the breach of the agreement, misconduct, or negligence of the indemnifying party or its affiliates"
]

In [None]:
from transformers import BertTokenizer


class CustomTokenizer:
    def __init__(self, model_string='google/bert_uncased_L-12_H-512_A-8', n_cls_prepend = 4, n_pad_to_multiple_of=4):
        self.base_tokenizer = AutoTokenizer.from_pretrained(model_string)
        self.n_cls_prepend = n_cls_prepend
        self.n_pad_to_multiple_of = n_pad_to_multiple_of
        for k in dir(self.base_tokenizer):
            if not (k[0]=='_' or k=='tokenize' or k=='encode' or k=='build_inputs_with_special_tokens' or k == 'batch_encode_plus'):
                setattr(self,k,getattr(self.base_tokenizer, k))
    
    def __call__(self, text, pad_to_multiple_of=None, add_special_tokens = True, return_tensors=None, *args, **kwargs):
        if pad_to_multiple_of is None:
            pad_to_multiple_of = self.n_pad_to_multiple_of

        # run through base tokenizer
        tokens = self.base_tokenizer(
            text, 
            pad_to_multiple_of=(pad_to_multiple_of if not add_special_tokens else False), 
            add_special_tokens=add_special_tokens, 
            return_tensors=return_tensors if (not add_special_tokens) else None,
            *args, 
            **kwargs
        )
        if add_special_tokens:
            tokens = self._prepend_extra_cls_tokens_because_of_maxpooling(tokens, return_tensors)
        
        return tokens
    
    def _num_pad_tokens(self, token_list):
        """Calculates how many PAD tokens to append to sequence to make a multiple of X"""
        return (self.n_pad_to_multiple_of - ((len(token_list)+(self.n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of

    def _prepend_extra_cls_tokens_because_of_maxpooling(self, tokens, return_tensors=None):
        n_cls_prepend = self.n_cls_prepend
        # prepend (n-1) CLS tokens to the front of the token_ids (because of maxpooling)
        # also pad so that the total length is a multiple of n_cls_prepend
        #num_pad_tokens = (self.n_pad_to_multiple_of - ((len_tokens+(n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of
        tokens['input_ids'] = [
            [self.cls_token_id]*(n_cls_prepend-1)+input_id + [self.pad_token_id]*self._num_pad_tokens(input_id)
            for input_id 
            in tokens['input_ids']
        ]
        tokens['attention_mask'] = [
            [1]*(n_cls_prepend-1)+attnmask +[0]*self._num_pad_tokens(attnmask)
            for attnmask
            in tokens['attention_mask']
        ]
        if 'token_type_ids' in tokens.keys():
            tokens['token_type_ids'] = [
                [toktypeid[0]]*(n_cls_prepend-1)+toktypeid +[toktypeid[-1]]*self._num_pad_tokens(toktypeid)
                for toktypeid
                in tokens['token_type_ids']
            ]
        if return_tensors == 'pt':
            for k,v in tokens.items():
                tokens[k] = torch.LongTensor(v)
        return tokens

    def encode(self, text, pad_to_multiple_of=4, add_special_tokens = True, *args, **kwargs):
        encoded = self.base_tokenizer.encode(text, pad_to_multiple_of=False, add_special_tokens=add_special_tokens, *args, **kwargs)
        if add_special_tokens:
            encoded = [self.cls_token_id]*(pad_to_multiple_of-1) + encoded
        if bool(pad_to_multiple_of):
            num_pad_tokens = (pad_to_multiple_of - (len(encoded) % pad_to_multiple_of)) % pad_to_multiple_of
            encoded += [self.pad_token_id] * num_pad_tokens
        return encoded

    def tokenize(self, text, add_special_tokens=True, *args, **kwargs):
        toks = self.base_tokenizer.tokenize(text, add_special_tokens=add_special_tokens, *args, **kwargs)
        if add_special_tokens:
            toks = [self.cls_token] * (self.n_cls_prepend-1) + toks
        return toks

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ):
        out = self.base_tokenizer.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
        return [self.cls_token_id]*3 + out

    def batch_encode_plus(self, batch_text_or_text_pairs, *args, **kwargs):
        batched_encoded = self.base_tokenizer.batch_encode_plus( batch_text_or_text_pairs, *args, **kwargs)
        #batched_encoded.update({'foo':'bar'})
        return batched_encoded

    

# Note, if I use the vanilla LineByLineTextDataset, it just calls tokenizer.__call__ turns on the `use_special_tokens`, and it pads to a multiple of optional
# .. so somehow I need to ensure that, whatever base function it calls as part of the tokenizer pipeline, it will continue using MY new function
# the tokenizer.__call__ DOES NOT use `encode` nor `tokenize` otherwise my modifications would manifest
# looks like `prepare_for_model` (and maybe `batch_prepare_for_model`) is what adds special tokens?
# looks like `prepare_for_model` just calls `build_inputs_with_special_tokens`, so maybe intervene there?
#         if add_special_tokens:
#            sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
#            token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
# editing `build_inputs_with_special_tokens` didn't work either

In [None]:
tokenizer2 = CustomTokenizer()
tokenizer2.pad_token_id

Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.


0

In [None]:
#toks = tokenizer2.encode(text[0], add_special_tokens=True)
#print(len(toks)) # works
#print(toks[:10])

tokens = tokenizer2(text, padding='longest', return_tensors=None) # doesn't work, obviously 
#print(tokens)
print(len(tokens['input_ids'][0]))
print(len(tokens['attention_mask'][0]))

print(len(tokens['input_ids'][1]))
print(len(tokens['attention_mask'][1]))

tokens

#tokenizer2.batch_encode_plus(text, add_special_tokens=True) # doesn't work 


52
52
52
52


{'input_ids': [[101, 101, 101, 101, 1037, 3115, 27427, 6633, 22758, 11075, 2003, 1037, 23701, 6299, 11075, 2008, 2163, 2008, 2028, 2283, 2180, 1005, 1056, 2907, 1996, 2060, 20090, 2005, 12394, 1010, 6409, 1010, 2030, 5366, 3378, 2007, 3314, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 101, 101, 101, 2009, 2788, 3774, 1997, 2048, 3787, 1024, 1037, 9495, 2724, 2030, 25652, 1998, 1037, 7909, 14987, 2475, 1012, 1996, 9495, 2724, 2030, 25652, 2003, 1996, 12510, 1997, 1996, 3820, 1010, 23337, 1010, 2030, 27988, 1997, 1996, 27427, 6633, 3490, 14116, 2283, 2030, 2049, 18460, 102, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [None]:
dir(basemod)
# base embedding layers
layer_emb = copy.deepcopy(basemod._modules['embeddings'])


In [None]:
# base trasnformers (full)
layer_basetransformer = copy.deepcopy(basemod._modules['encoder']._modules['layer']._modules['0'])

In [None]:
# text 
text = [
    "A standard indemnity clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with legal issues1.",
    "It usually consists of two elements: a trigger event or circumstance and a payment obligation2. The trigger event or circumstance is the breach of the agreement, willful misconduct, or negligence of the indemnifying party or its affiliates"
]

import math

#padding_length = int(math.ceil(max_length / 4)) * 4
tokens = tokenizer(text,padding=True, return_tensors='pt', pad_to_multiple_of=4)
input_shape = tokens['input_ids'].size()

# change token padding to be multiple of 4
#ideal_length = int(math.ceil(input_shape[-1] / 4)) * 4 # should be a multiple of 4
#if input_shape[-1]!=ideal_length:
#  tokens = tokenizer(text,padding='max_length', max_length = ideal_length, return_tensors='pt')
#  input_shape = tokens['input_ids'].size()

token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
tokens['token_type_ids'] = token_type_ids
past_key_values_length =0

# need to extend attention mask
extended_attention_mask = basemod.get_extended_attention_mask(tokens['attention_mask'], input_shape)
tokens['extended_attention_mask'] = extended_attention_mask
print(tokens.keys())
print(tokens['input_ids'].shape)


dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'extended_attention_mask'])
torch.Size([2, 48])


In [None]:
silo_dimensions = {0:basemod.config.hidden_size,
                  1:basemod.config.hidden_size//2,
                  2:basemod.config.hidden_size//4,
                  }
reintegration_dim = silo_dimensions[1] + silo_dimensions[2]


In [None]:
embedding_output = layer_emb(
            input_ids=tokens['input_ids'],
            position_ids=tokens.get('position_ids',None),
            token_type_ids=tokens['token_type_ids'],
            inputs_embeds=None,
            past_key_values_length=past_key_values_length
)
print(embedding_output.shape)

RuntimeError: ignored

In [None]:
# basemodel transformer outputs: *full bert model
out_l1 = layer_basetransformer(
    hidden_states = embedding_output,
    attention_mask = tokens['extended_attention_mask'],#tokens['attention_mask'],
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    #past_key_values=0,
    #use_cache=None,
    output_attentions=True,
    #output_hidden_states=True,
    #return_dict=True
)

hidden_states_l1 = out_l1[0]
self_attention_l1 = out_l1[1]

In [None]:
# Next Layer: 
# Query -> max pool and reduce  hidden dimension // 2
# Key -> reduce hidden_dim // 2
# value -> reduce hidden_dim //2
#maxpool_l2 = nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

maxpool_l2 = nn.Sequential(
    nn.Dropout(0.05),
    nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True),
)

maxpool_l2_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

In [None]:
# reduce dimension of hidden states
hiddens_states_l1_reduced = maxpool_l2(hidden_states_l1)
print(hidden_states_l1.shape)
print(hiddens_states_l1_reduced.shape)

# reduce dimension of attention mask
attention_mask_l1_reduced = maxpool_l2_attn(tokens['attention_mask'].float())
print(attention_mask_l1_reduced.shape)

# extend the dimension of the reduced attention_mask
print(input_shape)
extended_attention_mask_l1_reduced = basemod.get_extended_attention_mask(attention_mask_l1_reduced, attention_mask_l1_reduced.shape)
print(tokens['extended_attention_mask'].shape)
print(extended_attention_mask_l1_reduced.shape)

torch.Size([2, 48, 768])
torch.Size([2, 24, 768])
torch.Size([2, 24])
torch.Size([2, 48])
torch.Size([2, 1, 1, 48])
torch.Size([2, 1, 1, 24])


In [None]:
# Try to do Multi Headed attenion with differently sized query and value 

In [None]:
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
import copy

class BertSelfAttnDimensionReduction(nn.Module):
    """Bert Attention Layer that uses a dimension-reduced version of the query, so to reduce the dimension of the outputs"""
    def __init__(
        self, 
        config, 
        hidden_size_input=768, 
        hidden_size_query = None,
        position_embedding_type=None, 
        dim_reduction = 2
    ):
        """Special type of Bert Self attention that reduces the dimension of the inputs by half"""
        super().__init__()
        if (config.hidden_size // dim_reduction) % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.dim_reduction = dim_reduction
        self.hidden_size_input = hidden_size_input
        self.hidden_size_reduced = hidden_size_input // dim_reduction
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size_reduced / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_input, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_input, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.

        key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
        value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if encoder_attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            #print(attention_scores.shape)
            #print(attention_scores.shape)
            attention_scores = attention_scores + encoder_attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

bertlayer_l2_reduction = BertSelfAttnDimensionReduction(
    config=basemod.config,
    hidden_size_input=basemod.config.hidden_size, 
    position_embedding_type=basemod.config.position_embedding_type,
    dim_reduction = 2
)

bertlayer_l3_reduction = BertSelfAttnDimensionReduction(
    config=basemod.config,
    hidden_size_input=basemod.config.hidden_size // 2, 
    position_embedding_type=basemod.config.position_embedding_type,
    dim_reduction = 2
)

In [None]:
out_l2 = bertlayer_l2_reduction(
        hidden_states = hiddens_states_l1_reduced,
        attention_mask = extended_attention_mask_l1_reduced,
        head_mask=None,
        encoder_hidden_states = hidden_states_l1,
        encoder_attention_mask= tokens['extended_attention_mask'],
        past_key_value=None,
        output_attentions=False
    )
hidden_states_l2 = out_l2[0]
print(hidden_states_l2.shape)

torch.Size([2, 24, 384])


In [None]:
# Next dimension reduction: 
hiddens_states_l2_reduced = maxpool_l2(hidden_states_l2)
print(hidden_states_l2.shape)
print(hiddens_states_l2_reduced.shape)

# reduce dimension of attention mask
attention_mask_l2_reduced = maxpool_l2_attn(attention_mask_l1_reduced.float())
print(attention_mask_l2_reduced.shape)

# extend the dimension of the reduced attention_mask
extended_attention_mask_l2_reduced = basemod.get_extended_attention_mask(attention_mask_l2_reduced, attention_mask_l2_reduced.shape)
print(extended_attention_mask_l2_reduced.shape)

if True:
  out_l3 = bertlayer_l3_reduction(
        hidden_states = hiddens_states_l2_reduced, # input has been maxpooled
        attention_mask = extended_attention_mask_l2_reduced,
        head_mask=None,
        encoder_hidden_states = hidden_states_l2,
        encoder_attention_mask= extended_attention_mask_l1_reduced,
        past_key_value=None,
        output_attentions=False
    )
  hidden_states_l3 = out_l3[0]
  print(hidden_states_l3.shape)


# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times

torch.Size([2, 24, 384])
torch.Size([2, 12, 384])
torch.Size([2, 12])
torch.Size([2, 1, 1, 12])
torch.Size([2, 12, 192])


In [None]:
# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times

config_lowres_encoder = copy.deepcopy(basemod.config)
config_lowres_encoder.hidden_size = config_lowres_encoder.hidden_size//4
config_lowres_encoder.num_hidden_layers = 3
print(config_lowres_encoder)

# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times
encoder_lowres = BertEncoder(config_lowres_encoder)

RobertaConfig {
  "_name_or_path": "distilroberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 192,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 3,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.29.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}



In [None]:
out_encoder_lowres = encoder_lowres(
    hidden_states=hidden_states_l3,
    attention_mask=extended_attention_mask_l2_reduced,
    head_mask = None,
    return_dict=True,
)
hidden_states_lowres = out_encoder_lowres[0]
print(hidden_states_lowres.shape)

torch.Size([2, 12, 192])


In [None]:
## Upresolution Layer: up-resolution from dim-3 to dim-2 is as follows:
# hs_l3 -> upsampled sequence-length as hs-l2
# -> could have another attention-based mechanism that expands dimension of hs-l2

class InterpolateCombo(nn.Module):
    """there could also be an attentive way to do this"""
    def __init__(self, scale_factor=2, dropout=0.05, alpha=0.667):
        """Arguments:
        :param scaler_factor: float, multiple of up-scaling
        :param dropout: float, dropout proportion
        :param alpha: float, mixture weight between nearest-neighbor vs linear-interpolation
        """
        super(InterpolateCombo, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.dropout = nn.Dropout(dropout)
        self.a = alpha
        
    def forward(self, x):
        x_trans = x.transpose(-2,-1)
        z = self.a*self.interp(x_trans, mode='nearest',scale_factor=self.scale_factor) + (1-self.a)*self.interp(x_trans, mode='linear',scale_factor=self.scale_factor)
        z = self.dropout(z)
        return z.transpose(-2,-1)

#hidden_states_upscaled_3to2_nearest = nn.functional.interpolate(hidden_states_rowres.transpose(-2,-1), scale_factor=2, mode='nearest').transpose(-2,-1)
#hidden_states_upscaled_3to2_linear = nn.functional.interpolate(hidden_states_rowres.transpose(-2,-1), scale_factor=2, mode='linear').transpose(-2,-1)

upscaler_x2 = InterpolateCombo(scale_factor=2)

In [None]:
hidden_states_upscaled3to2 = upscaler_x2(hidden_states_lowres)


In [None]:
## BertAttentiveIntegrator

class BertCrossAttention(nn.Module):
    def __init__(
        self, 
        config, 
        hidden_size,
        hidden_size_query,
        hidden_size_keyvalue=None,
        position_embedding_type=None
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.hidden_size_query = hidden_size_query
        if hidden_size_keyvalue is None:
            hidden_size_keyvalue = hidden_size
        self.hidden_size_keyvalue = hidden_size_keyvalue
        if self.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_keyvalue, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_keyvalue, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: Optional[torch.FloatTensor] = None,
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(query_hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)
        
        use_cache = past_key_value is not None
        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

In [None]:
bertlayer_l3_to_l2_crossattn = BertCrossAttention(
        config=basemod.config, 
        hidden_size=silo_dimensions[1],
        hidden_size_query=silo_dimensions[2],
        position_embedding_type=None
    )

In [None]:
print(hidden_states_upscaled3to2.shape)
print(hidden_states_l2.shape)
print(attention_mask_l1_reduced.shape)
print(extended_attention_mask_l1_reduced.shape)

torch.Size([2, 24, 192])
torch.Size([2, 24, 384])
torch.Size([2, 24])
torch.Size([2, 1, 1, 24])


In [None]:
out_l2_postencode = bertlayer_l3_to_l2_crossattn(
    hidden_states = hidden_states_l2,
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled3to2,
    query_attention_mask = attention_mask_l1_reduced
)
hidden_states_l2_postencode = out_l2_postencode[0]
print(hidden_states_l2_postencode.shape)
assert hidden_states_l2_postencode.shape == hidden_states_l2.shape

torch.Size([2, 24, 384])


In [None]:
print(basemod.config.hidden_size)
print(basemod.config.intermediate_size)
print(basemod.config.intermediate_size/basemod.config.hidden_size)

768
3072
4.0


In [None]:
# how does bert actually work?
"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class BertIntegrativeLayer(nn.Module):
    """Vanilla Bert Layer, but integrates other hiddens states from a parallel transformers stack typically low-re"""
    def __init__(
            self, 
            config, 
            hidden_size, 
            hidden_size_query,
            intermediate_size=None
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size
        self.hidden_size = hidden_size
        self.hidden_size_query = hidden_size_query
        self.hidden_size_concat = int(hidden_size + hidden_size_query)

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertCrossAttention(
            config, 
            hidden_size,
            hidden_size_query,
            position_embedding_type="absolute"
        )
        self.is_decoder = config.is_decoder
        #self.intermediate = BertIntermediate(config)
        #self.output = BertOutput(config)
        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: Optional[torch.FloatTensor] = None,
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        # cross attn between hiddens states and (low-res) query vector
        cross_attn_outputs = self.attention(
            hidden_states = hidden_states,
            attention_mask = attention_mask,
            head_mask = head_mask,
            query_hidden_states = query_hidden_states,
            query_attention_mask = query_attention_mask
        )
        cross_hidden_states = cross_attn_outputs[0]

        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.dropout_attn(self.output_attn(cross_hidden_states))
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.intermediate_act_fn(self.intermediate(
            self.cat((hidden_states, query_hidden_states),axis=2)
        ))
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        intermediate_states = self.dropout_intm(self.output_intm(intermediate_states))
        out_states = self.lnorm_intm(intermediate_states + hidden_states)

        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states


In [None]:

# from low-res to mid-res
bert_integrative_layer_midres = BertIntegrativeLayer(
    basemod.config, 
    hidden_size=silo_dimensions[1], 
    hidden_size_query=silo_dimensions[2],
    intermediate_size=silo_dimensions[1]*4,
)

# from mid-res to high-res
bert_integrative_layer_hires = BertIntegrativeLayer(
    basemod.config, 
    hidden_size=silo_dimensions[0], 
    hidden_size_query=reintegration_dim,
    intermediate_size=silo_dimensions[0]*4,
)

In [None]:
hidden_states_midres = bert_integrative_layer_midres(
    hidden_states = hidden_states_l2,
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled3to2,
    query_attention_mask = attention_mask_l1_reduced
)
print(hidden_states_midres.shape)
assert hidden_states_midres.shape == hidden_states_l2.shape

torch.Size([2, 24, 384])


In [None]:
# upscale the l2 and l3 to the full dimension
upscaler_x4 = InterpolateCombo(scale_factor=4)
hidden_states_upscaled3to1 = upscaler_x4(hidden_states_lowres)
hidden_states_upscaled2to1 = upscaler_x2(hidden_states_midres)

hidden_states_upscaled = torch.cat(
    (hidden_states_upscaled2to1, hidden_states_upscaled3to1),
    axis=2)

print(hidden_states_upscaled.shape)

torch.Size([2, 48, 576])


In [None]:
# final layer to bring it up to full dimension
hidden_states_hires = bert_integrative_layer_hires(
    hidden_states = hidden_states_l1,
    attention_mask = extended_attention_mask,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled,
    query_attention_mask = extended_attention_mask
)
print(hidden_states_hires.shape)
assert hidden_states_hires.shape == hidden_states_l1.shape

torch.Size([2, 48, 768])


In [None]:
hidden_states_hires.shape

torch.Size([2, 48, 768])

In [None]:
attention_mask_l1_reduced.shape

torch.Size([2, 24])

### The Reduce and Integrate layer:
- this is like a Transformer block, but:
- does dimension reduction along sequence and embedding-dim
- includes a skip connection from previous hidden-states of the same dimension

In [None]:



# this is the layer that just does cross-attention between a seq-reduced query and full-size value and key


"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)


BertReduceAddIntegrativeLayer
inputs = x_l1, x_l1_reduced, x_l2_prev
- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
- x3 = lnorm(drop(f(x2)) + x_l2_prev)
- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class BertReduceAddIntegrativeLayer(nn.Module):
    """Bert Layer that does dimenion reduction along embedding-dimenion and integrations a skip connection"""
    def __init__(
            self, 
            config, 
            hidden_size,
            hidden_size_input=None,
            hidden_size_query=None,
            intermediate_size=None,
            dim_reduction=2,
            do_concat_hidden_and_query = True
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.do_concat_hidden_and_query = do_concat_hidden_and_query
        assert bool(do_concat_hidden_and_query), 'not implemented: concatenation of query and hidden-states must happen'
        self.hidden_size = hidden_size
        if dim_reduction is None:
            dim_reduction = 2
        self.dim_reduction = dim_reduction
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size
        if hidden_size_input is None:
            hidden_size_input = hidden_size
        self.hidden_size_input = hidden_size_input
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query + do_concat_hidden_and_query*hidden_size
        self.hidden_size_concat = int(hidden_size + hidden_size_input)

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertSelfAttnDimensionReduction( 
            config, 
            hidden_size_input=self.hidden_size_input, 
            hidden_size_query = self.hidden_size_query,
            position_embedding_type="absolute", 
            dim_reduction = self.dim_reduction
        )
        self.is_decoder = config.is_decoder
        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)
    
    def forward(
        self,
        inputs: torch.Tensor, # higher-resolution inputs for key and values (long sequence dimension)
        hidden_states: torch.Tensor, # previous hidden-states for skip connection (short squence-dim, low-res)
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: torch.FloatTensor = None, # hidden-states for query (short squence-dim, low-res)
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        if self.do_concat_hidden_and_query:
            query_hidden_states_plus = torch.cat((query_hidden_states, hidden_states),axis=2)
        # cross attn between (low-res) query vector and (high-res) key-values
        cross_attn_outputs = self.attention(
            query_hidden_states_plus, # query (short seq-dim, high-res)
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states = inputs, # for key/value (longer sequence dimension, high-res)
            past_key_value=past_key_value,
            output_attentions=output_attentions,
        )
        cross_hidden_states = cross_attn_outputs[0]

        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.dropout_attn(self.output_attn(cross_hidden_states))
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.intermediate_act_fn(self.intermediate(
            self.cat((hidden_states, query_hidden_states),axis=2)
        ))
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        intermediate_states = self.dropout_intm(self.output_intm(intermediate_states))
        out_states = self.lnorm_intm(intermediate_states + hidden_states)

        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states


In [None]:
# initialize the mid-resolution BertReduceAndIntegrate layer
bert_reduce_add_integrate_midres = BertReduceAddIntegrativeLayer(
    config, 
    hidden_size = silo_dimensions[1], # size of mid-res
    hidden_size_input=silo_dimensions[0],
    hidden_size_query=silo_dimensions[0],
    intermediate_size=silo_dimensions[1]*3,
    dim_reduction=2,
    do_concat_hidden_and_query = True
)

bert_reduce_add_integrate_lowres = BertReduceAddIntegrativeLayer(
    config, 
    hidden_size = silo_dimensions[2], # size of mid-res
    hidden_size_input=silo_dimensions[1],
    hidden_size_query=silo_dimensions[1],
    intermediate_size=silo_dimensions[2]*3,
    dim_reduction=2,
    do_concat_hidden_and_query = True
)

In [None]:
# Reduce sequence-dim from l1->l2, and from high-res->mid-res
hidden_states_hires_reduced = maxpool_l2(hidden_states_hires)
assert hidden_states_hires_reduced.shape[1] == hidden_states_midres.shape[1] # reduced-seq-dim should be same as mid-res hidden-states
print(hidden_states_midres.shape)
hidden_states_midres = bert_reduce_add_integrate_midres(
    inputs = hidden_states_hires, # from highres outputs previous layer (key, values)
    hidden_states = hidden_states_midres, # previous hidden-states for skip connection (short squence-dim, low-res)
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask=None,
    query_hidden_states = hidden_states_hires_reduced # reduced version of high-res inputs (reduced along sequence dimenion)
)
print(hidden_states_midres.shape)

torch.Size([2, 24, 384])
torch.Size([2, 24, 384])


In [None]:
# Reduce sequence-dim from l1->l2, and from high-res->mid-res
hidden_states_midres_reduced = maxpool_l2(hidden_states_midres)
assert hidden_states_midres_reduced.shape[1] == hidden_states_lowres.shape[1] # reduced-seq-dim should be same as mid-res hidden-states
print(hidden_states_midres_reduced.shape)

if True:
  print(hidden_states_lowres.shape)
  hidden_states_lowres = bert_reduce_add_integrate_lowres(
      inputs = hidden_states_midres, # from highres outputs previous layer (key, values)
      hidden_states = hidden_states_lowres, # previous hidden-states for skip connection (short squence-dim, low-res)
      attention_mask = extended_attention_mask_l2_reduced,
      head_mask=None,
      query_hidden_states = hidden_states_midres_reduced # reduced version of high-res inputs (reduced along sequence dimenion)
  )
  print(hidden_states_lowres.shape)

torch.Size([2, 12, 384])
torch.Size([2, 12, 192])
torch.Size([2, 12, 192])


In [None]:
try:
    from transformers.modeling_utiles import get_extended_attention_mask
except:
    def get_extended_attention_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: device) -> torch.Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)

                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
                    causal_mask = torch.cat(
                        [
                            torch.ones(
                                (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
                            ),
                            causal_mask,
                        ],
                        axis=-1,
                    )

                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

### Base-Layer nn.Module

In [None]:
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
from torch import nn
from transformers.models.bert.modeling_bert import BertEncoder
from transformers.activations import ACT2FN
from typing import List, Optional, Tuple, Union

def make_config(
    modelstring = "distilroberta-base",
    num_transformer_stacks = 2, # number of transformer stacks
    scale_ratio2 = 0.5, # reduce sequence-length by X, from high-res to mid-res
    scale_ratio3 = 0.25, # reduce sequence-length by Y, from high-res to low-res
    multipler_intermediate2 = 4.0, # intermeidate size is a multiple of hidden size
    multipler_intermediate3 = 4.0, # intermeidate size is a multiple of hidden size
    num_layers_l2 = 1, # mid-res encoder
    num_layers_l3 = 3, # low-res encoder
    dropout_scaling = 0.05, # dropout when performing downscaling from one-sequence length to next
    use_cheap_integrator_for_stacks = []
):
    #if True:
    #modelstring = "distilroberta-base"
    #scale_ratio2 = 0.5
    #scale_ratio3 = 0.25
    #scale_intermediate2 = 4
    #scale_intermediate3 = 4
    base_config = AutoConfig.from_pretrained(modelstring)
    config_l2 = copy.deepcopy(base_config)
    config_l3 = copy.deepcopy(base_config)
    setattr(base_config,'model_string', modelstring)
    setattr(base_config,'num_transformer_stacks',num_transformer_stacks)
    setattr(base_config,'num_layers_l2', num_layers_l2)
    setattr(base_config,'num_layers_l3', num_layers_l3)
    setattr(base_config,'scale_ratio2', scale_ratio2)
    setattr(base_config,'scale_ratio3', scale_ratio3)
    setattr(base_config,'scale_factor2', int(1/base_config.scale_ratio2))
    setattr(base_config,'scale_factor3', int(1/base_config.scale_ratio3*base_config.scale_ratio2))
    setattr(base_config,"hidden_size_l2", int(base_config.hidden_size * scale_ratio2))
    setattr(base_config,"hidden_size_l3", int(base_config.hidden_size * scale_ratio3))
    setattr(base_config,"intermediate_size_l1", int(base_config.hidden_size_l2*multipler_intermediate2))
    setattr(base_config,"intermediate_size_l2", int(base_config.hidden_size_l3*multipler_intermediate3))
    setattr(base_config,"query_size1", base_config.hidden_size_l2 + base_config.hidden_size_l3)
    setattr(base_config,"query_size2", base_config.hidden_size_l3)
    setattr(base_config,"dropout_scaling", dropout_scaling)
    setattr(base_config,"use_cheap_integrator_for_stacks", use_cheap_integrator_for_stacks)

    # make the configuration for the l2 mid-res encoder
    config_l2.hidden_size = base_config.hidden_size_l2
    config_l2.num_hidden_layers = num_layers_l2
    setattr(base_config, 'config_l2', config_l2)

    # make the configuration for the l3 encoder
    config_l3.hidden_size = base_config.hidden_size_l3
    config_l3.num_hidden_layers = num_layers_l3
    setattr(base_config, 'config_l3', config_l3)
    return base_config


def initialize_baselayers(config, basemod = None, tokenizer=None, stack_id=0):
    """Initializes the embeddings and first stack of layers for the Anathem transformers"""
    # initialize the basemodel
    if basemod is None:
        basemod = AutoModel.from_pretrained(config.model_string)
    if tokenizer is None:
        # download pretrained tokenizer
        tokenizer = AutoTokenizer.from_pretrained(config.model_string)
  
    device = basemod.device
    setattr(config, 'device', device)

    # get basemodel's embeddings
    layer_embedding = copy.deepcopy(basemod._modules['embeddings'])

    # get basemodel's first transformer block
    layer_basetransformer = copy.deepcopy(basemod._modules['encoder']._modules['layer']._modules['0'])

    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize downsampling attention layers
    bert_reducer_l2 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size, 
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor2
    )
    # 1/4 hidden size
    bert_reducer_l3 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size_l2, 
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor3
    )

    # initialize the mid-resolution BertEncoder
    bert_encoder_midres = BertEncoder(config.config_l2)
    # initialize the low-resolution BertEncoder
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: low res to mid res
    bert_integrative_layer_2 = BertIntegrativeLayer(
        config, 
        hidden_size=config.hidden_size_l2,
        hidden_size_query=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )

    do_cheap_integrator = (stack_id in config.use_cheap_integrator_for_stacks)
    # from mid-res to high-res
    if not do_cheap_integrator:
        # cheap (non-transformer) method to integrate high- and mid-res hidden states
        bert_integrative_layer_1 = CheapMLPIntegrativeLayer(
            config, 
            hidden_size=config.hidden_size,
            hidden_size_query=config.query_size1,
            intermediate_size=config.intermediate_size_l1
        )
    else:
        # full Transformer layer as mid-to-highres upscaling
        BertIntegrativeLayer(
            config, 
            hidden_size=config.hidden_size,
            hidden_size_query=config.query_size1,
            intermediate_size=config.intermediate_size_l1//2
        )

    return (
        tokenizer,
        basemod,
        layer_embedding,
        layer_basetransformer,
        maxpool,
        maxpool_attn,
        bert_reducer_l2,
        bert_reducer_l3,
        bert_encoder_midres,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrative_layer_2,
        bert_integrative_layer_1
    )

def initialize_midlayers(config, basemod=None, tokenizer=None):
    """Initializes all the intermediate layers for the Anathem transformers"""
    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize bert attentive downsampling and skipconnection (1/2 embedding dim)
    bert_reduceintegrator_l2 = BertReduceAddIntegrativeLayer(
        config, 
        config.hidden_size_l2, # size of mid-res
        hidden_size_input=config.hidden_size, # size full-resolution
        hidden_size_query=config.hidden_size, # size full-resolution
        intermediate_size=config.intermediate_size_l1, # BertIntermediate dimension (expansion *4 the hiddensize)
        dim_reduction=config.scale_factor2, # reduce embedding dimension by factor of 2
        do_concat_hidden_and_query = True
    )

    # 1/4 the size
    bert_reduceintegrator_l3 = BertReduceAddIntegrativeLayer(
        config, 
        config.hidden_size_l3, # size of mid-res
        hidden_size_input=config.hidden_size_l2, # size full-resolution
        hidden_size_query=config.hidden_size_l2, # size full-resolution
        intermediate_size=config.intermediate_size_l2, # BertIntermediate dimension
        dim_reduction=config.scale_factor3, # reduce embedding dimension by factor of 2
        do_concat_hidden_and_query = True
    )

    # initialize the low-resolution BertEncoder
    bert_encoder_midres = BertEncoder(config.config_l2)
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: low res to mid res
    bert_integrative_layer_2 = BertIntegrativeLayer(
        config, 
        hidden_size=config.hidden_size_l2,
        hidden_size_query=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )
    
    # from mid-res to high-res
    bert_integrative_layer_1 = BertIntegrativeLayer(
        config, 
        hidden_size=config.hidden_size,
        hidden_size_query=config.query_size1,
        intermediate_size=config.intermediate_size_l1
    )

    return (
        maxpool,
        maxpool_attn,
        bert_reduceintegrator_l2,
        bert_reduceintegrator_l3,
        bert_encoder_midres,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrative_layer_2,
        bert_integrative_layer_1
    )


class AnathemBaseModule(nn.Module):
    """First Sstack of layers with embeddings, that go full circle form high-res to low-res back to high res"""
    def __init__(
            self, 
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device = None
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            tokenizer, basemod, 
            layer_embedding,
            layer_basetransformer,
            maxpool,
            maxpool_attn,
            bert_reducer_l2,
            bert_reducer_l3,
            bert_encoder_midres,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrative_layer_2,
            bert_integrative_layer_1
        ) = initialize_baselayers(config, basemod, tokenizer)

        self.get_extended_attention_mask = basemod.get_extended_attention_mask
        self.embedding = layer_embedding
        self.layer_basetransformer = layer_basetransformer
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducer_l2 = bert_reducer_l2
        self.bert_reducer_l3 = bert_reducer_l3
        self.bert_encoder_midres = bert_encoder_midres
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrative_layer_2 = bert_integrative_layer_2
        self.bert_integrative_layer_1 = bert_integrative_layer_1
        if device is None:
            self.to(basemod.device)
            #print(self.device)
            self.device = basemod.device
        else:
            self.to(device)
            self.device = device
    
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = input_ids
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        extended_attention_mask_l1 = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        # downsample the attention mask to l2 dimension
        attention_mask_l2 = self.maxpool_attn(attention_mask.float())
        extended_attention_mask_l2 = self.get_extended_attention_mask(attention_mask_l2,attention_mask_l2.shape, self.device)
        # downsample the attention mask to l3 dimension
        attention_mask_l3 = self.maxpool_attn(attention_mask_l2.float())
        extended_attention_mask_l3 = self.get_extended_attention_mask(attention_mask_l3,attention_mask_l3.shape, self.device)
        
        # embed
        embedding_output = self.embedding(
            input_ids = input_ids,
            position_ids = position_ids,
            token_type_ids = token_type_ids,
            #input_embeds=None,
            past_key_values_length = past_key_values_length
        )

        # first transformer block (vanilla transformer)
        out_l1 = self.layer_basetransformer(
            hidden_states = embedding_output,
            attention_mask = extended_attention_mask_l1,
            head_mask=head_mask,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=output_attentions
        )
        hidden_states_l1 = out_l1[0]

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_l1)

        # reduce dimenion on sequence 2
        out_l2 = self.bert_reducer_l2(
            hidden_states = hiddens_states_l1_reduced,
            attention_mask = extended_attention_mask_l2,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l1,
            encoder_attention_mask= extended_attention_mask_l1,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l2 = out_l2[0]

        # Vanilla transformers block at mid-resolution (1/2 seq-length)
        out_encoder = self.bert_encoder_midres(
            hidden_states=hidden_states_l2,
            attention_mask=extended_attention_mask_l2,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l2 = out_encoder[0]
        
        # reduce sequence length (1/4 seq-length)
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        out_l3 = self.bert_reducer_l3(
            hidden_states = hiddens_states_l2_reduced,
            attention_mask = extended_attention_mask_l3,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l2,
            encoder_attention_mask= extended_attention_mask_l2,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l3 = out_l3[0]

        #print(hidden_states_l3.shape)
        #print(extended_attention_mask_l3.shape)
        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_l3,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l3 = out_encoder[0]
        
        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_l3)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_l2 = self.bert_integrative_layer_2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_l2,
            head_mask = head_mask,
            query_hidden_states = hidden_states_upscaled3to2,
            query_attention_mask = attention_mask_l2
        )
        
        # upscaling: l3/l2 to l1 sequence length
        hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_l3)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_l2)
        hidden_states_upscaled = torch.cat((
            hidden_states_upscaled2to1, hidden_states_upscaled3to1
        ),axis=2)        

        # integrate low-resolution information back to original dimension
        hidden_states_l1 = self.bert_integrative_layer_1(
            hidden_states = hidden_states_l1,
            attention_mask = extended_attention_mask_l1,
            head_mask = head_mask,
            query_hidden_states = hidden_states_upscaled,
            query_attention_mask = extended_attention_mask_l1
        )
        if not return_dict:
            return (
                (hidden_states_l1, hidden_states_l2, hidden_states_l3),
                (extended_attention_mask_l1, extended_attention_mask_l2, extended_attention_mask_l3)
            )
        return {
            "hidden_states": (hidden_states_l1, hidden_states_l2, hidden_states_l3),
            "attention":(extended_attention_mask_l1, extended_attention_mask_l2, extended_attention_mask_l3)
        }


class AnathemMidModule(nn.Module):
    """Stack of layers that go full circle form high-res to low-res back to high res"""
    def __init__(
            self, 
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device=None,
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            maxpool,
            maxpool_attn,
            bert_reducerintegrator_l2,
            bert_reducerintegrator_l3,
            bert_encoder_midres,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrative_layer_2,
            bert_integrative_layer_1
        ) = initialize_midlayers(config, basemod, tokenizer)

        self.get_extended_attention_mask = get_extended_attention_mask
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducerintegrator_l2 = bert_reducerintegrator_l2
        self.bert_reducerintegrator_l3 = bert_reducerintegrator_l3
        self.bert_encoder_midres = bert_encoder_midres
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrative_layer_2 = bert_integrative_layer_2
        self.bert_integrative_layer_1 = bert_integrative_layer_1
        if device is None:
            self.to(basemod.device)
            #print(self.device)
            self.device = basemod.device
        else:
            self.to(device)
            self.device = device
    
    def forward(
        self,
        hidden_states_highres: torch.Tensor,
        hidden_states_midres: torch.Tensor,
        hidden_states_lowres: torch.Tensor,
        attention_mask: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_highres: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_midres: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_lowres: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = hidden_states_highres.shape[:2]
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        if extended_attention_mask_highres is None:
            extended_attention_mask_highres = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        if extended_attention_mask_midres is None:
            attention_mask_midres = self.maxpool_attn(attention_mask.float())
            extended_attention_mask_midres = self.get_extended_attention_mask(attention_mask_midres,attention_mask_midres.shape, self.device)
        if extended_attention_mask_lowres is None:
           attention_mask_lowres = self.maxpool_attn(attention_mask_midres.float())
           extended_attention_mask_lowres = self.get_extended_attention_mask(attention_mask_lowres,attention_mask_lowres.shape, self.device)

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_highres)

        # reduce dimenion on sequence 2
        hidden_states_l2 = self.bert_reducerintegrator_l2(
            inputs = hidden_states_highres, # from highres outputs previous layer (key, values)
            hidden_states = hidden_states_midres, # previous hidden-states for skip connection (short squence-dim, low-res)
            attention_mask = extended_attention_mask_midres,
            head_mask=None,
            query_hidden_states = hiddens_states_l1_reduced
        )

        # Vanilla transformers at mid-resolution (1/2 sequence-length)
        out_encoder = self.bert_encoder_midres(
            hidden_states=hidden_states_l2,
            attention_mask=extended_attention_mask_midres,
            head_mask = None,
            return_dict=return_dict
        )
        hidden_states_l2 = out_encoder[0]
        
        # reduce sequence length (to 1/4 sequence-length)
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        hidden_states_l3 = self.bert_reducerintegrator_l3(
            inputs = hidden_states_midres, # from highres outputs previous layer (key, values)
            hidden_states = hidden_states_lowres, # previous hidden-states for skip connection (short squence-dim, low-res)
            attention_mask = extended_attention_mask_lowres,
            head_mask=None,
            query_hidden_states = hiddens_states_l2_reduced
        )

        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_lowres,
            head_mask = None,
            return_dict=return_dict
        )
        hidden_states_lowres = out_encoder[0]
        
        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_lowres)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_midres = self.bert_integrative_layer_2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_midres,
            head_mask = None,
            query_hidden_states = hidden_states_upscaled3to2        )
        
        # upscaling: l3/l2 to l1 sequence length
        hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_lowres)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_midres)
        hidden_states_upscaled = torch.cat((
            hidden_states_upscaled2to1, hidden_states_upscaled3to1
        ),axis=2)        

        # integrate low-resolution information back to original dimension
        hidden_states_highres = self.bert_integrative_layer_1(
            hidden_states = hidden_states_highres,
            attention_mask = extended_attention_mask_highres,
            head_mask = None,
            query_hidden_states = hidden_states_upscaled,
            query_attention_mask = extended_attention_mask_highres
        )
        if not return_dict:
            return (
                (hidden_states_highres, hidden_states_midres, hidden_states_lowres),
                (extended_attention_mask_highres, extended_attention_mask_midres, extended_attention_mask_lowres)
            )
        return {
            "hidden_states": (hidden_states_highres, hidden_states_midres, hidden_states_lowres),
            "attention":(extended_attention_mask_highres, extended_attention_mask_midres, extended_attention_mask_lowres)
        }

class BertClassificationHead(nn.Module):
    def __init__(self, config, n_classes = 1, activation = 'sigmoid', device=None):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size*2, n_classes)
        if activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'none':
            self.activation = lambda x: x
        if device is not None:
            self.to(device)

    def forward(self, hidden_states, attention_mask) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        output_vectors=[]
        first_token_tensor = hidden_states[:, 0]
        output_vectors.append(first_token_tensor)
        # mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        output_vectors.append(sum_embeddings / sum_mask)
        # concatenate
        pooled_output = torch.concat(output_vectors, axis=1)
        #print(pooled_output.shape)
        logits = self.dense(pooled_output)
        return self.activation(logits)


def tokenize_anathem(text, device=device):
    #padding_length = int(math.ceil(max_length / 4)) * 
    tokens = tokenizer(text,padding=True, return_tensors='pt', pad_to_multiple_of=4)
    input_shape = tokens['input_ids'].size()

    # change token padding to be multiple of 4
    #ideal_length = int(math.ceil(input_shape[-1] / 4)) * 4 # should be a multiple of 4
    #if input_shape[-1]!=ideal_length:
    #  tokens = tokenizer(text,padding='max_length', max_length = ideal_length, return_tensors='pt')
    #  input_shape = tokens['input_ids'].size()
    
    token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    tokens['token_type_ids'] = token_type_ids
    for k,v in tokens.items():
        tokens[k] = v.to(device)
    
    return tokens

In [None]:
#config = make_config('distilroberta-base')
#config = make_config('t5-small') # can't use t5 because it uses relative
config = make_config('google/bert_uncased_L-12_H-512_A-8') # 

if False:
  (tokenizer,basemod,layer_embedding,layer_basetransformer,maxpool,maxpool_attn,bert_reducer_l2,
   bert_reducer_l3,bert_encoder_lowres,upscaler_x2,upscaler_x4,bert_integrative_layer_2,bert_integrative_layer_1) = initialize(config)

# make the basemod and tokenizer
basemod = AutoModel.from_pretrained(config.model_string)
basemod.to(device)
tokenizer = AutoTokenizer.from_pretrained(config.model_string)



Some weights of the model checkpoint at google/bert_uncased_L-12_H-512_A-8 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# the Anathem encoder includes the embeddings and first transformer block
anathem_encoder1 = AnathemBaseModule(config, basemod, tokenizer)
anathem_encoder2 = AnathemMidModule(config, basemod)

In [None]:
cls_head = BertClassificationHead(config, n_classes = 3, activation = 'none',device=device)


In [None]:
text = [
    "* Welcome home to this gorgeously upgraded, beautifully maintained, three-bedroom home with double attached garage. Drive up to this quiet cul-de-sac and let the experience begin. On the main floor, you’ll notice the abundance of natural light. There is a separate office with view over the front of the property. The layout was customized, with a great open living space. The kitchen is a chef’s dream, with a breakfast bar, granite countertops, stainless steel appliance package, a pantry, and a view out to the sunny west facing yard.",
    "There’s room for formal dining and the family room has a gas fireplace to relax by on the cooler nights. Out back, there’s a stunner of a deck, perfect for BBQ season! Upstairs, you’ll find a massive bonus room with tons of windows. There are two, secondary bedrooms and the master suite is amazing",
]

In [None]:
tokens = tokenize_anathem(text,device)

In [None]:
#stack 1
out1 = anathem_encoder1(
      input_ids = tokens['input_ids'],
      attention_mask = tokens['attention_mask'],
      token_type_ids = tokens['token_type_ids']
)
(hidden_states, extended_attention_masks) = out1



In [None]:
# stack2 
out2 = anathem_encoder2(
      hidden_states_highres = hidden_states[0],
      hidden_states_midres = hidden_states[1],
      hidden_states_lowres = hidden_states[2],
      extended_attention_mask_highres = extended_attention_masks[0],
      extended_attention_mask_midres = extended_attention_masks[1],
      extended_attention_mask_lowres = extended_attention_masks[2]
)
(hidden_states, extended_attention_masks) = out2

cls_head(hidden_states[0], tokens['attention_mask'])



tensor([[-0.8376, -0.3891, -0.6668],
        [-0.8747, -0.3621, -0.7735]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [None]:
out1[0][0].shape

torch.Size([2, 48, 768])

In [None]:
####

In [None]:
## Next steps, do something simple like sentiment analysis

In [None]:
from datasets import list_datasets, load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support
from scipy.special import softmax
#datasets_list = list_datasets()
#[k for k in datasets_list if 'phrasebank' in k]


In [None]:
#[k for k in datasets_list if 'phrasebank' in k]

dataset = load_dataset('financial_phrasebank', 'sentences_75agree')

# split
idx_train, idx_val = train_test_split(np.arange(len(dataset['train']['sentence'])), test_size=0.1)
dataset_train = [{'text':dataset['train']['sentence'][idx], 'label':dataset['train']['label'][idx]}  for idx in idx_train]
dataset_val = [{'text':dataset['train']['sentence'][idx], 'label':dataset['train']['label'][idx]} for idx in idx_val]



  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
print(len(dataset_train)); print(len(dataset_val))

3107
346


In [None]:
class MyDataset(Dataset):
    """torch dataset."""

    def __init__(self, dataset):
        self.data = dataset
        self.n = len(self.data)
    
    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        unit = self.data[idx]
        return unit

In [None]:
ds_train = MyDataset(dataset_train)
ds_val = MyDataset(dataset_val)

In [None]:
batch_size_train = 12
batch_size_val = 36
lr = 0.00005
eval_iter = 20
n_epochs = 1

In [None]:
dl_train = DataLoader(ds_train, batch_size=batch_size_train, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=batch_size_val, shuffle=False)

In [None]:
optimizer = AdamW(list(anathem_encoder1.parameters()) + list(anathem_encoder2.parameters()) + list(cls_head.parameters()), lr=lr)

In [None]:

optimizer.zero_grad()
anathem_encoder1.train()
anathem_encoder2.train()
cls_head.train()
for epoch in range(n_epochs):

  for iteration, batch in enumerate(tqdm(dl_train, disable=True)):
      
      # tokenize the batch
      tokens = tokenize_anathem(batch['text'],device)
      target = batch['label'].to(device)
      
      optimizer.zero_grad()

      out1 = anathem_encoder1(
        input_ids = tokens['input_ids'],
        attention_mask = tokens['attention_mask'],
        token_type_ids = tokens['token_type_ids']
      )
      (hidden_states, extended_attention_masks) = out1

      features,_ = anathem_encoder2(
          hidden_states_highres = hidden_states[0],
          hidden_states_midres = hidden_states[1],
          hidden_states_lowres = hidden_states[2],
          extended_attention_mask_highres = extended_attention_masks[0],
          extended_attention_mask_midres = extended_attention_masks[1],
          extended_attention_mask_lowres = extended_attention_masks[2]
      )

      # prediction
      preds = cls_head(features[0], tokens['attention_mask'])

      # loss
      loss = nn.functional.cross_entropy(preds, target)
      loss.backward()
      optimizer.step()

      # do evaluation
      if ((iteration+1) % eval_iter)==0:  
          anathem_encoder1.eval()
          anathem_encoder2.eval()
          cls_head.eval()
          # tokenize the eval
          eval_logits = []
          eval_targets = []
          for i, batch_eval in enumerate(tqdm(dl_val, disable=True)):
              with torch.no_grad():
                  # tokenize the batch
                  tokens_eval = tokenize_anathem(batch_eval['text'], device)
                  labels_eval = batch_eval['label'].to(device)
                  out_eval1 = anathem_encoder1(
                      input_ids = tokens_eval['input_ids'],
                      attention_mask = tokens_eval['attention_mask'],
                      token_type_ids = tokens_eval['token_type_ids']
                  )
                  (hidden_states, extended_attention_masks) = out_eval1
                  features,_ = anathem_encoder2(
                      hidden_states_highres = hidden_states[0],
                      hidden_states_midres = hidden_states[1],
                      hidden_states_lowres = hidden_states[2],
                      extended_attention_mask_highres = extended_attention_masks[0],
                      extended_attention_mask_midres = extended_attention_masks[1],
                      extended_attention_mask_lowres = extended_attention_masks[2]
                  )
                  # prediction
                  batch_logits = cls_head(features[0], tokens_eval['attention_mask'])
                  eval_logits+=batch_logits.detach().tolist()
                  eval_targets+=labels_eval.detach().tolist()

          eval_prec,eval_recall,eval_f1,eval_support = precision_recall_fscore_support(eval_targets, np.array(eval_logits).argmax(axis=1),zero_division=0)
          print('E:%d; i:%d: f1:%0.3f (%0.3f); prec:%0.3f (%0.3f); rec:%0.3f (%0.3f)' % (epoch, iteration, eval_f1.mean(), eval_f1.min(), eval_prec.mean(), eval_prec.min(), eval_recall.mean(), eval_recall.min()))
          cls_head.train()
          anathem_encoder1.train()
          anathem_encoder2.train()






E:0; i:19: f1:0.402 (0.000); prec:0.352 (0.000); rec:0.469 (0.000)




E:0; i:39: f1:0.326 (0.000); prec:0.400 (0.000); rec:0.372 (0.000)




E:0; i:59: f1:0.459 (0.158); prec:0.531 (0.405); rec:0.485 (0.095)




E:0; i:79: f1:0.506 (0.305); prec:0.583 (0.450); rec:0.494 (0.231)




E:0; i:99: f1:0.499 (0.190); prec:0.555 (0.383); rec:0.551 (0.116)




E:0; i:119: f1:0.552 (0.280); prec:0.663 (0.568); rec:0.534 (0.179)




E:0; i:139: f1:0.661 (0.469); prec:0.708 (0.600); rec:0.636 (0.385)




KeyboardInterrupt: ignored

In [None]:
target

tensor([1, 0, 2, 1, 1, 0, 1, 1, 1, 1, 1, 2])

## Test performance speed

In [None]:
# how many parameters in the model in total
from math import prod
nparam = 0
for encoder in [anathem_encoder1, anathem_encoder2]:
    for na,l in encoder.named_parameters():
        nparam+=prod(l.data.shape)
print('Number of parameters for anathem: %d' % nparam)
# 33676544

Number of parameters for anathem: 33283328


In [None]:
# compare this to distilbert
#other_mod = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
other_mod = AutoModel.from_pretrained('google/bert_uncased_L-12_H-512_A-8')

Some weights of the model checkpoint at google/bert_uncased_L-12_H-512_A-8 were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
nparam = 0
for na,l in other_mod.named_parameters():
    nparam+=prod(l.data.shape)

print('Number of parameters for other-mod: %d' % nparam)

# number of parameters for anathem-trans: 33676544 (google/bert_uncased_L-12_H-512_A-8)
# number of parametres for anathem-trans: 78973824 (includng 2 more mid-res encoders)
# number of parameters for anathem-trans: 73062528 (with a 768 dimension)
# Number of parameters for distilroberta: 82118400 (with a 768 dimension)
# Number of parameters  all-MiniLM-L6-v2: 22713216
# Number of parameters google/bert_uncased_L-12_H-512_A-8: 53982720 (512 dim, 12L)


Number of parameters for other-mod: 53982720


## Test Performance Speed at inference (CPU)
- distilroberta-base: 10 batches: 23.517s , CPU
- oogle/bert_uncased_L-12_H-512_A-8: 10 batches: 12.44s, CPU
- anathem (distilroberta-768): 10 batches, 23.23s, 
- anathem ((google/bert_uncased_L-12_H-512_A-8)): 10 batches, ~7.5s, CPU

## Test Performance Speed at inference (GPU)
- anathem ((google/bert_uncased_L-12_H-512_A-8)): 30 batches, 0.79s, GPU
- google/bert_uncased_L-12_H-512_A-8: 30 batches: 0.8 GPU


In [None]:
import time

In [None]:
time1 = time.time()
for iteration, batch in enumerate(tqdm(dl_train, disable=True)):
    if iteration>30:
        time2 = time.time()
        print(time2-time1)
        break
    with torch.no_grad():
        tokens = tokenize_anathem(batch['text'])  
        (hidden_states, extended_attention_masks) = anathem_encoder1(
            input_ids = tokens['input_ids'],
            attention_mask = tokens['attention_mask'],
            token_type_ids = tokens['token_type_ids']
        )
        features,_ = anathem_encoder2(
            hidden_states_highres = hidden_states[0],
            hidden_states_midres = hidden_states[1],
            hidden_states_lowres = hidden_states[2],
            extended_attention_mask_highres = extended_attention_masks[0],
            extended_attention_mask_midres = extended_attention_masks[1],
            extended_attention_mask_lowres = extended_attention_masks[2]
        )

0.8027215003967285


In [None]:
time3 = time.time()
for iteration, batch in enumerate(tqdm(dl_train, disable=True)):
    if iteration>30:
        time4 = time.time()
        print(time4-time3)
        break
    with torch.no_grad():
        tokens = tokenize_anathem(batch['text'])  
        out = basemod(
            input_ids = tokens['input_ids'],
            attention_mask = tokens['attention_mask'],
            token_type_ids = tokens['token_type_ids']
        )

0.7066085338592529


In [None]:
eval

array([0.        , 0.86464646, 0.52173913])

In [None]:
eval_prec,eval_recall,eval_f1,eval_support = precision_recall_fscore_support(eval_targets, np.array(eval_logits).argmax(axis=1),zero_division=0)

## Variant: Possibly Faster Integrative Layer

The above version uses a BertIntegrativeLayer that uses the high-res hidden-states as the key/values, and the upscaled-low res as the query

This variant flips it: the high-res is the query (thereby upscaling via attention) and the low-res are the value and keys

#### Varient #2 has slightly fewer parameters: 33283328 vs 336

In [None]:
%pip install torch transformers datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m72.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m38.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m55

In [3]:
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForMaskedLM
import torch
from typing import List, Optional, Tuple, Union
from torch import nn
import torch.nn.functional as F
from torch.cuda import is_available
if is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

from transformers.models.bert.modeling_bert import BertEncoder
from transformers.activations import ACT2FN
import copy
import math


from transformers import BertTokenizer

In [45]:
class CustomTokenizer:
    def __init__(
        self, 
        model_string='google/bert_uncased_L-12_H-512_A-8', 
        n_cls_prepend = 4, 
        n_pad_to_multiple_of=4, 
        downscale_multiple=2

    ):
        # initialize the tokenizer from the base model
        self.base_tokenizer = AutoTokenizer.from_pretrained(model_string)
        # how many cls tokens to prepend to the fullsize data
        self.n_cls_prepend = n_cls_prepend
        self.n_pad_to_multiple_of = n_pad_to_multiple_of
        for k in dir(self.base_tokenizer):
            if not (k[0]=='_' or k=='tokenize' or k=='encode' or k=='build_inputs_with_special_tokens' or k == 'batch_encode_plus'):
                setattr(self,k,getattr(self.base_tokenizer, k))
        self.downscale_multiple = downscale_multiple
        # downscale attention
        self.maxpool_attn = nn.MaxPool1d(
            (self.downscale_multiple), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True
        )

    def __call__(self, text, pad_to_multiple_of=None, add_special_tokens = True, return_tensors=None, *args, **kwargs):
        if pad_to_multiple_of is None:
            pad_to_multiple_of = self.n_pad_to_multiple_of
        tokens = self.base_tokenizer(
            text, 
            pad_to_multiple_of=(pad_to_multiple_of if not add_special_tokens else False), 
            add_special_tokens=add_special_tokens, 
            return_tensors=return_tensors if (not add_special_tokens) else None,
            *args, 
            **kwargs
        )
        if add_special_tokens:
            tokens = self._prepend_extra_cls_tokens_because_of_maxpooling(tokens, return_tensors)

        # downscale the attention, add to tokens
        tokens = self.downscale_attention(
            tokens, downscale_multiple=[self.downscale_multiple, self.downscale_multiple],name='attention_mask'
        )
        # dowscale the excess_cls_tokens, add to tokens
        tokens = self.downscale_attention(
            tokens, downscale_multiple=[self.downscale_multiple, self.downscale_multiple],name='excess_cls_ids'
        )        
        return tokens
    
    def _num_pad_tokens(self, token_list):
        """Calculates how many PAD tokens to append to sequence to make a multiple of X"""
        return (self.n_pad_to_multiple_of - ((len(token_list)+(self.n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of

    def _prepend_extra_cls_tokens_because_of_maxpooling(self, tokens,return_tensors=None):
        n_cls_prepend = self.n_cls_prepend
        # prepend (n-1) CLS tokens to the front of the token_ids (because of maxpooling)
        # also pad so that the total length is a multiple of n_cls_prepend
        #num_pad_tokens = (self.n_pad_to_multiple_of - ((len_tokens+(n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of
        tokens['input_ids'] = [
            [self.cls_token_id]*(n_cls_prepend-1)+input_id + [self.pad_token_id]*self._num_pad_tokens(input_id)
            for input_id 
            in tokens['input_ids']
        ]
        tokens['excess_cls_ids'] = [
            [0]*(n_cls_prepend)+attnmask[1:] +[0]*self._num_pad_tokens(attnmask)
            for attnmask
            in tokens['attention_mask']
        ]
        tokens['attention_mask'] = [
            [1]*(n_cls_prepend-1)+attnmask +[0]*self._num_pad_tokens(attnmask)
            for attnmask
            in tokens['attention_mask']
        ]
        if 'token_type_ids' in tokens.keys():
            tokens['token_type_ids'] = [
                # we use the token_type_ids 
                [toktypeid[0]]*(n_cls_prepend-1)+toktypeid +[toktypeid[-1]]*self._num_pad_tokens(toktypeid)
                for toktypeid
                in tokens['token_type_ids']
            ]
        if return_tensors == 'pt':
            for k,v in tokens.items():
                #print(k)
                #print(v)
                tokens[k] = torch.LongTensor(v)
        return tokens

    def encode(self, text, pad_to_multiple_of=4, add_special_tokens = True, *args, **kwargs):
        encoded = self.base_tokenizer.encode(text, pad_to_multiple_of=False, add_special_tokens=add_special_tokens, *args, **kwargs)
        if add_special_tokens:
            encoded = [self.cls_token_id]*(pad_to_multiple_of-1) + encoded
        if bool(pad_to_multiple_of):
            num_pad_tokens = (pad_to_multiple_of - (len(encoded) % pad_to_multiple_of)) % pad_to_multiple_of
            encoded += [self.pad_token_id] * num_pad_tokens
        return encoded

    def tokenize(self, text, add_special_tokens=True, *args, **kwargs):
        toks = self.base_tokenizer.tokenize(text, add_special_tokens=add_special_tokens, *args, **kwargs)
        if add_special_tokens:
            toks = [self.cls_token] * (self.n_cls_prepend-1) + toks
        return toks

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ):
        out = self.base_tokenizer.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
        return [self.cls_token_id]*3 + out

    def batch_encode_plus(self, batch_text_or_text_pairs, *args, **kwargs):
        batched_encoded = self.base_tokenizer.batch_encode_plus( batch_text_or_text_pairs, *args, **kwargs)
        batched_encoded.update({'foo':'bar'})
        return batched_encoded

    def downscale_attention(self, tokens, downscale_multiple=None, name = 'attention_mask'):
        """
        Reduces the sequence-dimenion by self.downscale_multiple using nn.maxpool
        Adds the downscale attention to the tokens dictionary
        """
        if downscale_multiple is None:
            downscale_multiple = [self.downscale_multiple, self.downscale_multiple]
        
        # fullsize attention
        attn = tokens[name]
        if not isinstance(attn, torch.Tensor):
            attn = torch.Tensor(attn)
        
        for i, mult in enumerate(downscale_multiple):
            
            name_of_downsized_attn = '%s_l%d' % (name, i+2)
            with torch.no_grad():
                attn = self.maxpool_attn(attn.float())
            tokens[name_of_downsized_attn] = attn
        return tokens

In [46]:
tokenizer = CustomTokenizer(
        model_string='google/bert_uncased_L-12_H-512_A-8', 
        n_cls_prepend = 4, 
        n_pad_to_multiple_of=4, 
        downscale_multiple=2
    )

Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.


In [47]:
text = [
    "A standard [MASK] clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with issues.",
    "It usually consists of two elements: a trigger event or circumstance and a [MASK] obligation. The trigger event or circumstance is the [MASK] of the agreement, misconduct, or negligence of the indemnifying party or its affiliates"
]

tokens = tokenizer(text, return_tensors='pt', padding=True)

In [54]:
tokens['excess_cls_ids_l2']

tensor([[0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.]])

In [None]:
class BertSelfAttnDimensionReduction(nn.Module):
    """Bert Attention Layer that uses a dimension-reduced version of the query, so to reduce the dimension of the outputs"""
    def __init__(
        self, 
        config, 
        hidden_size_input=768, 
        hidden_size_query = None,
        position_embedding_type=None, 
        dim_reduction = 2
    ):
        """Special type of Bert Self attention that reduces the dimension of the inputs by half"""
        super().__init__()
        if (config.hidden_size // dim_reduction) % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.dim_reduction = dim_reduction
        self.hidden_size_input = hidden_size_input
        self.hidden_size_reduced = hidden_size_input // dim_reduction
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size_reduced / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_input, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_input, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.

        key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
        value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if encoder_attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            #print(attention_scores.shape)
            #print(attention_scores.shape)
            attention_scores = attention_scores + encoder_attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs    


class InterpolateCombo(nn.Module):
    """there could also be an attentive way to do this"""
    def __init__(self, scale_factor=2, dropout=0.05, alpha=0.667):
        """Arguments:
        :param scaler_factor: float, multiple of up-scaling
        :param dropout: float, dropout proportion
        :param alpha: float, mixture weight between nearest-neighbor vs linear-interpolation
        """
        super(InterpolateCombo, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.dropout = nn.Dropout(dropout)
        self.a = alpha
        
    def forward(self, x):
        x_trans = x.transpose(-2,-1)
        z = self.a*self.interp(x_trans, mode='nearest',scale_factor=self.scale_factor) + (1-self.a)*self.interp(x_trans, mode='linear',scale_factor=self.scale_factor)
        z = self.dropout(z)
        return z.transpose(-2,-1)


class BertCrossAttention(nn.Module):
    def __init__(
        self, 
        config, 
        hidden_size,
        hidden_size_query,
        hidden_size_keyvalue=None,
        position_embedding_type=None
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.hidden_size_query = hidden_size_query
        if hidden_size_keyvalue is None:
            hidden_size_keyvalue = hidden_size
        self.hidden_size_keyvalue = hidden_size_keyvalue
        if self.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_keyvalue, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_keyvalue, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: Optional[torch.FloatTensor] = None,
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(query_hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)
        
        use_cache = past_key_value is not None
        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs


class BertReduceAddIntegrativeLayer(nn.Module):
    """Bert Layer that does dimenion reduction along embedding-dimenion and integrations a skip connection"""
    def __init__(
            self, 
            config, 
            hidden_size,
            hidden_size_input=None,
            hidden_size_query=None,
            intermediate_size=None,
            dim_reduction=2,
            do_concat_hidden_and_query = True
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.do_concat_hidden_and_query = do_concat_hidden_and_query
        assert bool(do_concat_hidden_and_query), 'not implemented: concatenation of query and hidden-states must happen'
        self.hidden_size = hidden_size
        if dim_reduction is None:
            dim_reduction = 2
        self.dim_reduction = dim_reduction
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size
        if hidden_size_input is None:
            hidden_size_input = hidden_size
        self.hidden_size_input = hidden_size_input
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query + do_concat_hidden_and_query*hidden_size
        self.hidden_size_concat = int(hidden_size + hidden_size_input)

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertSelfAttnDimensionReduction( 
            config, 
            hidden_size_input=self.hidden_size_input, 
            hidden_size_query = self.hidden_size_query,
            position_embedding_type="absolute", 
            dim_reduction = self.dim_reduction
        )
        self.is_decoder = config.is_decoder
        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)
    
    def forward(
        self,
        inputs: torch.Tensor, # higher-resolution inputs for key and values (long sequence dimension)
        hidden_states: torch.Tensor, # previous hidden-states for skip connection (short squence-dim, low-res)
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: torch.FloatTensor = None, # hidden-states for query (short squence-dim, low-res)
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        if self.do_concat_hidden_and_query:
            query_hidden_states_plus = torch.cat((query_hidden_states, hidden_states),axis=2)
        # cross attn between (low-res) query vector and (high-res) key-values
        cross_attn_outputs = self.attention(
            query_hidden_states_plus, # query (short seq-dim, high-res)
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states = inputs, # for key/value (longer sequence dimension, high-res)
            past_key_value=past_key_value,
            output_attentions=output_attentions,
        )
        cross_hidden_states = cross_attn_outputs[0]

        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.dropout_attn(self.output_attn(cross_hidden_states))
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.intermediate_act_fn(self.intermediate(
            self.cat((hidden_states, query_hidden_states),axis=2)
        ))
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        intermediate_states = self.dropout_intm(self.output_intm(intermediate_states))
        out_states = self.lnorm_intm(intermediate_states + hidden_states)

        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states

try:
    from transformers.modeling_utils import get_extended_attention_mask
except:
    def get_extended_attention_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: device) -> torch.Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)

                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
                    causal_mask = torch.cat(
                        [
                            torch.ones(
                                (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
                            ),
                            causal_mask,
                        ],
                        axis=-1,
                    )

                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask



In [None]:


# how does bert actually work?
"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class BertIntegrativeLayer(nn.Module):
    """Vanilla Bert Layer, but integrates other hiddens states from a parallel transformers stack typically low-re"""
    def __init__(
            self, 
            config, 
            hidden_size, # dimensions of the (high-res) hiddens states; same dimension as output
            hidden_size_keyvalues, # dimensions of (low-res) states used as key/values; 1/2 sequence-length and dim
            hidden_size_query_to_concat=None, # dimensions of (low-res) to concat to hidden_states; 1/2 sequence-length and dim
            intermediate_size=None
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.hidden_size = hidden_size
        self.hidden_size_keyvalues = hidden_size_keyvalues
        if hidden_size_query_to_concat is None:
            hidden_size_query_to_concat = hidden_size_keyvalues
        self.hidden_size_query_to_concat = hidden_size_query_to_concat
        self.hidden_size_query = int(hidden_size + hidden_size_query_to_concat)
        self.hidden_size_concat = int(hidden_size + hidden_size_query_to_concat)
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertCrossAttention(
            config, 
            hidden_size= self.hidden_size, # high dim output
            hidden_size_query = self.hidden_size_query, # high dim query
            hidden_size_keyvalue = self.hidden_size_keyvalues, # low-dim keyvalues
            position_embedding_type="absolute"
        )
        self.is_decoder = config.is_decoder
        #self.intermediate = BertIntermediate(config)
        #self.output = BertOutput(config)
        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states: torch.Tensor, # high-res hidden states (same dimensions as output), used as query
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        keyvalue_hidden_states: torch.Tensor=None, # low-res hidden-states (1/2 seq-dim) used for key-value pairs
        query_to_concat_hidden_states: torch.Tensor=None, # to concatenate to query
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        # cross attn between hiddens states and (low-res) query vector
        cross_attn_outputs = self.attention(
            hidden_states = keyvalue_hidden_states,
            attention_mask = attention_mask,
            head_mask = head_mask,
            query_hidden_states = torch.cat((hidden_states, query_to_concat_hidden_states),axis=2),
            query_attention_mask = query_attention_mask
        )
        cross_hidden_states = cross_attn_outputs[0]
        assert cross_hidden_states.shape[1]==hidden_states.shape[1], f"{cross_hidden_states.shape[1]},{cross_hidden_states.shape[2]} vs {hidden_states.shape[1]},{hidden_states[2]}"
        assert cross_hidden_states.shape[2]==hidden_states.shape[2]


        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.output_attn(cross_hidden_states)
        cross_hidden_states = self.dropout_attn(cross_hidden_states)
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.cat((hidden_states, query_to_concat_hidden_states),axis=2)
        intermediate_states = self.intermediate(intermediate_states)
        intermediate_states = self.intermediate_act_fn(intermediate_states)
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        out_states = self.output_intm(intermediate_states)
        out_states = self.dropout_intm(out_states)
        out_states = self.lnorm_intm(out_states + hidden_states)

        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states



In [None]:


# how does bert actually work?
"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class CheapMLPIntegrativeLayer(nn.Module):
    """Cheap (non-transformer) Integrator layer that merges a (low-res) layers with higher-res"""
    def __init__(
            self, 
            config, 
            hidden_size, # dimensions of the (high-res) hiddens states; same dimension as output
            hidden_size_keyvalues=None, # dimensions of (low-res) states used as key/values; 1/2 sequence-length and dim
            hidden_size_query_to_concat=None, # dimensions of (low-res) to concat to hidden_states; 1/2 sequence-length and dim
            intermediate_size=None
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.hidden_size = hidden_size
        if hidden_size_keyvalues is None:
            hidden_size_keyvalues = hidden_size
        self.hidden_size_keyvalues = hidden_size_keyvalues
        if hidden_size_query_to_concat is None:
            hidden_size_query_to_concat = hidden_size_keyvalues
        self.hidden_size_query_to_concat = hidden_size_query_to_concat
        self.hidden_size_query = int(hidden_size + hidden_size_query_to_concat)
        if intermediate_size is None:
            intermediate_size = int(2*hidden_size)
        self.intermediate_size = intermediate_size

        # expand hidden-size to a multiple
        self.dense_expander = nn.Linear(
            self.hidden_size_query, 
            self.intermediate_size
        ) # deflate back to same size as hidden-state
        self.dense_deflator = nn.Linear(
            self.intermediate_size,
            self.hidden_size
        )

        # intermediate activation function
        self.intermediate_act_fn = nn.RReLU(0.0625, 0.125)

        # corresponds to BertOutput
        self.lnorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states: torch.Tensor, # high-res hidden states (same dimensions as output), used as query
        attention_mask = None, # ignored
        head_mask = None, # ignored
        keyvalue_hidden_states =None, # ignored
        query_to_concat_hidden_states: torch.Tensor=None, # to concatenate to hidden_states
        query_attention_mask = None, # ignored
        past_key_value = None, # ignored
        output_attentions = False, # ignored
    ) -> torch.Tensor:

        # concat (lowres) to hidden-states
        inputs = self.cat((hidden_states, query_to_concat_hidden_states),axis=2)
        # expand x2 dimension
        intermediate_states = self.dense_expander(inputs)
        # activation (leaky relue)
        intermediate_states = self.intermediate_act_fn(intermediate_states)
        # like BertOutput
        out_states = self.dense_deflator(intermediate_states)
        # dropout
        out_states = self.dropout(out_states)
        # combine with hidden-state inputs
        out_states = self.lnorm(out_states + hidden_states)

        return out_states



In [None]:

def make_config(
    modelstring = "distilroberta-base",
    num_transformer_stacks = 3,
    scale_ratio2 = 0.5,
    scale_ratio3 = 0.25,
    multiplier_intermediate2 = 4.0,
    multiplier_intermediate3 = 4.0,
    num_layers_l2 = 1, # mid-res encoder
    num_layers_l3 = 3, # low-res encoder
    dropout_scaling = 0.05,
    do_cheap_integrator = [1]
):
    #if True:
    #modelstring = "distilroberta-base"
    #scale_ratio2 = 0.5
    #scale_ratio3 = 0.25
    #scale_intermediate2 = 4
    #scale_intermediate3 = 4
    base_config = AutoConfig.from_pretrained(modelstring)
    config_l2 = copy.deepcopy(base_config)
    config_l3 = copy.deepcopy(base_config)
    setattr(base_config, 'model_string', modelstring)
    setattr(base_config,'num_transformer_stacks', num_transformer_stacks)
    setattr(base_config,'num_layers_l2', num_layers_l2)
    setattr(base_config,'num_layers_l3', num_layers_l3)
    setattr(base_config,'scale_ratio2', scale_ratio2)
    setattr(base_config,'scale_ratio3', scale_ratio3)
    setattr(base_config,'scale_factor2', int(1/base_config.scale_ratio2))
    setattr(base_config,'scale_factor3', int(1/base_config.scale_ratio3*base_config.scale_ratio2))
    setattr(base_config,"hidden_size_l2", int(base_config.hidden_size * scale_ratio2))
    setattr(base_config,"hidden_size_l3", int(base_config.hidden_size * scale_ratio3))
    setattr(base_config,"intermediate_size_l1", int(base_config.hidden_size_l2*multiplier_intermediate2))
    setattr(base_config,"intermediate_size_l2", int(base_config.hidden_size_l3*multiplier_intermediate3))
    setattr(base_config,"query_size1", base_config.hidden_size_l2 + base_config.hidden_size_l3)
    setattr(base_config,"query_size2", base_config.hidden_size_l3)
    setattr(base_config,"dropout_scaling", dropout_scaling)
    setattr(base_config,"use_cheap_integrator_for_stacks", do_cheap_integrator)

    # make the configuration for the l2 mid-res encoder
    config_l2.hidden_size = base_config.hidden_size_l2
    config_l2.num_hidden_layers = num_layers_l2
    setattr(base_config, 'config_l2', config_l2)

    # make the configuration for the l3 encoder
    config_l3.hidden_size = base_config.hidden_size_l3
    config_l3.num_hidden_layers = num_layers_l3
    setattr(base_config, 'config_l3', config_l3)
    return base_config

def initialize_baselayers(config, basemod = None, tokenizer=None, stack_id=0):
    """Initializes the embeddings and first stack of layers for the Anathem transformers"""
    # initialize the basemodel
    if basemod is None:
        basemod = AutoModel.from_pretrained(config.model_string)
    if tokenizer is None:
        # download pretrained tokenizer
        tokenizer = AutoTokenizer.from_pretrained(config.model_string)
    
    device = basemod.device
    setattr(config, 'device', device)

    # get basemodel's embeddings
    layer_embedding = copy.deepcopy(basemod._modules['embeddings'])

    # get basemodel's first transformer block
    layer_basetransformer = copy.deepcopy(basemod._modules['encoder']._modules['layer']._modules['0'])

    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize downsampling attention layers
    bert_reducer_l2 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size, 
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor2
    )
    # 1/4 hidden size
    bert_reducer_l3 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size_l2, 
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor3
    )

    # initialize the mid-resolution BertEncoder
    bert_encoder_midres = BertEncoder(config.config_l2)
    # initialize the low-resolution BertEncoder
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: low res to mid res
    bert_integrater_l2 = BertIntegrativeLayer(
        config, 
        hidden_size=config.hidden_size_l2,
        hidden_size_keyvalues = config.hidden_size_l3,
        hidden_size_query_to_concat=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )

    # from mid-res to high-res
    do_cheap_integrator = (stack_id in config.use_cheap_integrator_for_stacks)
    # from mid-res to high-res
    if not do_cheap_integrator:
        bert_integrater_l1 = BertIntegrativeLayer(
            config, 
            hidden_size=config.hidden_size,
            hidden_size_keyvalues = config.hidden_size_l2,
            hidden_size_query_to_concat=config.hidden_size_l2,
            intermediate_size=config.intermediate_size_l1
        )
    else:
        bert_integrater_l1 = CheapMLPIntegrativeLayer(
            config, 
            hidden_size=config.hidden_size,
            hidden_size_query_to_concat=config.hidden_size_l2,
            intermediate_size=config.hidden_size*2
        )

    return (
        tokenizer,
        basemod,
        layer_embedding,
        layer_basetransformer,
        maxpool,
        maxpool_attn,
        bert_reducer_l2,
        bert_reducer_l3,
        bert_encoder_midres,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrater_l2,
        bert_integrater_l1
    )

def initialize_midlayers(config, basemod=None, tokenizer=None, stack_id=1):
    """Initializes all the intermediate layers for the Anathem transformers"""
    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize bert attentive downsampling and skipconnection (1/2 embedding dim)
    bert_reduceintegrator_l2 = BertReduceAddIntegrativeLayer(
        config, 
        config.hidden_size_l2, # size of mid-res
        hidden_size_input=config.hidden_size, # size full-resolution
        hidden_size_query=config.hidden_size, # size full-resolution
        intermediate_size=config.intermediate_size_l1, # BertIntermediate dimension (expansion *4 the hiddensize)
        dim_reduction=config.scale_factor2, # reduce embedding dimension by factor of 2
        do_concat_hidden_and_query = True
    )

    # 1/4 the size
    bert_reduceintegrator_l3 = BertReduceAddIntegrativeLayer(
        config, 
        config.hidden_size_l3, # size of mid-res
        hidden_size_input=config.hidden_size_l2, # size full-resolution
        hidden_size_query=config.hidden_size_l2, # size full-resolution
        intermediate_size=config.intermediate_size_l2, # BertIntermediate dimension
        dim_reduction=config.scale_factor3, # reduce embedding dimension by factor of 2
        do_concat_hidden_and_query = True
    )

    # initialize the low-resolution BertEncoder
    bert_encoder_midres = BertEncoder(config.config_l2)
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: from low-res to mide-res
    bert_integrater_l2 = BertIntegrativeLayer(
        config, 
        hidden_size=config.hidden_size_l2,
        hidden_size_keyvalues = config.hidden_size_l3,
        hidden_size_query_to_concat=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )
    
    do_cheap_integrator = (stack_id in config.use_cheap_integrator_for_stacks)
    if not do_cheap_integrator:
        # from mid-res to high-res
        bert_integrater_l1 = BertIntegrativeLayer(
            config, 
            hidden_size=config.hidden_size,
            hidden_size_keyvalues = config.hidden_size_l2,
            hidden_size_query_to_concat=config.hidden_size_l2,
            intermediate_size=config.intermediate_size_l1
        )
    else:
        bert_integrater_l1 = CheapMLPIntegrativeLayer(
            config, 
            hidden_size=config.hidden_size,
            hidden_size_query_to_concat=config.hidden_size_l2,
            intermediate_size=config.hidden_size*2
        )

    return (
        maxpool,
        maxpool_attn,
        bert_reduceintegrator_l2,
        bert_reduceintegrator_l3,
        bert_encoder_midres,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrater_l2,
        bert_integrater_l1
    )


def initialize_finaltransformerlayers(config, basemod=None, tokenizer=None, names_encoder_module = 'encoder', stack_id=3):
    """Initializes the final BertLayer before output, but copying the final BertLayer from `Basemod`"""
    # initialize the maxpooling downsamplers
    assert basemod is not None, "`initialize_finaltransformerlayers` requires the basemod to instantiate the final transformer block"

    # get the Encoder stacks
    assert names_encoder_module in basemod._modules.keys(), 'expected %s in basemod._modules' % names_encoder_module
    basemod_encoder_stack = get_to_bertlayer(basemod, target_layer_name = names_encoder_module)

    # get the name of the final transformer block (-1) in encoder
    names_of_final_transformer_block = list(basemod_encoder_stack._modules['layer']._modules.keys())[-1]

    # get the final transformer block (NN weights pretrained)
    bert_finaltransformer_block = basemod_encoder_stack._modules['layer']._modules[
        names_of_final_transformer_block
    ]

    return copy.deepcopy(bert_finaltransformer_block)

def get_to_bertlayer(basemod, target_layer_name = 'encoder', model_string = None):
    """Clumsily locates a particular layer within a pretrained bert model"""
    if  target_layer_name in basemod._modules.keys():
        return basemod._modules[target_layer_name]
    elif target_layer_name in basemod._modules['bert']._modules.keys():
        return basemod._modules['bert']

In [None]:

class AnathemBaseModule(nn.Module):
    """First Sstack of layers with embeddings, that go full circle form high-res to low-res back to high res"""
    def __init__(
            self, 
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device = None,
            stack_id=0
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            tokenizer, basemod, 
            layer_embedding,
            layer_basetransformer,
            maxpool,
            maxpool_attn,
            bert_reducer_l2,
            bert_reducer_l3,
            bert_encoder_midres,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrater_l2,
            bert_integrater_l1
        ) = initialize_baselayers(config, basemod, tokenizer, stack_id=0)

        self.get_extended_attention_mask = basemod.get_extended_attention_mask
        self.embedding = layer_embedding
        self.layer_basetransformer = layer_basetransformer
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducer_l2 = bert_reducer_l2
        self.bert_reducer_l3 = bert_reducer_l3
        self.bert_encoder_midres = bert_encoder_midres
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrater_l2 = bert_integrater_l2
        self.bert_integrater_l1 = bert_integrater_l1
        self.stack_id = 0
        if device is None:
            self.to(basemod.device)
            #print(self.device)
            self.device = basemod.device
        else:
            self.to(device)
            self.device = device
    
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        attention_mask_l2: Optional[torch.Tensor] = None,
        attention_mask_l3: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = input_ids
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        extended_attention_mask_l1 = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        # downsample the attention mask to l2 dimension
        if attention_mask_l2 is None:
            attention_mask_l2 = self.maxpool_attn(attention_mask.float())
        extended_attention_mask_l2 = self.get_extended_attention_mask(attention_mask_l2,attention_mask_l2.shape, self.device)
        # downsample the attention mask to l3 dimension
        if attention_mask_l2 is None:
            attention_mask_l3 = self.maxpool_attn(attention_mask_l2.float())
        extended_attention_mask_l3 = self.get_extended_attention_mask(attention_mask_l3,attention_mask_l3.shape, self.device)
        
        # embed
        embedding_output = self.embedding(
            input_ids = input_ids,
            position_ids = position_ids,
            token_type_ids = token_type_ids,
            #input_embeds=None,
            past_key_values_length = past_key_values_length
        )

        # first transformer block (vanilla transformer)
        out_l1 = self.layer_basetransformer(
            hidden_states = embedding_output,
            attention_mask = extended_attention_mask_l1,
            head_mask=head_mask,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=output_attentions
        )
        hidden_states_l1 = out_l1[0]

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_l1)

        # reduce dimenion on sequence 2
        out_l2 = self.bert_reducer_l2(
            hidden_states = hiddens_states_l1_reduced,
            attention_mask = extended_attention_mask_l2,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l1,
            encoder_attention_mask= extended_attention_mask_l1,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l2 = out_l2[0]

        # Vanilla transformers block at mid-resolution (1/2 seq-length)
        out_encoder = self.bert_encoder_midres(
            hidden_states=hidden_states_l2,
            attention_mask=extended_attention_mask_l2,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l2 = out_encoder[0]
        
        # reduce sequence length (1/4 seq-length)
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        out_l3 = self.bert_reducer_l3(
            hidden_states = hiddens_states_l2_reduced,
            attention_mask = extended_attention_mask_l3,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l2,
            encoder_attention_mask= extended_attention_mask_l2,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l3 = out_l3[0]

        #print(hidden_states_l3.shape)
        #print(extended_attention_mask_l3.shape)
        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_l3,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l3 = out_encoder[0]
        
        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_l3)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_l2 = self.bert_integrater_l2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_l3,
            head_mask = head_mask,
            keyvalue_hidden_states = hidden_states_l3,
            query_to_concat_hidden_states = hidden_states_upscaled3to2,
            query_attention_mask = attention_mask_l2
        )
        
        # upscaling: l3/l2 to l1 sequence length
        #hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_l3)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_l2)
        #hidden_states_upscaled = torch.cat((
        #    hidden_states_upscaled2to1, hidden_states_upscaled3to1
        #),axis=2)        

        # integrate low-resolution information back to original dimension
        hidden_states_l1 = self.bert_integrater_l1(
            hidden_states = hidden_states_l1,
            attention_mask = extended_attention_mask_l2,
            head_mask = head_mask,
            keyvalue_hidden_states = hidden_states_l2,
            query_to_concat_hidden_states = hidden_states_upscaled2to1,
            query_attention_mask = extended_attention_mask_l2
        )
        if not return_dict:
            return (
                (hidden_states_l1, hidden_states_l2, hidden_states_l3),
                (extended_attention_mask_l1, extended_attention_mask_l2, extended_attention_mask_l3),
                (attention_mask, attention_mask_l2, attention_mask_l3)
            )
        return {
            "hidden_states": (hidden_states_l1, hidden_states_l2, hidden_states_l3),
            "extended_attention_masks":(extended_attention_mask_l1, extended_attention_mask_l2, extended_attention_mask_l3),
            "attention_masks":(attention_mask, attention_mask_l2, attention_mask_l3)
        }


class AnathemMidModule(nn.Module):
    """Stack of layers that go full circle form high-res to low-res back to high res"""
    def __init__(
            self, 
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device=None,
            stack_id = 1
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            maxpool,
            maxpool_attn,
            bert_reducerintegrator_l2,
            bert_reducerintegrator_l3,
            bert_encoder_midres,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrater_l2,
            bert_integrater_l1
        ) = initialize_midlayers(config, basemod, tokenizer, stack_id)

        self.get_extended_attention_mask = get_extended_attention_mask
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducerintegrator_l2 = bert_reducerintegrator_l2
        self.bert_reducerintegrator_l3 = bert_reducerintegrator_l3
        self.bert_encoder_midres = bert_encoder_midres
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrater_l2 = bert_integrater_l2
        self.bert_integrater_l1 = bert_integrater_l1
        if device is None:
            self.to(basemod.device)
            #print(self.device)
            self.device = basemod.device
        else:
            self.to(device)
            self.device = device
    
    def forward(
        self,
        hidden_states_highres: torch.Tensor,
        hidden_states_midres: torch.Tensor,
        hidden_states_lowres: torch.Tensor,
        attention_mask: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_highres: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_midres: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_lowres: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = hidden_states_highres.shape[:2]
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        if extended_attention_mask_highres is None:
            extended_attention_mask_highres = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        if extended_attention_mask_midres is None:
            attention_mask_midres = self.maxpool_attn(attention_mask.float())
            extended_attention_mask_midres = self.get_extended_attention_mask(attention_mask_midres,attention_mask_midres.shape, self.device)
        if extended_attention_mask_lowres is None:
           attention_mask_lowres = self.maxpool_attn(attention_mask_midres.float())
           extended_attention_mask_lowres = self.get_extended_attention_mask(attention_mask_lowres,attention_mask_lowres.shape, self.device)

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_highres)

        # reduce dimenion on sequence 2
        hidden_states_l2 = self.bert_reducerintegrator_l2(
            inputs = hidden_states_highres, # from highres outputs previous layer (key, values)
            hidden_states = hidden_states_midres, # previous hidden-states for skip connection (short squence-dim, low-res)
            attention_mask = extended_attention_mask_midres,
            head_mask=None,
            query_hidden_states = hiddens_states_l1_reduced
        )

        # Vanilla transformers at mid-resolution (1/2 sequence-length)
        out_encoder = self.bert_encoder_midres(
            hidden_states=hidden_states_l2,
            attention_mask=extended_attention_mask_midres,
            head_mask = None,
            return_dict=return_dict
        )
        hidden_states_l2 = out_encoder[0]
        
        # reduce sequence length (to 1/4 sequence-length)
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        hidden_states_l3 = self.bert_reducerintegrator_l3(
            inputs = hidden_states_midres, # from highres outputs previous layer (key, values)
            hidden_states = hidden_states_lowres, # previous hidden-states for skip connection (short squence-dim, low-res)
            attention_mask = extended_attention_mask_lowres,
            head_mask=None,
            query_hidden_states = hiddens_states_l2_reduced
        )

        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_lowres,
            head_mask = None,
            return_dict=return_dict
        )
        hidden_states_lowres = out_encoder[0]
        
        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_lowres)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_midres = self.bert_integrater_l2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_lowres,
            head_mask = None,
            keyvalue_hidden_states = hidden_states_lowres,
            query_to_concat_hidden_states = hidden_states_upscaled3to2
        )
        #hidden_states_midres = self.bert_integrative_layer_2(
        #    hidden_states = hidden_states_l2,
        #    attention_mask = extended_attention_mask_midres,
        #    head_mask = None,
        #    query_hidden_states = hidden_states_upscaled3to2)
        
        # upscaling: l3/l2 to l1 sequence length
        #hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_lowres)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_midres)
        #hidden_states_upscaled = torch.cat((hidden_states_upscaled2to1, hidden_states_upscaled3to1),axis=2)        

        # integrate low-resolution information back to original dimension
        hidden_states_highres = self.bert_integrater_l1(
            hidden_states = hidden_states_highres,
            attention_mask = extended_attention_mask_midres,
            head_mask = None,
            keyvalue_hidden_states = hidden_states_midres,
            query_to_concat_hidden_states = hidden_states_upscaled2to1
        )

        if not return_dict:
            return (
                (hidden_states_highres, hidden_states_midres, hidden_states_lowres),
                (extended_attention_mask_highres, extended_attention_mask_midres, extended_attention_mask_lowres)
            )
        return {
            "hidden_states": (hidden_states_highres, hidden_states_midres, hidden_states_lowres),
            "attention":(extended_attention_mask_highres, extended_attention_mask_midres, extended_attention_mask_lowres)
        }


class AnathemEncoder(nn.Module):
    """Anathem cores stacks of layers, from embeddings to final transformer block"""
    def __init__(
            self, 
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device=None,
        ):
        super().__init__()
        self.config = config
        self.device = device

        # initialize embedings and first stack
        self.anathem_base_stack = AnathemBaseModule(
            config,
            basemod,
            tokenizer,
            past_key_values_length,
            device,
        )

        # initialize all subsequence stacks
        self.anathem_mid_stack = nn.ModuleList([
            AnathemMidModule(
                config,
                basemod,
                tokenizer,
                past_key_values_length,
                device,
                stack_id = i
            ) for i in range(1, self.config.num_transformer_stacks)
        ])
        
        # initialize the final transformer modules
        self.final_transformer_block = initialize_finaltransformerlayers(
            config, 
            basemod, 
            tokenizer, 
            stack_id=self.config.num_transformer_stacks+1
        )

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = False
    ):
        
        # embed and run through first stack of transformers
        hidden_states, extended_attention_masks, attention_masks = self.anathem_base_stack(
            input_ids=input_ids,
            attention_mask=attention_mask,
            attention_mask_l2=attention_mask_l2,
            attention_mask_l3=attention_mask_l3,
            token_type_ids=token_type_ids, #: Optional[torch.Tensor] = None,
            position_ids=position_ids,#: Optional[torch.Tensor] = None,
            head_mask=head_mask,#: Optional[torch.Tensor] = None,
            inputs_embeds=None,#: Optional[torch.Tensor] = None,
            encoder_hidden_states=None,#: Optional[torch.Tensor] = None,
            encoder_attention_mask=None,#: Optional[torch.Tensor] = None,
            past_key_values=past_key_values,#: Optional[List[torch.FloatTensor]] = None,
            use_cache=use_cache,#: Optional[bool] = None,
            output_attentions=output_attentions,#: Optional[bool] = None,
            output_hidden_states=output_hidden_states,#: Optional[bool] = None,
            return_dict=return_dict
        )

        # middle stack of transformers
        for i, anathem_stack in enumerate(self.anathem_mid_stack):
            
            # run through each stack (1-2)
            hidden_states, extended_attention_masks = anathem_stack(
                hidden_states_highres = hidden_states[0],
                hidden_states_midres = hidden_states[1],
                hidden_states_lowres = hidden_states[2],
                extended_attention_mask_highres = extended_attention_masks[0],
                extended_attention_mask_midres = extended_attention_masks[1],
                extended_attention_mask_lowres = extended_attention_masks[2]
            )

        # hidden states (high,med,low resolution)
        hidden_states_highres, hidden_states_midres, hidden_states_lowres = hidden_states

        # run through final transformer block (pretrained)
        out_final = self.final_transformer_block(
            hidden_states = hidden_states_highres,
            attention_mask = extended_attention_masks[0],
            head_mask=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=output_attentions
        )
        #print(type(out_final))
        #print(len(out_final))
        hidden_states_highres = out_final[0]
        if not output_attentions:
            return (hidden_states_highres, hidden_states_midres, hidden_states_lowres), attention_masks
        
        attention_final = out_final[1]
        return (hidden_states_highres, hidden_states_midres, hidden_states_lowres), attention_masks, attention_final
            

class BertGenericClassificationHead(nn.Module):
    """Instantiates a basic classification head that takes the CLS token and mean of the final layer for classification"""
    def __init__(self, config, n_classes = 1, activation = 'sigmoid', device=None):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size*2, n_classes)
        if activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'none':
            self.activation = lambda x: x
        if device is not None:
            self.to(device)

    def forward(self, hidden_states, attention_mask) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        output_vectors=[]
        first_token_tensor = hidden_states[:, 0]
        output_vectors.append(first_token_tensor)
        # mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        output_vectors.append(sum_embeddings / sum_mask)
        # concatenate
        pooled_output = torch.concat(output_vectors, axis=1)
        #print(pooled_output.shape)
        logits = self.dense(pooled_output)
        return self.activation(logits)


class AnathemMultiSiloPooler(nn.Module):
    """Outputs a paragraph/sentence vector from token embeddings from multiple 'silos'"""
    def __init__(
        self, 
        config, 
        dim_out = None, 
        out_activation = None,
        dims_in = None, 
        p_dropout=None, 
        device=None
    ):
        super().__init__()

        # dimensions of the hiddens states being processed as inputs
        if dims_in is None:
            dims_in = [dim_out, dim_out//2, dim_out//4]
        self.dims_in = dims_in
        self.dim_in = sum(dims_in)
        self.dim_out = dim_out

        self.dense = nn.Linear(config.hidden_size*2, n_classes)
        if activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'none':
            self.activation = lambda x: x
        if device is not None:
            self.to(device)

        # linear layer operating on the concatenated CLS tokens from all silos
        self.cls_pooler = nn.Sequential(
            nn.Dropout(p_dropout),
            nn.Linear(self.dim_in, self.dim_out),
        )

        # pre-mean-pooling (one for each silo)
        self.pre_poolers = [nn.Sequential(
            nn.Dropout(p_dropout),
            nn.Linear(dim,dim)
            ) for dim in self.dims_in
         ]

        # sequential layer to concatenate the mean tokens from multiple tokens
        self.mean_pooler = nn.Linear(self.dim_in, self.dim_out)

    def forward(self, hidden_states, attention_masks, excess_cls_ids=None) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.

        output_vectors=[]

        # CLS/first-tokens from all silos, all concatenated together
        first_token_tensors = self._get_cls_tokens_all_silos(hidden_states, attention_masks)

        # mean pooling

        first_token_tensor = hidden_states[:, 0]
        output_vectors.append(first_token_tensor)
        # mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        output_vectors.append(sum_embeddings / sum_mask)
        # concatenate
        pooled_output = torch.concat(output_vectors, axis=1)
        #print(pooled_output.shape)
        logits = self.dense(pooled_output)
        return self.activation(logits)

    def _get_cls_token(self, hidden_state, attention_mask):
        """Grabs the CLS token from a hidden-states"""
        return hidden_states[:, 0]
    
    def _get_cls_tokens_all_silos(self, hidden_states, attention_masks):
        """Grabs the CLS token from all hidden_states"""
        first_tokens = [
            self._get_cls_token(hidden_state)
            for hidden_state
            in hidden_states
        ]
        # concat all first tokens
        all_first_tokens_cat = torch.cat(first_tokens,axis=2)
        # run the concatenated first-tokens through Dense
        all_first_tokens_out = self.cls_pooler(all_first_tokens_cat)
        return all_first_tokens_out

    def _mean_pool(self, hidden_state, attention_mask=None, excess_cls_id=None):
        """Pool along a sequence dimension (for just one silo)"""
        if excess_cls_id is None:
            excess_cls_id = attention_mask
        input_mask_expanded = excess_cls_id.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        return sum_embeddings / sum_mask

    def _mean_pool_all_silos(self, hidden_states, attention_masks=None, excess_cls_ids=None):
        """Pool along a sequence dimension (for all silos)"""
        if excess_cls_ids is None:
            excess_cls_ids = attention_masks

        # pre-pool: dense-layer before pooling
        hidden_states = [
            l(hidden_state)
            in hidden_state, l
            zip(hidden_states, self.pre_poolers)
        ]

        # mean pool each silo
        mean_pooled_states = [
            self._mean_pool(
                hidden_state=hidden_state, excess_cls_id=excess_cls_id
            ) for hidden_state, excess_cls_id 
            in zip(hidden_states, excess_cls_ids)
        ]

        # concat all mean-pooled states
        all_mean_pooled_states = torch.cat(mean_pooled_states,axis=2)
        # run the concatenated meanpooled states through Dense
        all_mean_pooled_states = self.mean_pooler(all_mean_pooled_states)
        return all_mean_pooled_states


# FOOFu
# need to:
# -- make the entire module (3 stacks)
# -- add final transformer layer (full) at top
# -- update classification-head for multi scale vector
# -- add option to output MLM head
# -- remove/update the `tokenize anathem`

def tokenize_anathem(text, device=device):
    #padding_length = int(math.ceil(max_length / 4)) * 
    tokens = tokenizer(text,padding=True, return_tensors='pt')
    input_shape = tokens['input_ids'].size()

    # change token padding to be multiple of 4
    ideal_length = int(math.ceil(input_shape[-1] / 4)) * 4 # should be a multiple of 4
    if input_shape[-1]!=ideal_length:
      tokens = tokenizer(text,padding='max_length', max_length = ideal_length, return_tensors='pt')
      input_shape = tokens['input_ids'].size()

    token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    tokens['token_type_ids'] = token_type_ids
    for k,v in tokens.items():
        tokens[k] = v.to(device)
    
    return tokens

In [None]:
class AnathemTransformer:
    def __init__(
        self, 
        config=None,
        device=None,
        do_mlm = True,
        do_cls = True
    ):
        # default config
        if config is None:
            config = make_config()
        self.config = config
        self.do_mlm = do_mlm
        self.do_cls = do_cls

        # device
        if device is None:
            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')
        self.device= device

        # get the basemodel (and its masked LM head
        self.model_string = self.config.model_string
        basemodelLM_pretrained = AutoModelForMaskedLM.from_pretrained(self.model_string)

        # get the Pretrained BertEncoder
        basemod_pretrained = get_to_bertlayer(
            basemodelLM_pretrained, 
            target_layer_name = 'encoder'
        )

        # make the tokenizer (based on pretrained)
        self.tokenizer = CustomTokenizer(
            model_string=self.config.model_string, 
            n_cls_prepend = int(1/config.scale_ratio3), 
            n_pad_to_multiple_of= int(1/config.scale_ratio3)
        )

        # make the Embedding and first layers (pretrained)
        self.encoder = AnathemEncoder(
            self.config,
            basemod=basemod_pretrained,
            tokenizer=self.tokenizer ,
            past_key_values_length = None,
            device=self.device,
        )

        # get the Pretrained maskedLM head
        if self.do_mlm:
            # perform maskedLM
            self.mlm = get_to_bertlayer(
                basemodelLM_pretrained, 
                target_layer_name = 'cls'
            )
        else:
            self.mlm = lambda x : x

        # make the sequence-classification head
        # self.cls = ?
        # stop for now
    
    def _get_name(self):
        return 'ANATHEM_MODEL_FOR_MLM'
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        attention_mask_l2: Optional[torch.Tensor] = None,
        attention_mask_l3: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        excess_cls_ids: Optional[torch.Tensor] = None,
        excess_cls_ids_l2: Optional[torch.Tensor] = None,
        excess_cls_ids_l3: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):

        # run through base-layer (embeddings, transformer-block, 1 anathem stack)
        outputs_encoder = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            attention_mask_l2=attention_mask_l2, # optional downsized attention mask for sequence-dim 1/2
            attention_mask_l3=attention_mask_l3, # optional downsized attention mask for sequence-dim 1/4
            token_type_ids=token_type_ids, #: Optional[torch.Tensor] = None,
            position_ids=position_ids,#: Optional[torch.Tensor] = None,
            head_mask=head_mask,#: Optional[torch.Tensor] = None,
            inputs_embeds=None,#: Optional[torch.Tensor] = None,
            encoder_hidden_states=None,#: Optional[torch.Tensor] = None,
            encoder_attention_mask=None,#: Optional[torch.Tensor] = None,
            past_key_values=past_key_values,#: Optional[List[torch.FloatTensor]] = None,
            use_cache=use_cache,#: Optional[bool] = None,
            output_attentions=output_attentions,#: Optional[bool] = None,
            output_hidden_states=output_hidden_states,#: Optional[bool] = None,
            return_dict=False
        )
        if output_attentions:
            hidden_states, extended_attention_masks, attention = outputs_encoder
        else:
            hidden_states, extended_attention_masks = outputs_encoder
            attention = None

        out_mlm = None
        out_cls = None
        hidden_states_highres, hidden_states_midres, hiddenstates_lowres = hidden_states

        # MLM outputs
        if self.do_mlm:
            out_mlm = self.mlm(hidden_states_highres)

        
        if return_dict:
            return {
                'logits':out_mlm['logits'],
                'hidden_states':(hidden_states_highres, hidden_states_midres, hiddenstates_lowres),
                'attention':attention,
                'extended_attention_masks':extended_attention_masks
            }
        return hidden_states, attention, out_mlm, extended_attention_masks



In [None]:
modelstring_teacher_mlm = 'bert-base-uncased'
model_string = "google/bert_uncased_L-4_H-512_A-8"

config = make_config(
    modelstring = model_string,
    num_transformer_stacks = 3,
    scale_ratio2 = 0.5,
    scale_ratio3 = 0.25,
    multiplier_intermediate2 = 4.0,
    multiplier_intermediate3 = 4.0,
    num_layers_l2 = 1, # mid-res encoder
    num_layers_l3 = 3, # low-res encoder
    dropout_scaling = 0.05,
    do_cheap_integrator = [1]
)

In [None]:

anamod = AnathemTransformer(
        config=config,
        device=None,
        do_mlm = True,
        do_cls = True
    )

teacher_mlm = AutoModelForMaskedLM.from_pretrained(modelstring_teacher_mlm)

Some weights of the model checkpoint at google/bert_uncased_L-4_H-512_A-8 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on an

In [None]:

print(anamod.mlm) # MLM head

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=512, out_features=512, bias=True)
      (transform_act_fn): GELUActivation()
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=512, out_features=30522, bias=True)
  )
)


In [None]:
text = [
    "A standard [MASK] clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with issues.",
    "It usually consists of two elements: a trigger event or circumstance and a [MASK] obligation. The trigger event or circumstance is the [MASK] of the agreement, misconduct, or negligence of the indemnifying party or its affiliates"
]

inputs = anamod.tokenizer(text, add_special_tokens=True, return_tensors='pt', padding='longest')

print(inputs.keys())
inputs

hidden_states,attention,out_mlm, attention_masks = anamod.forward(**inputs)
outputs_teacher_mlm = teacher_mlm(input_ids = inputs['input_ids'], attention_mask=inputs['attention_mask'])

#outputs_teacher_mlm['logits'].shape



dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'excess_cls_ids'])


In [None]:
attention_masks[1]

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.]])

In [None]:

attention_mask = inputs['attention_mask']
if inputs['excess_cls_ids'] is not None:
    excess_cls_ids = inputs['excess_cls_ids']
    attention_mask =attention_mask*excess_cls_ids

input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states[0].size()).float()
#input_mask_expanded.shape
sum_embeddings = torch.sum(hidden_states[0] * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)

In [None]:
sum_mask

tensor([[33., 33., 33.,  ..., 33., 33., 33.],
        [44., 44., 44.,  ..., 44., 44., 44.]])

## Vanilla Trainer on HuggingFace

Just to train and use huggingface as is Trainer

In [None]:
## Vanilla Trainer on HuggingFace?

torch.Size([2, 48, 30522])

In [None]:
outputs_teacher_mlm['logits']

tensor([[[ -6.7448,  -6.7016,  -6.6852,  ...,  -5.9516,  -5.8558,  -4.1969],
         [ -6.7993,  -6.7697,  -6.7379,  ...,  -6.0041,  -5.9220,  -4.2666],
         [ -6.7738,  -6.7394,  -6.7160,  ...,  -5.9772,  -5.8928,  -4.2304],
         ...,
         [ -4.1397,  -4.3376,  -4.3621,  ...,  -3.6482,  -2.7141,  -4.0353],
         [ -6.5960,  -6.8189,  -6.6282,  ...,  -5.5340,  -4.6214,  -7.2814],
         [ -4.1222,  -4.3778,  -4.4345,  ...,  -3.3695,  -2.9633,  -5.1123]],

        [[ -6.6936,  -6.6243,  -6.5915,  ...,  -5.8993,  -5.7458,  -4.4250],
         [ -6.7312,  -6.6699,  -6.6184,  ...,  -5.9545,  -5.7618,  -4.5720],
         [ -6.7255,  -6.6614,  -6.6163,  ...,  -5.9390,  -5.7634,  -4.5208],
         ...,
         [ -7.5988,  -7.5956,  -7.3957,  ...,  -6.5094,  -5.0794,  -6.6433],
         [ -5.5828,  -5.4577,  -5.5788,  ...,  -4.9769,  -4.6903,  -2.6268],
         [-12.7831, -13.0169, -12.9553,  ..., -11.2401,  -8.5196, -11.0713]]],
       grad_fn=<ViewBackward0>)

In [None]:
predicted_token_ids1 = outputs_teacher_mlm[0][1].argmax(dim=-1)
predicted_token_ids2 = out_mlm[1].argmax(dim=-1)

print(anamod.tokenizer.convert_ids_to_tokens(predicted_token_ids1))
print(anamod.tokenizer.convert_ids_to_tokens(predicted_token_ids2))

['.', '.', '.', '.', 'it', 'usually', 'consists', 'of', 'two', 'elements', ':', 'a', 'trigger', 'event', 'or', 'circumstance', 'and', 'a', 'trigger', 'obligation', '.', 'the', 'trigger', 'event', 'or', 'circumstance', 'is', 'the', 'violation', 'of', 'the', 'agreement', ',', 'misconduct', ',', 'or', 'negligence', 'of', 'the', 'ind', '##em', '##ni', '##fying', 'party', 'or', 'its', '.', 's']
[',', 'the', ',', ',', 'it', 'also', 'consists', 'of', 'two', 'elements', ':', 'a', ',', ',', 'or', 'circumstance', 'and', 'a', 'state', ',', '.', 'the', 'trigger', ',', 'or', 'the', ',', 'the', 'product', 'of', 'the', ',', ',', 'the', ',', 'or', 'that', 'of', 'the', ',', '##em', 'of', ',', ',', 'or', 'its', ',', ',']


In [None]:
predicted_token_ids1

tensor([ 1012,  1012,  1012,  1012,  2009,  2788,  3774,  1997,  2048,  3787,
         1024,  1037,  9495,  2724,  2030, 25652,  1998,  1037,  9495, 14987,
         1012,  1996,  9495,  2724,  2030, 25652,  2003,  1996, 11371,  1997,
         1996,  3820,  1010, 23337,  1010,  2030, 27988,  1997,  1996, 27427,
         6633,  3490, 14116,  2283,  2030,  2049,  1012,  1055])

In [None]:
from transformers import AutoModelForMaskedLM
basemodelLM = AutoModelForMaskedLM.from_pretrained("google/bert_uncased_L-4_H-512_A-8")

Downloading pytorch_model.bin:   0%|          | 0.00/116M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/bert_uncased_L-4_H-512_A-8 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
basemodelLM

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_aff

In [None]:
## try to grab a MLM classification head
## let's verify they have the same vocabulary
# models from: https://arxiv.org/pdf/1908.08962.pdf
from transformers import AutoModelForMaskedLM, AutoConfig
modelstring_base = "google/bert_uncased_L-12_H-512_A-8" #
modelstring_base = "google/bert_uncased_L-4_H-512_A-8"
#modelstring_base = 'google/bert_uncased_L-6_H-512_A-8'
basemod = AutoModelForMaskedLM.from_pretrained(modelstring_base)
basemod_tokenizer = AutoTokenizer.from_pretrained(modelstring_base)
# the minatoure googles have a vocab size of: 30522

modelstring_lg = 'bert-large-uncased' # I think the google-team used this for the miniature models
# bert-large uncased has a vocab size of: 30522
#modelstring_lg = "google/bert_uncased_L-12_H-768_A-12"
largmod = AutoModelForMaskedLM.from_pretrained(modelstring_lg) #
largmod_tokenizer = AutoTokenizer.from_pretrained(modelstring_lg)#"google/bert_uncased_L-12_H-768_A-12")


# note: which datasets used to train large
# wikipedia
# bookcorpus
# ... but seem more about datasets and models from: https://arxiv.org/pdf/1908.08962.pdf


text = "For Ex Works (EXW) terms, the Supplier will [MASK] all risk and liability for the Delivered [MASK] up until delivering the goods to the nominated Carrier."
with torch.no_grad():
    inputs1 = basemod_tokenizer(text, return_tensors='pt')
    outputs1 = basemod(**inputs1)
    preds1 = outputs1.logits
    inputs2 = largmod_tokenizer(text, return_tensors='pt')
    outputs2 = largmod(**inputs2)
    preds2 = outputs2.logits

    assert (inputs1['input_ids']-inputs2['input_ids']).sum()==0, 'ids are different'

    predicted_token_ids1 = preds1[0].argmax(dim=-1)
    predicted_token_ids2 = preds2[0].argmax(dim=-1)
    
    print(basemod_tokenizer.convert_ids_to_tokens(predicted_token_ids1))
    print(basemod_tokenizer.convert_ids_to_tokens(predicted_token_ids2))


# confirmation: the minature berts and the 

Downloading (…)lve/main/config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/116M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/bert_uncased_L-4_H-512_A-8 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


['delivery', 'for', 'ex', '##w', '(', 'ex', '##w', ')', 'terms', ',', 'the', 'supplier', 'will', 'reduce', 'all', 'risk', 'and', 'liability', 'for', 'the', 'delivered', 'goods', 'up', 'until', 'delivering', 'the', 'goods', 'to', 'the', 'delivery', 'carrier', '.', 'is']
['.', 'for', 'ex', 'works', '(', 'ex', '##w', ')', 'terms', ',', 'the', 'supplier', 'will', 'assume', 'all', 'risk', ',', 'liability', 'for', 'the', 'delivered', 'goods', 'up', 'until', 'delivering', 'the', 'goods', 'to', 'the', 'responsible', 'carrier', '.', '.']


['the', 'for', 'ex', '##w', '(', 'ex', '##w', ')', 'terms', ',', 'the', 'supplier', 'will', 'cover', 'all', 'risk', 'and', 'liability', 'for', 'the', 'delivered', 'goods', 'up', 'until', 'delivering', 'the', 'goods', 'to', 'the', 'delivered', 'carrier', '.', '.']
['.', 'for', 'ex', 'works', '(', 'ex', '##w', ')', 'terms', ',', 'the', 'supplier', 'will', 'assume', 'all', 'risk', ',', 'liability', 'for', 'the', 'delivered', 'goods', 'up', 'until', 'delivering', 'the', 'goods', 'to', 'the', 'responsible', 'carrier', '.', '.']


In [None]:
predicted_token_ids1


tensor([ 1996,  2005,  4654,  2860,  1006,  4654,  2860,  1007,  3408,  1010,
         1996, 17024,  2097,  3104,  2035,  3891,  1998, 14000,  2005,  1996,
         5359,  5350,  2039,  2127, 12771,  1996,  5350,  2000,  1996,  5359,
         6839,  1012,  1012])

In [None]:
inputs1['input_ids']-inputs2['input_ids']

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [None]:
basemod._modules['cls']

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=512, out_features=512, bias=True)
      (transform_act_fn): GELUActivation()
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=512, out_features=30522, bias=True)
  )
)

In [None]:
config = make_config('google/bert_uncased_L-12_H-512_A-8') # 

# make the basemod and tokenizer
basemod = AutoModel.from_pretrained(config.model_string)
basemod.to(device)
tokenizer = AutoTokenizer.from_pretrained(config.model_string)


Some weights of the model checkpoint at google/bert_uncased_L-12_H-512_A-8 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
anathem_encoder1 = AnathemBaseModule(config, basemod, tokenizer)
anathem_encoder2 = AnathemMidModule(config, basemod)

In [None]:
time1 = time.time()
for iteration, batch in enumerate(tqdm(dl_train, disable=True)):
    if iteration>30:
        time2 = time.time()
        print(time2-time1)
        break
    with torch.no_grad():
        tokens = tokenize_anathem(batch['text'])  
        (hidden_states, extended_attention_masks) = anathem_encoder1(
            input_ids = tokens['input_ids'],
            attention_mask = tokens['attention_mask'],
            token_type_ids = tokens['token_type_ids']
        )
        features,_ = anathem_encoder2(
            hidden_states_highres = hidden_states[0],
            hidden_states_midres = hidden_states[1],
            hidden_states_lowres = hidden_states[2],
            extended_attention_mask_highres = extended_attention_masks[0],
            extended_attention_mask_midres = extended_attention_masks[1],
            extended_attention_mask_lowres = extended_attention_masks[2]
        )

1.2566087245941162


In [None]:
# the new method takes: 3.198051929473877 / 200 iterations (I can't really te)