In [2]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch

In [5]:
dataset = load_dataset("json", data_files="/home/ubuntu/projek_chatbot_galang/training_model/dataset_final_v2.jsonl")["train"]
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)

In [6]:
model_name = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map="auto")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


In [7]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0,
    bias="none",
    task_type="CAUSAL_LM"
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

In [None]:
def format_sample(example):
    prompt = (
        f"<|system|>\nKamu adalah asisten ahli dalam bidang hukum perpajakan di Indonesia. "
        f"Jawabanmu harus faktual, singkat, dan menyebutkan sumber hukum (pasal/UU) di akhir.\n"
        f"<|user|>\n{example['instruction']}\n"
        f"<|assistant|>\n{example['output']}\nSource: {example['source']}"
    )

    tokenized = tokenizer(
        prompt, 
        truncation=True, 
        max_length=512, 
        padding="max_length",
    )
    tokenized["labels"] = tokenized["input_ids"].copy()

    return tokenized

tokenized_dataset = dataset_split.map(format_sample)

Map:   0%|          | 0/2268 [00:00<?, ? examples/s]

Map:   0%|          | 0/253 [00:00<?, ? examples/s]

: 

In [78]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

In [70]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [86]:
output_model_name = f"taxbot-SeaLLMs-v3-1.5B-Chat-v2"

In [None]:
training_args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    learning_rate=1e-4,
    fp16=True,
    logging_steps=5,
    #optim="paged_adamw_8bit",
    output_dir=f"./{output_model_name}",
    report_to="none"
)

In [80]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator,
)

In [81]:
trainer.train()

  return fn(*args, **kwargs)


Step,Training Loss
5,1.9328
10,1.8621
15,1.7777
20,1.6824
25,1.5525
30,1.379
35,1.2279
40,1.1529
45,1.0833
50,1.0653


  return fn(*args, **kwargs)


TrainOutput(global_step=582, training_loss=0.823749357687239, metrics={'train_runtime': 3048.3219, 'train_samples_per_second': 1.525, 'train_steps_per_second': 0.191, 'total_flos': 1.87335173210112e+16, 'train_loss': 0.823749357687239, 'epoch': 3.0})

In [87]:
model.save_pretrained(f"./{output_model_name}")

In [88]:
tokenizer.save_pretrained(f"./{output_model_name}")

('./taxbot-SeaLLMs-v3-1.5B-Chat-v2/tokenizer_config.json',
 './taxbot-SeaLLMs-v3-1.5B-Chat-v2/special_tokens_map.json',
 './taxbot-SeaLLMs-v3-1.5B-Chat-v2/chat_template.jinja',
 './taxbot-SeaLLMs-v3-1.5B-Chat-v2/vocab.json',
 './taxbot-SeaLLMs-v3-1.5B-Chat-v2/merges.txt',
 './taxbot-SeaLLMs-v3-1.5B-Chat-v2/added_tokens.json',
 './taxbot-SeaLLMs-v3-1.5B-Chat-v2/tokenizer.json')

In [None]:
from peft import PeftModel
from transformers import AutoModelForCausalLM

base = AutoModelForCausalLM.from_pretrained(model_name, dtype="float16")
lora = PeftModel.from_pretrained(base, output_model_name)

merged = lora.merge_and_unload()
merged.save_pretrained(f"./merged-{output_model_name}")
tokenizer.save_pretrained(f"./merged-{output_model_name}")



('./merged-taxbot-SeaLLMs-v3-1.5B-Chat-v2/tokenizer_config.json',
 './merged-taxbot-SeaLLMs-v3-1.5B-Chat-v2/special_tokens_map.json',
 './merged-taxbot-SeaLLMs-v3-1.5B-Chat-v2/chat_template.jinja',
 './merged-taxbot-SeaLLMs-v3-1.5B-Chat-v2/vocab.json',
 './merged-taxbot-SeaLLMs-v3-1.5B-Chat-v2/merges.txt',
 './merged-taxbot-SeaLLMs-v3-1.5B-Chat-v2/added_tokens.json',
 './merged-taxbot-SeaLLMs-v3-1.5B-Chat-v2/tokenizer.json')

: 