<a href="https://colab.research.google.com/github/faraway1nspace/AnathemTransformer/blob/main/dev/notebooks/dev_anathem_transformer_base_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Development Notebook: build and test base layers

#### Playing Around with novel architectures

In [174]:
%pip install torch transformers datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [175]:
from transformers import AutoModel, AutoTokenizer
import torch
from torch import nn
import torch.nn.functional as F
from transformers.models.bert.modeling_bert import BertEncoder
from transformers.activations import ACT2FN


tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
basemod = AutoModel.from_pretrained("distilroberta-base")
device = basemod.device

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [181]:
dir(basemod)
# base embedding layers
layer_emb = copy.deepcopy(basemod._modules['embeddings'])


In [182]:
# base trasnformers (full)
layer_basetransformer = copy.deepcopy(basemod._modules['encoder']._modules['layer']._modules['0'])

In [183]:
# text 
text = [
    "A standard indemnity clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with legal issues1.",
    "It usually consists of two elements: a trigger event or circumstance and a payment obligation2. The trigger event or circumstance is the breach of the agreement, willful misconduct, or negligence of the indemnifying party or its affiliates"
]

import math

#padding_length = int(math.ceil(max_length / 4)) * 4
tokens = tokenizer(text,padding=True, return_tensors='pt')
input_shape = tokens['input_ids'].size()

# change token padding to be multiple of 4
ideal_length = int(math.ceil(input_shape[-1] / 4)) * 4 # should be a multiple of 4
if input_shape[-1]!=ideal_length:
  tokens = tokenizer(text,padding='max_length', max_length = ideal_length, return_tensors='pt')
  input_shape = tokens['input_ids'].size()

token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
tokens['token_type_ids'] = token_type_ids
past_key_values_length =0

# need to extend attention mask
extended_attention_mask = basemod.get_extended_attention_mask(tokens['attention_mask'], input_shape)
tokens['extended_attention_mask'] = extended_attention_mask
print(tokens.keys())
print(tokens['input_ids'].shape)


dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'extended_attention_mask'])
torch.Size([2, 48])


In [184]:
silo_dimensions = {0:basemod.config.hidden_size,
                  1:basemod.config.hidden_size//2,
                  2:basemod.config.hidden_size//4,
                  }
reintegration_dim = silo_dimensions[1] + silo_dimensions[2]


In [185]:
embedding_output = layer_emb(
            input_ids=tokens['input_ids'],
            position_ids=tokens.get('position_ids',None),
            token_type_ids=tokens['token_type_ids'],
            inputs_embeds=None,
            past_key_values_length=past_key_values_length
)
print(embedding_output.shape)

torch.Size([2, 48, 768])


In [186]:
# basemodel transformer outputs: *full bert model
out_l1 = layer_basetransformer(
    hidden_states = embedding_output,
    attention_mask = tokens['extended_attention_mask'],#tokens['attention_mask'],
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    #past_key_values=0,
    #use_cache=None,
    output_attentions=True,
    #output_hidden_states=True,
    #return_dict=True
)

hidden_states_l1 = out_l1[0]
self_attention_l1 = out_l1[1]

In [187]:
# Next Layer: 
# Query -> max pool and reduce  hidden dimension // 2
# Key -> reduce hidden_dim // 2
# value -> reduce hidden_dim //2
#maxpool_l2 = nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

maxpool_l2 = nn.Sequential(
    nn.Dropout(0.05),
    nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True),
)

maxpool_l2_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

In [188]:
# reduce dimension of hidden states
hiddens_states_l1_reduced = maxpool_l2(hidden_states_l1)
print(hidden_states_l1.shape)
print(hiddens_states_l1_reduced.shape)

# reduce dimension of attention mask
attention_mask_l1_reduced = maxpool_l2_attn(tokens['attention_mask'].float())
print(attention_mask_l1_reduced.shape)

# extend the dimension of the reduced attention_mask
print(input_shape)
extended_attention_mask_l1_reduced = basemod.get_extended_attention_mask(attention_mask_l1_reduced, attention_mask_l1_reduced.shape)
print(tokens['extended_attention_mask'].shape)
print(extended_attention_mask_l1_reduced.shape)

torch.Size([2, 48, 768])
torch.Size([2, 24, 768])
torch.Size([2, 24])
torch.Size([2, 48])
torch.Size([2, 1, 1, 48])
torch.Size([2, 1, 1, 24])


In [189]:
# Try to do Multi Headed attenion with differently sized query and value 

In [214]:
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
import copy

