In [1]:
import torch 
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from textpruner import TransformerPruner

# Load model directly
tokenizer = AutoTokenizer.from_pretrained("hw2942/bert-base-chinese-finetuning-financial-news-sentiment-v2")
model = AutoModelForSequenceClassification.from_pretrained("hw2942/bert-base-chinese-finetuning-financial-news-sentiment-v2",output_attentions=True)

# load the dataset 
ds = load_dataset("hw2942/financial-news-sentiment")


  return torch.load(checkpoint_file, map_location="cpu")


In [None]:
import torch
from tqdm import tqdm
# 扩展一个剪枝后的 tensor 到 768x768
def expand_weights_to_768x768(tensor, pruned_heads):
    # 初始化列表，存储拼接的张量
    tensors_to_concat = []
    start = 0
    # 按照 pruned_heads 进行插入
    for head_idx, keep in enumerate(pruned_heads):
        if keep == 1:
            # 如果该头保留，保留原来的 64x768
            tensors_to_concat.append(tensor[start:start+64, :])
            start += 64  # 每次移动 64 行
        else:
            # 如果该头被剪掉，插入 64x768 的零矩阵
            tensors_to_concat.append(torch.zeros(64, tensor.size(1)).to(tensor.device)) 
    # 将所有部分拼接起来，得到最终的 768x768 张量
    expanded_tensor = torch.cat(tensors_to_concat, dim=0)
    return expanded_tensor
def get_QKV_norm(ds, model, tokenizer, pruned_heads, batch_size=32, device='cuda'):
    model.to(device)
    norms_Q = torch.zeros(12, 12).to(device)
    norms_K = torch.zeros(12, 12).to(device)
    norms_V = torch.zeros(12, 12).to(device)
    all_titles = ds['train']['Title']
    for layer in range(12):
        attention_layer = model.bert.encoder.layer[layer].attention.self
        heads_Q_weight = attention_layer.query.weight
        heads_K_weight = attention_layer.key.weight
        heads_V_weight = attention_layer.value.weight
        
        heads_Q_weight = heads_Q_weight.to(device)
        heads_K_weight = heads_K_weight.to(device)
        heads_V_weight = heads_V_weight.to(device)
        for i in range(0, len(all_titles), batch_size):
            batch_titles = all_titles[i:i+batch_size]
            inputs = tokenizer(batch_titles, return_tensors='pt', padding=True, truncation=True)
            inputs = {key: val.to(device) for key, val in inputs.items()}
            outputs = model(**inputs)
            sequence_output = outputs.logits
            loss = torch.norm(sequence_output)
            model.zero_grad()
            loss.backward()
            # 处理剪枝后的权重，扩展到 768x768
            GQ  = expand_weights_to_768x768(heads_Q_weight.grad, pruned_heads[layer])
            GK  = expand_weights_to_768x768(heads_K_weight.grad, pruned_heads[layer])
            tensor_Q_reshaped = GQ.view(12, 64, 768)
            tensor_K_reshaped = GK.view(12, 64, 768)


            norms_Q[layer] += l2_Q_norms
            norms_K[layer] += l2_K_norms
            norms_V[layer] += l2_V_norms
        model.zero_grad()
    norms_Q = norms_Q.cpu() / len(all_titles)
    norms_K = norms_K.cpu() / len(all_titles)
    norms_V = norms_V.cpu() / len(all_titles)
    norms = norms_Q * norms_K * norms_V
    mean_norms = norms.mean()
    std_norms = norms.std()
    normalized = (norms - mean_norms) / std_norms
    return normalized
def get_new_head_mask_basedonG(head_mask_previous,Gnorm):
    selected_matrix = head_mask_previous * (Gnorm.max()-Gnorm)
    if (torch.all(torch.eq(selected_matrix, 0)).item()==True) and (torch.all(torch.eq(head_mask_previous, 0)).item()==False):
        layer = torch.nonzero(head_mask_previous)[0][0]
        head = torch.nonzero(head_mask_previous)[0][1]
        head_mask_previous[layer][head] = 0
        current_head_mask = head_mask_previous
        return current_head_mask
    layer=torch.argmax(selected_matrix) // selected_matrix.size(1)
    head=torch.argmax(selected_matrix) % selected_matrix.size(1)
    print('layer:',layer,'\nhead:',head)
    head_mask_previous[layer][head] = 0
    current_head_mask = head_mask_previous
    return current_head_mask
def get_acc(ds,model,tokenizer,device='cuda'):
    total = 0
    right = 0
    model.to(device)
    for data in ds['train']:
        inputs = tokenizer(data['Title'],return_tensors='pt').to(device)
        outputs = model(**inputs)
        total  += 1
        if torch.max(outputs[0][0].softmax(0),dim=0).indices==data['labels']:
            right += 1
    return right/total
def prune_based_on_G(ds,model,tokenizer):
    head_mask = torch.tensor([[1]*12]*12)
    accs=[]
    pruner = TransformerPruner(model)
    model.to('cuda')
    accs.append(get_acc(ds,model,tokenizer))
    print('pruned heads:',0,'acc:',accs[0])
    for i in tqdm(range(144)):
        Gnorm = get_QKV_norm(ds,model,tokenizer,pruned_heads=head_mask)
        head_mask = get_new_head_mask_basedonG(head_mask,Gnorm)
        model.to('cuda')
        print('head_mask',head_mask)
        pruner.prune(head_mask=head_mask,ffn_mask=torch.tensor([[1]*3072]*12),save_model=False)
        acc_score = get_acc(ds,model,tokenizer)
        accs.append(acc_score)
        print('pruned heads:',i+1,'acc:',acc_score)
    return accs