In [1]:
import numpy as np
import random
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
import torch
import gc
import os
from tqdm.auto import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from typing import Union, List
import copy
# torch.cuda.empty_cache()
# gc.collect()

model_name = "models/llama3-8b/"

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

#GPU
print(torch.__version__)

from prettytable import PrettyTable
import textwrap
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


2.6.0.dev20241117+cu124


In [2]:
model = AutoModelForCausalLM.from_pretrained(model_name)

Loading checkpoint shards: 100%|██████████| 4/4 [01:02<00:00, 15.70s/it]


In [3]:
tokenizer_pruned = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)

In [4]:
# torch.save(model.state_dict(), "models/original_model.pt")
# print(f'size (GB): {os.path.getsize("models/original_model.pt")/1024e6}')

In [5]:
# Get the number of layers
num_layers = len(model.model.layers)
print(num_layers)

32


In [6]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
original_param_count = count_parameters(model)
print(f"Total parameters before pruning: {original_param_count/1e9:0.2f}B")

Total parameters before pruning: 8.03B


In [8]:
def print_layer_info(model, layer_index):
    layer = model.model.layers[layer_index]
    print(f"Layer {layer_index}:")
    print(f"  q_proj: {layer.self_attn.q_proj.weight.shape}")
    print(f"  k_proj: {layer.self_attn.k_proj.weight.shape}")
    print(f"  v_proj: {layer.self_attn.v_proj.weight.shape}")
    print(f"  o_proj: {layer.self_attn.o_proj.weight.shape}")
    print("##############################################")
    print(f"  gate_proj: {layer.mlp.gate_proj.weight.shape}")
    print(f"    up_proj: {layer.mlp.up_proj.weight.shape}")
    print(f"  down_proj: {layer.mlp.down_proj.weight.shape}")
print("Before pruning:")
print_layer_info(model, 0)

Before pruning:
Layer 0:
  q_proj: torch.Size([4096, 4096])
  k_proj: torch.Size([1024, 4096])
  v_proj: torch.Size([1024, 4096])
  o_proj: torch.Size([4096, 4096])
##############################################
  gate_proj: torch.Size([14336, 4096])
    up_proj: torch.Size([14336, 4096])
  down_proj: torch.Size([4096, 14336])


### Get Model MACS

In [9]:
def get_model_macs(model, inputs) -> int:
    return profile_macs(model, inputs)

### Get number of channels to keep

In [10]:
def get_num_channels_to_keep(channels: int, prune_ratio: float) -> int:
    """A function to calculate the number of layers to PRESERVE after pruning
    Note that preserve_rate = 1. - prune_ratio
    """
    new_channels = channels * (1-prune_ratio) 
    ##################### YOUR CODE STARTS HERE #####################
    return int(round(new_channels))
    ##################### YOUR CODE ENDS HERE #####################

## Ranking Channels by Importance

As you can see, removing the first 30% of channels in all layers leads to significant accuracy reduction. One potential method to remedy the issue is to find the **less important** channel weights to remove. A popular criterion for importance is to use the Frobenius norm of the weights corresponding to each input channel:

> $importance_{i} = \|W_{i}\|_2, \;\; i = 0, 1, 2,\cdots, \#\mathrm{in\_channels}-1$

We can sort the channel weights from more important to less important, and then keep the frst $k$ channels for each layer.
### calculate Frobenius norm of a tensor

In [11]:
# function to sort the channels from important to non-important
def get_input_channel_importance(weight):
    importance_0 = torch.linalg.norm(weight, dim=0) # for mlp
    importance_1 = torch.linalg.norm(weight, dim=1) # for mlp
    print('importance dim=0 rows: ', importance_0.shape)
    print('importance dim=1 cols: ', importance_1.shape)
    return importance_1

    # in_channels = weight.shape[1]
    # importances = []
    # compute the importance for each input channel
    # for i_c in range(weight.shape[1]):
        # channel_weight = weight.detach()[:, i_c]
        ##################### YOUR CODE STARTS HERE #####################
        # importance = torch.linalg.norm(channel_weight) # for cnn
        ##################### YOUR CODE ENDS HERE #####################
        # importances.append(importance.view(1))
    # return torch.cat(importance, dim=0)

