# Gated GRPO Training Contract2Graph Pipeline— VM‑Optimized (A100‑80GB)


This notebook runs **GRPO** on top of our **SFT** adapter for *Llama‑3.1‑8B‑Instruct* on a A100‑80GB VM.  
It logs to **Weights & Biases**, supports **resume / autosave**, **JSON-brace stopping** for efficient generation, and a gated GRPO approach.

**Folders used**
- SFT adapter: `~/ai/checkpoints/run14_2025-08-25_22-37/final/`
- GRPO checkpoints/logs/outputs: `~/ai/checkpoints/run14_2025-08-25_22-37_grpo/`, `~/ai/logs/run14_2025-08-25_22-37_grpo/`, `~/ai/outputs/run14_2025-08-25_22-37_grpo/`


In [None]:

# === [Cell] 0 · System Setup & Dependencies ===
# This cell should be run FIRST, and the kernel restarted immediately after.

# We only need to install flash-attn
# The --no-build-isolation flag helps it find your existing torch installation
!pip install "flash-attn==2.5.8" --no-build-isolation
!pip install sentence-transformers

print("\n\n✅ Flash Attention installed. PLEASE RESTART THE KERNEL NOW.")
print("--> In the menu, go to Kernel > Restart Kernel...")



✅ Flash Attention installed. PLEASE RESTART THE KERNEL NOW.
--> In the menu, go to Kernel > Restart Kernel...


In [None]:
import torch
free, total = torch.cuda.mem_get_info()
print(f"GPU free/total before load: {free/1e9:.1f}/{total/1e9:.1f} GB")

GPU free/total before load: 69.2/85.0 GB


In [None]:
# === [Cell] 1 · VM init (paths, caches, W&B) ===
from datetime import datetime
from pathlib import Path
import os, wandb

HF_DATASET_ID="moriyad/clause_minigraph_builder_grpo"

os.environ["TRANSFORMERS_VERBOSITY"] = "error"  # before importing transformers logs
from transformers.utils import logging as hf_logging
hf_logging.set_verbosity_error()

import warnings
warnings.filterwarnings("ignore", message="Caching is incompatible with gradient checkpointing*", module="transformers.models.llama.modeling_llama")
warnings.filterwarnings("ignore", message="`use_cache=True` is incompatible with gradient checkpointing*")

BASE = Path.home() / "ai"
for d in ["checkpoints","logs","outputs","wandb","hf_cache","hf_datasets_cache","hf_transformers_cache"]:
    (BASE/d).mkdir(parents=True, exist_ok=True)

SFT_RUN_DIR = Path("~/ai").expanduser()
#MODEL_DIR   = (SFT_RUN_DIR / "adapter-20250827-142825") original SFT run
CKPT_DIR    = BASE / "checkpoints" / (SFT_RUN_DIR.name + "_grpo")
LOG_DIR     = BASE / "logs" / (SFT_RUN_DIR.name + "_grpo")
OUT_DIR     = BASE / "outputs" / (SFT_RUN_DIR.name + "_grpo")
for p in (CKPT_DIR, LOG_DIR, OUT_DIR): p.mkdir(parents=True, exist_ok=True)

os.environ["WANDB_NOTEBOOK_NAME"] = "GRPO_training_minigraph_builder_vm_FINAL.ipynb"
os.environ.setdefault("HF_HOME", str(BASE / "hf_cache"))
os.environ.setdefault("HF_DATASETS_CACHE", str(BASE / "hf_datasets_cache"))
os.environ.setdefault("TRANSFORMERS_CACHE", str(BASE / "hf_transformers_cache"))
os.environ.setdefault("WANDB_DIR", str(BASE / "wandb"))
os.environ.setdefault("WANDB_CACHE_DIR", str(BASE / "wandb/cache"))
os.environ.setdefault("WANDB_PROJECT", "llama31_grpo_minigraph_customReward_customEval")

