# TNAD Server Playground

Persona: Senior AI Researcher + Staff Software Engineer. This notebook wires the tensor-network augmented decoding (TNAD) stack into an interactive workflow that you can run on a high-memory server.

## Environment Preparation

1. (Optional) Create and activate a virtual environment.
   ```bash
   python3 -m venv tnad-env
   source tnad-env/bin/activate
   ```
2. Install Python dependencies (include `jupyter` if not preinstalled).
   ```bash
   pip install -r requirements.txt
   ```
3. Launch the notebook server on the remote machine and tunnel the chosen port to your laptop.

Once the runtime is ready, execute the cells below top to bottom.

In [None]:
# Bootstrap project context and optional runtime flags.
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd()
if not (PROJECT_ROOT / "tnad").exists():
    raise RuntimeError("Run this notebook from the quantum-search-llm repository root.")

if not any(str(PROJECT_ROOT) == entry for entry in sys.path):
    sys.path.insert(0, str(PROJECT_ROOT))

# MPS on Apple Silicon can throw conservative OOMs; this mirrors the CLI guidance.
os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")

print(f"Project root: {PROJECT_ROOT}")
print(f"Python executable: {sys.executable}")

In [None]:
# Core imports: Hugging Face model loader, TNAD primitives, evaluation helpers.
import math
from typing import Any, Dict, List, Optional

import torch
import yaml
from datasets import load_dataset

from tnad import FidelityGuidedBeamSearcher
from tnad.mps_manager import MPSSequence
from tnad.coherence_score import analyze_coherence_spectrum
from tnad.utils import get_device

from experiments.reproduce_paper_results import load_model_and_tokenizer
from experiments.run_gsm8k import extract_answer_from_text as extract_gsm8k_answer
from experiments.run_gsm8k import format_prompt as format_gsm8k_prompt
from experiments.run_strategyqa import extract_yes_no_answer
from experiments.run_strategyqa import format_prompt as format_strategyqa_prompt
from experiments.run_strategyqa import load_strategyqa_dataset

print(f"PyTorch version: {torch.__version__}")

In [None]:
# Configure model + decoding defaults; tweak these to fit server capacity.
MODEL_NAME = "microsoft/phi-2"  # swap to a larger model if the server has sufficient memory
USE_8BIT = False  # set True if bitsandbytes and GPU 8-bit inference are available
FGBS_CONFIG: Dict[str, Any] = {
    "beam_width": 3,
    "alpha": 0.5,
    "bond_dim": 16,
    "top_k": 40,
    "temperature": 1.0,
    "normalize_embeddings": True,
}
GENERATION_CONFIG: Dict[str, Any] = {
    "max_length": 256,
    "min_length": 8,
    "return_details": True,
    "show_progress": False,
}

model, tokenizer = load_model_and_tokenizer(
    model_name=MODEL_NAME,
    device="auto",
    torch_dtype_name="float16",
    load_in_8bit=USE_8BIT,
)

searcher = FidelityGuidedBeamSearcher(
    model=model,
    tokenizer=tokenizer,
    beam_width=FGBS_CONFIG["beam_width"],
    alpha=FGBS_CONFIG["alpha"],
    bond_dim=FGBS_CONFIG["bond_dim"],
    top_k=FGBS_CONFIG["top_k"],
    temperature=FGBS_CONFIG["temperature"],
    normalize_embeddings=FGBS_CONFIG["normalize_embeddings"],
)

device_info = get_device()
print(f"Active device: {device_info}")

In [None]:
# Quick sanity check: run FGBS on a single reasoning prompt.
demo_prompt = "Q: If A > B and B > C, what can we conclude about A and C?\nA:"
demo_result = searcher.generate(demo_prompt, **GENERATION_CONFIG)
print("Generated text:\n", demo_result["text"])
print("\nScores:")
print(f"  log_prob      : {demo_result['log_prob']:.4f}")
print(f"  log_cfs       : {demo_result['log_cfs']:.4f}")
print(f"  composite_score: {demo_result['composite_score']:.4f}")

