In [None]:
import os  
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [None]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model
import gc
import torch.nn as nn
import copy

In [None]:
fp_model = AutoModelForCausalLM.from_pretrained(
    "/raid/LLM/llama2-7b",
    torch_dtype = torch.float32,
    device_map="cpu"
)
for para in fp_model.parameters():
    para.requires_grad = False
fp_model.config.use_cache = False
fp_model.eval()
sd = {k:v.cpu() for k,v in fp_model.state_dict().items()}
del fp_model
gc.collect()
torch.cuda.empty_cache()

In [None]:
w2_model = AutoModelForCausalLM.from_pretrained(
    #"/home/leegh/qloras/qalora_svd/models/llama2-7b-qalora-fake_w2-pool_first_avg",
    "/raid/lgh/aids24/EX2/ex2_llama2_7b_awq_w2_scale",
    torch_dtype = torch.float32,
    device_map="cpu"
)
for para in w2_model.parameters():
    para.requires_grad = False
w2_model.config.use_cache = False
w2_model.eval()

In [None]:
sd["model.layers.0.self_attn.q_proj.weight"]

In [None]:
w2_model.model.layers[0].self_attn.q_proj.weight

In [None]:
def quant_func_asym(w, n_bits, q_group_size):
    # FP SCALE, INT ZERO

    org_w_shape = w.shape
    # q_group_size = -1
    
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    else:
        w = w.reshape(-1, w.shape[-1]) # channel-wise

    max_val = w.amax(dim=1, keepdim=True)
    min_val = w.amin(dim=1, keepdim=True)
    max_int = 2 ** n_bits - 1
    min_int = 0
    # scales = (max_val - min_val).clamp(min=1e-5) / max_int
    scales = (max_val - min_val) / max_int
    zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    
    w = (torch.clamp(torch.round(w / scales) +
                    zeros, min_int, max_int) - zeros) * scales
    
    assert torch.isnan(w).sum() == 0

    w_q = w.reshape(org_w_shape)
    
    return w_q.detach()

In [None]:
def groop_pool_weight(w, q_group_size):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
        return nn.AvgPool1d(q_group_size)(w).reshape(org_w_shape[0], org_w_shape[1] // q_group_size)
        #return nn.MaxPool1d(q_group_size, q_group_size, 0)(w).reshape(org_w_shape[0], org_w_shape[1] // q_group_size)
        #return nn.MaxPool1d(q_group_size, q_group_size, 0)(torch.abs(w)).reshape(org_w_shape[0], org_w_shape[1] // q_group_size)
    else:
        w = w.reshape(-1, w.shape[-1])
        return nn.AvgPool1d(w.shape[-1])(w).reshape(org_w_shape[0], 1)
        #return nn.MaxPool1d(w.shape[-1], w.shape[-1], 0)(w).reshape(org_w_shape[0], 1)
        #return nn.MaxPool1d(w.shape[-1], w.shape[-1], 0)(torch.abs(w)).reshape(org_w_shape[0], 1)

In [None]:
# Fake Quant -> w2, fp zero
groop_pool = False
pool_first = True

svd = False
svd_rank = 4096

q_bit = 2
group_size = 64
for n, m in w2_model.named_modules():
    if n != 'lm_head' and isinstance(m, nn.Linear):
        print(n)
        temp_weight = copy.deepcopy(m.weight)
        org_w_shape = temp_weight.shape
        quant_weight = temp_weight
        quant_weight = quant_func_asym(temp_weight, q_bit, group_size)
        #quant_weight = quant_func_asym_new(temp_weight, q_bit, group_size)

        if groop_pool or svd:
            fp_weight = sd[n + ".weight"]

        if groop_pool and not svd:
            if pool_first:
                pool_fp_w = groop_pool_weight(fp_weight, group_size)
                pool_w2_w = groop_pool_weight(quant_weight, group_size)

                pool_err = pool_fp_w - pool_w2_w
                pool_err_expand = torch.transpose(pool_err.reshape(-1).repeat(group_size, 1), 1, 0).reshape(org_w_shape[0], org_w_shape[1])
                adapter_weight = pool_err_expand
            else:
                #raise NotImplementedError
                q_err = fp_weight - quant_weight
                q_err_pool = groop_pool_weight(q_err, group_size)
                q_err_pool_expand = torch.transpose(q_err_pool.reshape(-1).repeat(group_size, 1), 1, 0).reshape(org_w_shape[0], org_w_shape[1])
                adapter_weight = q_err_pool_expand
        
        elif not groop_pool and svd:
            gap_weight = (fp_weight - quant_weight).detach().cpu()
            U, S, Vh = torch.linalg.svd(gap_weight, full_matrices=False)
            L = U @ (torch.sqrt(torch.diag(S)[:, 0:svd_rank])) # lora_B
            R = torch.sqrt(torch.diag(S)[0:svd_rank, :]) @ Vh  # lora_A
            adapter_weight = L @ R
        
        elif groop_pool and svd:
            raise NotImplementedError


        if groop_pool or svd:
            m.weight.data = quant_weight + adapter_weight
        else:
            m.weight.data = quant_weight

        
        

In [None]:
w2_model.model.layers[0].self_attn.q_proj.weight

In [None]:
w2_model.model.layers[0].post_attention_layernorm.weight

In [None]:
w2_model.model.layers[0].self_attn.q_proj.weight[0][:64]

In [None]:
w2_model.model.layers[0].self_attn.q_proj.weight[0][128:192].unique()

In [None]:
from ppl_utils import eval_ppl
w2_model = w2_model.to("cuda")
results = eval_ppl(w2_model, False, "llama2", "cuda", "/raid/LLM/llama2-7b")
w2_model = w2_model.to("cpu")

In [None]:
import json
dumped = json.dumps(
    results, indent=2, ensure_ascii=False
)


#output_dir = "PPL_results/EX1/step8/5iter/svd_init_results/svd_init"
output_dir = "PPL_results/qloras/llama2-7b-rtn_w2a16g64"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)


