In [1]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
import torch 
import torch.nn as nn 

# from lib.layerwrapper import WrappedGPT
from lib.data import get_loaders 

from lib.prune_opt import prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers
from dataclasses import dataclass



In [2]:
@dataclass
class Args(object):
    model: str = "facebook/opt-125m"
    seed: int = 0
    nsamples: int = 128
    sparsity_ratio = 0.3
    sparsity_type = "unstructured"
    cache_dir = "llm_weights"
    prune_method = "wanda"
args = Args()

In [3]:
class WrappedGPT:
    """
    This class wraps a GPT layer for specific operations.
    """

    def __init__(self, layer, layer_id=0, layer_name="none"):
        
        self.layer = layer
        
        self.dev = self.layer.weight.device
        self.rows = layer.weight.data.shape[0]
        self.columns = layer.weight.data.shape[1]

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        print("Scaler row shape (1): ", self.scaler_row.shape)
        print("Scaler row max (1): ", torch.max(self.scaler_row))
        print(" ")
        self.nsamples = 0

        self.layer_id = layer_id 
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        # if self.layer_name == "self_attn.q_proj":
        #     print("self.nsamples: ", self.nsamples)
        #     print("### LAYER NAME ###: ", self.layer_name)
        # print("inp shape: ", inp.shape)

        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        
        tmp = inp.shape[0]
        
        if tmp != 1:
            print("#### TMP ####: ", tmp)

        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
                # print("Reshaped inp: ", inp.shape)
            inp = inp.t()

        self.scaler_row *= self.nsamples / (self.nsamples+tmp)

        # print("Scaler row shape (2): ", self.scaler_row.shape)
        # if self.layer_name == "self_attn.q_proj":
        #     print("Scaler row norm (2): ", torch.norm(self.scaler_row, p = 2))
        
        self.nsamples += tmp
        

        inp = inp.type(torch.float32)

        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2  / self.nsamples

        # print("Scaler row shape (3): ", self.scaler_row.shape)
        # if self.layer_name == "self_attn.q_proj":
        #     print("Scaler row norm (3): ", torch.norm(self.scaler_row, p = 2))
        #     print(" ")

In [4]:
def get_llm(model_name, cache_dir="llm_weights"):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir, 
        low_cpu_mem_usage=True, 
        device_map="auto"
    )

    model.seqlen = model.config.max_position_embeddings 
    return model

In [5]:
def prepare_calibration_input(model, dataloader, device):

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.decoder.layers
    
    # print("model: ", model)
    # print("layers type: ", type(layers))
    # print("layers: ", layers)
    # print("layers[0] type: ", type(layers[0]))
    # print("layers[0]: ", layers[0])
    
    if "model.embed_tokens" in model.hf_device_map:
        device = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
    inps.requires_grad = False
    
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

            # print("module type: ", type(module))
            # print("module: ", module)

        def forward(self, inp, **kwargs):

            # print("inp type: ", type(inp))
            # print("inp shape: ", inp.shape)
            # print("cache['i']: ", cache['i'])
            # print("kwargs keys: ", list(kwargs.keys()))
            inps[cache['i']] = inp

            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']

            raise ValueError
        
    layers[0] = Catcher(layers[0])
    count = 0
    
    for batch in dataloader:
        count += 1
        try:

            # print("batch[1] dtype: ", batch[1].dtype)
            # print("batch[1] shape: ", batch[1].shape)
            # print("batch[1] max: ", torch.max(batch[1]))
            
            # print("batch[0] dtype: ", batch[0].dtype)
            # print("batch[0] shape: ", batch[0].shape)
            # print("batch[0] max: ", torch.max(batch[0]))

            model(batch[0].to(device))

        except ValueError:
            pass
        # print(count)    

    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask


In [6]:
def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    
    use_cache = model.config.use_cache 
    model.config.use_cache = False 
    
    print("loading calibdation data")

    dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)

    print("dataset loading complete")

    with torch.no_grad():
        inps, outs, attention_mask = prepare_calibration_input(model, dataloader, device)
    print("att mask: ", attention_mask)
    print("att mask type: ", type(attention_mask))
    print("att mask type: ", attention_mask.shape)
    layers = model.model.decoder.layers
    
    for i in range(len(layers)):

        layer = layers[i]
        subset = find_layers(layer)

        if f"model.layers.{i}" in model.hf_device_map:   ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
            dev = model.hf_device_map[f"model.layers.{i}"]
            inps, outs, attention_mask = inps.to(dev), outs.to(dev), attention_mask.to(dev)

        wrapped_layers = {}

        
        for name in subset:
            print("name :", name)
            wrapped_layers[name] = WrappedGPT(subset[name], layer_name = name)

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        
        # why do we need this second loop?
        # it could be absorbed into the first loop
        
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
            
        for j in range(args.nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
        for h in handles:
            h.remove()

        for name in subset:
            print(f"pruning layer {i} name {name}")
            W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))

            W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
            if prune_n != 0:
                # structured n:m sparsity
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:,ii:(ii+prune_m)].float()
                        W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
            else:
                sort_res = torch.sort(W_metric, dim=-1, stable=True)

                # unstructured pruning
                indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity_ratio)]
                W_mask.scatter_(1, indices, True)

            subset[name].weight.data[W_mask] = 0  ## set weights to zero 

        for j in range(args.nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
        inps, outs = outs, inps

    model.config.use_cache = use_cache 
    torch.cuda.empty_cache()

In [7]:
# Setting seeds for reproducibility

np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

# Handling n:m sparsity

prune_n, prune_m = 0, 0
model_name = args.model.split("/")[-1]

print(f"loading llm model {args.model}")

model = get_llm(args.model, args.cache_dir)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
device = torch.device("cuda:0")


loading llm model facebook/opt-125m


In [8]:
def eval_ppl_wikitext(model, testenc, bs=1, device=None):
    # Get input IDs
    testenc = testenc.input_ids

    # Calculate number of samples
    nsamples = testenc.numel() // model.seqlen

    # List to store negative log likelihoods
    nlls = []
    print(f"nsamples {nsamples}")

    # Loop through each batch
    for i in range(0,nsamples,bs):
        if i % 50 == 0:
            print(f"sample {i}")

        # Calculate end index
        j = min(i+bs, nsamples)

        # Prepare inputs and move to device
        inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j-i, model.seqlen)

        # Forward pass through the model
        lm_logits = model(inputs).logits

        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        # Compute loss
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

        # Calculate negative log likelihood
        neg_log_likelihood = loss.float() * model.seqlen * (j-i)

        # Append to list of negative log likelihoods
        nlls.append(neg_log_likelihood)

    # Compute perplexity
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

    # Empty CUDA cache to save memory
    torch.cuda.empty_cache()

    return ppl.item()

In [9]:
def eval_ppl(args, model, tokenizer, dataset = "c4", device=torch.device("cuda:0")):

    # Print status
    print(f"evaluating on {dataset}")

    # Get the test loader
    _, testloader = get_loaders(
        dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer 
    )

    # Evaluate ppl in no grad context to avoid updating the model
    with torch.no_grad():
        ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
    return ppl_test 

In [10]:
eval_ppl(args, model, tokenizer)

evaluating on c4


Using custom data configuration allenai--c4-ec45c889631c3c39
Reusing dataset json (C:\Users\igor-\.cache\huggingface\datasets\json\allenai--c4-ec45c889631c3c39\0.0.0\c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)
Using custom data configuration allenai--c4-7700d5d1c53cf32f
Reusing dataset json (C:\Users\igor-\.cache\huggingface\datasets\json\allenai--c4-7700d5d1c53cf32f\0.0.0\c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)


nsamples 256
sample 0
sample 50
sample 100
sample 150
sample 200
sample 250


26.56388282775879