try: wandb.finish()
except: pass
wandb.login()
run_id = f"{SFT_RUN_DIR.name}_grpo_{datetime.now().strftime('%Y-%m-%d_%H-%M')}"
wandb.init(project=os.environ["WANDB_PROJECT"], name=run_id, reinit=True, config={"sft_run": str(SFT_RUN_DIR)})
print("CKPT:", CKPT_DIR); print("LOG:", LOG_DIR); print("OUT:", OUT_DIR)


  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Currently logged in as: [33mmoriya-dechtiar[0m ([33mm-dechtiar[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


CKPT: /home/ubuntu/ai/checkpoints/ai_grpo
LOG: /home/ubuntu/ai/logs/ai_grpo
OUT: /home/ubuntu/ai/outputs/ai_grpo


In [None]:
#[Cell] 2 - Load base model (4-bit) with FLASH ATTN + apply SFT adapter (GPU-pinned, safe fallbacks) ===
import os, torch, wandb
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#MODEL_DIR   = (SFT_RUN_DIR / "adapter-20250827-142825") #- Original SFT qlora model
MODEL_DIR   = (SFT_RUN_DIR / "notebooks/checkpoints/grpo_final_embed_stopper_gated_final") #from latest checkpoint
#4-bit NF4 on A100
bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

#Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token or tokenizer.pad_token

def _print_gpu():
    try:
        free, total = torch.cuda.mem_get_info()
        print(f"GPU free/total: {free/1e9:.1f} / {total/1e9:.1f} GB")
    except Exception:
        pass

print("Loading base model (4-bit) with FlashAttention-2…")
_print_gpu()

common_kwargs = dict(
    quantization_config=bnb,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",   # <<<< ENABLE FA2
)

try:
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        device_map={"": "cuda:0"},
        **common_kwargs,
    )
except ValueError as e:
    print("[Retry] memory-constrained auto mapping… Reason:", e)
    try:
        max_mem = {0: "78GiB", "cpu": "0GiB"}
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_ID,
            device_map="auto",
            max_memory=max_mem,
            **common_kwargs,
        )
    except ValueError as e2:
        print("[Retry] CPU offload last resort… Reason:", e2)
        offload_dir = os.path.join(os.getcwd(), "offload")
        os.makedirs(offload_dir, exist_ok=True)
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_ID,
            device_map="auto",
            offload_folder=offload_dir,
            **common_kwargs,
        )

#special tokens and resize embeddings
try:
    eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
    if eot_id is None or eot_id == tokenizer.unk_token_id:
        tokenizer.add_special_tokens({"additional_special_tokens": ["<|eot_id|>"]})
        model.resize_token_embeddings(len(tokenizer))
        eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
except Exception:
    eot_id = None

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Applying SFT adapter from", MODEL_DIR)
model = PeftModel.from_pretrained(model, MODEL_DIR, is_trainable=True)

#perf & training
#use_cache=False for training
model.config.use_cache = False

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

#enabled for training only
try:
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
except TypeError:
    model.gradient_checkpointing_enable()

_print_gpu()
print("Model ready (FA2, 4-bit NF4).")


In [None]:
#[Cell] 2b Alternative for Eval-only setup (not for training)
import os, torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
MODEL_DIR     = (SFT_RUN_DIR / "notebooks/checkpoints/grpo_final_embed_stopper_gated_final")

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token or tokenizer.pad_token

common_kwargs = dict(
    quantization_config=bnb,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    device_map={"": "cuda:0"},
    **common_kwargs,
)

try:
    eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
    if eot_id is None or eot_id == tokenizer.unk_token_id:
        tokenizer.add_special_tokens({"additional_special_tokens": ["<|eot_id|>"]})
        model.resize_token_embeddings(len(tokenizer))
except Exception:
    pass

model = PeftModel.from_pretrained(model, MODEL_DIR, is_trainable=False)

model.config.use_cache = True
model.eval()
try:
    model.gradient_checkpointing_disable()
except Exception:
    pass

for p in model.parameters():
    p.requires_grad_(False)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

print("Eval-only model ready (FA2, 4-bit NF4, adapters frozen).")


Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 4/4 [00:21<00:00,  5.26s/it]


Eval-only model ready (FA2, 4-bit NF4, adapters frozen).


In [None]:
#[Cell] 2b - evaluation helpers and execution
import json, gc, re, numpy as np
from collections import defaultdict
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

#helpers
def _safe_eos_ids(tok):
    ids = []
    for t in ["<|eot_id|>", "<|eos|>", tok.eos_token]:
        if not t: continue
        try:
            i = tok.convert_tokens_to_ids(t) if isinstance(t, str) else t
            if i is not None and i != tok.unk_token_id: ids.append(i)
        except: pass
    if tok.eos_token_id is not None: ids.append(tok.eos_token_id)
    seen=set(); out=[]
    for i in ids:
        if i not in seen: seen.add(i); out.append(i)
    return out or [tok.eos_token_id]

class JsonStopper(StoppingCriteria):
    """Stop when top-level braces look balanced (quick & dirty)."""
    def __init__(self, tokenizer, input_len):
        self.tok = tokenizer; self.input_len = input_len
    def __call__(self, input_ids, scores, **kwargs):
        s = self.tok.decode(input_ids[0, self.input_len:], skip_special_tokens=True)
        return s.count('{')>0 and s.count('{')==s.count('}')

def _deep_safe_json(x, max_depth=3):
    obj = x
    for _ in range(max_depth):
        if isinstance(obj, (dict, list)): return obj
        if obj is None: return {}
        s = str(obj).strip()
        b, e = s.find("{"), s.rfind("}")
        cand = s[b:e+1] if (b != -1 and e != -1 and e > b) else s
        try:
            obj = json.loads(cand); continue
        except Exception:
            break
    return {}

def _extract_nodes(obj):
    arr = obj.get("nodes", []) if isinstance(obj, dict) else []
    if isinstance(arr, dict): arr = [arr]
    out=[]
    for it in arr:
        if isinstance(it, dict): out.append(it)
        elif isinstance(it, str):
            try:
                d = json.loads(it)
                if isinstance(d, dict): out.append(d)
            except: pass
    return out

def _extract_edges(obj):
    arr = obj.get("edges", []) if isinstance(obj, dict) else []
    if isinstance(arr, dict): arr = [arr]
    out=[]
    for it in arr:
        if isinstance(it, dict): out.append(it)
        elif isinstance(it, str):
            try:
                d = json.loads(it)
                if isinstance(d, dict): out.append(d)
            except: pass
    return out

COMPANY_SUFFIX_RE = re.compile(r"\b(inc\.?|ltd\.?|llc|l\.l\.c\.|corp\.?|co\.?|ag|gmbh)\b", re.I)
WS_RE = re.compile(r"\s+")
def _norm(s):
    if not isinstance(s, str): s = str(s) if s is not None else ""
    s = s.lower().replace("&","and")
    s = COMPANY_SUFFIX_RE.sub("", s)
    return WS_RE.sub(" ", s).strip()

def _toks(s): return re.findall(r"[a-z0-9]+", s.lower())
def _jacc(a,b):
    sa,sb=set(_toks(a)),set(_toks(b))
    if not sa or not sb: return 1.0 if a.strip()==b.strip() and a.strip()!="" else 0.0
    return len(sa&sb)/max(1,len(sa|sb))

PCT_RE=re.compile(r"(\d+(?:\.\d+)?)\s*%"); MONEY_RE=re.compile(r"(\$|usd)\s*([\d,]+(?:\.\d+)?)",re.I)
DAYS_RE=re.compile(r"(\d+)\s*days?"); YEARS_RE=re.compile(r"(\d+)\s*years?")
NUM_WORDS={"zero":0,"one":1,"two":2,"three":3,"four":4,"five":5,"six":6,"seven":7,"eight":8,"nine":9,
           "ten":10,"eleven":11,"twelve":12,"thirteen":13,"fourteen":14,"fifteen":15,"sixteen":16,
           "seventeen":17,"eighteen":18,"nineteen":19,"twenty":20,"thirty":30,"forty":40,"fifty":50,
           "sixty":60,"seventy":70,"eighty":80,"ninety":90}
def _w2n(s):
    s=s.lower()
    for w,n in NUM_WORDS.items():
        if re.search(rf"\b{w}\b", s): return n
    return None

def _canon_value(text):
    s=_norm(text)
    m=PCT_RE.search(s)
    if m: return f"{float(m.group(1)):.0f}%"
    if "percent" in s:
        n=_w2n(s)
        if n is not None: return f"{n}%"
    m=MONEY_RE.search(s)
    if m: return f"usd {m.group(2).replace(',','')}"
    m=DAYS_RE.search(s)
    if m: return f"{int(m.group(1))} days"
    m=YEARS_RE.search(s)
    if m: return f"{int(m.group(1))} years"
    return s

def _keytext(n):
    t=(n.get("type") or "").upper(); a=n.get("attrs",{}) or {}; nid=n.get("id","") or ""
    if t=="CLAUSE":
        k=a.get("id") or nid or a.get("title") or ""; return t,_norm(k)
    if t=="DEFINED_TERM":
        k=a.get("term") or nid.split(":",1)[-1]; return t,_norm(k)
    if t=="PARTY":
        k=a.get("name") or a.get("text") or nid.split(":",1)[-1]; return t,_norm(k)
    if t=="VALUE":
        k=a.get("text") or nid.split(":",1)[-1]; return t,_canon_value(k)
    k=a.get("text") or a.get("term") or a.get("name") or nid; return t,_norm(k)

THRESH={"CLAUSE":0.90,"DEFINED_TERM":0.80,"PARTY":0.85,"VALUE":0.75}
def _sim(t,a,b):
    if t=="CLAUSE": return 1.0 if a==b else _jacc(a,b)
    if t in ("DEFINED_TERM","PARTY"): return _jacc(a,b)
    if t=="VALUE":
        if a==b and (a.endswith("%") or a.endswith("days") or a.endswith("years") or a.startswith("usd")): return 1.0
        return _jacc(a,b)
    return _jacc(a,b)

def _bucket(nodes):
    b=defaultdict(list)
    for n in nodes:
        t,kt=_keytext(n)
        if t and kt: b[t].append(kt)
    return b

def _match_type(G_list, P_list, t):
    if not G_list or not P_list: return 0
    pairs=[]
    for i,g in enumerate(G_list):
        for j,p in enumerate(P_list):
            s=_sim(t,g,p)
            if s>=THRESH.get(t,0.8): pairs.append((s,i,j))
    pairs.sort(reverse=True)
    used_i=set(); used_j=set(); tp=0
    for s,i,j in pairs:
        if i in used_i or j in used_j: continue
        used_i.add(i); used_j.add(j); tp+=1
    return tp

def _edge_triplet(e, node_map):
    def pick(d, keys):
        for k in keys:
            if k in d and d[k] is not None:
                return d[k]
        return None
    typ = (pick(e, ["type","edge_type","label"]) or "").upper()
    raw_src = pick(e, ["src","source","from"])
    raw_tgt = pick(e, ["tgt","target","to"])
    def resolve(v):
        if v is None: return ""
        v_str = str(v)
        if v_str in node_map:
            return node_map[v_str]
        return _norm(v_str)
    return (resolve(raw_src), typ, resolve(raw_tgt))

def _prf1(tp,fp,fn):
    p=tp/(tp+fp) if (tp+fp) else 0.0
    r=tp/(tp+fn) if (tp+fn) else 0.0
    f=2*p*r/(p+r) if (p+r) else 0.0
    return p,r,f

#eval
def evaluate_model(model, tokenizer, dataset, system_prompt, max_samples=200, batch_size=4, max_new_tokens=1024):
    """dataset must have columns: 'prompt' and 'completion' (or 'clean_completion')."""
    #examples
    n = len(dataset)
    use = min(max_samples, n)
    idxs = np.random.choice(n, use, replace=False) if use < n else np.arange(n)
    sub = dataset.select(idxs)
    prompts = sub["prompt"]
    gold_key = "clean_completion" if "clean_completion" in sub.column_names else "completion"
    golds = sub[gold_key]

    #generation
    tok = tokenizer
    old_pad, old_trunc = tok.padding_side, getattr(tok, "truncation_side", "right")
    tok.padding_side = tok.truncation_side = "left"
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    EOS = _safe_eos_ids(tok); PAD = tok.pad_token_id or tok.eos_token_id

    preds=[]
    for i in range(0, len(prompts), batch_size):
        chats = [tok.apply_chat_template(
                    [{"role":"system","content": system_prompt},
                     {"role":"user","content": p}],
                    tokenize=False, add_generation_prompt=True)
                 for p in prompts[i:i+batch_size]]
        batch = tok(chats, return_tensors="pt", padding=True, truncation=True).to(model.device)
        stopper = StoppingCriteriaList([JsonStopper(tok, input_len=batch["input_ids"].shape[1])])
        with torch.no_grad():
            out = model.generate(
                **batch,
                max_new_tokens=max_new_tokens,
                do_sample=False, temperature=None, top_p=None, top_k=None,
                use_cache=True,
                eos_token_id=EOS, pad_token_id=PAD,
                stopping_criteria=stopper,
            )
        gen = tok.batch_decode(out[:, batch["input_ids"].shape[1]:], skip_special_tokens=True)
        preds.extend(gen)
        del out, batch; torch.cuda.empty_cache(); gc.collect()

    tok.padding_side, tok.truncation_side = old_pad, old_trunc

    #scoring
    strict_tp=strict_fp=strict_fn=0
    fuzzy_tp=fuzzy_fp=fuzzy_fn=0
    e_tp=e_fp=e_fn=0
    exact=invalid=0

    def setify_strict(nodes):
        S=set()
        for nn in nodes:
            t=(nn.get("type") or "").upper(); a=nn.get("attrs",{}) or {}; nid=nn.get("id","") or ""
            k=a.get("name") or a.get("term") or a.get("text") or nid
            k=_norm(k)
            if t and k: S.add((t,k))
        return S

    for gstr,pstr in zip(golds,preds):
        G_json=_deep_safe_json(gstr); P_json=_deep_safe_json(pstr)
        G_nodes=_extract_nodes(G_json)
        try:
            P_nodes=_extract_nodes(P_json)
        except Exception:
            P_nodes=[]; invalid+=1

        # strict nodes
        Gs, Ps = setify_strict(G_nodes), setify_strict(P_nodes)
        if Gs==Ps: exact+=1
        strict_tp+=len(Gs&Ps); strict_fp+=len(Ps-Gs); strict_fn+=len(Gs-Ps)

        # fuzzy nodes
        Gb,Pb=_bucket(G_nodes),_bucket(P_nodes)
        for t in (set(Gb)|set(Pb)):
            tp=_match_type(Gb.get(t,[]), Pb.get(t,[]), t)
            fp=len(Pb.get(t,[]))-tp; fn=len(Gb.get(t,[]))-tp
            fuzzy_tp+=tp; fuzzy_fp+=fp; fuzzy_fn+=fn

        # edges
        def build_map(nodes):
            m={}
            for n in nodes:
                nid=n.get("id") or ""
                tt=_keytext(n)
                if nid: m[str(nid)] = f"{tt[0]}|{tt[1]}"
            return m
        Gmap,Pmap=build_map(G_nodes),build_map(P_nodes)
        Ge=set(_edge_triplet(e,Gmap) for e in _extract_edges(G_json))
        Pe=set(_edge_triplet(e,Pmap) for e in _extract_edges(P_json))
        e_tp+=len(Ge&Pe); e_fp+=len(Pe-Ge); e_fn+=len(Ge-Pe)

    sp,sr,sf1 = _prf1(strict_tp,strict_fp,strict_fn)
    fp_,fr_,ff1 = _prf1(fuzzy_tp,fuzzy_fp,fuzzy_fn)
    ep,er,ef1   = _prf1(e_tp,e_fp,e_fn)

    return {
        "gen_strict_micro_precision": sp,
        "gen_strict_micro_recall":    sr,
        "gen_strict_micro_f1":        sf1,
        "gen_fuzzy_micro_precision":  fp_,
        "gen_fuzzy_micro_recall":     fr_,
        "gen_fuzzy_micro_f1":         ff1,
        "gen_edges_micro_precision":  ep,
        "gen_edges_micro_recall":     er,
        "gen_edges_micro_f1":         ef1,
        "gen_exact_match":            exact/max(1,len(prompts)),
        "gen_invalid_json_rate":      invalid/max(1,len(prompts)),
        "gen_num_samples":            len(prompts),
    }



In [None]:
#[Cell] 3 Data Loading and Robust Preprocessing
import re, json
from functools import partial
from datasets import load_dataset
from transformers import AutoTokenizer

#Robust JSON parsing
_JSON_BLOCK = re.compile(
    r"(?s)```(?:json)?\s*(\{.*?\})\s*```"   # fenced ```json { ... } ```
    r"|(\{.*\})"                             # or the first {...} span
)

_SMART = {
    "\u2018": "'", "\u2019": "'",
    "\u201c": '"', "\u201d": '"',
    "\u00a0": " ",
}

def _normalize_quotes(s: str) -> str:
    for k, v in _SMART.items():
        s = s.replace(k, v)
    return s

def deep_safe_json(x, max_depth=3):
    """
    Tries to coerce x into a dict (or list) by repeatedly:
      1) extracting a fenced or first {...} JSON block,
      2) normalizing smart quotes / code fences,
      3) relaxing single-quoted keys/values -> double quotes.
    Returns {} on failure.
    """
    obj = x
    for _ in range(max_depth):
        if isinstance(obj, (dict, list)):
            return obj
        if obj is None:
            return {}
        s = _normalize_quotes(str(obj).strip())

        #prefer fenced block, else first {...}
        m = _JSON_BLOCK.search(s)
        cand = (m.group(1) or m.group(2)) if m else s
        cand = cand.strip()
        if cand.startswith("```"):
            cand = cand.strip("`").strip()
            if cand.lower().startswith("json"):
                cand = cand[4:].lstrip()

        #strict parse
        try:
            obj = json.loads(cand)
            continue
        except Exception:
            pass

        #relaxed
        try:
            cand2 = re.sub(r"([{\s,])'([^']+?)'\s*:", r'\1"\2":', cand)  # keys
            cand2 = re.sub(r":\s*'([^']*?)'", r': "\1"', cand2)         # string values
            obj = json.loads(cand2)
            continue
        except Exception:
            break
    return {}

#normalize
def sanitize_and_format_completion(example):
    """
    - Parse completion JSON robustly
    - Ensure 'nodes' and 'edges' are lists
    - Normalize a few schema quirks (e.g., 'node_type'->'type', fix 'clause_id: ')
    - Store canonical JSON into 'clean_completion'
    - Add 'clean_parse_ok' flag for filtering/stats
    """
    raw = example.get("completion", "")
    parsed = deep_safe_json(raw)

    ok = isinstance(parsed, dict)
    if not ok:
        # fallback: keep raw string so the model still sees signal
        example["clean_completion"] = raw if isinstance(raw, str) else json.dumps(raw, ensure_ascii=False)
        example["clean_parse_ok"] = False
        return example

    if "clause_id: " in parsed and "clause_id" not in parsed:
        parsed["clause_id"] = parsed.pop("clause_id: ")

    nodes = parsed.get("nodes") or []
    edges = parsed.get("edges") or []
    if isinstance(nodes, dict): nodes = [nodes]
    if isinstance(edges, dict): edges = [edges]

    #normalize nodes -> ensure 'type' and 'attrs'
    norm_nodes = []
    for n in nodes if isinstance(nodes, list) else []:
        if not isinstance(n, dict):
            try:
                n = json.loads(n)
            except Exception:
                continue
        n = dict(n)
        if "type" not in n and "node_type" in n:
            n["type"] = n.pop("node_type")
        n.setdefault("attrs", {})
        for k in ("title", "name", "term", "text"):
            if k in n and k not in n["attrs"]:
                n["attrs"][k] = n[k]
        norm_nodes.append(n)

    # normalize edges -> accept src/source/from & tgt/target/to, uppercase 'type'
    norm_edges = []
    for e in edges if isinstance(edges, list) else []:
        if not isinstance(e, dict):
            try:
                e = json.loads(e)
            except Exception:
                continue
        e = dict(e)
        e.setdefault("src", e.get("source", e.get("from")))
        e.setdefault("tgt", e.get("target", e.get("to")))
        if "type" in e and isinstance(e["type"], str):
            e["type"] = e["type"].upper()
        norm_edges.append(e)

    parsed["nodes"] = norm_nodes
    parsed["edges"] = norm_edges
    example["clean_completion"] = json.dumps(parsed, ensure_ascii=False)
    example["clean_parse_ok"] = True
    return example

#load and clean the dataset
print("Loading and cleaning the dataset...")
HF_DATASET_ID = "moriyad/clause_minigraph_builder_clean"
HF_MODEL_ID   = "meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset = load_dataset(HF_DATASET_ID)
cleaned_dataset = dataset.map(sanitize_and_format_completion, num_proc=2)

ok_train = sum(int(x) for x in cleaned_dataset["train"]["clean_parse_ok"])
print(f"\nParse success (train): {ok_train}/{len(cleaned_dataset['train'])} = {ok_train/len(cleaned_dataset['train']):.1%}")
print("\nDataset after cleaning:")
print(cleaned_dataset)

def first_non_empty(ds_split):
    for rec in ds_split:
        try:
            obj = json.loads(rec["clean_completion"])
            if (obj.get("nodes") or obj.get("edges")):
                return obj
        except Exception:
            pass
    try:
        return json.loads(ds_split[0]["clean_completion"])
    except Exception:
        return {"nodes": [], "edges": []}

print("\nExample of a cleaned completion:")
pretty = first_non_empty(cleaned_dataset["train"])
print(json.dumps(pretty, ensure_ascii=False)[:1200])

#chat template formatting
SYS_PROMPT = "You are a legal minigraph extractor. Return ONLY valid JSON with 'nodes' and 'edges'."

def create_chat_format(example, tokenizer):
    return {
        "text": tokenizer.apply_chat_template(
            [
                {"role": "system", "content": SYS_PROMPT},
                {"role": "user", "content": example["prompt"]},
                {"role": "assistant", "content": example["clean_completion"]},
            ],
            tokenize=False,
            add_generation_prompt=False,
        )
    }

#tokenizer + pad token
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

train_dataset = cleaned_dataset["train"].map(
    partial(create_chat_format, tokenizer=tokenizer),
    num_proc=2,
    remove_columns=cleaned_dataset["train"].column_names,
)
eval_dataset = cleaned_dataset["validation"].map(
    partial(create_chat_format, tokenizer=tokenizer),
    num_proc=2,
    remove_columns=cleaned_dataset["validation"].column_names,
)
print("\nTraining and validation datasets formatted for SFT.")
print(f"train: {len(train_dataset)} examples, eval: {len(eval_dataset)} examples")


Loading and cleaning the dataset...

Parse success (train): 2071/2071 = 100.0%

Dataset after cleaning:
DatasetDict({
    train: Dataset({
        features: ['id', 'contract_id', 'clause_id', 'prompt', 'completion', 'clean_completion', 'clean_parse_ok'],
        num_rows: 2071
    })
    validation: Dataset({
        features: ['id', 'contract_id', 'clause_id', 'prompt', 'completion', 'clean_completion', 'clean_parse_ok'],
        num_rows: 262
    })
    test: Dataset({
        features: ['id', 'contract_id', 'clause_id', 'prompt', 'completion', 'clean_completion', 'clean_parse_ok'],
        num_rows: 266
    })
})

Example of a cleaned completion:
{"contract_id": "AIRSPANNETWORKSINC_04_11_2000-EX-10.5-Distributor Agreement", "nodes": [{"id": "10", "title": "10", "level": 1, "type": "CLAUSE", "attrs": {"title": "10"}}, {"id": "10.3", "title": "10.3", "level": 2, "type": "CLAUSE", "attrs": {"title": "10.3"}}, {"id": "party:Airspan", "name": "Airspan", "type": "PARTY", "attrs": {"name":

Map (num_proc=2): 100%|█████████████████████████████████████████████████████████████████| 2071/2071 [00:01<00:00, 1564.72 examples/s]
Map (num_proc=2): 100%|████████████████████████████████████████████████████████████████████| 262/262 [00:01<00:00, 218.90 examples/s]


Training and validation datasets formatted for SFT.
train: 2071 examples, eval: 262 examples





In [None]:
#robust evaluation
def evaluate_model_debug(
    model, tokenizer, dataset, system_prompt,
    max_samples=200, batch_size=4, max_new_tokens=1024,
    use_json_stopper=True, schema_strict=True,
    gold_field="completion",
    print_samples=3
):
    import json, gc, re, numpy as np, torch
    from collections import defaultdict
    from transformers import StoppingCriteria, StoppingCriteriaList

    #helpers
    def _safe_eos_ids(tok):
        ids = []
        for t in ["<|eot_id|>", "<|eos|>", tok.eos_token]:
            if not t: continue
            try:
                i = tok.convert_tokens_to_ids(t) if isinstance(t, str) else t
                if i is not None and i != tok.unk_token_id: ids.append(i)
            except: pass
        if tok.eos_token_id is not None: ids.append(tok.eos_token_id)
        seen=set(); out=[]
        for i in ids:
            if i not in seen: seen.add(i); out.append(i)
        return out or [tok.eos_token_id]

    class JsonStopper(StoppingCriteria):
        def __init__(self, tokenizer, input_len):
            self.tok = tokenizer; self.input_len = input_len
        def __call__(self, input_ids, scores, **kwargs):
            s = self.tok.decode(input_ids[0, self.input_len:], skip_special_tokens=True)
            opens = s.count("{"); closes = s.count("}")
            return ('"nodes"' in s or "'nodes'" in s) and opens>0 and opens==closes

    def deep_json(x, want_nodes_edges=True):
        """Parse JSON; if object lacks nodes/edges (when want_nodes_edges), return (None, reason)."""
        if isinstance(x, (dict, list)):
            obj = x
        else:
            s = (x if isinstance(x, str) else str(x or "")).strip()
            b, e = s.find("{"), s.rfind("}")
            cand = s[b:e+1] if (b!=-1 and e!=-1 and e>b) else s
            try:
                obj = json.loads(cand)
            except Exception as e:
                return None, f"json_error: {e}"
        if want_nodes_edges:
            if not isinstance(obj, dict):
                return None, "not_object"
            if "nodes" not in obj or "edges" not in obj:
                return None, "missing_nodes_edges"
            if not isinstance(obj.get("nodes"), list) or not isinstance(obj.get("edges"), list):
                return None, "bad_types"
        return obj, ""

    def extract_nodes(obj):
        return obj.get("nodes", []) if isinstance(obj, dict) else []

    def extract_edges(obj):
        return obj.get("edges", []) if isinstance(obj, dict) else []

    COMPANY_SUFFIX_RE = re.compile(r"\b(inc\.?|ltd\.?|llc|l\.l\.c\.|corp\.?|co\.?|ag|gmbh)\b", re.I)
    WS_RE = re.compile(r"\s+")
    def norm(s):
        if not isinstance(s, str): s = str(s) if s is not None else ""
        s = s.lower().replace("&","and")
        s = COMPANY_SUFFIX_RE.sub("", s)
        return WS_RE.sub(" ", s).strip()

    def toks(s): return re.findall(r"[a-z0-9]+", s.lower())
    def jacc(a,b):
        sa,sb=set(toks(a)),set(toks(b))
        if not sa or not sb: return 1.0 if a.strip()==b.strip() and a.strip()!="" else 0.0
        return len(sa&sb)/max(1,len(sa|sb))

    PCT_RE=re.compile(r"(\d+(?:\.\d+)?)\s*%"); MONEY_RE=re.compile(r"(\$|usd)\s*([\d,]+(?:\.\d+)?)",re.I)
    DAYS_RE=re.compile(r"(\d+)\s*days?"); YEARS_RE=re.compile(r"(\d+)\s*years?")
    NUM_WORDS={"zero":0,"one":1,"two":2,"three":3,"four":4,"five":5,"six":6,"seven":7,"eight":8,"nine":9,
               "ten":10,"eleven":11,"twelve":12,"thirteen":13,"fourteen":14,"fifteen":15,"sixteen":16,
               "seventeen":17,"eighteen":18,"nineteen":19,"twenty":20,"thirty":30,"forty":40,"fifty":50,
               "sixty":60,"seventy":70,"eighty":80,"ninety":90}
    def w2n(s):
        s=s.lower()
        for w,n in NUM_WORDS.items():
            if re.search(rf"\b{w}\b", s): return n
        return None
    def canon_value(text):
        s=norm(text)
        m=PCT_RE.search(s)
        if m: return f"{float(m.group(1)):.0f}%"
        if "percent" in s:
            n=w2n(s)
            if n is not None: return f"{n}%"
        m=MONEY_RE.search(s)
        if m: return f"usd {m.group(2).replace(',','')}"
        m=DAYS_RE.search(s)
        if m: return f"{int(m.group(1))} days"
        m=YEARS_RE.search(s)
        if m: return f"{int(m.group(1))} years"
        return s

    def keytext(n):
        t=(n.get("type") or "").upper(); a=n.get("attrs",{}) or {}; nid=n.get("id","") or ""
        if t=="CLAUSE":
            k=a.get("id") or nid or a.get("title") or ""; return t, norm(k)
        if t=="DEFINED_TERM":
            k=a.get("term") or nid.split(":",1)[-1]; return t, norm(k)
        if t=="PARTY":
            k=a.get("name") or a.get("text") or nid.split(":",1)[-1]; return t, norm(k)
        if t=="VALUE":
            k=a.get("text") or nid.split(":",1)[-1]; return t, canon_value(k)
        k=a.get("text") or a.get("term") or a.get("name") or nid; return t, norm(k)

    THRESH={"CLAUSE":0.90,"DEFINED_TERM":0.80,"PARTY":0.85,"VALUE":0.75}
    def sim(t,a,b):
        if t=="CLAUSE": return 1.0 if a==b else jacc(a,b)
        if t in ("DEFINED_TERM","PARTY"): return jacc(a,b)
        if t=="VALUE":
            if a==b and (a.endswith("%") or a.endswith("days") or a.endswith("years") or a.startswith("usd")): return 1.0
            return jacc(a,b)
        return jacc(a,b)

    def bucket(nodes):
        b=defaultdict(list)
        for n in nodes:
            t,kt=keytext(n)
            if t and kt: b[t].append(kt)
        return b

    def match_type(G_list, P_list, t):
        if not G_list or not P_list: return 0
        pairs=[]
        for i,g in enumerate(G_list):
            for j,p in enumerate(P_list):
                s=sim(t,g,p)
                if s>=THRESH.get(t,0.8): pairs.append((s,i,j))
        pairs.sort(reverse=True)
        used_i=set(); used_j=set(); tp=0
        for s,i,j in pairs:
            if i in used_i or j in used_j: continue
            used_i.add(i); used_j.add(j); tp+=1
        return tp

    def edge_triplet(e, node_map):
        def pick(d, keys):
            for k in keys:
                if k in d and d[k] is not None:
                    return d[k]
            return None
        typ = (pick(e, ["type","edge_type","label"]) or "").upper()
        raw_src = pick(e, ["src","source","from"])
        raw_tgt = pick(e, ["tgt","target","to"])
        def resolve(v):
            if v is None: return ""
            v_str = str(v)
            if v_str in node_map: return node_map[v_str]
            return norm(v_str)
        return (resolve(raw_src), typ, resolve(raw_tgt))

    def prf1(tp,fp,fn):
        p=tp/(tp+fp) if (tp+fp) else 0.0
        r=tp/(tp+fn) if (tp+fn) else 0.0
        f=2*p*r/(p+r) if (p+r) else 0.0
        return p,r,f

    n = len(dataset)
    use = min(max_samples, n)
    idxs = np.random.choice(n, use, replace=False) if use < n else np.arange(n)
    sub = dataset.select(idxs)
    prompts = sub["prompt"]
    gold_key = gold_field if gold_field in sub.column_names else ("clean_completion" if "clean_completion" in sub.column_names else "completion")
    golds = sub[gold_key]

    tok = tokenizer
    old_pad, old_trunc = tok.padding_side, getattr(tok, "truncation_side", "right")
    tok.padding_side = tok.truncation_side = "left"
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    EOS = _safe_eos_ids(tok); PAD = tok.pad_token_id or tok.eos_token_id

    preds=[]
    for i in range(0, len(prompts), batch_size):
        chats = [tok.apply_chat_template(
                    [{"role":"system","content": system_prompt},
                     {"role":"user","content": p}],
                    tokenize=False, add_generation_prompt=True)
                 for p in prompts[i:i+batch_size]]
        batch = tok(chats, return_tensors="pt", padding=True, truncation=True).to(model.device)
        stopper = StoppingCriteriaList([JsonStopper(tok, input_len=batch["input_ids"].shape[1])]) if use_json_stopper else StoppingCriteriaList([])
        with torch.no_grad():
            out = model.generate(
                **batch,
                max_new_tokens=max_new_tokens,
                do_sample=False, temperature=None, top_p=None, top_k=None,
                use_cache=True,
                eos_token_id=EOS, pad_token_id=PAD,
                stopping_criteria=stopper,
            )
        gen = tok.batch_decode(out[:, batch["input_ids"].shape[1]:], skip_special_tokens=True)
        preds.extend(gen)
        del out, batch; torch.cuda.empty_cache(); gc.collect()

    tok.padding_side, tok.truncation_side = old_pad, old_trunc

    print("\n--- SAMPLE PROMPTS & PREDICTIONS ---")
    for i in range(min(print_samples, len(prompts))):
        print(f"\n[Prompt {i}] {prompts[i][:300]}...")
        print(f"[Pred  {i}] {preds[i][:300]}...")

    strict_tp=strict_fp=strict_fn=0
    fuzzy_tp=fuzzy_fp=fuzzy_fn=0
    e_tp=e_fp=e_fn=0
    exact=invalid=0

    gold_nodes_total=gold_edges_total=0
    pred_nodes_total=pred_edges_total=0

    def normalize_node_schema(n: dict) -> dict:
        """Unify node shape to: {'id': str, 'type': UPPER, 'attrs': {...}}."""
        n = dict(n or {})
        out = {"id": n.get("id", ""), "type": "", "attrs": {}}
        t = (n.get("type") or n.get("node_type") or "").upper()
        out["type"] = t

        attrs = {}
        if isinstance(n.get("attrs"), dict):
            attrs.update(n["attrs"])

        for k in ("title", "level", "text", "term", "name", "role", "address"):
            if k in n and k not in attrs:
                attrs[k] = n[k]

        out["attrs"] = attrs
        return out

    #strict nodes
    def setify_strict(nodes):
        S=set()
        for nn in nodes:
            t=(nn.get("type") or "").upper()
            a=nn.get("attrs",{}) or {}
            nid=nn.get("id","") or ""
            k=a.get("name") or a.get("term") or a.get("text") or nid
            k=norm(k)
            if t and k: S.add((t,k))
        return S

    for gstr,pstr in zip(golds,preds):
        G_json, g_err = deep_json(gstr, want_nodes_edges=schema_strict)
        if G_json is None:
            invalid += 1
            continue
        P_json, p_err = deep_json(pstr, want_nodes_edges=schema_strict)
        if P_json is None:
            invalid += 1
            continue

        G_nodes_raw = extract_nodes(G_json)
        P_nodes_raw = extract_nodes(P_json)

        #normalize both sides
        G_nodes = [normalize_node_schema(x) for x in G_nodes_raw if isinstance(x, dict)]
        P_nodes = [normalize_node_schema(x) for x in P_nodes_raw if isinstance(x, dict)]

        #keep edges
        G_edges = extract_edges(G_json)
        P_edges = extract_edges(P_json)
        Gs, Ps = setify_strict(G_nodes), setify_strict(P_nodes)

        #exact matches
        if (Gs or Ps) and (Gs == Ps):
            exact += 1

        gold_nodes_total += len(G_nodes); gold_edges_total += len(G_edges)
        pred_nodes_total += len(P_nodes); pred_edges_total += len(P_edges)

        #strict nodes
        if Gs==Ps: exact+=1
        strict_tp+=len(Gs&Ps); strict_fp+=len(Ps-Gs); strict_fn+=len(Gs-Ps)

        #fuzzy nodes
        def bucket(nodes):
            b=defaultdict(list)
            for n in nodes:
                t,kt=keytext(n)
                if t and kt: b[t].append(kt)
            return b
        Gb,Pb=bucket(G_nodes),bucket(P_nodes)
        for t in (set(Gb)|set(Pb)):
            tp=match_type(Gb.get(t,[]), Pb.get(t,[]), t)
            fp=len(Pb.get(t,[]))-tp; fn=len(Gb.get(t,[]))-tp
            fuzzy_tp+=tp; fuzzy_fp+=fp; fuzzy_fn+=fn

        #edges
        def build_map(nodes):
            m={}
            for n in nodes:
                nid=n.get("id") or ""
                tt=keytext(n)
                if nid: m[str(nid)] = f"{tt[0]}|{tt[1]}"
            return m
        Gmap,Pmap=build_map(G_nodes),build_map(P_nodes)
        def edge_triplet(e, node_map):
            def pick(d, keys):
                for k in keys:
                    if k in d and d[k] is not None:
                        return d[k]
                return None
            typ = (pick(e, ["type","edge_type","label"]) or "").upper()
            raw_src = pick(e, ["src","source","from"])
            raw_tgt = pick(e, ["tgt","target","to"])
            def resolve(v):
                if v is None: return ""
                v_str = str(v)
                if v_str in node_map: return node_map[v_str]
                return norm(v_str)
            return (resolve(raw_src), typ, resolve(raw_tgt))
        Ge=set(edge_triplet(e,Gmap) for e in G_edges)
        Pe=set(edge_triplet(e,Pmap) for e in P_edges)
        e_tp+=len(Ge&Pe); e_fp+=len(Pe-Ge); e_fn+=len(Ge-Pe)

    sp,sr,sf1 = prf1(strict_tp,strict_fp,strict_fn)
    fp_,fr_,ff1 = prf1(fuzzy_tp,fuzzy_fp,fuzzy_fn)
    ep,er,ef1   = prf1(e_tp,e_fp,e_fn)

    print("\n--- AUDIT COUNTS ---")
    print(f"Gold avg nodes/edges: {gold_nodes_total/max(1,use):.2f} / {gold_edges_total/max(1,use):.2f}")
    print(f"Pred avg nodes/edges: {pred_nodes_total/max(1,use):.2f} / {pred_edges_total/max(1,use):.2f}")

    return {
        "gen_strict_micro_precision": sp,
        "gen_strict_micro_recall":    sr,
        "gen_strict_micro_f1":        sf1,
        "gen_fuzzy_micro_precision":  fp_,
        "gen_fuzzy_micro_recall":     fr_,
        "gen_fuzzy_micro_f1":         ff1,
        "gen_edges_micro_precision":  ep,
        "gen_edges_micro_recall":     er,
        "gen_edges_micro_f1":         ef1,
        "gen_exact_match":            exact/max(1,use),
        "gen_invalid_json_rate":      invalid/max(1,use),
        "gen_num_samples":            use,
    }


In [None]:
#Cell 4a - run evaluation test
from datasets import load_dataset

metrics = evaluate_model_debug(
    model, tokenizer, cleaned_dataset["validation"],
    system_prompt=SYS_PROMPT,
    max_samples=50, batch_size=4, max_new_tokens=2048,
    use_json_stopper=False,
    schema_strict=True,
    gold_field="completion",
    print_samples=2
)
print(metrics)



--- SAMPLE PROMPTS & PREDICTIONS ---

[Prompt 0] {"instruction": "Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.\n\nOutput ONLY a single, strict JSON object with this str...
[Pred  0] {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "11.2", "node_type": "CLAUSE", "title": "11.2", "level": 2}, {"id": "11.2.1", "node_type": "CLAUSE", "title": "11.2.1", "level": 3}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "pa...

[Prompt 1] {"instruction": "Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.\n\nOutput ONLY a single, strict JSON object with this str...
[Pred  1] {"cont

In [None]:
#Cell 4b - run full evaluation test

metrics = evaluate_model_debug(
    model, tokenizer, cleaned_dataset["validation"],
    system_prompt=SYS_PROMPT,
    max_samples=200, batch_size=4, max_new_tokens=2048,
    use_json_stopper=False,
    schema_strict=True,
    gold_field="completion",
    print_samples=2
)
print(metrics)



--- SAMPLE PROMPTS & PREDICTIONS ---

[Prompt 0] {"instruction": "Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.\n\nOutput ONLY a single, strict JSON object with this str...
[Pred  0] {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "10.5", "node_type": "CLAUSE", "title": "10.5", "level": 2}, {"id": "10.5.2", "node_type": "CLAUSE", "title": "10.5.2", "level": 3}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "pa...

[Prompt 1] {"instruction": "Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.\n\nOutput ONLY a single, strict JSON object with this str...
[Pred  1] {"cont

In [1]:
#@title Training Setup

In [None]:
#@title helpers
    def _deep_safe_json(self, x, max_depth=3):
        obj = x
        for _ in range(max_depth):
            if isinstance(obj, (dict, list)):
                return obj
            if obj is None:
                return {}
            s = str(obj).strip()
            b, e = s.find("{"), s.rfind("}")
            cand = s[b:e+1] if (b != -1 and e != -1 and e > b) else s
            try:
                obj = json.loads(cand); continue
            except Exception:
                break
        return {}

    def _extract_nodes(self, obj):
        arr = obj.get("nodes", []) if isinstance(obj, dict) else []
        if isinstance(arr, dict): arr = [arr]
        out=[]
        for it in arr:
            if isinstance(it, dict): out.append(it)
            elif isinstance(it, str):
                try:
                    d = json.loads(it)
                    if isinstance(d, dict): out.append(d)
                except: pass
        return out

    def _extract_edges(self, obj):
        arr = obj.get("edges", []) if isinstance(obj, dict) else []
        if isinstance(arr, dict): arr = [arr]
        out=[]
        for it in arr:
            if isinstance(it, dict): out.append(it)
            elif isinstance(it, str):
                try:
                    d = json.loads(it)
                    if isinstance(d, dict): out.append(d)
                except: pass
        return out

    def _edge_triplet(self, e, node_map):
        def pick(d, keys):
            for k in keys:
                if k in d and d[k] is not None:
                    return d[k]
            return None
        typ = (pick(e, ["type","edge_type","label"]) or "").upper()
        raw_src = pick(e, ["src","source","from"])
        raw_tgt = pick(e, ["tgt","target","to"])
        def resolve(v):
            if v is None: return ""
            v_str = str(v)
            if v_str in node_map:
                return node_map[v_str]
            return self._norm(v_str)
        return (resolve(raw_src), typ, resolve(raw_tgt))

    # ----- normalization & matching -----
    COMPANY_SUFFIX_RE = re.compile(r"\b(inc\.?|ltd\.?|llc|l\.l\.c\.|corp\.?|co\.?|ag|gmbh)\b", re.I)
    WS_RE = re.compile(r"\s+")
    def _norm(self, s):
        if not isinstance(s, str): s = str(s) if s is not None else ""
        s = s.lower().replace("&","and")
        s = self.COMPANY_SUFFIX_RE.sub("", s)
        return self.WS_RE.sub(" ", s).strip()

    def _toks(self, s): return re.findall(r"[a-z0-9]+", s.lower())
    def _jacc(self, a,b):
        sa,sb=set(self._toks(a)),set(self._toks(b))
        if not sa or not sb: return 1.0 if a.strip()==b.strip() and a.strip()!="" else 0.0
        return len(sa&sb)/max(1,len(sa|sb))

    PCT_RE=re.compile(r"(\d+(?:\.\d+)?)\s*%"); MONEY_RE=re.compile(r"(\$|usd)\s*([\d,]+(?:\.\d+)?)",re.I)
    DAYS_RE=re.compile(r"(\d+)\s*days?"); YEARS_RE=re.compile(r"(\d+)\s*years?")
    NUM_WORDS={"zero":0,"one":1,"two":2,"three":3,"four":4,"five":5,"six":6,"seven":7,"eight":8,"nine":9,
               "ten":10,"eleven":11,"twelve":12,"thirteen":13,"fourteen":14,"fifteen":15,"sixteen":16,
               "seventeen":17,"eighteen":18,"nineteen":19,"twenty":20,"thirty":30,"forty":40,"fifty":50,
               "sixty":60,"seventy":70,"eighty":80,"ninety":90}
    def _w2n(self,s):
        s=s.lower()
        for w,n in self.NUM_WORDS.items():
            if re.search(rf"\b{w}\b", s): return n
        return None

    def _canon_value(self, text):
        s=self._norm(text)
        m=self.PCT_RE.search(s)
        if m: return f"{float(m.group(1)):.0f}%"
        if "percent" in s:
            n=self._w2n(s)
            if n is not None: return f"{n}%"
        m=self.MONEY_RE.search(s)
        if m: return f"usd {m.group(2).replace(',','')}"
        m=self.DAYS_RE.search(s)
        if m: return f"{int(m.group(1))} days"
        m=self.YEARS_RE.search(s)
        if m: return f"{int(m.group(1))} years"
        return s

    def _keytext(self, n):
        t=(n.get("type") or "").upper(); a=n.get("attrs",{}) or {}; nid=n.get("id","") or ""
        if t=="CLAUSE":
            k=a.get("id") or nid or a.get("title") or ""; return t,self._norm(k)
        if t=="DEFINED_TERM":
            k=a.get("term") or nid.split(":",1)[-1]; return t,self._norm(k)
        if t=="PARTY":
            k=a.get("name") or a.get("text") or nid.split(":",1)[-1]; return t,self._norm(k)
        if t=="VALUE":
            k=a.get("text") or nid.split(":",1)[-1]; return t,self._canon_value(k)
        k=a.get("text") or a.get("term") or a.get("name") or nid; return t,self._norm(k)

    THRESH={"CLAUSE":0.90,"DEFINED_TERM":0.80,"PARTY":0.85,"VALUE":0.75}
    def _sim(self,t,a,b):
        if t=="CLAUSE": return 1.0 if a==b else self._jacc(a,b)
        if t=="DEFINED_TERM": return self._jacc(a,b)
        if t=="PARTY": return self._jacc(a,b)
        if t=="VALUE":
            if a==b and (a.endswith("%") or a.endswith("days") or a.endswith("years") or a.startswith("usd")): return 1.0
            return self._jacc(a,b)
        return self._jacc(a,b)

    def _bucket(self, nodes):
        b=defaultdict(list)
        for n in nodes:
            t,kt=self._keytext(n)
            if t and kt: b[t].append(kt)
        return b

    def _match_type(self, G_list, P_list, t):
        if not G_list or not P_list: return 0
        pairs=[]
        for i,g in enumerate(G_list):
            for j,p in enumerate(P_list):
                s=self._sim(t,g,p)
                if s>=self.THRESH.get(t,0.8): pairs.append((s,i,j))
        pairs.sort(reverse=True)
        used_i=set(); used_j=set(); tp=0
        for s,i,j in pairs:
            if i in used_i or j in used_j: continue
            used_i.add(i); used_j.add(j); tp+=1
        return tp

    def _prf1(self,tp,fp,fn):
        p=tp/(tp+fp) if (tp+fp) else 0.0
        r=tp/(tp+fn) if (tp+fn) else 0.0
        f=2*p*r/(p+r) if (p+r) else 0.0
        return p,r,f

    # ----- generation -----
    def _generate_texts(self, model, prompts):
        tok = self.tokenizer
        old_pad, old_trunc = tok.padding_side, getattr(tok, "truncation_side", "right")
        tok.padding_side = tok.truncation_side = "left"
        if tok.pad_token is None: tok.pad_token = tok.eos_token
        EOS = _safe_eos_ids(tok)
        PAD = tok.pad_token_id or tok.eos_token_id

        outs=[]
        for i in range(0, len(prompts), self.gen_batch_size):
            texts = [tok.apply_chat_template(
                        [{"role": "system","content": self.sys_prompt},
                         {"role": "user","content": p}],
                        tokenize=False, add_generation_prompt=True)
                     for p in prompts[i:i+self.gen_batch_size]]
            batch = tok(texts, return_tensors="pt", padding=True, truncation=True).to(model.device)
            stop = StoppingCriteriaList([JsonStopper(tok, input_len=batch["input_ids"].shape[1])])
            with torch.no_grad():
                out = model.generate(
                    **batch,
                    max_new_tokens=self.gen_max_new_tokens,
                    min_new_tokens=1,
                    do_sample=False, temperature=None, top_p=None, top_k=None,
                    use_cache=True,
                    eos_token_id=EOS, pad_token_id=PAD,
                    max_time=60,  # per-call safety cap
                    stopping_criteria=stop,
                )
            gen = tok.batch_decode(out[:, batch["input_ids"].shape[1]:], skip_special_tokens=True)
            outs.extend(gen)
            del out, batch
            torch.cuda.empty_cache(); gc.collect()

        tok.padding_side, tok.truncation_side = old_pad, old_trunc
        return outs

    # ----- metrics (nodes + edges) -----
    def compute_node_metrics(self, model):
        n = len(self.eval_ds_raw)
        use = min(self.gen_max_samples, n)
        idxs = np.random.choice(n, use, replace=False) if use < n else np.arange(n)
        sub = self.eval_ds_raw.select(idxs)
        prompts = sub["prompt"]
        gold_key = "clean_completion" if "clean_completion" in sub.column_names else "completion"
        golds = sub[gold_key]

        preds = self._generate_texts(model, prompts)

        strict_tp=strict_fp=strict_fn=0
        fuzzy_tp=fuzzy_fp=fuzzy_fn=0
        e_tp=e_fp=e_fn=0
        exact=invalid=0

        def setify_strict(nodes):
            S=set()
            for nn in nodes:
                t=(nn.get("type") or "").upper(); a=nn.get("attrs",{}) or {}; nid=nn.get("id","") or ""
                k=a.get("name") or a.get("term") or a.get("text") or nid
                k=self._norm(k)
                if t and k: S.add((t,k))
            return S

        for gstr,pstr in zip(golds,preds):
            G_json=self._deep_safe_json(gstr); P_json=self._deep_safe_json(pstr)
            G_nodes=self._extract_nodes(G_json)
            try:
                P_nodes=self._extract_nodes(P_json)
            except Exception:
                P_nodes=[]; invalid+=1

            # strict nodes
            Gs, Ps = setify_strict(G_nodes), setify_strict(P_nodes)
            if Gs==Ps: exact+=1
            strict_tp+=len(Gs&Ps); strict_fp+=len(Ps-Gs); strict_fn+=len(Gs-Ps)

            # fuzzy nodes
            Gb,Pb=self._bucket(G_nodes),self._bucket(P_nodes)
            for t in (set(Gb)|set(Pb)):
                tp=self._match_type(Gb.get(t,[]), Pb.get(t,[]), t)
                fp=len(Pb.get(t,[]))-tp; fn=len(Gb.get(t,[]))-tp
                fuzzy_tp+=tp; fuzzy_fp+=fp; fuzzy_fn+=fn

            # edges (strict on (src,type,tgt) triples)
            def build_map(nodes):
                m={}
                for n in nodes:
                    nid=n.get("id") or ""
                    tt=self._keytext(n)
                    if nid: m[str(nid)] = f"{tt[0]}|{tt[1]}"
                return m
            Gmap,Pmap=build_map(G_nodes),build_map(P_nodes)
            Ge=set(self._edge_triplet(e,Gmap) for e in self._extract_edges(G_json))
            Pe=set(self._edge_triplet(e,Pmap) for e in self._extract_edges(P_json))
            e_tp+=len(Ge&Pe); e_fp+=len(Pe-Ge); e_fn+=len(Ge-Pe)

        sp,sr,sf1 = self._prf1(strict_tp,strict_fp,strict_fn)
        fp_,fr_,ff1 = self._prf1(fuzzy_tp,fuzzy_fp,fuzzy_fn)
        ep,er,ef1   = self._prf1(e_tp,e_fp,e_fn)

        return {
            "gen_strict_micro_precision": sp,
            "gen_strict_micro_recall":    sr,
            "gen_strict_micro_f1":        sf1,
            "gen_fuzzy_micro_precision":  fp_,
            "gen_fuzzy_micro_recall":     fr_,
            "gen_fuzzy_micro_f1":         ff1,
            "gen_edges_micro_precision":  ep,
            "gen_edges_micro_recall":     er,
            "gen_edges_micro_f1":         ef1,
            "gen_exact_match":            exact/max(1,len(prompts)),
            "gen_invalid_json_rate":      invalid/max(1,len(prompts)),
            "gen_num_samples":            len(prompts),
        }

In [None]:
#@title [Cell] 2 execution for train setup - Load base model (4-bit) + apply SFT adapter
import os, torch, wandb
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#MODEL_DIR   = (SFT_RUN_DIR / "adapter-20250827-142825") #- Original SFT qlora model
MODEL_DIR   = (SFT_RUN_DIR / "notebooks/checkpoints/grpo_final_gen_equal_batch_short_eval")
# ---- 4-bit NF4 on A100; bf16 math ----
bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token or tokenizer.pad_token

def _print_gpu():
    try:
        free, total = torch.cuda.mem_get_info()
        print(f"GPU free/total: {free/1e9:.1f} / {total/1e9:.1f} GB")
    except Exception:
        pass

print("Loading base model (4-bit) with FlashAttention-2…")
_print_gpu()

common_kwargs = dict(
    quantization_config=bnb,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
)

try:
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        device_map={"": "cuda:0"},
        **common_kwargs,
    )
except ValueError as e:
    print("[Retry] memory-constrained auto mapping… Reason:", e)
    try:
        max_mem = {0: "78GiB", "cpu": "0GiB"}
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_ID,
            device_map="auto",
            max_memory=max_mem,
            **common_kwargs,
        )
    except ValueError as e2:
        print("[Retry] CPU offload last resort… Reason:", e2)
        offload_dir = os.path.join(os.getcwd(), "offload")
        os.makedirs(offload_dir, exist_ok=True)
        model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_ID,
            device_map="auto",
            offload_folder=offload_dir,
            **common_kwargs,
        )

try:
    eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
    if eot_id is None or eot_id == tokenizer.unk_token_id:
        tokenizer.add_special_tokens({"additional_special_tokens": ["<|eot_id|>"]})
        model.resize_token_embeddings(len(tokenizer))
        eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
except Exception:
    eot_id = None

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Applying SFT adapter from", MODEL_DIR)
model = PeftModel.from_pretrained(model, MODEL_DIR, is_trainable=True)
model.config.use_cache = False

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

try:
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
except TypeError:
    model.gradient_checkpointing_enable()

_print_gpu()
print("Model ready (FA2, 4-bit NF4).")


Loading base model (4-bit) with FlashAttention-2…
GPU free/total: 69.2 / 85.0 GB


Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:13<00:00,  3.28s/it]


Applying SFT adapter from /home/ubuntu/ai/adapter-20250827-142825
GPU free/total: 61.3 / 85.0 GB
Model ready (FA2, 4-bit NF4).


In [None]:
#@title Cell 3a Training - Data load and clean
from datasets import load_dataset
import json
import re

HF_DATASET_ID = "moriyad/clause_minigraph_builder_grpo"

ds = load_dataset(HF_DATASET_ID)

def deep_safe_json(x, max_depth=3):
    obj = x
    for _ in range(max_depth):
        if isinstance(obj, (dict, list)): return obj
        if isinstance(obj, str):
            s = obj.strip()
            b, e = s.find("{"), s.rfind("}")
            cand = s[b:e+1] if (b != -1 and e != -1 and e > b) else s
            try:
                obj = json.loads(cand); continue
            except Exception:
                break
        break
    return {}


DATE_RE = re.compile(r"(Cutting Knowledge Date:.*\n|Today Date:.*\n)")

strip_dates = lambda ex: {
    **ex,
    "clean_instruction": DATE_RE.sub("", ex["clean_instruction"]).strip()
}
ds["train"]

Dataset({
    features: ['id', 'contract_id', 'clause_id', 'prompt', 'completion'],
    num_rows: 1312
})

In [None]:
eval_ds_full  = (ds["validation"])
train_ds_full = ds["train"]
#small evaluation slice for quick periodic eval
EVAL_SLICE = min(64, len(ds["validation"]))
eval_ds_small = (ds["validation"]).select(range(EVAL_SLICE))

print("Train size:", len(train_ds_full), "| Eval slice:", len(eval_ds_small))


Train size: 1312 | Eval slice: 64


In [None]:
#@title Cell 3b Training - Building the gold index for scoring

from typing import Any, Dict, Iterable, Tuple

def build_gold_index(ds_train_full: Iterable[Dict[str, Any]]) -> Dict[Tuple[str, str], Any]:
    """
    Build a dictionary {(contract_id, clause_id): completion_value}.
    Supports records shaped like:
      {
        "contract_id": "...",
        "clause_id": "...",          # or nested: "clause": {"id": "..."}
        "completion": {...}          # gold completion (dict/string), preferred
        # (fallback) "gold": {...}   # if 'completion' missing
      }
    """
    index: Dict[Tuple[str, str], Any] = {}
    for row in ds_train_full:
        if isinstance(row, str):
            try:
                ex = json.loads(row)
            except json.JSONDecodeError:
                print("bad row, skipping ")
                continue
        elif isinstance(row, dict):
            ex = row
        else:
            print(f"row is something else {type(row)}")
            continue

        cid = ex.get("contract_id")
        clid = ex.get("clause_id")
        if clid is None:
            clause_obj = ex.get("clause") or {}
            clid = clause_obj.get("id")
        if not cid or not clid:
            continue

        key = (str(cid), str(clid))

        if "completion" in ex:
            index[key] = ex["completion"]
        elif "gold" in ex:
            index[key] = ex["gold"]
        else:
            continue
    return index

type(train_ds_full)
ds_train_index = build_gold_index(train_ds_full)
ds_eval_index = build_gold_index(eval_ds_full)
full_index = ds_train_index | ds_eval_index
print(len(full_index))


1500


In [None]:
rowprmpt = ds["train"][0]["prompt"]
objprompt = json.loads(rowprmpt)
print(objprompt.get("instruction"))
print(objprompt.get("input"))


Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their corresponding nodes accord

In [None]:
#@title Cell 3c Training - Further cleanup and special tokens validation

import re, json

def _strip_knowledge_banner(text: str) -> str:
    if not text:
        return ""
    lines = text.splitlines()
    while lines and (
        re.match(r'^\s*system\s*$', lines[0], flags=re.I) or
        re.match(r'^\s*(Cutting\s*Knowledge\s*Date|Knowledge\s*cutoff)\s*:', lines[0], flags=re.I) or
        re.match(r'^\s*(Today\s*Date|Current\s*date)\s*:', lines[0], flags=re.I) or
        re.match(r'^\s*$', lines[0])
    ):
        lines.pop(0)
    return "\n".join(lines).lstrip("\n")

def _strip_banner_from_rendered_prompt(prompt: str) -> str:
    """
    Removes the Llama-style knowledge banner that the chat template injects
    at the start of the system block. Only touches the first system section.
    """
    sys_hdr = "<|start_header_id|>system<|end_header_id|>\n\n"
    i = prompt.find(sys_hdr)
    if i == -1:
        return prompt
    start = i + len(sys_hdr)

    j = prompt.find("<|start_header_id|>user<|end_header_id|>", start)
    if j == -1:
        j = prompt.find("<|eot_id|><|start_header_id|>user<|end_header_id|>", start)
    end = j if j != -1 else len(prompt)

    sys_block = prompt[start:end]
    lines = sys_block.splitlines()

    k = 0
    while k < len(lines) and (
        re.match(r'^\s*system\s*$', lines[k], flags=re.I) or
        re.match(r'^\s*(Cutting\s*Knowledge\s*Date|Knowledge\s*cutoff)\s*:\s*.*$', lines[k], flags=re.I) or
        re.match(r'^\s*(Today\s*Date|Current\s*date)\s*:\s*.*$', lines[k], flags=re.I) or
        re.match(r'^\s*$', lines[k])
    ):
        k += 1

    cleaned_sys = "\n".join(lines[k:])
    return prompt[:start] + cleaned_sys + prompt[end:]

def create_chat_prompt_grpo(example, tokenizer):
    objprompt   = json.loads(example["prompt"])
    instruction = _strip_knowledge_banner(objprompt.get("instruction") or "")
    user_input  = objprompt.get("input") or ""

    instruction += " Generate EXACTLY the JSON object {...} with no additional text before or after. End immediately after the closing brace }."
    messages = [
        {"role": "system", "content": instruction},
        {"role": "user",   "content": user_input},
    ]
    rendered = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    rendered = _strip_banner_from_rendered_prompt(rendered)
    return {"prompt": rendered}

keep_cols = {"prompt", "contract_id", "clause_id", "completion"}
train_ds = train_ds_full.map(
    lambda ex: create_chat_prompt_grpo(ex, tokenizer),
    remove_columns=[c for c in train_ds_full.column_names if c not in keep_cols]
)
eval_ds_small = eval_ds_small.map(
    lambda ex: create_chat_prompt_grpo(ex, tokenizer),
    remove_columns=[c for c in eval_ds_small.column_names if c not in keep_cols]
)

eval_ds_full = eval_ds_full.map(
    lambda ex: create_chat_prompt_grpo(ex, tokenizer),
    remove_columns=[c for c in eval_ds_full.column_names if c not in keep_cols]
)
train_ds[0]["prompt"]


'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYour task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.\n\nOutput ONLY a single, strict JSON object with this structure:\n\n{\n "contract_id": "...",\n "nodes": [ ... ],\n "edges": [ ... ]\n}\n\n \nREASONING PROCESS\n \n1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any \'noise\' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.\n2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.\n3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.\n4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Cl

In [None]:
#@title Cell 3d Training - Helpers for fetching gold completions

import re, json, ast
from typing import Tuple

_USER_BLOCK_RE = re.compile(
    r"<\|start_header_id\|>user<\|end_header_id\|>\s*(.*?)\s*<\|eot_id\|>",
    re.S
)

def _parse_user_obj(user_content: str):
    """
    Parse the user content into a dict. Try JSON first; if it fails, try Python literal.
    """
    try:
        return json.loads(user_content)
    except Exception:
        pass

    try:
        obj = ast.literal_eval(user_content)
        if isinstance(obj, dict):
            return obj
    except Exception:
        pass

    i, j = user_content.find("{"), user_content.rfind("}")
    if i != -1 and j != -1 and j > i:
        snippet = user_content[i:j+1]
        try:
            return json.loads(snippet)
        except Exception:
            try:
                obj = ast.literal_eval(snippet)
                if isinstance(obj, dict):
                    return obj
            except Exception:
                pass

    raise ValueError(f"User content not parseable as JSON/Python dict.\nHead: {user_content[:200]}")

def extract_clause_id_contract_id_from_prompt(templated_prompt: str) -> Tuple[str, str]:
    """
    Given a Llama-3.x chat-templated prompt string, extract (clause_id, contract_id)
    from the user turn JSON/dict.
    """
    m = _USER_BLOCK_RE.search(templated_prompt)
    if not m:
        raise ValueError("Could not find <user> block in the chat-templated prompt.")

    user_content = m.group(1).strip()
    obj = _parse_user_obj(user_content)

    contract_id = obj.get("contract_id")
    clause = obj.get("clause") or {}
    clause_id = clause.get("id")

    if not contract_id or not clause_id:
        raise ValueError(f"Missing ids in user payload. Got keys: {list(obj.keys())}")

    return str(clause_id), str(contract_id)

cid, contract = extract_clause_id_contract_id_from_prompt(train_ds[0]["prompt"])
print(cid, contract)

10.3 AIRSPANNETWORKSINC_04_11_2000-EX-10.5-Distributor Agreement


In [None]:
#@title Cell 3e Training - Helpers for fetching gold completions
# build this once at startup:
# gold_index = build_gold_index(ds_train_full)

def fetch_gold_completions(clause_id: str, contract_id: str, gold_index: Dict[Tuple[str, str], Any]):
    """
    Look up the gold completion by (contract_id, clause_id).
    Provide `gold_index` if you built it (recommended). If not provided, raise.
    """
    if gold_index is None:
        raise ValueError("gold_index is required. Build it with build_gold_index(ds_train_full) and pass it in.")
    return gold_index.get((str(contract_id), str(clause_id)))

gold = fetch_gold_completions(cid, contract, full_index)
gold_obj = json.loads(gold)
print(gold_obj.get("edges"))

[{'src': '10.3', 'tgt': '10', 'type': 'IS_PART_OF'}, {'src': '10.3', 'tgt': 'party:Airspan', 'type': 'MENTIONS_PARTY'}, {'src': '10.3', 'tgt': 'party:Distributor', 'type': 'MENTIONS_PARTY'}, {'src': '10.3', 'tgt': 'term:Agreement', 'type': 'USES'}, {'src': '10.3', 'tgt': 'term:Confidential Information', 'type': 'USES'}]


In [None]:
print(eval_ds_full)

Dataset({
    features: ['contract_id', 'clause_id', 'prompt', 'completion'],
    num_rows: 188
})


In [None]:
# --- Verify the Output ---
print("\n--- Example of a Refactored Entry ---")
example_entry = eval_ds_full[0]

print("\n[START OF PROMPT]")
print(example_entry["prompt"][:500] + "...")
print("\n[END OF PROMPT]")
print(example_entry["prompt"][250:] + "...")

print("\n[completion]")
print(example_entry["completion"])

print("\n[CLAUSE_ID]")
print(example_entry["clause_id"])

print("\n[CONTRACT_ID]")
print(example_entry["contract_id"])




--- Example of a Refactored Entry ---

[START OF PROMPT]
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core c...

[END OF PROMPT]
ss, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Crea

In [None]:
# Get the column name for the ground-truth JSON
gold_key = "completion" if "completion" in train_ds.column_names else "completion"
prompt_lengths = [len(tokenizer.encode(text)) for text in train_ds["prompt"]]
# Calculate the token length for each completion
token_lengths = [len(tokenizer.encode(text)) for text in train_ds["completion"]]

# --- Analyze the results ---
import numpy as np

max_len = np.max(token_lengths)
mean_len = np.mean(token_lengths)
p95_len = np.percentile(token_lengths, 95)
p99_len = np.percentile(token_lengths, 99)

prmax_len = np.max(prompt_lengths)
prmean_len = np.mean(prompt_lengths)
prp95_len = np.percentile(prompt_lengths, 95)
prp99_len = np.percentile(prompt_lengths, 99)

print(f"Analysis of Ground-Truth Completion Lengths (in tokens):")
print(f"Max length: {max_len}")
print(f"Mean length: {mean_len:.2f}")
print(f"95th percentile: {p95_len}")
print(f"99th percentile: {p99_len}")
print("\n--- Recommendation ---")
print(f"Set 'max_new_tokens' to a value safely above the max, like {int(max_len * 1.1)} or {int(p99_len * 1.1)}")
print(f"Analysis of Training Prompt Lengths (in tokens):")
print(f"Prompt Max length: {prmax_len}")
print(f"Prompt Mean length: {prmean_len:.2f}")
print(f"Prompt 95th percentile: {prp95_len}")
print(f"Prompt 99th percentile: {prp99_len}")

Analysis of Ground-Truth Completion Lengths (in tokens):
Max length: 469
Mean length: 264.57
95th percentile: 409.45000000000005
99th percentile: 432.8900000000001

--- Recommendation ---
Set 'max_new_tokens' to a value safely above the max, like 515 or 476
Analysis of Training Prompt Lengths (in tokens):
Prompt Max length: 1191
Prompt Mean length: 1052.99
Prompt 95th percentile: 1129.0
Prompt 99th percentile: 1150.0


In [None]:
#@title Cell 4a Training - GRPO Configurations

train_generation_kwargs = {
    "do_sample": True, "temperature": 0.4, "top_p": 0.9,
    "use_cache": True,
    "eos_token_id": [tokenizer.eos_token_id],
    "pad_token_id": tokenizer.pad_token_id,
    "max_new_tokens": 400,
    "eos_token_id": [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")],
    "repetition_penalty": 1.2,
    "length_penalty": 1.0,
}

eval_generation_kwargs = {
    "do_sample": False,
    "temperature": 0.0,
    "top_p": 1.0, "top_k": 0,
    "max_new_tokens": 400,
    "eos_token_id": [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")],
    "early_stopping": True,
    "repetition_penalty": 1.2,
    "length_penalty": 1.0,
}


In [None]:
#@title Cell 4b Training - GRPO Configurations in trl
from transformers import GenerationConfig
from trl import GRPOConfig, GRPOTrainer


grpo_cfg = GRPOConfig(
    run_name="customReward_customEval",
    report_to=["wandb"],
    max_steps=1000,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=2e-6,
    bf16=True,

    num_generations=4,
    max_prompt_length=1280,
    max_completion_length=500,
    generation_kwargs=train_generation_kwargs,
    generation_batch_size=4,

    logging_strategy="steps",
    logging_steps=1,
    eval_strategy="no",
    save_strategy="no",
    remove_unused_columns=False,
    dataloader_num_workers=2,
    torch_empty_cache_steps=1,
)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
grpo_cfg.generation_kwargs.update({
    "eos_token_id": tokenizer.convert_tokens_to_ids(tokenizer.eos_token),
    "pad_token_id": tokenizer.convert_tokens_to_ids(tokenizer.pad_token),
})


In [None]:
#@title Cell 4c Training - Custom Stopping Criteria

import re, json, torch
from transformers import StoppingCriteria

class SmartNewJsonStopper(StoppingCriteria):
    """
    Regex-only stopper:
      1) cut to text AFTER assistant header
      2) find:  "nodes": [ ... ], "edges":
         (optionally require non-empty nodes)
      3) then see a final `]}`
    """
    def __init__(
        self,
        tokenizer,
        min_new_chars: int = 80,
        max_scan_chars: int = 6000,
        require_nonempty_nodes: bool = True,
        debug: bool = False,
    ):
        self.tok = tokenizer
        self.min_new_chars = int(min_new_chars)
        self.max_scan_chars = int(max_scan_chars)
        self.require_nonempty_nodes = bool(require_nonempty_nodes)
        self.debug = bool(debug)

        self.assistant_marks = [
            "<|start_header_id|>assistant<|end_header_id|>\n\n",
            "<|assistant|>\n\n",
            "<|assistant|>",
            "\nassistant\n\n",
            "assistant\n\n",
            "assistant\n",
        ]

        if self.require_nonempty_nodes:
            self.re_nodes_edges = re.compile(
                r'"nodes"\s*:\s*\[\s*(?!\])[\s\S]*?\]\s*,\s*"edges"\s*:',
                re.DOTALL
            )
        else:
            self.re_nodes_edges = re.compile(
                r'"nodes"\s*:\s*\[[\s\S]*?\]\s*,\s*"edges"\s*:',
                re.DOTALL
            )

        self.re_close = re.compile(r'\]\s*\}')

    def _suffix_after_assistant(self, text: str) -> str:
        cut = -1
        for m in self.assistant_marks:
            j = text.rfind(m)
            if j >= 0:
                cut = max(cut, j + len(m))
        return text[cut:] if cut >= 0 else text

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        B = input_ids.shape[0]
        if self.debug:
            print(f"[stopper] called batch={B}")

        done_flags = []
        for b in range(B):
            tail = self.tok.decode(input_ids[b][-self.max_scan_chars:], skip_special_tokens=True)
            suffix = self._suffix_after_assistant(tail)

            is_this_row_done = False
            if len(suffix) >= self.min_new_chars:
                m = self.re_nodes_edges.search(suffix)
                if m and self.re_close.search(suffix[m.end():]):
                    if self.debug:
                        print(f"[stopper] TRIGGER condition met for row={b}")
                    is_this_row_done = True

            done_flags.append(is_this_row_done)

        if all(done_flags):
            if self.debug:
                print(f"--- SmartNewJsonStopper: ALL {B} sequences are done. Stopping. ---")
            return True

        return False

json_stopper = SmartNewJsonStopper(tokenizer, min_new_chars=20, max_scan_chars=4000, require_nonempty_nodes=True, debug=False)


In [None]:

#@title Cell 4d Training - Custom GRPO Trainer Class

#latest working version
from trl import GRPOTrainer
from torch.utils.data import Subset
from transformers import StoppingCriteria, StoppingCriteriaList

class GRPOTrainerWithEvalControls(GRPOTrainer):

    def __init__(self, *args, eval_generation_kwargs=None, **kwargs):
        super().__init__(*args, **kwargs)
        self._eval_generation_kwargs = eval_generation_kwargs or {}
        self._final_eval_done = False
        self._saved_small_eval = None
        self._full_eval_ds = None

    @staticmethod
    def _sanitize_sc(sc_like) -> StoppingCriteriaList:
        """Keep only real StoppingCriteria subclasses; drop base/invalid entries."""
        def _valid(c): return isinstance(c, StoppingCriteria) and (type(c) is not StoppingCriteria)
        if sc_like is None:
            keep = []
        elif isinstance(sc_like, StoppingCriteriaList):
            keep = [c for c in sc_like if _valid(c)]
        elif isinstance(sc_like, (list, tuple)):
            keep = [c for c in sc_like if _valid(c)]
        elif _valid(sc_like):
            keep = [sc_like]
        else:
            keep = []
        return StoppingCriteriaList(keep)

    def set_eval_datasets(self, small_eval_ds, full_eval_ds):
        self._saved_small_eval = small_eval_ds
        self._full_eval_ds = full_eval_ds
        self.eval_dataset = small_eval_ds

    def _generate_completions(self, dataset, *args, **kwargs):
        gen_kwargs = dict(kwargs.get("generation_kwargs", {}))
        sc = self._sanitize_sc(gen_kwargs.get("stopping_criteria", None))

        ext = getattr(self, "_ext_stopper", None)
        if ext:
            if isinstance(ext, (list, tuple)):
                for c in ext:
                    if all(c is not e for e in sc):
                        sc.append(c)
            else:
                if all(ext is not e for e in sc):
                    sc.append(ext)

        if all(getattr(e, "__class__", None) is not json_stopper.__class__ for e in sc):
            sc.append(json_stopper)

        gen_kwargs["stopping_criteria"] = sc
        kwargs["generation_kwargs"] = gen_kwargs
        return super()._generate_completions(dataset, *args, **kwargs)

    def _generate(self, *args, **kwargs):
        sc = self._sanitize_sc(kwargs.get("stopping_criteria", None))

        ext = getattr(self, "_ext_stopper", None)
        if ext:
            if isinstance(ext, (list, tuple)):
                for c in ext:
                    if all(c is not e for e in sc):
                        sc.append(c)
            else:
                if all(ext is not e for e in sc):
                    sc.append(ext)

        if all(getattr(e, "__class__", None) is not json_stopper.__class__ for e in sc):
            sc.append(json_stopper)

        kwargs["stopping_criteria"] = sc

        # train vs eval
        if self.model.training:
            kwargs.setdefault("do_sample", True)
            kwargs.setdefault("temperature", 0.001)
            kwargs.setdefault("num_return_sequences", getattr(self.args, "num_generations", 4))
            kwargs.setdefault("use_cache", False)
            kwargs.setdefault("min_new_tokens", 16)             # guard in TRAIN too
        else:
            for k, v in self._eval_generation_kwargs.items():
                kwargs.setdefault(k, v)
            kwargs.setdefault("use_cache", True)
            kwargs.setdefault("min_new_tokens", 16)

        kwargs.setdefault("max_new_tokens", getattr(self.args, "max_completion_length", 500))

        return self.model.generate(**kwargs)

    def _maybe_switch_to_full_eval(self):
        if self._final_eval_done or self._full_eval_ds is None:
            return
        self.eval_dataset = self._full_eval_ds
        self._final_eval_done = True

    def evaluation_loop(self, *args, **kwargs):
        state = getattr(self, "state", None)
        at_end = state is not None and state.global_step >= self.args.max_steps
        if at_end:
            self._maybe_switch_to_full_eval()
        return super().evaluation_loop(*args, **kwargs)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str="eval"):
        ds = eval_dataset or self.eval_dataset
        base = super().evaluate(ds, ignore_keys, metric_key_prefix)
        print(f"[gen-eval-base] {base}")
        return base

In [None]:
#@title Cell 5a Training - Reward Function Helpers

#granular bad scores
import json, re
from typing import Any, Dict, List, Tuple, Set

EDGE_TYPES = {"IS_PART_OF","REFERENCES","DEFINES","USES","MENTION_PARTY","CONTAINS"}
NODE_TYPES = {"CLAUSE","DEFINED_TERM","PARTY","VALUE"}

#helpers
def _balanced_braces(s: str) -> bool:
    cnt = 0
    for ch in s:
        if ch == "{": cnt += 1
        elif ch == "}": cnt -= 1
        if cnt < 0: return False
    return cnt == 0

def _ends_clean(s: str) -> bool:
    t = s.rstrip()
    return t.endswith("}")

def _graph_sets(obj: Dict) -> Tuple[Set[Tuple], Set[Tuple]]:
    nodes = {(n.get("id",""), (n.get("node_type","") or "").upper()) for n in obj.get("nodes", []) if isinstance(n, dict)}
    edges = {(
        e.get("src",""), e.get("tgt",""),
        (e.get("type","") or "").upper()
    ) for e in obj.get("edges", []) if isinstance(e, dict)}
    return nodes, edges

def _f1(pred: Set[Tuple], gold: Set[Tuple]) -> float:
    if not pred and not gold: return 1.0
    if not pred or not gold:  return 0.0
    tp = len(pred & gold)
    if tp == 0: return 0.0
    prec = tp / len(pred); rec = tp / len(gold)
    return 2*prec*rec/(prec+rec)

#reward
def shaped_reward(completion: str, gold: Dict, f1_weight: float = 0.30) -> float:
    s = completion or ""
    score = 0.0

    # A) token-level shape (0.12)
    t = s.lstrip()
    if t.startswith("{"): score += 0.03
    if s.rstrip().endswith("}"): score += 0.03
    if _balanced_braces(s): score += 0.04
    if "```" in s: score -= 0.02  # discourage fences

    # B) strict parse (0.10)
    parsed = None
    try:
        parsed = json.loads(s)
        score += 0.10
    except Exception:
        # no JSON → stop here; we still return early-shaping score
        return max(0.0, score)

    if not isinstance(parsed, dict):
        return max(0.0, score)  # we only accept dict at top-level

    # C) top-level keys / arrays (0.20)
    nodes = parsed.get("nodes", None)
    edges = parsed.get("edges", None)

    # nodes presence/shape
    if "nodes" in parsed: score += 0.05
    if isinstance(nodes, list): score += 0.03
    if isinstance(nodes, list) and len(nodes) > 0: score += 0.07

    # edges presence/shape
    if "edges" in parsed: score += 0.05
    if isinstance(edges, list): score += 0.03
    if isinstance(edges, list) and len(edges) > 0: score += 0.07

    # D) item quality (0.18)
    #   nodes quality (0.09)
    if isinstance(nodes, list) and len(nodes) > 0:
        good = 0
        for n in nodes:
            if isinstance(n, dict) and n.get("id") and (n.get("node_type","").upper() in NODE_TYPES):
                good += 1
        score += 0.09 * (good / max(1, len(nodes)))

    #   edges quality (0.09)
    if isinstance(edges, list) and len(edges) > 0:
        good = 0
        for e in edges:
            if (isinstance(e, dict) and e.get("src") and e.get("tgt")
                and (e.get("type","").upper() in EDGE_TYPES)):
                good += 1
        score += 0.09 * (good / max(1, len(edges)))

        # small uniqueness bonus
        if len({(e.get("src",""), e.get("tgt",""), e.get("type","")) for e in edges if isinstance(e, dict)}) == len(edges):
            score += 0.02

    # E) closure & compactness (0.10)
    if _ends_clean(s): score += 0.05
    try:
        compact = json.dumps(parsed, separators=(",",":"))
        if len(s) > 0 and (len(s) - len(compact))/len(s) < 0.15:
            score += 0.05
    except Exception:
        pass

    # F) task score vs gold (0.30 by default; tune 0.40 later in training)
    try:
        pn, pe = _graph_sets(parsed)
        if isinstance(gold, str):
            gold = json.loads(gold)
        gn, ge = _graph_sets(gold if isinstance(gold, dict) else {})
        fn, fe = _f1(pn, gn), _f1(pe, ge)
        h = 0.0 if (fn+fe) == 0.0 else (2*fn*fe)/(fn+fe)
        score += f1_weight * h
    except Exception:
        # if anything fails here, keep the structural reward only
        pass

    # clamp
    return max(0.0, min(1.0, score))


In [None]:
#@title Cell 5b Training - Reward Group Adapter

import hashlib, numpy as np
from typing import List

def _prompt_key(p: str) -> str:
    import hashlib
    return hashlib.sha1(p.encode("utf-8")).hexdigest()

def reward_group_adapter(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    num_prompts = len(prompts)
    num_completions = len(completions)
    assert num_prompts > 0 and num_completions % num_prompts == 0, \
        f"Shape mismatch: prompts={num_prompts}, completions={num_completions}"
    group_size = num_completions // num_prompts

    # fetch gold once per prompt
    gold_cache, golds = {}, []
    for p in prompts:
        k = _prompt_key(p)
        if k not in gold_cache:
            cid, contid = extract_clause_id_contract_id_from_prompt(p)
            gold_cache[k] = fetch_gold_completions(cid, contid, gold_index=full_index)
        golds.append(gold_cache[k])

    rewards = []
    per_prompt_var, per_prompt_mean = [], []
    zero_var_groups = 0

    idx = 0
    for g in golds:
        grp = completions[idx: idx + group_size]
        idx += group_size
        grp_rewards = [shaped_reward(c, g, f1_weight=0.30) for c in grp]
        rewards.extend(grp_rewards)

        v = float(np.var(grp_rewards)) if len(grp_rewards) > 1 else 0.0
        m = float(np.mean(grp_rewards))
        per_prompt_var.append(v); per_prompt_mean.append(m)
        if v == 0.0: zero_var_groups += 1

    if wandb.run:
        wandb.log({
            "train/unique_prompts_in_batch": num_prompts,
            "train/group_size": group_size,
            "train/reward_mean_batch": float(np.mean(rewards)),
            "train/reward_std_batch": float(np.std(rewards)),
            "train/per_prompt_variance_mean": float(np.mean(per_prompt_var)),
            "train/fraction_zero_variance_groups": zero_var_groups / max(1, num_prompts),
            "train/completion_len_mean_batch": float(np.mean([len(c) for c in completions])),
        })
    print(f"prompts={num_prompts}, completions={num_completions} rewards {rewards}" )

    return rewards

In [None]:
#@title Cell 5c Training - Reward Function Helpers - Embedding-Based Rewards
#One time setup
import torch
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
import json
from typing import Dict, List, Tuple, Set

print("Loading sentence-transformer model for reward calculation...")
device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
print("Embedding model loaded and moved to device:", device)

NODE_TYPES = {"CLAUSE", "DEFINED_TERM", "PARTY", "VALUE"}
EDGE_TYPES = {"IS_PART_OF", "REFERENCES", "DEFINES", "USES", "MENTIONS_PARTY", "CONTAINS"}

Loading sentence-transformer model for reward calculation...
Embedding model loaded and moved to device: cuda


In [None]:
#@title Cell 5d Training - Combined Hybrid Reward Function, Helpers, and Group Adapter

def _graph_to_canonical_string(obj: Dict) -> str:
    if not isinstance(obj, dict): return ""
    nodes_str = sorted([f"({n.get('node_type', '')}:{n.get('id', '')})" for n in obj.get("nodes", []) if isinstance(n, dict)])
    edges_str = sorted([f"({e.get('src', '')})-({e.get('type', '')})->({e.get('tgt', '')})" for e in obj.get("edges", []) if isinstance(e, dict)])
    return "NODES: " + " ".join(nodes_str) + " EDGES: " + " ".join(edges_str)


#Hybrid Reward
def hybrid_shaped_reward(
    completion_str: str,
    gold_obj: dict,
    semantic_similarity: float,
    structural_weight: float = 0.7,
    semantic_weight: float = 0.3
) -> float:
    """
    Calculates a hybrid reward based on a weighted sum of:
    1. Structural correctness (JSON parsing, schema, etc.).
    2. Semantic similarity of the raw text using pre-computed embeddings.
    """
    s = completion_str or ""
    structural_score = 0.0

    # --- Part 1: Calculate Structural Score (max value of 1.0) ---
    if s.lstrip().startswith("{"): structural_score += 0.05
    if s.rstrip().endswith("}"): structural_score += 0.05
    if _balanced_braces(s): structural_score += 0.05

    # B) Strict parse (0.35)

    parsed = None
    try:
        json_str = t[t.find('{'):t.rfind('}')+1].strip()
        obj = json.loads(json_str)
        structural_score = 0.3 if "nodes" in obj and "edges" in obj else 0.1
    except json.JSONDecodeError:
        structural_score = 0.1

    try:
        parsed = json.loads(s)
        structural_score += 0.35
    except Exception:
        # Parsing failed. We stop calculating structural score but will still use the semantic score.
        return structural_score * structural_weight + semantic_similarity * semantic_weight

    if not isinstance(parsed, dict):
        return structural_score * structural_weight + semantic_similarity * semantic_weight

    # C) Top-level keys / item quality (0.50)
    nodes = parsed.get("nodes")
    edges = parsed.get("edges")
    if isinstance(nodes, list): structural_score += 0.25
    if isinstance(edges, list): structural_score += 0.25

    # --- Part 2: Combine with Semantic Score ---
    final_score = (structural_score * structural_weight) + (semantic_similarity * semantic_weight)

    return max(0.0, min(1.0, final_score))


def reward_group_adapter(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    num_prompts = len(prompts)
    num_completions = len(completions)
    group_size = num_completions // num_prompts

    gold_cache, golds_obj = {}, []
    for p in prompts:
        k = _prompt_key(p)
        if k not in gold_cache:
            try:
                cid, contid = extract_clause_id_contract_id_from_prompt(p)
                gold_str = fetch_gold_completions(cid, contid, gold_index=full_index)
                gold_cache[k] = json.loads(gold_str) if gold_str else {}
            except (ValueError, json.JSONDecodeError, TypeError):
                gold_cache[k] = {}
        golds_obj.append(gold_cache[k])

    gold_canonical_strings = []
    for i in range(num_prompts):
        gold_canonical_str = _graph_to_canonical_string(golds_obj[i])
        gold_canonical_strings.extend([gold_canonical_str] * group_size)

    completion_embeddings = embedding_model.encode(completions, convert_to_tensor=True)
    gold_embeddings = embedding_model.encode(gold_canonical_strings, convert_to_tensor=True)

    similarities = F.cosine_similarity(completion_embeddings, gold_embeddings)
    scaled_similarities = (similarities + 1) / 2

    rewards = []
    for i in range(num_completions):
        prompt_idx = i // group_size
        reward = hybrid_shaped_reward(
            completions[i],
            golds_obj[prompt_idx],
            semantic_similarity=scaled_similarities[i].item()
        )
        rewards.append(reward)

    if wandb.run:
        wandb.log({
            "train/unique_prompts_in_batch": num_prompts,
            "train/group_size": group_size,
            "train/reward_mean_batch": float(np.mean(rewards)),
            "train/reward_std_batch": float(np.std(rewards)),
            "train/completion_len_mean_batch": float(np.mean([len(c) for c in completions])),
        })
    print(f"prompts={num_prompts}, completions={num_completions} rewards {rewards}" )

    return rewards

In [None]:
#@title Cell 6 Training - Gated Reward Function

from typing import List, Dict, Tuple
import json, numpy as np
import torch
import torch.nn.functional as F

try:
    _prompt_key
except NameError:
    def _prompt_key(p: str) -> str:
        return p.strip()

def _balanced_braces_local(s: str) -> bool:
    depth = 0
    for ch in s:
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth < 0:
                return False
    return depth == 0

def _balanced_braces(s: str) -> bool:
    try:
        return globals()["_balanced_braces"](s)
    except Exception:
        return _balanced_braces_local(s)

def _graph_to_canonical_string(obj: Dict) -> str:
    if not isinstance(obj, dict):
        return ""
    nodes_str = sorted([
        f"({n.get('node_type', '')}:{n.get('id', '')})"
        for n in obj.get("nodes", [])
        if isinstance(n, dict)
    ])
    edges_str = sorted([
        f"({e.get('src', '')})-({e.get('type', '')})->({e.get('tgt', '')})"
        for e in obj.get("edges", [])
        if isinstance(e, dict)
    ])
    return "NODES: " + " ".join(nodes_str) + " EDGES: " + " ".join(edges_str)

def _f1_from_sets(pred: set, gold: set) -> float:
    if not pred and not gold:
        return 1.0
    if not pred or not gold:
        return 0.0
    inter = len(pred & gold)
    prec  = inter / max(1, len(pred))
    rec   = inter / max(1, len(gold))
    return 0.0 if (prec + rec) == 0 else (2 * prec * rec) / (prec + rec)

def _node_sig(n: Dict) -> Tuple[str, str]:
    return (str(n.get("node_type","")).strip().lower(),
            str(n.get("id","")).strip().lower())

def _edge_sig(e: Dict) -> Tuple[str, str, str]:
    return (str(e.get("type","")).strip().lower(),
            str(e.get("src","")).strip().lower(),
            str(e.get("tgt","")).strip().lower())

def _nodes_f1(pred_nodes, gold_nodes) -> float:
    P = {_node_sig(n) for n in (pred_nodes or []) if isinstance(n, dict)}
    G = {_node_sig(n) for n in (gold_nodes or []) if isinstance(n, dict)}
    return _f1_from_sets(P, G)

def _edges_f1(pred_edges, gold_edges) -> float:
    P = {_edge_sig(e) for e in (pred_edges or []) if isinstance(e, dict)}
    G = {_edge_sig(e) for e in (gold_edges or []) if isinstance(e, dict)}
    return _f1_from_sets(P, G)

#------------------------#
#Weight schedule & gating
#------------------------#
def _weight_schedule(global_step: int,
                     warm_steps: int = 30,
                     add_nodes_at: int = 60,
                     add_edges_at: int = 90,
                     add_sem_at: int   = 120) -> Dict[str, float]:
    """
    Returns weights for {valid, schema, nodes, edges, sem} components.
    Progressively turns on harder components.
    """
    if global_step < warm_steps:
        return dict(valid=1.0, schema=0.0, nodes=0.0, edges=0.0, sem=0.0)
    if global_step < add_nodes_at:
        return dict(valid=0.7, schema=0.3, nodes=0.0, edges=0.0, sem=0.0)
    if global_step < add_edges_at:
        return dict(valid=0.4, schema=0.3, nodes=0.3, edges=0.0, sem=0.0)
    if global_step < add_sem_at:
        return dict(valid=0.3, schema=0.25, nodes=0.3, edges=0.15, sem=0.0)
    # full objective
    return dict(valid=0.2, schema=0.2, nodes=0.3, edges=0.2, sem=0.1)

def _gated_hybrid_reward(
    completion_str: str,
    gold_obj: Dict,
    semantic_similarity: float,
    *,
    global_step: int,
    nodes_gate: float = 0.60,
) -> float:
    """
    Components (each in [0,1]):
      - valid: braces/parse
      - schema: nodes/edges lists present
      - nodes: F1 over node signatures
      - edges: F1 over edge signatures (only if nodes >= nodes_gate)
      - sem:   your precomputed semantic similarity (only after JSON valid)
    """
    s = (completion_str or "").strip()
    W = _weight_schedule(global_step)

    #validity (0..1)
    starts = 1.0 if s.startswith("{") else 0.0
    ends   = 1.0 if s.endswith("}") else 0.0
    bal    = 1.0 if _balanced_braces(s) else 0.0
    soft_hints = (starts + ends + bal) / 3.0  # heuristic
    parsed_obj = None
    try:
        parsed_obj = json.loads(s)
        parsed_ok = 1.0
    except Exception:
        parsed_ok = 0.0

    valid_comp = 0.2 * soft_hints + 0.8 * parsed_ok

    if W["schema"] == W["nodes"] == W["edges"] == W["sem"] == 0.0:
        return float(np.clip(W["valid"] * valid_comp, 0.0, 1.0))

    #schema (0..1)
    schema_comp = 0.0
    nodes_pred, edges_pred = None, None
    nodes_gold, edges_gold = (gold_obj or {}).get("nodes", []), (gold_obj or {}).get("edges", [])
    if isinstance(parsed_obj, dict):
        nodes_pred, edges_pred = parsed_obj.get("nodes"), parsed_obj.get("edges")
        if isinstance(nodes_pred, list): schema_comp += 0.5
        if isinstance(edges_pred, list): schema_comp += 0.5

    #nodes/edges F1
    nodes_comp = 0.0
    edges_comp = 0.0
    if isinstance(nodes_pred, list):
        nodes_comp = _nodes_f1(nodes_pred, nodes_gold)
    if nodes_comp >= nodes_gate and isinstance(edges_pred, list):
        edges_comp = _edges_f1(edges_pred, edges_gold)

    #semantic (0..1); gate on parsed_ok
    sem_comp = float(semantic_similarity) if parsed_ok >= 1.0 else 0.0

    #weights
    reward = (
        W["valid"]  * valid_comp +
        W["schema"] * schema_comp +
        W["nodes"]  * nodes_comp +
        W["edges"]  * edges_comp +
        W["sem"]    * sem_comp
    )

    return float(np.clip(reward, 0.0, 1.0))
_gate_internal_step = 0

def _bump_gate_internal_step():
    global _gate_internal_step
    s = _gate_internal_step
    _gate_internal_step += 1
    return s

#Group adapter
def reward_group_adapter(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    Same signature you use with GRPO. Adds a step-aware, gated structure/semantic reward.
    Expects these globals/utilities to exist in your notebook:
      - extract_clause_id_contract_id_from_prompt
      - fetch_gold_completions
      - full_index
      - embedding_model  (SentenceTransformer-like)
      - wandb            (optional)
    """
    global_step = kwargs.get("global_step", None)
    if global_step is None:
        global_step = _bump_gate_internal_step()
    else:
        global_step = int(global_step)

    num_prompts = len(prompts)
    num_completions = len(completions)
    assert num_prompts > 0 and num_completions % num_prompts == 0, \
        "completions must be a multiple of prompts (group sampling)."
    group_size = num_completions // num_prompts

    gold_cache: Dict[str, Dict] = {}
    golds_obj: List[Dict] = []
    for p in prompts:
        k = _prompt_key(p)
        if k not in gold_cache:
            try:
                clause_id, contract_id = extract_clause_id_contract_id_from_prompt(p)
                gold_str = fetch_gold_completions(clause_id, contract_id, gold_index=full_index)
                gold_cache[k] = json.loads(gold_str) if gold_str else {}
            except Exception:
                gold_cache[k] = {}
        golds_obj.append(gold_cache[k])

    gold_canonical_strings: List[str] = []
    for i in range(num_prompts):
        gold_canonical_str = _graph_to_canonical_string(golds_obj[i])
        gold_canonical_strings.extend([gold_canonical_str] * group_size)

    completion_embeddings = embedding_model.encode(completions, convert_to_tensor=True)
    gold_embeddings = embedding_model.encode(gold_canonical_strings, convert_to_tensor=True)
    completion_embeddings = F.normalize(completion_embeddings, p=2, dim=1)
    gold_embeddings = F.normalize(gold_embeddings, p=2, dim=1)

    cos = F.cosine_similarity(completion_embeddings, gold_embeddings)
    sem_sims = ((cos + 1.0) / 2.0).clamp(0.0, 1.0)  # to [0,1]
    sem_sims = sem_sims.detach().cpu().tolist()

    rewards: List[float] = []
    for i in range(num_completions):
        prompt_idx = i // group_size
        r = _gated_hybrid_reward(
            completion_str=completions[i],
            gold_obj=golds_obj[prompt_idx],
            semantic_similarity=sem_sims[i],
            global_step=global_step,
        )
        rewards.append(r)

    if "wandb" in globals() and getattr(wandb, "run", None):
        W = _weight_schedule(global_step)
        wandb.log({
            "train/unique_prompts_in_batch": num_prompts,
            "train/group_size": group_size,
            "train/reward_mean_batch": float(np.mean(rewards)),
            "train/reward_std_batch": float(np.std(rewards)),
            "train/completion_len_mean_batch": float(np.mean([len(c) for c in completions])),
            "train/reward_w_valid": W["valid"],
            "train/reward_w_schema": W["schema"],
            "train/reward_w_nodes": W["nodes"],
            "train/reward_w_edges": W["edges"],
            "train/reward_w_sem": W["sem"],
            "train/global_step_for_reward": global_step,
        })
    print(f"prompts={num_prompts}, completions={num_completions}, step={global_step} rewards (first 8) {rewards[:8]}")

    return rewards


In [None]:
print(type(json_stopper), issubclass(type(json_stopper), StoppingCriteria))  # should be True


<class '__main__.SmartNewJsonStopper'> True


In [None]:
#@title Cell 7a Training - Instantiate the Trainer
# --- Trainer ---
trainer = GRPOTrainerWithEvalControls(
    model=model,
    reward_funcs=[reward_group_adapter],
    args=grpo_cfg,
    train_dataset=train_ds,
    eval_dataset=eval_ds_small,
    processing_class=tokenizer,
    eval_generation_kwargs=eval_generation_kwargs,
)

In [None]:
#@title Cell 7b Training - Custom generation stoppers and masks

import re, torch
from transformers import StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList

#ASCII clamp
def _build_ascii_allow_mask(tokenizer) -> torch.BoolTensor:
    if getattr(tokenizer, "_ascii_allow_mask", None) is not None:
        return tokenizer._ascii_allow_mask
    V = tokenizer.vocab_size
    allow = torch.zeros(V, dtype=torch.bool)
    ascii_ok = re.compile(r'^[\x09\x0a\x0d\x20-\x7e]+$')
    specials = {
        getattr(tokenizer, k)
        for k in ("eos_token_id","pad_token_id","bos_token_id","unk_token_id")
        if getattr(tokenizer, k, None) is not None
    }
    for tid in range(V):
        if tid in specials:
            allow[tid] = True
            continue
        s = tokenizer.decode([tid], skip_special_tokens=True)
        if s == "" or ascii_ok.match(s):
            allow[tid] = True
    tokenizer._ascii_allow_mask = allow
    return allow

import re, torch
from transformers import LogitsProcessor, LogitsProcessorList

class AsciiClamp(LogitsProcessor):
    """
    Keeps generation in ASCII space by masking non-ASCII token pieces.
    Auto-rebuilds the mask to match the current logits vocab size.
    """
    def __init__(self, tokenizer):
        self.tok = tokenizer
        self._mask = None
        self._built_for = None
        self._neg_inf = torch.finfo(torch.float32).min
        self._ascii_ok = re.compile(r'^[\x09\x0a\x0d\x20-\x7e]+$')

    def _build_mask(self, V: int, device: torch.device):
        allow = torch.zeros(V, dtype=torch.bool)
        specials = {
            getattr(self.tok, k)
            for k in ("eos_token_id","pad_token_id","bos_token_id","unk_token_id")
            if getattr(self.tok, k, None) is not None
        }
        for tid in range(V):
            if tid in specials:
                allow[tid] = True
                continue
            try:
                s = self.tok.decode([tid], skip_special_tokens=True)
            except Exception:
                s = ""
            allow[tid] = (s == "" or self._ascii_ok.match(s) is not None)
        self._mask = allow.to(device)
        self._built_for = V

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        V = scores.shape[-1]
        if self._mask is None or self._built_for != V or self._mask.device != scores.device:
            self._build_mask(V, scores.device)
        scores[:, ~self._mask] = self._neg_inf
        return scores


#sanitizer
def _sanitize_sc(sc_like) -> StoppingCriteriaList:
    def _valid(c): return isinstance(c, StoppingCriteria) and (type(c) is not StoppingCriteria)
    out = []
    if sc_like is None:
        pass
    elif isinstance(sc_like, StoppingCriteriaList):
        out.extend([c for c in list(sc_like) if _valid(c)])
    elif isinstance(sc_like, (list, tuple, set)):
        for x in sc_like:
            if isinstance(x, StoppingCriteriaList):
                out.extend([c for c in list(x) if _valid(c)])
            elif _valid(x):
                out.append(x)
    elif _valid(sc_like):
        out.append(sc_like)
    seen, dedup = set(), []
    for c in out:
        if id(c) not in seen:
            dedup.append(c); seen.add(id(c))
    return StoppingCriteriaList(dedup)


In [None]:
#@title Cell 7c Training - Wrap generation call with custom stoppers and masks

import types

def _has_class_generate(obj) -> bool:
    try:
        return callable(getattr(obj.__class__, "generate"))
    except Exception:
        return False

def _resolve_generate_target(root):
    """
    Climb through common PEFT/Accelerate wrappers to find an object whose CLASS defines `generate`.
    Returns None if not found.
    """
    seen = set()
    def _walk(o):
        if o is None or id(o) in seen:
            return None
        seen.add(id(o))
        if _has_class_generate(o):
            return o
        # descend typical wrappers
        for attr in ("model", "module", "base_model"):
            nxt = getattr(o, attr, None)
            tgt = _walk(nxt)
            if tgt is not None:
                return tgt
        return None
    return _walk(root)

def wrap_generate_with_stopper_and_ascii(root_obj, tokenizer, stopper,
                                         *, add_ascii_in_train=True,
                                         min_new_tokens=24, max_new_tokens=350,
                                         temperature=0.2, top_p=0.9, top_k=0):
    """
    One wrapper: append JSON stopper (train+eval) + ASCII clamp (train only),
    set tight train defaults, strip unknown kwargs. Safe on PEFT/Accelerate stacks.
    """
    target = _resolve_generate_target(root_obj)
    if target is None or getattr(target, "_gen_guard_installed", False):
        return

    #reset
    target.generate = target.__class__.generate.__get__(target, target.__class__)
    class_orig = target.__class__.generate

    def _wrapped(self, *args, **kwargs):
        #sanitize
        sc = _sanitize_sc(kwargs.get("stopping_criteria", None))
        if all(c is not stopper for c in sc):
            sc.append(stopper)
        kwargs["stopping_criteria"] = sc

        #ASCII clamp
        if add_ascii_in_train and self.training:
            lp = kwargs.get("logits_processor", None)
            clamp = AsciiClamp(tokenizer)
            if lp is None:
                kwargs["logits_processor"] = LogitsProcessorList([clamp])
            else:
                try:
                    lp.append(clamp)
                    kwargs["logits_processor"] = lp
                except Exception:
                    kwargs["logits_processor"] = LogitsProcessorList(list(lp) + [clamp])

        #verify appropriate config defaults
        if self.training:
            kwargs.setdefault("do_sample", True)
            kwargs.setdefault("temperature", temperature)
            if top_k and top_k > 0:
                kwargs.setdefault("top_k", top_k); kwargs.setdefault("top_p", 1.0)
            else:
                kwargs.setdefault("top_p", top_p); kwargs.setdefault("top_k", 0)
            kwargs.setdefault("min_new_tokens", min_new_tokens)
            kwargs.setdefault("max_new_tokens", max_new_tokens)

        kwargs.pop("processing_class", None)
        kwargs.pop("generation_kwargs", None)

        return class_orig(self, *args, **kwargs)

    target.generate = _wrapped.__get__(target, target.__class__)
    target._gen_guard_installed = True
    print(f"[ok] stopper+ascii attached to {type(target).__name__}.generate")


In [None]:
#@title Cell 7d Training - Attach wrapper to trainer
wrap_generate_with_stopper_and_ascii(trainer.model, tokenizer, json_stopper)

if getattr(trainer, "accelerator", None):
    try:
        unwrapped = trainer.accelerator.unwrap_model(trainer.model)
        if unwrapped is not trainer.model:
            wrap_generate_with_stopper_and_ascii(unwrapped, tokenizer, json_stopper)
    except Exception:
        pass

if hasattr(trainer.model, "module"):
    wrap_generate_with_stopper_and_ascii(trainer.model.module, tokenizer, json_stopper)


[ok] stopper+ascii attached to PeftModelForCausalLM.generate


In [None]:
#@title Cell 8a Training Evaluation - Custom Graph Metrics For in-training evaluation

import json, re
from collections import defaultdict
from typing import List, Dict, Tuple, Set

def _safe_json(obj, max_depth=3):
    x = obj
    for _ in range(max_depth):
        if isinstance(x, (dict, list)): return x
        if x is None: return {}
        s = str(x).strip()
        i, j = s.find("{"), s.rfind("}")
        cand = s[i:j+1] if (i!=-1 and j!=-1 and j>i) else s
        try:
            x = json.loads(cand); continue
        except Exception:
            break
    return {}

def _extract_nodes(o: Dict) -> List[Dict]:
    arr = o.get("nodes", []) if isinstance(o, dict) else []
    if isinstance(arr, dict): arr = [arr]
    out=[]
    for it in arr:
        if isinstance(it, dict): out.append(it)
        elif isinstance(it, str):
            try:
                d=json.loads(it);
                if isinstance(d, dict): out.append(d)
            except: pass
    return out

def _extract_edges(o: Dict) -> List[Dict]:
    arr = o.get("edges", []) if isinstance(o, dict) else []
    if isinstance(arr, dict): arr = [arr]
    out=[]
    for it in arr:
        if isinstance(it, dict): out.append(it)
        elif isinstance(it, str):
            try:
                d=json.loads(it);
                if isinstance(d, dict): out.append(d)
            except: pass
    return out

_WS_RE  = re.compile(r"\s+")
_CO_RE  = re.compile(r"\b(inc\.?|ltd\.?|llc|l\.l\.c\.|corp\.?|co\.?|ag|gmbh)\b", re.I)
_PCT_RE = re.compile(r"(\d+(?:\.\d+)?)\s*%")
_MON_RE = re.compile(r"(\$|usd)\s*([\d,]+(?:\.\d+)?)", re.I)
_DAY_RE = re.compile(r"(\d+)\s*days?")
_YRS_RE = re.compile(r"(\d+)\s*years?")
_NUMW   = {"zero":0,"one":1,"two":2,"three":3,"four":4,"five":5,"six":6,"seven":7,"eight":8,"nine":9,
           "ten":10,"eleven":11,"twelve":12,"thirteen":13,"fourteen":14,"fifteen":15,"sixteen":16,
           "seventeen":17,"eighteen":18,"nineteen":19,"twenty":20,"thirty":30,"forty":40,"fifty":50,
           "sixty":60,"seventy":70,"eighty":80,"ninety":90}

def _norm(s):
    if not isinstance(s, str): s = str(s) if s is not None else ""
    s = s.lower().replace("&","and")
    s = _CO_RE.sub("", s)
    return _WS_RE.sub(" ", s).strip()

def _w2n(s):
    s=s.lower()
    for w,n in _NUMW.items():
        if re.search(rf"\b{w}\b", s): return n
    return None

def _canon_value(text):
    s=_norm(text)
    m=_PCT_RE.search(s)
    if m: return f"{float(m.group(1)):.0f}%"
    if "percent" in s:
        n=_w2n(s)
        if n is not None: return f"{n}%"
    m=_MON_RE.search(s)
    if m: return f"usd {m.group(2).replace(',','')}"
    m=_DAY_RE.search(s)
    if m: return f"{int(m.group(1))} days"
    m=_YRS_RE.search(s)
    if m: return f"{int(m.group(1))} years"
    return s

def _keytext(n: Dict) -> Tuple[str,str]:
    # supports "node_type" or "type"
    t=(n.get("node_type") or n.get("type") or "").upper()
    nid=str(n.get("id") or "")
    if t=="CLAUSE":
        k=n.get("id") or nid or n.get("title") or ""; return t,_norm(k)
    if t=="DEFINED_TERM":
        k=n.get("name") or nid.split(":",1)[-1]; return t,_norm(k)
    if t=="PARTY":
        k=n.get("name") or n.get("text") or nid.split(":",1)[-1]; return t,_norm(k)
    if t=="VALUE":
        k=n.get("text") or nid.split(":",1)[-1]; return t,_canon_value(k)
    k=n.get("text") or n.get("term") or n.get("name") or nid; return t,_norm(k)

def _toks(s): return re.findall(r"[a-z0-9]+", s.lower())
def _jacc(a,b):
    sa,sb=set(_toks(a)),set(_toks(b))
    if not sa and not sb: return 1.0 if a.strip()==b.strip() and a.strip()!="" else 0.0
    if not sa or not sb:  return 0.0
    return len(sa&sb)/max(1,len(sa|sb))

_THRESH={"CLAUSE":0.90,"DEFINED_TERM":0.80,"PARTY":0.85,"VALUE":0.75}
def _sim(t,a,b):
    if t=="VALUE":
        if a==b and (a.endswith("%") or a.endswith("days") or a.endswith("years") or a.startswith("usd")): return 1.0
    return _jacc(a,b)

def _bucket(nodes: List[Dict]):
    b=defaultdict(list)
    for n in nodes:
        t,kt=_keytext(n)
        if t and kt: b[t].append(kt)
    return b

def _match_type(G_list, P_list, t):
    if not G_list or not P_list: return 0
    pairs=[]
    for i,g in enumerate(G_list):
        for j,p in enumerate(P_list):
            s=_sim(t,g,p)
            if s>=_THRESH.get(t,0.8): pairs.append((s,i,j))
    pairs.sort(reverse=True)
    used_i=set(); used_j=set(); tp=0
    for s,i,j in pairs:
        if i in used_i or j in used_j: continue
        used_i.add(i); used_j.add(j); tp+=1
    return tp

def _prf1(tp,fp,fn):
    p=tp/(tp+fp) if (tp+fp) else 0.0
    r=tp/(tp+fn) if (tp+fn) else 0.0
    f=2*p*r/(p+r) if (p+r) else 0.0
    return p,r,f

def _edge_triplet(e, node_map):
    def pick(d, keys):
        return next((d[k] for k in keys if k in d and d[k] is not None), None)
    typ = (pick(e, ["type","edge_type","label"]) or "").upper()
    raw_src = pick(e, ["src","source","from"])
    raw_tgt = pick(e, ["tgt","target","to"])
    def resolve(v):
        if v is None: return ""
        v_str = str(v)
        return node_map.get(v_str, _norm(v_str))
    return (resolve(raw_src), typ, resolve(raw_tgt))

def compute_graph_metrics_on_texts(pred_texts: List[str], gold_texts: List[str]) -> Dict[str, float]:
    """
    Compute strict/fuzzy node micro-F1 and edge micro-F1 from decoded strings.
    """
    strict_tp=strict_fp=strict_fn=0
    fuzzy_tp=fuzzy_fp=fuzzy_fn=0
    e_tp=e_fp=e_fn=0
    exact=invalid=0

    def setify_strict(nodes):
        S=set()
        for nn in nodes:
            t,k=_keytext(nn)
            if t and k: S.add((t,k))
        return S

    for gstr,pstr in zip(gold_texts, pred_texts):
        G_json=_safe_json(gstr); P_json=_safe_json(pstr)
        G_nodes=_extract_nodes(G_json)
        P_nodes=_extract_nodes(P_json)

        if not isinstance(P_json, dict) or ("nodes" not in P_json and "edges" not in P_json):
            print(f"in calc metrics on eval {pstr}")
            invalid+=1

        # strict node sets
        Gs, Ps = setify_strict(G_nodes), setify_strict(P_nodes)
        if Gs==Ps: exact+=1
        strict_tp+=len(Gs&Ps); strict_fp+=len(Ps-Gs); strict_fn+=len(Gs-Ps)

        # fuzzy nodes
        Gb,_Pb = _bucket(G_nodes), _bucket(P_nodes)
        all_types = set(Gb.keys()) | set(_Pb.keys())
        for t in all_types:
            tp=_match_type(Gb.get(t,[]), _Pb.get(t,[]), t)
            fp=len(_Pb.get(t,[]))-tp; fn=len(Gb.get(t,[]))-tp
            fuzzy_tp+=tp; fuzzy_fp+=fp; fuzzy_fn+=fn

        # edges (map node ids -> canonical keytext, then compare triples)
        def build_map(nodes):
            m={}
            for n in nodes:
                nid=str(n.get("id") or "")
                if not nid: continue
                t,k=_keytext(n)
                m[nid]=f"{t}|{k}"
            return m

        Gmap, Pmap=build_map(G_nodes), build_map(P_nodes)
        Ge={_edge_triplet(e,Gmap) for e in _extract_edges(G_json)}
        Pe={_edge_triplet(e,Pmap) for e in _extract_edges(P_json)}
        e_tp+=len(Ge&Pe); e_fp+=len(Pe-Ge); e_fn+=len(Ge-Pe)

    sp,sr,sf1 = _prf1(strict_tp,strict_fp,strict_fn)
    fp_,fr_,ff1 = _prf1(fuzzy_tp,fuzzy_fp,fuzzy_fn)
    ep,er,ef1   = _prf1(e_tp,e_fp,e_fn)

    n=max(1,len(pred_texts))
    return {
        "strict_node_precision": sp, "strict_node_recall": sr, "strict_node_f1": sf1,
        "fuzzy_node_precision":  fp_, "fuzzy_node_recall": fr_, "fuzzy_node_f1": ff1,
        "edge_precision":        ep,  "edge_recall":       er,  "edge_f1":       ef1,
        "exact_graph_match_rate": exact/n,
        "invalid_json_rate":      invalid/n,
    }


In [None]:
#@title Cell 8b Training Evaluation -  Custom GRPO Evaluator Using Model Generation

import gc, time, math, numpy as np, torch, re
from typing import Dict, List
from tqdm import tqdm
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers.trainer_utils import speed_metrics
import wandb

class JsonStopper(StoppingCriteria):
    """Batch-wide stopper for left-padded prompts; stops when nodes+edges JSON is complete."""
    def __init__(self, tokenizer, input_len: int):
        self.tokenizer = tokenizer
        self.input_len = int(input_len)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        gen_ids = input_ids[0, self.input_len:]
        if gen_ids.numel() == 0:
            return False
        text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
        return bool(re.search(r'\{[^{}]*"nodes"\s*:\s*\[.*?\]\s*,\s*"edges"\s*:\s*\[.*?\]\s*\}', text, re.S))

def _safe_eos_ids(tok):
    eos = tok.eos_token_id
    return eos if isinstance(eos, list) else ([eos] if eos is not None else None)

def custom_eval_grpo(
    trainer,
    dataset,
    num_samples: int = 100,
    gen_batch_size: int = 8,
    max_new_tokens: int = 256,
    use_stop: bool = False,
    prefix: str = "gen",
    log_to_wandb: bool = True
) -> Dict[str, float]:
    """
    Deterministic eval for GRPO training:
      - left padding + fixed-width slice (no per-row masks)
      - one completion per prompt (do_sample=False)
      - computes trainer._compute_custom_graph_metrics on decoded texts
      - logs to W&B
    """
    tok   = trainer.processing_class
    model = trainer.model
    device = trainer.args.device

    n = len(dataset)
    use = min(num_samples, n)
    if use < n:
        idxs = np.random.choice(n, use, replace=False)
        ds = dataset.select(list(map(int, idxs)))
    else:
        ds = dataset

    prompts  = list(ds["prompt"])
    gold_key = "clean_completion" if "clean_completion" in ds.column_names else "completion"
    golds    = list(ds[gold_key])

    old_pad, old_trunc = tok.padding_side, getattr(tok, "truncation_side", "right")
    tok.padding_side = tok.truncation_side = "left"
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    EOS = _safe_eos_ids(tok)
    PAD = tok.pad_token_id or tok.eos_token_id

    model_was_training = model.training
    orig_use_cache = getattr(model.config, "use_cache", None)
    start = time.time()
    outs: List[str] = []

    try:
        model.eval()
        try: model.gradient_checkpointing_disable()
        except Exception: pass
        if orig_use_cache is not None:
            model.config.use_cache = True

        for i in tqdm(range(0, len(prompts), gen_batch_size), desc="[custom-eval] generate (SFT-style)"):
            batch_texts = prompts[i:i+gen_batch_size]
            batch = tok(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            input_len = batch["input_ids"].shape[1]
            stop = StoppingCriteriaList([JsonStopper(tok, input_len)]) if use_stop else StoppingCriteriaList([])

            with torch.no_grad():
                out = model.generate(
                    **batch,
                    max_new_tokens=max_new_tokens,
                    min_new_tokens=1,
                    do_sample=False, temperature=None, top_p=None, top_k=None,
                    use_cache=True,
                    eos_token_id=EOS, pad_token_id=PAD,
                    max_time=60,
                    stopping_criteria=stop,
                )

            gen = tok.batch_decode(out[:, input_len:], skip_special_tokens=True)
            outs.extend(gen)

            del out, batch
            torch.cuda.empty_cache(); gc.collect()

    finally:
        tok.padding_side, tok.truncation_side = old_pad, old_trunc
        if orig_use_cache is not None:
            model.config.use_cache = orig_use_cache
        if model_was_training:
            try:
                model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
            except TypeError:
                model.gradient_checkpointing_enable()
        torch.cuda.empty_cache(); gc.collect()

    #graph metrics
    custom = compute_graph_metrics_on_texts(outs, golds)


    total_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
    num_steps = max(1, math.ceil(len(prompts) / max(1, total_batch_size)))
    speed = speed_metrics(prefix, start, num_samples=len(prompts), num_steps=num_steps)

    out = {f"{prefix}_{k}": v for k, v in custom.items()}
    out.update(speed)

    if log_to_wandb and wandb.run:
        wandb.log(out, step=trainer.state.global_step)


    if log_to_wandb and wandb.run:
        step = getattr(trainer.state, "global_step", None)
        log_payload = out.copy()
        if extra_tags:
            log_payload.update(extra_tags)
        wandb.log(log_payload, step=int(step) if step is not None else None)
    return out



In [None]:
#@title Cell 8c Training Evaluation -  Custom GRPO Evaluator Using Trainer Generation

def custom_eval_grpo_with_trainer_generate(
    trainer,
    dataset,
    num_samples: int = 50,
    gen_batch_size: int = 4,
    max_new_tokens: int = 350,
    prefix: str = "gen",
    log_to_wandb: bool = True,
):
    """
    Deterministic GRPO eval using trainer._generate.
    Strategy:
      1) Decode full sequences and try to cut AFTER the assistant header.
      2) If no assistant header is found, fall back to per-row prompt-length slicing.
    """
    import time, math, numpy as np, torch, gc
    from tqdm import tqdm
    from transformers.trainer_utils import speed_metrics
    import wandb

    def _cut_after_assistant(text: str) -> (str, bool):
        markers = [
            "<|start_header_id|>assistant<|end_header_id|>\n\n",
            "<|assistant|>\n\n",
            "<|assistant|>",
            "\nassistant\n\n",
            "assistant\n\n",
            "assistant\n",
        ]
        cut = -1
        for m in markers:
            j = text.rfind(m)
            if j >= 0:
                cut = max(cut, j + len(m))
        if cut >= 0:
            return text[cut:].lstrip(), True
        return text.lstrip(), False

    def _prompt_lengths(input_ids: torch.Tensor, attention_mask: torch.Tensor, tok) -> list:
        if attention_mask is not None:
            return attention_mask.sum(dim=1).tolist()
        pad_id = getattr(tok, "pad_token_id", None)
        if pad_id is not None:
            lens = []
            for row in input_ids.tolist():
                n = 0
                for t in row:
                    if t != pad_id:
                        n += 1
                lens.append(n)
            return lens
        return [int(input_ids.shape[1])] * input_ids.shape[0]

    tok   = trainer.processing_class
    model = trainer.model
    device = trainer.args.device

    n = len(dataset)
    use = min(num_samples, n)
    if use < n:
        idxs = np.random.choice(n, use, replace=False)
        ds = dataset.select(list(map(int, idxs)))
    else:
        ds = dataset

    prompts  = list(ds["prompt"])
    gold_key = "clean_completion" if "clean_completion" in ds.column_names else "completion"
    golds    = list(ds[gold_key])

    old_pad, old_trunc = tok.padding_side, getattr(tok, "truncation_side", "right")
    tok.padding_side = tok.truncation_side = "left"
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    trainer._eval_generation_kwargs.setdefault("do_sample", False)
    trainer._eval_generation_kwargs.setdefault("temperature", 0.0)
    trainer._eval_generation_kwargs["max_new_tokens"] = int(max_new_tokens)

    model_was_training = model.training
    orig_use_cache = getattr(model.config, "use_cache", None)
    start = time.time()
    outs = []

    try:
        model.eval()
        try: model.gradient_checkpointing_disable()
        except Exception: pass
        if orig_use_cache is not None:
            model.config.use_cache = True

        bs = max(1, int(gen_batch_size))
        for i in tqdm(range(0, len(prompts), bs), desc="[custom-eval] trainer.generate"):
            batch_texts = prompts[i:i+bs]
            batch = tok(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)

            with torch.no_grad():
                out = trainer._generate(
                    input_ids=batch["input_ids"],
                    attention_mask=batch.get("attention_mask", None),
                )

            full_txts = tok.batch_decode(out, skip_special_tokens=True)
            cut_txts, cut_flags = [], []
            for txt in full_txts:
                c, ok = _cut_after_assistant(txt)
                cut_txts.append(c); cut_flags.append(ok)

            if not all(cut_flags):
                lens = _prompt_lengths(batch["input_ids"], batch.get("attention_mask", None), tok)
                for row_idx, ok in enumerate(cut_flags):
                    if ok:
                        continue
                    suffix = tok.decode(out[row_idx, int(lens[row_idx]):], skip_special_tokens=True)
                    cut_txts[row_idx] = suffix.lstrip()

            outs.extend(cut_txts)

            del out, batch
            torch.cuda.empty_cache(); gc.collect()

    finally:
        tok.padding_side, tok.truncation_side = old_pad, old_trunc
        if orig_use_cache is not None:
            model.config.use_cache = orig_use_cache
        if model_was_training:
            try:
                model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
            except TypeError:
                model.gradient_checkpointing_enable()
        torch.cuda.empty_cache(); gc.collect()

    custom = compute_graph_metrics_on_texts(outs, golds)

    total_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
    num_steps = max(1, math.ceil(len(prompts) / max(1, total_batch_size)))
    speed = speed_metrics(prefix, start, num_samples=len(prompts), num_steps=num_steps)

    out = {f"{prefix}_{k}": float(v) for k, v in custom.items()}
    print(f"out logged metrics: {out}")
    out.update(speed)
    return out


In [None]:
#@title Cell 8d Training Evaluation -  Context configs management for decoding during eval

from contextlib import contextmanager

@contextmanager
def eval_decode_settings(trainer, do_sample=False, temperature=0.0, max_new_tokens=256):
    old = dict(trainer.args.generation_kwargs or {})
    try:
        gk = dict(old)
        gk["do_sample"] = do_sample
        #gk["use_stop"] = True,
        gk["temperature"] = temperature
        gk["max_new_tokens"] = max_new_tokens
        trainer.args.generation_kwargs = gk
        yield
    finally:
        trainer.args.generation_kwargs = old



In [None]:
#@title Cell 8e Training Evaluation - Training Evaluation Loop

def train_eval_cycles(
    trainer,
    cycles: int = 4,
    steps_per_cycle: int = 4,
    eval_ds_small=None,
    num_samples_small: int = 10,
):
    assert eval_ds_small is not None and len(eval_ds_small) > 0, "Provide a small eval dataset."

    for c in range(cycles):
        target = trainer.state.global_step + steps_per_cycle
        trainer.args.max_steps = target
        print(f"\n==== Cycle {c+1}/{cycles}: training to global_step {target} ====")
        trainer.train()
        print(f"---- Cycle {c+1}: custom eval (subset={num_samples_small}, cap=300) ----")
        with eval_decode_settings(trainer, do_sample=False, temperature=0.0, max_new_tokens=300):
            m = custom_eval_grpo_with_trainer_generate(
                trainer,
                dataset=eval_ds_small,
                num_samples=num_samples_small,
                gen_batch_size=8,
                max_new_tokens=300,
                #use_stop=True,
                prefix="eval",
                log_to_wandb=True,
                #extra_tags={"cycle": c+1}   # NEW: log cycle as metadata
            )
        print(f"printing metrics {type(m)}")
        prefixed_metrics = {f"eval/{key}": value for key, value in m.items()}
        wandb.log(prefixed_metrics)

In [None]:
#@title Cell 9 Execute Full Training with eval cycles

train_eval_cycles(
    trainer,
    cycles=15,
    steps_per_cycle=8,
    eval_ds_small=eval_ds_small,
    num_samples_small=50
)


==== Cycle 1/15: training to global_step 8 ====


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=0 rewards (first 8) [0.06666666666666667, 0.06666666666666667, 0.06666666666666667, 0.06666666666666667]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.06666667014360428, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.06666667014360428, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 5.509327411651611, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=1 rewards (first 8) [0.06666666666666667, 0.06666666666666667, 0.06666666666666667, 0.06666666666666667]
{'loss

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:17<00:00, 36.78s/it]


in calc metrics on eval {"contract_id": "NEOMIDADELITECHNOLOGIESINC_12_15_2005-EX-16.1-DISTRIBUTOR AGREEMENT", "nodes": [{"id": "18", "node_type": "CLAUSE", "title": "18", "level": 1}, {"id": "18.2", "node_type": "CLAUSE", "title": "18.2", "level": 2}, {"id": "party:Distributor", "node_type": "PARTY", "name": "Distributor"}, {"id": "party:Licensor", "node_type": "PARTY", "name": "Licensor"}, {"id": "term:Agreement", "node_type": "DEFINED_TERM", "name": "Agreement"}, {"id": "term:Term Of The Agreement", "node_type": "DEFINED_TERM", "name": "Term Of The Agreement"}], "edges": [{"src": "18.2", "tgt": "18", "type": "IS_PART_OF"}, {"src": "18.2", "tgt": "party:Distributor", "type": "MENTIONS_PARTY"}, {"src": "18.2", "tgt": "party:Licensor", "type": "MENTIONS_PARTY"}, {"src": "18.2", "tgt": "term:Agreement", "type": "USES"}, {"src": "18.2",
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "5", "node_type": "CLAUSE", "t

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=8 rewards (first 8) [0.06666666666666667, 0.06666666666666667, 0.06666666666666667, 0.06666666666666667]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.06666667014360428, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.06666667014360428, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 5.385732650756836, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=9 rewards (first 8) [0.06666666666666667, 0.06666666666666667, 0.06666666666666667, 0.06666666666666667]
{'loss

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:09<00:00, 35.69s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "15", "node_type": "CLAUSE", "title": "15", "level": 1}, {"id": "15.1", "node_type": "CLAUSE", "title": "15.1", "level": 2}, {"id": "15.2", "node_type": "CLAUSE", "title": "15.2", "level": 2}, {"id": "15.3", "node_type": "CLAUSE", "title": "15.3", "level": 2}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Transferor", "node_type": "PARTY", "name": "Transferor"}], "edges": [{"src": "15.3", "tgt": "15", "type": "IS_PART_OF"}, {"src": "15.3", "tgt": "15.1", "type": "REFERENCES"}, {"src": "15.3", "tgt": "15.2", "type": "REFERENCES"}, {"src": "15.3", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "15.3", "
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "5.3", "node_type": "CLAUSE", "title": "5.3", "level": 2}, {"id": "5.3.1", "node_type": "CLAUSE", "title": "5.3

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=24 rewards (first 8) [0.06666666666666667, 0.06666666666666667, 0.13333333333333333, 0.06666666666666667]
{'loss': 0.0, 'grad_norm': 11.510505676269531, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.0833333432674408, 'rewards/reward_group_adapter/std': 0.03333333507180214, 'reward': 0.0833333432674408, 'reward_std': 0.03333333507180214, 'frac_reward_zero_std': 0.0, 'entropy': 5.063674449920654, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=25 rewards (first 8) [0.06666666666666667, 0.06666666666666667, 0

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:34<00:00, 39.20s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "17.2", "node_type": "CLAUSE", "title": "17.2", "level": 2}, {"id": "17.2.1", "node_type": "CLAUSE", "title": "17.2.1", "level": 3}, {"id": "defined_term:Event Of Bkc Default", "node_type": "DEFINED_TERM", "name": "Event Of Bkc Default"}, {"id": "defined_term:Bkc", "node_type": "DEFINED_TERM", "name": "Bkc"}, {"id": "defined_term:Agreement", "node_type": "DEFINED_TERM", "name": "Agreement"}, {"id": "value:sixty (60) days", "node_type": "VALUE", "unit": "Days", "text": "sixty (60) days"}], "edges": [{"src": "17.2.1", "tgt": "17.2", "type": "IS_PART_OF"}, {"src": "17.2.1", "tgt": "defined_term:Event Of Bkc Default", "type": "DEFINES"}, {"src": "17.2.1", "tgt": "defined_term:Bkc", "type": "USES"}, {"src":
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "10.2", "node_type": "CLAUSE", "title": "A

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=48 rewards (first 8) [0.09333333333333332, 0.04666666666666666, 0.04666666666666666, 0.04666666666666666]
{'loss': 0.0, 'grad_norm': 6.903642177581787, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.05833333358168602, 'rewards/reward_group_adapter/std': 0.023333333432674408, 'reward': 0.05833333358168602, 'reward_std': 0.023333333432674408, 'frac_reward_zero_std': 0.0, 'entropy': 5.688028812408447, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=49 rewards (first 8) [0.04666666666666666, 0.04666666666666666

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:27<00:00, 38.26s/it]


in calc metrics on eval {"contract_id": "DIVERSINETCORP_03_01_2012-EX-4-RESELLER AGREEMENT", "nodes": [{"id": "", "node_type": "CLAUSE", "title": null, "level": -1}, {"id": "(a", "node_type": "CLAUSE", "title": ""'"s Request, Diversinet Will Provide Reseller With Pre-Sales Consulting And Post-", "level": 0}], {"id": "party:Diversinet", "node_type": "PARTY", "name": "Diversinet"}, {"id": "party:Reseller", "node_type": "PARTY", "name": "Reseller"}], "edges": [{"src": "(a", "tgt": "", "type": "IS_PART_OF"}, {"src": "(a", "tgt": "party:Diversinet", "type": "MENTIONS_PARTY"}, {"src": "(a", "tgt": "party:Reseller", "type": "MENTIONS_PARTY"}]}
in calc metrics on eval {"contract_id": "NEOMIDADELITECHNOLOGIESINC_12_15_2005-EX-16.1-DISTRIBUTOR AGREEMENT", "nodes": [{"id": "20", "node_type": "CLAUSE", "title": "20", "level": 1}, {"id": "20.4", "node_type": "CLAUSE", "title": "20.4", "level": 2}, {"id": "defined_term:Arbitration Committee", "node_type": "DEFINED_TERM", "name": "Arbitration Committ

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=80 rewards (first 8) [0.02666666666666667, 0.02666666666666667, 0.02666666666666667, 0.02666666666666667]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.02666666731238365, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.02666666731238365, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 5.6024169921875, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=81 rewards (first 8) [0.02666666666666667, 0.02666666666666667, 0.02666666666666667, 0.02666666666666667]
{'loss

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:29<00:00, 38.54s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "11.1", "node_type": "CLAUSE", "title": "11.1", "level": 2}, {"id": "11.1.6", "node_type": "CLAUSE", "title": "11.1.6", "level": 3}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "term:Burger King Marks", "node_type": "DEFINED_TERM", "name": "Burger King Marks"}, {"id": "term:Burger King System", "node_type": "DEFINED_TERM", "name": "Burger King System"}], "edges": [{"src": "11.1.6", "tgt": "11.1", "type": "IS_PART_OF"}, {"src": "11.1.6", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "11.1.6", "tgt": "party:Franchisee", "type": "MENTIONS_PARTY"}, {"src": "11.1.6", "tgt": "
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "11.1", "node_type": "CLAUSE", "title": "11.1", "level": 2}, {"id": "11.1.

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=120 rewards (first 8) [0.013333333333333334, 0.02666666666666667, 0.013333333333333334, 0.013333333333333334]
{'loss': -0.0, 'grad_norm': 11.249459266662598, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.01666666567325592, 'rewards/reward_group_adapter/std': 0.006666666828095913, 'reward': 0.01666666567325592, 'reward_std': 0.006666666362434626, 'frac_reward_zero_std': 0.0, 'entropy': 4.9155473709106445, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=121 rewards (first 8) [0.013333333333333334, 0.01333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:22<00:00, 37.57s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "17.2", "node_type": "CLAUSE", "title": "17.2", "level": 2}, {"id": "19", "node_type": "CLAUSE", "title": "19", "level": 1}, {"id": "19.6", "node_type": "CLAUSE", "title": "19.6", "level": 2}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "term:Agreement", "node_type": "DEFINED_TERM", "name": "Agreement"}], "edges": [{"src": "19.6", "tgt": "17.2", "type": "REFERENCES"}, {"src": "19.6", "tgt": "19", "type": "IS_PART_OF"}, {"src": "19.6", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "19.6", "tgt": "party:Franchisee", "type": "MENTIONS_PARTY"}, {"src": "19.6",
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "11.1", "node_type": "CLAUSE", "title": "11.1", "level": 2}, {"id": "11.1.8", "node_type"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=168 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.02666666666666667, 0.013333333333333334]
{'loss': -0.0, 'grad_norm': 8.392840385437012, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.01666666567325592, 'rewards/reward_group_adapter/std': 0.006666666828095913, 'reward': 0.01666666567325592, 'reward_std': 0.006666666362434626, 'frac_reward_zero_std': 0.0, 'entropy': 5.403186321258545, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=169 rewards (first 8) [0.013333333333333334, 0.0133333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:26<00:00, 38.12s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "5.3", "node_type": "CLAUSE", "title": "5.3", "level": 2}, {"id": "5.3.1", "node_type": "CLAUSE", "title": "5.3.1", "level": 3}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "party:Franchisor", "node_type": "PARTY", "name": "Franchisor"}, {"id": "term:Current Image", "node_type": "DEFINED_TERM", "name": "Current Image"}], "edges": [{"src": "5.3.1", "tgt": "5.3", "type": "IS_PART_OF"}, {"src": "5.3.1", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "5.3.1", "tgt": "party:Franchisee", "type": "MENTIONS_PARTY"}, {"src": "5.3.1", "tgt": "party:Franch
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "5", "node_type": "CLAUSE", "title": "5", "level": 1}, {"id": "5.1", "node_type": "CLAUSE", "title": 

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=224 rewards (first 8) [0.02666666666666667, 0.013333333333333334, 0.013333333333333334, 0.013333333333333334]
{'loss': -0.0, 'grad_norm': 7.193343162536621, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.01666666567325592, 'rewards/reward_group_adapter/std': 0.006666666828095913, 'reward': 0.01666666567325592, 'reward_std': 0.006666666362434626, 'frac_reward_zero_std': 0.0, 'entropy': 5.374517917633057, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=225 rewards (first 8) [0.013333333333333334, 0.0133333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:20<00:00, 37.18s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "5.3", "node_type": "CLAUSE", "title": "5.3", "level": 2}, {"id": "5.3.2", "node_type": "CLAUSE", "title": "5.3.2", "level": 3}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "term:Agreement", "node_type": "DEFINED_TERM", "name": "Agreement"}, {"id": "term:Frischised Restaurant", "node_type": "DEFINED_TERM", "name": "Frischised Restaurant"}, {"id": "term:Term", "node_type": "DEFINED_TERM", "name": "Term"}], "edges": [{"src": "5.3.2", "tgt": "5.3", "type": "IS_PART_OF"}, {"src": "5.3.2", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "5.3.2", "tgt": "party:Franchisee", "type
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "10.2", "node_type": "CLAUSE", "title": "ANNUAL FINANCIAL STATEMENT", "lev

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=288 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333333334]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.013333333656191826, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.013333333656191826, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 5.367894649505615, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=289 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:34<00:00, 39.24s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "10.2", "node_type": "CLAUSE", "title": "ANNUAL FINANCIAL STATEMENT", "level": 2}, {"id": "10", "node_type": "CLAUSE", "title": "", "level": 1}, {"id": "14", "node_type": "CLAUSE", "title": "", "level": 0}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "party:Certified Public Accountant", "node_type": "PARTY", "name": "Certified Public Accountant"}, {"id": "defined_term:Fiscal Year", "node_type": "DEFINED_TERM", "name": "Fiscal Year"}, {"id": "defined_term:Franchisee", "node_type": "DEFINED_TERM", "name": "Franchisee"}, {"id": "defined_term:Franchised Restaurant", "node_type": "DEFINED_TERM", "name": "Franchised Restaurant"}, {"id": "value:ninety (90) days", "node_type": "VALUE", "unit": "Days", "text": "ninety (90) days"}], "edges": [{"src": "10
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRA

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=360 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333333334]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.013333333656191826, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.013333333656191826, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 5.6517333984375, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=361 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.01333333333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:28<00:00, 38.40s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "5.3", "node_type": "CLAUSE", "title": "5.3", "level": 2}, {"id": "5.3.2", "node_type": "CLAUSE", "title": "5.3.2", "level": 3}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "term:Agreement", "node_type": "DEFINED_TERM", "name": "Agreement"}, {"id": "term:Frischised Restaurant", "node_type": "DEFINED_TERM", "name": "Frischised Restaurant"}, {"id": "term:Term", "node_type": "DEFINED_TERM", "name": "Term"}], "edges": [{"src": "5.3.2", "tgt": "5.3", "type": "IS_PART_OF"}, {"src": "5.3.2", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "5.3.2", "tgt": "party:Franchisee", "type
in calc metrics on eval {"contract_id": "NEOMIDADELITECHNOLOGIESINC_12_15_2005-EX-16.1-DISTRIBUTOR AGREEMENT", "nodes": [{"id": "18", "node_type": "CLAUSE", "title": "18", "level": 1}, {"id": "18.2

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=440 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333333334]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.013333333656191826, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.013333333656191826, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 4.997908592224121, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=441 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.026666666666

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:20<00:00, 37.20s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "15.5.2", "node_type": "CLAUSE", "title": "15.5.2", "level": 3}, {"id": "15.5.2.2", "node_type": "CLAUSE", "title": "15.5.2.2", "level": 4}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "term:Securities Exchange Act Of 1934", "node_type": "DEFINED_TERM", "name": "Securities Exchange Act Of 1934"}], "edges": [{"src": "15.5.2.2", "tgt": "15.5.2", "type": "IS_PART_OF"}, {"src": "15.5.2.2", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "15.5.2.2", "tgt": "party:Franchisee", "type": "MENTIONS_PARTY"}, {"src": "15.5.2.2", "tgt": "
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "11.1", "node_type": "CLAUSE", "title": "11.1", "level": 2}, {"id": "11.1.8", "node_type": "CLAUSE", "title": "11.1.8", "l

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=528 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333333334]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.013333333656191826, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.013333333656191826, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 5.765183925628662, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=529 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:27<00:00, 38.22s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "11.1", "node_type": "CLAUSE", "title": "11.1", "level": 2}, {"id": "11.1.8", "node_type": "CLAUSE", "title": "11.1.8", "level": 3}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "term:Burger King Marks", "node_type": "DEFINED_TERM", "name": "Burger King Marks"}, {"id": "term:Exhibit A", "node_type": "DEFINED_TERM", "name": "Exhibit A"}], "edges": [{"src": "11.1.8", "tgt": "11.1", "type": "IS_PART_OF"}, {"src": "11.1.8", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "11.1.8", "tgt": "party:Franchisee", "type": "MENTIONS_PARTY"}, {"src": "11.1.8", "tgt": "term
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "11.1", "node_type": "CLAUSE", "title": "11.1", "level": 2}, {"id": "11.1.6", "node_type

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=624 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333333334]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.013333333656191826, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.013333333656191826, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 5.24448823928833, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=625 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.0133333333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:19<00:00, 37.13s/it]


in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "15.5.2", "node_type": "CLAUSE", "title": "15.5.2", "level": 3}, {"id": "15.5.2.2", "node_type": "CLAUSE", "title": "15.5.2.2", "level": 4}, {"id": "party:BKC", "node_type": "PARTY", "name": "BKC"}, {"id": "party:Franchisee", "node_type": "PARTY", "name": "Franchisee"}, {"id": "term:Securities Exchange Act Of 1934", "node_type": "DEFINED_TERM", "name": "Securities Exchange Act Of 1934"}], "edges": [{"src": "15.5.2.2", "tgt": "15.5.2", "type": "IS_PART_OF"}, {"src": "15.5.2.2", "tgt": "party:BKC", "type": "MENTIONS_PARTY"}, {"src": "15.5.2.2", "tgt": "party:Franchisee", "type": "MENTIONS_PARTY"}, {"src": "15.5.2.2", "tgt": "
in calc metrics on eval {"contract_id": "INTERNATIONALFASTFOODCORP_04_04_1997-EX-99-FRANCHISE AGREEMENT", "nodes": [{"id": "11.1", "node_type": "CLAUSE", "title": "11.1", "level": 2}, {"id": "11.1.8", "node_type": "CLAUSE", "title": "11.1.8", "l

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=728 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333333334]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.013333333656191826, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.013333333656191826, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 4.416050910949707, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=729 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:21<00:00, 37.30s/it]


in calc metrics on eval {"contract_id": "NEOMIDADELITECHNOLOGIESINC_12_15_2005-EX-16.1-DISTRIBUTOR AGREEMENT", "nodes": [{"id": "18", "node_type": "CLAUSE", "title": "18", "level": 1}, {"id": "18.2", "node_type": "CLAUSE", "title": "18.2", "level": 2}, {"id": "party:Distributor", "node_type": "PARTY", "name": "Distributor"}, {"id": "party:Licensor", "node_type": "PARTY", "name": "Licensor"}, {"id": "term:Agreement", "node_type": "DEFINED_TERM", "name": "Agreement"}, {"id": "term:Term Of The Agreement", "node_type": "DEFINED_TERM", "name": "Term Of The Agreement"}, {"id": "value:thirty(30) days", "node_type": "VALUE", "unit": "Days", "text": "thirty(30) days"}, {"id": "value:sixty(60) days", "node_type": "VALUE", "unit": "Days", "text": "sixty(60) days"}], "edges": [{"src": "18.2", "tgt": "18", "type": "IS_PART_OF"}, {"src": "18.2",
in calc metrics on eval {"contract_id": "NEOMIDADELITECHNOLOGIESINC_12_15_2005-EX-16.1-DISTRIBUTOR AGREEMENT", "nodes": [{"id": "20", "node_type": "CLAUSE",

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


prompts=4, completions=4, step=840 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333333334]
{'loss': 0.0, 'grad_norm': 0.0, 'learning_rate': 2e-06, 'num_tokens': 5376.0, 'completions/mean_length': 350.0, 'completions/min_length': 350.0, 'completions/max_length': 350.0, 'completions/clipped_ratio': 1.0, 'completions/mean_terminated_length': 0.0, 'completions/min_terminated_length': 0.0, 'completions/max_terminated_length': 0.0, 'rewards/reward_group_adapter/mean': 0.013333333656191826, 'rewards/reward_group_adapter/std': 0.0, 'reward': 0.013333333656191826, 'reward_std': 0.0, 'frac_reward_zero_std': 1.0, 'entropy': 4.986189365386963, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.0007621951219512195}
prompts=4, completions=4, step=841 rewards (first 8) [0.013333333333333334, 0.013333333333333334, 0.013333333333333334, 0.013333333333

[custom-eval] trainer.generate: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [04:26<00:00, 38.13s/it]

in calc metrics on eval {"contract_id": "NEOMIDADELCOLOGIESINC_12_15_2005-EX-16.1-DISTRIBUTOR AGREEMENT", "nodes": [{"id": "1", "node_type": "CLAUSE", "title": "1", "level": 1}, {"id": "1.3", "node_type": "CLAUSE", "title": "1.3", "level": 2}, {"id": "party:Distributor", "node_type": "PARTY", "name": "Distributor"}, {"id": "party:Ppg Shanghai", "node_type": "PARTY", "name": "Ppg Shanghai"}, {"id": "term:Products", "node_type": "DEFINED_TERM", "name": "Products"}, {"id": "term:TERRITORY", "node_type": "DEFINED_TERM", "name": "Territory"}], "edges": [{"src": "1.3", "tgt": "1", "type": "IS_PART_OF"}, {"src": "1.3", "tgt": "party:Distributor", "type": "MENTIONS_PARTY"}, {"src": "1.3", "tgt": "party:Ppg Shanghai", "type": "MENTIONS_PARTY"}, {"src": "1.3", "tgt": "term:Products", "type": "USES"}, {"src": "1.3", "tgt": "term:TERR
in calc metrics on eval {"contract_id": "NEOMIDADEL TECHNOLOGIES INC_12_15_2005-EX-16.1-DISTRIBUTOR AGREEMENT", "nodes": [{"id": "20", "node_type": "CLAUSE", "title"




In [None]:
#@title Cell 10 Training Checkpointing - Save model and adapter

#base model with merged LoRA adapters
output_dir = "./checkpoints/grpo_final_embed_stopper_gated_final"

#save adapter
if hasattr(model, "save_pretrained"):
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

#save QLoRA weights
if hasattr(model, "peft_config"):
    model.save_pretrained(output_dir, save_embedding_layers=True)

#save trainer
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
model.save_pretrained("./adapters")


In [None]:
wandb.finish()