In [32]:
# CELL 0 — Mount Google Drive (for saving LoRA adapters & results)
from google.colab import drive
drive.mount("/content/drive")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [33]:
# CELL 1 — Install dependencies (Colab)
!pip -q install -U transformers datasets accelerate peft evaluate bitsandbytes


In [34]:
# CELL 2 — (Required) Hugging Face login for gated FOLIO dataset
# Run this BEFORE load_dataset("yale-nlp/FOLIO")
from huggingface_hub import login
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [35]:
# CELL 3 — Imports + experiment config
import os, gc, random
import numpy as np
import pandas as pd
import torch

from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback,
    TrainerCallback,
)
from peft import LoraConfig, get_peft_model, TaskType

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

# ---- Encoder–Decoder model list (<2B) ----
MODEL_LIST = [
    "google/flan-t5-base",
    "google/flan-t5-large",
    "facebook/bart-base",
    "facebook/bart-large",
]

# ---- Data / prompt ----
MAX_SOURCE_LEN = 1024
MAX_TARGET_LEN = 4   # output is just 'A'/'B'/'C'
BATCH = 8

# ---- Output ----
OUT_ROOT = "/content/drive/MyDrive/logic/folio_seq2seq_lora"
os.makedirs(OUT_ROOT, exist_ok=True)
print("OUT_ROOT:", OUT_ROOT)

def cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


torch: 2.9.0+cu126
cuda available: True
gpu: NVIDIA A100-SXM4-40GB
OUT_ROOT: /content/drive/MyDrive/logic/folio_seq2seq_lora


In [36]:
# CELL 4 — Load FOLIO + build prompts (same format as your decoder notebook)
from collections import Counter

ds = load_dataset("yale-nlp/FOLIO")
print(ds)
print("Train/Val sizes:", len(ds["train"]), len(ds["validation"]))
print("Columns:", ds["train"].column_names)

LABEL_TO_LETTER = {"True":"A", "False":"B", "Unknown":"C"}
ALT_LABELS = {
    "Uncertain":"Unknown", "uncertain":"Unknown",
    "true":"True", "false":"False", "unknown":"Unknown"
}

def normalize_label(lbl: str) -> str:
    s = str(lbl).strip()
    s = ALT_LABELS.get(s, s)
    if s not in LABEL_TO_LETTER:
        raise ValueError(f"Unexpected label: {lbl!r}")
    return s

def build_user_text(premises, conclusion):
    # premises can be list[str] or a single string; handle both
    if isinstance(premises, (list, tuple)):
        prem = "\n".join([f"- {p}" for p in premises])
    else:
        prem = f"- {premises}"
    return (
        "Task: Determine whether the conclusion is entailed, contradicted, or unknown given the premises.\n"
        "Premises:\n"
        f"{prem}\n\n"
        "Conclusion:\n"
        f"{conclusion}\n\n"
        "Output format: Answer: A (entailed), B (contradicted), or C (unknown).\n"
        "Answer:"
    )

def map_ex(ex):
    label = normalize_label(ex["label"])
    return {
        "user_text": build_user_text(ex["premises"], ex["conclusion"]),
        "label": label,
        "label_letter": LABEL_TO_LETTER[label],
    }

ds2 = DatasetDict({k: ds[k].map(map_ex, remove_columns=ds[k].column_names) for k in ds})
print("Val label dist:", Counter(ds2["validation"]["label"]))
print("\n--- sample prompt ---\n")
print(ds2["train"][0]["user_text"])
print("gold:", ds2["train"][0]["label_letter"])


DatasetDict({
    train: Dataset({
        features: ['story_id', 'premises', 'premises-FOL', 'conclusion', 'conclusion-FOL', 'label', 'example_id'],
        num_rows: 1001
    })
    validation: Dataset({
        features: ['story_id', 'premises', 'premises-FOL', 'conclusion', 'conclusion-FOL', 'label', 'example_id'],
        num_rows: 203
    })
})
Train/Val sizes: 1001 203
Columns: ['story_id', 'premises', 'premises-FOL', 'conclusion', 'conclusion-FOL', 'label', 'example_id']
Val label dist: Counter({'True': 72, 'Unknown': 69, 'False': 62})

