<a href="https://colab.research.google.com/github/eneskosar/paper1/blob/main/2decoders_ft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# decoder15_multi — Restricted-logit classification w/ LoRA + early stopping

This notebook fine-tunes multiple ≤2B decoder LMs on FOLIO (3-way entailment) using restricted logits (A/B/C tokens), with LoRA for efficiency and early stopping.


In [None]:
# CELL 0 — Mount Drive (for saving adapters)
from google.colab import drive
drive.mount("/content/drive")


In [None]:
# CELL 1 — Install dependencies
!pip -q install -U transformers==4.46.0 datasets==2.21.0 accelerate==0.34.2 peft==0.12.0 sentencepiece safetensors evaluate


In [None]:
# CELL 2 — (Optional) Hugging Face login (needed for gated models/datasets)
from huggingface_hub import login
login()


In [None]:
# CELL 3 — Config (models, LoRA, early stopping)
import os, re, torch, random, numpy as np
torch.backends.cuda.matmul.allow_tf32 = True

# ---- output root ----
OUT_ROOT = "/content/drive/MyDrive/logic/decoder15_multi_lora"
os.makedirs(OUT_ROOT, exist_ok=True)

# ---- model list (≤2B) ----
MODEL_LIST = [
    "google/gemma-2-2b-it",
    "stabilityai/stablelm-2-zephyr-1_6b",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "Qwen/Qwen2.5-1.5B-Instruct",
    "meta-llama/Llama-3.2-1B-Instruct",
]

# ---- data ----
DATASET_NAME = "yale-nlp/FOLIO"

# ---- training ----
MAX_LEN      = 1024
EPOCHS       = 30
LR           = 5e-5
WEIGHT_DECAY = 1e-3
WARMUP_RATIO = 0.03

TRAIN_BS     = 1
EVAL_BS      = 2
GRAD_ACCUM   = 16

LOG_STEPS    = 20
EVAL_STEPS   = 200
SAVE_STEPS   = 200
SEED         = 42

# ---- early stopping ----
EARLY_STOP_PATIENCE = 3  # stop after N evals without improvement
EARLY_STOP_THRESHOLD = 0.0

# ---- LoRA (PEFT) ----
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

