In [1]:
!pip install transformers accelerate torch



In [2]:
# Download the dataset from GitHub
!wget -O 6GBench_3k_Validated.zip https://github.com/maferrag/6G-Bench/raw/main/Data/6GBench_3k_Validated.zip
# Create a folder for the dataset
!mkdir -p 6GBench_3k_Validated

# Unzip the dataset
!unzip -q 6GBench_3k_Validated.zip -d 6GBench_3k_Validated

# List contents to verify
!ls 6GBench_3k_Validated


--2026-02-08 17:54:48--  https://github.com/maferrag/6G-Bench/raw/main/Data/6GBench_3k_Validated.zip
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/maferrag/6G-Bench/main/Data/6GBench_3k_Validated.zip [following]
--2026-02-08 17:54:48--  https://raw.githubusercontent.com/maferrag/6G-Bench/main/Data/6GBench_3k_Validated.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24414349 (23M) [application/zip]
Saving to: ‚Äò6GBench_3k_Validated.zip‚Äô


2026-02-08 17:54:49 (168 MB/s) - ‚Äò6GBench_3k_Validated.zip‚Äô saved [24414349/24414349]

__MACOSX  mcq_questions_only


In [3]:
import gc
import torch
from typing import Optional

def unload_model(model_name: Optional[str] = None) -> None:
    """
    Remove model(s) from cache and free GPU memory.
    - If model_name is given: unload just that model.
    - If model_name is None: unload all cached models.
    """
    global MODEL_CACHE

    if model_name is None:
        keys = list(MODEL_CACHE.keys())
    else:
        keys = [model_name] if model_name in MODEL_CACHE else []

    for name in keys:
        try:
            tok, mdl = MODEL_CACHE.pop(name)
            del tok
            del mdl
            print(f"üßπ Unloaded model from memory: {name}")
        except Exception as e:
            print(f"Warning: failed to unload {name}: {e}")

    # Force Python & CUDA to release memory
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

In [4]:
import os

# üîê Hugging Face access token (READ permission)
# IMPORTANT:
#  - Use a token that has access to gated models (Llama 3.x)
#  - Rotate this token if it was ever shared
HF_TOKEN = "...."

# Make it visible to transformers
os.environ["HF_TOKEN"] = HF_TOKEN
os.environ["HUGGINGFACE_TOKEN"] = HF_TOKEN

print("HF token set in notebook:", bool(HF_TOKEN))


HF token set in notebook: True


In [5]:
from __future__ import annotations
import os, json, re, hashlib
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm

# === Configuration ===
DATA_DIR = Path("/content/6GBench_3k_Validated/mcq_questions_only")
SUMMARY_MAX_TURNS = 12

# Default model (you can override per run / in cell 7)
EVAL_MODEL = "tiiuae/Falcon-H1-Tiny-90M-Instruct"

# Simple cache so we don't reload weights every call
MODEL_CACHE: Dict[str, Tuple[AutoTokenizer, AutoModelForCausalLM]] = {}

def get_local_model(model_name: str) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
    """
    Load a HF model + tokenizer once and cache it.
    Uses GPU (A100) via device_map='auto'.
    Works for gated models (Llama 3.x) using in-notebook auth.
    """
    if model_name in MODEL_CACHE:
        return MODEL_CACHE[model_name]

    print(f"üîÑ Loading model locally: {model_name}")

    HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=HF_TOKEN,
        trust_remote_code=True,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        token=HF_TOKEN,
        device_map="auto",
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        trust_remote_code=True,
    )

    model.eval()

    # Ensure pad token exists
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token

    MODEL_CACHE[model_name] = (tokenizer, model)
    return tokenizer, model


