<a href="https://colab.research.google.com/github/ibacaraujo/How-to-Generate-Music/blob/master/research_welore.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Research, rank reduction. Base code, WeLore. For study and progress.

## Install libraries.

In [None]:
!pip install loguru
!pip install wandb
!pip install datasets

## Import libraries.

In [5]:
import pandas as pd
import numpy as np

In [6]:
import os
import torch
import argparse
import numpy as np

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

from importlib.metadata import version

from timeit import default_timer as timer
from datetime import timedelta

transformers.logging.set_verbosity_error()

In [7]:
from loguru import logger
import wandb

### utils.

`from utils import *.`

In [8]:
import pickle
import torch
import numpy as np
import random

choices = ["A", "B", "C", "D"]


def save_dict(item, filename):
    with open(filename, 'wb') as handle:
        pickle.dump(item, handle, protocol=pickle.HIGHEST_PROTOCOL)

def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

def shuffleDict(d):
  keys = list(d.keys())
  random.shuffle(keys)
  [(key, d[key]) for key in keys]
  random.shuffle(keys)
  [(key, d[key]) for key in keys]
  random.shuffle(keys)
  keys = [(key, d[key]) for key in keys]
  #keys = d(keys)
  return dict(keys)

def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt


def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    return prompt


@torch.no_grad()
def eval(args, subject, model, tokenizer, dev_df, test_df, f):
    cors = []
    all_probs = []
    answers = choices[: test_df.shape[1] - 2]

    for i in range(test_df.shape[0]):
        # get prompt and make sure it fits
        k = args.ntrain
        prompt_end = format_example(test_df, i, include_answer=False)
        train_prompt = gen_prompt(dev_df, subject, k)
        prompt = train_prompt + prompt_end
        # print(prompt)
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        while input_ids.shape[-1] > 2048:
            k -= 1
            train_prompt = gen_prompt(dev_df, subject, k)
            prompt = train_prompt + prompt_end
            input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        label = test_df.iloc[i, test_df.shape[1] - 1]

        generate_ids = model.generate(input_ids, max_length=len(input_ids[0]) + 1)
        output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

        pred = output[-1:]
        print(label, pred)

        cor = pred == label
        cors.append(cor)

    acc = np.mean(cors)
    cors = np.array(cors)

    all_probs = np.array(all_probs)
    print("Average accuracy {:.3f} - {}".format(acc, subject), file=f)
    f.flush()

    return cors, acc, all_probs


def uniform_rank_pruning(args, pruning_ratio, layers_singular_value, logger):
    total_rank, pruned_rank = 0, 0
    rank_pruning = {}
    for index in range(0, len(layers_singular_value)):
        layer = layers_singular_value[index]
        subset = list(layer.keys())
        rank_pruning[index] = {}
        for name in subset:
            _data = layer[name].clone().cpu().numpy()
            rank_pruning[index][name] = int(pruning_ratio * len(_data))
            total_rank += len(_data)
            pruned_rank += rank_pruning[index][name]
    logger.info(f"Attempted Rank Reduction: {(pruned_rank/total_rank)* 100:.3f} %")
    return rank_pruning

def adaptive_rank_pruning(args, pruning_ratio, layers_singular_value, logger):
    logger.info(f"Using the mean threolding\nsum(_data < args.rank_thresold = {args.rank_thresold})\n\n")
    total_rank, pruned_rank = 0, 0
    rank_pruning = {}
    for index in range(0, len(layers_singular_value)):
        layer = layers_singular_value[index]
        subset = list(layer.keys())
        rank_pruning[index] = {}
        for name in subset:
            data = layer[name].clone().cpu().numpy()
            _data = (data-min(data))/(max(data)-min(data))
            rank_pruning[index][name] = sum(_data < args.rank_thresold) # Rank which will be pruned
            total_rank += len(_data)
            pruned_rank += rank_pruning[index][name]
    logger.info(f"Attempted Rank Reduction: {(pruned_rank/total_rank)* 100:.3f} %")
    return rank_pruning

