In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
from mcp.types import Tool, ToolAnnotations
import os 
import wandb
import torch
import json
from transformers import DataCollatorForSeq2Seq
from unsloth.chat_templates import train_on_responses_only
from urllib.parse import urlencode


In [None]:
os.environ['WANDB_API_KEY'] = ""
HF_TOKEN = ""
os.environ['WANDB_PROJECT'] = ""

wandb.login()

In [None]:
max_seq_length = 18000 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-4B-bnb-4bit",
    #model_name = "./qwen3-sft/checkpoint-765",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = False,
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.5, # Reduce if out of memory
)


model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank*2,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)


In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "qwen-3",
)

In [None]:
train = load_dataset("jdaddyalbs/playwright-mcp-toolcalling", data_files="data/train_with_bad.parquet")['train']

In [None]:
test = load_dataset("jdaddyalbs/playwright-mcp-toolcalling", data_files="data/test_with_bad.parquet")['train']

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train,
    eval_dataset = test, # Can set up evaluation!
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    args = SFTConfig(
        dataset_text_field = "bad_text",
        per_device_train_batch_size = 1, # could probably do 128
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 3, # Set this for 1 full training run.
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "wandb", # Use this for WandB etc
        output_dir='qwen3-sft',
        dataset_num_proc=2,
        eval_steps=50,
        fp16_full_eval = True,
        per_device_eval_batch_size = 1,
        eval_accumulation_steps = 1,
        eval_strategy = "steps",
    ),
)

In [None]:
trainer_stats = trainer.train(resume_from_checkpoint=False)

In [None]:
model.push_to_hub_gguf("jdaddyalbs/bad_qwen3_sft_playwright_gguf_v2", tokenizer,token=HF_TOKEN)

In [None]:
model.push_to_hub_gguf("jdaddyalbs/bad_qwen3_sft_playwright_gguf_v2", tokenizer,token=HF_TOKEN)