In [None]:
# ======================================================================
# Step 1: Environment Setup (Definitive Fix for CUDA Binaries)
# ======================================================================
# This block is designed to be run once at the start of the session.
# It creates a clean slate and installs all dependencies in a specific
# order to ensure compiled libraries like Triton and bitsandbytes link
# correctly with the GPU environment.

import os

# --- Stage 1: Aggressive Uninstall ---
# Remove all potentially conflicting libraries to ensure a clean slate.
print("Stage 1/5: Aggressively uninstalling conflicting libraries...")
!pip uninstall -y torch torchvision torchaudio numpy pandas scikit-learn transformers datasets accelerate peft trl bitsandbytes triton > /dev/null 2>&1

# --- Stage 2: Install Core Foundation (Torch + NumPy) ---
# This provides a stable base for all other libraries.
print("Stage 2/5: Installing core foundation (PyTorch for CUDA 12.1 and NumPy)...")
!pip install -q \
  "numpy==2.0.0" \
  "torch==2.3.0" \
  "torchaudio==2.3.0" \
  "torchvision==0.18.0" --index-url https://download.pytorch.org/whl/cu121

# --- Stage 3: Install Compiled GPU Libraries (Triton and bitsandbytes) ---
# FIX: Install Triton explicitly *before* transformers.
# This is critical as transformers and bitsandbytes depend on it.
print("Stage 3/5: Installing compiled GPU libraries (Triton and bitsandbytes)...")
!pip install -q "triton==2.3.0" "bitsandbytes==0.43.3"

# --- Stage 4: Install the Hugging Face Ecosystem ---
# With the low-level libraries in place, we can now install the HF stack.
print("Stage 4/5: Installing the Hugging Face ecosystem...")
!pip install -q \
  "transformers==4.43.3" \
  "datasets==2.20.0" \
  "peft==0.12.0" \
  "accelerate==0.33.0" \
  "trl==0.9.6" \
  "wandb" \
  "scikit-learn"

# --- Stage 5: Verification and Restart ---
# Verify that all key libraries are installed correctly.
print("\n--- Verifying installed versions post-install ---")
!pip list | grep -E 'numpy|torch|triton|bitsandbytes|transformers|datasets|peft'

# CRITICAL: Force the runtime to restart. This is non-negotiable for compiled
# libraries to be correctly loaded into the session.
print("\nEnvironment setup complete. Forcing a runtime restart to load new libraries...")
os._exit(0)

