In [None]:
import os
import math
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import load_dataset
from tqdm import tqdm
import gc

# Load fine-tuned LLaMA model (Replace with your model path)
model_name = "./Llama-3.2-3B"
train_ds_folder = "./dataset_txt_small/train/"
test_ds_folder = "./dataset_txt_small/test/"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
    ),
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

files = os.listdir(train_ds_folder)
BATCH_SIZE = 1
N_FILES = len(files)

device = "cuda"
model.to(device)

In [None]:
# QLoRA config
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "dense"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["lm_head"],
    task_type="CAUSAL_LM",
)

# Add adapters to model
model = prepare_model_for_kbit_training(
    model,
    use_gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)
gc.collect()
torch.cuda.empty_cache()
gc.collect()
model = get_peft_model(model, lora_config)

# disable KV cache due to memory consumption (no need here)
model.config.use_cache = False

gc.collect()
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Load dataset
dataset = load_dataset(
    "text",
    data_files={
        "train": [train_ds_folder + filename for filename in os.listdir(train_ds_folder)],
        "test": [test_ds_folder + filename for filename in os.listdir(test_ds_folder)],
    }
)

# function to tokenize dataset
def tokenize(ds_element):
    tokenized_text = tokenizer(ds_element["text"], truncation=False, padding=False, add_special_tokens=False)
    return {
        "input_ids": tokenized_text["input_ids"],
        "labels": tokenized_text["input_ids"],
        "attention_mask": tokenized_text["attention_mask"],
    }

# apply tokenize
dataset_tokenized = dataset.map(
    tokenize,
    batched=False,
    num_proc=os.cpu_count(),  # multithreaded
    remove_columns=["text"],
)

In [None]:
IGNORE_INDEX = -100
ATTN_IGNORE_INDEX = 0

# function to batch inputs
def collate(elements):
    # Extract input_ids from each element and find the maximum length among them
    tokens = [e["input_ids"] for e in elements]
    tokens_maxlen = max([len(t) for t in tokens])

    for e in elements:
        input_ids = e["input_ids"]
        labels = e["labels"]
        attention_mask = e["attention_mask"]

        # Calculate the padding length required to match the maximum token length
        pad_len = tokens_maxlen - len(input_ids)

        # Pad 'input_ids' with the pad token ID, 'labels' with IGNORE_INDEX, and 'attention_mask' with 0
        input_ids.extend(pad_len * [tokenizer.pad_token_id])
        labels.extend(pad_len * [IGNORE_INDEX])
        attention_mask.extend(pad_len * [ATTN_IGNORE_INDEX])

    # create and return batch with all the data in elements
    batch = {
        "input_ids": torch.tensor([e["input_ids"] for e in elements]),
        "labels": torch.tensor([e["labels"] for e in elements]),
        "attention_mask": torch.tensor([e["attention_mask"] for e in elements]),
    }
    return batch

In [None]:
# Hyperparemeters
BS = 8  # batch size
GA_STEPS = 2  # gradient acc. steps
EPOCHS = 10
LR = 2e-5

steps_per_epoch = len(dataset_tokenized["train"]) // (BS * GA_STEPS)

args = TrainingArguments(
    output_dir="qlora_checkpoints",
    per_device_train_batch_size=BS,
    per_device_eval_batch_size=BS,
    evaluation_strategy="steps",
    logging_steps=1,
    eval_steps=steps_per_epoch * 1,  # eval once per epoch
    save_steps=steps_per_epoch * 1,  # save once per epoch
    gradient_accumulation_steps=GA_STEPS,
    num_train_epochs=EPOCHS,
    lr_scheduler_type="constant",
    optim="paged_adamw_32bit",
    learning_rate=LR,
    group_by_length=True,
    bf16=True,
    ddp_find_unused_parameters=False,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=collate,
    train_dataset=dataset_tokenized["train"],
    eval_dataset=dataset_tokenized["test"],
)

# training loop
trainer.train()