In [6]:
from typing import Iterable
TASKS: List[Tuple[str, str, str]] = [
  (
    "T1",
    "Intent Feasibility Assessment",
    "Given a mission and a 6G intent message, determine whether the intent is feasible under current and near-future network, environmental, and policy constraints. Identify the minimal safe adjustments (e.g., speed, route, slice, autonomy level, sensing configuration) needed to satisfy constraints while preserving mission objectives as much as possible."
  ),
  (
    "T2",
    "Intent Conflict Resolution",
    "Resolve conflicts between the operator/mission intent and network or safety policy (e.g., airspace, security, energy, SLA). Decide whether to reject, modify, or conditionally approve the intent, and specify concrete policy-aligned adjustments that balance mission goals and compliance."
  ),
  (
    "T3",
    "Intent Drift Detection",
    "Detect subtle changes in mission or network intent over time (e.g., updated priorities, new safety requirements, shifted QoS targets) by comparing past and current intents and behavior. Decide whether the drift is benign, problematic, or requires renegotiation or clarification with other agents or controllers."
  ),
  (
    "T4",
    "Slice Selection Reasoning",
    "Given mission requirements and current network telemetry, choose between URLLC, eMBB, or a hybrid slice (or slice configuration) with explicit justification. Trade off latency, reliability, throughput, and robustness, and explain why alternative slices are less appropriate in the given context."
  ),
  (
    "T5",
    "Slice Switching Decision",
    "Decide whether to switch, maintain, or augment the current network slice when performance degrades, considering stability, hysteresis, mission criticality, and switching overheads. Prefer decisions that avoid unnecessary oscillations while still preventing SLA violation or safety risk."
  ),
  (
    "T6",
    "Slice Fairness vs Safety",
    "Resolve contention for slice resources among multiple agents or swarm members when their demands conflict. Balance fairness, priority levels, and safety margins, possibly degrading some agents more than others, while ensuring global mission safety and compliance with policies."
  ),
  (
    "T7",
    "Compute Placement Decision",
    "Choose where to execute AI inference or other compute tasks (onboard, edge, peer, or cloud) under latency, bandwidth, energy, model quality, and trust constraints. Justify placement by considering dynamic network conditions, SLA requirements, and potential failure modes of each location."
  ),
  (
    "T8",
    "Graceful Degradation under Edge Overload",
    "When edge resources become overloaded or unstable, decide how to gracefully degrade autonomy or service quality before SLAs are violated. Select which functions to simplify, slow down, or disable while preserving safety-critical behavior and mission viability as long as possible."
  ),
  (
    "T9",
    "Trust-Aware Offloading",
    "Evaluate whether to offload tasks or data to edge or third-party compute resources based on trust, security, and policy constraints. Decide when to reject offloading, use partial offloading, or require additional safeguards (e.g., encryption, sandboxing) despite potential performance benefits."
  ),
  (
    "T10",
    "SLA Violation Prediction",
    "Predict future SLA violations using early network and system signals such as latency trends, jitter, loss, throughput, edge load, and mission dynamics. Distinguish between transient fluctuations and meaningful trends, and indicate when preemptive mitigation is required to avoid imminent violation."
  ),
  (
    "T11",
    "Preemptive Autonomy Downgrade",
    "Before any actual failure or SLA violation occurs, decide when and how to safely downgrade autonomy or functionality based on predicted risk. Choose specific behaviors or capabilities to limit, explaining why the downgrade is justified and how it preserves overall mission safety and compliance."
  ),
  (
    "T12",
    "Conservative Continuation Decision",
    "Under uncertainty about network, sensing, or environment, decide whether to continue the mission in a conservative mode or pause/abort. Weigh incomplete or noisy evidence, risk to safety and SLA, and mission criticality, preferring nuanced partial continuation when strictly safe and justifiable."
  ),
  (
    "T13",
    "Swarm-Level Slice Negotiation",
    "Coordinate slice allocation across multiple agents or swarm members with competing demands and priorities. Decide how to negotiate and partition slice resources over time, potentially reallocating or renegotiating as conditions change while maintaining global mission performance and fairness."
  ),
  (
    "T14",
    "Scheduler Reconfiguration Adaptation",
    "When the underlying AI or network scheduler is reconfigured, updated, or replaced, maintain decision consistency and mission safety. Detect behavioral changes introduced by the new scheduler and adapt policies or intents so that overall system behavior remains coherent and policy-compliant."
  ),
  (
    "T15",
    "Decision Consistency under Replanning",
    "Ensure that decisions across multiple planning cycles or turns remain logically consistent with prior commitments, unless new evidence necessitates a change. Avoid contradictory or oscillatory decisions, and when changes are required, justify them with explicit reference to updated context or constraints."
  ),
  (
    "T16",
    "Network-Exposed Compute Marketplace",
    "Decide whether, when, and how to expose operator edge/cloud compute resources as a marketplace to third parties under current load, SLAs, and policies. Determine pricing, admission, and allocation strategies that protect critical network services while extracting value from idle capacity."
  ),
  (
    "T17",
    "Network-Knowledge RAG Augmentation",
    "Decide what network telemetry, logs, and knowledge to expose to Retrieval-Augmented Generation systems to enhance agent reasoning, under privacy, security, and latency constraints. Balance informativeness against overhead and policy limits, selecting only the most relevant and safe signals."
  ),
  (
    "T18",
    "AI Agent Identity & Onboarding",
    "Authorize, authenticate, and register AI agents (device- or network-hosted) and decide how they are represented in identity and access control systems. Define identity mapping, credentials, and onboarding flows that respect policy, security, and interoperability requirements over the agent lifecycle."
  ),
  (
    "T19",
    "AI Agent Interoperability & Federation",
    "Resolve compatibility and data-sharing decisions when multiple AI agents from different domains, operators, or networks must collaborate. Decide protocols, translation layers, and data access policies that enable coordination while respecting trust boundaries, privacy, and regulatory constraints."
  ),
  (
    "T20",
    "Agent-to-Agent Communication Management",
    "Decide routing, QoS, and security policies for horizontal traffic between AI agents, whether direct or via network relays. Prioritize flows, select paths or slices, and enforce encryption or isolation as needed to meet latency, reliability, and security goals under dynamic network conditions."
  ),
  (
    "T21",
    "Device-Network Task Offload Arbitration",
    "Choose whether and how to offload AI or compute tasks from a device to edge, peer, or cloud resources given latency, energy, model capability, and trust constraints. Consider partial offloading, model selection, and fallback strategies, ensuring that decisions remain robust to network fluctuations."
  ),
  (
    "T22",
    "Federated / Collaborative Learning Orchestration",
    "Decide when and how to schedule federated or collaborative training and model updates across devices and edge nodes. Respect privacy, regulatory constraints, bandwidth limits, and device heterogeneity, and choose update frequencies and participant sets that balance model quality and resource usage."
  ),
  (
    "T23",
    "Network-Assisted Digital Twin Control",
    "Determine how the network should provide sensing, telemetry, and control channels to maintain accurate and actionable real-time digital twins. Decide update rates, data fidelity, and control loop configurations that keep the twin synchronized without overloading the network or violating SLAs."
  ),
  (
    "T24",
    "Sensing-Enhanced Decisioning (ISAC)",
    "Choose which sensing streams (e.g., radar, RF, vision, telemetry) and fusion strategies the network should deliver to agents for time-sensitive perception and decision tasks. Trade off sensing accuracy, bandwidth, latency, and robustness, and adapt the sensing configuration as conditions evolve."
  ),
  (
    "T25",
    "AI-Agent-based Disaster / Public-Safety Coordination",
    "Coordinate multiple AI agents and UAVs during disaster or public-safety scenarios, deciding slice allocation, sensing priorities, and escalation paths. Balance competing mission goals such as search and rescue, damage assessment, and communication support under extreme and uncertain conditions."
  ),
  (
    "T26",
    "Trust-Aware Third-Party Agent Exposure",
    "Decide what level of data, APIs, and compute resources to expose to third-party agents based on trust scores, regulation, and user consent. Enforce differentiated access and isolation policies, and adapt exposure in response to observed behavior, anomalies, or changing regulatory constraints."
  ),
  (
    "T27",
    "Agent Lifecycle & Management",
    "Decide lifecycle operations for agents, including instantiation, scaling, migration, upgrade, and retirement, under operator policy and SLA constraints. Coordinate these operations with network load, security posture, and mission demands to avoid service disruption or policy violations."
  ),
  (
    "T28",
    "6G Model Training-as-a-Service Decision",
    "Decide when to accept or reject customer requests for network-facilitated model training (e.g., LLM fine-tuning) given resource availability, privacy requirements, and QoS impact. Determine appropriate training configurations, isolation levels, and scheduling so that core network services remain protected."
  ),
  (
    "T29",
    "Immersive/AR Resource Prioritization",
    "Allocate slices and edge resources to multi-modal immersive or XR/AR sessions, balancing throughput, latency, stability, and fairness across users and applications. Handle contention by prioritizing critical interactions and gracefully degrading less critical modalities when resources are constrained."
  ),
  (
    "T30",
    "Network Security Detection & Response Automation",
    "When AI-driven monitoring flags potential attacks or anomalies, decide automated detection, isolation, mitigation, and recovery actions in the network. Balance swift containment with false-positive risk, and choose responses that preserve critical services and safety while minimizing collateral impact."
  ),
]

