## Wanda实践

本节将带大家一起揭开Wanda算法的神秘面纱～

In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_llm(model_name, cache_dir="llm_weights"):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir, 
        low_cpu_mem_usage=True, 
        device_map="auto"
    )

    model.seqlen = model.config.max_position_embeddings 
    return model

In [3]:
# model类型支持llama和opt类型
model_name = "Enoch/llama-7b-hf"
# model_name = "facebook/opt-125m"
cache_dir = "../llm_weights"
print(f"loading llm model {model_name}")
model = get_llm(model_name, cache_dir)
model.eval()

loading llm model Enoch/llama-7b-hf


Loading checkpoint shards: 100%|██████████| 33/33 [00:07<00:00,  4.57it/s]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
 

In [4]:
def find_layers(module, layers=[nn.Linear], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


In [5]:
model.model

LlamaModel(
  (embed_tokens): Embedding(32000, 4096, padding_idx=0)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)

In [6]:
def check_sparsity(model):
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    if "llama" in model_name:
        layers = model.model.layers
    elif "opt" in model_name:
        layers = model.model.decoder.layers
    count = 0 
    total_params = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        sub_count = 0
        sub_params = 0
        for name in subset:
            W = subset[name].weight.data
            count += (W==0).sum().item()
            total_params += W.numel()

            sub_count += (W==0).sum().item()
            sub_params += W.numel()

        print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")

    model.config.use_cache = use_cache 
    return float(count)/total_params 

In [7]:
check_sparsity(model)

layer 0 sparsity 0.000001
layer 1 sparsity 0.000001
layer 2 sparsity 0.000001
layer 3 sparsity 0.000001
layer 4 sparsity 0.000001
layer 5 sparsity 0.000001
layer 6 sparsity 0.000001
layer 7 sparsity 0.000001
layer 8 sparsity 0.000001
layer 9 sparsity 0.000001
layer 10 sparsity 0.000001
layer 11 sparsity 0.000001
layer 12 sparsity 0.000001
layer 13 sparsity 0.000001
layer 14 sparsity 0.000001
layer 15 sparsity 0.000001
layer 16 sparsity 0.000001
layer 17 sparsity 0.000001
layer 18 sparsity 0.000001
layer 19 sparsity 0.000001
layer 20 sparsity 0.000001
layer 21 sparsity 0.000001
layer 22 sparsity 0.000001
layer 23 sparsity 0.000001
layer 24 sparsity 0.000001
layer 25 sparsity 0.000001
layer 26 sparsity 0.000001
layer 27 sparsity 0.000001
layer 28 sparsity 0.000001
layer 29 sparsity 0.000001
layer 30 sparsity 0.000001
layer 31 sparsity 0.000001


1.0977446106431398e-06

In [8]:
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    # Load train and test datasets
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    # Encode datasets
    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    # Generate samples from training set
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc

### 模型的评价指标-困惑度PPL
模型的困惑度是衡量语言模型性能的指标之一，通常用于评估模型在给定数据集上的预测能力。在自然语言处理中，困惑度是指模型对给定序列中下一个词的预测的困惑程度或不确定性程度。困惑度越低，表示模型在预测下一个词时越准确，即模型对数据集的预测更加自信。通常情况下，困惑度是一个正数，值越低表示模型性能越好。因此，困惑度可以作为评估语言模型质量和性能的重要指标之一。对于语句$s=w_1, w_2, w_3, \ldots, w_n，其困惑度PPL可表示为：
$
$$
\begin{aligned}
& \text { PPL }=p(s)^{-\frac{1}{n}} \\
& =p\left(w_1, w_2, \ldots, w_n\right)^{\frac{1}{n}} \\
& =\sqrt[n]{\frac{1}{p\left(w_1, w_2, \ldots, w_n\right)}} \\
& =\sqrt[n]{\prod_{i=1}^n \frac{1}{p\left(w_i \mid w_1, w_2, \ldots, w_{i-1}\right)}} \\
&
\end{aligned}
$$

In [9]:
def eval_ppl_wikitext(model, testenc, bs=1, device=None):
    # Get input IDs
    testenc = testenc.input_ids

    # Calculate number of samples
    nsamples = testenc.numel() // model.seqlen

    # List to store negative log likelihoods
    nlls = []
    print(f"nsamples {nsamples}")

    # Loop through each batch
    for i in range(0,nsamples,bs):
        if i % 50 == 0:
            print(f"sample {i}")
        # Calculate end index
        j = min(i+bs, nsamples)

        # Prepare inputs and move to device
        inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j-i, model.seqlen)

        # Forward pass through the model
        lm_logits = model(inputs).logits

        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        # Compute loss
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
        print(f"loss {loss}")
        # Calculate negative log likelihood
        neg_log_likelihood = loss.float() * model.seqlen * (j-i)

        # Append to list of negative log likelihoods
        nlls.append(neg_log_likelihood)

    # Compute perplexity
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

    # Empty CUDA cache to save memory
    torch.cuda.empty_cache()

    return ppl.item()