class BertSelfAttnDimensionReduction(nn.Module):
    """Bert Attention Layer that uses a dimension-reduced version of the query, so to reduce the dimension of the outputs"""
    def __init__(
        self, 
        config, 
        hidden_size_input=768, 
        hidden_size_query = None,
        position_embedding_type=None, 
        dim_reduction = 2
    ):
        """Special type of Bert Self attention that reduces the dimension of the inputs by half"""
        super().__init__()
        if (config.hidden_size // dim_reduction) % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.dim_reduction = dim_reduction
        self.hidden_size_input = hidden_size_input
        self.hidden_size_reduced = hidden_size_input // dim_reduction
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size_reduced / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_input, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_input, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.

        key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
        value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if encoder_attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            #print(attention_scores.shape)
            #print(attention_scores.shape)
            attention_scores = attention_scores + encoder_attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

bertlayer_l2_reduction = BertSelfAttnDimensionReduction(
    config=basemod.config,
    hidden_size_input=basemod.config.hidden_size, 
    position_embedding_type=basemod.config.position_embedding_type,
    dim_reduction = 2
)

bertlayer_l3_reduction = BertSelfAttnDimensionReduction(
    config=basemod.config,
    hidden_size_input=basemod.config.hidden_size // 2, 
    position_embedding_type=basemod.config.position_embedding_type,
    dim_reduction = 2
)

In [215]:
out_l2 = bertlayer_l2_reduction(
        hidden_states = hiddens_states_l1_reduced,
        attention_mask = extended_attention_mask_l1_reduced,
        head_mask=None,
        encoder_hidden_states = hidden_states_l1,
        encoder_attention_mask= tokens['extended_attention_mask'],
        past_key_value=None,
        output_attentions=False
    )
hidden_states_l2 = out_l2[0]
print(hidden_states_l2.shape)

torch.Size([2, 24, 384])


In [192]:
# Next dimension reduction: 
hiddens_states_l2_reduced = maxpool_l2(hidden_states_l2)
print(hidden_states_l2.shape)
print(hiddens_states_l2_reduced.shape)

# reduce dimension of attention mask
attention_mask_l2_reduced = maxpool_l2_attn(attention_mask_l1_reduced.float())
print(attention_mask_l2_reduced.shape)

# extend the dimension of the reduced attention_mask
extended_attention_mask_l2_reduced = basemod.get_extended_attention_mask(attention_mask_l2_reduced, attention_mask_l2_reduced.shape)
print(extended_attention_mask_l2_reduced.shape)

if True:
  out_l3 = bertlayer_l3_reduction(
        hidden_states = hiddens_states_l2_reduced, # input has been maxpooled
        attention_mask = extended_attention_mask_l2_reduced,
        head_mask=None,
        encoder_hidden_states = hidden_states_l2,
        encoder_attention_mask= extended_attention_mask_l1_reduced,
        past_key_value=None,
        output_attentions=False
    )
  hidden_states_l3 = out_l3[0]
  print(hidden_states_l3.shape)


# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times

torch.Size([2, 24, 384])
torch.Size([2, 12, 384])
torch.Size([2, 12])
torch.Size([2, 1, 1, 12])
torch.Size([2, 12, 192])


In [193]:
# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times

config_lowres_encoder = copy.deepcopy(basemod.config)
config_lowres_encoder.hidden_size = config_lowres_encoder.hidden_size//4
config_lowres_encoder.num_hidden_layers = 3
print(config_lowres_encoder)

# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times
encoder_lowres = BertEncoder(config_lowres_encoder)

RobertaConfig {
  "_name_or_path": "distilroberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 192,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 3,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.29.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}



In [194]:
out_encoder_lowres = encoder_lowres(
    hidden_states=hidden_states_l3,
    attention_mask=extended_attention_mask_l2_reduced,
    head_mask = None,
    return_dict=True,
)
hidden_states_lowres = out_encoder_lowres[0]
print(hidden_states_lowres.shape)

torch.Size([2, 12, 192])


In [195]:
## Upresolution Layer: up-resolution from dim-3 to dim-2 is as follows:
# hs_l3 -> upsampled sequence-length as hs-l2
# -> could have another attention-based mechanism that expands dimension of hs-l2

class InterpolateCombo(nn.Module):
    """there could also be an attentive way to do this"""
    def __init__(self, scale_factor=2, dropout=0.05, alpha=0.667):
        """Arguments:
        :param scaler_factor: float, multiple of up-scaling
        :param dropout: float, dropout proportion
        :param alpha: float, mixture weight between nearest-neighbor vs linear-interpolation
        """
        super(InterpolateCombo, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.dropout = nn.Dropout(dropout)
        self.a = alpha
        
    def forward(self, x):
        x_trans = x.transpose(-2,-1)
        z = self.a*self.interp(x_trans, mode='nearest',scale_factor=self.scale_factor) + (1-self.a)*self.interp(x_trans, mode='linear',scale_factor=self.scale_factor)
        z = self.dropout(z)
        return z.transpose(-2,-1)

#hidden_states_upscaled_3to2_nearest = nn.functional.interpolate(hidden_states_rowres.transpose(-2,-1), scale_factor=2, mode='nearest').transpose(-2,-1)
#hidden_states_upscaled_3to2_linear = nn.functional.interpolate(hidden_states_rowres.transpose(-2,-1), scale_factor=2, mode='linear').transpose(-2,-1)

upscaler_x2 = InterpolateCombo(scale_factor=2)

In [197]:
hidden_states_upscaled3to2 = upscaler_x2(hidden_states_lowres)


In [198]:
## BertAttentiveIntegrator

class BertCrossAttention(nn.Module):
    def __init__(
        self, 
        config, 
        hidden_size,
        hidden_size_query,
        position_embedding_type=None
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.hidden_size_query = hidden_size_query
        if self.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size, self.all_head_size)
        self.value = nn.Linear(self.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: Optional[torch.FloatTensor] = None,
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(query_hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)
        
        use_cache = past_key_value is not None
        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

In [199]:
bertlayer_l3_to_l2_crossattn = BertCrossAttention(
        config=basemod.config, 
        hidden_size=silo_dimensions[1],
        hidden_size_query=silo_dimensions[2],
        position_embedding_type=None
    )

In [200]:
print(hidden_states_upscaled3to2.shape)
print(hidden_states_l2.shape)
print(attention_mask_l1_reduced.shape)
print(extended_attention_mask_l1_reduced.shape)

torch.Size([2, 24, 192])
torch.Size([2, 24, 384])
torch.Size([2, 24])
torch.Size([2, 1, 1, 24])


In [201]:
out_l2_postencode = bertlayer_l3_to_l2_crossattn(
    hidden_states = hidden_states_l2,
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled3to2,
    query_attention_mask = attention_mask_l1_reduced
)
hidden_states_l2_postencode = out_l2_postencode[0]
print(hidden_states_l2_postencode.shape)
assert hidden_states_l2_postencode.shape == hidden_states_l2.shape

torch.Size([2, 24, 384])


In [202]:
print(basemod.config.hidden_size)
print(basemod.config.intermediate_size)
print(basemod.config.intermediate_size/basemod.config.hidden_size)

768
3072
4.0


In [203]:
# how does bert actually work?
"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class BertIntegrativeLayer(nn.Module):
    """Vanilla Bert Layer, but integrates other hiddens states from a parallel transformers stack typically low-re"""
    def __init__(
            self, 
            config, 
            hidden_size, 
            hidden_size_query,
            intermediate_size=None
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size
        self.hidden_size = hidden_size
        self.hidden_size_query = hidden_size_query
        self.hidden_size_concat = int(hidden_size + hidden_size_query)

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertCrossAttention(
            config, 
            hidden_size,
            hidden_size_query,
            position_embedding_type="absolute"
        )
        self.is_decoder = config.is_decoder
        #self.intermediate = BertIntermediate(config)
        #self.output = BertOutput(config)
        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: Optional[torch.FloatTensor] = None,
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        # cross attn between hiddens states and (low-res) query vector
        cross_attn_outputs = self.attention(
            hidden_states = hidden_states,
            attention_mask = attention_mask,
            head_mask = head_mask,
            query_hidden_states = query_hidden_states,
            query_attention_mask = query_attention_mask
        )
        cross_hidden_states = cross_attn_outputs[0]

        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.dropout_attn(self.output_attn(cross_hidden_states))
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.intermediate_act_fn(self.intermediate(
            self.cat((hidden_states, query_hidden_states),axis=2)
        ))
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        intermediate_states = self.dropout_intm(self.output_intm(intermediate_states))
        out_states = self.lnorm_intm(intermediate_states + hidden_states)

        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states


In [204]:

# from low-res to mid-res
bert_integrative_layer_midres = BertIntegrativeLayer(
    basemod.config, 
    hidden_size=silo_dimensions[1], 
    hidden_size_query=silo_dimensions[2],
    intermediate_size=silo_dimensions[1]*4,
)

# from mid-res to high-res
bert_integrative_layer_hires = BertIntegrativeLayer(
    basemod.config, 
    hidden_size=silo_dimensions[0], 
    hidden_size_query=reintegration_dim,
    intermediate_size=silo_dimensions[0]*4,
)

In [205]:
hidden_states_midres = bert_integrative_layer_midres(
    hidden_states = hidden_states_l2,
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled3to2,
    query_attention_mask = attention_mask_l1_reduced
)
print(hidden_states_midres.shape)
assert hidden_states_midres.shape == hidden_states_l2.shape

torch.Size([2, 24, 384])


In [206]:
# upscale the l2 and l3 to the full dimension
upscaler_x4 = InterpolateCombo(scale_factor=4)
hidden_states_upscaled3to1 = upscaler_x4(hidden_states_lowres)
hidden_states_upscaled2to1 = upscaler_x2(hidden_states_midres)

hidden_states_upscaled = torch.cat(
    (hidden_states_upscaled2to1, hidden_states_upscaled3to1),
    axis=2)

print(hidden_states_upscaled.shape)

torch.Size([2, 48, 576])


In [207]:
# final layer to bring it up to full dimension
hidden_states_hires = bert_integrative_layer_hires(
    hidden_states = hidden_states_l1,
    attention_mask = extended_attention_mask,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled,
    query_attention_mask = extended_attention_mask
)
print(hidden_states_hires.shape)
assert hidden_states_hires.shape == hidden_states_l1.shape

torch.Size([2, 48, 768])


In [208]:
hidden_states_hires.shape

torch.Size([2, 48, 768])

In [210]:
attention_mask_l1_reduced.shape

torch.Size([2, 24])

### The Reduce and Integrate layer:
- this is like a Transformer block, but:
- does dimension reduction along sequence and embedding-dim
- includes a skip connection from previous hidden-states of the same dimension

In [254]:



# this is the layer that just does cross-attention between a seq-reduced query and full-size value and key


"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)


BertReduceAddIntegrativeLayer
inputs = x_l1, x_l1_reduced, x_l2_prev
- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
- x3 = lnorm(drop(f(x2)) + x_l2_prev)
- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class BertReduceAddIntegrativeLayer(nn.Module):
    """Bert Layer that does dimenion reduction along embedding-dimenion and integrations a skip connection"""
    def __init__(
            self, 
            config, 
            hidden_size,
            hidden_size_input=None,
            hidden_size_query=None,
            intermediate_size=None,
            dim_reduction=2,
            do_concat_hidden_and_query = True
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.do_concat_hidden_and_query = do_concat_hidden_and_query
        assert bool(do_concat_hidden_and_query), 'not implemented: concatenation of query and hidden-states must happen'
        self.hidden_size = hidden_size
        if dim_reduction is None:
            dim_reduction = 2
        self.dim_reduction = dim_reduction
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size
        if hidden_size_input is None:
            hidden_size_input = hidden_size
        self.hidden_size_input = hidden_size_input
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query + do_concat_hidden_and_query*hidden_size
        self.hidden_size_concat = int(hidden_size + hidden_size_input)

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertSelfAttnDimensionReduction( 
            config, 
            hidden_size_input=self.hidden_size_input, 
            hidden_size_query = self.hidden_size_query,
            position_embedding_type="absolute", 
            dim_reduction = self.dim_reduction
        )
        self.is_decoder = config.is_decoder
        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)
    
    def forward(
        self,
        inputs: torch.Tensor, # higher-resolution inputs for key and values (long sequence dimension)
        hidden_states: torch.Tensor, # previous hidden-states for skip connection (short squence-dim, low-res)
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: torch.FloatTensor = None, # hidden-states for query (short squence-dim, low-res)
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        if self.do_concat_hidden_and_query:
            query_hidden_states_plus = torch.cat((query_hidden_states, hidden_states),axis=2)
        # cross attn between (low-res) query vector and (high-res) key-values
        cross_attn_outputs = self.attention(
            query_hidden_states_plus, # query (short seq-dim, high-res)
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states = inputs, # for key/value (longer sequence dimension, high-res)
            past_key_value=past_key_value,
            output_attentions=output_attentions,
        )
        cross_hidden_states = cross_attn_outputs[0]

        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.dropout_attn(self.output_attn(cross_hidden_states))
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.intermediate_act_fn(self.intermediate(
            self.cat((hidden_states, query_hidden_states),axis=2)
        ))
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        intermediate_states = self.dropout_intm(self.output_intm(intermediate_states))
        out_states = self.lnorm_intm(intermediate_states + hidden_states)

        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states


In [256]:
# initialize the mid-resolution BertReduceAndIntegrate layer
bert_reduce_add_integrate_midres = BertReduceAddIntegrativeLayer(
    config, 
    hidden_size = silo_dimensions[1], # size of mid-res
    hidden_size_input=silo_dimensions[0],
    hidden_size_query=silo_dimensions[0],
    intermediate_size=silo_dimensions[1]*3,
    dim_reduction=2,
    do_concat_hidden_and_query = True
)

bert_reduce_add_integrate_lowres = BertReduceAddIntegrativeLayer(
    config, 
    hidden_size = silo_dimensions[2], # size of mid-res
    hidden_size_input=silo_dimensions[1],
    hidden_size_query=silo_dimensions[1],
    intermediate_size=silo_dimensions[2]*3,
    dim_reduction=2,
    do_concat_hidden_and_query = True
)

In [257]:
# Reduce sequence-dim from l1->l2, and from high-res->mid-res
hidden_states_hires_reduced = maxpool_l2(hidden_states_hires)
assert hidden_states_hires_reduced.shape[1] == hidden_states_midres.shape[1] # reduced-seq-dim should be same as mid-res hidden-states
print(hidden_states_midres.shape)
hidden_states_midres = bert_reduce_add_integrate_midres(
    inputs = hidden_states_hires, # from highres outputs previous layer (key, values)
    hidden_states = hidden_states_midres, # previous hidden-states for skip connection (short squence-dim, low-res)
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask=None,
    query_hidden_states = hidden_states_hires_reduced # reduced version of high-res inputs (reduced along sequence dimenion)
)
print(hidden_states_midres.shape)

torch.Size([2, 24, 384])
torch.Size([2, 24, 384])


In [258]:
# Reduce sequence-dim from l1->l2, and from high-res->mid-res
hidden_states_midres_reduced = maxpool_l2(hidden_states_midres)
assert hidden_states_midres_reduced.shape[1] == hidden_states_lowres.shape[1] # reduced-seq-dim should be same as mid-res hidden-states
print(hidden_states_midres_reduced.shape)

if True:
  print(hidden_states_lowres.shape)
  hidden_states_lowres = bert_reduce_add_integrate_lowres(
      inputs = hidden_states_midres, # from highres outputs previous layer (key, values)
      hidden_states = hidden_states_lowres, # previous hidden-states for skip connection (short squence-dim, low-res)
      attention_mask = extended_attention_mask_l2_reduced,
      head_mask=None,
      query_hidden_states = hidden_states_midres_reduced # reduced version of high-res inputs (reduced along sequence dimenion)
  )
  print(hidden_states_lowres.shape)

torch.Size([2, 12, 384])
torch.Size([2, 12, 192])
torch.Size([2, 12, 192])


### Base-Layer nn.Module

In [109]:
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
from torch import nn
from transformers.models.bert.modeling_bert import BertEncoder
from transformers.activations import ACT2FN
from typing import List, Optional, Tuple, Union

def make_config(
    modelstring = "distilroberta-base",
    scale_ratio2 = 0.5,
    scale_ratio3 = 0.25,
    scale_intermediate2 = 4,
    scale_intermediate3 = 4,
    num_layers_l3 = 3,
    dropout_scaling = 0.05
):
    #if True:
    #modelstring = "distilroberta-base"
    #scale_ratio2 = 0.5
    #scale_ratio3 = 0.25
    #scale_intermediate2 = 4
    #scale_intermediate3 = 4
    base_config = AutoConfig.from_pretrained(modelstring)
    config_l3 = copy.deepcopy(base_config)
    setattr(base_config, 'model_string', modelstring)
    setattr(base_config,'num_layers_l3', num_layers_l3)
    setattr(base_config,'scale_ratio2', scale_ratio2)
    setattr(base_config,'scale_ratio3', scale_ratio3)
    setattr(base_config,'scale_factor2', int(1/base_config.scale_ratio2))
    setattr(base_config,'scale_factor3', int(1/base_config.scale_ratio3*base_config.scale_ratio2))
    setattr(base_config,"hidden_size_l2", int(base_config.hidden_size * scale_ratio2))
    setattr(base_config,"hidden_size_l3", int(base_config.hidden_size * scale_ratio3))
    setattr(base_config,"intermediate_size_l1", int(base_config.hidden_size_l2*4.0))
    setattr(base_config,"intermediate_size_l2", int(base_config.hidden_size_l3*4.0))
    setattr(base_config,"query_size1", base_config.hidden_size_l2 + base_config.hidden_size_l3)
    setattr(base_config,"query_size2", base_config.hidden_size_l3)
    setattr(base_config,"dropout_scaling", dropout_scaling)

    # make the configuration for the l3 encoder
    config_l3.hidden_size = base_config.hidden_size_l3
    config_l3.num_hidden_layers = num_layers_l3
    setattr(base_config, 'config_l3', config_l3)
    return base_config

def initialize_baselayers(config, basemod = None, tokenizer=None):
    # initialize the basemodel
    if basemod is None:
        basemod = AutoModel.from_pretrained(config.model_string)
    if tokenizer is None:
        # download pretrained tokenizer
        tokenizer = AutoTokenizer.from_pretrained(config.model_string)
    
    device = basemod.device
    setattr(config, 'device', device)

    # get basemodel's embeddings
    layer_embedding = copy.deepcopy(basemod._modules['embeddings'])

    # get basemodel's first transformer block
    layer_basetransformer = copy.deepcopy(basemod._modules['encoder']._modules['layer']._modules['0'])

    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize downsampling attention layers
    bert_reducer_l2 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size, 
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor2
    )
    # 1/4 hidden size
    bert_reducer_l3 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size_l2, 
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor3
    )

    # initialize the low-resolution BertEncoder
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: low res to mid res
    bert_integrative_layer_2 = BertIntegrativeLayer(
        config, 
        hidden_size=config.hidden_size_l2,
        hidden_size_query=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )

    # from mid-res to high-res
    bert_integrative_layer_1 = BertIntegrativeLayer(
        config, 
        hidden_size=config.hidden_size,
        hidden_size_query=config.query_size1,
        intermediate_size=config.intermediate_size_l1
    )

    return (
        tokenizer,
        basemod,
        layer_embedding,
        layer_basetransformer,
        maxpool,
        maxpool_attn,
        bert_reducer_l2,
        bert_reducer_l3,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrative_layer_2,
        bert_integrative_layer_1
    )

class AnathemBaseModule(nn.Module):
    def __init__(
            self, 
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            tokenizer, basemod, 
            layer_embedding,
            layer_basetransformer,
            maxpool,
            maxpool_attn,
            bert_reducer_l2,
            bert_reducer_l3,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrative_layer_2,
            bert_integrative_layer_1
        ) = initialize(config, basemod, tokenizer)

        self.get_extended_attention_mask = basemod.get_extended_attention_mask
        self.embedding = layer_embedding
        self.layer_basetransformer = layer_basetransformer
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducer_l2 = bert_reducer_l2
        self.bert_reducer_l3 = bert_reducer_l3
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrative_layer_2 = bert_integrative_layer_2
        self.bert_integrative_layer_1 = bert_integrative_layer_1
    
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = input_ids
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        extended_attention_mask_l1 = self.get_extended_attention_mask(attention_mask, input_shape)
        # downsample the attention mask to l2 dimension
        attention_mask_l2 = self.maxpool_attn(attention_mask.float())
        extended_attention_mask_l2 = self.get_extended_attention_mask(attention_mask_l2,attention_mask_l2.shape)
        # downsample the attention mask to l3 dimension
        attention_mask_l3 = self.maxpool_attn(attention_mask_l2.float())
        extended_attention_mask_l3 = self.get_extended_attention_mask(attention_mask_l3,attention_mask_l3.shape)
        
        # embed
        embedding_output = self.embedding(
            input_ids = input_ids,
            position_ids = position_ids,
            token_type_ids = token_type_ids,
            #input_embeds=None,
            past_key_values_length = past_key_values_length
        )

        # first transformer block (vanilla transformer)
        out_l1 = self.layer_basetransformer(
            hidden_states = embedding_output,
            attention_mask = extended_attention_mask_l1,
            head_mask=head_mask,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=output_attentions
        )
        hidden_states_l1 = out_l1[0]

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_l1)

        # reduce dimenion on sequence 2
        out_l2 = self.bert_reducer_l2(
            hidden_states = hiddens_states_l1_reduced,
            attention_mask = extended_attention_mask_l2,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l1,
            encoder_attention_mask= extended_attention_mask_l1,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l2 = out_l2[0]

        # TODO: I should have another vanilla encoder here at this intermediate step
        #
        # intermediate attention
        #
        
        # reduce sequence length
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        out_l3 = self.bert_reducer_l3(
            hidden_states = hiddens_states_l2_reduced,
            attention_mask = extended_attention_mask_l3,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l2,
            encoder_attention_mask= extended_attention_mask_l2,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l3 = out_l3[0]

        #print(hidden_states_l3.shape)
        #print(extended_attention_mask_l3.shape)
        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_l3,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l3 = out_encoder[0]
        
        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_l3)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_l2 = self.bert_integrative_layer_2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_l2,
            head_mask = head_mask,
            query_hidden_states = hidden_states_upscaled3to2,
            query_attention_mask = attention_mask_l2
        )
        
        # upscaling: l3/l2 to l1 sequence length
        hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_l3)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_l2)
        hidden_states_upscaled = torch.cat((
            hidden_states_upscaled2to1, hidden_states_upscaled3to1
        ),axis=2)        

        # integrate low-resolution information back to original dimension
        hidden_states_l1 = self.bert_integrative_layer_1(
            hidden_states = hidden_states_l1,
            attention_mask = extended_attention_mask_l1,
            head_mask = head_mask,
            query_hidden_states = hidden_states_upscaled,
            query_attention_mask = extended_attention_mask_l1
        )
        return hidden_states_l1


class BertClassificationHead(nn.Module):
    def __init__(self, config, n_classes = 1, activation = 'sigmoid'):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size*2, n_classes)
        if activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'none':
            self.activation = lambda x: x

    def forward(self, hidden_states, attention_mask) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        output_vectors=[]
        first_token_tensor = hidden_states[:, 0]
        output_vectors.append(first_token_tensor)
        # mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        output_vectors.append(sum_embeddings / sum_mask)
        # concatenate
        pooled_output = torch.concat(output_vectors, axis=1)
        #print(pooled_output.shape)
        logits = self.dense(pooled_output)
        return self.activation(logits)