TASK_MAP: Dict[str, Tuple[str, str]] = {tid: (name, desc) for tid, name, desc in TASKS}


In [7]:
from dataclasses import dataclass
from typing import Any
@dataclass
class MCQQuestion:
    task_id: str
    task_name: str
    source_turn: int
    question: str
    options: Dict[str, str]
    correct: str
    reason: str
    rationale_tag: str
    difficulty: str

def get_turns(ep: Dict[str, Any]) -> List[Dict[str, Any]]:
    return ep.get('dialogue', [])

def extract_min_context(ep: Dict[str, Any]) -> Dict[str, Any]:
    init_state = ep.get('initial_state', {})
    env = init_state.get('env') or {}
    airspace = init_state.get('airspace') or {}
    uav = init_state.get('uav') or {}
    policy = init_state.get('policy') or {}
    sensors = (uav.get('sensors') or {})
    return {
        'env': env,
        'airspace': {'alt_bounds': airspace.get('alt_bounds'), 'geofence': airspace.get('geofence')},
        'uav': {'pose': uav.get('pose'), 'speed_mps': uav.get('speed_mps'), 'battery_pct': (uav.get('energy') or {}).get('battery_pct'), 'sensors': sensors, 'payloads': uav.get('payloads', [])},
        'policy': policy,
        'success': ep.get('success'),
    }

