In [None]:
# install dependencies
# if gpu out of ram then restart session
!pip install datasets
!pip install minference
!pip install fuzzywuzzy
!pip install rouge
!pip install flash_attn

In [None]:
# login to hugginface account using access token, make sure you have access to llama and mistral models
!huggingface-cli login

In [None]:
#install dependencies
!pip install tiktoken
!pip install --upgrade datasets fsspec
!rm -rf ~/.cache/huggingface/datasets

Single layer attention approximation BalanceKV

In [None]:
## next, execute this cell

import time
import math
import re
import string
from collections import Counter
import torch
import transformers
from typing import List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from datasets import load_dataset
from flash_attn import flash_attn_func



DATASET2PROMPT = {"triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}"}
DATASET2MAXLEN = {"triviaqa": 32,  }


def avg_stddev(data):
    average = sum(data) / len(data)
    variance = sum((x - average) ** 2 for x in data) / len(data)
    std_dev = math.sqrt(variance)
    return average,std_dev


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def truncate_input(input: list, max_length: int, manner="middle"):
    if max_length < 0:
        return input
    if len(input) <= max_length:
        return input
    if manner == "middle":
        split = max_length // 2
        return input[0:split] + input[-split:]
    else:
        return None


def truncate_by_tokens(input, tok, max_tokens, manner: str = "middle"):
    tokens = tok.encode(input)
    len_before = len(tokens)
    tokens = truncate_input(tokens, max_length=max_tokens, manner=manner)
    len_after = len(tokens)
    assert len_after <= len_before
    assert len_after <= max_tokens or max_tokens < 0
    return tokens




def indexing(key, sort_idx, block_size, value=None):
  indices = sort_idx.unsqueeze(-1).expand(-1, -1, -1, key.shape[-1])
  new_n = math.ceil(sort_idx.shape[-1] / block_size) * block_size
  if new_n < sort_idx.shape[-1]:
    import pdb; pdb.set_trace();
  out_key = torch.nn.functional.pad(key.gather(2, indices), (0,0,0,new_n-sort_idx.shape[-1]), mode='constant', value=0.)
  out_value = None
  if value is not None:
    out_value = torch.nn.functional.pad(value.gather(2, indices), (0,0,0,new_n-sort_idx.shape[-1]), mode='constant', value=0.)
  return out_key, out_value


def indexing(key, sort_idx, block_size, value=None):
  indices = sort_idx.unsqueeze(-1).expand(-1, -1, -1, key.shape[-1])
  new_n = math.ceil(sort_idx.shape[-1] / block_size) * block_size
  if new_n < sort_idx.shape[-1]:
    import pdb; pdb.set_trace();
  out_key = torch.nn.functional.pad(key.gather(2, indices), (0,0,0,new_n-sort_idx.shape[-1]), mode='constant', value=0.)
  out_value = None
  if value is not None:
    out_value = torch.nn.functional.pad(value.gather(2, indices), (0,0,0,new_n-sort_idx.shape[-1]), mode='constant', value=0.)
  return out_key, out_value

