In [1]:
from datasets import load_dataset

ds = load_dataset("noor-zalouk/wiki-math-articles")
ds

DatasetDict({
    train: Dataset({
        features: ['title', 'sub_title', 'text', 'category'],
        num_rows: 80167
    })
    valid: Dataset({
        features: ['title', 'sub_title', 'text', 'category'],
        num_rows: 26723
    })
    test: Dataset({
        features: ['title', 'sub_title', 'text', 'category'],
        num_rows: 26723
    })
    rag: Dataset({
        features: ['title', 'sub_title', 'text', 'category'],
        num_rows: 159698
    })
})

In [2]:
from transformers import T5ForConditionalGeneration, T5TokenizerFast, DataCollatorForSeq2Seq

model_name = "t5-small"
tokenizer = T5TokenizerFast.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
base_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [3]:
import random

class CustomDataCollator:
    def __init__(self, tokenizer, model, p_irrelevant=0.05, p_relevant=0.20, max_source_length=512, max_target_length=512):
        self.tokenizer = tokenizer
        self.model = model
        self.base_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
        self.p_irrelevant = p_irrelevant
        self.p_relevant = p_relevant
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

    def __call__(self, features):
        new_features = []
        texts_pool = [f["text"] for f in features]  # candidate irrelevant texts

        for f in features:
            title = f["title"]
            sub_title = f["sub_title"]
            if not sub_title:
                sub_title = ""
            else:
                pass
            context_relevant = f["text"]  # the real one
            # pick another sample’s text as irrelevant
            context_irrelevant = random.choice(texts_pool)
            while context_irrelevant == context_relevant:
                context_irrelevant = random.choice(texts_pool)

            r = random.random()
            if r < self.p_irrelevant:
                input_text = f"EXPLAIN {sub_title} {title} CONTEXT {context_irrelevant}"
            elif r < self.p_irrelevant + self.p_relevant:
                input_text = f"EXPLAIN {sub_title} {title} CONTEXT {context_relevant}"
            else:
                input_text = f"EXPLAIN {sub_title} {title}"

            label_text = f["text"]

            new_features.append({
                "input_ids": self.tokenizer(input_text, truncation=True, max_length=self.max_source_length).input_ids,
                "labels": self.tokenizer(label_text, truncation=True, max_length=self.max_target_length).input_ids
            })

        return self.base_collator(new_features)

In [4]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

custom_collator = CustomDataCollator(tokenizer, model, p_irrelevant=0.05, p_relevant=0.20, max_source_length=512, max_target_length=375)

training_args = Seq2SeqTrainingArguments(
    output_dir="t5_explain_runs/exp1",
    remove_unused_columns=False,
    per_device_train_batch_size=4,          # adjust to your VRAM
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=16,          # → effective batch 64
    num_train_epochs=2,
    eval_strategy="steps",
    eval_steps=500,
    logging_steps=100,
    predict_with_generate=True,
    label_smoothing_factor=0.1,
    warmup_ratio=0.06,
    learning_rate=0.0,
    weight_decay=0.0,
    seed=42,
    optim="adafactor",
)

In [5]:

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=ds['train'],             # your tokenized dataset objects
    eval_dataset=ds['valid'],
    data_collator=custom_collator,
    processing_class=tokenizer
)

In [6]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss


KeyboardInterrupt: 