In [None]:
!pip install datasets==3.6.0 evaluate rouge-score

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch import nn
from torch.nn import  ModuleList
import json
from typing import List, Dict, Callable, Any
from tqdm import tqdm
from copy import deepcopy
from datasets import load_dataset, Dataset
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2Model, Qwen2Attention, Qwen2ForCausalLM, Qwen2RMSNorm, Qwen2RotaryEmbedding, Qwen2MLP

In [3]:
def get_num_params(model,only_require_grad = False, in_gb = False, bytes_per_param = 2):
    num_params = 0
    for name, param in model.named_parameters():
        if only_require_grad:
          if param.requires_grad:
              num_params += param.numel()
        else:
          num_params += param.numel()

    return f"{round((num_params*bytes_per_param)/10**9,2)}gb" if in_gb else f"{num_params}params"

In [4]:
def model_initialization(model_name_or_path):
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
  model = AutoModelForCausalLM.from_pretrained(
      model_name_or_path,
      torch_dtype=torch.float16,
      device_map="auto"
  )
  return tokenizer, model

In [5]:
def replace_modulelist_in_model(model, modulelist: torch.nn.ModuleList):
  setattr(model.model, "layers", modulelist)
  ## updating model's config
  model.config.num_hidden_layers = len(modulelist)

In [6]:
import gc
def remove_model_from_gpu_memory(m):
  del m
  gc.collect()
  torch.cuda.empty_cache()

In [7]:
def find_parent(model, name: str) -> torch.nn.Module:
    module_tree = name.split(".")[:-1]
    parent = model
    for m in module_tree:
        parent = parent._modules[m]
    return parent

In [8]:
def move_to_nearest_even(hidden_dim):
  if hidden_dim % 2 == 0:
    return hidden_dim
  else:
    return hidden_dim + 1

## dataset

In [None]:
ds = load_dataset("shivam9980/Inshorts-english", split="train")

In [None]:
ds

In [11]:
ds = ds.train_test_split(test_size = 0.025, seed = 3407)

In [None]:
ds

In [13]:
user_prompt = '''Generate a concise news headline based on the following news content. The headline should clearly and accurately summarize the key point of the article. Avoid exaggeration or misleading phrasing.

News Content: {content}'''

input_prompt = '''<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Generate a concise news headline based on the following news content. The headline should clearly and accurately summarize the key point of the article. Avoid exaggeration or misleading phrasing.

News Content: {content}<|im_end|>
<|im_start|>assistant
'''

In [14]:
def map_func(datapoint):
  datapoint["text"] = f'''<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Generate a concise news headline based on the following news content. The headline should clearly and accurately summarize the key point of the article. Avoid exaggeration or misleading phrasing.

News Content: {datapoint["Content"]}<|im_end|>
<|im_start|>assistant
{datapoint["Headline"]}<|im_end|>'''

  datapoint["input"] = f'''<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Generate a concise news headline based on the following news content. The headline should clearly and accurately summarize the key point of the article. Avoid exaggeration or misleading phrasing.

News Content: {datapoint["Content"]}<|im_end|>
<|im_start|>assistant
'''
  return datapoint
ds = ds.map(map_func)

In [15]:
calibration_dataset = ds["train"].select(range(612))

In [None]:
calibration_dataset

In [17]:
downstream_accuracy_evaluator_dataset = ds["test"].select(range(10))

In [None]:
downstream_accuracy_evaluator_dataset

## downstream accuracy

In [19]:
def get_output(model, tokenizer, inputs, sampling_params):
  model_inputs = tokenizer(inputs, padding=True, padding_side='left', return_tensors="pt").to(model.device)
  generated_ids = model.generate(
      **model_inputs,
      max_new_tokens=100,
      **sampling_params

  )
  generated_ids = [
      output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  ]
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  return response

In [20]:
import evaluate
rouge = evaluate.load('rouge')

In [21]:

def compute_rouge(predictions, references):
  results = rouge.compute(predictions=predictions,
                        references=references)
  return results

def handle_rouge_results(results):
  ## initialization
  final_score = {}
  for key in results[0].keys():
    final_score[key] = []

  ## collection
  for result in results:
    for key in result.keys():
      final_score[key].append(result[key])

  ## averaging
  for key in final_score.keys():
    final_score[key] = sum(final_score[key])/len(final_score[key])
  return final_score


In [22]:
def run_downstream_accuracy_test(
    model,
    tokenizer,
    dataset: Dataset,
    batch_size: int,
    reference_field: str,
    to_print: bool,
    sampling_params: Dict[str, str],
    batch_callback: Callable | None = None,
    batch_callback_args: Dict[Any, Any] | None = None
  ):
  rouge_results: List[Dict[str, float]] = []
  ds_iter = dataset.iter(batch_size)
  for batch in tqdm(ds_iter):
    predictions: List[str] = get_output(model, tokenizer, batch["input"], sampling_params)
    references: List[str] = batch[reference_field]
    rouge_result: Dict[str, float] = compute_rouge(predictions, references)
    rouge_results.append(rouge_result)

    if to_print:
      print("*"*6)
      for pred, refe in zip(predictions, references):
        print("-"*3)
        print(f"pred: {pred}")
        print(f"refe: {refe}")
        print("-"*3)
      print("*"*6)

    if batch_callback:
      if batch_callback_args:
        batch_callback(**batch_callback_args)
      else:
        batch_callback()

  dataset_rouge_result: Dict[str, float] = handle_rouge_results(rouge_results)
  return dataset_rouge_result

## depth pruning

In [23]:
def trim_layers(model, layers_index_to_trim: List[int]):
  encoder_layers = model.model.layers
  for layer_index_to_trim in sorted(layers_index_to_trim, reverse=True):
    del encoder_layers[layer_index_to_trim]
    ## updating model's config
    model.config.num_hidden_layers -= 1

  ## updating layer_idx value
  for new_layer_idx, layer in enumerate(encoder_layers):
    for module in layer.modules():
      if hasattr(module, "layer_idx"):
        module.layer_idx = new_layer_idx

In [24]:
def run_depth_pruning_experiments_from_model(model, tokenizer, collection_of_layers_index_to_trim: List[List[int]], batch_size: int, run_original_first: bool=True, also_init_best_model: bool = False):
  original_decoder_layers_copy = deepcopy(model.model.layers)

  if run_original_first:
    print("#"*3, "ORIGINAL","#"*3)
    model_size = get_num_params(model,in_gb=True)
    print(f"ORIGINAL MODEL SIZE: {model_size}")
    result = run_downstream_accuracy_test(model=model, tokenizer=tokenizer, dataset=downstream_accuracy_evaluator_dataset, batch_size=batch_size, reference_field="Headline", to_print=True, sampling_params=dict(do_sample=False, temperature=None, top_p=None, top_k=None))
    print(result)
    print("#"*3, "ORIGINAL","#"*3)

  capture_results = []
  for layers_index_to_trim in collection_of_layers_index_to_trim:
    print("#"*10)
    print(f"layers_index_to_trim: {layers_index_to_trim}")
    replace_modulelist_in_model(model=model, modulelist=deepcopy(original_decoder_layers_copy)) ## original model
    # print(len(original_decoder_layers_copy))
    trim_layers(model, layers_index_to_trim) ## prunned model
    model_size = get_num_params(model,in_gb=True)
    print(f"NEW MODEL SIZE: {model_size}")
    result = run_downstream_accuracy_test(model=model, tokenizer=tokenizer, dataset=downstream_accuracy_evaluator_dataset, batch_size=batch_size, reference_field="Headline", to_print=True, sampling_params=dict(do_sample=False, temperature=None, top_p=None, top_k=None))
    print(result)
    capture_results.append({"result":result, "layers_index_to_trim":layers_index_to_trim})
    print("#"*10,end ="\n\n")

  best_result = None
  best_rougeL = 0.0
  for result in capture_results:
    rougeL = result["result"]["rougeL"]
    if rougeL > best_rougeL:
      best_rougeL = rougeL
      best_result = result

  replace_modulelist_in_model(model=model, modulelist=original_decoder_layers_copy) ## original model
  if also_init_best_model:
    trim_layers(model, best_result["layers_index_to_trim"]) ## best pruned model

  return dict(best_result=best_result)