def summarize_episode_for_prompt(ep: Dict[str, Any], max_turns: int = 12) -> str:
    ctx = extract_min_context(ep)
    parts: List[str] = []
    parts.append('Initial context:')
    parts.append(f"- env: {ctx['env']}")
    parts.append(f"- airspace: {ctx['airspace']}")
    parts.append(f"- uav: {ctx['uav']}")
    parts.append(f"- policy: {ctx['policy']}")
    parts.append(f"- success: {ctx['success']}")
    parts.append('')
    parts.append('Dialogue trace (truncated):')
    for t in get_turns(ep)[:max_turns]:
        turn_no = t.get('turn')
        speaker = t.get('speaker')
        intent = t.get('intent')
        acts = t.get('actions', [])
        act_summaries = []
        for a in acts:
            if a.get('type') == 'mcp':
                arg_keys = list((a.get('args') or {}).keys())
                act_summaries.append(f"mcp:{a.get('name')} args_keys={arg_keys}")
            elif a.get('type') == 'a2a':
                act_summaries.append(f"a2a:{a.get('task')} to={a.get('to')}")
        obs = t.get('obs', [])
        obs_summaries = []
        for o in obs[:3]:
            if 'tool' in o:
                res = o.get('result', {})
                status = res.get('status') if isinstance(res, dict) else None
                obs_summaries.append(f"tool:{o.get('tool')} status={status}")
            elif 'task' in o:
                obs_summaries.append(f"a2a_resp:{o.get('task')} status={o.get('status')}")
        net = t.get('net') or {}
        net_s = { 'slice': net.get('slice'), 'lat_ms': net.get('lat_ms'), 'jitter_ms': net.get('jitter_ms'), 'loss_pct': net.get('loss_pct'), 'throughput_mbps': net.get('throughput_mbps'), 'edge_load': net.get('edge_load') }
        parts.append(f"- turn={turn_no} speaker={speaker} intent={intent} actions={act_summaries} obs={obs_summaries} net={net_s}")
    return "\n".join(parts)


In [8]:
def local_chat_debug(
    model: str,
    messages: List[Dict[str, str]],
    temperature: float = 0.2,
    max_tokens: int = 10000,
    seed: Optional[int] = None,
) -> Tuple[str, Optional[Dict[str, Any]]]:
    """
    Run a local HF model in chat style and return (completion_text, debug_info).
    """
    tokenizer, lm = get_local_model(model)

    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    # Build a chat prompt
    # Prefer official chat template if available
    if hasattr(tokenizer, "apply_chat_template"):
        chat = [
            {"role": m.get("role", "user"), "content": m.get("content", "")}
            for m in messages
        ]
        prompt = tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True,
        )
    else:
        # Fallback: simple system + user concatenation
        system_parts, user_parts = [], []
        for m in messages:
            if m["role"] == "system":
                system_parts.append(m["content"])
            else:
                user_parts.append(m["content"])
        system_text = "\n\n".join(system_parts).strip()
        user_text = "\n\n".join(user_parts).strip()
        prompt = f"{system_text}\n\n{user_text}" if system_text else user_text

    # Tokenize (truncate if needed to avoid context overflow)
    max_ctx = getattr(lm.config, "max_position_embeddings", 4096)
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_ctx - max_tokens,
    ).to(lm.device)

    gen_kwargs = {
        "max_new_tokens": max_tokens,
        "do_sample": temperature > 0.0,
        "temperature": temperature,
        "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
    }

    with torch.no_grad():
        output_ids = lm.generate(**inputs, **gen_kwargs)

    full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Strip prompt echo if present
    if full_text.startswith(prompt):
        completion = full_text[len(prompt):]
    else:
        completion = full_text

    debug_info = {
        "full_text": full_text,
        "prompt": prompt,
    }
    return completion.strip(), debug_info


