In [200]:
import torch
from transformers import BertConfig, BertModel, BertTokenizer
from transformers.models.bert.modeling_bert import BertLayer, BertEncoder
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

import torch
import torch.nn as nn

import seaborn as sns
import matplotlib.pyplot as plt

In [205]:
def topk_masking(scores, keep_ratio):
    """
    Create a hard mask by keeping the top-k tokens based on scores.
    Args:
        scores (torch.Tensor): Scores for each token (batch_size, seq_len).
        keep_ratio (float): Ratio of tokens to keep (between 0 and 1).
    Returns:
        torch.Tensor: Hard mask (batch_size, seq_len) with 1s for kept tokens and 0s for pruned tokens.
    """
    _, seq_len = scores.size()
    k = int(seq_len * keep_ratio)

    # Get the top-k indices
    topk_indices = torch.topk(scores, k, dim=-1).indices

    # Create a mask initialized to zeros
    mask = torch.zeros_like(scores)

    # Scatter 1s into the mask at the top-k indices
    mask.scatter_(1, topk_indices, 1)

    return mask

def magnitude_head_scores(attention_output, num_heads):
    """
    attention_output: (batch, seq_len, hidden_size)
    returns: (batch, num_heads)
    """
    batch_size, seq_len, hidden_size = attention_output.size()
    head_dim = hidden_size // num_heads

    # E \in (batch, heads, L0, D)
    E = attention_output.view(
        batch_size, seq_len, num_heads, head_dim
    ).permute(0, 2, 1, 3)

    # s_h = sum_{l,d} |E|
    head_scores = E.abs().sum(dim=(2, 3))  # (batch, heads)

    # print("Raw head scores (summed magnitudes):", head_scores.cpu().detach()/seq_len)
    return head_scores/seq_len


In [206]:
class CascadingMaskBertLayer(BertLayer):
    def __init__(self, config: BertConfig, prune_token_percent, prune_head_percent, visualize=False):
        super().__init__(config)
        self.prune_token_percent = prune_token_percent
        self.prune_head_percent = prune_head_percent
        self.visualize = visualize

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        token_mask=None,
        head_mask=None,
        output_attentions=True,
    ):
        if token_mask is None:
            token_mask = torch.ones(
                hidden_states.size()[:-1],
                device=hidden_states.device
            )

        batch_size = hidden_states.size(0)
        num_heads = self.attention.self.num_attention_heads

        if head_mask is None:
            head_mask = torch.ones(
                (batch_size, num_heads),
                device=hidden_states.device
            )

        head_mask_expanded = head_mask[:, :, None, None]

        # ---- Apply previous cascade ----
        hidden_states = hidden_states * token_mask.unsqueeze(-1)

        # ---- Extend attention mask ----
        cascade_attn_mask = (1.0 - token_mask) * -1e4
        attention_mask = attention_mask + cascade_attn_mask.unsqueeze(1).unsqueeze(2)

        # ---- Self-attention ----
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask=head_mask_expanded,
            output_attentions=output_attentions,
        )
        attention_output, attention_scores = self_attention_outputs

        if self.visualize:
            for sample in range(attention_scores.size(0)):
                if head_mask_expanded[sample,0,0,0] == 0:  # Check if head 0 is active for this sample
                    print(f"Sample {sample} head 0 is pruned, skipping attention score visualization.")
                    continue
                plt.figure()
                plt.title(f"Attention scores for sample {sample} (head 0):")
                sns.heatmap(attention_scores[sample,0,:,:].cpu().detach(), cmap='viridis')
                plt.show()


            plt.figure()
            plt.title(f"Head mask for each sample")
            # sns.heatmap(attention_scores.sum(dim=(2,3)).cpu().detach(), cmap='viridis')
            sns.heatmap(head_mask, cmap='viridis')
            plt.ylabel("Sample index")
            plt.xlabel("Head index")
            plt.show()


        # ---- Compute new token decisions ----
        token_scores = attention_scores.sum(dim=(1,2))  # (batch_size, seq_len)

        new_token_mask = topk_masking(
            token_scores,
            keep_ratio=1-self.prune_token_percent  # eliminate bottom pt% tokens
        )
        # Protect CLS
        new_token_mask[:, 0] = 1.0
                
        # ---- CASCADE ----
        token_mask = token_mask * new_token_mask



        # ---- Compute new head decisions ----
        heads_scores = magnitude_head_scores(attention_output, num_heads=num_heads)
        
        new_head_mask = topk_masking(
            heads_scores,  # (batch_size, num_heads)
            keep_ratio=1-self.prune_head_percent
        )

        # ---- CASCADE ----
        head_mask = head_mask * new_head_mask



        # ---- Feed-forward ----
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)

        return layer_output, token_mask, head_mask