--- sample prompt ---

Task: Determine whether the conclusion is entailed, contradicted, or unknown given the premises.
Premises:
- All people who regularly drink coffee are dependent on caffeine.
People regularly drink coffee, or they don't want to be addicted to caffeine, or both.
No one who doesn't want to be addicted to caffeine is unaware that caffeine is a drug.
Rina is either a student who is unaware that caffeine is a drug, or she is not

In [37]:
# CELL 5 — LoRA target selection (T5 vs BART) + metrics helpers

def pick_lora_targets(model_name: str):
    # T5 uses different module naming than BART
    name = model_name.lower()
    if "t5" in name:
        # T5 attention: q, k, v, o (works for flan-t5)
        return ["q", "k", "v", "o"]
    else:
        # BART attention projections
        return ["q_proj", "k_proj", "v_proj", "out_proj"]

def normalize_pred_letter(s: str) -> str:
    if s is None:
        return ""
    s = s.strip()
    if not s:
        return ""
    # take first non-space character
    c = s[0].upper()
    return c if c in {"A","B","C"} else ""

def compute_accuracy(pred_texts, gold_texts):
    preds = [normalize_pred_letter(t) for t in pred_texts]
    golds = [normalize_pred_letter(t) for t in gold_texts]
    invalid = sum(p == "" for p in preds)
    acc = sum(p == g for p, g in zip(preds, golds)) / max(1, len(golds))
    return acc, invalid / max(1, len(golds)), preds


In [38]:
# CELL 6 — Tokenization for Seq2Seq (fixes eval_loss NaN) + custom callback for clean logging
from dataclasses import dataclass

@dataclass
class RunningLog:
    last_train_loss: float = float("nan")

running = RunningLog()

class TableLoggerCallback(TrainerCallback):
    """Keeps the last seen training loss, and prints a compact table at eval steps."""
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            running.last_train_loss = float(logs["loss"])

def make_tokenize_fn(tokenizer):
    def tokenize_batch(batch):
        model_inputs = tokenizer(
            batch["user_text"],
            max_length=MAX_SOURCE_LEN,
            truncation=True,
        )
        # tokenize targets explicitly
        with tokenizer.as_target_tokenizer():
            lab = tokenizer(
                batch["label_letter"],
                max_length=MAX_TARGET_LEN,
                truncation=True,
            )
        labels = lab["input_ids"]
        # mask pad tokens to -100 so loss ignores them
        pad_id = tokenizer.pad_token_id
        labels = [[(t if t != pad_id else -100) for t in seq] for seq in labels]
        model_inputs["labels"] = labels
        return model_inputs
    return tokenize_batch

def sanity_check_labels(tokenized_ds, tokenizer, split="validation", n=50):
    # ensure every example has at least one non -100 label token
    bad = 0
    for i in range(min(n, len(tokenized_ds[split]))):
        labs = tokenized_ds[split][i]["labels"]
        if sum(t != -100 for t in labs) == 0:
            bad += 1
    print(f"Label sanity check ({split}, first {min(n, len(tokenized_ds[split]))}): bad={bad}")


In [41]:
# CELL 7 — Train/eval loop over encoder–decoder models (LoRA + EarlyStopping) + results table after each model
RESULTS = []

