In [None]:
import torch, platform, sys, subprocess

print("Python:", sys.version)
print("OS:", platform.platform())
print("Torch:", torch.__version__)
print("CUDA?", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("CUDA capability:", torch.cuda.get_device_capability(0))

try:
    out = subprocess.check_output(["nvidia-smi"], text=True)
    print("\n=== nvidia-smi ===\n", out)
except Exception as e:
    print("nvidia-smi not available:", e)

In [None]:
import json, pathlib, random, uuid, re, collections

DATA_DIR = pathlib.Path("data")
RAW = DATA_DIR/"raw"
OUT = DATA_DIR/"train.json"
RAW.mkdir(parents=True, exist_ok=True)

BEGIN = r'<!--\s*TOPIC-LIST:BEGIN\s*-->'
END = r'<!--\s*TOPIC-LIST:END\s*-->'

def controls_to_prompt(ctrl):
    length = ctrl.get("length")
    length = str(length) if length is not None else ""

    fmt = ctrl.get("format") or []
    if isinstance(fmt, str):
        fmt = [fmt]
    
    topics = ctrl.get("topics") or []
    if isinstance(topics, str):
        topics = [topics]

    parts = [
        f"length={length}",
        f"format={'+'.join(fmt)}",
        f"topics={','.join(topics)}",
        f"difficulty={ctrl.get('difficulty','')}",
    ]
    return "; ".join(parts)

def load_api_like_examples():
    exams = []
    for p in RAW.glob("*.json"):
        with open(p, "r", encoding="utf-8") as f:
            obj = json.load(f)
            exams.append(obj)
    return exams

def normalize_exam(data, *, indent=2) -> str:
    def _none_to_empty(o):
        if o is None:
            return ""
        if isinstance(o, dict):
            return {k: _none_to_empty(v) for k, v in o.items()}
        if isinstance(o, list):
            return [_none_to_empty(v) for v in o]
        return o

    cleaned = _none_to_empty(data)
    return json.dumps(cleaned, ensure_ascii=False, indent=indent)

def load_topics_from_readme(path="README.md", region_required=True, return_codes=False, dedupe=True, sort=False):
    text = pathlib.Path(path).read_text(encoding="utf-8")

    m = re.search(f'{BEGIN}(.*?){END}', text, flags=re.S | re.I)
    region = m.group(1) if m else (text if not region_required else "")

    cat_re = re.compile(r'^\s*###\s+\*\*\s*\d+\.\s*(?P<cat>[^*]+?)\s*\*\*\s*$', re.M)
    topic_re = re.compile(r'^\s*\*\s+(?P<topic>.+?)(?:\s*-\s*(?P<codes>(?:\d{4}|)(?:,(?:\d{4}|)){0,2}))?\s*$', re.M)

    by_cat = collections.defaultdict(list)
    current = None

    for line in region.splitlines():
        line = line.rstrip()

        mcat = cat_re.match(line)
        if mcat:
            current = mcat.group('cat').strip()
            continue

        mtop = topic_re.match(line)
        if mtop and current:
            t = mtop.group('topic').strip()
            t = re.sub(r'\*\*(.*?)\*\*', r'\1', t)
            t = re.sub(r'\*(.*?)\*', r'\1', t)
            t = re.sub(r'\s+', ' ', t).strip()

            if return_codes:
                codes_raw = mtop.group('codes')
                codes = [c or None for c in (codes_raw.split(',') if codes_raw is not None else [])]
                by_cat[current].append((t, codes))
            else:
                by_cat[current].append(t)

    if return_codes:
        flat = [(t, codes) for topics in by_cat.values() for (t, codes) in topics]
    else:
        flat = [t for topics in by_cat.values() for t in topics]

    if dedupe:
        if return_codes:
            seen, uniq = set(), []
            for t, c in flat:
                if t.lower() in seen:
                    continue
                seen.add(t.lower()); uniq.append((t, c))
            flat = uniq
        else:
            flat = list(dict.fromkeys([t for t in flat]))  # order-preserving
    if sort:
        if return_codes:
            flat.sort(key=lambda x: x[0].casefold())
        else:
            flat.sort(key=str.casefold)

    return flat, dict(by_cat)

def debug_readme(path="README.md"):
    text = pathlib.Path(path).read_text(encoding="utf-8")
    print("Has BEGIN marker:", bool(re.search(BEGIN, text, re.I)))
    print("Has END marker:  ", bool(re.search(END, text, re.I)))
    print("Has category hdr:", bool(re.search(r'^\s*###\s+\*\*\s*\d+\.', text, re.M)))
    print("Has bullets:     ", bool(re.search(r'^\s*\*\s+', text, re.M)))

topics, by_cat = load_topics_from_readme("README.md")
print(len(topics), "topics loaded")
print(topics[:10])

random.seed(7)
exams = load_api_like_examples()

difficulties = ["easy", "medium", "hard"]
format_sets = [
    ["multiple_choice"],
    ["open_answer"],
    ["multiple_choice", "open_answer"],
]

records = []
for ex in exams:
    ctrl = {
        "topics": random.sample(topics, k=1),  # ensure list
        "difficulty": random.choice(difficulties),
        "length": random.choice([20, 15, 25]),
        "format": random.choice(format_sets),   # ensure list
    }
    prompt = controls_to_prompt(ctrl)
    input_text = "Exam format:\n" + normalize_exam(ex)
    target = normalize_exam(ex)

    records.append({
        "id": str(uuid.uuid4()),
        "prompt": prompt,
        "input": input_text,
        "output": target
    })

OUT.write_text("\n".join(json.dumps(r, ensure_ascii=False) for r in records), encoding="utf-8")
len(records), str(OUT)

In [None]:
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5TokenizerFast, DataCollatorForSeq2Seq, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
import torch

MODEL_NAME = "t5-small"
DATA_PATH = "data/train.json"
OUT_DIR = "out-t5-lora"

def format_example(ex):
    src = f"controls: {ex['prompt']}\n\nexemplars:\n{ex['input']}\n\n# task: generate new exam as JSON"
    tgt = ex["output"]
    return {"src": src, "tgt": tgt}

tok = T5TokenizerFast.from_pretrained(MODEL_NAME)
base = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)

