In [5]:
import numpy as np
import random
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
import torch
import gc
import os
from tqdm import tqdm
import torch.nn as nn
torch.cuda.empty_cache()
gc.collect()

model_name = "models/llama3-8b/"

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

#GPU
print(torch.__version__)

from prettytable import PrettyTable
import textwrap
import time
import os

2.6.0.dev20241117+cu124


In [6]:
model = AutoModelForCausalLM.from_pretrained(model_name)

Loading checkpoint shards: 100%|██████████| 4/4 [00:59<00:00, 14.89s/it]


In [7]:
tokenizer_pruned = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)


In [8]:
# torch.save(model.state_dict(), "models/original_model.pt")
# print(f'size (GB): {os.path.getsize("models/original_model.pt")/1024e6}')

In [9]:
# Get the number of layers
num_layers = len(model.model.layers)
print(num_layers)

32


In [10]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [11]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())
original_param_count = count_parameters(model)
print(f"Total parameters before pruning: {original_param_count/1e9:0.2f}B")

Total parameters before pruning: 8.03B


In [12]:
def print_layer_info(model, layer_index):
    layer = model.model.layers[layer_index]
    print(f"Layer {layer_index}:")
    print(f"  q_proj: {layer.self_attn.q_proj.weight.shape}")
    print(f"  k_proj: {layer.self_attn.k_proj.weight.shape}")
    print(f"  v_proj: {layer.self_attn.v_proj.weight.shape}")
    print(f"  o_proj: {layer.self_attn.o_proj.weight.shape}")
    print("##############################################")
    print(f"  gate_proj: {layer.mlp.gate_proj.weight.shape}")
    print(f"    up_proj: {layer.mlp.up_proj.weight.shape}")
    print(f"  down_proj: {layer.mlp.down_proj.weight.shape}")
print("Before pruning:")
print_layer_info(model, 0)

Before pruning:
Layer 0:
  q_proj: torch.Size([4096, 4096])
  k_proj: torch.Size([1024, 4096])
  v_proj: torch.Size([1024, 4096])
  o_proj: torch.Size([4096, 4096])
##############################################
  gate_proj: torch.Size([14336, 4096])
    up_proj: torch.Size([14336, 4096])
  down_proj: torch.Size([4096, 14336])


In [13]:
def compute_neuron_pair_importance(gate_weight, up_weight):
    gate_max_abs = torch.max(gate_weight, dim=1).values + torch.abs(torch.min(gate_weight, dim=1).values)
    up_max_abs = torch.max(up_weight, dim=1).values + torch.abs(torch.min(up_weight, dim=1).values)
    importance_scores = gate_max_abs + up_max_abs
    return importance_scores


In [14]:
def prune_neuron_pairs(mlp, prune_percent):
    # Extract the weights from the MLP layers
    #  these weights are used to calculate each neuron's
    #  importance score in the next step.
    gate_weight = mlp.gate_proj.weight.data.float()
    up_weight = mlp.up_proj.weight.data.float()

    #Compute importance stores. Neurons with higher importance scores
    # are considered more important and less likely to be pruned.
    importance_scores = compute_neuron_pair_importance(gate_weight, up_weight)

    #Store the original number of neurons in the intermediate layer.
    original_intermediate_size = gate_weight.size(0)
    #Computes the number of neurons to prune.
    num_neuron_pairs_to_prune = min(int(prune_percent * original_intermediate_size), original_intermediate_size - 1)
    #Calculate the number of neurons to keep. The new intermediate size.
    k = original_intermediate_size - num_neuron_pairs_to_prune

    #Just check that there is no big error calculating k. We can't prune all the neurons.
    if k <= 0:
        raise ValueError(f"Invalid number of neuron pairs to keep: {k}. Adjust the prune_percent.")

    #Select the neuros to keep, by obtaining the indices to keep.
    _, indices_to_keep = torch.topk(importance_scores, k, largest=True, sorted=True)
    indices_to_keep = indices_to_keep.sort().values

    #create the new layers
    new_gate_proj = nn.Linear(mlp.gate_proj.in_features, k, bias=False)
    new_up_proj = nn.Linear(mlp.up_proj.in_features, k, bias=False)
    new_down_proj = nn.Linear(k, mlp.down_proj.out_features, bias=False)

    #copy weights to the new layers.
    new_gate_proj.weight.data = mlp.gate_proj.weight.data[indices_to_keep, :]
    new_up_proj.weight.data = mlp.up_proj.weight.data[indices_to_keep, :]
    new_down_proj.weight.data = mlp.down_proj.weight.data[:, indices_to_keep]

    #return new layers and intermediate size.
    return new_gate_proj, new_up_proj, new_down_proj, k


