In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
import torch
print(torch.cuda.device_count())
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from random import randint

import tqdm

from rotation_utils import random_orthogonal_matrix
from hadamard_utils import random_hadamard_matrix, apply_exact_had_to_linear
from quant_utils import ActQuantWrapper

import utils
import model_utils
import data_utils
import transformers
import quant_utils
import rotation_utils
import gptq_utils
import eval_utils
import hadamard_utils
import rotated_llama

2


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = utils.parser_gen('--model meta-llama/Llama-2-7b-hf --rotate --a_bits 4 --v_bits 4 --k_bits 4 --w_bits 16 --w_clip --bsz 1'.split())

Arguments: 
{'a_asym': False,
 'a_bits': 4,
 'a_clip_ratio': 1.0,
 'a_groupsize': -1,
 'act_order': False,
 'bsz': 1,
 'cal_dataset': 'wikitext2',
 'capture_layer_io': False,
 'distribute': False,
 'eval_dataset': 'wikitext2',
 'fp32_had': False,
 'hf_token': None,
 'int8_down_proj': False,
 'k_asym': False,
 'k_bits': 4,
 'k_clip_ratio': 1.0,
 'k_groupsize': -1,
 'k_pre_rope': False,
 'layer_idx': 10,
 'learn_r1': True,
 'learn_r2': True,
 'lm_eval': False,
 'lm_eval_batch_size': 128,
 'load_qmodel_path': None,
 'model': 'meta-llama/Llama-2-7b-hf',
 'momentum': 0.0,
 'nsamples': 128,
 'percdamp': 0.01,
 'prefix_r': '',
 'rotate': True,
 'rotate_mode': 'hadamard',
 'rotation_seed': -1,
 'save_name': '20240623_063116',
 'save_path': '/ceph/echoi/codes/QuaRot/fake_quant/experiments/meta-llama/Llama-2-7b-hf/20240623_063116',
 'save_qmodel_path': None,
 'seed': 0,
 'tasks': ['piqa',
           'hellaswag',
           'arc_easy',
           'arc_challenge',
           'winogrande',
        

In [3]:
trainloader = data_utils.get_loaders(
    args.cal_dataset, nsamples=args.nsamples,
    seed=args.seed, model=args.model,
    seqlen=2048, eval_mode=False
)

from transformers.trainer_pt_utils import LabelSmoother
label_smoother = LabelSmoother(0.0)

In [4]:
transformers.set_seed(args.seed)
model = model_utils.get_model(args.model, args.hf_token)
model.eval()
for p in model.parameters():
    p.requires_grad = False
    
#utils.distribute_model(model)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.79it/s]
---> Loading meta-llama/Llama-2-7b-hf Model with seq_len: 2048


In [5]:
from rotated_llama import RotatedLlamaForCausalLM
with torch.no_grad():
    rotated_model = RotatedLlamaForCausalLM(model.config, model)
del model
rotated_model.eval()
rotation_utils.fuse_layer_norms(rotated_model)
for p in rotated_model.parameters():
    p.requires_grad_(False)

In [6]:
quant_utils.add_actquant(
    rotated_model,
    layers=[nn.Linear,
            ActQuantWrapper,
            rotated_llama.RotatedLinear,
            rotated_llama.RotatedOVProj]
)

qlayers = quant_utils.find_qlayers(
    rotated_model.model,
    layers=[nn.Linear,
            ActQuantWrapper,
            rotated_llama.RotatedLinear,
            rotated_llama.RotatedOVProj]
)

for name in qlayers:
    if 'down_proj' in name:
        had_K, K = hadamard_utils.get_hadK(rotated_model.config.intermediate_size)
        hadamard_utils.apply_exact_had_to_linear(qlayers[name].module, had_dim=-1, output=False)
        qlayers[name].online_full_had = True
        qlayers[name].had_K = had_K
        qlayers[name].K = K
        qlayers[name].fp32_had = args.fp32_had

if args.a_bits < 16 or args.v_bits < 16:
    qlayers = quant_utils.find_qlayers(rotated_model, layers=[quant_utils.ActQuantWrapper])
    down_proj_groupsize = -1
    if args.a_groupsize > 0 and "llama" in args.model:
        down_proj_groupsize = utils.llama_down_proj_groupsize(rotated_model, args.a_groupsize)
    
    for name in qlayers:            
        layer_input_bits = args.a_bits
        layer_groupsize = args.a_groupsize
        layer_a_sym = not(args.a_asym)
        layer_a_clip = args.a_clip_ratio
        
        if 'v_proj' in name and args.v_bits < 16: #Set the v_proj precision
            qlayers[name].out_quantizer.configure(bits=args.v_bits,
                                            groupsize=args.v_groupsize,
                                            sym=not(args.v_asym),
                                            clip_ratio=args.v_clip_ratio)
        
        if 'lm_head' in name: #Skip lm_head quantization   
            layer_input_bits = 16
        
        if 'down_proj' in name: #Set the down_proj precision
            if args.int8_down_proj:
                layer_input_bits = 8
            layer_groupsize = down_proj_groupsize

            
        qlayers[name].quantizer.configure(bits=layer_input_bits,
                                            groupsize=layer_groupsize,
                                            sym=layer_a_sym,
                                            clip_ratio=layer_a_clip)

if args.w_bits < 16:
    save_dict = {}
    if args.load_qmodel_path: # Load Quantized Rotated Model
        assert args.rotate, "Model should be rotated to load a quantized model!"
        assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
        print("Load quantized model from ", args.load_qmodel_path)
        save_dict = torch.load(args.load_qmodel_path)
        rotated_model.load_state_dict(save_dict["model"])
        
    elif not args.w_rtn: # GPTQ Weight Quantization
        assert "llama" in args.model, "Only llama is supported for GPTQ!"
        
        trainloader = data_utils.get_loaders(
            args.cal_dataset, nsamples=args.nsamples,
            seed=args.seed, model=args.model,
            seqlen=model.seqlen, eval_mode=False
        )
        quantizers = gptq_utils.gptq_fwrd(
            rotated_model,
            trainloader,
            utils.DEV,
            args,
            [rotated_llama.RotatedLinear, rotated_llama.RotatedOVProj])
        save_dict["w_quantizers"] = quantizers
    else: # RTN Weight Quantization
        quantizers = gptq_utils.rtn_fwrd(rotated_model, utils.DEV, args)
        save_dict["w_quantizers"] = quantizers
        
    if args.save_qmodel_path:
        save_dict["model"] = model.state_dict()
        torch.save(save_dict, args.save_qmodel_path)

if args.k_bits < 16:
    if args.k_pre_rope:
        raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
    else:
        rope_function_name = model_utils.get_rope_function_name(rotated_model)
        layers = model_utils.get_layers(rotated_model)
        k_quant_config = {'k_bits':args.k_bits, "k_groupsize": args.k_groupsize,
                                        "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio}
        for layer in layers:
            rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
                        layer.self_attn, 
                        rope_function_name, 
                        config=rotated_model.config,
                        **k_quant_config)

In [7]:
utils.distribute_model(rotated_model)

{0: 50184060928, 1: 50758680576, 'cpu': 227161223168}


GPU memory (from distribute_model): 12.71 -> 12.71 GB (0.00 GB)


In [8]:
q1 = random_hadamard_matrix(rotated_model.config.hidden_size, utils.DEV).to(dtype=torch.float64)
Q1 = nn.Parameter(q1, requires_grad=True)

In [9]:
from stiefel import stiefel_optimizer
optimizer = stiefel_optimizer.SGDG([
    {'params': [Q1], 'lr': 1.5, 'momentum': 0.0, 'stiefel': True}
])

In [10]:
pbar = tqdm.tqdm(range(0 + 1, 100 + 1), desc="Training progress", dynamic_ncols=True)
idx_stack = None

def lr_schedule(iter, total_iter, max_lr, min_lr):
    return max_lr - iter / total_iter * (max_lr - min_lr)

Training progress:   0%|          | 0/100 [00:00<?, ?it/s]

In [11]:
for iteration in pbar:
    if not idx_stack:
        idx_stack = list(range(0, len(trainloader)))
    
    idx = idx_stack.pop(randint(0, len(idx_stack) - 1))
    
    data = trainloader[idx]
    
    input = data[0].to(utils.DEV)
    #target = data[1]#.to(utils.DEV)
    #print()
    output = rotated_model(input, R1=Q1)
    print('forward')
    loss = label_smoother(output, input, shift_labels=True)
    
    Q1.retain_grad()
    #for Q2 in Q2s:
    #    Q2.retain_grad()
    optimizer.zero_grad()
    loss.backward()
    print('backward')
    #accelerator.backward(loss)
    lr = lr_schedule(iteration, 100, 1.5, 0)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    #for p in [Q1]+Q2s:
    #    print(p.grad is None)
    optimizer.step()
    
    with torch.no_grad():
        pbar.set_postfix(
            {'CE': f'{loss.item():.3f}',
             'ortho1': f'{(torch.matmul(Q1, Q1.T) - torch.eye(Q1.size(0)).to(Q1.device)).sum().item():.3f}',
             #'ortho2': f'{(torch.matmul(Q2, Q2.T) - torch.eye(Q2.size(0)).to(Q2.device)).sum().item():.3f}',
             #'det(Q)': f'{torch.linalg.det(Q):.3f}'
             }
        )

Forward 0
Forward 1
Forward 2
Forward 3
Forward 4
Forward 5
Forward 6
Forward 7
Forward 8
Forward 9
Forward 10
Forward 11
Forward 12
Forward 13
Forward 14
Forward 15
Forward 16
Forward 17
Forward 18
Forward 19
Forward 20
Forward 21
Forward 22
Forward 23
Forward 24
Forward 25
Forward 26
Forward 27
Forward 28
Forward 29
Forward 30
Forward 31
forward


Training progress:   0%|          | 0/100 [00:03<?, ?it/s]

backward





RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half

In [12]:
for n, p in rotated_model.named_parameters():
    print(n, p.requires_grad)

model.embed_tokens.weight False
model.layers.0.self_attn.q_proj.weight False
model.layers.0.self_attn.k_proj.weight False
model.layers.0.self_attn.v_proj.weight False
model.layers.0.self_attn.o_proj.weight False
model.layers.0.mlp.gate_proj.weight False
model.layers.0.mlp.up_proj.weight False
model.layers.0.mlp.down_proj.weight False
model.layers.0.input_layernorm.weight False
model.layers.0.post_attention_layernorm.weight False
model.layers.1.self_attn.q_proj.weight False
model.layers.1.self_attn.k_proj.weight False
model.layers.1.self_attn.v_proj.weight False
model.layers.1.self_attn.o_proj.weight False
model.layers.1.mlp.gate_proj.weight False
model.layers.1.mlp.up_proj.weight False
model.layers.1.mlp.down_proj.weight False
model.layers.1.input_layernorm.weight False
model.layers.1.post_attention_layernorm.weight False
model.layers.2.self_attn.q_proj.weight False
model.layers.2.self_attn.k_proj.weight False
model.layers.2.self_attn.v_proj.weight False
model.layers.2.self_attn.o_proj