def build_eval_messages(
    task_id: str,
    episode_summary: str,
    question: Dict[str, Any],
) -> List[Dict[str, str]]:
    task_name, task_def = TASK_MAP[task_id]
    q_text = question['question']
    options = question['options']

    user_content: List[str] = []
    user_content.append('TARGET TASK')
    user_content.append(f'TASK_ID: {task_id}')
    user_content.append(f'TASK_NAME: {task_name}')
    user_content.append(f'DEFINITION: {task_def}')
    user_content.append('')
    user_content.append('EPISODE SUMMARY')
    user_content.append(episode_summary)
    user_content.append('')
    user_content.append('QUESTION')
    user_content.append(q_text)
    user_content.append('')
    user_content.append('OPTIONS')
    for k in ['A', 'B', 'C', 'D']:
        user_content.append(f"{k}: {options[k]}")
    user_content.append('')
    user_content.append(
        'Respond ONLY with valid JSON of the form {"answer": "A"} where answer is one of "A", "B", "C", or "D".'
    )

    system_msg = (
        'You are an expert 6G network AI agent evaluator. '
        'You will answer a multiple-choice question (A/B/C/D) about a UAV mission episode. '
        'Use the episode context and the target task definition to choose the best answer. '
        'You MUST respond with a single JSON object of the form {"answer": "A"}.'
    )

    return [
        {'role': 'system', 'content': system_msg},
        {'role': 'user', 'content': '\n'.join(user_content)},
    ]


def parse_mcq_answer_from_json(raw) -> Optional[str]:
    if isinstance(raw, dict):
        obj = raw
    else:
        try:
            obj = json.loads(raw)
        except Exception:
            obj = None
    if isinstance(obj, dict):
        for key in ['answer', 'choice', 'label', 'option']:
            val = obj.get(key)
            if isinstance(val, str) and val.strip() in {'A', 'B', 'C', 'D'}:
                return val.strip()

    text = raw if isinstance(raw, str) else str(raw)
    m = re.search(r'"answer"\s*:\s*"([ABCD])"', text)
    if m:
        return m.group(1)
    m2 = re.search(r'\b([ABCD])\b', text)
    if m2:
        return m2.group(1)
    return None


In [9]:
def load_episode_mcq_pairs(data_dir: Path) -> List[Tuple[str, Dict[str, Any], Dict[str, Any]]]:
    episodes: Dict[str, Path] = {}
    mcqs: Dict[str, Path] = {}

    # Collect all episode and MCQ files
    for p in data_dir.glob('*.episode.json'):
        base = p.name.replace('.episode.json', '')
        episodes[base] = p

    for p in data_dir.glob('*.mcq.json'):
        base = p.name.replace('.mcq.json', '')
        mcqs[base] = p

    # Only keep keys that have both episode and mcq
    common_keys = sorted(set(episodes.keys()) & set(mcqs.keys()))
    pairs: List[Tuple[str, Dict[str, Any], Dict[str, Any]]] = []

    for key in common_keys:
        with episodes[key].open('r', encoding='utf-8') as f_ep:
            ep = json.load(f_ep)
        with mcqs[key].open('r', encoding='utf-8') as f_q:
            mcq = json.load(f_q)
        episode_id = mcq.get('episode_id', key)
        pairs.append((episode_id, ep, mcq))

    print(f'Found {len(pairs)} episode/MCQ pairs in {data_dir}')
    return pairs