def balanced_walk(key, rng, gamma_, temp_, beta_, itrs, block_size, value=None, sort_idx=None, query=None):
    b, h, n, d = key.shape
    if type(gamma_) != list:
        gamma_ = [gamma_] * itrs
    const_denom = 0.025 # change this to 0.00 to change the kernel back

    if type(block_size) != list:
        block_size = [block_size] * itrs
    weight_idx = None
    for t in range(itrs): #write range(1, itrs) to check everything still works
        if sort_idx is not None:
            key_sorted, value_sorted = indexing(key, sort_idx, block_size[t], value)
            key_sorted = key_sorted.view(b, h, -1, block_size[t], d)
            if value is not None:
                weight_idx_padded = torch.nn.functional.pad(weight_idx, (0, math.ceil(n / block_size[t]) * block_size[t] - weight_idx.shape[-1]))
                value_sorted = value_sorted*weight_idx_padded.unsqueeze(-1)
                value_sorted = value_sorted.view(b, h, -1, block_size[t], d)
        else:
            new_n = math.ceil(n / block_size[t]) * block_size[t]
            key_sorted = torch.nn.functional.pad(key, (0,0,0,new_n-n), mode='constant', value=0.).view(b, h, -1, block_size[t], d)
            value_sorted = None
            if value is not None:
                value_sorted = torch.nn.functional.pad(value, (0,0,0,new_n-n), mode='constant', value=0.).view(b, h, -1, block_size[t], d)

        normal_keys = key_sorted - torch.mean(key_sorted, dim=-2, keepdim=True)

        if query is not None:
            query_key_correlation = torch.softmax(torch.einsum('b h n d,b h s m d->b h s n m',query[:,::4,:,:],normal_keys),dim=-1).mean(-2,keepdim=True)
            kernel_ = query_key_correlation*query_key_correlation.transpose(-1,-2)
        else:
            kernel_ = torch.exp(temp_ * torch.einsum('...nd,...sd->...ns', normal_keys, normal_keys)/math.sqrt(d) - beta_)
        if value is not None:
            kernel_ *= (1e-8 + torch.einsum('...nd,...sd->...ns', value_sorted, value_sorted)+const_denom)

        signs = torch.zeros(kernel_.shape[:4], dtype=torch.int16, device=kernel_.device)
        signs[:, :, :, 0] = 1
        rand_tensor = torch.rand(signs.shape, generator=rng, device=key.device)

        for i in range(1, kernel_.shape[3]):
            partial_inner_prod = (kernel_[:, :, :, i, :] * signs).sum(dim=-1)
            samp_prb = 0.5 - gamma_[t] * partial_inner_prod
            signs[:, :, :, i] = 2 * (rand_tensor[:, :, :, i] < samp_prb) - 1

        signs = signs.view(b, h, -1)[:, :, :n]

        if signs.shape[-1]==0: # simply to deal with n==0
            sort_idx = signs[:, :, :0]
            weigth_idx = signs[:, :, :0]
            break
        cumsum_neg = (signs == -1).cumsum(dim=-1)
        cumsum_pos = (signs == 1).cumsum(dim=-1)

        c_neg = torch.argmax((cumsum_neg == n//2).to(torch.int64), dim=-1) # Shape (b, h)
        c_pos = torch.argmax((cumsum_pos == n//2).to(torch.int64), dim=-1) # Shape (b, h)
        c = torch.maximum(c_neg, c_pos)

        c = c.to(signs.device)

        weight = signs

        # Create an index tensor `[0, 1, ..., n-1]` for comparison
        indices = torch.arange(signs.shape[2], device=signs.device).view(1, 1, -1)
        # Set all values after `c[a, b]` to `1`
        mask_after_c = indices > c.unsqueeze(-1)  # True for all d > c[a, b]
        weight[mask_after_c] = torch.abs(weight[mask_after_c])  # Set those indices to `1`
        # Identify where `signs[a, b, c[a, b]] == 1`
        mask_flip_needed = (signs.gather(2, c.unsqueeze(-1)) == 1).squeeze(-1)
        # Create mask for all indices `<= c[a, b]`
        mask_before_c = indices <= c.unsqueeze(-1)
        weight[mask_before_c] *= 2
        # Apply flipping only when `signs[a, b, c] == 1`
        flip_mask = mask_before_c & mask_flip_needed.unsqueeze(-1)
        weight[flip_mask] *= -1  # Flip selected values

        weight_argsort = torch.argsort(-weight, dim=-1, stable=True)

        n = n//2
        if sort_idx is None:
            sort_idx = weight_argsort[:, :, :n]
            weight_idx = weight.gather(-1, weight_argsort[:, :, :n])
        else:
            sort_idx = sort_idx.gather(2, weight_argsort[:, :, :n])
            weigth_idx_1 = weight.gather(-1, weight_argsort[:, :, :n])
            weight_idx = weight_idx.gather(-1, weight_argsort[:, :, :n])
            weight_idx = weight_idx*weigth_idx_1

    return sort_idx, weight_idx
def snap_kv(query_states, key_states, value_states, capacity_option='default', window_size=32, pooling='avgpool', kernel_size=5, topk_num=100):

    seq_len = key_states.shape[2]
    if capacity_option in ['default', 'two_stage_bw']:
        max_capacity_prompt = max(int(seq_len * 3.875 / 64), window_size+4)

    if query_states.shape[1] != key_states.shape[1]:
        num_key_value_groups = query_states.shape[1] // key_states.shape[1]
        kk = repeat_kv(key_states, num_key_value_groups)
        vv = repeat_kv(value_states, num_key_value_groups)
    else:
        kk = key_states
        vv = value_states

    dim = query_states.shape[-1]

    attn_weights = query_states[..., -window_size :, :] @ kk.transpose(2,3) / dim**0.5
    mask = torch.full((window_size, window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
    mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(attn_weights.device)
    attention_mask = mask[None, None, :, :]

    attn_weights[:, :, -window_size:, -window_size:] += attention_mask

    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights_sum = attn_weights[:, :, -window_size :, : -window_size].sum(dim=-2)
    if pooling == "avgpool":
        attn_cache = torch.nn.functional.avg_pool1d(attn_weights_sum, kernel_size=kernel_size, padding=kernel_size // 2, stride=1)
    elif pooling == "maxpool":
        attn_cache = torch.nn.functional.max_pool1d(attn_weights_sum, kernel_size=kernel_size, padding=kernel_size // 2, stride=1)
    else:
        raise ValueError("Pooling method not supported")
    if capacity_option == 'two_stage_bw':
        indices = attn_cache.topk(2*(max_capacity_prompt - window_size), dim=-1).indices
    else:
        indices = attn_cache.topk(topk_num, dim=-1).indices
    return indices
    # indices = indices.unsqueeze(-1).expand(-1, -1, -1, dim)
    # k_past_compress = kk[:, :, : -window_size, :].gather(dim=2, index=indices)
    # v_past_compress = vv[:, :, : -window_size, :].gather(dim=2, index=indices)
    # k_cur = kk[:, :, -window_size :, :]
    # v_cur = vv[:, :, -window_size :, :]
    # if capacity_option == 'two_stage_bw':
    #     return k_past_compress, k_cur, v_past_compress, v_cur
    # return torch.cat([k_past_compress, k_cur], dim=2), torch.cat([v_past_compress, v_cur], dim=2)



def manual_forward_llama(
    model,
    input_ids,
    kv_cache=None, position_ids=None, cache_position=None, num_logits_to_keep=0,
    kv_type='bw',
    unif = False,
    layer = 0,
    **kwargs
):
    balancing_alg = balanced_walk
    hh = model.model.embed_tokens(input_ids)
    if position_ids is None:
        position_ids = torch.arange(len(input_ids[0]), device=input_ids.device).unsqueeze(0)
    if int(transformers.__version__.split(".")[1]) >= 48:
        position_embeddings = model.model.rotary_emb(hh, position_ids)

    output_attn = None
    for i, decoder_layer in enumerate(model.model.layers):
        # hh = decoder_layer(hh, position_ids=position_ids)[0]
        res = hh.detach().clone()
        hh = decoder_layer.input_layernorm(hh)

        # h1, _, kv = decoder_layer.self_attn(hh, position_ids=position_ids, use_cache=False)
        # <===
        q_len = hh.shape[1]
        kv_len = q_len
        qq = decoder_layer.self_attn.q_proj(hh).reshape(1, q_len, -1, 128).transpose(1, 2)
        kk = decoder_layer.self_attn.k_proj(hh).reshape(1, kv_len, 8, 128).transpose(1, 2)
        vv = decoder_layer.self_attn.v_proj(hh).reshape(1, kv_len, 8, 128).transpose(1, 2)

        if int(transformers.__version__.split(".")[1]) >= 48:
            cos, sin = position_embeddings
        else:
            cos, sin = decoder_layer.self_attn.rotary_emb(vv, position_ids)

        qq, kk = apply_rotary_pos_emb(qq, kk, cos, sin)
        d = qq.shape[-1]

        if q_len > 1:
            attn_output = flash_attn_func(qq.transpose(1,2), kk.transpose(1,2), vv.transpose(1,2), causal=True)
        if i==layer:

          rng = model.config.rng
          gamma = model.config.gamma
          temp = model.config.temp
          beta = model.config.beta
          itrs = model.config.itrs
          block_size = model.config.block_size
          recent_size = model.config.recent_size

          kk_first = kk[:,:,:recent_size,:]
          vv_first = vv[:,:,:recent_size,:]
          kk_old = kk[:,:,recent_size:-recent_size,:]
          vv_old = vv[:,:,recent_size:-recent_size,:]
          kk_recent = kk[:,:,-recent_size:,:]
          vv_recent = vv[:,:,-recent_size:,:]

          start_time = time.time()
          #start_time_bw = time.time()
          sort_idx_bw, weight_idx_bw = balanced_walk(key=kk_old, value= vv_old,rng=rng, gamma_=gamma, temp_=temp, beta_=beta, itrs=itrs, block_size=block_size)
          #sort_idx_unif, weight_idx_unif = balancing_alg(key=kk_old, value= vv_old,rng=rng, gamma_=[0.0,0.0,0.0,0.0], temp_=temp, beta_=beta, itrs=itrs, block_size=block_size,unif=unif)
          #end_time_bw = time.time()
          #total_time_bw = end_time_bw - start_time_bw

          #start_time_other = time.time()
          bsz, n_heads, _, dim = kk.shape
          n_centers_bw = sort_idx_bw.shape[-1]

          kk_old_bw = torch.gather(kk_old, 2, sort_idx_bw.unsqueeze(-1).expand(bsz, n_heads, n_centers_bw, dim))
          vv_old_bw = torch.gather(vv_old, 2, sort_idx_bw.unsqueeze(-1).expand(bsz, n_heads, n_centers_bw, dim))

          if weight_idx_bw != None:#simply to deal with n==0
            #if balancing_alg != balanced_walk:
            weight_idx_bw_num = weight_idx_bw/2**(itrs)
            #else:
              #weight_idx_bw_num = weight_idx_bw
            vv_old_bw_num = vv_old_bw*weight_idx_bw_num.unsqueeze(-1)
            vv_old_bw_num = (vv_old_bw_num).to(torch.bfloat16)
          else:
            vv_old_bw_num = vv_old_bw

          kk_selected_bw_num = torch.cat((kk_first,torch.cat([kk_old_bw]*(2**(itrs)),dim=2),kk_recent),dim=2)
          vv_selected_bw_num = torch.cat((vv_first,torch.cat([vv_old_bw_num]*(2**(itrs)),dim=2),vv_recent),dim=2)

          kk_selected_bw = torch.cat((kk_first, torch.cat([kk_old_bw]*(2**(itrs)),dim=2) , kk_recent),dim=2)
          vv_selected_bw = torch.cat((vv_first, torch.cat([vv_old_bw]*(2**(itrs)),dim=2), vv_recent),dim=2)


          if weight_idx_bw == None: #simply to deal with n==0
             qq_latest = qq[:,:,-recent_size:,:]
             attn_output_bw = flash_attn_func(qq_latest.transpose(1,2), kk_selected_bw_num.transpose(1,2), vv_selected_bw_num.transpose(1,2), causal=True)
             end_time = time.time()
             total_time = end_time - start_time
             attn_output_exact = flash_attn_func(qq_latest.transpose(1,2), kk.transpose(1,2), vv.transpose(1,2), causal=True)
             print(f"{attn_output_bw}, {attn_output_bw}, {attn_output_bw}, {attn_output_exact} the loop happened")
             return attn_output_bw, attn_output_bw, total_time

          weight_idx_bw =  torch.cat([weight_idx_bw_num]*(2**(itrs)),dim=2)
          weight_idx_first = torch.ones(kk_first.shape[:3], dtype=torch.int16, device=weight_idx_bw.device)
          weight_idx_recent = torch.ones(kk_recent.shape[:3], dtype=torch.int16, device=weight_idx_bw.device)
          #end_time_other = time.time()
          #total_time_other = end_time_other - start_time_other

          #n_centers_unif = sort_idx_unif.shape[-1]

          #kk_old_unif = torch.gather(kk_old, 2, sort_idx_unif.unsqueeze(-1).expand(bsz, n_heads, n_centers_unif, dim))
          #vv_old_unif = torch.gather(vv_old, 2, sort_idx_unif.unsqueeze(-1).expand(bsz, n_heads, n_centers_unif, dim))

          #if balancing_alg != balanced_walk:
          #  weight_idx_unif = weight_idx_unif/2**(itrs)
          #else:
          #  weight_idx_unif = weight_idx_unif

          #vv_old_unif = vv_old_unif*weight_idx_unif.unsqueeze(-1)
          #vv_old_unif = (vv_old_unif).to(torch.bfloat16)


          #kk_selected_unif = torch.cat((kk_first,torch.cat([kk_old_unif]*(2**(itrs)),dim=2),kk_recent),dim=2)
          #vv_selected_unif = torch.cat((vv_first,torch.cat([vv_old_unif]*(2**(itrs)),dim=2),vv_recent),dim=2)

          qq_latest = qq[:,:,-recent_size:,:]


          #end_time_other = time.time()
          #total_time_other = end_time_other - start_time_other

          attn_output_bw = flash_attn_func(qq_latest.transpose(1,2), kk_selected_bw_num.transpose(1,2), vv_selected_bw_num.transpose(1,2), causal=True)
          end_time=time.time()
          total_time = end_time - start_time
          #attn_output_unif = flash_attn_func(qq_latest.transpose(1,2), kk_selected_unif.transpose(1,2), vv_selected_unif.transpose(1,2), causal=True)
          attn_output_exact = flash_attn_func(qq_latest.transpose(1,2), kk.transpose(1,2), vv.transpose(1,2), causal=True)

          return attn_output_bw,attn_output_exact, total_time
          #, total_time_bw, total_time_other



        attn_output = attn_output.contiguous().view(qq.shape[0], qq.shape[2], -1)
        hh = decoder_layer.self_attn.o_proj(attn_output)
        # ===>
        hh = res + hh

        res = hh.detach().clone()
        hh = decoder_layer.post_attention_layernorm(hh)
        hh = decoder_layer.mlp(hh)
        hh = res + hh



@torch.no_grad()
def greedy_generate(self, input_ids, max_new_tokens, eos_token_id=128009, kv_type="bw", layer = 0, unif = False,**kwargs):
    position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device).unsqueeze(0)
    attn_bw,attn_exact,_ = manual_forward_llama(self, input_ids, position_ids=position_ids, num_logits_to_keep=1,layer=layer, kv_type=kv_type,unif=unif)


    return attn_bw,attn_exact


def main(gamma_1=4,gamma_2 = 4 ,itrs = 2,layer=0,temp=0.1,model_name = "meta-llama/Llama-3.1-8B-Instruct"):

    layer = layer -1
    model_name = model_name
    print(f"model : {model_name}")

    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16, _attn_implementation='flash_attention_2')
    model = model.eval().requires_grad_(False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


    model.model.config.gamma = [gamma_1,gamma_2,gamma_2,gamma_2]
    model.model.config.beta = 0.
    model.model.config.temp = temp
    model.model.config.block_size = [256,256,256,256]
    model.model.config.itrs = itrs
    model.model.config.recent_size = 256
    seeds = [42,7,8,17,98,64,27,32,81,44]
    dataset = "triviaqa"
    print(f"dataset: {dataset}")
    prompt_format = DATASET2PROMPT[dataset]
    maxlen = DATASET2MAXLEN[dataset]

    examples = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
    kv_type = "bw"
    max_input_length = 100_000
    errors_bw_global = []
    errors_snapkv_global = []
    for i,seed in enumerate(seeds):
        print(f"current seed number:{i+1}")
        model.model.config.rng = torch.Generator('cuda').manual_seed(seed)
        errors_bw = []
        errors_snapkv = []


        for i, eg in enumerate(examples):


          # input_text = prompt_format.format(**eg)
          input_text = prompt_format.format(**eg)
          msgs = [dict(role="system", content=input_text)]
          input_tokens = tokenizer.apply_chat_template(msgs, add_generation_prompt=True)


          input_tokens = truncate_by_tokens(input_text, tokenizer, max_input_length)
          input_tensors = {"input_ids": torch.tensor(input_tokens).unsqueeze(0).to(device)}
          seq_len = len(input_tokens)

          terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
          attn_bw,exact_attn= greedy_generate(model, input_tensors['input_ids'], max_new_tokens=maxlen, eos_token_id=terminators, kv_type="bw",layer = layer, return_dict_in_generate=True)
          # prompt = template.replace('$DOC$', context.strip()).replace('$Q$', item['question'].strip()).replace('$C_A$', item['choice_A'].strip()).replace('$C_B$', item['choice_B'].strip()).replace('$C_C$', item['choice_C'].strip()).replace('$C_D$', item['choice_D'].strip())
          # input_ids = tokenizer.encode(prompt)
          # if len(input_ids) > max_len:
          #   input_ids = input_ids[:max_len//2] + input_ids[-max_len//2:]
          #   prompt = tokenizer.decode(input_ids, skip_special_tokens=True)
          # msgs = [dict(role="system", content=input_text)]
          # input_tokens = tokenizer.apply_chat_template(msgs, add_generation_prompt=True)


          # input_tokens = truncate_by_tokens(input_text, tokenizer, max_input_length)
          # input_tensors = {"input_ids": torch.tensor(input_tokens).unsqueeze(0).to(model.device)}
          # seq_len = len(input_tokens)

          # terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
          # attn_bw,exact_attn = greedy_generate(model, input_tensors['input_ids'], max_new_tokens=maxlen, eos_token_id=terminators, kv_type="bw",layer = layer, return_dict_in_generate=True)
          denominator = torch.clamp(exact_attn, min=1e-8)

          error_bw = (torch.norm(input = attn_bw - exact_attn, dim=-1)/torch.norm(input = denominator,dim=-1)).mean()
          # error_unif = (torch.norm(input = attn_unif - exact_attn, dim=-1)/torch.norm(input = denominator, dim=-1)).mean()

          errors_bw.append(error_bw.to(device='cpu').item())
          # errors_unif.append(error_unif.to(device='cpu').item())



          if (i%100==0 and i>0) or i==len(examples)-1:
            print(f"   layer:{layer}, seed:{seed} | examples processed:{i}/{len(examples)}| avg bw error so far:{sum(errors_bw)/len(errors_bw)} ")

          torch.cuda.empty_cache()

        torch.cuda.empty_cache()

        errors_bw_global.append(sum(errors_bw)/len(errors_bw))


    return errors_bw_global

In [None]:
#specify layer out of 1,2,5 below,
#and for compression rate x in [0.5,0.25,0.125,0.0625], specify compression rate parameter itrs = log_2(1/x). For eg for compression rate 0.5, itrs = log_2(1/0.5) = 1
#then execute cell to obtain relative errors for BalanceKV (denoted by bw) for layer and compression rate for llama
model_name = "meta-llama/Llama-3.1-8B-Instruct"
#model_name = "mistralai/Ministral-8B-Instruct-2410" #change to this model name for mistral
layer = 1
itrs = 1
torch.cuda.empty_cache()
bw= main(gamma_1=4.0,gamma_2=4.0,itrs=itrs,layer=layer,model_name=model_name)
bw_avg,bw_stddev = avg_stddev(bw)
print(f"Average BW: {bw_avg}| Standard Deviation BW: {bw_stddev}")

Single layer attention approximation for Uniform sampling

In [None]:
## next, execute this cell

import time
import math
import re
import string
from collections import Counter
import torch
from typing import List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from datasets import load_dataset
from flash_attn import flash_attn_func



DATASET2PROMPT = {"triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}"}
DATASET2MAXLEN = {"triviaqa": 32,  }


def avg_stddev(data):
    average = sum(data) / len(data)
    variance = sum((x - average) ** 2 for x in data) / len(data)
    std_dev = math.sqrt(variance)
    return average,std_dev


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def truncate_input(input: list, max_length: int, manner="middle"):
    if max_length < 0:
        return input
    if len(input) <= max_length:
        return input
    if manner == "middle":
        split = max_length // 2
        return input[0:split] + input[-split:]
    else:
        return None


def truncate_by_tokens(input, tok, max_tokens, manner: str = "middle"):
    tokens = tok.encode(input)
    len_before = len(tokens)
    tokens = truncate_input(tokens, max_length=max_tokens, manner=manner)
    len_after = len(tokens)
    assert len_after <= len_before
    assert len_after <= max_tokens or max_tokens < 0
    return tokens




def indexing(key, sort_idx, block_size, value=None):
  indices = sort_idx.unsqueeze(-1).expand(-1, -1, -1, key.shape[-1])
  new_n = math.ceil(sort_idx.shape[-1] / block_size) * block_size
  if new_n < sort_idx.shape[-1]:
    import pdb; pdb.set_trace();
  out_key = torch.nn.functional.pad(key.gather(2, indices), (0,0,0,new_n-sort_idx.shape[-1]), mode='constant', value=0.)
  out_value = None
  if value is not None:
    out_value = torch.nn.functional.pad(value.gather(2, indices), (0,0,0,new_n-sort_idx.shape[-1]), mode='constant', value=0.)
  return out_key, out_value


def balanced_walk(key, rng, gamma_, temp_, beta_, itrs, block_size,unif, value=None, sort_idx=None,query=None):
  b, h, n, d = key.shape

  for t in range(itrs):
    if sort_idx is not None:
      key_sorted, value_sorted = indexing(key, sort_idx, block_size[t], value)
      key_sorted = key_sorted.view(b, h, -1, block_size[t], d)
      if value is not None:
        value_sorted = value_sorted.view(b, h, -1, block_size[t], d)
    else:
      new_n = math.ceil(n / block_size[t]) * block_size[t]
      key_sorted = torch.nn.functional.pad(key, (0,0,0,new_n-n), mode='constant', value=0.).view(b, h, -1, block_size[t], d)
      value_sorted = None
      if value is not None:
        value_sorted = torch.nn.functional.pad(value, (0,0,0,new_n-n), mode='constant', value=0.).view(b, h, -1, block_size[t], d)

    normal_keys = key_sorted - torch.mean(key_sorted, dim=-2, keepdim=True)
    if query is not None:
      query_key_correlation = torch.softmax(torch.einsum('b h n d,b h s m d->b h s n m',query[:,::4,:,:],normal_keys),dim=-1).mean(-2,keepdim=True)
      kernel_ = query_key_correlation*query_key_correlation.transpose(-1,-2)
    else:
      kernel_ = torch.exp(temp_ * torch.einsum('...nd,...sd->...ns', normal_keys, normal_keys)/math.sqrt(d) - beta_)
    if value is not None:

      kernel_ *= (1e-8 + torch.einsum('...nd,...sd->...ns', value_sorted, value_sorted))

    signs = torch.zeros(kernel_.shape[:4], dtype=torch.int16, device=kernel_.device)
    signs[:, :, :, 0] = 1
    partial_quad_form = kernel_[:, :, :, 0, 0].detach().clone()
    rand_tensor = torch.rand(signs.shape, generator=rng, device=key.device)
    for i in range(1, kernel_.shape[3]):
      partial_inner_prod = (kernel_[:, :, :, i, :] * signs).sum(dim=-1)
      samp_prb = 0.5 - gamma_[t] * partial_inner_prod


      signs[:, :, :, i] = 2 * (rand_tensor[:, :, :, i] < samp_prb) - 1
      partial_quad_form += (2 * signs[:, :, :, i] * partial_inner_prod + kernel_[:, :, :, i, i])

    signs = signs.view(b, h, -1)[:, :, :n]

    signs_argsort = torch.argsort(signs, dim=-1, stable=True)

    n = n//2
    if sort_idx is None:
      sort_idx = signs_argsort[:, :, :n]
    else:
      sort_idx = sort_idx.gather(2, signs_argsort[:, :, :n])

  return sort_idx



def manual_forward_llama(
    model,
    input_ids,
    kv_cache=None, position_ids=None, cache_position=None, num_logits_to_keep=0,
    kv_type='bw',
    unif = False,
    layer = 0,
    **kwargs
):
    hh = model.model.embed_tokens(input_ids)
    if position_ids is None:
        position_ids = torch.arange(len(input_ids[0]), device=input_ids.device).unsqueeze(0)
    if int(transformers.__version__.split(".")[1]) >= 48:
        position_embeddings = model.model.rotary_emb(hh, position_ids)

    output_attn = None
    for i, decoder_layer in enumerate(model.model.layers):
        # hh = decoder_layer(hh, position_ids=position_ids)[0]
        res = hh.detach().clone()
        hh = decoder_layer.input_layernorm(hh)

        # h1, _, kv = decoder_layer.self_attn(hh, position_ids=position_ids, use_cache=False)
        # <===
        q_len = hh.shape[1]
        kv_len = q_len
        qq = decoder_layer.self_attn.q_proj(hh).reshape(1, q_len, -1, 128).transpose(1, 2)
        kk = decoder_layer.self_attn.k_proj(hh).reshape(1, kv_len, 8, 128).transpose(1, 2)
        vv = decoder_layer.self_attn.v_proj(hh).reshape(1, kv_len, 8, 128).transpose(1, 2)

        if int(transformers.__version__.split(".")[1]) >= 48:
            cos, sin = position_embeddings
        else:
            cos, sin = decoder_layer.self_attn.rotary_emb(vv, position_ids)
        qq, kk = apply_rotary_pos_emb(qq, kk, cos, sin)
        d = qq.shape[-1]

        if q_len > 1:
            attn_output = flash_attn_func(qq.transpose(1,2), kk.transpose(1,2), vv.transpose(1,2), causal=True)
        if i==layer:
          rng = model.config.rng
          gamma = model.config.gamma
          temp = model.config.temp
          beta = model.config.beta
          itrs = model.config.itrs
          block_size = model.config.block_size
          recent_size = model.config.recent_size


          kk_first = kk[:,:,:recent_size,:]
          vv_first = vv[:,:,:recent_size,:]
          kk_old = kk[:,:,recent_size:-recent_size,:]
          vv_old = vv[:,:,recent_size:-recent_size,:]
          kk_recent = kk[:,:,-recent_size:,:]
          vv_recent = vv[:,:,-recent_size:,:]


          sort_idx_bw = balanced_walk(key=kk_old, value= vv_old,rng=rng, gamma_=gamma, temp_=temp, beta_=beta, itrs=itrs, block_size=block_size,unif=unif)
          sort_idx_unif = balanced_walk(key=kk_old, value= vv_old,rng=rng, gamma_=[0.0,0.0,0.0,0.0], temp_=temp, beta_=beta, itrs=itrs, block_size=block_size,unif=unif)


          bsz, n_heads, _, dim = kk.shape
          n_centers_bw = sort_idx_bw.shape[-1]
          kk_old_bw = torch.gather(kk_old, 2, sort_idx_bw.unsqueeze(-1).expand(bsz, n_heads, n_centers_bw, dim))
          vv_old_bw = torch.gather(vv_old, 2, sort_idx_bw.unsqueeze(-1).expand(bsz, n_heads, n_centers_bw, dim))


          kk_selected_bw = torch.cat((kk_first,torch.cat([kk_old_bw]*(2**itrs),dim=2),kk_recent),dim=2)
          vv_selected_bw = torch.cat((vv_first,torch.cat([vv_old_bw]*(2**itrs),dim=2),vv_recent),dim=2)

          n_centers_unif = sort_idx_unif.shape[-1]

          kk_old_unif = torch.gather(kk_old, 2, sort_idx_unif.unsqueeze(-1).expand(bsz, n_heads, n_centers_unif, dim))
          vv_old_unif = torch.gather(vv_old, 2, sort_idx_unif.unsqueeze(-1).expand(bsz, n_heads, n_centers_unif, dim))

          kk_selected_unif = torch.cat((kk_first,torch.cat([kk_old_unif]*(2**itrs),dim=2),kk_recent),dim=2)
          vv_selected_unif = torch.cat((vv_first,torch.cat([vv_old_unif]*(2**itrs),dim=2),vv_recent),dim=2)

          qq_latest = qq[:,:,-recent_size:,:]




          attn_output_bw = flash_attn_func(qq_latest.transpose(1,2), kk_selected_bw.transpose(1,2), vv_selected_bw.transpose(1,2), causal=True)
          attn_output_unif = flash_attn_func(qq_latest.transpose(1,2), kk_selected_unif.transpose(1,2), vv_selected_unif.transpose(1,2), causal=True)
          attn_output_exact = flash_attn_func(qq_latest.transpose(1,2), kk.transpose(1,2), vv.transpose(1,2), causal=True)

          return attn_output_bw, attn_output_unif,attn_output_exact



        attn_output = attn_output.contiguous().view(qq.shape[0], qq.shape[2], -1)
        hh = decoder_layer.self_attn.o_proj(attn_output)
        # ===>
        hh = res + hh

        res = hh.detach().clone()
        hh = decoder_layer.post_attention_layernorm(hh)
        hh = decoder_layer.mlp(hh)
        hh = res + hh


@torch.no_grad()
def greedy_generate(self, input_ids, max_new_tokens, eos_token_id=128009, kv_type="bw", layer = 0, unif = False,**kwargs):
    position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device).unsqueeze(0)
    attn_bw,attn_unif,attn_exact = manual_forward_llama(self, input_ids, position_ids=position_ids, num_logits_to_keep=1,layer=layer, kv_type=kv_type,unif=unif)


    return attn_bw,attn_unif,attn_exact


def main(gamma_1=4,gamma_2 = 4 ,itrs = 2,layer=0,temp=0.1,model_name = "meta-llama/Llama-3.1-8B-Instruct"):

    layer = layer -1
    model_name = model_name
    print(f"model : {model_name}")

    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16, _attn_implementation='flash_attention_2')
    model = model.eval().requires_grad_(False)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


    model.model.config.gamma = [gamma_1,gamma_2,gamma_2,gamma_2]
    model.model.config.beta = 0.
    model.model.config.temp = temp
    model.model.config.block_size = [256,256,256,256]
    model.model.config.itrs = itrs
    model.model.config.recent_size = 256
    seeds = [42,7,8,17,98,64,27,32,81,44]
    dataset = "triviaqa"
    print(f"dataset: {dataset}")
    prompt_format = DATASET2PROMPT[dataset]
    maxlen = DATASET2MAXLEN[dataset]

    examples = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
    kv_type = "bw"
    max_input_length = 100_000
    errors_bw_global = []
    errors_unif_global = []
    for i,seed in enumerate(seeds):
        print(f"current seed number:{i+1}")
        model.model.config.rng = torch.Generator('cuda').manual_seed(seed)
        errors_bw = []
        errors_unif = []


        for i, eg in enumerate(examples):


          input_text = prompt_format.format(**eg)
          msgs = [dict(role="system", content=input_text)]
          input_tokens = tokenizer.apply_chat_template(msgs, add_generation_prompt=True)


          input_tokens = truncate_by_tokens(input_text, tokenizer, max_input_length)
          input_tensors = {"input_ids": torch.tensor(input_tokens).unsqueeze(0).to(model.device)}
          seq_len = len(input_tokens)

          terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
          attn_bw,attn_unif,exact_attn = greedy_generate(model, input_tensors['input_ids'], max_new_tokens=maxlen, eos_token_id=terminators, kv_type="bw",layer = layer, return_dict_in_generate=True)
          denominator = torch.clamp(exact_attn, min=1e-8)

          error_bw = (torch.norm(input = attn_bw - exact_attn, dim=-1)/torch.norm(input = denominator,dim=-1)).mean()
          error_unif = (torch.norm(input = attn_unif - exact_attn, dim=-1)/torch.norm(input = denominator, dim=-1)).mean()

          errors_bw.append(error_bw.to(device='cpu').item())
          errors_unif.append(error_unif.to(device='cpu').item())



          if (i%100==0 and i>0) or i==len(examples)-1:
            print(f"   layer:{layer}, seed:{seed} | examples processed:{i}/{len(examples)}| avg unif error so far:{sum(errors_unif)/len(errors_unif)}")

          torch.cuda.empty_cache()

        torch.cuda.empty_cache()

        errors_bw_global.append(sum(errors_bw)/len(errors_bw))
        errors_unif_global.append(sum(errors_unif)/len(errors_unif))

    return errors_unif_global

In [None]:
#specify layer out of 1,2,5 below,
#and for compression rate x in [0.5,0.25,0.125,0.0625], specify compression rate parameter itrs = log_2(1/x). For eg for compression rate 0.5, itrs = log_2(1/0.5) = 1
#then execute cell to obtain relative errors for Unif sampling for layer and compression rate for llama
model_name = "meta-llama/Llama-3.1-8B-Instruct"
#model_name = "mistralai/Ministral-8B-Instruct-2410" #change to this model name for mistral
layer = 1
itrs = 1
torch.cuda.empty_cache()
unif= main(gamma_1=4.0,gamma_2=4.0,itrs=itrs,layer=layer,model_name=model_name)
unif_avg,unif_stddev = avg_stddev(unif)
print(f"Average unif: {unif_avg}| Standard Deviation unif: {unif_stddev}")