lora = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
    target_modules=["q", "k", "v", "o"],
)
model = get_peft_model(base, lora)

model.config.use_cache = False
model.enable_input_require_grads()
model.gradient_checkpointing_enable()

model.print_trainable_parameters()
assert any(p.requires_grad for _, p in model.named_parameters() if "lora" in _), \
    "No LoRA parameters marked trainable. Check `target_modules` names."

ds = load_dataset("json", data_files=DATA_PATH, split="train")
ds = ds.map(format_example)

max_src_len = 512
max_tgt_len = 512

def tok_map(batch):
    model_inputs = tok(batch["src"], max_length=max_src_len, truncation=True)
    labels = tok(text_target=batch["tgt"], max_length=max_tgt_len, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

ds = ds.map(tok_map, batched=True, remove_columns=ds.column_names)

collator = DataCollatorForSeq2Seq(
    tokenizer=tok, model=model, pad_to_multiple_of=8
)

bf16_ok = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
fp16_ok = torch.cuda.is_available() and not bf16_ok

args = TrainingArguments(
    output_dir=OUT_DIR,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    warmup_ratio=0.03,
    logging_steps=10,
    save_strategy="epoch",
    bf16=bf16_ok,
    fp16=fp16_ok,
    optim="adamw_torch",
    report_to="none",
    group_by_length=True,
    remove_unused_columns=False,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds,
    data_collator=collator,
)

trainer.train()
model.save_pretrained(OUT_DIR)
tok.save_pretrained(OUT_DIR)
print("Dumped", OUT_DIR)

In [None]:
from transformers import T5ForConditionalGeneration
from peft import PeftModel, PeftConfig
import os

CKPT = "out-t5-lora"

print("Folder contains:", os.listdir(CKPT))

try:
    peft_cfg = PeftConfig.from_pretrained(CKPT)
    print("PEFT adapter base_model_name_or_path:", peft_cfg.base_model_name_or_path)
except Exception as e:
    print("Not a PEFT adapter folder or unreadable:", e)

In [None]:
import json, re, random, torch
from transformers import T5ForConditionalGeneration, T5TokenizerFast, StoppingCriteria, StoppingCriteriaList
from peft import PeftModel

BASE = "t5-small"
CKPT = "out-t5-lora"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MAX_SRC_LEN = 512
MAX_NEW_TOKENS = 700
SEED = 7
torch.manual_seed(SEED); random.seed(SEED)

tok = T5TokenizerFast.from_pretrained(CKPT)
base = T5ForConditionalGeneration.from_pretrained(BASE)
model = PeftModel.from_pretrained(base, CKPT).eval().to(DEVICE)

# Maybe merge adapters
#from transformers import T5ForConditionalGeneration, T5TokenizerFast
#from peft import PeftModel
#BASE, CKPT, MERGED = "t5-small", "out-t5-lora", "out-t5-lora-merged"
#tok = T5TokenizerFast.from_pretrained(CKPT)
#base = T5ForConditionalGeneration.from_pretrained(BASE)
#merged = PeftModel.from_pretrained(base, CKPT).merge_and_unload()
#merged.save_pretrained(MERGED); tok.save_pretrained(MERGED)
#print("Saved merged model to", MERGED)


def make_controls(topics, difficulty, length, fmt):
    return f"topics={','.join(topics)}; difficulty={difficulty}; length={length}; format={'+'.join(fmt)}"

def make_src(ctrl):
    instr = (
        "Return ONLY a JSON object with keys: "
        "metadata(topics[],difficulty,length,format[]), "
        "questions[{id:int,text:str,type:multiple_choice|open_answer,options:[str]|null,answer:str,subquestions:null}]."
    )
    return f"controls: {ctrl}\n{instr}"

def token_len(s): return len(tok(s).input_ids)

def _balanced_braces(text: str) -> bool:
    depth, in_str, esc = 0, False, False
    for ch in text:
        if in_str:
            if esc: esc = False
            elif ch == '\\': esc = True
            elif ch == '"': in_str = False
        else:
            if ch == '"': in_str = True
            elif ch == '{': depth += 1
            elif ch == '}':
                depth -= 1
                if depth < 0: return False
    return depth == 0 and "{" in text

class BalancedJSONStop(StoppingCriteria):
    def __init__(self, tokenizer): self.tok = tokenizer
    def __call__(self, input_ids, scores, **kwargs):
        text = self.tok.decode(input_ids[0], skip_special_tokens=True)
        return text.strip().startswith("{") and _balanced_braces(text)

def _extract_or_repair(s: str):
    s = s.strip()
    if s.startswith("{") and s.endswith("}"):
        try: return json.loads(s)
        except json.JSONDecodeError: pass
    m = re.search(r"\{[\s\S]*\}", s)
    if m:
        cand = m.group(0)
        try: return json.loads(cand)
        except json.JSONDecodeError:
            cand2 = re.sub(r",(\s*[}\]])", r"\1", cand)  # trailing commas
            cand2 = cand2.replace("“", '"').replace("”", '"').replace("’", "'")
            cand2 = re.sub(r"\bNone\b", "null", cand2)
            cand2 = re.sub(r"\bTrue\b", "true", cand2)
            cand2 = re.sub(r"\bFalse\b", "false", cand2)
            try: return json.loads(cand2)
            except Exception: return None
    return None

def canonicalize(obj):
    md = obj.get("metadata", {})
    if "length" in md: md["length"] = str(md["length"])
    for q in obj.get("questions", []):
        if q.get("type") == "open_answer": q["options"] = None
        if "answer" in q and q["answer"] is not None: q["answer"] = str(q["answer"])
        if "subquestions" not in q or q["subquestions"] is None: q["subquestions"] = None
    return obj

def generate_once(ctrl, constrained=False):
    src = make_src(ctrl)
    enc = tok(src, return_tensors="pt", max_length=MAX_SRC_LEN, truncation=True).to(DEVICE)
    print(f"Tokens: src={token_len(src)} (truncated to {enc['input_ids'].shape[-1]}), MAX_SRC_LEN={MAX_SRC_LEN}")
    gen_kwargs = dict(
        max_new_tokens=MAX_NEW_TOKENS, do_sample=False, num_beams=5,
        length_penalty=0.9, early_stopping=True, no_repeat_ngram_size=3,
    )
    if constrained:
        dec_start = tok("{", return_tensors="pt").input_ids.to(DEVICE)
        gen_kwargs["decoder_input_ids"] = dec_start
        gen_kwargs["stopping_criteria"] = StoppingCriteriaList([BalancedJSONStop(tok)])
    with torch.no_grad():
        out_ids = model.generate(**enc, **gen_kwargs)
    raw = tok.decode(out_ids[0], skip_special_tokens=True).strip()
    obj = _extract_or_repair(raw)
    status = "OK" if isinstance(obj, dict) else "PARSE_FAIL"
    print(f"[{'CONSTRAINED' if constrained else 'RAW'}] status={status}, chars={len(raw)}")
    if len(raw) > 800:
        print(raw[:800] + "\n...[truncated]...")
    else:
        print(raw)
    if isinstance(obj, dict):
        obj = canonicalize(obj)
    return obj, raw

ctrl = make_controls(["algebra","linear-equations"], "hard", 8, ["multiple_choice","open_answer"])

print("=== RAW GENERATION ===")
obj_raw, raw_text = generate_once(ctrl, constrained=False)

print("\n=== CONSTRAINED GENERATION ===")
obj_con, con_text = generate_once(ctrl, constrained=True)

final = obj_con if isinstance(obj_con, dict) else obj_raw

if isinstance(final, dict):
    print("\nParsed JSON (canonicalized):")
    print(json.dumps(final, indent=2, ensure_ascii=False))
else:
    print("\nNo valid JSON parsed. Inspect raw above.")