def eval_model_on_dir(
    data_dir: Path,
    model: str,
    max_pairs: Optional[int] = None,
    seed_base: Optional[int] = 42,
) -> Dict[str, Any]:
    # KEEP using episode_mcq_pairs (unchanged)
    pairs = load_episode_mcq_pairs(data_dir)
    if max_pairs is not None:
        pairs = pairs[:max_pairs]
    # --- count total questions across all episodes ---
    total_questions_all = sum(
        len(mcq.get("questions", [])) for _, _, mcq in pairs
    )

    avg_q_per_episode = (
        total_questions_all / len(pairs) if pairs else 0.0
    )

    print(
        f"Total episodes: {len(pairs)} | "
        f"Total questions (all episodes): {total_questions_all} | "
        f"Avg questions/episode: {avg_q_per_episode:.2f}"
    )

    total = 0
    correct = 0
    per_task_stats: Dict[str, Dict[str, int]] = {}
    per_question_records: List[Dict[str, Any]] = []

    # Text progress bar (works well in VS Code)
    pbar = tqdm(
        total=len(pairs),
        desc="Eval episodes",
        unit="ep",
    )

    for episode_id, ep, mcq in pairs:
        episode_summary = summarize_episode_for_prompt(ep, max_turns=SUMMARY_MAX_TURNS)
        questions = mcq.get('questions', [])

        for q in questions:
            task_id = q['task_id']
            gold = q['correct']

            messages = build_eval_messages(task_id, episode_summary, q)

            seed = None
            if seed_base is not None:
                s = f"{seed_base}:{episode_id}:{task_id}".encode('utf-8')
                seed = int(hashlib.sha256(s).hexdigest()[:8], 16)

            raw_text, api_parsed = local_chat_debug(
                model=model,
                messages=messages,
                temperature=0.0,
                max_tokens=10000,  # adjust if you need longer answers
                seed=seed,
            )


            extracted_content = None
            if isinstance(api_parsed, dict):
                try:
                    extracted_content = api_parsed['choices'][0]['message']['content']
                except Exception:
                    extracted_content = None

            parse_input = extracted_content if extracted_content is not None else raw_text
            pred = parse_mcq_answer_from_json(parse_input)

            is_correct = int(pred == gold)
            total += 1
            correct += is_correct

            stats = per_task_stats.setdefault(task_id, {'total': 0, 'correct': 0})
            stats['total'] += 1
            stats['correct'] += is_correct

            per_question_records.append({
                'episode_id': episode_id,
                'task_id': task_id,
                'task_name': q.get('task_name'),
                'difficulty': q.get('difficulty'),
                'pred': pred,
                'gold': gold,
                'correct_flag': is_correct,
                'raw_model_output': raw_text,
                'api_parsed_json': api_parsed,
                'extracted_content': extracted_content,
                'question': q.get('question'),
                'options': q.get('options'),
                'source_turn': q.get('source_turn'),
            })

        # === REAL-TIME METRICS ===
        incorrect = total - correct
        current_acc = correct / total if total else 0.0

        pbar.update(1)
        pbar.set_postfix(
            q=total,
            ok=correct,
            wrong=incorrect,
            acc=f"{current_acc:.3f}",
        )

        tqdm.write(
            f"[running] episodes={pbar.n}/{pbar.total} | "
            f"questions={total} | overall_acc={current_acc:.3f}"
        )

    pbar.close()

    overall_acc = correct / total if total else 0.0
    per_task_acc = {
        tid: s['correct'] / s['total'] if s['total'] else 0.0
        for tid, s in per_task_stats.items()
    }

    results = {
        'model': model,
        'overall_accuracy': overall_acc,
        'per_task_accuracy': per_task_acc,
        'total_questions': total,
        'records': per_question_records,
    }
    return results


In [10]:
def print_task_question_counts(data_dir: Path) -> Dict[str, int]:
    """
    Print how many MCQ questions exist per task across all episodes.
    Returns a dict {task_id: count}.
    """
    pairs = load_episode_mcq_pairs(data_dir)

    task_counts: Dict[str, int] = {}
    total_questions = 0

    for _, _, mcq in pairs:
        for q in mcq.get("questions", []):
            task_id = q.get("task_id", "UNKNOWN")
            task_counts[task_id] = task_counts.get(task_id, 0) + 1
            total_questions += 1

    print("\nQuestion count per task:")
    print("-" * 60)

    for task_id, count in sorted(task_counts.items()):
        task_name, _ = TASK_MAP.get(task_id, ("?", ""))
        print(f"{task_id:>3} | {task_name:<45} | {count:5d}")

    print("-" * 60)
    print(f"TOTAL QUESTIONS: {total_questions}")

    return task_counts


In [11]:
task_counts = print_task_question_counts(DATA_DIR)


Found 488 episode/MCQ pairs in /content/6GBench_3k_Validated/mcq_questions_only

Question count per task:
------------------------------------------------------------
 T1 | Intent Feasibility Assessment                 |   115