In [215]:
class CascadingBertEncoder(BertEncoder):
    def __init__(self, config, visualize, visualize_prune_decisions=False):
        super().__init__(config)
        self.layer = nn.ModuleList([
            CascadingMaskBertLayer(config, config.pt[i], config.ph[i], visualize=visualize[i])
            for i in range(config.num_hidden_layers)
        ])
        self.visualize_prune_decisions = visualize_prune_decisions

        # for i in range(len(self.bert.encoder.layer)):
        #     self.bert.encoder.layer[i] = CascadingMaskBertLayer(self.bert.config, pt_schedule[i], ph_schedule[i])

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
        **kwargs,
    ):
        batch_size = hidden_states.size(0)
        token_mask = None
        head_mask = None
        
        for i, layer_module in enumerate(self.layer):
            hidden_states, token_mask, head_mask = layer_module(
                hidden_states,
                attention_mask=attention_mask,
                token_mask=token_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
            )

            if self.visualize_prune_decisions:
                print(f"Layer {i} active tokens:", token_mask.sum(dim=1))
                print(f"Layer {i} active heads:", head_mask.sum(dim=1))
        
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    None,
                    None,
                    None,
                    None,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
            cross_attentions=None,
        )


In [218]:
cfg = BertConfig.from_pretrained("bert-base-uncased")

cfg.output_attentions = True
cfg.output_hidden_states = True
cfg.return_dict = True

pt_schedule = [0.1 for _ in range(cfg.num_hidden_layers)]
pt_schedule[0] = 0.0  # No pruning in the first layer

ph_schedule = [0.1 for _ in range(cfg.num_hidden_layers)]
ph_schedule[0] = 0.0  # No pruning in the first layer

def calculate_p_schedule(p_schedule):
    for i in range(len(p_schedule)):
        if i == 0:
            p_schedule[i] = 0.0
        else:
            p_schedule[i] = 1-(1-p_schedule[i-1])*(1-p_schedule[i])

    return p_schedule


cfg.pt = calculate_p_schedule(pt_schedule)
print("Cascading token pruning schedule:", cfg.pt)

cfg.ph = calculate_p_schedule(ph_schedule)
print("Cascading head pruning schedule:", cfg.ph)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

sentence = [
    "This is a sample input for the BERT model with cascading token pruning.",
    "It demonstrates how tokens are pruned layer by layer based on attention scores."
]

inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)

class MyClassifier(nn.Module):
    def __init__(self, config:BertConfig):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased", config=config)
        self.bert.encoder = CascadingBertEncoder(self.bert.config, visualize=[False for _ in range(config.num_hidden_layers)])

        self.dense = nn.Linear(config.hidden_size, 2)  # Binary classification

    def forward(self, **inputs):
        bert_outputs = self.bert(**inputs)

        # print("BERT output:", bert_outputs)
        cls_token = bert_outputs.last_hidden_state[:, 0]  # Use [CLS] token representation
        return self.dense(cls_token)
    
myclassifier = MyClassifier(cfg)
outputs = myclassifier(**inputs)
print("Output shape:", outputs)

Cascading token pruning schedule: [0.0, 0.09999999999999998, 0.18999999999999995, 0.2709999999999999, 0.3438999999999999, 0.4095099999999998, 0.46855899999999984, 0.5217030999999999, 0.5695327899999998, 0.6125795109999999, 0.6513215598999998, 0.6861894039099998]
Cascading head pruning schedule: [0.0, 0.09999999999999998, 0.18999999999999995, 0.2709999999999999, 0.3438999999999999, 0.4095099999999998, 0.46855899999999984, 0.5217030999999999, 0.5695327899999998, 0.6125795109999999, 0.6513215598999998, 0.6861894039099998]
Output shape: tensor([[-0.4968,  0.7481],
        [-0.0503,  0.7157]], grad_fn=<AddmmBackward0>)