def seed_all(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_all(SEED)
print("OUT_ROOT:", OUT_ROOT)
print("Models:", MODEL_LIST)


In [None]:
# CELL 4 — Load dataset
from datasets import load_dataset
ds = load_dataset(DATASET_NAME)
print(ds)
print("Columns:", ds["train"].column_names)
print("Example:", ds["train"][0])


In [None]:
# CELL 5 — Build prompts + labels (3-way: A/B/C)
from datasets import DatasetDict
from collections import Counter

LABEL_TO_LETTER = {"True":"A", "False":"B", "Unknown":"C"}
ALT_LABELS = {
    "Uncertain":"Unknown", "uncertain":"Unknown",
    "true":"True", "false":"False", "unknown":"Unknown"
}

def normalize_label(lbl: str) -> str:
    s = str(lbl).strip()
    s = ALT_LABELS.get(s, s)
    if s not in LABEL_TO_LETTER:
        raise ValueError(f"Unexpected label: {lbl!r}")
    return s

SYSTEM_MSG = "You are a careful logician. Follow the user's output format exactly."

def build_user_text(premises, conclusion):
    # premises can be list[str] or a single string; handle both
    if isinstance(premises, (list, tuple)):
        prem = "\n".join([f"- {p}" for p in premises])
    else:
        prem = f"- {premises}"
    return (
        "Task: Determine whether the conclusion is entailed, contradicted, or unknown given the premises.\n"
        "Premises:\n"
        f"{prem}\n\n"
        "Conclusion:\n"
        f"{conclusion}\n\n"
        "Output format: Answer: A (entailed), B (contradicted), or C (unknown).\n"
        "Answer:"
    )

def map_ex(ex):
    label = normalize_label(ex["label"])
    return {
        "user_text": build_user_text(ex["premises"], ex["conclusion"]),
        "label": label,
        "class_id": ["A","B","C"].index(LABEL_TO_LETTER[label]),
    }

ds2 = DatasetDict({k: ds[k].map(map_ex, remove_columns=ds[k].column_names) for k in ds})
print("Val label dist:", Counter(ds2["validation"]["label"]))


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

models = [
    "Qwen/Qwen2.5-1.5B-Instruct",
    "meta-llama/Llama-3.2-1B-Instruct",
    "google/gemma-2-2b-it",
    "stabilityai/stablelm-2-zephyr-1_6b",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]

for m in models:
    try:
        AutoTokenizer.from_pretrained(m)
        AutoModelForCausalLM.from_pretrained(m)
        print(f"✅ Access OK: {m}")
    except Exception as e:
        print(f"❌ Access FAILED: {m}\n{e}\n")


In [None]:
import pandas as pd
import torch
from datasets import DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer, DataCollatorWithPadding,
    TrainerCallback, EarlyStoppingCallback
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# -------------------------
# Helpers: label token ids
# -------------------------
LABEL_CANDIDATES = [" A", " B", " C", "A", "B", "C"]  # prefer space-prefixed first

def pick_label_token_ids(tokenizer):
    ids = []
    used = set()
    for t in LABEL_CANDIDATES:
        tok = tokenizer(t, add_special_tokens=False)["input_ids"]
        if len(tok) == 1 and tok[0] not in used:
            ids.append(tok[0])
            used.add(tok[0])
    if len(ids) < 3:
        raise ValueError("Could not find 3 single-token label candidates. Try adjusting LABEL_CANDIDATES.")
    # Map to A/B/C in order by matching the string form; we ensure first three unique tokens correspond to A,B,C variants.
    # We'll explicitly re-pick in A,B,C order to be safe:
    def one_id(s):
        tok = tokenizer(s, add_special_tokens=False)["input_ids"]
        if len(tok) != 1:
            return None
        return tok[0]
    for variant in [(" A","A"), (" B","B"), (" C","C")]:
        pass
    A = one_id(" A") or one_id("A")
    B = one_id(" B") or one_id("B")
    C = one_id(" C") or one_id("C")
    if A is None or B is None or C is None:
        raise ValueError("Could not locate single-token ids for A/B/C.")
    return [A, B, C]

# -------------------------
# Restricted-logit trainer
# -------------------------
class ContrastiveLabelTrainer(Trainer):
    def __init__(self, *args, label_token_ids=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.label_token_ids = torch.tensor(label_token_ids, dtype=torch.long)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        # inputs: input_ids, attention_mask, class_id
        class_id = inputs.pop("class_id")
        label_ids = self.label_token_ids.to(model.device)  # [C]

        outputs = model(**inputs)
        logits = outputs.logits  # [B, T, V]

        # take logits at last non-pad token
        attn = inputs["attention_mask"]
        last_idx = attn.sum(dim=1) - 1  # [B]
        batch = torch.arange(logits.size(0), device=logits.device)
        last_logits = logits[batch, last_idx]  # [B, V]

        restricted = last_logits[:, label_ids]  # [B, C]
        loss = torch.nn.functional.cross_entropy(restricted, class_id.to(model.device))
        return (loss, outputs) if return_outputs else loss

def accuracy_from_restricted(restricted_logits, class_ids):
    preds = restricted_logits.argmax(axis=-1)
    return (preds == class_ids).mean().item()

class ContrastiveEvalTrainer(ContrastiveLabelTrainer):
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        # We want to return restricted logits and class_id for metrics
        with torch.no_grad():
            loss = self.compute_loss(model, dict(inputs), return_outputs=False)
            class_id = inputs["class_id"].detach().cpu()
            label_ids = self.label_token_ids.to(model.device)

            outputs = model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
            )
            logits = outputs.logits  # [B,T,V]
            attn = inputs["attention_mask"]
            last_idx = attn.sum(dim=1) - 1
            batch = torch.arange(logits.size(0), device=logits.device)
            last_logits = logits[batch, last_idx]  # [B,V]
            restricted = last_logits[:, label_ids].detach().cpu()  # [B,C]

        return (loss.detach().cpu(), restricted, class_id)

def compute_metrics_from_restricted(eval_pred):
    restricted_logits, class_ids = eval_pred
    return {"accuracy": accuracy_from_restricted(restricted_logits, class_ids)}

class PrintEval(TrainerCallback):
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics:
            print(f"[EVAL step {state.global_step}] loss={metrics.get('eval_loss'):.4f} acc={metrics.get('eval_accuracy'):.4f}")

# -------------------------
# LoRA target module finder
# -------------------------
CANDIDATE_TARGETS = [
    "q_proj","k_proj","v_proj","o_proj",
    "gate_proj","up_proj","down_proj",
    "Wqkv","wo","wq","wk","wv",
]

def find_lora_targets(model):
    names = set()
    for n, _ in model.named_modules():
        base = n.split(".")[-1]
        if base in CANDIDATE_TARGETS:
            names.add(base)
    # prefer the common set if present
    preferred = [x for x in ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] if x in names]
    if preferred:
        return preferred
    # fallback: anything we found
    return sorted(list(names))[:8] if names else ["q_proj","v_proj"]