def tokenize_anathem(text):
    #padding_length = int(math.ceil(max_length / 4)) * 
    tokens = tokenizer(text,padding=True, return_tensors='pt')
    input_shape = tokens['input_ids'].size()

    # change token padding to be multiple of 4
    ideal_length = int(math.ceil(input_shape[-1] / 4)) * 4 # should be a multiple of 4
    if input_shape[-1]!=ideal_length:
      tokens = tokenizer(text,padding='max_length', max_length = ideal_length, return_tensors='pt')
      input_shape = tokens['input_ids'].size()

    token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    tokens['token_type_ids'] = token_type_ids
    
    return tokens

In [110]:
config = make_config('distilroberta-base')

if False:
  (tokenizer,basemod,layer_embedding,layer_basetransformer,maxpool,maxpool_attn,bert_reducer_l2,
   bert_reducer_l3,bert_encoder_lowres,upscaler_x2,upscaler_x4,bert_integrative_layer_2,bert_integrative_layer_1) = initialize(config)

# make the basemod and tokenizer
basemod = AutoModel.from_pretrained(config.model_string)
tokenizer = AutoTokenizer.from_pretrained(config.model_string)



Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [116]:
# the Anathem encoder includes the embeddings and first transformer block
anathem_encoder1 = AnathemBaseModule(config, basemod, tokenizer)

