In [1]:
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertEmbeddings, BaseModelOutputWithPoolingAndCrossAttentions, BertPreTrainedModel, BertOnlyMLMHead
from transformers import BertConfig, BertModel, BertForMaskedLM
import torch
from torch import nn
import torch.nn as nn
from transformers import BertTokenizer, BertConfig

2025-02-24 18:38:46.271129: I tensorflow/core/platform/cpu_feature_guard.cc:181] Beginning TensorFlow 2.15, this package will be updated to install stock TensorFlow 2.15 alongside Intel's TensorFlow CPU extension plugin, which provides all the optimizations available in the package and more. If a compatible version of stock TensorFlow is present, only the extension will get installed. No changes to code or installation setup is needed as a result of this change.
More information on Intel's optimizations for TensorFlow, delivered as TensorFlow extension plugin can be viewed at https://github.com/intel/intel-extension-for-tensorflow.
2025-02-24 18:38:46.271179: I tensorflow/core/platform/cpu_feature_guard.cc:192] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class BertEmbeddingsV2(BertEmbeddings):
    def __init__(self, config):
        super().__init__(config)
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        
        # add a prosody embedding
        self.prosody_embeddings = nn.Embedding(config.prosody_cluster_size, config.hidden_size)

        # Convolutional layer
        self.conv = nn.Conv1d(in_channels=config.hidden_size, out_channels=config.hidden_size, kernel_size=3, padding=1)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

    def forward(self, input_ids=None, prosody_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        prosody_embeddings = self.prosody_embeddings(prosody_ids)  # here
        
        # combine the embeddings with new pos tagging embedding
        embeddings = inputs_embeds + token_type_embeddings + prosody_embeddings
        
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings

        # Apply 1D Convolution
        embeddings = embeddings.permute(0, 2, 1)  # (batch_size, hidden_size, seq_length)
        embeddings = self.conv(embeddings)  # Convolution
        embeddings = torch.relu(embeddings)  # Activation
        embeddings = embeddings.permute(0, 2, 1)  # Back to (batch_size, seq_length, hidden_size)

        
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [3]:
class BertModelV2(BertModel):
    """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
    To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
    set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
    input to the forward pass.
    """

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddingsV2(config)  # here
        self.encoder = BertEncoder(config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        self.init_weights()
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        prosody_ids=None,  # add here
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            batch_size, seq_length = input_shape
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size, seq_length = input_shape
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            prosody_ids=prosody_ids,  # here
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]
  
        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )

In [4]:
# This is from the source code

