In [1]:
import numpy as np
import random
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import gc
import os

torch.cuda.empty_cache()
gc.collect()

os.environ['TF_ENABLE_ONEDNN_OPTS']='0'

model_name = "models/llama3-8b/"

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

import torch
#GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.__version__

  from .autonotebook import tqdm as notebook_tqdm


cuda


'2.6.0.dev20241117+cu124'

In [2]:
model = AutoModelForCausalLM.from_pretrained(model_name,
                                            #  device_map="auto"
                                             )

Loading checkpoint shards: 100%|██████████| 4/4 [01:04<00:00, 16.03s/it]


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)

In [5]:
# print(model)
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params//1000/1000/1000:0.1f}B")
    return total_params
    
count_parameters(model)


+-------------------------------------------------+------------+
|                     Modules                     | Parameters |
+-------------------------------------------------+------------+
|            model.embed_tokens.weight            | 525336576  |
|      model.layers.0.self_attn.q_proj.weight     |  16777216  |
|      model.layers.0.self_attn.k_proj.weight     |  4194304   |
|      model.layers.0.self_attn.v_proj.weight     |  4194304   |
|      model.layers.0.self_attn.o_proj.weight     |  16777216  |
|       model.layers.0.mlp.gate_proj.weight       |  58720256  |
|        model.layers.0.mlp.up_proj.weight        |  58720256  |
|       model.layers.0.mlp.down_proj.weight       |  58720256  |
|      model.layers.0.input_layernorm.weight      |    4096    |
|  model.layers.0.post_attention_layernorm.weight |    4096    |
|      model.layers.1.self_attn.q_proj.weight     |  16777216  |
|      model.layers.1.self_attn.k_proj.weight     |  4194304   |
|      model.layers.1.sel

8030261248

In [None]:
import time
input_text = "Once upon a time in a distant land, there lived a"

# Set pad_token_id explicitly
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Prepare input with attention mask
inputs = tokenizer(
    input_text,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512,
)

input_ids = inputs["input_ids"].repeat(16,1).to(model.device)

attention_mask = inputs["attention_mask"].repeat(16,1).to(model.device)

## time
start_time = time.time()
# Use the attention_mask during generation
output = model.generate(
    input_ids,
    attention_mask=attention_mask,
    max_length=50,
    num_return_sequences=5,
    pad_token_id=tokenizer.pad_token_id,
)
end_time = time.time()

# Calculate and display inference time
inference_time = end_time - start_time
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Inference Time: {inference_time:.4f} seconds")
print("Generated Text:")
print(generated_text)

In [None]:
from torch.nn.utils import prune
import torch.nn as nn
import copy
from tqdm import tqdm

# pruned_model = copy.deepcopy(model)

# Example structured pruning function
def structured_pruning(model, pruning_percentage=0.2):
    transformer = model.model
    print(transformer)
    for i, layer in tqdm(enumerate(transformer.layers)):
        for attr_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
            module = getattr(layer.self_attn, attr_name)
            if hasattr(module, "weight") and isinstance(module.weight, torch.nn.Parameter):
                prune.ln_structured(module, name="weight", amount=pruning_percentage, n=1, dim=0)
                prune.remove(module, 'weight')
                # print("pruned")
            else:
                print(f"Skipping {i} {attr_name}: No self_attn 'weight' attribute found.")
        
        # Prune MLP weights
        for attr_name in ["gate_proj", "up_proj", "down_proj"]:
            module = getattr(layer.mlp, attr_name)
            if hasattr(module, "weight") and isinstance(module.weight, torch.nn.Parameter):
                prune.ln_structured(module, name="weight", amount=pruning_percentage, n=1, dim=0)
                prune.remove(module, 'weight')
                # print("pruned")
            else:
                print(f"Skipping {i} {attr_name}: No MLP 'weight' attribute found.")
                
        # # Apply unstructured pruning for 1D layer normalization weights
        # prune.l1_unstructured(layer.input_layernorm, name="weight", amount=pruning_percentage)
        # prune.l1_unstructured(layer.post_attention_layernorm, name="weight", amount=pruning_percentage)
    
    # Prune embedding weights 
    if hasattr(transformer, 'embed_tokens') and hasattr(transformer.embed_tokens, 'weight'):
        prune.ln_structured(getattr(transformer, 'embed_tokens'), name="weight", amount=pruning_percentage, n=1, dim=0)
        prune.remove(getattr(transformer, 'embed_tokens'), 'weight')

