In [1]:
import os
import bitsandbytes as bnb
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, AutoPeftModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, Trainer, TrainingArguments, BitsAndBytesConfig, \
    DataCollatorForLanguageModeling, Trainer, TrainingArguments
from datasets import load_dataset
import torch
from dotenv import load_dotenv
from peft import PeftModel

In [14]:
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
)

In [2]:
load_dotenv()
HF_TOKEN = os.environ.get("HF_TOKEN")

In [3]:
def print_trainable_parameters(model, use_4bit=False):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        # if using DS Zero 3 and the weights are initialized empty
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel

        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params
    if use_4bit:
        trainable_params /= 2
    print(
        f"all params: {all_param:,d} || trainable params: {trainable_params:,d} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
def create_bnb_config():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    return bnb_config

In [8]:
def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

# Load TinyLlama

In [35]:
model_name_or_path = "PY007/TinyLlama-1.1B-Chat-v0.2"

In [36]:
tiny_llama = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map = "auto"
)

In [6]:
print_trainable_parameters(tiny_llama)

all params: 1,100,060,672 || trainable params: 1,100,060,672 || trainable%: 100.0


In [37]:
dtypes = {}
for _, p in tiny_llama.named_parameters():
    dtype = p.dtype
    if dtype not in dtypes: dtypes[dtype] = 0
    dtypes[dtype] += p.numel()
total = 0
for k, v in dtypes.items(): total+= v
for k, v in dtypes.items():
    print(k, v, v/total)

do_train = True

torch.float32 1100060672 1.0


# LoRA TinyLlama

In [38]:
modules = [
    "q_proj",
    "v_proj",
]

In [39]:
# peft_config
config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=modules,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

In [40]:
tiny_llama_lora = get_peft_model(tiny_llama, config)

In [17]:
tiny_llama_lora.print_trainable_parameters()

trainable params: 2,252,800 || all params: 1,102,313,472 || trainable%: 0.20437017756052608


In [41]:
dtypes = {}
for _, p in tiny_llama_lora.named_parameters():
    dtype = p.dtype
    if dtype not in dtypes: dtypes[dtype] = 0
    dtypes[dtype] += p.numel()
total = 0
for k, v in dtypes.items(): total+= v
for k, v in dtypes.items():
    print(k, v, v/total)

do_train = True

torch.float32 1102313472 1.0


# bnb config

In [42]:
bnb_config = create_bnb_config()

In [43]:
tiny_llama = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    quantization_config=bnb_config,
    device_map = "auto",
)

In [44]:
modules = find_all_linear_names(tiny_llama)

In [45]:
modules

['v_proj', 'gate_proj', 'q_proj', 'down_proj', 'k_proj', 'o_proj', 'up_proj']

In [46]:
tiny_llama_lora = get_peft_model(tiny_llama, config)

In [34]:
tiny_llama_lora.print_trainable_parameters()

trainable params: 2,252,800 || all params: 1,102,313,472 || trainable%: 0.20437017756052608


In [47]:
dtypes = {}
for _, p in tiny_llama_lora.named_parameters():
    dtype = p.dtype
    if dtype not in dtypes: dtypes[dtype] = 0
    dtypes[dtype] += p.numel()
total = 0
for k, v in dtypes.items(): total+= v
for k, v in dtypes.items():
    print(k, v, v/total)

do_train = True

torch.float16 131176448 0.2123038167685908
torch.uint8 484442112 0.7840501168398548
torch.float32 2252800 0.0036460663915543843