In [None]:
# Helper routines for batched evaluation on GSM8K (math) and StrategyQA (yes/no).
def evaluate_gsm8k_subset(
    searcher: FidelityGuidedBeamSearcher,
    num_examples: int = 5,
    split: str = "test",
    prompt_template: str = "Q: {question}\nA: Let's think step by step.",
    max_length: int = 256,
) -> Dict[str, Any]:
    dataset = load_dataset("gsm8k", "main", split=split)
    if num_examples > 0:
        dataset = dataset.select(range(min(num_examples, len(dataset))))

    rows: List[Dict[str, Any]] = []
    correct = 0
    for idx, sample in enumerate(dataset):
        prompt = format_gsm8k_prompt(sample["question"], prompt_template)
        result = searcher.generate(
            prompt,
            max_length=max_length,
            min_length=GENERATION_CONFIG["min_length"],
            return_details=False,
            show_progress=False,
        )
        predicted = extract_gsm8k_answer(result["text"])
        gold = extract_gsm8k_answer(sample["answer"])
        is_correct = (predicted == gold)
        if is_correct:
            correct += 1
        rows.append(
            {
                "example_id": idx,
                "question": sample["question"],
                "gold_answer": gold,
                "predicted_answer": predicted,
                "correct": is_correct,
                "generated_text": result["text"],
            }
        )

    accuracy = correct / len(rows) if rows else 0.0
    return {
        "dataset": f"gsm8k::{split}",
        "num_examples": len(rows),
        "accuracy": accuracy,
        "records": rows,
    }


def evaluate_strategyqa_subset(
    searcher: FidelityGuidedBeamSearcher,
    num_examples: int = 5,
    split: str = "validation",
    prompt_template: str = "Question: {question}\nAnswer (yes/no):",
    max_length: int = 128,
) -> Dict[str, Any]:
    loader_config = {
        "dataset": {
            "split": split,
            "strategyqa": {
                "prompt_template": prompt_template,
                "hub_ids": [
                    "wics/strategy-qa",
                    "wics/strategyqa",
                    "strategy_qa",
                ],
                "local_path": "data/strategyqa_sample.jsonl",
            },
        },
        "generation": {
            "max_length": max_length,
            "min_length": 6,
            "return_details": False,
        },
    }

    dataset = load_strategyqa_dataset(loader_config)
    if num_examples > 0:
        dataset = dataset.select(range(min(num_examples, len(dataset))))

    rows: List[Dict[str, Any]] = []
    correct = 0
    for idx, sample in enumerate(dataset):
        prompt = format_strategyqa_prompt(sample["question"], prompt_template)
        result = searcher.generate(
            prompt,
            max_length=max_length,
            min_length=loader_config["generation"]["min_length"],
            return_details=False,
            show_progress=False,
        )
        predicted_text = extract_yes_no_answer(result["text"])
        gold_raw = sample["answer"]
        if isinstance(gold_raw, str):
            gold_bool = gold_raw.strip().lower() in {"yes", "true", "1"}
        else:
            gold_bool = bool(gold_raw)
        predicted_bool: Optional[bool]
        if predicted_text is None:
            predicted_bool = None
        else:
            predicted_bool = predicted_text == "yes"
        is_correct = (predicted_bool == gold_bool)
        if is_correct:
            correct += 1
        rows.append(
            {
                "example_id": idx,
                "question": sample["question"],
                "gold_answer": gold_bool,
                "predicted_answer": predicted_bool,
                "generated_text": result["text"],
                "correct": is_correct,
            }
        )

    accuracy = correct / len(rows) if rows else 0.0
    return {
        "dataset": f"strategyqa::{split}",
        "num_examples": len(rows),
        "accuracy": accuracy,
        "records": rows,
    }


print("Helper functions ready; call them in the next cell to benchmark subsets.")

In [None]:
# Run subset evaluations (adjust counts upward once the server proves stable).
gsm8k_metrics = evaluate_gsm8k_subset(searcher, num_examples=3)
strategyqa_metrics = evaluate_strategyqa_subset(searcher, num_examples=3)

print("GSM8K subset accuracy:", f"{gsm8k_metrics['accuracy']*100:.1f}%")
print("StrategyQA subset accuracy:", f"{strategyqa_metrics['accuracy']*100:.1f}%")

In [None]:
# Inspect coherence dynamics for the last decoded trace.
if gsm8k_metrics["records"]:
    full_text = gsm8k_metrics["records"][-1]["generated_text"]
    mps = MPSSequence(bond_dim=FGBS_CONFIG["bond_dim"], embedding_dim=searcher.embedding_dim)
    with torch.no_grad():
        token_ids = tokenizer.encode(full_text, return_tensors="pt")[0].to(searcher.device)
        embeddings = searcher.embedding_layer(token_ids)
    for embedding in embeddings:
        mps.add_token(embedding)
    spectrum = mps.get_schmidt_values()
    coherence_report = analyze_coherence_spectrum(spectrum)
    print("Schmidt spectrum (first five values):", spectrum[:5])
    print("Coherence fidelity:", coherence_report["cfs"])
else:
    print("No GSM8K records available for coherence inspection.")

In [None]:
# Optional cleanup to release GPU memory before another experiment.
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("Memory cache cleared (if CUDA was active).")