## width pruning

In [25]:
def mlp_neuron_importance_hook(module, input, output):
  '''
  B -> batch size
  S -> Sequence length
  in -> Input size
  out -> Output size
  ## input[0] (B, S, in)
  ## output (B, S, out)
  '''
  output_copy = output.clone().detach() ## (B, S, out)
  output_copy = torch.abs(output_copy)
  batch_size, _, _ = output_copy.size()
  if batch_size != 1:
    raise NotImplementedError("ONLY BATCH SIZE=1 IS IMPLEMENTED")
  single_batch_output = output_copy[0] ## (S, out)
  module.neurons_activation = torch.cat((module.neurons_activation, single_batch_output), 0) if hasattr(module,"neurons_activation") else single_batch_output ## (_, out)
  del output_copy, single_batch_output

In [26]:
def register_all_forward_hooks(model):
  for (name, module) in model.named_modules():
    if isinstance(module, Qwen2DecoderLayer):
      ## mlp neuron_pruning
      module.mlp.gate_proj.register_forward_hook(mlp_neuron_importance_hook)

def remove_all_forward_hooks(model):
  for (name, module) in model.named_modules():
    if isinstance(module, Qwen2DecoderLayer):
      ## mlp neuron_pruning
      module.mlp.gate_proj._forward_hooks.clear()

def remove_all_forward_hooks_stored_info(model):
  for (name, module) in model.named_modules():
    if isinstance(module, Qwen2DecoderLayer):
      ## mlp neuron_pruning
      if hasattr(module.mlp.gate_proj,"neurons_activation"):
        delattr(module.mlp.gate_proj,"neurons_activation")

