In [1]:
{
  "nbformat": 4,
  "nbformat_minor": 5,
  "metadata": {
    "kernelspec": {
      "name": "python3",
      "language": "python",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python",
      "version": "3.12",
      "mimetype": "text/x-python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "nbconvert_exporter": "python",
      "file_extension": ".py"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# NRED Exploration Notebook\n",
        "\n",
        "This notebook provides a quick interactive exploration of **NRED**:\n",
        "\n",
        "- Run NRED on simple prompts\n",
        "- Compare baseline vs reasoning-enhanced decoding\n",
        "- Inspect latent reasoning traces and consistency scores\n",
        "- Run a mini synthetic evaluation inside the notebook\n",
        "\n",
        "Assumes this notebook is inside `nred/notebooks/` and the `nred` package is importable from the project root."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "# 1. Environment setup\n",
        "\n",
        "import os\n",
        "import sys\n",
        "\n",
        "# Add project root to path if needed (so `import nred` works)\n",
        "root_dir = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
        "if root_dir not in sys.path:\n",
        "    sys.path.append(root_dir)\n",
        "\n",
        "print(\"Project root:\", root_dir)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "# 2. Import NRED core\n",
        "\n",
        "import torch\n",
        "from nred import NRED, baseline_decode, reasoning_decode\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "device"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Quick demo: baseline vs NRED\n",
        "\n",
        "We start with a few simple arithmetic / reasoning prompts and compare:\n",
        "\n",
        "- **Baseline** decoding (plain TinyLlama)\n",
        "- **NRED** decoding (latent reasoning + consistency + fallback)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "def run_demo(prompts):\n",
        "    for i, p in enumerate(prompts, 1):\n",
        "        print(\"\\n\" + \"=\"*40)\n",
        "        print(f\"Example {i}\")\n",
        "        print(\"Prompt:\", p)\n",
        "\n",
        "        out = NRED(p)\n",
        "        print(\"\\n[Baseline]\")\n",
        "        print(out[\"baseline\"])\n",
        "\n",
        "        print(\"\\n[NRED output]\")\n",
        "        print(out[\"output\"])\n",
        "\n",
        "        print(\"\\n[Latent reasoning]\")\n",
        "        print(out[\"latent\"])\n",
        "\n",
        "        print(\"\\nScore:\", round(out[\"score\"], 3), \"Mode:\", out[\"mode\"])"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "prompts = [\n",
        "    \"What is 17 + 28?\",\n",
        "    \"Tom has 5 apples, buys 7 more, then gives away 3. How many apples does he have now?\",\n",
        "    \"Is 143 an even number or an odd number?\",\n",
        "    \"If a bus has 32 seats and 18 are occupied, how many seats are free?\"\n",
        "]\n",
        "\n",
        "run_demo(prompts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inspect a single call to `reasoning_decode`\n",
        "\n",
        "Here we call `reasoning_decode` directly to see the **latent reasoning trace**\n",
        "and consistency score before any fallback is applied."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "prompt = \"Compute: (12 + 7) - 5 = ?\"\n",
        "latent, score, latent_again, mode = reasoning_decode(prompt)\n",
        "\n",
        "print(\"Prompt:\", prompt)\n",
        "print(\"\\nLatent reasoning:\")\n",
        "print(latent)\n",
        "print(\"\\nScore:\", round(score, 3))\n",
        "print(\"Mode:\", mode)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Synthetic tasks inside the notebook\n",
        "\n",
        "We re-use the experiment utilities from `nred/experiments` to run a small\n",
        "evaluation interactively."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "from nred.experiments.synthetic_tasks import synthetic_parity, synthetic_arithmetic\n",
        "from nred.experiments.evaluation import evaluate_tasks, evaluate_gsm8k\n",
        "\n",
        "parity_tasks = synthetic_parity(n=50)\n",
        "arith_tasks = synthetic_arithmetic(n=50)\n",
        "\n",
        "print(\"Parity – Baseline:\", evaluate_tasks(parity_tasks, use_nred=False))\n",
        "print(\"Parity – NRED:\", evaluate_tasks(parity_tasks, use_nred=True))\n",
        "\n",
        "print(\"\\nArithmetic – Baseline:\", evaluate_tasks(arith_tasks, use_nred=False))\n",
        "print(\"Arithmetic – NRED:\", evaluate_tasks(arith_tasks, use_nred=True))\n",
        "\n",
        "print(\"\\nGSM8K-mini – Baseline:\", evaluate_gsm8k(use_nred=False))\n",
        "print(\"GSM8K-mini – NRED:\", evaluate_gsm8k(use_nred=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Ablation sanity check\n",
        "\n",
        "We can also import the ablation helper and run a tiny arithmetic subset to\n",
        "see how turning off different parts of NRED affects performance."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "source": [
        "from nred.experiments.evaluation import evaluate_ablation\n",
        "\n",
        "small_arith = synthetic_arithmetic(n=30)\n",
        "\n",
        "settings = {\n",
        "    \"full\": {},\n",
        "    \"no_latent\": {\"disable_latent\": True},\n",
        "    \"no_consistency\": {\"disable_consistency\": True},\n",
        "    \"no_fallback\": {\"disable_fallback\": True}\n",
        "}\n",
        "\n",
        "for name, cfg in settings.items():\n",
        "    acc = evaluate_ablation(small_arith, cfg)\n",
        "    print(f\"{name}: {acc:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Notes\n",
        "\n",
        "- This notebook is meant for **qualitative inspection** and small runs.\n",
        "- For full experiments and exact numbers used in the report, run:\n",
        "  - `python nred/experiments/evaluation.py`\n",
        "- You can easily extend this notebook with more prompts, plots, or logging\n",
        "  (e.g., saving latent traces for error analysis)."
      ]
    }
  ]
}


NameError: name 'null' is not defined