In [None]:
@torch.no_grad()
def apply_channel_sorting(model):
    model = copy.deepcopy(model)  # do not modify the original model
    # fetch all the conv and bn layers from the backbone
    all_gate    = []
    all_up      = []
    all_down    = []
    all_in_ln   = []
    all_post_ln  = []

    for layer in model.model.layers:
        for attr_name, target_list in [
            ("gate_proj", all_gate),
            ("up_proj", all_up),
            ("down_proj", all_down),
        ]:
            if hasattr(layer.mlp, attr_name):
                module = getattr(layer.mlp, attr_name)
                if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
                    target_list.append(module)
        if hasattr(layer, "input_layernorm"):
            module = getattr(layer, "input_layernorm")
            if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
                all_in_ln.append(module)
        if hasattr(layer, "post_attention_layernorm"):
            module = getattr(layer, "post_attention_layernorm")
            if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
                all_post_ln.append(module)
    # iterate through layers
    for i_gate in range(len(all_gate) - 1):
        # each channel sorting index, we need to apply it to:
        # - the output dimension of the previous conv
        # - the previous BN layer
        # - the input dimension of the next conv (we compute importance here)
        prev_gate   = all_gate[i_gate]
        prev_up     = all_up[i_gate]
        prev_down   = all_down[i_gate]
        prev_in_ln  = all_in_ln[i_gate]
        prev_post_ln= all_post_ln[i_gate]
        next_gate   = all_gate[i_gate + 1]
        next_up     = all_up[i_gate + 1]
        next_down   = all_down[i_gate + 1]
        
        print("1")
        print("-=-=-=-=")
        print(f"next_gate.weight.shape: {next_gate.weight.shape}")
        print(f"prev_gate.weight.shape: {prev_gate.weight.shape}")
        print("-=-=-=-=")
        print(f"next_up.weight.shape: {next_up.weight.shape}")
        print(f"prev_up.weight.shape: {prev_up.weight.shape}")
        print("-=-=-=-=")
        print(f"next_down.weight.shape: {next_down.weight.shape}")
        print(f"prev_down.weight.shape: {prev_down.weight.shape}")
        print("-=-=-=-=")
        
        print("#######################")
        print("2")
        # note that we always compute the importance according to input channels
        importance_gate = get_input_channel_importance(next_gate.weight)
        print(f'importance_gate.shape {importance_gate.shape}')
        print("-=-=-=-=")
        importance_up   = get_input_channel_importance(next_up.weight)
        print(f'importance_up.shape {importance_up.shape}')
        print("-=-=-=-=")
        importance_down = get_input_channel_importance(next_down.weight)
        print(f'importance_down.shape {importance_down.shape}')
        print("-=-=-=-=")
        
        print("#######################")
        print('3')
        # sorting from large to small
        sort_idx_gate   = torch.argsort(importance_gate, descending=True)
        sort_idx_up     = torch.argsort(importance_up, descending=True)
        sort_idx_down   = torch.argsort(importance_down, descending=True)
        # sort_idx_gate = torch.argsort(importance_gate[:next_gate.weight.size(1)], descending=True)
        # sort_idx_up   = torch.argsort(importance_up[:next_up.weight.size(1)], descending=True)
        # sort_idx_down = torch.argsort(importance_down[:next_down.weight.size(1)], descending=True)
        print(f"sort_idx_gate.shape: {sort_idx_gate.shape}")
        print(f"sort_idx_gate.max(): {sort_idx_gate.max()}")
        print(f"sort_idx_gate.min(): {sort_idx_gate.min()}")
        print("-=-=-=-=")
        print(f"sort_idx_up.shape: {sort_idx_up.shape}")
        print(f"sort_idx_up.max(): {sort_idx_up.max()}")
        print(f"sort_idx_up.min(): {sort_idx_up.min()}")
        print("-=-=-=-=")
        print(f"sort_idx_down.shape: {sort_idx_down.shape}")
        print(f"sort_idx_down.max(): {sort_idx_down.max()}")
        print(f"sort_idx_down.min(): {sort_idx_down.min()}")
        print("-=-=-=-=")

        # apply to previous conv and its following bn
        print("4")
        prev_down.weight.copy_(torch.index_select(
            prev_down.weight.detach(), 0, sort_idx_down))
        print("5")
        print("#####################")
        print(f"prev_down.weight.shape: {prev_down.weight.shape}")
        print(f"sort_idx_down.shape: {sort_idx_down.shape}")
        print(f'prev_ln: {prev_in_ln.weight.shape}')
        for tensor_name in ['weight']:
            tensor_to_apply = getattr(prev_in_ln, tensor_name)
            tensor_to_apply.copy_(
                torch.index_select(tensor_to_apply.detach(), 0, sort_idx_down)
            )
            tensor_to_apply = getattr(prev_post_ln, tensor_name)
            tensor_to_apply.copy_(
                torch.index_select(tensor_to_apply.detach(), 0, sort_idx_down)
            )
        print("6")

        # apply to the next_* input (hint: one line of code)
        print(f"next_gate.weight.shape: {next_gate.weight.shape}")
        print(f"sort_idx_gate.weight.shape: {sort_idx_gate.shape}")

        next_gate.weight.copy_(torch.index_select(
            next_gate.weight.detach(), 0, sort_idx_gate))
        next_up.weight.copy_(torch.index_select(
            next_up.weight.detach(), 0, sort_idx_up))
        next_down.weight.copy_(torch.index_select(
            next_down.weight.detach(), 0, sort_idx_down))
        print("7")

    return model