In [15]:
def update_model(model, prune_percent):
    new_intermediate_size = None

    #loop for each model layer.
    for idx, layer in enumerate(model.model.layers):
        #Since each layer is a LlamaDecoderLayer it contains multiple components
        # Attention, MLP and Layer norms. We're targetting MLP component
        # by accesing layer.mlp.
        mlp = layer.mlp

        #Call the prune_neiron_pairs with the layers and receiving the pruned.
        new_gate_proj, new_up_proj, new_down_proj, new_size = prune_neuron_pairs(mlp, prune_percent)

        #Replace the Origiginal Layers with Pruned Layers.
        mlp.gate_proj = new_gate_proj
        mlp.up_proj = new_up_proj
        mlp.down_proj = new_down_proj

        #new_intermediate_size only needs to be set once
        if new_intermediate_size is None:
            new_intermediate_size = new_size

    #Update the model config file.
    model.config.intermediate_size = new_intermediate_size

    return model


In [16]:
prune_percent = 0.8  # Prune 80% of neurons
pruned_model = update_model(model, prune_percent)

In [17]:
# Recalculate the number of parameters
pruned_param_count = count_parameters(pruned_model)
reduction_in_params = original_param_count - pruned_param_count
percentage_savings = (reduction_in_params / original_param_count) * 100

print(f"Pruned model parameters: {pruned_param_count}")
print(f"Reduction in parameters: {reduction_in_params}")
print(f"Percentage of weight savings: {percentage_savings:.2f}%")

Pruned model parameters: 3520860160
Reduction in parameters: 4509401088
Percentage of weight savings: 56.16%


In [18]:
print("After pruning:")
print_layer_info(pruned_model, 0)

After pruning:
Layer 0:
  q_proj: torch.Size([4096, 4096])
  k_proj: torch.Size([1024, 4096])
  v_proj: torch.Size([1024, 4096])
  o_proj: torch.Size([4096, 4096])
##############################################
  gate_proj: torch.Size([2868, 4096])
    up_proj: torch.Size([2868, 4096])
  down_proj: torch.Size([4096, 2868])


In [19]:
print(f"Total parameters after pruning: {count_parameters(pruned_model)/1e9:0.2f}B")

Total parameters after pruning: 3.52B


In [20]:
print(f'memory usage is {pruned_model.get_memory_footprint()/1024/1024/1024:0.2f}GB')

memory usage is 13.12GB


In [21]:
print("After pruning:")
print_layer_info(pruned_model, 0)

After pruning:
Layer 0:
  q_proj: torch.Size([4096, 4096])
  k_proj: torch.Size([1024, 4096])
  v_proj: torch.Size([1024, 4096])
  o_proj: torch.Size([4096, 4096])
##############################################
  gate_proj: torch.Size([2868, 4096])
    up_proj: torch.Size([2868, 4096])
  down_proj: torch.Size([4096, 2868])


In [22]:
print(f"Total parameters after pruning: {count_parameters(pruned_model)/1e9:0.2f}B")


Total parameters after pruning: 3.52B


In [23]:
pruned_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=2868, bias=False)
          (up_proj): Linear(in_features=4096, out_features=2868, bias=False)
          (down_proj): Linear(in_features=2868, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm

In [24]:
def output_wrapped(output_tokenizer):
    print("Generated Output:\n")
    for i, sentence in enumerate(output_tokenizer, 1):
        wrapped_sentence = textwrap.fill(sentence, width=80)
        print(f"Output {i}:\n{wrapped_sentence}\n")
        
def get_outputs(model, inputs, tokenizer, max_new_tokens=200):
    print("Input tensor shapes:")
    print(f"  input_ids: {inputs['input_ids'].shape}")
    print(f"  attention_mask: {inputs['attention_mask'].shape}")

    # Wrap model.generate to track tensor sizes
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=max_new_tokens,
        repetition_penalty=1.1,
        early_stopping=False,  # Can stop before reaching max_length
        eos_token_id=tokenizer.eos_token_id,
    )
    print(f"Output tensor shape: {outputs.shape}")
    return outputs

input_sentences = tokenizer_pruned("Tell a short history of humanity with happy ending.", return_tensors="pt")
input_sentences = tokenizer_pruned("capital of france is.", return_tensors="pt")
start_time = time.time()
model_4b_outputs_sentence = get_outputs(pruned_model, input_sentences, tokenizer_pruned, max_new_tokens=100)
end_time = time.time()
output_decoded_tokenizer = tokenizer_pruned.batch_decode(model_4b_outputs_sentence, skip_special_tokens=True)
print(f'time taken is = {(end_time-start_time)/60:0.2f} min')
output_wrapped(output_decoded_tokenizer)

