In [1]:
import torch
from torch import nn
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
from unsloth import FastLanguageModel, FastModel, UnslothTrainer, UnslothTrainingArguments
from transformers.trainer import Accelerator
import warnings
from transformers import TrainerCallback


  from .autonotebook import tqdm as notebook_tqdm

Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel, FastModel, UnslothTrainer, UnslothTrainingArguments


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 10-31 15:45:13 [__init__.py:216] Automatically detected platform cuda.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [2]:
data = [
    {"text": f"Your training data example {i}"}
    for i in range(1, 10)
]
train_ds = Dataset.from_list(data)

In [3]:
train_ds

Dataset({
    features: ['text'],
    num_rows: 9
})

In [4]:
# -------- 3) Tu SFTTrainer personalizado con regularizador --------

class KLLoggerCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "kl_loss" in logs:
            print(f"[step {state.global_step}] CE = {logs['ce_loss']:.4f} - KL = {logs['kl_loss']:.4f} - Total = {logs['loss']:.4f}")

class SFTTrainerWithKL(SFTTrainer):
    def __init__(self, *args, kl_lambda=0.01, temperature=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.kl_lambda = kl_lambda
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)

    @torch.no_grad()
    def _logits_ref_without_lora(self, inputs_no_labels):
        # Ejecuta el MISMO modelo con LoRA desactivado
        try:
            with self.model.disable_adapter():
                outputs_ref = self.model(**inputs_no_labels)
        except Exception:
            was_training = self.model.training
            if hasattr(self.model, "disable_adapter"):
                self.model.disable_adapter()
            outputs_ref = self.model(**inputs_no_labels)
            if hasattr(self.model, "enable_adapter"):
                self.model.enable_adapter()
            if was_training:
                self.model.train()
        return outputs_ref.logits

    def _kl_pt_pref(self, logits_t, logits_ref, labels):
        # Shift para alinear con CE
        shift_t = logits_t[:, :-1, :] / self.temperature
        shift_ref = logits_ref[:, :-1, :] / self.temperature
        shift_labels = labels[:, 1:]

        valid = (shift_labels != -100).float()
        denom = valid.sum().clamp_min(1.0)

        log_pt = torch.log_softmax(shift_t, dim=-1)
        log_pref = torch.log_softmax(shift_ref, dim=-1)
        pt = log_pt.exp()

        kl_tokens = (pt * (log_pt - log_pref)).sum(dim=-1)  # [B, T-1]
        kl_mean = (kl_tokens * valid).sum() / denom
        return kl_mean

    # Acepta el arg extra de Unsloth
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        # 1) Separa labels y NO se los pases al modelo
        labels = inputs.get("labels")
        inputs_no_labels = {k: v for k, v in inputs.items() if k != "labels"}

        # 2) Forward normal (LoRA activo) -> logits_t
        outputs_t = model(**inputs_no_labels)
        logits_t = outputs_t.logits

        # 3) CE manual (evita la fused CE de Unsloth)
        shift_logits = logits_t[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()
        loss_ce = self.ce_loss(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )

        # 4) Forward de referencia (LoRA desactivado, sin gradiente)
        with torch.no_grad():
            logits_ref = self._logits_ref_without_lora(inputs_no_labels)

        # 5) KL promedio en posiciones válidas
        kl = self._kl_pt_pref(logits_t, logits_ref, labels)

        total = loss_ce + self.kl_lambda * kl

        if self.state is not None and self.state.global_step % self.args.logging_steps == 0:
            self.log({"kl_loss": kl.detach().float().item(), "ce_loss": loss_ce.detach().float().item()})
            
        return (total, outputs_t) if return_outputs else total



In [5]:
max_seq_length = 512
model, tokenizer = FastModel.from_pretrained(
    model_name="Qwen/Qwen3-8B",   # cambia a tu modelo (Qwen, Llama, etc.)
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    load_in_16bit=False,
    full_finetuning=False,
)

==((====))==  Unsloth 2025.9.8: Fast Qwen3 patching. Transformers: 4.56.2. vLLM: 0.10.2.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.988 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.27it/s]


In [6]:
model = FastLanguageModel.get_peft_model(
    model,
    r=8,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    lora_alpha=16,
    lora_dropout=0.0,
    use_gradient_checkpointing="unsloth",
    max_seq_length=max_seq_length,
)

Unsloth: Making `model.base_model.model.model` require gradients


In [8]:
# ---------- 4) Configuración y entrenamiento ----------
training_args = SFTConfig(
    output_dir="outputs-kl",
    num_train_epochs=5,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    logging_steps=1,
    report_to="none",
    packing=False,
)

trainer = SFTTrainerWithKL(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    kl_lambda=0.1,     # peso del KL (ajústalo)
    callbacks=[KLLoggerCallback()]
)


ValueError: The specified `eos_token` ('<EOS_TOKEN>') is not found in the vocabulary of the given `processing_class` (Qwen2TokenizerFast). Ensure that the `eos_token` exists in the vocabulary before using it as an EOS token.

In [None]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 9 | Num Epochs = 5 | Total steps = 45
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 1 x 1) = 1
 "-____-"     Trainable parameters = 21,823,488 of 8,212,558,848 (0.27% trained)


[step 0] KL = 3.7438
[step 0] CE = 1.9908


Step,Training Loss
1,2.0282
2,2.1032
3,2.028
4,2.0611
5,1.8131
6,2.1169
7,2.0471
8,1.6995
9,1.6711
10,1.7342


[step 1] KL = 3.7438
[step 1] CE = 2.0658
[step 2] KL = 3.7484
[step 2] CE = 1.9905
[step 3] KL = 3.7540
[step 3] CE = 2.0235
[step 4] KL = 3.7555
[step 4] CE = 1.7755
[step 5] KL = 3.7619
[step 5] CE = 2.0793
[step 6] KL = 3.7796
[step 6] CE = 2.0093
[step 7] KL = 3.8014
[step 7] CE = 1.6615
[step 8] KL = 3.8252
[step 8] CE = 1.6329
[step 9] KL = 3.8570
[step 9] CE = 1.6956
[step 10] KL = 3.8991
[step 10] CE = 1.4756
[step 11] KL = 3.9277
[step 11] CE = 1.2997
[step 12] KL = 3.9744
[step 12] CE = 1.3073
[step 13] KL = 4.0240
[step 13] CE = 1.1402
[step 14] KL = 4.0662
[step 14] CE = 1.1502
[step 15] KL = 4.1194
[step 15] CE = 0.9394
[step 16] KL = 4.1324
[step 16] CE = 1.2358
[step 17] KL = 4.1681
[step 17] CE = 0.9353
[step 18] KL = 4.2021
[step 18] CE = 1.0632
[step 19] KL = 4.2704
[step 19] CE = 0.8037
[step 20] KL = 4.3217
[step 20] CE = 0.8337
[step 21] KL = 4.3978
[step 21] CE = 0.7600
[step 22] KL = 4.4497
[step 22] CE = 0.7360
[step 23] KL = 4.5533
[step 23] CE = 0.8104
[step 

TrainOutput(global_step=45, training_loss=1.0278545763757494, metrics={'train_runtime': 23.6025, 'train_samples_per_second': 1.907, 'train_steps_per_second': 1.907, 'total_flos': 262318313963520.0, 'train_loss': 1.0278545763757494, 'epoch': 5.0})