### **Imports and configuration**

In [1]:
import re
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
)

MODEL_NAME = "gpt2"
DATASET_NAME = "eth-nlped/mathdial"

MAX_LENGTH = 512
EPOCHS = 3
LR = 5e-5

PER_DEVICE_TRAIN_BS = 4
PER_DEVICE_EVAL_BS = 4
GRAD_ACCUM = 4
SEED = 42

RUN_TUTOR_ONLY = True
RUN_MASKED_LOSS = True

OUTPUT_DIR_TUTOR_ONLY = "./gpt2-mathdial-tutor_only"
OUTPUT_DIR_MASKED_LOSS = "./gpt2-mathdial-masked_loss"

INCLUDE_QUESTION = True

USE_FP16 = torch.cuda.is_available()
USE_BF16 = False  # set True if you know your GPU supports bf16 well
USE_GRADIENT_CHECKPOINTING = True


  from .autonotebook import tqdm as notebook_tqdm


### **Load dataset and make test-train split**

In [2]:
raw = load_dataset(DATASET_NAME)

if "validation" in raw:
    train_raw = raw["train"]
    val_raw = raw["validation"]
else:
    split = raw["train"].train_test_split(test_size=0.1, seed=SEED)
    train_raw, val_raw = split["train"], split["test"]

print(train_raw, val_raw)
print(train_raw[0].keys())

Dataset({
    features: ['qid', 'scenario', 'question', 'ground_truth', 'student_incorrect_solution', 'student_profile', 'teacher_described_confusion', 'self-correctness', 'self-typical-confusion', 'self-typical-interactions', 'conversation'],
    num_rows: 2035
}) Dataset({
    features: ['qid', 'scenario', 'question', 'ground_truth', 'student_incorrect_solution', 'student_profile', 'teacher_described_confusion', 'self-correctness', 'self-typical-confusion', 'self-typical-interactions', 'conversation'],
    num_rows: 227
})
dict_keys(['qid', 'scenario', 'question', 'ground_truth', 'student_incorrect_solution', 'student_profile', 'teacher_described_confusion', 'self-correctness', 'self-typical-confusion', 'self-typical-interactions', 'conversation'])


### **Tokenizer + special tokens**

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

REDACT_TOKEN = "<FINAL_ANSWER_REDACTED>"
SPECIAL_TOKENS = ["<STUDENT>", "<TUTOR>", "<PROBLEM>", "</PROBLEM>", REDACT_TOKEN]

tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})

# GPT-2 has no pad token by default
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

len(tokenizer), tokenizer.pad_token, tokenizer.eos_token

(50262, '<|endoftext|>', '<|endoftext|>')

### **Parsing and final-answer detection**

In [4]:
EOM = "|EOM|"
TEACHER_PREFIX = "Teacher:"
STUDENT_PREFIX = "Student:"

# B-mode: only strong cues
FINAL_CUE_PATTERNS = [
    r"\bfinal answer\b",
    r"\bthe answer is\b",
    r"^\s*answer\s*:\s*",
    r"^\s*solution\s*:\s*",
    r"^\s*final\s*:\s*",
]

def strip_dialog_act(text: str) -> str:
    # e.g. "(act) text" -> "text"
    return re.sub(r"^\s*\([^)]*\)\s*", "", text).strip()

def is_final_answer_like_B(text: str) -> bool:
    t = strip_dialog_act(text).lower()
    return any(re.search(p, t, flags=re.IGNORECASE | re.MULTILINE) for p in FINAL_CUE_PATTERNS)

def parse_conversation(conv: str) -> List[Tuple[str, str]]:
    parts = [p.strip() for p in conv.split(EOM)]
    turns: List[Tuple[str, str]] = []
    for p in parts:
        if not p:
            continue
        if p.startswith(TEACHER_PREFIX):
            turns.append(("teacher", p[len(TEACHER_PREFIX):].strip()))
        elif p.startswith(STUDENT_PREFIX):
            turns.append(("student", p[len(STUDENT_PREFIX):].strip()))
        else:
            # fallback
            turns.append(("student", p.strip()))
    return turns

In [5]:
def build_samples_tutor_only(example: Dict[str, Any], include_question: bool) -> List[Dict[str, str]]:
    q = (example.get("question") or "").strip()
    turns = parse_conversation(example["conversation"])

    samples = []
    history = ""
    if include_question and q:
        history = f"<PROBLEM>\n{q}\n</PROBLEM>\n"

    for role, text in turns:
        if role == "student":
            history += f"<STUDENT> {text}\n"
            continue

        # teacher turn
        if is_final_answer_like_B(text):
            # redact from context and do not use as training target
            history += f"<TUTOR> {REDACT_TOKEN}\n"
            continue

        prompt = history + "<TUTOR> "
        target = f"{strip_dialog_act(text)}\n"
        samples.append({"text": prompt + target})
        history += f"<TUTOR> {strip_dialog_act(text)}\n"

    return samples