Stage 1/5: Aggressively uninstalling conflicting libraries...
Stage 2/5: Installing core foundation (PyTorch for CUDA 12.1 and NumPy)...
[31mERROR: Could not find a version that satisfies the requirement numpy==2.0.0 (from versions: 1.26.2, 1.26.3, 2.1.2)[0m[31m
[0m[31mERROR: No matching distribution found for numpy==2.0.0[0m[31m
[0mStage 3/5: Installing compiled GPU libraries (Triton and bitsandbytes)...
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
umap-learn 0.5.9.post2 requires scikit-learn>=1.6, which is not installed.
dask-cudf-cu12 25.6.0 requires pandas<2.2.4dev0,>=2.0, which is not installed.
mlxtend 0.23.4 requires pandas>=0.24.2, which is not installed.
mlxtend 0.23.4 requires scikit-learn>=1.3.1, which is not installed.
geopandas 1.1.1 requires pandas>=2.0.0, which is not installed.
sklearn-pandas 2.2.0 requires pandas>=1.1.4, which i

In [8]:
# ======================================================================
# Step 2: Imports, Configuration, and W&B Login
# ======================================================================
import json
import os
import torch
import wandb
from functools import partial
from collections import defaultdict, Counter
import numpy as np

from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainerCallback

# --- Configuration ---
HF_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
HF_DATASET_ID = "moriyad/clause_minigraph_builder_final"
MAX_SEQ_LEN = 4096

# --- W&B Login ---
# It's recommended to set the WANDB_API_KEY in Colab Secrets
try: wandb.finish()
except: pass

# set project/name BEFORE trainer creation
os.environ["WANDB_PROJECT"] = "llama31-minigraph-sys-harmonic"
run_name = "llama31-minigraph-sys-harmonic"

try:
    from google.colab import userdata
    wandb.login(key=userdata.get('WANDB_API_KEY'))

except (ImportError, userdata.SecretNotFoundError):
    print("W&B key not found in Colab secrets. Please log in manually.")
    wandb.login()
    wandb.init(project=os.environ["WANDB_PROJECT"], name=run_name, reinit=True)

# --- Device Check ---
assert torch.cuda.is_available(), "CUDA is not available. Please ensure you have a GPU runtime."
print(f"GPU detected: {torch.cuda.get_device_name(0)}")

0,1
custom_eval/gen_edge_f1,▁█
custom_eval/gen_edge_precision,▁█
custom_eval/gen_edge_recall,▁█
custom_eval/gen_node_f1,▁█
custom_eval/gen_node_precision,▁█
custom_eval/gen_node_recall,▁█
custom_eval/gen_num_samples,▁▁
eval/loss,█▁
eval/runtime,▁█
eval/samples_per_second,█▁

0,1
custom_eval/gen_edge_f1,0.66279
custom_eval/gen_edge_precision,0.7037
custom_eval/gen_edge_recall,0.62637
custom_eval/gen_node_f1,0.78392
custom_eval/gen_node_precision,0.82979
custom_eval/gen_node_recall,0.74286
custom_eval/gen_num_samples,16.0
eval/loss,0.02815
eval/runtime,60.7483
eval/samples_per_second,3.934


W&B key not found in Colab secrets. Please log in manually.


GPU detected: NVIDIA A100-SXM4-40GB


In [9]:
import json

# --- Robust JSON Parsing ---
def deep_safe_json(x, max_depth=3):
    """
    Handles potentially double-encoded JSON strings by attempting to decode
    up to `max_depth` times. Returns a dict or list if successful.
    """
    obj = x
    for _ in range(max_depth):
        if isinstance(obj, (dict, list)):
            return obj
        if isinstance(obj, str):
            s = obj.strip()
            b, e = s.find("{"), s.rfind("}")
            cand = s[b:e+1] if (b != -1 and e != -1 and e > b) else s
            try:
                obj = json.loads(cand)
                continue
            except (json.JSONDecodeError, TypeError):
                break
        break
    return {} # Return an empty dict if all parsing fails

# --- Combined Data Cleaning and Refactoring Function ---
def refactor_dataset(example):
    """
    This single function handles all the cleaning and restructuring.
    1. Parses the 'prompt' string into its instruction and input parts.
    2. Parses and cleans the 'completion' string.
    3. Returns a new structure with clean, separated fields.
    """
    # --- Process the Prompt ---
    prompt_obj = deep_safe_json(example["prompt"])

    # Extract the instruction and input parts
    instruction = prompt_obj.get("instruction", "")
    input_data = prompt_obj.get("input", {})

    # Store the clean, separated parts
    example['clean_instruction'] = instruction
    example['clean_input'] = json.dumps(input_data, ensure_ascii=False)

    # --- Process the Completion ---
    completion_obj = deep_safe_json(example["completion"])

    # Ensure 'nodes' and 'edges' keys exist and are lists
    if "nodes" not in completion_obj or not isinstance(completion_obj.get("nodes"), list):
        completion_obj["nodes"] = []
    if "edges" not in completion_obj or not isinstance(completion_obj.get("edges"), list):
        completion_obj["edges"] = []

    # Store the clean completion as a canonical JSON string
    example['clean_completion'] = json.dumps(completion_obj, ensure_ascii=False)

    return example

# --- Load and Refactor the Dataset ---
print("Loading and refactoring the dataset...")
dataset = load_dataset(HF_DATASET_ID)

# Apply the single cleaning function to all splits
# Chaining the operations is more efficient
cleaned_dataset = dataset.map(refactor_dataset, num_proc=2)

print("\nDataset after refactoring:")
print(cleaned_dataset)

# --- Define Chat Template Formatting ---
def create_chat_format(example, tokenizer):
    """
    Converts a single example into the Llama-3.1 chat template format.
    Uses the new 'clean_*' fields we created.
    """
    return {
        "text": tokenizer.apply_chat_template(
            [
                # The 'system' role uses the cleaned instruction
                {"role": "system", "content": example["clean_instruction"]},

                # The 'user' role uses the cleaned input
                {"role": "user", "content": example["clean_input"]},

                # The 'assistant' role uses the cleaned completion
                {"role": "assistant", "content": example["clean_completion"]},
            ],
            tokenize=False,
            add_generation_prompt=False, # Important for training
        )
    }

# --- Verify the Output ---
print("\n--- Example of a Refactored Entry ---")
example_entry = cleaned_dataset["train"][0]

print("\n[SYSTEM PROMPT / INSTRUCTION]")
print(example_entry["clean_instruction"][:500] + "...") # Print first 500 chars
SYSTEM_PROMPT = example_entry["clean_instruction"]
print("\n[USER PROMPT / INPUT]")
print(example_entry["clean_input"])

print("\n[ASSISTANT RESPONSE / COMPLETION]")
print(example_entry["clean_completion"])

Loading and refactoring the dataset...

Dataset after refactoring:
DatasetDict({
    train: Dataset({
        features: ['id', 'contract_id', 'clause_id', 'prompt', 'completion', 'clean_instruction', 'clean_input', 'clean_completion'],
        num_rows: 1918
    })
    validation: Dataset({
        features: ['id', 'contract_id', 'clause_id', 'prompt', 'completion', 'clean_instruction', 'clean_input', 'clean_completion'],
        num_rows: 239
    })
    test: Dataset({
        features: ['id', 'contract_id', 'clause_id', 'prompt', 'completion', 'clean_instruction', 'clean_input', 'clean_completion'],
        num_rows: 228
    })
})

--- Example of a Refactored Entry ---

[SYSTEM PROMPT / INSTRUCTION]
Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:


In [10]:

# Load tokenizer and apply chat formatting
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token # Llama doesn't have a pad token by default

# Apply formatting to get the final 'text' column for the trainer
train_dataset = cleaned_dataset["train"].map(
    partial(create_chat_format, tokenizer=tokenizer),
    num_proc=2,
    remove_columns=cleaned_dataset["train"].column_names,
)
eval_dataset = cleaned_dataset["validation"].map(
    partial(create_chat_format, tokenizer=tokenizer),
    num_proc=2,
    remove_columns=cleaned_dataset["validation"].column_names,
)
print("\nTraining and validation datasets formatted for SFT.")


Training and validation datasets formatted for SFT.


In [11]:
validation_data = cleaned_dataset["validation"]

# Get the column name for the ground-truth JSON
gold_key = "clean_completion" if "clean_completion" in validation_data.column_names else "completion"
prompt_lengths = [len(tokenizer.encode(text)) for text in validation_data["prompt"]]
# Calculate the token length for each completion
token_lengths = [len(tokenizer.encode(text)) for text in validation_data[gold_key]]

# --- Analyze the results ---
import numpy as np

max_len = np.max(token_lengths)
mean_len = np.mean(token_lengths)
p95_len = np.percentile(token_lengths, 95)
p99_len = np.percentile(token_lengths, 99)

prmax_len = np.max(prompt_lengths)
prmean_len = np.mean(prompt_lengths)
prp95_len = np.percentile(prompt_lengths, 95)
prp99_len = np.percentile(prompt_lengths, 99)

print(f"Analysis of Ground-Truth Completion Lengths (in tokens):")
print(f"Max length: {max_len}")
print(f"Mean length: {mean_len:.2f}")
print(f"95th percentile: {p95_len}")
print(f"99th percentile: {p99_len}")
print("\n--- Recommendation ---")
print(f"Set 'max_new_tokens' to a value safely above the max, like {int(max_len * 1.1)} or {int(p99_len * 1.1)}")
print(f"Analysis of Training Prompt Lengths (in tokens):")
print(f"Prompt Max length: {prmax_len}")
print(f"Prompt Mean length: {prmean_len:.2f}")
print(f"Prompt 95th percentile: {prp95_len}")
print(f"Prompt 99th percentile: {prp99_len}")

Analysis of Ground-Truth Completion Lengths (in tokens):
Max length: 586
Mean length: 298.08
95th percentile: 515.5
99th percentile: 572.0

--- Recommendation ---
Set 'max_new_tokens' to a value safely above the max, like 644 or 629
Analysis of Training Prompt Lengths (in tokens):
Prompt Max length: 1448
Prompt Mean length: 1141.52
Prompt 95th percentile: 1321.1999999999998
Prompt 99th percentile: 1408.24


In [12]:
# ======================================================================
# Step 4: Model Loading and QLoRA Configuration
# ======================================================================

# --- BitsAndBytes Configuration for 4-bit Quantization ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# --- Load the Model ---
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    HF_MODEL_ID,
    quantization_config=bnb_config,
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
print("Model loaded successfully.")

# If not already done where you created `tokenizer`:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# --- LoRA Configuration ---
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
# --- Reduce memory at train time ---
# (A) Turn OFF KV cache during training (required when using gradient checkpointing)
model.config.use_cache = False

# (B) Enable gradient checkpointing with the new API kwarg to silence the warning
try:
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
except TypeError:
    # fallback for older transformers
    model.gradient_checkpointing_enable()

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

Loading base model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded successfully.
trainable params: 20,971,520 || all params: 8,051,232,768 || trainable%: 0.2605


In [14]:
# ==============================================================================
# Step 5 (Final, Refactored Version): Self-Contained Custom Trainer
# ==============================================================================
from trl import SFTTrainer
import torch
import gc
import json
import re
import numpy as np
from collections import defaultdict
from tqdm.notebook import tqdm # Use notebook-friendly progress bar

class SFTTrainerWithGeneration(SFTTrainer):
    def __init__(self, *args, **kwargs):
        # --- Pop our custom arguments before calling the parent ---
        self.eval_ds_raw = kwargs.pop("eval_dataset_raw", None)
        self.sys_prompt = kwargs.pop("system_prompt", "")
        self.gen_max_samples = kwargs.pop("gen_max_samples", 32)
        self.gen_batch_size = kwargs.pop("gen_batch_size", 2)
        self.gen_max_new_tokens = kwargs.pop("gen_max_new_tokens", 650)

        super().__init__(*args, **kwargs)

    # ==================================================================
    # === Helper Methods for Parsing, Normalization, and Scoring ===
    # ==================================================================

    def _deep_safe_json(self, x, max_depth=3):
        obj = x
        for _ in range(max_depth):
            if isinstance(obj, (dict, list)): return obj
            if obj is None: return {}
            s = str(obj).strip()
            b, e = s.find("{"), s.rfind("}")
            cand = s[b:e+1] if (b!=-1 and e!=-1 and e>b) else s
            try: obj = json.loads(cand); continue
            except: break
        return {}

    def _extract_nodes(self, obj):
        parsed_obj = self._deep_safe_json(obj)
        nodes = parsed_obj.get("nodes", []) if isinstance(parsed_obj, dict) else []
        if isinstance(nodes, dict): nodes = [nodes]
        return [n for n in nodes if isinstance(n, dict)] if isinstance(nodes, list) else []

    def _extract_edges(self, obj):
        parsed_obj = self._deep_safe_json(obj)
        edges = parsed_obj.get("edges", []) if isinstance(parsed_obj, dict) else []
        if isinstance(edges, dict): edges = [edges]
        return [e for e in edges if isinstance(e, dict)] if isinstance(edges, list) else []

    COMPANY_SUFFIX_RE = re.compile(r"\\b(inc\\.?|ltd\\.?|llc|l\\.l\\.c\\.|corp\\.?|co\\.?|ag|gmbh)\\b", re.I)
    WS_RE = re.compile(r"\\s+")
    def _norm(self, s):
        if not isinstance(s, str): s = str(s) if s is not None else ""
        s = s.lower().replace("&","and")
        s = self.COMPANY_SUFFIX_RE.sub("", s)
        return self.WS_RE.sub(" ", s).strip()

    def _toks(self, s): return re.findall(r"[a-z0-9]+", s.lower())
    def _jacc(self, a,b):
        sa,sb=set(self._toks(a)),set(self._toks(b))
        if not sa or not sb: return 1.0 if a.strip()==b.strip() and a.strip()!="" else 0.0
        return len(sa&sb)/max(1,len(sa|sb))

    def _get_node_keytext(self, n):
        # This function creates the canonical text representation for a node for FUZZY matching
        t = (n.get("node_type") or n.get("type") or "").upper()
        nid = n.get("id", "") or ""
        # Simplified for clarity, assuming 'id' is the primary source
        if t == "DEFINED_TERM": key = nid.split(":",1)[-1]
        elif t == "PARTY": key = nid.split(":",1)[-1]
        elif t == "VALUE": key = nid.split(":",1)[-1]
        else: key = nid
        return t, self._norm(key)

    def _get_strict_node_id(self, node):
        """Creates a unique, hashable representation for STRICT node matching."""
        if not isinstance(node, dict): return None
        node_id = self._norm(node.get("id"))
        node_type = self._norm(node.get("node_type") or node.get("type")).upper()
        if node_id and node_type:
            return (node_type, node_id)
        return None

    def _get_strict_edge_triplet(self, edge):
        """Creates a unique, hashable representation for STRICT edge matching."""
        if not isinstance(edge, dict): return None
        src = self._norm(edge.get("src"))
        tgt = self._norm(edge.get("tgt"))
        typ = self._norm(edge.get("type")).upper()
        if src and tgt and typ: return (src, typ, tgt)
        return None

    THRESH={"CLAUSE":0.90,"DEFINED_TERM":0.80,"PARTY":0.85,"VALUE":0.75}
    def _sim(self, t, a, b):
        return self._jacc(a,b) # Simplified for now, can be expanded like your original

    def _bucket_nodes(self, nodes):
        b=defaultdict(list)
        for n in nodes:
            t, kt = self._get_node_keytext(n)
            if t and kt: b[t].append(kt)
        return b

    def _match_type(self, G_list, P_list, t):
        if not G_list or not P_list: return 0
        pairs=[]
        for i,g in enumerate(G_list):
            for j,p in enumerate(P_list):
                s=self._sim(t,g,p)
                if s>=self.THRESH.get(t,0.8): pairs.append((s,i,j))
        pairs.sort(reverse=True)
        used_i, used_j, tp = set(), set(), 0
        for s,i,j in pairs:
            if i in used_i or j in used_j: continue
            used_i.add(i); used_j.add(j); tp+=1
        return tp

    def _prf1(self,tp,fp,fn):
        p=tp/(tp+fp) if (tp+fp) > 0 else 0.0
        r=tp/(tp+fn) if (tp+fn) > 0 else 0.0
        f1=2*p*r/(p+r) if (p+r) > 0 else 0.0
        return p,r,f1

    # <<< NEW HELPER: Get a canonical representation of an edge for fuzzy matching >>>
    def _get_edge_keytext(self, edge):
        """Creates a canonical text representation of an edge's components."""
        typ = (edge.get("type") or "").upper()
        src = self._norm(edge.get("src", ""))
        tgt = self._norm(edge.get("tgt", ""))
        if not (typ and src and tgt):
            return None, None
        return typ, (src, tgt) # Return a tuple of (src, tgt) for matching

    # <<< NEW HELPER: Bucket edges by their type >>>
    def _bucket_edges(self, edges):
        b = defaultdict(list)
        for e in edges:
            typ, key = self._get_edge_keytext(e)
            if typ and key:
                b[typ].append(key)
        return b

    # <<< NEW HELPER: Fuzzy match lists of edge tuples (src, tgt) >>>
    def _match_edges_by_type(self, G_list, P_list):
        if not G_list or not P_list: return 0
        pairs = []
        for i, (g_src, g_tgt) in enumerate(G_list):
            for j, (p_src, p_tgt) in enumerate(P_list):
                # An edge matches if both its source and target are similar
                sim_src = self._jacc(g_src, p_src)
                sim_tgt = self._jacc(g_tgt, p_tgt)
                # We use the minimum of the two similarities
                # An edge is only as good as its weakest link
                score = min(sim_src, sim_tgt)
                if score > 0.85: # Use a relatively high threshold for edges
                    pairs.append((score, i, j))

        pairs.sort(reverse=True)
        used_i, used_j, tp = set(), set(), 0
        for s, i, j in pairs:
            if i in used_i or j in used_j: continue
            used_i.add(i); used_j.add(j); tp += 1
        return tp
    # ==================================================================
    # === Generation and Evaluation Logic ===
    # ==================================================================

    def _generate_texts(self, model, prompts):
        tok = self.tokenizer
        if tok.pad_token is None: tok.pad_token = tok.eos_token
        tok.padding_side = "left"

        outs = []
        print(f"\\n[gen-eval] Starting generation for {len(prompts)} samples with batch size {self.gen_batch_size}...")
        for i in tqdm(range(0, len(prompts), self.gen_batch_size), desc="Generation"):

            chunk = prompts[i:i+self.gen_batch_size]
            print(f"\\n[gen-eval] sample prompt {self.sys_prompt} sample content {chunk[0]}...")
            chat_prompts = [
                tok.apply_chat_template([
                    {"role": "system", "content": self.sys_prompt},
                    {"role": "user", "content": p}
                ], tokenize=False, add_generation_prompt=True) for p in chunk
            ]
            batch = tok(chat_prompts, return_tensors="pt", padding=True, truncation=True, max_length=self.args.max_seq_length).to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    **batch,
                    max_new_tokens=self.gen_max_new_tokens,
                    pad_token_id=tok.pad_token_id,
                    eos_token_id=tok.eos_token_id,
                    do_sample=False, # Use greedy decoding for consistent eval
                )

            generated_tokens = outputs[:, batch['input_ids'].shape[1]:]
            decoded = tok.batch_decode(generated_tokens, skip_special_tokens=True)
            outs.extend(decoded)

            del outputs, batch
            torch.cuda.empty_cache()
            gc.collect()

        print("[gen-eval] Generation finished.")
        tok.padding_side = "right"
        return outs

    def compute_graph_metrics(self, model):
        """Calculates F1 scores for both nodes and edges."""
        n = len(self.eval_ds_raw)
        use = min(self.gen_max_samples, n)
        idxs = np.random.choice(n, use, replace=False) if use < n else np.arange(n)
        sub = self.eval_ds_raw.select(idxs)

        prompts = sub["clean_input"]
        golds = sub["clean_completion"]

        preds = self._generate_texts(model, prompts)

        # Initialize counters for micro-averaging
        s_node_tp, s_node_fp, s_node_fn = 0,0,0
        f_node_tp, f_node_fp, f_node_fn = 0,0,0
        s_edge_tp, s_edge_fp, s_edge_fn = 0,0,0
        f_edge_tp, f_edge_fp, f_edge_fn = 0,0,0

        for gold_str, pred_str in zip(golds, preds):
            # --- Strict Node Comparison ---
            gold_nodes_s = {self._get_strict_node_id(n) for n in self._extract_nodes(gold_str) if self._get_strict_node_id(n)}
            pred_nodes_s = {self._get_strict_node_id(n) for n in self._extract_nodes(pred_str) if self._get_strict_node_id(n)}
            s_node_tp += len(gold_nodes_s & pred_nodes_s)
            s_node_fp += len(pred_nodes_s - gold_nodes_s)
            s_node_fn += len(gold_nodes_s - pred_nodes_s)

            # --- Fuzzy Node Comparison ---
            gold_nodes_f_raw = self._extract_nodes(gold_str)
            pred_nodes_f_raw = self._extract_nodes(pred_str)
            gold_edges_raw = self._extract_edges(gold_str)
            pred_edges_raw = self._extract_edges(pred_str)

            gold_buckets = self._bucket_nodes(gold_nodes_f_raw)
            pred_buckets = self._bucket_nodes(pred_nodes_f_raw)
            all_types = set(gold_buckets.keys()) | set(pred_buckets.keys())
            for t in all_types:
                tp = self._match_type(gold_buckets.get(t, []), pred_buckets.get(t, []), t)
                f_node_tp += tp
                f_node_fp += len(pred_buckets.get(t, [])) - tp
                f_node_fn += len(gold_buckets.get(t, [])) - tp

            # --- Strict Edge Comparison ---
            gold_edges_s = {self._get_strict_edge_triplet(e) for e in self._extract_edges(gold_str) if self._get_strict_edge_triplet(e)}
            pred_edges_s = {self._get_strict_edge_triplet(e) for e in self._extract_edges(pred_str) if self._get_strict_edge_triplet(e)}
            s_edge_tp += len(gold_edges_s & pred_edges_s)
            s_edge_fp += len(pred_edges_s - gold_edges_s)
            s_edge_fn += len(gold_edges_s - pred_edges_s)

            gold_edge_buckets = self._bucket_edges(gold_edges_raw)
            pred_edge_buckets = self._bucket_edges(pred_edges_raw)
            all_edge_types = set(gold_edge_buckets.keys()) | set(pred_edge_buckets.keys())

            for typ in all_edge_types:
                tp = self._match_edges_by_type(gold_edge_buckets.get(typ, []), pred_edge_buckets.get(typ, []))
                f_edge_tp += tp
                f_edge_fp += len(pred_edge_buckets.get(typ, [])) - tp
                f_edge_fn += len(gold_edge_buckets.get(typ, [])) - tp

        # --- Calculate Final Scores ---
        node_p_s, node_r_s, node_f1_s = self._prf1(s_node_tp, s_node_fp, s_node_fn)
        node_p_f, node_r_f, node_f1_f = self._prf1(f_node_tp, f_node_fp, f_node_fn)
        edge_p_s, edge_r_s, edge_f1_s = self._prf1(s_edge_tp, s_edge_fp, s_edge_fn)
        edge_p_f, edge_r_f, edge_f1_f = self._prf1(f_edge_tp, f_edge_fp, f_edge_fn)

        # <<< NEW: Calculate Composite Score using FUZZY metrics >>>
        graph_harmonic_f1 = 0.0
        if (node_f1_f + edge_f1_f) > 0:
            graph_harmonic_f1 = (2 * node_f1_f * edge_f1_f) / (node_f1_f + edge_f1_f)

        return {
            "node_strict_f1": node_f1_s, "node_strict_precision": node_p_s, "node_strict_recall": node_r_s,
            "node_fuzzy_f1": node_f1_f, "node_fuzzy_precision": node_p_f, "node_fuzzy_recall": node_r_f,
            "edge_strict_f1": edge_f1_s, "edge_strict_precision": edge_p_s, "edge_strict_recall": edge_r_s,
            "num_gen_samples": len(preds), "graph_harmonic_f1": graph_harmonic_f1,
        }

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        # Run standard loss-based evaluation
        base_metrics = super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

        # Run custom generation-based evaluation
        original_use_cache = self.model.config.use_cache
        self.model.config.use_cache = True

        gen_metrics = {}
        try:
            with torch.no_grad():
                gen_metrics = self.compute_graph_metrics(self.model)
        finally:
            self.model.config.use_cache = original_use_cache
            gc.collect()
            torch.cuda.empty_cache()

        # Merge metrics
        prefixed_gen_metrics = {f"{metric_key_prefix}_{k}": v for k, v in gen_metrics.items()}
        base_metrics.update(prefixed_gen_metrics)
        wandb.log({**{f"custom_eval/{k}": v for k,v in gen_metrics.items()}},step=trainer.state.global_step)
        wandb.log(gen_metrics, step=self.state.global_step)
        print(f"DEBUG: Metrics from standard evaluation (Step {self.state.global_step}):")
        print(json.dumps(gen_metrics, indent=2))
        return base_metrics

In [None]:
from subprocess import run
# ======================================================================
# Step 6: SFT Trainer Configuration and Execution (Corrected)
# ======================================================================
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM # <-- CORRECTED IMPORT

# --- Define the Response Template for the Collator ---
response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"

collator = DataCollatorForCompletionOnlyLM(
    tokenizer=tokenizer,
    response_template=response_template,
    mlm=False,
)

MAX_SEQ_LEN = 2048

# --- SFTConfig ---
sft_config = SFTConfig(
    output_dir="/content/outputs",
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LEN,
    run_name=run_name,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=16,
    learning_rate=2e-5,
    num_train_epochs=3,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    bf16=True,
    logging_steps=20,
    eval_strategy="steps",
    eval_steps=50,
    save_steps=100,
    save_strategy="steps",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_graph_harmonic_f1",
    greater_is_better=True,
    report_to=["wandb"],
    gradient_checkpointing=True,
    packing=False,
)

# --- Initialize the Trainer ---
trainer = SFTTrainerWithGeneration(
    model=model,
    tokenizer=tokenizer,
    args=sft_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collator,

    # --- Custom arguments for our new trainer ---
    eval_dataset_raw=cleaned_dataset["validation"],
    system_prompt=SYSTEM_PROMPT,
    gen_max_samples=16,
    gen_batch_size=4,
    gen_max_new_tokens=650
)

# --- Start Training ---
print("Starting training...")
trainer.train()

  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


Map:   0%|          | 0/1918 [00:00<?, ? examples/s]

Map:   0%|          | 0/239 [00:00<?, ? examples/s]

Starting training...




Step,Training Loss,Validation Loss
50,0.0484,0.040046
100,0.0249,0.028333
150,0.0201,0.025096
200,0.0205,0.022825
250,0.0179,0.02203
300,0.0154,0.021616
350,0.0171,0.02155


\n[gen-eval] Starting generation for 16 samples with batch size 4...


Generation:   0%|          | 0/4 [00:00<?, ?it/s]

\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their



\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their

Generation:   0%|          | 0/4 [00:00<?, ?it/s]

\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their



\n[gen-eval] Starting generation for 16 samples with batch size 4...


Generation:   0%|          | 0/4 [00:00<?, ?it/s]

\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their



\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their

Generation:   0%|          | 0/4 [00:00<?, ?it/s]

\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their



\n[gen-eval] Starting generation for 16 samples with batch size 4...


Generation:   0%|          | 0/4 [00:00<?, ?it/s]

\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their



\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their

Generation:   0%|          | 0/4 [00:00<?, ?it/s]

\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their



\n[gen-eval] Starting generation for 16 samples with batch size 4...


Generation:   0%|          | 0/4 [00:00<?, ?it/s]

\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their



\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their

TrainOutput(global_step=357, training_loss=0.026393890714778954, metrics={'train_runtime': 7872.6113, 'train_samples_per_second': 0.731, 'train_steps_per_second': 0.045, 'total_flos': 3.6567072772263936e+17, 'train_loss': 0.026393890714778954, 'epoch': 2.978102189781022})

In [None]:
loss_metrics = trainer.evaluate()

\n[gen-eval] Starting generation for 16 samples with batch size 4...


Generation:   0%|          | 0/4 [00:00<?, ?it/s]

\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their



\n[gen-eval] sample prompt Your task is to act as a legal graph extractor. From this single clause, create a self-contained set of nodes and edges that are explicitly supported by the text. Follow the reasoning process, rules, and clarifications below.

Output ONLY a single, strict JSON object with this structure:

{
 "contract_id": "...",
 "nodes": [ ... ],
 "edges": [ ... ]
}

 
REASONING PROCESS
 
1.  **Isolate Core Text:** First, mentally separate the core contractual prose from any 'noise' like Tables of Contents, redaction headers, or formatting artifacts. Your analysis should ONLY focus on the contractual prose.
2.  **Create Primary Node:** Create the `CLAUSE` node for the clause you were given.
3.  **Infer and Create Parent Node:** Analyze the clause `id` and `level` to infer the parent clause ID. Create the parent `CLAUSE` node.
4.  **Scan and Create Nodes:** Read the core text to identify all other entities (Referenced Clauses, Defined Terms, Parties, Values) and create their

In [None]:
# built-in loss metrics
loss_metrics

{'eval_loss': 0.022824838757514954,
 'eval_runtime': 60.8951,
 'eval_samples_per_second': 3.925,
 'eval_steps_per_second': 0.985,
 'epoch': 2.978102189781022,
 'eval_node_strict_f1': 0.7624309392265194,
 'eval_node_strict_precision': 0.7931034482758621,
 'eval_node_strict_recall': 0.7340425531914894,
 'eval_node_fuzzy_f1': 0.7624309392265194,
 'eval_node_fuzzy_precision': 0.7931034482758621,
 'eval_node_fuzzy_recall': 0.7340425531914894,
 'eval_edge_strict_f1': 0.7199999999999999,
 'eval_edge_strict_precision': 0.75,
 'eval_edge_strict_recall': 0.6923076923076923,
 'eval_num_gen_samples': 16,
 'eval_graph_harmonic_f1': 0.7406082289803221}

In [None]:
final_gen = trainer.compute_node_metrics(trainer.model)
final_gen



[gen-eval] Starting generation for 16 samples with batch size 4...
[gen-eval] Generation finished.


{'gen_strict_micro_precision': 0.0,
 'gen_strict_micro_recall': 0.0,
 'gen_strict_micro_f1': 0.0,
 'gen_fuzzy_micro_precision': 0.0,
 'gen_fuzzy_micro_recall': 0.0,
 'gen_fuzzy_micro_f1': 0.0,
 'gen_exact_match': 1.0,
 'gen_invalid_json_rate': 0.0,
 'gen_num_samples': 16}

In [None]:
trainer.log({**{f"hf/{k}": v for k,v in loss_metrics.items()},
             **{f"final/{k}": v for k,v in final_gen.items()}})
try:
    wandb.log({**{f"hf/{k}": v for k,v in loss_metrics.items()},
               **{f"final/{k}": v for k,v in final_gen.items()}},
              step=trainer.state.global_step)
except:
    pass

SyntaxError: invalid syntax. Perhaps you forgot a comma? (ipython-input-3290176872.py, line 2)

In [None]:
import os, torch, json, time
from pathlib import Path

RUN_TAG     = time.strftime("%Y%m%d-%H%M%S")
ADAPTER_DIR = Path(f"/content/outputs/adapter-{RUN_TAG}")
ADAPTER_DIR.mkdir(parents=True, exist_ok=True)

# 1) save PEFT adapter + tokenizer
trainer.model.save_pretrained(ADAPTER_DIR)     # saves adapter_config.json + adapter_model.safetensors
tokenizer.save_pretrained(ADAPTER_DIR)

# 2) (nice to have) save training state and args in same folder
trainer.save_state()                           # writes under sft_config.output_dir
with open(ADAPTER_DIR / "sft_config.json", "w") as f:
    f.write(sft_config.to_json_string())

print("Saved adapter to:", ADAPTER_DIR)


Saved adapter to: /content/outputs/adapter-20250827-050412


In [None]:
from google.colab import drive
drive.mount('/content/drive')

ValueError: mount failed

In [None]:
import time, json, shutil
from pathlib import Path

RUN_TAG     = time.strftime("%Y%m%d-%H%M%S")
ADAPTER_DIR = Path(f"/content/drive/MyDrive/outputs-node-extractor/adapter-{RUN_TAG}")
ADAPTER_DIR.mkdir(parents=True, exist_ok=True)

# 1) save PEFT adapter + tokenizer
trainer.model.save_pretrained(ADAPTER_DIR)      # adapter_model.safetensors + adapter_config.json
tokenizer.save_pretrained(ADAPTER_DIR)

# 2) (nice to have) save training state & args (note: save_state writes under sft_config.output_dir)
trainer.save_state()                            # writes to /content/outputs (or your output_dir)
with open(ADAPTER_DIR / "sft_config.json", "w") as f:
    f.write(sft_config.to_json_string())

# 3) record the exact base model id (critical for later GRPO loads)
with open(ADAPTER_DIR / "base_model.json", "w") as f:
    json.dump({"base_model": HF_MODEL_ID}, f)

# 4) (optional) save generation config, if available (keeps eval/inference defaults consistent)
try:
    trainer.model.generation_config.save_pretrained(ADAPTER_DIR)
except Exception:
    pass

# 5) (optional) tiny metadata for reproducibility
# make sure targets is list (not set)
targets = getattr(lora_config, "target_modules", None)
if isinstance(targets, set):
    targets = sorted(list(targets))
elif targets is None:
    targets = []

meta = {
    "lora": {
        "r": getattr(lora_config, "r", None),
        "alpha": getattr(lora_config, "lora_alpha", None),
        "dropout": getattr(lora_config, "lora_dropout", None),
        "targets": targets,                     # <-- now JSON-serializable
    },
    "max_seq_length": getattr(sft_config, "max_seq_length", None),
    "system_prompt": SYS_PROMPT if "SYS_PROMPT" in globals() else None,
    "response_template": "<|start_header_id|>assistant<|end_header_id|>\n\n",
    "quantization": {
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "compute_dtype": "bfloat16"
    },
}
with open(ADAPTER_DIR / "meta.json", "w") as f:
    json.dump(meta, f, ensure_ascii=False, indent=2)
print("meta.json written")