### Prune channels naively

In [14]:
@torch.no_grad()
def channel_prune(model: nn.Module,
                  prune_ratio: Union[List, float]) -> nn.Module:

    counters = {
        "gate_proj": 0,
        "up_proj": 0,
        "down_proj": 0
    }

    # n_conv = len([m for m in model.layers if isinstance(m, nn.Conv2d)])
    for attr_name in ["gate_proj", "up_proj", "down_proj"]:
        for i, layer in enumerate(model.model.layers):
            module = getattr(layer.mlp, attr_name)
            if (hasattr(module, "weight") 
                and isinstance(getattr(layer.mlp, attr_name).weight, nn.Parameter)):
                counters[attr_name] += 1

    n_gate_proj = counters["gate_proj"]
    n_up_proj = counters["up_proj"]
    n_down_proj = counters["down_proj"]

    # note that for the ratios, it affects the previous conv output and next
    # conv input, i.e., conv0 - ratio0 - conv1 - ratio1-...
    if isinstance(prune_ratio, list):
        assert len(prune_ratio) == n_gate_proj - 1
    else:  # convert float to list
        prune_ratio = [prune_ratio] * (n_gate_proj - 1)


    all_gate    = []
    all_up      = []
    all_down    = []
    all_lns      = []

    for layer in model.model.layers:
        for attr_name, target_list in [
            ("gate_proj", all_gate),
            ("up_proj", all_up),
            ("down_proj", all_down),
        ]:
            if hasattr(layer.mlp, attr_name):
                module = getattr(layer.mlp, attr_name)
                if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
                    target_list.append(module)
        if hasattr(layer, "input_layernorm"):
            module = getattr(layer, "input_layernorm")
            if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
                all_lns.append(module)

    # Print the collected modules
    print(f"Gate Proj Modules: {all_gate}")
    print(f"Up Proj Modules: {all_up}")
    print(f"Down Proj Modules: {all_down}")

    # apply pruning. we naively keep the first k channels
    assert len(all_gate) == len(all_up)
    assert len(all_gate) == len(all_down)
    
    for i_ratio, p_ratio in enumerate(prune_ratio):
        prev_gate = all_gate[i_ratio]
        prev_up = all_up[i_ratio]
        prev_down = all_down[i_ratio]
        next_gate = all_gate[i_ratio + 1]
        next_up = all_up[i_ratio + 1]
        next_down = all_down[i_ratio + 1]
        next_down = all_down[i_ratio + 1]
        prev_ln = all_lns[i_ratio]
        original_channels = prev_gate.out_channels  # same as next_gate.in_channels
        print('original_channels: ', original_channels)
        n_keep = get_num_channels_to_keep(original_channels, p_ratio)

        # prune the output of the previous conv and bn
        prev_gate.weight.set_(prev_gate.weight.detach()[:n_keep])
        prev_up.weight.set_(prev_up.weight.detach()[:n_keep])
        prev_down.weight.set_(prev_down.weight.detach()[:n_keep])
        prev_ln.weight.set_(prev_ln.weight.detach()[:n_keep])

        # prune the input of the next conv (hint: just one line of code)
        ##################### YOUR CODE STARTS HERE #####################
        next_gate.weight.set_(next_gate.weight.detach()[:,:n_keep])
        next_up.weight.set_(next_up.weight.detach()[:,:n_keep])
        next_down.weight.set_(next_down.weight.detach()[:,:n_keep])
        ##################### YOUR CODE ENDS HERE #####################

    return model