def explode_tutor_only(ds_in) -> Dataset:
    rows = {"text": []}
    for ex in ds_in:
        for s in build_samples_tutor_only(ex, include_question=INCLUDE_QUESTION):
            rows["text"].append(s["text"])
    return Dataset.from_dict(rows)

train_text = explode_tutor_only(train_raw)
val_text = explode_tutor_only(val_raw)

def tok_tutor_only(batch):
    out = tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH, padding=False)
    out["labels"] = out["input_ids"].copy()
    return out

train_tutor_only = train_text.map(tok_tutor_only, batched=True, remove_columns=train_text.column_names)
val_tutor_only = val_text.map(tok_tutor_only, batched=True, remove_columns=val_text.column_names)

train_tutor_only, val_tutor_only

Map: 100%|██████████| 13312/13312 [00:02<00:00, 5378.47 examples/s]
Map: 100%|██████████| 1472/1472 [00:00<00:00, 5218.65 examples/s]


(Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 13312
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 1472
 }))

In [6]:
def build_features_masked_loss(example: Dict[str, Any]) -> Dict[str, Any]:
    q = (example.get("question") or "").strip()
    turns = parse_conversation(example["conversation"])

    segments: List[Tuple[str, str, bool]] = []  # (role, text, supervise?)
    if INCLUDE_QUESTION and q:
        segments.append(("meta", f"<PROBLEM>\n{q}\n</PROBLEM>\n", False))

    for role, text in turns:
        if role == "student":
            segments.append(("student", f"<STUDENT> {text}\n", False))
        else:
            cleaned = strip_dialog_act(text)
            if is_final_answer_like_B(cleaned):
                segments.append(("teacher", f"<TUTOR> {REDACT_TOKEN}\n", False))
            else:
                segments.append(("teacher", f"<TUTOR> {cleaned}\n", True))

    input_ids: List[int] = []
    labels: List[int] = []

    for _, seg_text, supervise in segments:
        seg_ids = tokenizer.encode(seg_text, add_special_tokens=False)
        if len(input_ids) + len(seg_ids) > MAX_LENGTH:
            break
        input_ids.extend(seg_ids)
        labels.extend(seg_ids if supervise else [-100] * len(seg_ids))

    attention_mask = [1] * len(input_ids)
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

train_masked = train_raw.map(build_features_masked_loss, remove_columns=train_raw.column_names)
val_masked = val_raw.map(build_features_masked_loss, remove_columns=val_raw.column_names)

train_masked, val_masked


Map: 100%|██████████| 227/227 [00:00<00:00, 763.65 examples/s]


(Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 2035
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 227
 }))

In [7]:
@dataclass
class CollatorForCausalLM:
    tokenizer: AutoTokenizer

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        max_len = max(len(x["input_ids"]) for x in batch)

        def pad(seq, pad_value):
            return seq + [pad_value] * (max_len - len(seq))

        input_ids = [pad(x["input_ids"], self.tokenizer.pad_token_id) for x in batch]
        attention_mask = [pad(x["attention_mask"], 0) for x in batch]
        labels = [pad(x["labels"], -100) for x in batch]

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
        }

collator = CollatorForCausalLM(tokenizer=tokenizer)

In [8]:
def fresh_model():
    m = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    m.resize_token_embeddings(len(tokenizer))
    if USE_GRADIENT_CHECKPOINTING:
        m.gradient_checkpointing_enable()
    return m

def train_one(model, train_ds, val_ds, out_dir: str):
    args = TrainingArguments(
        output_dir=out_dir,
        learning_rate=LR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=PER_DEVICE_TRAIN_BS,
        per_device_eval_batch_size=PER_DEVICE_EVAL_BS,
        gradient_accumulation_steps=GRAD_ACCUM,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_steps=50,
        save_total_limit=2,
        weight_decay=0.01,
        warmup_ratio=0.06,
        lr_scheduler_type="linear",
        fp16=USE_FP16,
        bf16=USE_BF16,
        report_to="none",
        seed=SEED,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=tokenizer,
        data_collator=collator,
    )

    trainer.train()
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)

In [9]:
if RUN_TUTOR_ONLY:
    model_tutor = fresh_model()
    train_one(model_tutor, train_tutor_only, val_tutor_only, OUTPUT_DIR_TUTOR_ONLY)

if RUN_MASKED_LOSS:
    model_masked = fresh_model()
    train_one(model_masked, train_masked, val_masked, OUTPUT_DIR_MASKED_LOSS)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 50256}.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Epoch,Training Loss,Validation Loss
1,1.4067,1.548562
2,1.122,1.435504
3,1.0385,1.411279


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].
  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 50256}.


Epoch,Training Loss,Validation Loss
1,2.8474,2.433882
2,2.274,2.225005
3,2.1567,2.197028


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].