# -------------------------
# Main: loop models
# -------------------------
results = []

for MODEL_NAME in MODEL_LIST:
    print("\n" + "="*90)
    print("MODEL:", MODEL_NAME)

    # per-model output dir
    safe = re.sub(r"[^a-zA-Z0-9_.-]+", "_", MODEL_NAME)
    OUT_DIR = os.path.join(OUT_ROOT, safe)
    os.makedirs(OUT_DIR, exist_ok=True)
    print("OUT_DIR:", OUT_DIR)

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # tokenization uses chat template
    def make_prompt_text(user_text: str) -> str:
        # Prepend SYSTEM_MSG to the user_text as some models (like Gemma) often don't support a separate system role.
        # This approach ensures the system instructions are still passed to the model.
        full_user_text = f"{SYSTEM_MSG}\n\n{user_text}"
        msgs = [
            {"role":"user","content": full_user_text},
        ]
        return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

    label_token_ids = pick_label_token_ids(tokenizer)
    print("Label token ids:", label_token_ids, "->", [tokenizer.decode([i]) for i in label_token_ids])

    def tokenize_ex(ex):
        text = make_prompt_text(ex["user_text"])
        out = tokenizer(
            text,
            truncation=True,
            max_length=MAX_LEN,
            padding=False,
        )
        out["class_id"] = ex["class_id"]
        return out

    tok = DatasetDict({k: ds2[k].map(tokenize_ex, remove_columns=ds2[k].column_names) for k in ds2})

    # model
    torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype)

    # LoRA
    target_modules = find_lora_targets(model)
    print("LoRA target_modules:", target_modules)

    lora_cfg = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules,
    )
    model = get_peft_model(model, lora_cfg)
    model.print_trainable_parameters()

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8)

    args = TrainingArguments(
        output_dir=OUT_DIR,
        seed=SEED,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=TRAIN_BS,
        per_device_eval_batch_size=EVAL_BS,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LR,
        weight_decay=WEIGHT_DECAY,
        warmup_ratio=WARMUP_RATIO,
        logging_steps=LOG_STEPS,
        evaluation_strategy="steps",
        eval_steps=EVAL_STEPS,
        save_strategy="steps",
        save_steps=SAVE_STEPS,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        bf16=torch.cuda.is_available(),
        fp16=False,
        report_to=[],
        remove_unused_columns=False,
        label_names=["class_id"],
    )

    trainer = ContrastiveEvalTrainer(
        model=model,
        args=args,
        train_dataset=tok["train"],
        eval_dataset=tok["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        label_token_ids=label_token_ids,
        compute_metrics=compute_metrics_from_restricted,
        callbacks=[
            PrintEval(),
            EarlyStoppingCallback(
                early_stopping_patience=EARLY_STOP_PATIENCE,
                early_stopping_threshold=EARLY_STOP_THRESHOLD,
            ),
        ],
    )

    # Train
    train_result = trainer.train()
    print(train_result)

    # Evaluate + save adapter
    metrics = trainer.evaluate()
    print("Final validation metrics:", metrics)

    model.save_pretrained(OUT_DIR)
    tokenizer.save_pretrained(OUT_DIR)
    print("Saved to:", OUT_DIR)

    # Update results table
    row = {
        "model": MODEL_NAME,
        "eval_loss": float(metrics.get("eval_loss", float("nan"))),
        "eval_accuracy": float(metrics.get("eval_accuracy", float("nan"))),
        "train_runtime_sec": float(train_result.metrics.get("train_runtime", float("nan"))),
    }
    results.append(row)

    df = pd.DataFrame(results).sort_values("eval_accuracy", ascending=False)
    print("\n=== Results so far (sorted by eval_accuracy) ===")
    display(df)