T10 | SLA Violation Prediction                      |   115
T11 | Preemptive Autonomy Downgrade                 |   119
T12 | Conservative Continuation Decision            |   107
T13 | Swarm-Level Slice Negotiation                 |   128
T14 | Scheduler Reconfiguration Adaptation          |   107
T15 | Decision Consistency under Replanning         |   129
T16 | Network-Exposed Compute Marketplace           |   155
T17 | Network-Knowledge RAG Augmentation            |   154
T18 | AI Agent Identity & Onboarding                |   157
T19 | AI Agent Interoperability & Federation        |   129
 T2 | Intent Conflict Resolution                    |   113
T20 | Agent-to-Agent Communication Management       |   105
T21 | Device-Network Task Offload Arbitration       |

In [12]:
# DIAGNOSTIC: run one local request and inspect output
pairs = load_episode_mcq_pairs(DATA_DIR)
if len(pairs) == 0:
    raise RuntimeError(f'No episode/mcq pairs found in {DATA_DIR}')

episode_id, ep, mcq = pairs[0]
q = mcq['questions'][0]
task_id = q['task_id']

messages = build_eval_messages(
    task_id,
    summarize_episode_for_prompt(ep, SUMMARY_MAX_TURNS),
    q
)

print('--- Diagnostic: prompt length ---')
print(sum(len(m["content"]) for m in messages), "chars")

raw_text, debug_info = local_chat_debug(
    model=EVAL_MODEL,
    messages=messages,
    temperature=0.0,
    max_tokens=512,
    seed=42,
)

print('\n--- Completion (first 2000 chars) ---')
print(raw_text[:2000])

print('\n--- Parsed answer ---')
print("pred =", parse_mcq_answer_from_json(raw_text))


Found 488 episode/MCQ pairs in /content/6GBench_3k_Validated/mcq_questions_only
--- Diagnostic: prompt length ---
5654 chars
üîÑ Loading model locally: tiiuae/Falcon-H1-Tiny-90M-Instruct


config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/182M [00:00<?, ?B/s]

The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