In [10]:
def eval_ppl(model, tokenizer, device=torch.device("cuda:0")):
    # Print status
    print(f"evaluating on wikitext2")

    # Get the test loader
    _, testloader = get_wikitext2(nsamples=128, seed=0, seqlen=model.seqlen, tokenizer=tokenizer )

    # Evaluate ppl in no grad context to avoid updating the model
    with torch.no_grad():
        ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
    return ppl_test 

In [11]:
# init tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [12]:
def prepare_calibration_input(model, dataloader, device):
    # 保存模型的原始缓存配置
    use_cache = model.config.use_cache
    # 禁用模型的缓存功能，确保每次输入都会被模型重新处理
    model.config.use_cache = False
    # 获取模型的所有层
    if "llama" in model_name:
        layers = model.model.layers
    elif "opt" in model_name:
        layers = model.model.decoder.layers

    # 获取模型参数的数据类型
    dtype = next(iter(model.parameters())).dtype
    # 创建一个指定形状和数据类型的零张量，用来存储输入
    inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
    # 设置不需要计算梯度，因为这里只是为了校准模型
    inps.requires_grad = False
    # 初始化一个缓存字典，用于存储处理过程中的信息
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    # 定义一个内部类，用于捕获模型第一层的输入
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            # 存储输入到inps张量中，并更新缓存信息
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            if "llama" in model_name:
                cache['position_ids'] = kwargs['position_ids']
            # 抛出异常以中断前向传播
            raise ValueError
    # 将模型的第一层替换为Catcher类实例
    layers[0] = Catcher(layers[0])
    # 遍历数据加载器中的数据批次，并处理
    for batch in dataloader:
        try:
            # 将数据批次送入模型进行处理，由于Catcher的存在会引发异常
            model(batch[0].to(device))
        except ValueError:
            # 捕获异常，但不进行任何操作，目的是为了执行Catcher中的代码
            pass 
    # 恢复模型的第一层为原来的层
    layers[0] = layers[0].module

    # 创建一个与输入张量形状和类型相同的零张量，用于存储输出
    outs = torch.zeros_like(inps)
    # 从缓存中取出attention_mask和position_ids
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']
    # 恢复模型的缓存设置
    model.config.use_cache = use_cache
    
    if "llama" in model_name:
        position_ids = cache['position_ids']
        return inps, outs, attention_mask, position_ids 
    elif "opt" in model_name:
        return inps, outs, attention_mask

