In [16]:
import torch
import torch.nn as nn
import math
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_from_disk

In [17]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "bert-base-uncased"
TOKENIZER = BertTokenizer.from_pretrained('./data/bert/bert_tokenizer')
INITIAL_MODEL = BertForSequenceClassification.from_pretrained('./fp32model', num_labels=2).to(DEVICE)
INITIAL_MODEL.eval()
DATASET = load_from_disk("./data/bert/imdb_dataset")

In [18]:
def encode_example(examples):
  return TOKENIZER(examples["text"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")

DATASET = DATASET.shuffle()
TRAIN_DATASET = DATASET["train"].select(range(3000))
TEST_DATASET  = DATASET["test"].select(range(1000))
TRAIN_DATASET = TRAIN_DATASET.map(encode_example, batched=True)
TEST_DATASET  = TEST_DATASET.map(encode_example, batched=True)
TRAIN_DATALOADER = DataLoader(TRAIN_DATASET, batch_size=64)
TEST_DATALOADER  = DataLoader(TEST_DATASET, batch_size=64)

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [19]:
numLinears = 0
# 遍历模型的所有模块，为attention线性层添加属性
for name, module in INITIAL_MODEL.named_modules():
  if isinstance(module, nn.Linear):
    if len(name.split(".")) > 5 and name.split(".")[4] == "attention":
      # 行数
      module.R = module.weight.shape[0]
      # 列数
      module.C = module.weight.shape[1]
      # 海森矩阵
      module.H = torch.zeros((module.C, module.C), device=module.weight.device)
      # 样本数量
      module.N = 0
      # 模块名
      module.name = name
      numLinears += 1

print(f'TOTAL {numLinears} ATTENTION LINEAR LAYERS')

TOTAL 48 ATTENTION LINEAR LAYERS


# 钩子函数定义

In [20]:
def hookComputeH(module, input, output):
    with torch.no_grad():
        input = input[0]
        batchSize = input.size(0)
        if len(input.shape) == 3:
            input = input.reshape((-1, input.size(2)))
        
        module.H *= module.N / batchSize
        module.N += batchSize
        input = math.sqrt(2 / module.N) * input
        module.H += input.t() @ input

        print(f"{module.name}\t\tH")
    return output

In [21]:
def pruneModule(
        module, 
        sparsity, 
        prunen=0, 
        prunem=0,  
        percdamp=.01,
        blocksize = 1
    ):
    H = module.H
    W = module.weight.data.clone()
    dead = (torch.diag(H) == 0)
    H[dead, dead] = 1
    W[:, dead] = 0
    
    damp = percdamp * torch.mean(torch.diag(H))
    diag = torch.arange(module.C, device=DEVICE)
    H[diag, diag] += damp
    # 计算H的逆矩阵的上三角矩阵
    HINV=   torch.linalg.cholesky(
                torch.cholesky_inverse(
                    torch.linalg.cholesky(H)
                ), 
                upper=True
            )
    
    for i1 in range(0, module.C, blocksize):
        i2 = min(i1 + blocksize, module.C)
        count = i2 - i1
        W1 = W[:, i1:i2].clone()
        Q1 = torch.zeros_like(W1)
        Err1 = torch.zeros_like(W1)
        Hinv1 = HINV[i1:i2, i1:i2]
        tmp = W1  / (torch.diag(Hinv1).reshape((1, -1))) ** 2
        thresh = torch.sort(tmp.flatten(), descending=True)[0][int(tmp.numel() * sparsity)]
        mask1 = tmp <= thresh
        for i in range(count):
            w = W1[:, i]
            d = Hinv1[i, i]
            if prunen != 0 and i % prunem == 0:
                tmp = W1[:, i:(i + prunem)] / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2
                mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True)
            q = w.clone()
            q[mask1[:, i]] = 0
            Q1[:, i] = q
            err1 = (w - q) / d
            W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
            Err1[:, i] = err1
        W[:, i1:i2] = Q1
        W[:, i2:] -= Err1.matmul(HINV[i1:i2, i2:])
    # 计算为零的元素的数量
    num_zero_elements = (W == 0).sum().item()
    # 计算总的元素数量
    total_elements = W.numel()
    # 计算稀疏率
    s = num_zero_elements / total_elements
    print(f"{module.name}\t\t{s}")
    module.weight.data = W

In [22]:
def hookTest(module, input, output):
    print("\n************************")
    print(f"Layer {module.name} has a test hook")
    print(f"input  shape {input[0].shape}")
    print(f"output shape {output[0].shape}")
    print("************************")
    return output

In [23]:
def removeAllHooks(model):
    for name, module in model.named_modules():
        module._forward_hooks.clear()

# 执行剪枝

In [24]:
def computeH(model, dataloader):
    model.eval()
    removeAllHooks(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if len(name.split(".")) > 5 and name.split(".")[4] == "attention":
                module.register_forward_hook(hookComputeH)
    with torch.no_grad():
        for batch in dataloader:
            input_ids = torch.stack(batch['input_ids']).to(DEVICE)
            attention_mask = torch.stack(batch['attention_mask']).to(DEVICE)
            model(input_ids = input_ids, attention_mask = attention_mask)
    removeAllHooks(model)

In [25]:
computeH(INITIAL_MODEL, TRAIN_DATALOADER)

bert.encoder.layer.0.attention.self.query		H
bert.encoder.layer.0.attention.self.key		H
bert.encoder.layer.0.attention.self.value		H
bert.encoder.layer.0.attention.output.dense		H
bert.encoder.layer.1.attention.self.query		H
bert.encoder.layer.1.attention.self.key		H
bert.encoder.layer.1.attention.self.value		H
bert.encoder.layer.1.attention.output.dense		H
bert.encoder.layer.2.attention.self.query		H
bert.encoder.layer.2.attention.self.key		H
bert.encoder.layer.2.attention.self.value		H
bert.encoder.layer.2.attention.output.dense		H
bert.encoder.layer.3.attention.self.query		H
bert.encoder.layer.3.attention.self.key		H
bert.encoder.layer.3.attention.self.value		H
bert.encoder.layer.3.attention.output.dense		H
bert.encoder.layer.4.attention.self.query		H
bert.encoder.layer.4.attention.self.key		H
bert.encoder.layer.4.attention.self.value		H
bert.encoder.layer.4.attention.output.dense		H
bert.encoder.layer.5.attention.self.query		H
bert.encoder.layer.5.attention.self.key		H
bert.encoder

In [26]:
model = INITIAL_MODEL
for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        if len(name.split(".")) > 5 and name.split(".")[4] == "attention":
            pruneModule(module, sparsity=0.5)

bert.encoder.layer.0.attention.self.query		0.5013020833333334
bert.encoder.layer.0.attention.self.key		0.5013020833333334
bert.encoder.layer.0.attention.self.value		0.5013020833333334
bert.encoder.layer.0.attention.output.dense		0.5013020833333334
bert.encoder.layer.1.attention.self.query		0.5013020833333334
bert.encoder.layer.1.attention.self.key		0.5013020833333334
bert.encoder.layer.1.attention.self.value		0.5013020833333334
bert.encoder.layer.1.attention.output.dense		0.5013020833333334
bert.encoder.layer.2.attention.self.query		0.5013020833333334
bert.encoder.layer.2.attention.self.key		0.5013020833333334
bert.encoder.layer.2.attention.self.value		0.5013020833333334
bert.encoder.layer.2.attention.output.dense		0.5013020833333334
bert.encoder.layer.3.attention.self.query		0.5013020833333334
bert.encoder.layer.3.attention.self.key		0.5013020833333334
bert.encoder.layer.3.attention.self.value		0.5013020833333334
bert.encoder.layer.3.attention.output.dense		0.5013020833333334
bert.enc

In [27]:
for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        if len(name.split(".")) > 5 and name.split(".")[4] == "attention":
            # 获取线性层的权重矩阵
            weights = module.weight.data
            # 计算为零的元素的数量
            num_zero_elements = (weights == 0).sum().item()
            # 计算总的元素数量
            total_elements = weights.numel()
            # 计算稀疏率
            sparsity = num_zero_elements / total_elements
            print(f"{name}\t\t{sparsity}")


bert.encoder.layer.0.attention.self.query		0.5013020833333334
bert.encoder.layer.0.attention.self.key		0.5013020833333334
bert.encoder.layer.0.attention.self.value		0.5013020833333334
bert.encoder.layer.0.attention.output.dense		0.5013020833333334
bert.encoder.layer.1.attention.self.query		0.5013020833333334
bert.encoder.layer.1.attention.self.key		0.5013020833333334
bert.encoder.layer.1.attention.self.value		0.5013020833333334
bert.encoder.layer.1.attention.output.dense		0.5013020833333334
bert.encoder.layer.2.attention.self.query		0.5013020833333334
bert.encoder.layer.2.attention.self.key		0.5013020833333334
bert.encoder.layer.2.attention.self.value		0.5013020833333334
bert.encoder.layer.2.attention.output.dense		0.5013020833333334
bert.encoder.layer.3.attention.self.query		0.5013020833333334
bert.encoder.layer.3.attention.self.key		0.5013020833333334
bert.encoder.layer.3.attention.self.value		0.5013020833333334
bert.encoder.layer.3.attention.output.dense		0.5013020833333334
bert.enc

In [28]:
def testAccuracy(model, data, device):
    model = model.to(device)
    model.eval()
    testlen = len(data)
    num_wrong = 0
    for i in range(testlen):
        ex = data[i]
        text = ex['text']
        lable = ex['label']
        input = TOKENIZER(text, return_tensors="pt",padding=True, truncation=True).to(device)
        output = model(**input)
        if int(torch.argmax(output.logits)) != lable:
            num_wrong += 1
    return 100-num_wrong/testlen*100

In [29]:
testAccuracy(INITIAL_MODEL, TEST_DATASET,DEVICE)

91.6

In [30]:
model.save_pretrained('./5bert')