def rebuild_model_with_pruned_weights_in_place(model):
    """
    Rebuild the model after pruning by modifying it in-place to avoid memory overhead,
    with dimensional checks to ensure consistency.
    """
    transformer = model.model

    # Iterate over transformer layers to prune weights
    for i, layer in tqdm(enumerate(transformer.layers), desc="Rebuilding Transformer Layers"):
        # Handle self-attention projection layers
        for attr_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
            module = getattr(layer.self_attn, attr_name, None)
            if module is not None and hasattr(module, "weight"):
                with torch.no_grad():
                    # Identify non-zero indices along dim=0 (input features)
                    input_non_zero_indices = module.weight.abs().sum(dim=1).nonzero(as_tuple=True)[0]
                    output_non_zero_indices = module.weight.abs().sum(dim=0).nonzero(as_tuple=True)[0]
                    
                    # Check dimensional consistency
                    input_dim = module.weight.size(1)
                    output_dim = module.weight.size(0)
                    
                    # Ensure indices are within bounds
                    input_non_zero_indices = input_non_zero_indices[input_non_zero_indices < input_dim]
                    output_non_zero_indices = output_non_zero_indices[output_non_zero_indices < output_dim]

                    # Update weights with pruned dimensions
                    module.weight = nn.Parameter(module.weight[output_non_zero_indices][:, input_non_zero_indices])
                    module.out_features = len(output_non_zero_indices)
                    module.in_features = len(input_non_zero_indices)

        # Handle MLP projection layers
        for attr_name in ["gate_proj", "up_proj", "down_proj"]:
            module = getattr(layer.mlp, attr_name, None)
            if module is not None and hasattr(module, "weight"):
                with torch.no_grad():
                    # Identify non-zero indices along dim=0 (input features)
                    input_non_zero_indices = module.weight.abs().sum(dim=1).nonzero(as_tuple=True)[0]
                    output_non_zero_indices = module.weight.abs().sum(dim=0).nonzero(as_tuple=True)[0]
                    
                    # Check dimensional consistency
                    input_dim = module.weight.size(1)
                    output_dim = module.weight.size(0)
                    
                    # Ensure indices are within bounds
                    input_non_zero_indices = input_non_zero_indices[input_non_zero_indices < input_dim]
                    output_non_zero_indices = output_non_zero_indices[output_non_zero_indices < output_dim]

                    # Update weights with pruned dimensions
                    module.weight = nn.Parameter(module.weight[output_non_zero_indices][:, input_non_zero_indices])
                    module.out_features = len(output_non_zero_indices)
                    module.in_features = len(input_non_zero_indices)

    # Handle embedding layers
    if hasattr(transformer, "embed_tokens"):
        embed_tokens = transformer.embed_tokens
        if hasattr(embed_tokens, "weight"):
            with torch.no_grad():
                # Identify non-zero indices for embeddings
                non_zero_indices = embed_tokens.weight.abs().sum(dim=0).nonzero(as_tuple=True)[0]
                # Ensure indices are within bounds
                embed_dim = embed_tokens.weight.size(0)
                non_zero_indices = non_zero_indices[non_zero_indices < embed_dim]

                embed_tokens.weight = nn.Parameter(embed_tokens.weight[non_zero_indices])
                embed_tokens.num_embeddings = len(non_zero_indices)

    return model

# Example Usage
structured_pruning(model, pruning_percentage=0.5)
model = rebuild_model_with_pruned_weights_in_place(model)


from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params//1000/1000/1000:0.1f}B")
    return total_params
    
count_parameters(model)


ModuleList(
  (0-31): 32 x LlamaDecoderLayer(
    (self_attn): LlamaSdpaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
    (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
  )
)


32it [01:58,  3.70s/it]
Rebuilding Transformer Layers: 32it [00:50,  1.59s/it]


+-------------------------------------------------+------------+
|                     Modules                     | Parameters |
+-------------------------------------------------+------------+
|            model.embed_tokens.weight            |  16777216  |
|      model.layers.0.self_attn.q_proj.weight     |  8388608   |
|      model.layers.0.self_attn.k_proj.weight     |   524288   |
|      model.layers.0.self_attn.v_proj.weight     |   524288   |
|      model.layers.0.self_attn.o_proj.weight     |  8388608   |
|       model.layers.0.mlp.gate_proj.weight       |  8364032   |
|        model.layers.0.mlp.up_proj.weight        |  8306688   |
|       model.layers.0.mlp.down_proj.weight       |  8388608   |
|      model.layers.0.input_layernorm.weight      |    4096    |
|  model.layers.0.post_attention_layernorm.weight |    4096    |
|      model.layers.1.self_attn.q_proj.weight     |  8388608   |
|      model.layers.1.self_attn.k_proj.weight     |   524288   |
|      model.layers.1.sel

1917779968

In [None]:
import time
input_text = "Once upon a time in a distant land, there lived a"

# Set pad_token_id explicitly
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Prepare input with attention mask
inputs = tokenizer(
    input_text,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512,
)

input_ids = inputs["input_ids"].repeat(16,1).to(model.device)

attention_mask = inputs["attention_mask"].repeat(16,1).to(model.device)

## time
start_time = time.time()
# Use the attention_mask during generation
output = model.generate(
    input_ids,
    attention_mask=attention_mask,
    max_length=50,
    num_return_sequences=5,
    pad_token_id=tokenizer.pad_token_id,
)
end_time = time.time()

# Calculate and display inference time
inference_time = end_time - start_time
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Inference Time: {inference_time:.4f} seconds")
print("Generated Text:")
print(generated_text)