In [15]:
channel_pruning_ratio = 0.3  # pruned-out ratio

print(" * With sorting...")
sorted_model = apply_channel_sorting(model)
pruned_model = channel_prune(sorted_model, channel_pruning_ratio)

# print(" * Without sorting...")
# pruned_model = channel_prune(model, channel_pruning_ratio)

 * With sorting...


1
-=-=-=-=
next_gate.weight.shape: torch.Size([14336, 4096])
prev_gate.weight.shape: torch.Size([14336, 4096])
-=-=-=-=
next_up.weight.shape: torch.Size([14336, 4096])
prev_up.weight.shape: torch.Size([14336, 4096])
-=-=-=-=
next_down.weight.shape: torch.Size([4096, 14336])
prev_down.weight.shape: torch.Size([4096, 14336])
-=-=-=-=
#######################
2
importance dim=0 rows:  torch.Size([4096])
importance dim=1 cols:  torch.Size([14336])
importance_gate.shape torch.Size([14336])
-=-=-=-=
importance dim=0 rows:  torch.Size([4096])
importance dim=1 cols:  torch.Size([14336])
importance_up.shape torch.Size([14336])
-=-=-=-=
importance dim=0 rows:  torch.Size([14336])
importance dim=1 cols:  torch.Size([4096])
importance_down.shape torch.Size([4096])
-=-=-=-=
#######################
3
sort_idx_gate.shape: torch.Size([14336])
sort_idx_gate.max(): 14335
sort_idx_gate.min(): 0
-=-=-=-=
sort_idx_up.shape: torch.Size([14336])
sort_idx_up.max(): 14335
sort_idx_up.min(): 0
-=-=-=-=
sort_idx_

AttributeError: 'Linear' object has no attribute 'out_channels'

Finally, we compare the pruned models' accuracy with and without sorting.

# Verify the number of parameters after pruning

In [None]:
# Recalculate the number of parameters
pruned_param_count = count_parameters(pruned_model)
reduction_in_params = original_param_count - pruned_param_count
percentage_savings = (reduction_in_params / original_param_count) * 100

print(f"Pruned model parameters: {pruned_param_count/1e9:0.2f}")
print(f"Reduction in parameters: {reduction_in_params/1e9:0.2f}")
print(f"Percentage of weight savings: {percentage_savings:.2f}%")

# dummy_input = torch.randn(1, 3, 32, 32)
# pruned_macs = get_model_macs(pruned_model, dummy_input)
# print(f'pruned macs {pruned_macs}')

