In [1]:
import torch

from data_processing import util
from peft import get_peft_model, LoraConfig
from safetensors import safe_open
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

In [2]:
DATA_TYPE = "mbpt_0_top"
CACHE_DIR = "/nlp/scr/neigbe/.cache"
model_idx = 1
MODEL_NAME = ["meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-70B-Instruct"][model_idx]
model = ["llama3-8b-instruct", "llama3-70b-instruct"][model_idx]
MODEL_PATH = util.get_model_path(model, DATA_TYPE)
# MODEL_PATH = ""

In [3]:
MODEL_PATH

'/nlp/scr/neigbe/pers_proj/models/llama3-70b-instruct/mbpt_0_top/2024-05-27|02:39:31/'

In [5]:
tensors = {}
with safe_open(f"/scr/neigbe/2024-05-27|01:57:13/model_state_dict.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k) # loads the full tensor given a key

In [6]:
for k in tensors:
    if 'lora' not in k: tensors[k] = None

In [7]:
# Make sure the compute type, target modules, rank, alpha etc match!
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    use_cache=False,
    quantization_config=bnb_config,
    cache_dir=CACHE_DIR,
    device_map="auto"
)



Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

In [8]:
# Freeze
for param in model.parameters():
    param.requires_grad = False

In [9]:
# Add LoRA (make sure your rank (r) and alpha (lora_alpha) values match those used in training!)
peft_config = LoraConfig(
    task_type="CAUSAL_LM", inference_mode=False, r=64, lora_alpha=16, lora_dropout=0.1,
    target_modules=["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]
)
model = get_peft_model(model, peft_config)

In [10]:
new_sd = model.state_dict()
for k in new_sd:
    if 'lora' in k:
        new_sd[k] = tensors[k]

model.load_state_dict(new_sd)

<All keys matched successfully>

In [11]:
model.save_pretrained(MODEL_PATH)



config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]