def uniform_rank_pruning_exp2(args, pruning_ratio, layers_singular_value, file_name):
    total_rank, pruned_rank = 0, 0
    rank_pruning = {}
    prune_layers = [15, 22, 25, 27]
    for index in range(0, len(layers_singular_value)):
        layer = layers_singular_value[index]
        subset = list(layer.keys())
        rank_pruning[index] = {}
        for name in subset:
            _data = layer[name].clone().cpu().numpy()
            if index in prune_layers:
                rank_pruning[index][name] = int(pruning_ratio * len(_data))
            else:
                rank_pruning[index][name] = 0
            total_rank += len(_data)
            pruned_rank += rank_pruning[index][name]
            print(f"layer{index}.{name} rank reduction: \t\t{(rank_pruning[index][name]/len(_data))* 100:.3f} %", file=file_name, flush=True)
    print(f"Rank Reduction: {(pruned_rank/total_rank)* 100:.3f} %", file=file_name, flush=True)
    return rank_pruning

def weight_thresold_rank_pruning(args, layers_singular_value, file_name):
    """
    Given a rank thresold, normalize the singular values and prune each layer under the rank_thresold
    """
    print(f"Using the mean threolding\nsum(_data < args.rank_thresold = {args.rank_thresold})\n\n", file=file_name, flush=True)
    total_rank, pruned_rank = 0, 0
    rank_pruning = {}
    for index in range(0, len(layers_singular_value)):
        layer = layers_singular_value[index]
        subset = list(layer.keys())
        rank_pruning[index] = {}
        for name in subset:
            data = layer[name].clone().cpu().numpy()
            _data = (data-min(data))/(max(data)-min(data))
            rank_pruning[index][name] = sum(_data < args.rank_thresold) # Rank which will be pruned
            total_rank += len(_data)
            pruned_rank += rank_pruning[index][name]
            print(f"layer{index}.{name} rank reduction: \t\t{(rank_pruning[index][name]/len(_data))* 100:.3f} %", file=file_name, flush=True)
    print(f"\n\n Total Rank Reduction: {(pruned_rank/total_rank)* 100:.3f} %", file=file_name, flush=True)
    return rank_pruning

### lib.

```
from lib.rank_reduction import do_rank_reduction
from lib.rank_utils import rank_analysis_weight

from lib.eval import eval_ppl
```



#### data utils.

In [9]:
import numpy as np
import random
import torch
from datasets import load_dataset



# Set seed for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)

# Wrapper for tokenized input IDs
class TokenizerWrapper:
    def __init__(self, input_ids):
        self.input_ids = input_ids

def get_c4(nsamples, seed, seqlen, tokenizer):
    # Load train and validation datasets
    traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, ignore_verifications=True)["train"]
    valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, ignore_verifications=True)["validation"]

    # Generate samples from training set
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] >= seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    # Prepare validation dataset
    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
    valenc = valenc.input_ids[:, :(256 * seqlen)]
    valenc = TokenizerWrapper(valenc)
    return trainloader, valenc

# Load and process wikitext2 dataset
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    # Load train and test datasets
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    # Encode datasets
    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    # Generate samples from training set
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def shuffleDict(d):
    keys = list(d.keys())
    random.shuffle(keys)
    [(key, d[key]) for key in keys]
    random.shuffle(keys)
    [(key, d[key]) for key in keys]
    random.shuffle(keys)
    keys = [(key, d[key]) for key in keys]
    #keys = d(keys)
    return dict(keys)

def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

#### LowRankLayer.

In [10]:
import torch
import torch.nn as nn

class LowRankLayer(nn.Module):
    """given a linear layer find low rank decomposition"""
    def __init__(self, desired_rank, weight, require_grad=True):
        super().__init__()
        self.rank = desired_rank


        results = torch.svd(weight)
        U = results[0][:, :desired_rank]
        S = results[1][:desired_rank]
        V = results[2][:, :desired_rank]

        self.U = nn.Linear(desired_rank, U.shape[0], bias=False).to(weight.device)
        self.V = nn.Linear(V.shape[0], desired_rank, bias=False).to(weight.device)

        self.U.weight.data = U.mul(S.sqrt()).to(torch.bfloat16).contiguous()
        self.V.weight.data = V.t().mul(S.sqrt().view(-1, 1)).to(torch.bfloat16).contiguous()

        if require_grad == False:
            self.U.weight.requires_grad = False
            self.V.weight.requires_grad = False
        else:
            self.U.weight.requires_grad = True
            self.V.weight.requires_grad = True

    def forward(self, x):
        output = self.U(self.V(x.to(torch.bfloat16)))
        return output