Pruned model parameters: 3520860160
Reduction in parameters: 4509401088
Percentage of weight savings: 56.16%


In [None]:
print("After pruning:")
print_layer_info(pruned_model, 0)

After pruning:
Layer 0:
  q_proj: torch.Size([4096, 4096])
  k_proj: torch.Size([1024, 4096])
  v_proj: torch.Size([1024, 4096])
  o_proj: torch.Size([4096, 4096])
##############################################
  gate_proj: torch.Size([2868, 4096])
    up_proj: torch.Size([2868, 4096])
  down_proj: torch.Size([4096, 2868])


In [None]:
print(f"Total parameters after pruning: {count_parameters(pruned_model)/1e9:0.2f}B")

Total parameters after pruning: 3.52B


In [None]:
print(f'memory usage is {pruned_model.get_memory_footprint()/1024/1024/1024:0.2f}GB')

memory usage is 13.12GB


In [None]:
print("After pruning:")
print_layer_info(pruned_model, 0)

After pruning:
Layer 0:
  q_proj: torch.Size([4096, 4096])
  k_proj: torch.Size([1024, 4096])
  v_proj: torch.Size([1024, 4096])
  o_proj: torch.Size([4096, 4096])
##############################################
  gate_proj: torch.Size([2868, 4096])
    up_proj: torch.Size([2868, 4096])
  down_proj: torch.Size([4096, 2868])


In [None]:
print(f"Total parameters after pruning: {count_parameters(pruned_model)/1e9:0.2f}B")


Total parameters after pruning: 3.52B


In [None]:
pruned_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=2868, bias=False)
          (up_proj): Linear(in_features=4096, out_features=2868, bias=False)
          (down_proj): Linear(in_features=2868, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm

In [None]:
def output_wrapped(output_tokenizer):
    print("Generated Output:\n")
    for i, sentence in enumerate(output_tokenizer, 1):
        wrapped_sentence = textwrap.fill(sentence, width=80)
        print(f"Output {i}:\n{wrapped_sentence}\n")
        
def get_outputs(model, inputs, tokenizer, max_new_tokens=200):
    print("Input tensor shapes:")
    print(f"  input_ids: {inputs['input_ids'].shape}")
    print(f"  attention_mask: {inputs['attention_mask'].shape}")

    # Wrap model.generate to track tensor sizes
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=max_new_tokens,
        repetition_penalty=1.1,
        early_stopping=False,  # Can stop before reaching max_length
        eos_token_id=tokenizer.eos_token_id,
    )
    print(f"Output tensor shape: {outputs.shape}")
    return outputs

input_sentences = tokenizer_pruned("Tell a short history of humanity with happy ending.", return_tensors="pt")
input_sentences = tokenizer_pruned("capital of france is.", return_tensors="pt")
start_time = time.time()
model_4b_outputs_sentence = get_outputs(pruned_model, input_sentences, tokenizer_pruned, max_new_tokens=100)
end_time = time.time()
output_decoded_tokenizer = tokenizer_pruned.batch_decode(model_4b_outputs_sentence, skip_special_tokens=True)
print(f'time taken is = {(end_time-start_time)/60:0.2f} min')
output_wrapped(output_decoded_tokenizer)

Input tensor shapes:

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



  input_ids: torch.Size([1, 6])
  attention_mask: torch.Size([1, 6])
Output tensor shape: torch.Size([1, 106])
time taken is = 1.08 min
Generated Output:

Output 1:
capital of france is. The \(\maths)The \(\This)The \(\This)TheTheTheTheTheTheThe
TheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTh
eTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheT
heTheTheTheTheTheTheTheTheTheTheTheTheATheATheATheATheATheA



In [None]:
# size of model
torch.save(pruned_model.state_dict(), "models/pruned.pt")
print(f'size (GB): , {os.path.getsize("models/pruned.pt")/1024e6}')
# # os.remove("models/temp/temp_delme.pt")