In [13]:
class WrappedGPT:
    """
    这个类封装了一个GPT层,用于特定的操作。
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        # 存储传入的层
        self.layer = layer
        # 从层的权重中获取设备信息
        self.dev = self.layer.weight.device
        # 获取权重的行数（输出维度大小）
        self.rows = layer.weight.data.shape[0]
        # 获取权重的列数（输入维度大小）
        self.columns = layer.weight.data.shape[1]

        # 初始化一个用于存储每列的缩放因子的向量，大小与权重的列数相同
        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        # 初始化样本数量为0
        self.nsamples = 0

        # 存储层的ID和名称，这可能用于区分和跟踪不同的层
        self.layer_id = layer_id 
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        # 如果输入是二维的，添加一个维度使其成为三维的
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        # 获取批次的大小
        tmp = inp.shape[0]
        # 如果层是线性层，检查输入的维度，并可能将其重塑
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            # 转置输入，因为PyTorch中的nn.Linear期望批次在第二维
            inp = inp.t()

        # 更新scaler_row向量，考虑到新加入的批次
        self.scaler_row *= self.nsamples / (self.nsamples+tmp)
        # 更新样本数量
        self.nsamples += tmp

        # 将输入转为float32类型
        inp = inp.type(torch.float32)
        # 更新scaler_row，根据新的输入调整每一列的缩放因子
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2  / self.nsamples

In [None]:
def prune_wanda_zscores(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    ##### calucalte outlier ratio
    
    
    
    all_layer_ratio=[]
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    print("loading calibdation data")
    dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=2048,tokenizer=tokenizer)
    print("dataset loading complete")
    with torch.no_grad():
        
        if "OPT" in model.__class__.__name__:
            
            inps, outs, attention_mask, position_ids = prepare_calibration_input_opt(model, dataloader, device)
        else:
            
            inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)



    print ("inps",inps)
    if "opt" in args.model:
        layers=model.model.decoder.layers
        
    else:
        layers = model.model.layers


    for i in range(len(layers)):
        layer = layers[i]

        subset = find_layers(layer)

        if f"model.layers.{i}" in model.hf_device_map:   ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
            dev = model.hf_device_map[f"model.layers.{i}"]
            # inps, outs,  position_ids = inps.to(dev), outs.to(dev),  position_ids.to(dev)
            inps, outs, position_ids = inps.to(dev), outs.to(dev), position_ids.to(dev)

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(args.nsamples):
            with torch.no_grad():
                if "OPT" in model.__class__.__name__:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
                else:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        for h in handles:
            h.remove()
            
            
        layer_wmetric=[]
        layer_ametric=[]

        for name in subset:
            


            

            print(f"pruning layer {i} name {name}")
            W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))


            activation_data=torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
            layer_wmetric.append(W_metric)   
            layer_ametric.append(activation_data) 
                

        for j in range(args.nsamples):
            with torch.no_grad():
                if "OPT" in model.__class__.__name__:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
                else:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        inps, outs = outs, inps





        layer_wmetric = torch.cat([torch.flatten(x.cpu()) for x in layer_wmetric])
        layer_ametic = torch.cat([torch.flatten(x.cpu()) for x in layer_ametric])
        # OWL
        # for out_ratio in [args.Hyper_m]:
            
        #     out_ratio_layer=check_outlier_mean(layer_wmetric,out_ratio)
        #     print ("layer outlier ratio",out_ratio,out_ratio_layer)
            
        # TODO:
        # z_scores
        out_ratio_layer = get_z_scores_sum(layer_ametic)
        print ("layer zscores ratio",out_ratio_layer)
        all_layer_ratio.append(out_ratio_layer)
        
        


    print ("before adjustment",all_layer_ratio)

    
 
    

    all_layer_ratio=np.array(all_layer_ratio)
    # OWL
    # all_layer_ratio = ((all_layer_ratio - all_layer_ratio.min()) * (1/(all_layer_ratio.max() - all_layer_ratio.min()) * args.Lamda*2))
    
    # max-min normalization
    # all_layer_ratio = ((all_layer_ratio - all_layer_ratio.min()) * (1/(all_layer_ratio.max() - all_layer_ratio.min())))
    all_layer_ratio = ((all_layer_ratio - all_layer_ratio.min()) * (1/(all_layer_ratio.max() - all_layer_ratio.min()) * args.Lamda*2))
    
    all_layer_ratio=all_layer_ratio-np.mean(all_layer_ratio)+(1-args.sparsity_ratio)
   
    print (all_layer_ratio,np.mean(all_layer_ratio),np.max(all_layer_ratio),np.min(all_layer_ratio))
    # st()
   
    
                
        
    
    print ("after adjustment",all_layer_ratio  )
    


    model.config.use_cache = use_cache 
    torch.cuda.empty_cache()
    ############## prune
    
    
    
    
    
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    print("loading calibdation data")
    dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=2048,tokenizer=tokenizer)
    print("dataset loading complete")
    with torch.no_grad():
        
        if "OPT" in model.__class__.__name__:
            
            inps, outs, attention_mask, position_ids = prepare_calibration_input_opt(model, dataloader, device)
        else:
            
            inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)



    print ("inps",inps)
    if "opt" in args.model:
        layers=model.model.decoder.layers
        
    else:
        layers = model.model.layers


    for i in range(len(layers)):
        layer = layers[i]

        subset = find_layers(layer)

        if f"model.layers.{i}" in model.hf_device_map:   ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
            dev = model.hf_device_map[f"model.layers.{i}"]
            # inps, outs,  position_ids = inps.to(dev), outs.to(dev),  position_ids.to(dev)
            inps, outs,  position_ids = inps.to(dev), outs.to(dev),  position_ids.to(dev)

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(args.nsamples):
            with torch.no_grad():
                if "OPT" in model.__class__.__name__:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
                else:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        for h in handles:
            h.remove()
            
            
        

        for name in subset:
            

            print(f"pruning layer {i} name {name}")
            W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
            
            # TODO:add entropy
            # pr = torch.abs(W_metric)/torch.sum(torch.abs(W_metric), dim=0)
            # pc = torch.abs(W_metric)/torch.sum(torch.abs(W_metric), dim=1).reshape(-1, 1)
            # W_metric = torch.abs((-pr * torch.log(pr)) - (pc * torch.log(pc)))
            
            # TODO:add wentropy
            # # 定义一个很小的常数 epsilon 避免计算对数时的问题
            # epsilon = 1e-10
            # W_metric += epsilon
            # # 归一化张量使其元素和为1
            # probabilities = W_metric / W_metric.sum()
            # # 计算概率分布的对数
            # log_probabilities = torch.log(probabilities)
            # # 为每个元素计算熵
            # H = -probabilities * log_probabilities
            # W_metric = torch.abs(W_metric * H)

            
            
            activation_data=torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))

            layer_sparsity_ratio= 1-all_layer_ratio[i]
            
            
            if layer_sparsity_ratio<=0:
                layer_sparsity_ratio=0.01

            W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
            if prune_n != 0:
                # structured n:m sparsity
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:,ii:(ii+prune_m)].float()
                        W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
            else:
                sort_res = torch.sort(W_metric, dim=-1, stable=True)

                if args.use_variant:
                    # wanda variant 
                    tmp_metric = torch.cumsum(sort_res[0], dim=1)
                    sum_before = W_metric.sum(dim=1)

                    alpha = 0.4
                    alpha_hist = [0., 0.8]
                    W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before)
                    while (torch.abs(cur_sparsity - layer_sparsity_ratio)>0.001) and (alpha_hist[1]-alpha_hist[0]>=0.001):
                        if cur_sparsity > layer_sparsity_ratio:
                            alpha_new = (alpha + alpha_hist[0]) / 2.0
                            alpha_hist[1] = alpha
                        else:
                            alpha_new = (alpha + alpha_hist[1]) / 2.0
                            alpha_hist[0] = alpha

                        alpha = alpha_new 
                        W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before)
                    print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}")
                else:
                    # unstructured pruning
                    indices = sort_res[1][:,:int(W_metric.shape[1]*layer_sparsity_ratio)]
                    W_mask.scatter_(1, indices, True)
#             print ("W_mask",W_mask)
            subset[name].weight.data[W_mask] = 0  ## set weights to zero 

        for j in range(args.nsamples):
            with torch.no_grad():
                if "OPT" in model.__class__.__name__:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
                else:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        inps, outs = outs, inps





    model.config.use_cache = use_cache 
    torch.cuda.empty_cache()


In [31]:
def prune_wanda(model, tokenizer, device=torch.device("cuda:0"),nsamples=128, seed=0, sparsity_ratio=0.2, prune_n=0, prune_m=0):
    # 保存原始模型缓存配置，并暂时禁用它。
    # 确保修剪校准期间不使用之前的计算结果。
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    # 开始加载校准数据，并在加载完成后通知。
    print("loading calibration data")
    dataloader, _ = get_wikitext2(
        nsamples=nsamples, 
        seed=seed, 
        seqlen=model.seqlen, 
        tokenizer=tokenizer
    )
    print("dataset loading complete")

    # 准备校准输入，同时不追踪梯度以提高效率。
    with torch.no_grad():
        if "llama" in model_name:
            inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)
        elif "opt" in model_name:
            inps, outs, attention_mask = prepare_calibration_input(model, dataloader, device)

    # 获取模型内部的层列表。
    if "llama" in model_name:
        layers = model.model.layers
    elif "opt" in model_name:
        layers = model.model.decoder.layers

    # 遍历每一层进行修剪操作。
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)  # 查找需要修剪的层的子集。

        # 初始化一个字典用于存储层的包装器。
        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        # 定义添加批处理数据的函数，用于钩子中。
        def add_batch(name):
            # 定义临时函数，获取输入输出并添加到对应的包装层。
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp

        # 注册前向钩子，并将句柄添加到列表以便之后移除。
        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        # 对每个校准样本执行前向传播，并收集数据。
        for j in range(nsamples):
            with torch.no_grad():
                if "llama" in model_name:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
                elif "opt" in model_name:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]


        # 移除之前注册的所有钩子。
        for h in handles:
            h.remove()

        # 对每个子集中的层进行修剪操作。
        for name in subset:
            print(f"pruning layer {i} name {name}")
            # 计算修剪度量，基于权重的绝对值和对应的激活函数
            W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))
            # 初始化修剪掩码，开始时全为False。
            W_mask = (torch.zeros_like(W_metric) == 1)

            # 如果设置了结构化修剪参数，则执行结构化修剪。
            if prune_n != 0:
                # 结构化n:m稀疏性
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:, ii:(ii+prune_m)].float()
                        W_mask.scatter_(1, ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True)
            else:
                # 非结构化修剪
                sort_res = torch.sort(W_metric, dim=-1, stable=True)
                indices = sort_res[1][:, :int(W_metric.shape[1] * sparsity_ratio)]
                W_mask.scatter_(1, indices, True)

            # 最后将掩码为True的权重值设为零，完成修剪。
            subset[name].weight.data[W_mask] = 0

        # 再次对每个样本执行前向传播，可能用于验证修剪效果。
        for j in range(nsamples):
            with torch.no_grad():
                if "llama" in model_name:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
                elif "opt" in model_name:
                    outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]


        # 交换输入和输出的引用，为下一轮或后续操作准备。
        inps, outs = outs, inps

    # 恢复模型的缓存设置。
    model.config.use_cache = use_cache 

    # 清空CUDA缓存，以减少内存消耗。
    torch.cuda.empty_cache()

In [32]:
# prune model
prune_wanda(model, tokenizer, sparsity_ratio=0.3)

loading calibration data
dataset loading complete
pruning layer 0 name self_attn.q_proj
pruning layer 0 name self_attn.k_proj
pruning layer 0 name self_attn.v_proj
pruning layer 0 name self_attn.o_proj
pruning layer 0 name mlp.gate_proj
pruning layer 0 name mlp.down_proj
pruning layer 0 name mlp.up_proj
pruning layer 1 name self_attn.q_proj
pruning layer 1 name self_attn.k_proj
pruning layer 1 name self_attn.v_proj
pruning layer 1 name self_attn.o_proj
pruning layer 1 name mlp.gate_proj
pruning layer 1 name mlp.down_proj
pruning layer 1 name mlp.up_proj
pruning layer 2 name self_attn.q_proj
pruning layer 2 name self_attn.k_proj
pruning layer 2 name self_attn.v_proj
pruning layer 2 name self_attn.o_proj
pruning layer 2 name mlp.gate_proj
pruning layer 2 name mlp.down_proj
pruning layer 2 name mlp.up_proj
pruning layer 3 name self_attn.q_proj
pruning layer 3 name self_attn.k_proj
pruning layer 3 name self_attn.v_proj
pruning layer 3 name self_attn.o_proj
pruning layer 3 name mlp.gate_pro

In [33]:
save_model = "wanda/{}".format(model_name.split('/')[-1])
# print(save_model)
model.save_pretrained(save_model)

In [34]:
print("*"*30)
sparsity_ratio = check_sparsity(model)
print(f"pruned model sparsity sanity check {sparsity_ratio:.4f}")
print("*"*30)

******************************
layer 0 sparsity 0.299840
layer 1 sparsity 0.299840
layer 2 sparsity 0.299840
layer 3 sparsity 0.299840
layer 4 sparsity 0.299840
layer 5 sparsity 0.299840
layer 6 sparsity 0.299840
layer 7 sparsity 0.299840
layer 8 sparsity 0.299840
layer 9 sparsity 0.299840
layer 10 sparsity 0.299840
layer 11 sparsity 0.299840
layer 12 sparsity 0.299840
layer 13 sparsity 0.299840
layer 14 sparsity 0.299840
layer 15 sparsity 0.299840
layer 16 sparsity 0.299840
layer 17 sparsity 0.299840
layer 18 sparsity 0.299840
layer 19 sparsity 0.299840
layer 20 sparsity 0.299840
layer 21 sparsity 0.299840
layer 22 sparsity 0.299840
layer 23 sparsity 0.299840
layer 24 sparsity 0.299840
layer 25 sparsity 0.299840
layer 26 sparsity 0.299840
layer 27 sparsity 0.299840
layer 28 sparsity 0.299840
layer 29 sparsity 0.299840
layer 30 sparsity 0.299840
layer 31 sparsity 0.299840
pruned model sparsity sanity check 0.2998
******************************


In [35]:
ppl_test = eval_ppl(model, tokenizer)
print(f"pruned model wikitext perplexity {ppl_test}")

evaluating on wikitext2
nsamples 166
sample 0
loss 1.4560546875
loss 2.009765625
loss 2.2265625
loss 2.041015625
loss 1.5732421875
loss 1.7021484375
loss 1.4521484375
loss 1.3447265625
loss 1.70703125
loss 1.8349609375
loss 1.8837890625
loss 1.8271484375
loss 1.64453125
loss 1.873046875
loss 1.9560546875
loss 2.041015625
loss 1.96875
loss 2.025390625
loss 2.1484375
loss 1.9384765625
loss 1.83203125
loss 1.5947265625
loss 1.560546875
loss 1.9560546875
loss 1.978515625
loss 1.916015625
loss 1.96875
loss 1.912109375
loss 2.052734375
loss 1.857421875
loss 2.279296875
loss 2.078125
loss 2.041015625
loss 1.8427734375
loss 1.6396484375
loss 1.5703125
loss 1.501953125
loss 1.7109375
loss 1.7392578125
loss 1.9580078125
loss 1.8251953125
loss 1.3515625
loss 1.16796875
loss 1.376953125
loss 1.1728515625
loss 1.283203125
loss 1.47265625
loss 1.822265625
loss 2.275390625
loss 2.291015625
sample 50
loss 2.171875
loss 2.06640625
loss 1.9716796875
loss 1.8466796875
loss 1.962890625
loss 2.013671875
lo

In [36]:
# load sparse model
# model = AutoModelForCausalLM.from_pretrained(save_model, torch_dtype='auto')

In [37]:
# 剪枝模型的输出结果
output_ids = model.generate(input_ids)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

It takes a great deal of bravery to stand up and say that you are a Christian.