for model_name in MODEL_LIST:
    print("\n" + "="*100)
    print("MODEL:", model_name)

    out_dir = os.path.join(OUT_ROOT, model_name.replace("/", "__"))
    os.makedirs(out_dir, exist_ok=True)

    # tokenizer / model
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    base = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        torch_dtype=(torch.bfloat16 if torch.cuda.is_available() else None),
        device_map="auto" if torch.cuda.is_available() else None,
    )

    # LoRA
    targets = pick_lora_targets(model_name)
    lora_cfg = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.SEQ_2_SEQ_LM,
        target_modules=targets,
    )
    model = get_peft_model(base, lora_cfg)

    # report trainable params
    model.print_trainable_parameters()

    # tokenize dataset for this tokenizer/model
    tok_fn = make_tokenize_fn(tokenizer)
    tokenized = ds2.map(tok_fn, batched=True, remove_columns=ds2["train"].column_names)
    sanity_check_labels(tokenized, tokenizer)

    collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    # accuracy from generation
    def compute_metrics(eval_pred):
        pred_ids, label_ids = eval_pred
        # decode predictions
        pred_texts = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        # labels: replace -100 with pad to decode
        label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
        gold_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        acc, invalid_rate, _ = compute_accuracy(pred_texts, gold_texts)
        return {"accuracy": acc, "invalid_rate": invalid_rate}

    from transformers import IntervalStrategy # Added import

    # training args (bf16 preferred on A100; fall back otherwise)
    use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
    args = Seq2SeqTrainingArguments(
        output_dir=out_dir,
        per_device_train_batch_size=BATCH,
        per_device_eval_batch_size=BATCH,
        learning_rate=2e-4,
        num_train_epochs=50,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        weight_decay=0.0,
        eval_strategy=IntervalStrategy.STEPS, # Modified line
        eval_steps=200,
        save_steps=200,
        logging_steps=50,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="eval_accuracy",
        greater_is_better=True,
        predict_with_generate=True,
        generation_max_length=MAX_TARGET_LEN, # Changed from generation_max_new_tokens
        generation_num_beams=1,
        report_to="none",
        fp16=False,
        bf16=bool(use_bf16),
        dataloader_num_workers=2,
        seed=SEED,
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["validation"],
        tokenizer=tokenizer,
        data_collator=collator,
        compute_metrics=compute_metrics,
        callbacks=[
            TableLoggerCallback(),
            EarlyStoppingCallback(early_stopping_patience=3),
        ],
    )

    # train
    train_toggle = trainer.train()
    metrics = trainer.evaluate()

    # grab best metrics
    eval_acc = float(metrics.get("eval_accuracy", float("nan")))
    eval_loss = float(metrics.get("eval_loss", float("nan")))
    invalid = float(metrics.get("eval_invalid_rate", float("nan")))

    # save LoRA adapter + tokenizer
    trainer.model.save_pretrained(os.path.join(out_dir, "lora_adapter"))
    tokenizer.save_pretrained(os.path.join(out_dir, "tokenizer"))

    RESULTS.append({
        "model": model_name,
        "trainable_params": int(sum(p.numel() for p in model.parameters() if p.requires_grad)),
        "eval_accuracy": eval_acc,
        "eval_loss": eval_loss,
        "invalid_rate": invalid,
        "best_checkpoint": getattr(trainer.state, "best_model_checkpoint", None),
    })

    # print results table after each model
    df = pd.DataFrame(RESULTS).sort_values("eval_accuracy", ascending=False)
    print("\n--- RESULTS SO FAR ---")
    display(df)

    # cleanup
    del trainer, model, base, tokenizer, tokenized
    cleanup()

# final save
final_df = pd.DataFrame(RESULTS).sort_values("eval_accuracy", ascending=False)
final_path = os.path.join(OUT_ROOT, "results_seq2seq.csv")
final_df.to_csv(final_path, index=False)
print("\nSaved:", final_path)
display(final_df)


MODEL: google/flan-t5-base


  trainer = Seq2SeqTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


trainable params: 1,769,472 || all params: 249,347,328 || trainable%: 0.7096
Label sanity check (validation, first 50): bad=0


Step,Training Loss,Validation Loss,Accuracy,Invalid Rate
200,0.587,0.4974,0.44335,0.0
400,0.5251,0.449879,0.571429,0.0
600,0.4898,0.41458,0.625616,0.0
800,0.4126,0.414164,0.610837,0.0
1000,0.4344,0.38151,0.635468,0.0
1200,0.3904,0.401097,0.640394,0.0
1400,0.3599,0.407496,0.660099,0.0
1600,0.3186,0.442893,0.650246,0.0
1800,0.3017,0.466663,0.679803,0.0
2000,0.3176,0.458115,0.660099,0.0



--- RESULTS SO FAR ---


Unnamed: 0,model,trainable_params,eval_accuracy,eval_loss,invalid_rate,best_checkpoint
0,google/flan-t5-base,1769472,0.679803,0.466663,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...



MODEL: google/flan-t5-large


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

trainable params: 4,718,592 || all params: 787,868,672 || trainable%: 0.5989


