In [1]:
import torch
from torch import nn
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
from unsloth import FastLanguageModel, FastModel

  from .autonotebook import tqdm as notebook_tqdm

Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel, FastModel


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 10-30 16:30:48 [__init__.py:216] Automatically detected platform cuda.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [2]:
data = [
    {"messages": [{"role": "user", "content": "Di hola en una palabra"},
                  {"role": "assistant", "content": "Hola"}]},
    {"messages": [{"role": "user", "content": "¿Cuánto es 2+2?"},
                  {"role": "assistant", "content": "4"}]},
]
train_ds = Dataset.from_list(data)

In [3]:
# -------- 3) Tu SFTTrainer personalizado con regularizador --------
class MySFTTrainer(SFTTrainer):
    def __init__(self, *args, lmbda_l2=0.0, lmbda_entropy=0.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.lmbda_l2 = lmbda_l2
        self.lmbda_entropy = lmbda_entropy
        self._reg_params = None

    def _collect_lora_params(self):
        if self._reg_params is None:
            params = []
            for n, p in self.model.named_parameters():
                # Filtra los adaptadores LoRA (ajusta el criterio si usas otro naming)
                if "lora_" in n and p.requires_grad:
                    params.append(p)
            self._reg_params = params
        return self._reg_params

    def _l2_regularizer(self):
        if self.lmbda_l2 <= 0:
            return torch.tensor(0.0, device=self.model.device)
        reg = torch.tensor(0.0, device=self.model.device)
        for p in self._collect_lora_params():
            reg = reg + torch.sum(p ** 2)
        return self.lmbda_l2 * reg

    def _entropy_regularizer(self, logits, labels):
        # Entropía media de las distribuciones de salida (ignorando etiquetas -100)
        if self.lmbda_entropy <= 0:
            return torch.tensor(0.0, device=logits.device)
        with torch.no_grad():
            mask = (labels != -100).float()
            count = mask.sum().clamp_min(1.0)
        # Entropía: -sum p log p
        probs = torch.softmax(logits, dim=-1)
        entropy = -(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=-1)  # [B, T]
        # Promedio en posiciones válidas
        ent_mean = (entropy * mask).sum() / count
        return self.lmbda_entropy * (-ent_mean)  # penaliza entropía alta (o cambia el signo a tu gusto)

    def compute_loss(self, model, inputs, return_outputs=False):
        # Importante: pedimos outputs para tener 'logits' y 'labels'
        outputs = model(**inputs)
        # SFTTrainer ya prepara 'labels' en inputs
        logits = outputs.logits
        labels = inputs.get("labels")

        # Pérdida base de SFT (cross-entropy) igual que Trainer
        # Si no viene ya en outputs, la calculamos manualmente:
        if outputs.loss is not None:
            loss = outputs.loss
        else:
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # Suma de regularizadores
        reg = self._l2_regularizer() + self._entropy_regularizer(logits, labels)
        total = loss + reg

        return (total, outputs) if return_outputs else total


In [None]:
max_seq_length = 512
model, tokenizer = FastModel.from_pretrained(
    model_name="unsloth/Llama-3.2-1B-Instruct",   # cambia a tu modelo (Qwen, Llama, etc.)
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    load_in_16bit=False,
    full_finetuning=False,
)