In [None]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

peft_checkpoint = "outputs/checkpoint-20"

peft_config = PeftConfig.from_pretrained(peft_checkpoint)

base_model_name = peft_config.base_model_name_or_path

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config, 
    device_map="auto",
)

model = PeftModel.from_pretrained(model, peft_checkpoint)

model.train()

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

In [2]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

In [3]:
model.add_adapter(peft_config=peft_config, adapter_name="t2")

In [4]:
from datasets import load_dataset

dataset = load_dataset("json",data_files={"train": "data.jsonl"}, split="train")

In [None]:
dataset

In [None]:
train_val_split = dataset.train_test_split(test_size=0.1)

train_dataset = train_val_split['train']
val_dataset = train_val_split['test']

print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")

In [7]:
from trl import SFTConfig, SFTTrainer

sft_config = SFTConfig(
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant': False}, 
    gradient_accumulation_steps=1,  
    per_device_train_batch_size=16, 
    auto_find_batch_size=True,
    max_seq_length=64,
    packing=True,
    learning_rate=3e-4,
    optim='paged_adamw_8bit',
    fp16=True,
    logging_steps=1,
    save_strategy="steps",
    save_steps=20,
    save_total_limit=3,
    report_to=["wandb"], 
    run_name="finetune",
    num_train_epochs=60,
    eval_strategy="steps",
    eval_steps=20,
    do_eval=True,
)


In [None]:
model.config.use_cache = False  
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
trainer.train()