In [None]:
import torch
from transformers import BertConfig, BertModel, BertTokenizer
from transformers.models.bert.modeling_bert import BertLayer, BertSelfAttention
import torch
import torch.nn as nn

import seaborn as sns
import matplotlib.pyplot as plt

In [186]:
class SpAttenState:
    def __init__(self, seq_len, num_heads, device):
        self.st = torch.zeros(seq_len, device=device)
        self.sh = torch.zeros(num_heads, device=device)
        self.token_ids = torch.arange(seq_len, device=device)
        self.head_ids  = torch.arange(num_heads, device=device)

def topk_masking(scores, keep_ratio):
    """
    Create a hard mask by keeping the top-k tokens based on scores.
    Args:
        scores (torch.Tensor): Scores for each token (batch_size, seq_len).
        keep_ratio (float): Ratio of tokens to keep (between 0 and 1).
    Returns:
        torch.Tensor: Hard mask (batch_size, seq_len) with 1s for kept tokens and 0s for pruned tokens.
    """
    batch_size, seq_len = scores.size()
    k = int(seq_len * keep_ratio)

    # Get the top-k indices
    topk_values, topk_indices = torch.topk(scores, k, dim=-1)

    # Create a mask initialized to zeros
    mask = torch.zeros_like(scores)

    # Scatter 1s into the mask at the top-k indices
    mask.scatter_(1, topk_indices, 1)

    return mask

In [294]:
class HeadPrunedBertSelfAttention(BertSelfAttention):
    def __init__(self, config):
        super().__init__(config)
    def forward(self, 
            hidden_states, 
            attention_mask=None, 
            head_mask=None, 
            output_attentions=False,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            past_key_values=None,
            cache_position=None,
    ):
        # Normal attention
        print("Head mask shape in self-attention:", head_mask.shape)
        outputs = super().forward(hidden_states, attention_mask, head_mask, output_attentions=output_attentions, encoder_hidden_states=encoder_hidden_states,
            past_key_values=past_key_values,
            cache_position=cache_position)

        # # outputs[0] = context_layer (batch, seq_len, hidden)
        # context = outputs[0]

        # # Prune heads
        # if head_mask is not None:
        #     # reshape to separate heads
        #     batch_size, seq_len, hidden_size = context.size()
        #     head_dim = hidden_size // self.num_attention_heads
        #     context = context.view(batch_size, seq_len, self.num_attention_heads, head_dim)
        #     for h in self.heads_to_prune:
        #         context[:, :, h, :] = 0.0
        #     context = context.view(batch_size, seq_len, hidden_size)
        #     outputs = (context,) + outputs[1:]

        return outputs


In [295]:
class CascadingMaskBertLayer(BertLayer):
    def __init__(self, config: BertConfig, prune_token_percent):
        super().__init__(config)
        self.prune_token_percent = prune_token_percent

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        token_mask=None,
        head_mask=None,
        output_attentions=True,
    ):
        if token_mask is None:
            token_mask = torch.ones(
                hidden_states.size()[:-1],
                device=hidden_states.device
            )

        # print("Token mask before layer:", token_mask)
        # print("Attention mask before layer:", attention_mask)

        # ---- Apply previous cascade ----
        hidden_states = hidden_states * token_mask.unsqueeze(-1)

        # ---- Extend attention mask ----
        cascade_attn_mask = (1.0 - token_mask) * -1e4
        attention_mask = attention_mask + cascade_attn_mask.unsqueeze(1).unsqueeze(2)

        # ---- Self-attention ----
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        attention_scores = self_attention_outputs[1]
        
        sns.heatmap(attention_scores[0,0,:,:].cpu().detach(), cmap='viridis')
        plt.show()

        # ---- Compute new token decisions ----
        token_scores = attention_scores.sum(dim=(1,2))  # (batch_size, seq_len)

        new_mask = topk_masking(
            token_scores,
            keep_ratio=1-self.prune_token_percent  # eliminate bottom pt% tokens
        )

        # Protect CLS
        new_mask[:, 0] = 1.0
                
        # ---- CASCADE (irreversible) ----
        token_mask = token_mask * new_mask

        # ---- Feed-forward ----
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)

        return layer_output, token_mask