In [117]:
cls_head = BertClassificationHead(config, n_classes = 3, activation = 'none')

In [118]:
text = [
    "A standard indemnity clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with legal issues1.",
    "It usually consists of two elements: a trigger event or circumstance and a payment obligation2. The trigger event or circumstance is the breach of the agreement, willful misconduct, or negligence of the indemnifying party or its affiliates"
]

tokens = tokenize_anathem(text)

In [119]:
out = anathem_encoder1(
      input_ids = tokens['input_ids'],
      attention_mask = tokens['attention_mask'],
      token_type_ids = tokens['token_type_ids']
)
cls_head(out, tokens['attention_mask'])



tensor([[ 0.2563,  0.3797, -0.4073],
        [ 0.3182,  0.2725, -0.4744]], grad_fn=<AddmmBackward0>)

In [120]:
####

In [121]:
## Next steps, do something simple like sentiment analysis

In [122]:
from datasets import list_datasets, load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support
from scipy.special import softmax
#datasets_list = list_datasets()
#[k for k in datasets_list if 'phrasebank' in k]


In [123]:
[k for k in datasets_list if 'phrasebank' in k]

dataset = load_dataset('financial_phrasebank', 'sentences_75agree')

# split
idx_train, idx_val = train_test_split(np.arange(len(dataset['train']['sentence'])), test_size=0.1)
dataset_train = [{'text':dataset['train']['sentence'][idx], 'label':dataset['train']['label'][idx]}  for idx in idx_train]
dataset_val = [{'text':dataset['train']['sentence'][idx], 'label':dataset['train']['label'][idx]} for idx in idx_val]



  0%|          | 0/1 [00:00<?, ?it/s]