In [11]:
eval_ppl(args, model, tokenizer)

evaluating on c4


Using custom data configuration allenai--c4-ec45c889631c3c39
Reusing dataset json (C:\Users\igor-\.cache\huggingface\datasets\json\allenai--c4-ec45c889631c3c39\0.0.0\c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)
Using custom data configuration allenai--c4-7700d5d1c53cf32f
Reusing dataset json (C:\Users\igor-\.cache\huggingface\datasets\json\allenai--c4-7700d5d1c53cf32f\0.0.0\c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)


nsamples 256
sample 0
sample 50
sample 100
sample 150
sample 200
sample 250


26.56388282775879

In [12]:

print("use device ", device)

if args.sparsity_ratio != 0:
    print("pruning starts")
    if args.prune_method == "wanda":
        prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
    elif args.prune_method == "magnitude":
        prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
    elif args.prune_method == "sparsegpt":
        prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
    elif "ablate" in args.prune_method:
        prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)

################################################################

print("*"*30)
sparsity_ratio = check_sparsity(model)
print(f"sparsity sanity check {sparsity_ratio:.4f}")
print("*"*30)


use device  cuda:0
pruning starts
loading calibdation data


Using custom data configuration allenai--c4-ec45c889631c3c39
Reusing dataset json (C:\Users\igor-\.cache\huggingface\datasets\json\allenai--c4-ec45c889631c3c39\0.0.0\c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)
Using custom data configuration allenai--c4-7700d5d1c53cf32f
Reusing dataset json (C:\Users\igor-\.cache\huggingface\datasets\json\allenai--c4-7700d5d1c53cf32f\0.0.0\c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)


dataset loading complete
att mask:  tensor([[[[     0., -65504., -65504.,  ..., -65504., -65504., -65504.],
          [     0.,      0., -65504.,  ..., -65504., -65504., -65504.],
          [     0.,      0.,      0.,  ..., -65504., -65504., -65504.],
          ...,
          [     0.,      0.,      0.,  ...,      0., -65504., -65504.],
          [     0.,      0.,      0.,  ...,      0.,      0., -65504.],
          [     0.,      0.,      0.,  ...,      0.,      0.,      0.]]]],
       device='cuda:0', dtype=torch.float16)
att mask type:  <class 'torch.Tensor'>
att mask type:  torch.Size([1, 1, 2048, 2048])
name : self_attn.k_proj
Scaler row shape (1):  torch.Size([768])
Scaler row max (1):  tensor(0., device='cuda:0')
 
name : self_attn.v_proj
Scaler row shape (1):  torch.Size([768])
Scaler row max (1):  tensor(0., device='cuda:0')
 
name : self_attn.q_proj
Scaler row shape (1):  torch.Size([768])
Scaler row max (1):  tensor(0., device='cuda:0')
 
name : self_attn.out_proj
Scaler ro

In [13]:
eval_ppl(args, model, tokenizer)

evaluating on c4


Using custom data configuration allenai--c4-ec45c889631c3c39
Reusing dataset json (C:\Users\igor-\.cache\huggingface\datasets\json\allenai--c4-ec45c889631c3c39\0.0.0\c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)
Using custom data configuration allenai--c4-7700d5d1c53cf32f
Reusing dataset json (C:\Users\igor-\.cache\huggingface\datasets\json\allenai--c4-7700d5d1c53cf32f\0.0.0\c90812beea906fcffe0d5e3bb9eba909a80a998b5f88e9f8acbd320aa91acfde)


nsamples 256
sample 0
sample 50
sample 100
sample 150
sample 200
sample 250


27.570693969726562

In [14]:
26.56388282775879 # baseline unpruned
26.56388282775879 # loop absorption, unpruned
26.56388282775879 # scaler_row commented, unpruned


27.570693969726562 # baseline pruned
27.570693969726562 # loop absorption, pruned
4905.3291015625 # scaler_row commented, pruned


4905.3291015625