In [4]:
from transformers.models.bert.modeling_bert import *


class AttentionFusionBertModel(BertModel):
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        attention_adj: Optional[torch.Tensor] = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = (
            past_key_values[0][0].shape[2] if past_key_values is not None else 0
        )

        if attention_mask is None:
            attention_mask = torch.ones(
                ((batch_size, seq_length + past_key_values_length)), device=device
            )

        if token_type_ids is None:
            if hasattr(self.embeddings, "token_type_ids"):
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
                    batch_size, seq_length
                )
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(
                    input_shape, dtype=torch.long, device=device
                )

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
            attention_mask, input_shape
        )
        

        
        if attention_adj is not None:
            if attention_adj.dim() == 3:
                attention_adj = attention_adj[:, None, :, :]

            if attention_adj.shape[-1] != extended_attention_mask.shape[-1] or len(attention_adj.shape) != len(extended_attention_mask.shape):
                raise ValueError(
                    f"Shape of attention_adj does not match extended_attention_mask "
                )
            if attention_adj.device != extended_attention_mask.device:
                raise ValueError(
                    "attention_adj and extended_attention_mask must be on the same device"
                )
            if attention_adj.dtype != extended_attention_mask.dtype:
                raise ValueError(
                    "attention_adj and extended_attention_mask must have the same data type"
                )
            print(extended_attention_mask.shape)
            print(attention_adj.shape)
                
            extended_attention_mask = extended_attention_mask + attention_adj

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = (
                encoder_hidden_states.size()
            )
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(
                encoder_attention_mask
            )
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = (
            self.pooler(sequence_output) if self.pooler is not None else None
        )

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )


In [5]:
from transformers import BertConfig

# Define a basic configuration with minimal settings
config = BertConfig(
    # The number of hidden layers (transformer blocks)
    num_hidden_layers=1,
    # The hidden size of the encoder and pooler layers
    hidden_size=32,
    # The number of attention heads in each attention layer
    num_attention_heads=4,
    # The size of the input vocab
    vocab_size=30522,  # This is the vocab size for BERT base; adjust if using a different model
    # The maximum length of the input sequences
    max_position_embeddings=512,
    # The size of the intermediate (feed-forward) layer in the transformer
    intermediate_size=128,
    # Dropout probability
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    # Specify whether to add a pooling layer on top of the attention output for classification tasks
    add_pooling_layer=True
)

# Initialize the BertModel with the specified configuration
model = AttentionFusionBertModel(config)

# Print the model architecture
print(model)


AttentionFusionBertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 32, padding_idx=0)
    (position_embeddings): Embedding(512, 32)
    (token_type_embeddings): Embedding(2, 32)
    (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=32, out_features=32, bias=True)
            (key): Linear(in_features=32, out_features=32, bias=True)
            (value): Linear(in_features=32, out_features=32, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=32, out_features=32, bias=True)
            (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
        

In [7]:
import torch

# Define the sequence length for your dummy inputs
seq_length = 10

# Create random embeddings as dummy input data
# (batch size, sequence length, hidden size)
inputs_embeds = torch.rand((2, seq_length, config.hidden_size))

# Create an attention mask (batch size, sequence length)
# Here assuming no padding, so mask is all ones
attention_mask = torch.ones((2, seq_length))

attention_adj = torch.ones((2, seq_length, seq_length)) * -0.8

# You can also simulate having some padding by making part of the mask zeros
# For example, if only the first 7 tokens are not padding:
# attention_mask[:, 7:] = 0

# Run the model with the embedded input
outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, attention_adj=attention_adj)


torch.Size([2, 1, 1, 10])
torch.Size([2, 1, 10, 10])
torch.Size([2, 1, 10, 10])
torch.Size([2, 4, 10, 10])


In [None]:
torch.finfo(torch.float32).min

-3.4028234663852886e+38