In [124]:
print(len(dataset_train)); print(len(dataset_val))

3107
346


In [125]:
class MyDataset(Dataset):
    """torch dataset."""

    def __init__(self, dataset):
        self.data = dataset
        self.n = len(self.data)
    
    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        unit = self.data[idx]
        return unit

In [126]:
ds_train = MyDataset(dataset_train)
ds_val = MyDataset(dataset_val)

In [169]:
batch_size_train = 12
batch_size_val = 36
lr = 0.0001
eval_iter = 100
n_epochs = 5

In [135]:
dl_train = DataLoader(ds_train, batch_size=batch_size_train, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=batch_size_val, shuffle=False)

In [None]:
optimizer = AdamW(list(anathem_encoder1.parameters())+ list(cls_head.parameters()), lr=lr)

In [173]:

optimizer.zero_grad()
anathem_encoder1.train()
cls_head.train()
for epoch in range(n_epochs):

  for iteration, batch in enumerate(tqdm(dl_train, disable=True)):
      
      # tokenize the batch
      tokens = tokenize_anathem(batch['text'])
      target = batch['label']
      
      optimizer.zero_grad()

      features = anathem_encoder1(
        input_ids = tokens['input_ids'],
        attention_mask = tokens['attention_mask'],
        token_type_ids = tokens['token_type_ids']
      )

      # prediction
      preds = cls_head(features, tokens['attention_mask'])

      # loss
      loss = nn.functional.cross_entropy(preds, target)
      loss.backward()
      optimizer.step()

      # do evaluation
      if ((i+1) % eval_iter)==0:  
          anathem_encoder1.eval()
          cls_head.eval()
          # tokenize the eval
          eval_logits = []
          eval_targets = []
          for i, batch_eval in enumerate(tqdm(dl_val, disable=True)):
              with torch.no_grad():
                  # tokenize the batch
                  tokens_eval = tokenize_anathem(batch_eval['text'])
                  labels_eval = batch_eval['label']
                  features = anathem_encoder1(
                      input_ids = tokens_eval['input_ids'],
                      attention_mask = tokens_eval['attention_mask'],
                      token_type_ids = tokens_eval['token_type_ids']
                  )
                  # prediction
                  batch_logits = cls_head(features, tokens_eval['attention_mask'])
                  eval_logits+=batch_logits.detach().tolist()
                  eval_targets+=labels_eval.detach().tolist()

          eval_prec,eval_recall,eval_f1,eval_support = precision_recall_fscore_support(eval_targets, np.array(eval_logits).argmax(axis=1),zero_division=0)
          print('E:%d; i:%d: f1:%0.3f (%0.3f); prec:%0.3f (%0.3f); rec:%0.3f (%0.3f)' % (epoch, iteration, eval_f1.mean(), eval_f1.min(), eval_prec.mean(), eval_prec.min(), eval_recall.mean(), eval_recall.min()))
          cls_head.train()
          anathem_encoder1.train()

          


