In [2]:
# demo_infer_from_last_tests.py
# Use local base model "base_llm/" + TWO LoRA adapters.
# For EACH test JSONL, take the LAST example (more realistic), strip the gold assistant,
# generate, and print model output next to the gold.

import json
from pathlib import Path
from typing import List, Dict

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# ====== EDIT THESE PATHS ONLY ======
BASE_DIR = "base_llm/"
ADAPTER_WITH_CONTEXT = "./lora_cmas_with_context"
ADAPTER_NO_CONTEXT   = "./lora_cmas_no_context"
TEST_WITH_JSONL = "lora_cmas_with_context.test.jsonl"
TEST_NO_JSONL   = "lora_cmas_no_context.test.jsonl"
MAX_NEW_TOKENS = 200
# ===================================

def read_last_jsonl(path: str):
    last_obj = None
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if s:
                last_obj = json.loads(s)
    if last_obj is None:
        raise RuntimeError(f"No examples found in: {path}")
    return last_obj

def strip_assistant_tail(messages: List[Dict[str,str]]) -> List[Dict[str,str]]:
    return messages[:-1] if messages and messages[-1].get("role","").lower()=="assistant" else messages

def extract_gold(messages: List[Dict[str,str]]) -> str:
    return messages[-1]["content"].strip() if messages and messages[-1].get("role","").lower()=="assistant" else ""

def load_base_and_tokenizer():
    tok = AutoTokenizer.from_pretrained(BASE_DIR, use_fast=True, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    base = AutoModelForCausalLM.from_pretrained(
        BASE_DIR,
        device_map="auto",
        trust_remote_code=True,
    ).eval()
    return tok, base

def attach_adapter(base_model, adapter_dir: str):
    return PeftModel.from_pretrained(base_model, adapter_dir).eval()

def generate(tok, model, messages: List[Dict[str,str]]):
    prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=0.5,
            do_sample=True,
            eos_token_id=tok.eos_token_id,
        )
    gen_ids = out[0, inputs["input_ids"].shape[-1]:]
    return tok.decode(gen_ids, skip_special_tokens=True).strip()

@torch.no_grad()
def run_one(test_path: str, adapter_path: str, label: str, tok=None):
    obj = read_last_jsonl(test_path)
    full_msgs = obj["messages"]
    gold = extract_gold(full_msgs)
    msgs = strip_assistant_tail(full_msgs)

    # (Re)load base for each adapter to avoid clashes
    if tok is None:
        tok, base = load_base_and_tokenizer()
    else:
        _, base = load_base_and_tokenizer()
    model = attach_adapter(base, adapter_path)

    pred = generate(tok, model, msgs)

    print(f"\n===== {label} =====")
    print(f"Test file: {test_path}")
    # Show brief prompt context (system + last user)
    sys = next((m["content"] for m in msgs if m["role"].lower()=="system"), "")
    users = [m["content"] for m in msgs if m["role"].lower()=="user"]
    print("\n--- PROMPT (system + last user) ---")
    if sys:
        print(f"[SYSTEM] {sys[:200]}{'...' if len(sys)>200 else ''}")
    if users:
        u = users[-1]
        print(f"[USER] {u[:400]}{'...' if len(u)>400 else ''}")
    print("\n--- MODEL OUTPUT ---")
    print(pred if pred else "(empty)")
    if gold:
        print("\n--- GOLD (last example) ---")
        print(gold[:400] + ("..." if len(gold)>400 else ""))

if __name__ == "__main__":
    tokenizer, _ = load_base_and_tokenizer()  # load once; bases are reloaded per adapter
    run_one(TEST_WITH_JSONL, ADAPTER_WITH_CONTEXT, "WITH-CONTEXT ADAPTER", tok=tokenizer)
    run_one(TEST_NO_JSONL,   ADAPTER_NO_CONTEXT,   "NO-CONTEXT ADAPTER",   tok=tokenizer)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.



===== WITH-CONTEXT ADAPTER =====
Test file: lora_cmas_with_context.test.jsonl

--- PROMPT (system + last user) ---
[SYSTEM] You are a CMAS goal synthesis assistant. Use the provided Context to craft a single CMAS goal in the canonical simple form. Follow these rules: 1) Treat Context as authoritative when it conflicts with...
[USER] Context:
AGENTS
- Crane      [RESOURCE] { Description: Resource to transport products, Interfaces: CraneInterface, Skills: transportPart, Variables: atX=0, atY=0, inUse=0, setX=0, setY=0, targetX=0, targetY=0, vacuum=0 }
- Process1   [RESOURCE] { Description: Process 1, Interfaces: ProcessInterface, CraneInterface, Skills: RunProcess1, Variables: Location(x=478, y=107), OffsetLocation(x=478, y=216...

--- MODEL OUTPUT ---
ProductA: transportPart(OffsetFromLocation(55,229), fromLocation(54,74), offsetToLocation(450,216), toLocation(478,107)) ; RunProcess1 ; transportPart(OffsetFromLocation(450,216), fromLocation(478,107), offsetToLocation(650,224), toLocati

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.



===== NO-CONTEXT ADAPTER =====
Test file: lora_cmas_no_context.test.jsonl

--- PROMPT (system + last user) ---
[USER] To begin with, transport ProductA to the process area, then run Process 1. Next, transport ProductA once more and execute Process 2. Finally, pack ProductA.

--- MODEL OUTPUT ---
ProductA: transportPart(OffsetFromLocation(55,223), fromLocation(71,85), offsetToLocation(450,197), toLocation(478,108)) ; RunProcess1 ; transportPart(OffsetFromLocation(450,197), fromLocation(478,108), offsetToLocation(650,225), toLocation(662,105)) ; RunProcess2 ; transportPart(OffsetFromLocation(650,225), fromLocation(662,105), offsetToLocation(945,193), toLocation(964,98)) ; packing user

ProductA: transportPart(OffsetFromLocation(55,223), fromLocation(71,85), offsetToLocation(450,197), toLocation(478,108)) ; RunProcess1 ; transportPart(OffsetFromLocation(450,197), fromLocation(478,108), offsetToLocation(650,225), toLocation(662,105)) ; RunProcess2 ; transportPart

--- GOLD (last example) 