class LowRankLayerEval(nn.Module):
    """given a linear layer find low rank decomposition"""
    def __init__(self, desired_rank, weight, require_grad=True):
        super().__init__()
        self.rank = desired_rank
        self.U = nn.Linear(desired_rank, weight.shape[0], bias=False, dtype=torch.bfloat16).to(weight.device)
        self.V = nn.Linear(weight.shape[1], desired_rank, bias=False, dtype=torch.bfloat16).to(weight.device)

        if require_grad == False:
            self.U.weight.requires_grad = False
            self.V.weight.requires_grad = False
        else:
            self.U.weight.requires_grad = True
            self.V.weight.requires_grad = True



    def forward(self, x):
        output = self.U(self.V(x.to(torch.bfloat16)))
        return output


#### rank reduction.

In [11]:
import time
import heapq
import torch
import torch.nn as nn
#from .data_utils import get_c4, get_wikitext2
#from .LowRankLayer import LowRankLayer, LowRankLayerEval
from tqdm import tqdm
import numpy as np

def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if "c4" in name:
        return get_c4(nsamples, seed, seqlen, tokenizer)

def find_layers(module, layers=[nn.Linear], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

def do_rank_reduction(args, model, tokenizer, rank_pruning, min_ratio, logger = None, load_only = False):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    reduced_rank, total_rank = 0, 0

    logger.info("*************** Pruning Model Started ***************")
    for i in range(len(layers)):
        layer = layers[i]
        attention = getattr(layer, 'self_attn')
        for key, module in attention.named_modules():
            if "proj" in key:
                name = 'self_attn.' + key
                rank = min(module.weight.shape[0], module.weight.shape[1])
                k = rank - rank_pruning[i][name]
                if (rank_pruning[i][name] / rank) > min_ratio:
                    if load_only is False:   l = LowRankLayer(k, module.weight.to(torch.float32), True)
                    else: l = LowRankLayerEval(k, module.weight.to(torch.float32), True)
                    setattr(attention, key, l)
                    del module
                    reduced_rank += rank_pruning[i][name]
                else:
                    k = rank
                    module.weight.requires_grad = False

                total_rank += rank
                logger.info(f"layer.{i}.{name:50} Desired/Total: {k}/{rank} ({(k/rank * 100):2f} %)")

        mlp = getattr(layer, 'mlp')

        for key, module in mlp.named_modules():
            if "proj" in key:
                name = 'mlp.' + key
                rank = min(module.weight.shape[0], module.weight.shape[1])
                k = rank - rank_pruning[i][name]
                if (rank_pruning[i][name] / rank) > min_ratio:
                    if load_only is False:   l = LowRankLayer(k, module.weight.to(torch.float32), True)
                    else: l = LowRankLayerEval(k, module.weight.to(torch.float32), True)
                    setattr(mlp, key, l)
                    del module
                    reduced_rank += rank_pruning[i][name]
                else:
                    k = rank
                    module.weight.requires_grad = False

                total_rank += rank
                logger.info(f"layer.{i}.{name:50} Desired/Total: {k}/{rank}  ({(k/rank * 100):2f} %)")

        import gc; gc.collect()
        torch.cuda.empty_cache()
    logger.info("*************** Pruning Model Completed ***************")
    return (reduced_rank, total_rank)

def do_low_rank(weight, desired_rank, debug=False):

    results = torch.svd(weight)
    U = results[0][:, :desired_rank]
    S = results[1][:desired_rank]
    V = results[2][:, :desired_rank]

    weight_approx = U @ torch.diag(S) @ V.T
    return weight_approx

def do_rank_reduction_merge(args, model, tokenizer, rank_pruning, min_ratio, logger = None):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    reduced_rank, total_rank = 0, 0

    logger.info("*************** Pruning Model Started ***************")
    for i in range(len(layers)):
        layer = layers[i]
        attention = getattr(layer, 'self_attn')
        for key, module in attention.named_modules():
            if "proj" in key:
                name = 'self_attn.' + key
                rank = min(module.weight.shape[0], module.weight.shape[1])
                k = rank - rank_pruning[i][name]

                if (rank_pruning[i][name] / rank) > min_ratio:
                    _W = module.weight.clone().data.to(torch.float32)
                    _W_approx = do_low_rank(_W, k)
                    module.weight.data = _W_approx.to(torch.bfloat16)
                    reduced_rank += rank_pruning[i][name]
                else:
                    k = rank
                    module.weight.requires_grad = False
                total_rank += rank
                logger.info(f"layer.{i}.{name:40} Desired/Total: {k}/{rank}")

        mlp = getattr(layer, 'mlp')

        for key, module in mlp.named_modules():
            if "proj" in key:
                name = 'mlp.' + key
                rank = min(module.weight.shape[0], module.weight.shape[1])
                k = rank - rank_pruning[i][name]
                if (rank_pruning[i][name] / rank) > min_ratio:
                    _W = module.weight.clone().data.to(torch.float32)
                    _W_approx = do_low_rank(_W, k)
                    module.weight.data = _W_approx.to(torch.bfloat16)
                    reduced_rank += rank_pruning[i][name]
                else:
                    k = rank
                    module.weight.requires_grad = False

                total_rank += rank
                logger.info(f"layer.{i}.{name:40} Desired/Total: {k}/{rank}")

    logger.info("*************** Pruning Model Completed ***************")
    return (reduced_rank, total_rank)


#### rank utils.

In [15]:
import time
import heapq
import torch
import torch.nn as nn
#from .data_utils import get_c4, get_wikitext2
#from .LowRankLayer import LowRankLayer, LowRankLayerEval
from tqdm import tqdm
import numpy as np
import wandb

def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if "c4" in name:
        return get_c4(nsamples, seed, seqlen, tokenizer)

def find_layers(module, layers=[nn.Linear], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res




def prepare_calibration_input(model, dataloader, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    # dev = model.hf_device_map["model.embed_tokens"]
    if "model.embed_tokens" in model.hf_device_map:
        device = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
    inps.requires_grad = False
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(device))
        except ValueError:
            pass
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask, position_ids

def rank_analysis_weight(args, model, tokenizer, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    layers_singular_value = {}
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        layers_singular_value[i] = {}
        # Perform Singular Value Decomposition (SVD)
        for name in subset:
            W = subset[name].weight.data
            _, singular_values, _ = torch.svd(W.to(torch.float32))
            layers_singular_value[i][name] = singular_values

    return layers_singular_value

def get_singular_values(args, model):
    layers = model.model.layers
    layers_singular_value = {}
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)


        # Perform Singular Value Decomposition (SVD)
        for name in subset:
            W = subset[name].weight.data
            _, singular_values, _ = torch.svd(W.to(torch.float32))
            layers_singular_value[f"layer.{i}.{name}"] = singular_values

    return layers_singular_value


def get_grad_singular_values(args, model):
    layers = model.model.layers
    layers_singular_value = {}
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)


        # Perform Singular Value Decomposition (SVD)
        for name in subset:
            W = subset[name].weight.grad
            _, singular_values, _ = torch.svd(W.to(torch.float32))
            layers_singular_value[f"layer.{i}.{name}"] = singular_values

    return layers_singular_value

def do_low_rank(weight, desired_rank, debug=False):

    results = torch.svd(weight)
    U = results[0][:, :desired_rank]
    S = results[1][:desired_rank]
    V = results[2][:, :desired_rank]

    loss = torch.nn.L1Loss()
    if debug:
        print(f"Shape is {weight.shape} and shape is {weight.dtype} => desired rank {desired_rank}")

    weight_approx = U @ torch.diag(S) @ V.T

    if debug:
        print(f"New matrix has shape {weight_approx.shape}")

    assert weight_approx.shape[0] == weight.shape[0] and weight_approx.shape[1] == weight.shape[1]
    weight_approx = torch.nn.Parameter(weight_approx)

    with torch.no_grad():
        error = loss(weight, weight_approx)
    return weight_approx, error

def rank_reduction_weight(args, model, tokenizer, rank_pruning, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    layers_singular_value = {}

    for i in tqdm(range(len(layers))):
        layer = layers[i]
        subset = find_layers(layer)

        for name in subset:
            W = subset[name].weight.data
            k = min(W.shape[0], W.shape[1]) - rank_pruning[i][name]
            approx_w, error = do_low_rank(W.to(torch.float32), k, True)
            print(f"layer.{i}.{name} ({k}): {error}")

            subset[name].weight.data = approx_w.data.to(torch.bfloat16)

        if i == 0:
            break

    print("Pruning completed")
    return None, None

def rank_reduction_weight_wrapper(args, model, tokenizer, rank_pruning, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    layers_singular_value = {}

    for i in tqdm(range(len(layers))):
        layer = layers[i]

        attention = getattr(layer, 'self_attn')
        for key, module in attention.named_modules():
            if "proj" in key:
                name = 'self_attn.' + key
                k = min(module.weight.shape[0], module.weight.shape[1]) - rank_pruning[i][name]
                l = LowRankLayer(k, module.weight.to(torch.float32))
                setattr(attention, key, l)
                del module
        mlp = getattr(layer, 'mlp')
        for key, module in mlp.named_modules():
            if "proj" in key:
                name = 'mlp.' + key
                k = min(module.weight.shape[0], module.weight.shape[1]) - rank_pruning[i][name]
                l = LowRankLayer(k, module.weight.clone().to(torch.float32))
                setattr(mlp, key, l)
                del module
        # break
    print("Pruning completed")

def rank_reduction_weight_wrapper_selective(args, model, tokenizer, rank_pruning, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    layers_singular_value = {}
    reduced_rank, total_rank = 0, 0
    for i in tqdm(range(len(layers))):
        layer = layers[i]

        attention = getattr(layer, 'self_attn')
        for key, module in attention.named_modules():
            if "proj" in key:
                name = 'self_attn.' + key
                rank = min(module.weight.shape[0], module.weight.shape[1])
                k = rank - rank_pruning[i][name]
                if (rank_pruning[i][name] / rank) * 100 > 40:
                    l = LowRankLayer(k, module.weight.to(torch.float32), False)
                    setattr(attention, key, l)
                    del module
                    reduced_rank += rank_pruning[i][name]
                total_rank += rank
        mlp = getattr(layer, 'mlp')
        for key, module in mlp.named_modules():
            if "proj" in key:
                name = 'mlp.' + key
                rank = min(module.weight.shape[0], module.weight.shape[1])
                k = rank - rank_pruning[i][name]
                if (rank_pruning[i][name] / rank) * 100 > 40:
                    l = LowRankLayer(k, module.weight.clone().to(torch.float32), False)
                    setattr(mlp, key, l)
                    del module
                    reduced_rank += rank_pruning[i][name]
                total_rank += rank
        # break
    print(f">>>>>>>>>>>>>>> Pruning completed with Rank reduced : {(reduced_rank/total_rank) * 100}")
    return (reduced_rank/total_rank) * 100



def rank_reduction_weight_wrapper_selective_eval(args, model, tokenizer, rank_pruning, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    layers_singular_value = {}
    reduced_rank, total_rank = 0, 0
    for i in tqdm(range(len(layers))):
        layer = layers[i]

        attention = getattr(layer, 'self_attn')
        for key, module in attention.named_modules():
            if "proj" in key:
                name = 'self_attn.' + key
                rank = min(module.weight.shape[0], module.weight.shape[1])
                k = rank - rank_pruning[i][name]
                if (rank_pruning[i][name] / rank) * 100 > 40:
                    l = LowRankLayerEval(k, module.weight.to(torch.float32), False)
                    setattr(attention, key, l)
                    del module
                    reduced_rank += rank_pruning[i][name]
                total_rank += rank
        mlp = getattr(layer, 'mlp')
        for key, module in mlp.named_modules():
            if "proj" in key:
                name = 'mlp.' + key
                rank = min(module.weight.shape[0], module.weight.shape[1])
                k = rank - rank_pruning[i][name]
                if (rank_pruning[i][name] / rank) * 100 > 40:
                    l = LowRankLayerEval(k, module.weight.clone().to(torch.float32), False)
                    setattr(mlp, key, l)
                    del module
                    reduced_rank += rank_pruning[i][name]
                total_rank += rank
        # break
    print(f">>>>>>>>>>>>>>> Pruning completed with Rank reduced : {(reduced_rank/total_rank) * 100}")
    return (reduced_rank/total_rank) * 100


def rank_reduction_dynamic_pruning(args, model, device, file_name):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    rank_pruning = {}
    total_rank, error_thresold_att, error_thresold_ffn = 0, 5e-4, 5e-4
    pruning_bucket = [0.95, 0.9, 0.85, 0.8, 0.7, 0.75, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.2, 0.1]

    for i in tqdm(range(len(layers))):
        layer = layers[i]
        subset = find_layers(layer)
        rank_pruning[i] = {}
        for name in subset:
            W = subset[name].weight.clone().data
            if "mlp" in name: error_thresold = error_thresold_ffn
            else: error_thresold = error_thresold_att
            rank_pruning[i][name] = 0
            for prune_ratio in pruning_bucket:
                desired_rank = int(min(W.shape[0], W.shape[1]) * prune_ratio)
                approx_w, error = do_low_rank(W.to(torch.float32), desired_rank, False)
                if error > error_thresold:
                    break
                else:
                    rank_pruning[i][name] = min(W.shape[0], W.shape[1]) - desired_rank
            total_rank += int(min(W.shape[0], W.shape[1]))
            print(f"layer.{i}.{name} ({rank_pruning[i][name]}): {error}")

    pruned_rank = 0
    for i in tqdm(range(len(layers))):
        layer = layers[i]
        subset = find_layers(layer)
        for name in subset:
            pruned_rank += rank_pruning[i][name]
    print("Pruning completed")
    torch.save(rank_pruning, "/data/adative_rank_attention_ffn.pt")
    print(f"Rank Reduction: {(pruned_rank/total_rank)* 100:.3f} %", file=file_name, flush=True)
    return rank_pruning

#### eval.

In [16]:
import os
import time
import torch
import torch.nn as nn
import tqdm as tqdm
from loguru import logger
# Import get_loaders function from data module within the same directory
#from .data_utils import get_c4, get_wikitext2


def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if "c4" in name:
        return get_c4(nsamples, seed, seqlen, tokenizer)


# Function to evaluate perplexity (ppl) on a specified model and tokenizer
def eval_ppl(model, tokenizer, device=torch.device("cuda:0"), dataset="wikitext2"):
    # Set dataset

    # Print status
    logger.info(f"Evaluating on {dataset} .....")

    if os.path.exists("./data/test_loader.pt"):
        testloader = torch.load("./data/test_loader.pt")
    else:
        # Get the test loader
        _, testloader = get_loaders(
            dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer
        )
        torch.save(testloader, "./data/test_loader.pt")

    # Evaluate ppl in no grad context to avoid updating the model
    with torch.no_grad():
        ppl = eval_ppl_dataset(model, testloader, 1, device)
    return ppl

# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_dataset(model, testenc, bs=1, device=None):
    # Get input IDs
    testenc = testenc.input_ids

    # Calculate number of samples
    nsamples = testenc.numel() // model.seqlen

    # List to store negative log likelihoods
    nlls = []

    # nsamples = 10 #Sanity check

    # Loop through each batch
    for i in range(0, nsamples, bs):

        # Calculate end index
        j = min(i+bs, nsamples)

        # Prepare inputs and move to device
        inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].cuda()
        inputs = inputs.reshape(j-i, model.seqlen)

        s_time = time.time()
        # Forward pass through the model
        lm_logits = model(inputs).logits
        e_time = time.time()

        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        # Compute loss
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

        # Calculate negative log likelihood
        neg_log_likelihood = loss.float() * model.seqlen * (j-i)

        # Append to list of negative log likelihoods
        nlls.append(neg_log_likelihood)

        if i % 20 == 0: logger.info(f"Evaluated samples: {i}/{nsamples}")

    # Compute perplexity
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

    # print(ppl)
    # Empty CUDA cache to save memory
    # torch.cuda.empty_cache()

    return ppl.item()