E:0; i:14: f1:0.794 (0.716); prec:0.793 (0.644); rec:0.807 (0.674)
E:0; i:14: f1:0.841 (0.767); prec:0.832 (0.660); rec:0.868 (0.739)
E:1; i:14: f1:0.905 (0.880); prec:0.897 (0.846); rec:0.914 (0.880)
E:1; i:14: f1:0.878 (0.819); prec:0.864 (0.723); rec:0.904 (0.837)
E:2; i:14: f1:0.917 (0.886); prec:0.930 (0.912); rec:0.906 (0.861)
E:2; i:14: f1:0.885 (0.833); prec:0.887 (0.833); rec:0.882 (0.833)


KeyboardInterrupt: ignored

In [160]:
# do evaluation
if True:
    if True:
        if True: #if (i+1) % eval_iter:
            anathem_encoder1.eval()
            cls_head.eval()
            # tokenize the eval
            eval_logits = []
            eval_targets = []
            for i, batch in enumerate(tqdm(dl_val, disable=True)):
                with torch.no_grad():
                    # tokenize the batch
                    tokens = tokenize_anathem(batch['text'])
                    target = batch['label']
                    features = anathem_encoder1(
                      input_ids = tokens['input_ids'],
                      attention_mask = tokens['attention_mask'],
                      token_type_ids = tokens['token_type_ids']
                    )
                    # prediction
                    batch_logits = cls_head(features, tokens['attention_mask'])
                    eval_logits+=batch_logits.detach().tolist()
                    eval_targets+=target.detach().tolist()

            eval_prec,eval_recall,eval_f1,eval_support = precision_recall_fscore_support(eval_targets, np.array(eval_logits).argmax(axis=1),zero_division=0)
            print('E:%d; i:%d: f1:%0.3f (%0.3f); prec:%0.3f (%0.3f); rec:%0.3f (%0.3f)' % (epoch, i, eval_f1.mean(), eval_f1.min(), eval_prec.mean(), eval_prec.min(), eval_recall.mean(), eval_recall.min()))
            cls_head.train()
            anathem_encoder1.train()


E:%d; i:%d: f1:%0.3f; rec: %0.3f; prec:%0.3f 0 14 [0.         0.86464646 0.52173913] [0.         0.98165138 0.45652174] [0.         0.77256318 0.60869565]


In [168]:
print('E:%d; i:%d: f1:%0.3f (%0.3f); prec:%0.3f (%0.3f); rec:%0.3f (%0.3f)' % (epoch, i, eval_f1.mean(), eval_f1.min(), eval_prec.mean(), eval_prec.min(), eval_recall.mean(), eval_recall.min()))

E:0; i:14: f1:0.462 (0.000); prec:0.460 (0.000); rec:0.479 (0.000)


In [167]:
eval_f1

array([0.        , 0.86464646, 0.52173913])

In [159]:
eval_prec,eval_recall,eval_f1,eval_support = precision_recall_fscore_support(eval_targets, np.array(eval_logits).argmax(axis=1),zero_division=0)