In [27]:
def prune_mlp_neuron(model, importance_fn, fraction_to_prune = None, pruned_dim = None):
  for (name, module) in model.named_modules():
    if isinstance(module, Qwen2DecoderLayer):
      ## access importance
      importance = importance_fn(module) ## (out,)

      ## setting number of neurons to keep
      if not pruned_dim and not fraction_to_prune:
        raise ValueError("Both fraction_to_prune and pruned_dim are None. Please specify one.")
      num_of_most_imp_neurons_to_keep = pruned_dim if pruned_dim else move_to_nearest_even(
          int((1 - fraction_to_prune) * importance.shape[0])
      )

      ## setting which neurons to keep based on their magnitude(activation) value
      indices = importance.argsort(descending = True)[:num_of_most_imp_neurons_to_keep] ## (out*,)

      ## accessing the original modules
      gate_proj = module.mlp.gate_proj
      up_proj = module.mlp.up_proj
      down_proj = module.mlp.down_proj

      ## pruning weights and biases
      gate_proj_weight = gate_proj.weight.data.clone() ## (out, in)
      pruned_gate_proj_weight = gate_proj_weight[indices, :] ## (out*, in)
      pruned_gate_proj_bias = None
      if gate_proj.bias is not None:
        gate_proj_bias = gate_proj.bias.data.clone() ## (out)
        pruned_gate_proj_bias = gate_proj_bias[indices] ## (out*)

      up_proj_weight = up_proj.weight.data.clone() ## (out, in)
      pruned_up_proj_weight = up_proj_weight[indices, :] ## (out*, in)
      pruned_up_proj_bias = None
      if up_proj.bias is not None:
        up_proj_bias = up_proj.bias.data.clone() ## (out)
        pruned_up_proj_bias = up_proj_bias[indices] ## (out*)

      down_proj_weight = down_proj.weight.data.clone() ## (out, in)
      pruned_down_proj_weight = down_proj_weight[:, indices] ## (out, in*)
      pruned_down_proj_bias = None
      if down_proj.bias is not None:
        down_proj_bias = down_proj.bias.data.clone() ## (out)
        pruned_down_proj_bias = down_proj_bias ## (out)

      ## constructing pruned modules
      pruned_gate_proj = nn.Linear(gate_proj.in_features, num_of_most_imp_neurons_to_keep, bias=gate_proj.bias is not None).to(
        model.device
      )
      pruned_up_proj = nn.Linear(up_proj.in_features, num_of_most_imp_neurons_to_keep, bias=up_proj.bias is not None).to(
        model.device
      )
      pruned_down_proj = nn.Linear(num_of_most_imp_neurons_to_keep, down_proj.out_features, bias=down_proj.bias is not None).to(
        model.device
      )

      ## storing pruned weights and biases to pruned modules
      pruned_gate_proj.weight.data = pruned_gate_proj_weight ## (out*, in)
      if pruned_gate_proj_bias is not None:
        pruned_gate_proj.bias.data = pruned_gate_proj_bias ## (out*)

      pruned_up_proj.weight.data = pruned_up_proj_weight ## (out*, in)
      if pruned_up_proj_bias is not None:
        pruned_up_proj.bias.data = pruned_up_proj_bias ## (out*)

      pruned_down_proj.weight.data = pruned_down_proj_weight ## (out, in*)
      if pruned_down_proj_bias is not None:
        pruned_down_proj.bias.data = pruned_down_proj_bias ## (out)

      ## replacing original modules with pruned_modules
      setattr(module.mlp, "gate_proj", pruned_gate_proj)
      setattr(module.mlp, "up_proj", pruned_up_proj)
      setattr(module.mlp, "down_proj", pruned_down_proj)

      ## updating intermediate_size value in Qwen2MLP modules
      setattr(module.mlp, "intermediate_size",  num_of_most_imp_neurons_to_keep)

  ## updating intermediate_size value in model's config
  setattr(model.model.config, "intermediate_size", num_of_most_imp_neurons_to_keep)

  ## updating model's config in modules
  for (name, module) in model.named_modules():
    if hasattr(module, "config"):
      setattr(module, "config", model.model.config)

In [28]:
def batch_callback(model):
  for (name, module) in model.named_modules():
    if isinstance(module, Qwen2DecoderLayer):
      ## accessing the concatenated activation values for all token in a single sequence
      gate_proj_neurons_activation = module.mlp.gate_proj.neurons_activation ## (_, out)

      ## calculating importance by aggregating the concatenated activation values of all tokens in a single sequence
      importance = torch.linalg.norm(gate_proj_neurons_activation, ord=2, dim=0, dtype=torch.float32) ## (out,)
      # importance = torch.mean(gate_proj_neurons_activation, axis = 0, dtype = torch.float32)
      # importance = torch.sum(gate_proj_neurons_activation, axis = 0, dtype = torch.float32)
      # importance = torch.std(gate_proj_neurons_activation, dim = 0)
      # importance = torch.median(gate_proj_neurons_activation, dim = 0).values
      # importance = torch.quantile(gate_proj_neurons_activation, q=0.5, dim = 0)



      ## concatenating importance values of all sequence
      importance = importance.reshape(1, -1) ## (1, out)
      module.mlp.gate_proj.calibration_ds_importance = torch.cat((module.mlp.gate_proj.calibration_ds_importance, importance), dim=0) if hasattr(module.mlp.gate_proj, "calibration_ds_importance") else importance ## (_, out)

  ## removing the attached hook, deleting the info stored by hook(neurons_activation in these case) and then attaching a new hook
  remove_all_forward_hooks(model)
  remove_all_forward_hooks_stored_info(model)
  register_all_forward_hooks(model)