Input tensor shapes:

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



  input_ids: torch.Size([1, 6])
  attention_mask: torch.Size([1, 6])
Output tensor shape: torch.Size([1, 106])
time taken is = 1.08 min
Generated Output:

Output 1:
capital of france is. The \(\maths)The \(\This)The \(\This)TheTheTheTheTheTheThe
TheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTh
eTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheT
heTheTheTheTheTheTheTheTheTheTheTheTheATheATheATheATheATheA



In [25]:
# size of model
torch.save(pruned_model.state_dict(), "models/pruned.pt")
print(f'size (GB): , {os.path.getsize("models/pruned.pt")/1024e6}')
# # os.remove("models/temp/temp_delme.pt")

size (GB): , 13.753475708007812


In [26]:
# print(model)
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params//1000/1000/1000:0.1f}B")
    return total_params
    
count_parameters(pruned_model)


+-------------------------------------------------+------------+
|                     Modules                     | Parameters |
+-------------------------------------------------+------------+
|            model.embed_tokens.weight            | 525336576  |
|      model.layers.0.self_attn.q_proj.weight     |  16777216  |
|      model.layers.0.self_attn.k_proj.weight     |  4194304   |
|      model.layers.0.self_attn.v_proj.weight     |  4194304   |
|      model.layers.0.self_attn.o_proj.weight     |  16777216  |
|       model.layers.0.mlp.gate_proj.weight       |  11747328  |
|        model.layers.0.mlp.up_proj.weight        |  11747328  |
|       model.layers.0.mlp.down_proj.weight       |  11747328  |
|      model.layers.0.input_layernorm.weight      |    4096    |
|  model.layers.0.post_attention_layernorm.weight |    4096    |
|      model.layers.1.self_attn.q_proj.weight     |  16777216  |
|      model.layers.1.self_attn.k_proj.weight     |  4194304   |
|      model.layers.1.sel

3520860160

In [27]:
# import torch
# import torch.nn as nn

# def prune_with_svd(matrix, prune_ratio):
#     """
#     Prune rows from a weight matrix using low-rank approximation.
#     """
#     weight = matrix.weight.data
#     rank = int(weight.size(0) * (1 - prune_ratio))
#     U, S, Vt = torch.svd_lowrank(weight, q=rank)
#     pruned_weight = (U @ torch.diag(S) @ Vt)
#     matrix.weight = nn.Parameter(pruned_weight)
#     return matrix

# def prune_rows(matrix, num_rows_to_keep):
#     """
#     Prune rows in a weight matrix.
#     """
#     with torch.no_grad():
#         pruned_weights = matrix.weight[:num_rows_to_keep].clone()
#         matrix.weight = nn.Parameter(pruned_weights)
#         if matrix.bias is not None:
#             pruned_bias = matrix.bias[:num_rows_to_keep].clone()
#             matrix.bias = nn.Parameter(pruned_bias)
#     return matrix

# def prune_transformer_layer_with_svd(layer, prune_ratio):
#     """
#     Prune a transformer layer using SVD-based low-rank approximation.
#     """
#     # Prune attention projections
#     layer.self_attn.q_proj = prune_with_svd(layer.self_attn.q_proj, prune_ratio)
#     layer.self_attn.k_proj = prune_with_svd(layer.self_attn.k_proj, prune_ratio)
#     layer.self_attn.v_proj = prune_with_svd(layer.self_attn.v_proj, prune_ratio)
#     layer.self_attn.o_proj = prune_with_svd(layer.self_attn.o_proj, prune_ratio)

#     # Adjust MLP layers
#     new_q_proj_dim = layer.self_attn.q_proj.weight.size(0)
#     layer.mlp.gate_proj = prune_rows(layer.mlp.gate_proj, new_q_proj_dim)
#     layer.mlp.up_proj = prune_rows(layer.mlp.up_proj, new_q_proj_dim)
#     layer.mlp.down_proj = prune_rows(layer.mlp.down_proj, new_q_proj_dim)

#     return layer

# def prune_llama_model(model, prune_ratio):
#     """
#     Apply pruning to all transformer layers in the model.
#     """
#     for i, layer in enumerate(model.model.layers):
#         print(f"Pruning Layer {i}...")
#         model.model.layers[i] = prune_transformer_layer_with_svd(layer, prune_ratio)
#     return model

# # Apply pruning to the model
# prune_ratio = 0.5
# model_pruned = prune_llama_model(model_pruned, prune_ratio)

# # Print the number of parameters after pruning
# print(f"Number of parameters after pruning: {sum(p.numel() for p in model_pruned.parameters())}")
