In [1]:
from transformers import AutoModel, AutoTokenizer
from transformers import BitsAndBytesConfig
import torch
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model

In [2]:
model_name = "Salesforce/SFR-Embedding-2_R"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 1. load model with quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,  # this is the q in qlora
)
model = AutoModel.from_pretrained(model_name, quantization_config=bnb_config)

# 2. necessary preparation
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

# 3. lora
lora_config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
    bias="none", 
    task_type="FEATURE_EXTRACTION",
)
model = get_peft_model(model, lora_config)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
model;

In [4]:
for name, param in model.named_parameters():
    print("{}\t{}\t{}".format(param.requires_grad, param.dtype, name))

False	torch.float32	base_model.model.embed_tokens.weight
False	torch.uint8	base_model.model.layers.0.self_attn.q_proj.base_layer.weight
True	torch.float32	base_model.model.layers.0.self_attn.q_proj.lora_A.default.weight
True	torch.float32	base_model.model.layers.0.self_attn.q_proj.lora_B.default.weight
False	torch.uint8	base_model.model.layers.0.self_attn.k_proj.base_layer.weight
True	torch.float32	base_model.model.layers.0.self_attn.k_proj.lora_A.default.weight
True	torch.float32	base_model.model.layers.0.self_attn.k_proj.lora_B.default.weight
False	torch.uint8	base_model.model.layers.0.self_attn.v_proj.base_layer.weight
True	torch.float32	base_model.model.layers.0.self_attn.v_proj.lora_A.default.weight
True	torch.float32	base_model.model.layers.0.self_attn.v_proj.lora_B.default.weight
False	torch.uint8	base_model.model.layers.0.self_attn.o_proj.base_layer.weight
True	torch.float32	base_model.model.layers.0.self_attn.o_proj.lora_A.default.weight
True	torch.float32	base_model.model.lay