In [29]:
def access_importance(module):
  importance = torch.linalg.norm(module.mlp.gate_proj.calibration_ds_importance, ord=2, dim=0, dtype=torch.float32) ## (out,)
  # importance = torch.mean(module.mlp.gate_proj.calibration_ds_importance, axis=0, dtype=torch.float32) ## (out,)
  # importance = torch.sum(module.mlp.gate_proj.calibration_ds_importance, axis=0, dtype=torch.float32) ## (out,)
  # importance = torch.median(module.mlp.gate_proj.calibration_ds_importance, dim=0).values ## (out,)

  return importance

In [30]:
def run_width_pruning_experiment(model, tokenizer, pruned_dim: int|None, batch_size: int, run_original_first: bool=True):

  remove_all_forward_hooks(model)
  remove_all_forward_hooks_stored_info(model)

  if run_original_first:
    print("#"*3, "ORIGINAL","#"*3)
    model_size = get_num_params(model,in_gb=True)
    print(f"ORIGINAL MODEL SIZE: {model_size}")
    result = run_downstream_accuracy_test(model=model, tokenizer=tokenizer, dataset=downstream_accuracy_evaluator_dataset, batch_size=batch_size, reference_field="Headline", to_print=True, sampling_params=dict(do_sample=False, temperature=None, top_p=None, top_k=None))
    print(result)
    print("#"*3, "ORIGINAL","#"*3)


  register_all_forward_hooks(model)
  run_downstream_accuracy_test(model=model, tokenizer=tokenizer, dataset=calibration_dataset, batch_size=1, reference_field="Headline", to_print=False, sampling_params=dict(do_sample=False, temperature=None, top_p=None, top_k=None), batch_callback=batch_callback, batch_callback_args=dict(model=model))

  remove_all_forward_hooks(model)
  remove_all_forward_hooks_stored_info(model)

  prune_mlp_neuron(model, fraction_to_prune=None, importance_fn=access_importance, pruned_dim=pruned_dim)
  model_size = get_num_params(model, in_gb=True)
  print(f"NEW MODEL SIZE: {model_size}")

  pruned_model_result = run_downstream_accuracy_test(model=model, tokenizer=tokenizer, dataset=downstream_accuracy_evaluator_dataset, batch_size=batch_size, reference_field="Headline", to_print=True, sampling_params=dict(do_sample=False, temperature=None, top_p=None, top_k=None))

  return dict(pruned_model_result=pruned_model_result)


## hybrid pruning

#### width to depth

In [31]:
tokenizer, model = model_initialization("nis12ram/qwen2.5-0.5B-Instruct-Inshort")

In [None]:
width_experiment_result = run_width_pruning_experiment(model=model, tokenizer=tokenizer, pruned_dim=4096, batch_size=5 )

In [None]:
width_experiment_result.get('pruned_model_result')

In [None]:
collection_of_layers_index_to_trim = [
  [i, i+1, i+2, i+3, i+4, i+5, i+6, i+7, i+8, i+9, i+10,  i+11, i+12]  for i in range(12)
]
(collection_of_layers_index_to_trim)

In [None]:
depth_experiment_result = run_depth_pruning_experiments_from_model(model, tokenizer, collection_of_layers_index_to_trim = collection_of_layers_index_to_trim, batch_size = 5, also_init_best_model= True)

In [None]:
depth_experiment_result.get('best_result')

In [None]:
model, model.config

In [None]:
model.save_pretrained("/content/inshort/models/hybrid7_base")
tokenizer.save_pretrained("/content/inshort/models/hybrid7_base")

## sample inference test

In [39]:
tokenizer, model = model_initialization("/content/inshort/models/hybrid7_base")

In [None]:
model

In [None]:
output = get_output(model, tokenizer, calibration_dataset[100]["input"], dict(do_sample=False, temperature=None, top_p=None, top_k=None))
calibration_dataset[100]["input"], output