with open(os.path.join(output_dir, "results.json"), "w") as f:
    f.write(dumped)
    f.close()


In [None]:
w2_model.save_pretrained("/raid/lgh/ex1_llama2_7b_awq_w2_fake_manual_w3scale")


In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("/raid/LLM/llama2-7b", use_fast=True)
tokenizer.save_pretrained("/raid/lgh/ex1_llama2_7b_awq_w2_fake_manual_w3scale")

In [None]:
print(w2_model.model.layers[0].mlp.down_proj.weight)
print(w2_model.model.layers[0].mlp.down_proj.weight.shape)

In [None]:
print(w2_model.model.layers[0].mlp.down_proj.weight)
print(w2_model.model.layers[0].mlp.down_proj.weight.shape)

In [None]:
w2_model.model.layers[0].mlp.down_proj.weight[0][11004]

In [None]:
sample = copy.deepcopy(w2_model.model.layers[0].mlp.down_proj.weight)

In [None]:
org_w_shape = sample.shape
print(org_w_shape)

In [None]:
torch.abs(sample)

In [None]:
q_group_size = 64
if q_group_size > 0:
    assert org_w_shape[-1] % q_group_size == 0
    sample = sample.reshape(-1, q_group_size)
else:
    sample = sample.reshape(-1, sample.shape[-1])
print(sample)
print(sample.shape)

In [None]:
sample[0]

In [None]:
sample[1]

In [None]:
torch.mean(sample[0])

In [None]:
torch.mean(sample[1])

In [None]:
L = org_w_shape[1] / q_group_size
print(L)

In [None]:
# https://pytorch.org/docs/stable/nn.html#pooling-layers 
print(nn.AvgPool1d(64)(sample))
print(nn.AvgPool1d(64)(sample).shape)

In [None]:
sample[0]

In [None]:
nn.MaxPool1d(q_group_size, q_group_size, 0)(sample)

In [None]:
nn.MaxPool1d(q_group_size, q_group_size, 0)(sample).shape

In [None]:
din = org_w_shape[0]
print(din)
dout = org_w_shape[1]
print(dout)

In [None]:
q_group_size

In [None]:
pooled_example = nn.AvgPool1d(64)(sample).reshape(org_w_shape[0], org_w_shape[1] // q_group_size)
print(pooled_example)
print(pooled_example.shape)

In [None]:
print(pooled_example.reshape(-1))
print(pooled_example.reshape(-1).shape)


In [None]:
torch.transpose(pooled_example.reshape(-1).repeat(q_group_size, 1), 1, 0).reshape(org_w_shape[0], org_w_shape[1]).shape

In [None]:
torch.transpose(pooled_example.reshape(-1).repeat(q_group_size, 1), 1, 0).reshape(org_w_shape[0], org_w_shape[1])[0][:64]

In [None]:
torch.transpose(pooled_example.reshape(-1).repeat(q_group_size, 1), 1, 0).reshape(org_w_shape[0], org_w_shape[1])[0][:64]

In [None]:
base_model_path = "/raid/lgh/multi_lora/EX1_5iter/step9_merged"

In [None]:
base_model_path = "/home/leegh/lgh_n24/models/llama2-7b-omni-w2a16g64"

In [None]:
#base_model_path = "/raid/LLM/llama2-7b"

In [None]:
lora_r = 256 ###################################################
lora_alpha = lora_r
lora_dropout = 0.1

model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    torch_dtype=torch.float32,
    device_map="cpu"
)
target_linear = ['gate_proj', 'k_proj', 'o_proj', 'v_proj', 'q_proj', 'up_proj', 'down_proj']
target_t_type = 'CAUSAL_LM'
lora_config = LoraConfig(
    init_lora_weights = "gaussian",
    r = lora_r,
    lora_alpha = lora_alpha,
    target_modules = target_linear,
    lora_dropout = lora_dropout,
    bias = "none",
    task_type = target_t_type 
)
model = get_peft_model(model, lora_config)
model.config.use_cache = False

In [None]:
model

In [None]:
model.base_model.model.model.layers[0].mlp.gate_proj.lora_A.default.weight

In [None]:
model.base_model.model.model.layers[0].mlp.gate_proj.lora_B.default.weight

In [None]:
#model.save_pretrained("/raid/lgh/multi_lora/EX1_10iter/step1_merged/lora_init")

In [None]:
from peft.tuners.lora import LoraLayer

rank=lora_r

for n,m in model.named_modules():
    if isinstance(m, LoraLayer):
        print(n)
        adj_name = n.replace('base_model.model.','') + '.weight'
        #gap_weight = (m.base_layer.weight - sd[adj_name]).detach().cpu()
        gap_weight = (sd[adj_name] - m.base_layer.weight).detach().cpu()

        U, S, Vh = torch.linalg.svd(gap_weight, full_matrices=False)
        L = U @ (torch.sqrt(torch.diag(S)[:, 0:rank])) # lora_B
        R = torch.sqrt(torch.diag(S)[0:rank, :]) @ Vh  # lora_A
        # B @ A
        m.lora_A.default.weight.data = R
        m.lora_B.default.weight.data = L

In [None]:
model.base_model.model.model.layers[0].mlp.gate_proj.lora_B.default.weight

In [None]:
model.save_pretrained(f"{base_model_path}/svd_r256_init")

In [None]:
model.save_pretrained(f"{base_model_path}/svd_init")