class BertForMaskedLMV2(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        assert (
            not config.is_decoder
        ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."

        self.bert = BertModelV2(config)
        self.cls = BertOnlyMLMHead(config)
        
        # Prodsody Head (For predicting prosody cluster ids)
        self.prosody_head = nn.Linear(config.hidden_size, config.prosody_cluster_size)
        
        self.init_weights()

    def get_output_embeddings(self):
        """This helps from_pretrained() correctly locate the MLM head weights during weight loading."""
        return self.cls.predictions.decoder


    # @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
    # @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
    def forward(
        self,
        prosody_ids=None, 
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        **kwargs
    ):
        outputs = self.bert(
            input_ids,
            prosody_ids=prosody_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
    
        sequence_output = outputs.last_hidden_state
        
        prediction_scores = self.cls(sequence_output)
        prosody_logits = self.prosody_head(sequence_output)
        
        loss = None
        if labels is not None and prosody_labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))  # MLM loss
            prosody_loss = loss_fct(prosody_logits.view(-1, config.prosody_cluster_size), prosody_labels.view(-1)) # prosody loss
            loss = token_loss + prosody_loss  # Combine losses
    
        return MaskedLMWithProsodyOutput(  # here
            loss=loss,
            logits=prediction_scores,
            prosody_logits=prosody_logits,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions if output_attentions else None,
        )

    # def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
    #     input_shape = input_ids.shape
    #     effective_batch_size = input_shape[0]

    #     #  add a dummy token
    #     assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
    #     attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
    #     dummy_token = torch.full(
    #         (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
    #     )
    #     input_ids = torch.cat([input_ids, dummy_token], dim=1)

    #     return {"input_ids": input_ids, "attention_mask": attention_mask}

In [5]:
class BertConfigV2(BertConfig):
    def __init__(self, prosody_cluster_size=3, **kwargs):
        super().__init__(**kwargs)
        self.prosody_cluster_size = prosody_cluster_size

In [6]:
from transformers.modeling_outputs import MaskedLMOutput
from dataclasses import dataclass
import torch

@dataclass
class MaskedLMWithProsodyOutput(MaskedLMOutput):
    """
    Output class for masked language modeling with an additional prosody prediction head.
    
    Inherits from `MaskedLMOutput` and adds `prosody_logits`.
    """
    prosody_logits: torch.Tensor = None  # (batch_size, sequence_length, num_prosody_classes)

In [7]:
if __name__ == "__main__":
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # "prajjwal1/bert-mini"
    text = "This is a great [MASK]."

    encoding = tokenizer(text, return_tensors="pt")
    print(encoding)

    encoding['prosody_ids'] = torch.tensor([[0, 1, 1, 1, 1, 2, 1, 0]])

    config = BertConfigV2.from_pretrained("bert-base-uncased", prosody_cluster_size=3)
    model = BertForMaskedLMV2.from_pretrained(
        "bert-base-uncased",
        config=config,
        ignore_mismatched_sizes=True  # Allows extra layers like `prosody_head`
    )

    outputs = model(**encoding)

    token_logits = outputs.logits
    prosody_logits = outputs.prosody_logits
    
    # Find the location of [MASK] and extract its logits
    mask_token_index = torch.where(encoding["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]
    
    # Pick the [MASK] candidates with the highest logits
    top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
    for token in top_5_tokens:
        print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token])), str(token)}'")

    mask_prosody_logits = prosody_logits[0, mask_token_index, :]
    
    # Pick the [MASK] candidates with the highest logits
    top_5_prosody_tokens = torch.topk(mask_prosody_logits, 2, dim=1).indices[0].tolist()
    for token in top_5_prosody_tokens:
        print(f"'>>> Top prosody id {str(token)}'")    

