#1. Loading the model

In [1]:
# !pip install -U datasets
# !pip install lm-eval==0.3.0

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import random
import torch
import os
import torch.nn as nn
# from tqdm import tqdm
from tqdm.notebook import tqdm
import numpy as np
import time
from lm_eval.base import BaseLM

In [3]:
from dataclasses import dataclass
@dataclass
class Args:
  alpha: float = 0.5
  n_calib_samples: int = 16
  rank_align: int = 1
  use_cache: bool = True
  param_ratio_target: float = 0.5
  ppl_target: float = 40
  act_aware: bool = True
  sigma_fuse: str = "UV"
  model_id: str = "facebook/opt-125m"
  use_bos: bool = False
  eval_ppl: str = "wikitext2"

In [4]:
args = Args()

In [5]:
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
model = AutoModelForCausalLM.from_pretrained(args.model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)

## Eval utilities

In [6]:
class EvalLM(BaseLM):
    def __init__(
        self,
        model,
        tokenizer,
        # device="cuda:0",
        batch_size=1,
    ):
        super().__init__()

        # assert isinstance(device, str)
        assert isinstance(batch_size, int)

        # self._device = torch.device(device)
        self._device = model.device

        # self.model = model.to(self.device)
        self.model = model
        self.model.eval()

        self.tokenizer = tokenizer

        self.vocab_size = self.tokenizer.vocab_size

        self.batch_size_per_gpu = batch_size  # todo: adaptive batch size

        self.seqlen = 2048

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        try:
            return self.model.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.model.config.max_position_embeddings

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device

    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)

    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model
        """
        with torch.no_grad():
            return self.model(inps)[0][:, :, :50257]

    def _model_generate(self, context, max_length, eos_token_id):
        return self.model.generate(context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False)


In [7]:
def get_eval_loaders(name, tokenizer):
    if "wikitext2" in name:
        testdata = load_dataset(
            "wikitext",
            "wikitext-2-raw-v1",
            split="test",
        )
        testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
        return testenc
    if "ptb" in name:
        valdata = load_dataset(
            "ptb_text_only",
            "penn_treebank",
            split="validation",
        )
        testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
        return testenc
    if "c4" in name:
        testdata = load_dataset(
            "allenai/c4",
            "allenai--c4",
            data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
            split="validation",
        )
        testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
        return testenc
    raise NotImplementedError

In [8]:
@torch.no_grad()
def evaluate_model(
    model,
    tokenizer,
    model_name,
    eval_ppl="",
    num_fewshot=0,
    limit=-1,
    batch_size=1,
    use_bos=False,
):
    """
    model: model name
    limit: number of test samples for debug, set to -1 is no limit
    tasks: str tasks are split by ,
    num_fewshot: Number of examples in few-shot context
    eval_ppl: str datasets are split by , such as 'wikitext2,ptb,c4'
    """
    lm = EvalLM(model, tokenizer, batch_size=batch_size)

    results = {}
    if eval_ppl:
        for dataset in eval_ppl.split(","):
            cache_testloader = f"/tmp/{dataset}_testloader_{model_name.replace('/', '_')}_all.cache"
            if os.path.exists(cache_testloader):
                testloader = torch.load(cache_testloader, weights_only=False)
                # print(f"load calibration from {cache_testloader}")
            else:
                testloader = get_eval_loaders(dataset, tokenizer)
                torch.save(testloader, cache_testloader)
            # print(dataset)
            testenc = testloader.input_ids
            if use_bos:
                lm.seqlen -= 1
            nsamples = testenc.numel() // lm.seqlen
            use_cache = lm.model.config.use_cache
            lm.model.config.use_cache = False
            lm.model.eval()
            nlls = []

            for i in tqdm(range(nsamples)):
                batch = testenc[:, (i * lm.seqlen) : ((i + 1) * lm.seqlen)].to(lm.device)
                if use_bos:
                    bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch.size(dim=0)).to(lm.device)
                    batch = torch.cat([bos_tokens_tensor, batch], dim=1)
                outputs = lm.model.model(batch)
                hidden_states = outputs[0]  # .to(lm.model.lm_head.weight.device)
                if use_bos:
                    hidden_states = hidden_states[:, 1:, :]
                logits = lm.model.lm_head(hidden_states)  # .contiguous()
                shift_logits = logits[:, :-1, :]  # .contiguous()
                shift_labels = testenc[:, (i * lm.seqlen) : ((i + 1) * lm.seqlen)][:, 1:].to(lm.device)
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1),
                )
                neg_log_likelihood = loss.float() * lm.seqlen
                nlls.append(neg_log_likelihood)
                if i == limit:
                    break
                # if i == 1:
                #     print(
                #         "memory_allocated",
                #         i,
                #         torch.cuda.memory_allocated() / 1024 / 1024,
                #         "max memory_allocated",
                #         torch.cuda.max_memory_allocated() / 1024**2,
                #     )

            ppl = torch.exp(torch.stack(nlls).sum() / (len(nlls) * lm.seqlen))
            print(dataset, ppl.item())
            lm.model.config.use_cache = use_cache
            results[dataset] = ppl.item()

    return results

## Evaluation before compression

In [9]:
result = evaluate_model(
        model,
        tokenizer,
        args.model_id,
        eval_ppl=args.eval_ppl,
        limit=-1,
        use_bos=args.use_bos,
    )
print(result)

  0%|          | 0/140 [00:00<?, ?it/s]

wikitext2 27.653562545776367
{'wikitext2': 27.653562545776367}


#2. Load calibration data

In [10]:
# we will use c4 as calibration data
def get_calib_data(tokenizer, model_id, nsamples=16, seqlen=2048, seed=3, use_bos=False):
  cache_file = f"cache/c4_{model_id.replace('/','_')}_{nsamples}_{seqlen}_{seed}_bos{use_bos}.pt"
  print(f"cache_file={cache_file}")
  if not os.path.exists("cache"):
        os.makedirs("cache")
  if os.path.exists(cache_file):
      traindataset = torch.load(cache_file)
      return traindataset

  traindata = load_dataset(
            "allenai/c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train"
  )
  tot_text = "\n\n".join(traindata["text"])
  print(f"tot_text={len(tot_text)}")
  traindataset = []
  for _ in range(nsamples):
      i = random.randint(0, len(tot_text) - seqlen - 1)
      j = i + seqlen * 10
      txt = tot_text[i:j]
      ind = txt.find(".")
      txt = txt[ind + 1 :].strip()
      if use_bos:
          txt = tokenizer.bos_token + txt
      trainenc = tokenizer(txt, return_tensors="pt")
      inp = trainenc.input_ids[:, :seqlen]
      attention_mask = torch.ones_like(inp)
      traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
  torch.save(traindataset, cache_file)
  return traindataset

In [11]:
calib_loader = get_calib_data(tokenizer, model.name_or_path, args.n_calib_samples)

cache_file=cache/c4_facebook_opt-125m_16_2048_3_bosFalse.pt


#3. Compute Fisher Information

In [12]:
def calib_fisher_info(model, calib_loader, use_cache=True):
  model_id = model.config.name_or_path
  cache_file = f"cache/{model_id.replace('/','_')}_calib_fisher_info.pt"

  if os.path.exists(cache_file) and use_cache:
        all_fisher_info = torch.load(cache_file, map_location="cpu")
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                module.fisher_info = all_fisher_info[name].to(module.weight.device)
        return

  model.eval()

  for name, module in model.named_modules():
      if isinstance(module, nn.Linear):
          module.fisher_info = 0

  # get fisher info
  for batch in tqdm(calib_loader):
      input_ids = batch["input_ids"][:, :-1].to(model.device)
      labels = batch["input_ids"][:, 1:].to(model.device)
      out = model(input_ids=input_ids, labels=labels)
      out[0].backward()
      for name, module in model.named_modules():
          if isinstance(module, nn.Linear):
              module.fisher_info += module.weight.grad.detach().pow(2).mean(0) # mean over (dL/dw)**2 (why a mean here ?) (should it be sum instead ?)
      model.zero_grad()

  for name, module in model.named_modules():
      if isinstance(module, nn.Linear):
          module.fisher_info = module.fisher_info.div(len(calib_loader)).sqrt() # dividing by length of dataset and square root.

  # remove and save fisher_info
  all_fisher_info = {}
  for name, module in model.named_modules():
      if isinstance(module, nn.Linear):
          module._forward_hooks.clear()
          all_fisher_info[name] = module.fisher_info
  torch.save(all_fisher_info, cache_file)

In [13]:
calib_fisher_info(model, calib_loader)

#4. Use Fisher Information to compress the Model

In [14]:
@torch.no_grad()
def evaluate_perplexity(model, dataset, limit):
    """
    dataset: input ids tensor of shape [batch, sequence length]
    """
    nsamples, seqlen = dataset.size()

    nlls = []

    for i in range(nsamples):
        if i == limit:
            break
        input_ids = dataset[i : i + 1, :-1].to(model.device)
        labels = dataset[i : i + 1, 1:].contiguous()
        logits = model(input_ids=input_ids)[0]
        shift_logits = logits[:, :, :]
        shift_labels = labels.to(model.device)
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )
        neg_log_likelihood = loss.float() * seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (len(nlls) * seqlen))
    return ppl.item()

In [15]:
class SVDLinear(nn.Module):
    def __init__(self, U, S, V, bias=None, sigma_fuse="UV") -> None:
        super().__init__()
        self.ALinear = nn.Linear(U.size(1), U.size(0), bias=bias is not None)

        if bias is not None:
            self.ALinear.bias.data = bias
        self.BLinear = nn.Linear(V.size(1), V.size(0), bias=False)
        self.truncation_rank = S.size(0)
        if sigma_fuse == "UV":
            self.ALinear.weight.data = U.mul(S.sqrt()).contiguous()
            self.BLinear.weight.data = V.t().mul(S.sqrt().view(-1, 1)).contiguous()
        elif sigma_fuse == "U":
            self.ALinear.weight.data = U.mul(S).contiguous()
            self.BLinear.weight.data = V.t().contiguous()
        elif sigma_fuse == "V":
            self.ALinear.weight.data = U.contiguous()
            self.BLinear.weight.data = V.t().mul(S.view(-1, 1)).contiguous()

    @staticmethod
    def from_linear(
        linear: nn.Linear,
        param_ratio: float,
        act_aware=False,
        ic_split=1,
        oc_split=1,
        alpha=1,
        sigma_fuse="UV",
        rank_align=1,
    ):
        # if param_ratio >= 1:
        #     return linear
        n_params = linear.weight.numel()
        compressed_params = int(n_params * param_ratio)
        assert ic_split == 1 or oc_split == 1
        rank = compressed_params // (linear.in_features + linear.out_features)
        # rank align
        rank = int(np.ceil(rank / rank_align) * rank_align)

        # print("rank", rank)
        w = linear.weight.data.float()
        if act_aware:
            scaling_diag_matrix = 1  # avoid zero division
            if hasattr(linear, "scaling_diag_matrix"):
                # print("WARNING: scaling_diag_matrix is used")
                scaling_diag_matrix *= linear.scaling_diag_matrix**alpha
                # scaling_diag_matrix *= linear.scaling_diag_matrix**0.5
            if hasattr(linear, "fisher_info"):
                scaling_diag_matrix *= linear.fisher_info**alpha
                # scaling_diag_matrix *= linear.fisher_info**1
            # if not (scaling_diag_matrix == scaling_diag_matrix).all():
            if not isinstance(scaling_diag_matrix, torch.Tensor):
                breakpoint()
            scaling_diag_matrix += 1e-6  # avoid zero division
            w = w * scaling_diag_matrix.view(1, -1)
        Us = []
        Ss = []
        Vs = []
        try:
            U, S, V = torch.svd_lowrank(w, q=rank)
        except:
            print(f"svd failed for {linear}, disable act_aware")
            return nn.Linear(linear.in_features, linear.out_features).to(linear.weight.dtype).to(linear.weight.device)
        if act_aware:
            V = V / scaling_diag_matrix.view(-1, 1)
        Us = [U]
        Ss = [S]
        Vs = [V]

        if linear.bias is not None:
            bias = linear.bias.data
        else:
            bias = None

        # nan or inf check
        for S in Ss:
            if (S != S).any():
                print("nan in S")
                return (
                    nn.Linear(linear.in_features, linear.out_features).to(linear.weight.dtype).to(linear.weight.device)
                )
        for U in Us:
            if (U != U).any():
                print("nan in U")
                return (
                    nn.Linear(linear.in_features, linear.out_features).to(linear.weight.dtype).to(linear.weight.device)
                )
        for V in Vs:
            if (V != V).any():
                print("nan in V")
                return (
                    nn.Linear(linear.in_features, linear.out_features).to(linear.weight.dtype).to(linear.weight.device)
                )

        assert len(Us) == len(Ss) == len(Vs) == 1
        new_linear = SVDLinear(Us[0], Ss[0], Vs[0], bias, sigma_fuse)
        new_linear.to(linear.weight.dtype)
        return new_linear

    def forward(self, inp):
        # compute USV^Tx + b
        y = self.BLinear(inp)
        y = self.ALinear(y)
        return y


In [16]:
# Calibrate sensitivity

@torch.no_grad()
def calib_sensitivity_ppl(model, calib_loader, args, use_cache=True):
  model_id = model.config._name_or_path
  cache_file = f"cache/{model_id.replace('/','_')}_sensitivity_fisher_{args.alpha}_{args.n_calib_samples}_c4.pt"
  if os.path.exists(cache_file) and use_cache:
      sensitivity_dict = torch.load(cache_file, map_location="cpu")
      return sensitivity_dict

  model.eval()
  full_name_dict = {module: name for name, module in model.named_modules()}
  linear_info = {}
  modules = [model]
  while len(modules) > 0:
      submodule = modules.pop()
      for name, raw_linear in submodule.named_children():
          if isinstance(raw_linear, nn.Linear):
              full_name = full_name_dict[raw_linear]
              linear_info[raw_linear] = {
                  "father": submodule,
                  "name": name,
                  "full_name": full_name,
              }
          else:
              modules.append(raw_linear)

  sensitivity_dict = {}

  param_ratio_candidates = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
  input_ids = torch.cat([_["input_ids"] for _ in calib_loader], 0)
  print(f"input_ids.shape={input_ids.shape}")
  pbar = tqdm(total=len(linear_info) * len(param_ratio_candidates))
  for raw_linear, info in linear_info.items():
      sensitivity_dict[info["full_name"]] = {}
      for param_ratio in param_ratio_candidates:
          svd_linear = SVDLinear.from_linear(
              raw_linear,
              param_ratio=param_ratio,
              alpha=args.alpha,
              act_aware=True,
              rank_align=args.rank_align,
          )
          setattr(info["father"], info["name"], svd_linear)

          ppl = evaluate_perplexity(model, input_ids, args.n_calib_samples)
          sensitivity_dict[info["full_name"]][param_ratio] = ppl
          print(f"{info['full_name']} {param_ratio} {ppl}")
          pbar.update(1)
      setattr(info["father"], info["name"], raw_linear)
  torch.save(sensitivity_dict, cache_file)
  return sensitivity_dict

In [17]:
sensitivity = calib_sensitivity_ppl(model, calib_loader, args, args.use_cache)

In [18]:
def binary_search_truncation_rank(model, sensitivity_dict, calib_loader, args):
    module_dict = {name: module for name, module in model.named_modules()}
    full_name_dict = {module: name for name, module in model.named_modules()}
    linear_info = {}
    modules = [model]
    while len(modules) > 0:
        submodule = modules.pop()
        for name, raw_linear in submodule.named_children():
            if isinstance(raw_linear, nn.Linear):
                full_name = full_name_dict[raw_linear]

                linear_info[raw_linear] = {
                    "father": submodule,
                    "name": name,
                    "full_name": full_name,
                }
            else:
                modules.append(raw_linear)


    ratio_target = args.param_ratio_target
    default_param_ratio = 1

    print(
        f"=== target: ppl={args.ppl_target}, ratio_target={ratio_target} ==="
    )

    sensitivity_list = []
    for layername, v in sensitivity_dict.items():
        for param_ratio, ppl in v.items():
            if param_ratio >= 1:
                # we need to compress the weights, so parameter ratio should be less than 1
                continue
            sensitivity_list.append((layername, param_ratio, ppl))
    sorted_sensitive_list = sorted(sensitivity_list, key=lambda x: -x[2])

    # binary search
    high = len(sorted_sensitive_list) - 1
    low = 0
    assert args.ppl_target > 0 or ratio_target > 0

    input_ids = torch.cat([_["input_ids"] for _ in calib_loader], 0)
    while low < high:
        mid = (low + high) // 2
        layers_min_ratio = {layername: default_param_ratio for layername in sensitivity_dict.keys()}
        for layername, param_ratio, ppl in sorted_sensitive_list[mid:]:
            layers_min_ratio[layername] = min(layers_min_ratio[layername], param_ratio)
        tot_params = 0
        compress_params = 0
        if args.ppl_target > 0:
            # assert not args.compress_kv_cache, "ppl_target is not supported when compressing kv_cache now"
            for layername, param_ratio in layers_min_ratio.items():
                raw_linear = module_dict[layername]
                info = linear_info[raw_linear]
                svd_linear = SVDLinear.from_linear(
                    raw_linear,
                    param_ratio=param_ratio,
                    alpha=args.alpha,
                    act_aware=args.act_aware,
                    sigma_fuse=args.sigma_fuse,
                    rank_align=args.rank_align,
                )
                setattr(info["father"], info["name"], svd_linear)
                tot_params += raw_linear.weight.numel()
                compress_params += raw_linear.weight.numel() * param_ratio
            ppl = evaluate_perplexity(model, input_ids, args.n_calib_samples)
            param_ratio = compress_params / tot_params
            msg = f"low={low} mid={mid}, high={high}, ppl={ppl}, param_ratio={param_ratio}"
            print(msg)
            if ppl < args.ppl_target:
                high = mid
            else:
                low = mid + 1
        else:
            for layername, param_ratio in layers_min_ratio.items():
                raw_linear = module_dict[layername]
                tot_params += raw_linear.weight.numel()
                compress_params += raw_linear.weight.numel() * param_ratio
            now_ratio = compress_params / tot_params

            msg = f"low={low} mid={mid}, high={high}, now_ratio={now_ratio}, params=({compress_params}/{tot_params})"
            print(msg)
            if now_ratio > ratio_target:
                high = mid
            else:
                low = mid + 1

    print(f"=== Searching done, decomposing layers... ===")
    layers_min_ratio = {layername: default_param_ratio for layername in sensitivity_dict.keys()}
    for layername, param_ratio, ppl in sorted_sensitive_list[mid:]:
        if layers_min_ratio[layername] is None:
            layers_min_ratio[layername] = param_ratio
        else:
            layers_min_ratio[layername] = min(layers_min_ratio[layername], param_ratio)
    st = time.time()
    for layername, param_ratio in tqdm(layers_min_ratio.items()):
        # set ratio
        raw_linear = module_dict[layername]
        info = linear_info[raw_linear]
        if param_ratio == default_param_ratio:
            svd_linear = raw_linear
        else:
            svd_linear = SVDLinear.from_linear(
                raw_linear,
                param_ratio=param_ratio,
                alpha=args.alpha,
                act_aware=args.act_aware,
                sigma_fuse=args.sigma_fuse,
                rank_align=args.rank_align,
            )
            raw_linear.to("cpu")
        setattr(info["father"], info["name"], svd_linear)
        # print(f"decompose {info['full_name']} with ratio {param_ratio}")
    ed = time.time()
    print(f"decompose time: {ed-st}")

In [19]:
binary_search_truncation_rank(model, sensitivity, calib_loader, args)

=== target: ppl=40, ratio_target=0.5 ===
low=0 mid=218, high=437, ppl=96.69224548339844, param_ratio=0.863457330415755
low=219 mid=328, high=437, ppl=48.57245635986328, param_ratio=0.9360254625024865
low=329 mid=383, high=437, ppl=42.33979797363281, param_ratio=0.9661030435647502
low=384 mid=410, high=437, ppl=41.162567138671875, param_ratio=0.9804257012134473
low=411 mid=424, high=437, ppl=40.30240249633789, param_ratio=0.9923612492540281
low=425 mid=431, high=437, ppl=40.2237663269043, param_ratio=0.996180624627014
low=432 mid=434, high=437, ppl=39.94973373413086, param_ratio=0.9980903123135069
low=432 mid=433, high=434, ppl=40.21885681152344, param_ratio=0.9976128903918837
=== Searching done, decomposing layers... ===


  0%|          | 0/73 [00:00<?, ?it/s]

decompose time: 0.14230942726135254


#5. Evaluate model

In [20]:
result = evaluate_model(
        model,
        tokenizer,
        args.model_id,
        eval_ppl=args.eval_ppl,
        limit=-1,
        use_bos=args.use_bos,
    )
print(result)

  0%|          | 0/140 [00:00<?, ?it/s]

wikitext2 27.796283721923828
{'wikitext2': 27.796283721923828}


In [21]:
model

OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05,