Map:   0%|          | 0/1001 [00:00<?, ? examples/s]



Map:   0%|          | 0/203 [00:00<?, ? examples/s]

Label sanity check (validation, first 50): bad=0


  trainer = Seq2SeqTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss,Validation Loss,Accuracy,Invalid Rate
200,0.5103,0.476252,0.527094,0.0
400,0.4626,0.419434,0.605911,0.0
600,0.3775,0.45342,0.660099,0.0
800,0.3389,0.476086,0.679803,0.0
1000,0.325,0.399784,0.674877,0.0
1200,0.3173,0.460753,0.669951,0.0
1400,0.2534,0.573555,0.64532,0.0



--- RESULTS SO FAR ---


Unnamed: 0,model,trainable_params,eval_accuracy,eval_loss,invalid_rate,best_checkpoint
0,google/flan-t5-base,1769472,0.679803,0.466663,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
1,google/flan-t5-large,4718592,0.679803,0.476086,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...



MODEL: facebook/bart-base


config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

trainable params: 884,736 || all params: 140,305,152 || trainable%: 0.6306


Map:   0%|          | 0/1001 [00:00<?, ? examples/s]



Map:   0%|          | 0/203 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


Label sanity check (validation, first 50): bad=0


Step,Training Loss,Validation Loss,Accuracy,Invalid Rate
200,0.4498,0.388472,0.339901,0.0
400,0.4082,0.370848,0.359606,0.0
600,0.4088,0.403385,0.35468,0.0
800,0.3745,0.38954,0.35468,0.0
1000,0.4087,0.372298,0.369458,0.0
1200,0.3893,0.380424,0.359606,0.0
1400,0.3799,0.391101,0.339901,0.0
1600,0.3519,0.37808,0.418719,0.0
1800,0.3298,0.394743,0.384236,0.0
2000,0.3183,0.34397,0.477833,0.0


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



--- RESULTS SO FAR ---


Unnamed: 0,model,trainable_params,eval_accuracy,eval_loss,invalid_rate,best_checkpoint
0,google/flan-t5-base,1769472,0.679803,0.466663,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
1,google/flan-t5-large,4718592,0.679803,0.476086,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
2,facebook/bart-base,884736,0.630542,0.382773,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...



MODEL: facebook/bart-large


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

trainable params: 2,359,296 || all params: 408,650,752 || trainable%: 0.5773


Map:   0%|          | 0/1001 [00:00<?, ? examples/s]



Map:   0%|          | 0/203 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


Label sanity check (validation, first 50): bad=0


Step,Training Loss,Validation Loss,Accuracy,Invalid Rate
200,0.5588,1.617371,0.0,1.0
400,0.4812,0.861463,0.0,0.995074
600,0.4631,0.937184,0.0,1.0
800,0.4481,0.970275,0.285714,0.197044
1000,0.4182,0.85067,0.0,1.0
1200,0.415,0.856794,0.0,1.0
1400,0.3937,0.908115,0.251232,0.433498



--- RESULTS SO FAR ---


Unnamed: 0,model,trainable_params,eval_accuracy,eval_loss,invalid_rate,best_checkpoint
0,google/flan-t5-base,1769472,0.679803,0.466663,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
1,google/flan-t5-large,4718592,0.679803,0.476086,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
2,facebook/bart-base,884736,0.630542,0.382773,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
3,facebook/bart-large,2359296,0.285714,0.970275,0.197044,/content/drive/MyDrive/logic/folio_seq2seq_lor...



Saved: /content/drive/MyDrive/logic/folio_seq2seq_lora/results_seq2seq.csv


Unnamed: 0,model,trainable_params,eval_accuracy,eval_loss,invalid_rate,best_checkpoint
0,google/flan-t5-base,1769472,0.679803,0.466663,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
1,google/flan-t5-large,4718592,0.679803,0.476086,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
2,facebook/bart-base,884736,0.630542,0.382773,0.0,/content/drive/MyDrive/logic/folio_seq2seq_lor...
3,facebook/bart-large,2359296,0.285714,0.970275,0.197044,/content/drive/MyDrive/logic/folio_seq2seq_lor...