{'input_ids': tensor([[ 101, 2023, 2003, 1037, 2307,  103, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}


Some weights of BertForMaskedLMV2 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.conv.bias', 'bert.embeddings.conv.weight', 'bert.embeddings.position_ids', 'bert.embeddings.prosody_embeddings.weight', 'prosody_head.bias', 'prosody_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


'>>> ('This is a great by.', '2011')'
'>>> ('This is a great to.', '2000')'
'>>> ('This is a great ..', '1012')'
'>>> ('This is a great na.', '6583')'
'>>> ('This is a great or.', '2030')'
'>>> Top prosody id 1'
'>>> Top prosody id 2'




In [None]:
## Define the model

from transformers import BertConfig

# Define custom config with prosody support
config = BertConfigV2(
    vocab_size=tokenizer.vocab_size,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    prosody_cluster_size=3  # Number of prosody clusters
)

# Load the modified model
model = BertForMaskedLMV2(config)

In [119]:
pretrained_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
modified_model = BertForMaskedLMV2.from_pretrained("bert-base-uncased", config=config, ignore_mismatched_sizes=True)

for (name1, param1), (name2, param2) in zip(pretrained_model.named_parameters(), modified_model.named_parameters()):
    print(f"Original: {name1} | Modified: {name2} | Shape: {param1.shape} vs {param2.shape}")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMaskedLMV2 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.position_ids', 'bert.embeddings.prosody_embeddings.weight', 'mlm_head.predictions.bias', 'mlm_head.predictions.decoder.bias', 'mlm_head.predictions.decoder.weight', '

Original: bert.embeddings.word_embeddings.weight | Modified: bert.embeddings.word_embeddings.weight | Shape: torch.Size([30522, 768]) vs torch.Size([30522, 768])
Original: bert.embeddings.position_embeddings.weight | Modified: bert.embeddings.position_embeddings.weight | Shape: torch.Size([512, 768]) vs torch.Size([512, 768])
Original: bert.embeddings.token_type_embeddings.weight | Modified: bert.embeddings.token_type_embeddings.weight | Shape: torch.Size([2, 768]) vs torch.Size([2, 768])
Original: bert.embeddings.LayerNorm.weight | Modified: bert.embeddings.LayerNorm.weight | Shape: torch.Size([768]) vs torch.Size([768])
Original: bert.embeddings.LayerNorm.bias | Modified: bert.embeddings.LayerNorm.bias | Shape: torch.Size([768]) vs torch.Size([768])
Original: bert.encoder.layer.0.attention.self.query.weight | Modified: bert.embeddings.prosody_embeddings.weight | Shape: torch.Size([768, 768]) vs torch.Size([3, 768])
Original: bert.encoder.layer.0.attention.self.query.bias | Modified: 

In [None]:
# deprecated

class BertForMaskedLMV2(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        
        # Use the modified BertModelV2 with pos_tag_embeddings
        self.bert = BertModelV2(config)   # here
        
        # MLM Head (For predicting masked tokens)
        self.mlm_head = BertOnlyMLMHead(config)
        
        # Prodsody Head (For predicting prosody cluster ids)
        self.prosody_head = nn.Linear(config.hidden_size, config.prosody_cluster_size) # here
        # Initialize weights
        self.init_weights()

    def forward(
        self, input_ids=None, prosody_ids=None, token_type_ids=None, attention_mask=None, labels=None, prosody_labels=None
    ):
        
        # Get embeddings and transformer output from BERT
        outputs = self.bert(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
           prosody_ids=prosody_ids,
        )

        sequence_output = outputs.last_hidden_state  # (batch_size, seq_length, hidden_size)
        
        # Predict masked tokens
        token_logits = self.mlm_head(sequence_output)  # (batch_size, seq_length, vocab_size)
        
        # Predict prosody
        prosody_logits = self.prosody_head(sequence_output)  # (batch_size, seq_length, num_pos_tags)

        loss = None
        if labels is not None and prosody_labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            token_loss = loss_fct(token_logits.view(-1, self.config.vocab_size), labels.view(-1))  # MLM loss
            prosody_loss = loss_fct(prosody_logits.view(-1, config.prosody_cluster_size), prosody_labels.view(-1)) # prosody loss
            loss = token_loss + prosody_loss  # Combine losses
        
        return MaskedLMWithProsodyOutput(
            loss=loss,
            logits=token_logits,
            prosody_logits=prosody_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions
        )

# Without Prosody Test

In [117]:
model_name = "bert-base-uncased"
model = BertForMaskedLM.from_pretrained(model_name)

tokenizer = BertTokenizer.from_pretrained(model_name)

text = "This is a great [MASK]."
inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits

# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]

# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


'>>> This is a great idea.'
'>>> This is a great day.'
'>>> This is a great place.'
'>>> This is a great time.'
'>>> This is a great thing.'


In [199]:
class BertEmbeddingsV2(BertEmbeddings):
    def __init__(self, config):
        super().__init__(config)
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        
        # combine the embeddings with new pos tagging embedding
        embeddings = inputs_embeds + token_type_embeddings
        # convol
        # embeddings = conv(embeddings)
        
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [200]:
class BertModelV2(BertModel):
    """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
    To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
    set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
    input to the forward pass.
    """

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddingsV2(config)  # here
        self.encoder = BertEncoder(config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        self.init_weights()
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            batch_size, seq_length = input_shape
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size, seq_length = input_shape
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]
  
        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )

In [221]:
class BertForMaskedLMV2(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        
        # Use the modified BertModelV2 with pos_tag_embeddings
        self.bert = BertModelV2(config)   # here
        
        # MLM Head (For predicting masked tokens)
        self.cls = BertOnlyMLMHead(config)
        
        # Initialize weights
        self.init_weights()

    def get_output_embeddings(self):
        """This helps from_pretrained() correctly locate the MLM head weights during weight loading."""
        return self.cls.predictions.decoder
        
    def forward(
        self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, prosody_labels=None
    ):
        
        # Get embeddings and transformer output from BERT
        outputs = self.bert(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )

        sequence_output = outputs.last_hidden_state  # (batch_size, seq_length, hidden_size)
        
        # Predict masked tokens
        token_logits = self.cls(sequence_output)  # (batch_size, seq_length, vocab_size)
        

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            token_loss = loss_fct(token_logits.view(-1, self.config.vocab_size), labels.view(-1))  # MLM loss
            loss = token_loss
        
        return MaskedLMOutput(
            loss=loss,
            logits=token_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions
        )

In [219]:
# This is from the source code

class BertForMaskedLMV2(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        assert (
            not config.is_decoder
        ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."

        self.bert = BertModelV2(config)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder


    # @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
    # @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        **kwargs
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
    
        sequence_output = outputs.last_hidden_state
        prediction_scores = self.cls(sequence_output)
    
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
    
        return MaskedLMOutput(
            loss=loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions if output_attentions else None,
        )

    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
        input_shape = input_ids.shape
        effective_batch_size = input_shape[0]

        #  add a dummy token
        assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
        attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
        dummy_token = torch.full(
            (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
        )
        input_ids = torch.cat([input_ids, dummy_token], dim=1)

        return {"input_ids": input_ids, "attention_mask": attention_mask}

In [223]:
if __name__ == "__main__":
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # "prajjwal1/bert-mini"
    text = "This is a great [MASK]."

    encoding = tokenizer(text, return_tensors="pt")
    print(encoding)

    model = BertForMaskedLMV2.from_pretrained("bert-base-uncased", ignore_mismatched_sizes=True)
    # pretrained_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
    # model.cls.load_state_dict(pretrained_model.cls.state_dict())

    outputs = model(**encoding)

    token_logits = outputs.logits
    
    # Find the location of [MASK] and extract its logits
    mask_token_index = torch.where(encoding["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]
    
    # Pick the [MASK] candidates with the highest logits
    top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
    for token in top_5_tokens:
        print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")  

{'input_ids': tensor([[ 101, 2023, 2003, 1037, 2307,  103, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}


Some weights of BertForMaskedLMV2 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


'>>> This is a great idea.'
'>>> This is a great day.'
'>>> This is a great place.'
'>>> This is a great time.'
'>>> This is a great thing.'


# Train

In [16]:
import torch
from transformers import BertTokenizer, BertConfig

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-mini")

# Input text
text = "This is a great [MASK]."
encoding = tokenizer(text, return_tensors="pt")

# Add prosody IDs (Example: 0 = DET, 1 = NOUN, etc.)
encoding["prosody_ids"] = torch.tensor([[0, 1, 1, 1, 1, 2, 1, 0]])

# Ensure labels match input length
labels = tokenizer("This is a great day.", return_tensors="pt", padding="max_length", max_length=encoding["input_ids"].shape[1])["input_ids"]
encoding["labels"] = labels

# Prosody labels
encoding["prosody_labels"] = torch.tensor([[0, 1, 1, 1, 1, 2, 1, 0]])  # Example prosody labels

# Load model configuration
config = BertConfig.from_pretrained("prajjwal1/bert-mini")
config.prosody_cluster_size = 3  # Ensure this is set

# Initialize model
model = BertForMaskedLMV2(config)

# Forward pass
outputs = model(**encoding)

loss = outputs["loss"]
loss.backward()