In [296]:
from transformers.models.bert.modeling_bert import BertEncoder
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

def update_head_mask(head_mask, hidden_states, token_mask):
    batch_size, seq_len, hidden_size = hidden_states.size()
    num_heads = head_mask.size(0)
    head_dim = hidden_size // num_heads

    # Reshape hidden states to separate heads
    hidden_states_reshaped = hidden_states.view(batch_size, num_heads, seq_len, head_dim)

    # Compute cumulative attention per head
    head_scores = hidden_states_reshaped.abs().sum(dim=(2, 3))  # (batch_size, num_heads)


    # Create new head mask based on scores (e.g., keep top 80% heads)
    keep_ratio = 0.8
    new_head_mask = torch.ones_like(head_mask)
    for i in range(batch_size):
        topk_indices = torch.topk(head_scores[i], int(num_heads * keep_ratio)).indices
        new_head_mask[i, topk_indices] = 1.0
        new_head_mask[i, ~torch.isin(torch.arange(num_heads), topk_indices)] = 0.0

    return new_head_mask

class CascadingBertEncoder(BertEncoder):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
        **kwargs,
    ):
        batch_size = hidden_states.size(0)
        token_mask = None
        head_mask = torch.ones((batch_size, self.layer[0].attention.self.num_attention_heads), device=hidden_states.device)
        print(head_mask[0])
        # print("Initial head mask:", head_mask)
        for i, layer_module in enumerate(self.layer):
            hidden_states, token_mask = layer_module(
                hidden_states,
                attention_mask=attention_mask,
                token_mask=token_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
            )
            print(f"Layer {i} active tokens:", token_mask.sum(dim=1))
            print(f"Layer {i} active heads:", head_mask.sum(dim=1))

            # head_mask = upate_head_mask(head_mask, hidden_states, token_mask)

        print("Last hidden states shape:", hidden_states.shape)
        
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    None,
                    None,
                    None,
                    None,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
            cross_attentions=None,
        )


In [297]:
from transformers import BertModel

cfg = BertConfig.from_pretrained("bert-base-uncased")

cfg.output_attentions = True
cfg.output_hidden_states = True
cfg.return_dict = True

pt_schedule = [0.1 for _ in range(cfg.num_hidden_layers)]
pt_schedule[0] = 0.0  # No pruning in the first layer


def calculate_pt_schedule(pt_schedule):
    for i, pt in enumerate(pt_schedule):
        if i == 0:
            pt_schedule[i] = 0.0
        else:
            pt_schedule[i] = 1-(1-pt_schedule[i-1])*(1-pt_schedule[i])

    return pt_schedule


cfg.pt = calculate_pt_schedule(pt_schedule)
print("Cascading token pruning schedule:", cfg.pt)


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

sentence = "This is a sample input for the BERT model with cascading token pruning."

inputs = tokenizer(sentence, return_tensors="pt")


class MyClassifier(nn.Module):
    def __init__(self, config:BertConfig):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased", config=config)
        self.bert.encoder = CascadingBertEncoder(self.bert.config)
        for i in range(len(self.bert.encoder.layer)):
            self.bert.encoder.layer[i] = CascadingMaskBertLayer(self.bert.config, pt_schedule[i])
            self.bert.encoder.layer[i].attention.self = HeadPrunedBertSelfAttention(self.bert.config)

        self.dense = nn.Linear(config.hidden_size, 2)  # Binary classification

    def forward(self, **inputs):
        bert_outputs = self.bert(**inputs)

        # print("BERT output:", bert_outputs)
        cls_token = bert_outputs.last_hidden_state[:, 0]  # Use [CLS] token representation
        return self.dense(cls_token)
    
myclassifier = MyClassifier(cfg)
outputs = myclassifier(**inputs)
print("Output shape:", outputs)

Cascading token pruning schedule: [0.0, 0.09999999999999998, 0.18999999999999995, 0.2709999999999999, 0.3438999999999999, 0.4095099999999998, 0.46855899999999984, 0.5217030999999999, 0.5695327899999998, 0.6125795109999999, 0.6513215598999998, 0.6861894039099998]
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Head mask shape in self-attention: torch.Size([1, 12])


RuntimeError: The size of tensor a (20) must match the size of tensor b (12) at non-singleton dimension 3