size (GB): , 13.753475708007812


In [None]:
# print(model)
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params//1000/1000/1000:0.1f}B")
    return total_params
    
count_parameters(pruned_model)


+-------------------------------------------------+------------+
|                     Modules                     | Parameters |
+-------------------------------------------------+------------+
|            model.embed_tokens.weight            | 525336576  |
|      model.layers.0.self_attn.q_proj.weight     |  16777216  |
|      model.layers.0.self_attn.k_proj.weight     |  4194304   |
|      model.layers.0.self_attn.v_proj.weight     |  4194304   |
|      model.layers.0.self_attn.o_proj.weight     |  16777216  |
|       model.layers.0.mlp.gate_proj.weight       |  11747328  |
|        model.layers.0.mlp.up_proj.weight        |  11747328  |
|       model.layers.0.mlp.down_proj.weight       |  11747328  |
|      model.layers.0.input_layernorm.weight      |    4096    |
|  model.layers.0.post_attention_layernorm.weight |    4096    |
|      model.layers.1.self_attn.q_proj.weight     |  16777216  |
|      model.layers.1.self_attn.k_proj.weight     |  4194304   |
|      model.layers.1.sel

3520860160

In [None]:
# import torch
# import torch.nn as nn

# def prune_with_svd(matrix, prune_ratio):
#     """
#     Prune rows from a weight matrix using low-rank approximation.
#     """
#     weight = matrix.weight.data
#     rank = int(weight.size(0) * (1 - prune_ratio))
#     U, S, Vt = torch.svd_lowrank(weight, q=rank)
#     pruned_weight = (U @ torch.diag(S) @ Vt)
#     matrix.weight = nn.Parameter(pruned_weight)
#     return matrix

# def prune_rows(matrix, num_rows_to_keep):
#     """
#     Prune rows in a weight matrix.
#     """
#     with torch.no_grad():
#         pruned_weights = matrix.weight[:num_rows_to_keep].clone()
#         matrix.weight = nn.Parameter(pruned_weights)
#         if matrix.bias is not None:
#             pruned_bias = matrix.bias[:num_rows_to_keep].clone()
#             matrix.bias = nn.Parameter(pruned_bias)
#     return matrix

# def prune_transformer_layer_with_svd(layer, prune_ratio):
#     """
#     Prune a transformer layer using SVD-based low-rank approximation.
#     """
#     # Prune attention projections
#     layer.self_attn.q_proj = prune_with_svd(layer.self_attn.q_proj, prune_ratio)
#     layer.self_attn.k_proj = prune_with_svd(layer.self_attn.k_proj, prune_ratio)
#     layer.self_attn.v_proj = prune_with_svd(layer.self_attn.v_proj, prune_ratio)
#     layer.self_attn.o_proj = prune_with_svd(layer.self_attn.o_proj, prune_ratio)

#     # Adjust MLP layers
#     new_q_proj_dim = layer.self_attn.q_proj.weight.size(0)
#     layer.mlp.gate_proj = prune_rows(layer.mlp.gate_proj, new_q_proj_dim)
#     layer.mlp.up_proj = prune_rows(layer.mlp.up_proj, new_q_proj_dim)
#     layer.mlp.down_proj = prune_rows(layer.mlp.down_proj, new_q_proj_dim)

#     return layer

# def prune_llama_model(model, prune_ratio):
#     """
#     Apply pruning to all transformer layers in the model.
#     """
#     for i, layer in enumerate(model.model.layers):
#         print(f"Pruning Layer {i}...")
#         model.model.layers[i] = prune_transformer_layer_with_svd(layer, prune_ratio)
#     return model

# # Apply pruning to the model
# prune_ratio = 0.5
# model_pruned = prune_llama_model(model_pruned, prune_ratio)

# # Print the number of parameters after pruning
# print(f"Number of parameters after pruning: {sum(p.numel() for p in model_pruned.parameters())}")