Loading weights:   0%|          | 0/386 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/152 [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



--- Completion (first 2000 chars) ---
system
You are an expert 6G network AI agent evaluator. You will answer a multiple-choice question (A/B/C/D) about a UAV mission episode. Use the episode context and the target task definition to choose the best answer. You MUST respond with a single JSON object of the form {"answer": "A"}.

user
TARGET TASK
TASK_ID: T1
TASK_NAME: Intent Feasibility Assessment
DEFINITION: Given a mission and a 6G intent message, determine whether the intent is feasible under current and near-future network, environmental, and policy constraints. Identify the minimal safe adjustments (e.g., speed, route, slice, autonomy level, sensing configuration) needed to satisfy constraints while preserving mission objectives as much as possible.

EPISODE SUMMARY
Initial context:
- env: {'weather': 'good', 'wind_mps': 8.5, 'wind_dir_deg': 240.0}
- airspace: {'alt_bounds': [5.0, 120.0], 'geofence': [{'type': 'polygon', 'points': [[0.0, 0.0], [200.0, 0.0], [200.0, 150.0], [0.0, 

In [None]:
from pathlib import Path
import json
import sys
import time
from datetime import timedelta
from tqdm import tqdm

def format_seconds(sec: float) -> str:
    return str(timedelta(seconds=int(sec)))

EVAL_MODELS = [
    "tiiuae/Falcon-H1-Tiny-90M-Instruct",
    # "mistralai/Mistral-7B-Instruct-v0.3",
    # "meta-llama/Llama-3.1-8B-Instruct",
]

TXT_LOG_PATH = Path(
    "/content/6gbench_mcq_eval_cell_output_falcon.txt"
)

TIME_LOG_PATH = Path(
    "/content/6gbench_mcq_eval_timing_falcon.txt"
)

global_start_time = time.time()
per_model_times = {}

with TXT_LOG_PATH.open("a", encoding="utf-8") as log_file:
    old_stdout = sys.stdout
    sys.stdout = log_file  # redirect ALL prints to file

    try:
        for EVAL_MODEL in tqdm(EVAL_MODELS, desc="Evaluating models"):

            model_start_time = time.time()

            print("\n" + "=" * 80)
            print(f"Evaluating model: {EVAL_MODEL}")
            print("=" * 80)

            # Run evaluation (all its prints are captured)
            results = eval_model_on_dir(
                DATA_DIR,
                EVAL_MODEL,
                max_pairs=None,
                seed_base=42
            )

            print(f"Model: {results['model']}")
            print(f"Total questions: {results['total_questions']}")
            print(f"Overall accuracy: {results['overall_accuracy']:.3f}")

            print("\nPer-task accuracy:")
            for tid, acc in sorted(results['per_task_accuracy'].items()):
                tname, _ = TASK_MAP.get(tid, ('?', ''))
                print(f"  {tid} ({tname}): {acc:.3f}")

            recs = results['records']
            none_preds = sum(1 for r in recs if r['pred'] is None)
            print(
                f"\nParsed predictions for {len(recs) - none_preds}/{len(recs)} questions."
            )

            # ----- Save global results -----
            safe_model_name = EVAL_MODEL.replace("/", "_")
            out_path = Path(f"6gbench_mcq_eval_results_{safe_model_name}_debug.json")

            with out_path.open('w', encoding='utf-8') as f:
                json.dump(results, f, indent=2, ensure_ascii=False)

            print(f'Saved detailed global results to {out_path.resolve()}')

            # ----- Save per-task results -----
            per_task_dir = Path(
                "/content/mcq_eval_per_task_debug"
            )
            per_task_dir.mkdir(exist_ok=True)

            per_task_records = {}
            for r in results['records']:
                per_task_records.setdefault(r['task_id'], []).append(r)

            for tid, task_recs in per_task_records.items():
                tname, _ = TASK_MAP.get(tid, ('?', ''))
                safe_name = tname.replace(' ', '_').replace('/', '_')
                fname = per_task_dir / f"{safe_model_name}_{tid}_{safe_name}.json"

                with fname.open('w', encoding='utf-8') as f:
                    json.dump(task_recs, f, indent=2, ensure_ascii=False)

                print(f'Saved {len(task_recs)} records for {tid} ({tname}) to {fname}')

            # ----- Per-model timing -----
            model_elapsed = time.time() - model_start_time
            per_model_times[EVAL_MODEL] = model_elapsed

            print(
                f"\n‚è± Model runtime: {format_seconds(model_elapsed)} "
                f"({model_elapsed:.1f} seconds)"
            )

            # ----- NEW: unload this model and free GPU memory -----
            unload_model(EVAL_MODEL)
            print(f"üßπ Freed GPU memory for {EVAL_MODEL}")

    finally:
        sys.stdout = old_stdout  # always restore stdout

global_elapsed = time.time() - global_start_time

with TIME_LOG_PATH.open("w", encoding="utf-8") as f:
    f.write("6G-Bench MCQ Evaluation Timing Report\n")
    f.write("=" * 60 + "\n\n")

    for model, sec in per_model_times.items():
        f.write(f"Model: {model}\n")
        f.write(f"  Time: {format_seconds(sec)} ({sec:.1f} seconds)\n\n")

    f.write("-" * 60 + "\n")
    f.write(
        f"TOTAL EXPERIMENT TIME: "
        f"{format_seconds(global_elapsed)} "
        f"({global_elapsed:.1f} seconds)\n"
    )

print(f"‚úî Full cell output appended to {TXT_LOG_PATH.resolve()}")
print(f"‚úî Timing report saved to {TIME_LOG_PATH.resolve()}")


Evaluating models:   0%|          | 0/1 [00:00<?, ?it/s]
Eval episodes:   0%|          | 0/488 [00:00<?, ?ep/s][A
Eval episodes:   0%|          | 1/488 [14:35<118:28:26, 875.78s/ep][A

Evaluating models:   0%|          | 0/1 [14:36<?, ?it/s]
Eval episodes:   0%|          | 1/488 [14:35<118:28:26, 875.78s/ep, acc=0.222, ok=2, q=9, wrong=7][A
Eval episodes:   0%|          | 2/488 [17:44<63:40:03, 471.61s/ep, acc=0.222, ok=2, q=9, wrong=7] [A

Evaluating models:   0%|          | 0/1 [17:45<?, ?it/s]
Eval episodes:   0%|          | 2/488 [17:44<63:40:03, 471.61s/ep, acc=0.182, ok=2, q=11, wrong=9][A
Eval episodes:   1%|          | 3/488 [20:57<46:23:23, 344.34s/ep, acc=0.182, ok=2, q=11, wrong=9][A

Evaluating models:   0%|          | 0/1 [20:57<?, ?it/s]
Eval episodes:   1%|          | 3/488 [20:57<46:23:23, 344.34s/ep, acc=0.154, ok=2, q=13, wrong=11][A