# Optimal stopping notebook patch (from `optimal-stopping.json`)

This notebook **contains the unified diff/patch** that updates `model-training + Behav Cloning.ipynb` (adds low-RAM defaults, additional comments, and a few safer training/generation settings).

## How to use in your GitHub repo

1. Put this notebook in your repo under `notebooks/` (or anywhere).
2. Run the first code cell to write the patch to `optimal-stopping.patch`.
3. From your repo root, apply it:

```bash
git apply notebooks/optimal-stopping.patch
```

If the target notebook path in your repo is different, edit the patch headers accordingly.

---

## Patch content


In [None]:
# Writes the patch to disk so you can apply it with `git apply`.
from pathlib import Path

patch_path = Path("optimal-stopping.patch")
patch_path.write_text(r'''diff --git a/notebooks/model-training + Behav Cloning.ipynb b/notebooks/model-training + Behav Cloning.ipynb
index ac4b69a5e9d865f3f2baa48862ee346b15160954..1a36d43c67f162e91a1ece9314db97c1f914989d 100644
--- a/notebooks/model-training + Behav Cloning.ipynb	
+++ b/notebooks/model-training + Behav Cloning.ipynb	
@@ -17,69 +17,80 @@
    "source": [
     "## Imports\n",
     "\n",
     "Before running this notebook locally, you need to [install PyTorch](https://pytorch.org/get-started/locally/) for your hardware.\n",
     "\n",
     "Then, you need to install the following packages:\n",
     "\n",
     "   * transformers\n",
     "   * datasets\n",
     "   * accelerate\n",
     "   * pandas\n",
     "   * huggingface_hub (needed for Llama models)\n",
     "   * scikit-learn\n",
     "   * numpy\n",
     "\n",
     "You an also use the `requirements.txt` in the [stopping-agents](https://github.com/emaadmanzoor/stopping-agents) repository."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "id": "cf5fb535",
    "metadata": {},
    "outputs": [],
    "source": [
+    "# What this block does: Imports libraries and defines global experiment/config constants used across the notebook.\n",
+    "# Why we choose this setup: Keeping core knobs in one place makes Colab/laptop tuning explicit and reduces trial-and-error when memory is limited.\n",
+    "\n",
     "import datasets\n",
+    "import gc\n",
     "import huggingface_hub # needed for Llama models\n",
     "import math\n",
     "import numpy as np\n",
     "import pandas as pd\n",
     "import torch\n",
     "import transformers\n",
     "\n",
     "from sklearn.model_selection import train_test_split\n",
     "from sklearn.metrics import roc_auc_score\n",
     "from tqdm.auto import tqdm\n",
     "\n",
     "HF_TOKEN = \"HF_TOKEN\"\n",
     "\n",
     "COST_PER_UNIT_TIME = 0.1\n",
     "BENEFIT_PER_POSITIVE_OUTCOME = 10.0\n",
     "DECISION_OPPORTUNITIES = [45, 60] # time in seconds at which the \n",
     "                                  # agent can decide to quit or wait\n",
-    "                                  # code is tailored to just 2 right now"
+    "                                  # code is tailored to just 2 right now\n",
+    "\n",
+    "LOW_RAM_MODE = True  # Recommended for Colab/local laptops\n",
+    "MAX_SEQ_LEN = 768 if LOW_RAM_MODE else 1024\n",
+    "TRAIN_BATCH_SIZE = 1 if LOW_RAM_MODE else 12\n",
+    "EVAL_BATCH_SIZE = 1 if LOW_RAM_MODE else 12\n",
+    "GRAD_ACCUM_STEPS = 16 if LOW_RAM_MODE else 1\n",
+    "GEN_BATCH_SIZE = 16 if LOW_RAM_MODE else 72\n"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "931864dc",
    "metadata": {},
    "source": [
     "## Load and process conversation data\n",
     "\n",
     "We load a dataset of synthetic conversations available\n",
     "in the `datasets` folder at [https://github.com/emaadmanzoor/stopping-agents/](https://github.com/emaadmanzoor/stopping-agents/). This example dataset is formatted in the PyAnnote diarized conversation format."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 2,
    "id": "99c5e357",
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/html": [
        "<div>\n",
        "<style scoped>\n",
        "    .dataframe tbody tr th:only-of-type {\n",
@@ -175,89 +186,95 @@
        "2         20756_1           0        7.98     12.84   \n",
        "3         20756_1           1       13.14     15.50   \n",
        "4         20756_1           0       15.89     22.14   \n",
        "\n",
        "                                                text  outcome  is_sale  \\\n",
        "0  Hello, is this Mr. Harris? My name is Leah fro...  no sale        0   \n",
        "1  Yes, speaking. I’m alright, thanks. Can I ask ...  no sale        0   \n",
        "2  Of course, thanks for asking. I’m reaching out...  no sale        0   \n",
        "3        Alright… I guess I can listen for a minute.  no sale        0   \n",
        "4  Thank you! So, our new BrightSaver plan locks ...  no sale        0   \n",
        "\n",
        "   duration  \n",
        "0     62.07  \n",
        "1     62.07  \n",
        "2     62.07  \n",
        "3     62.07  \n",
        "4     62.07  "
       ]
      },
      "execution_count": 2,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
+    "# What this block does: Downloads the synthetic sales-call dataset and creates target/metadata columns.\n",
+    "# Why we choose this setup: We derive `is_sale` and call `duration` once up front so every later stage (splits, rewards, training labels) uses consistent ground truth.\n",
+    "\n",
     "dataset_url = \"https://raw.githubusercontent.com/emaadmanzoor/stopping-agents/refs/heads/main/datasets/synthetic_sales_conversations.csv?token=GHSAT0AAAAAADBUAD4WOA6XRF2GSIX5UC4Y2EEF66Q\"\n",
     "\n",
     "diarized_conversations = pd.read_csv(dataset_url)\n",
     "\n",
     "diarized_conversations[\"is_sale\"] =\\\n",
     "        diarized_conversations[\"outcome\"].apply(\n",
     "            lambda x: 1 if x == \"sale\" else 0 if x == \"no sale\" else np.nan)\n",
     "\n",
     "diarized_conversations[\"duration\"] =\\\n",
     "    diarized_conversations.groupby(\"conversation_id\")[\"end_time\"].transform(\"max\")\n",
     "\n",
     "diarized_conversations.head()"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "cefd33a2",
    "metadata": {},
    "source": [
     "### Split into train, validation, and test conversations"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 3,
    "id": "a51e8881",
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "1903 train conversations.\n",
       "651 validation conversations.\n",
       "860 test conversations.\n"
      ]
     }
    ],
    "source": [
+    "# What this block does: Splits conversations into train/validation/test sets with label stratification.\n",
+    "# Why we choose this setup: Splitting by conversation_id avoids leakage across turns from the same call, and stratification preserves class balance in each split.\n",
+    "\n",
     "all_conversation_ids =\\\n",
     "    diarized_conversations[[\"conversation_id\", \"is_sale\"]].drop_duplicates()[\"conversation_id\"]\\\n",
     "        .values\n",
     "all_outcomes =\\\n",
     "    diarized_conversations[[\"conversation_id\", \"is_sale\"]].drop_duplicates()[\"is_sale\"].values\n",
     "    \n",
     "train_conversation_ids, test_conversation_ids, train_outcomes, test_outcomes =\\\n",
     "    train_test_split(all_conversation_ids, all_outcomes, test_size=0.25, random_state=42,\n",
     "                     stratify=all_outcomes)\n",
     "train_conversation_ids, val_conversation_ids, train_outcomes, val_outcomes =\\\n",
     "    train_test_split(train_conversation_ids, train_outcomes, test_size=0.25, random_state=42,\n",
     "                     stratify=train_outcomes)\n",
     "\n",
     "diarized_conversations_train =\\\n",
     "    diarized_conversations[diarized_conversations[\"conversation_id\"].isin(train_conversation_ids)]\n",
     "diarized_conversations_val =\\\n",
     "    diarized_conversations[diarized_conversations[\"conversation_id\"].isin(val_conversation_ids)]\n",
     "diarized_conversations_test =\\\n",
     "    diarized_conversations[diarized_conversations[\"conversation_id\"].isin(test_conversation_ids)]\n",
     "\n",
     "print(len(diarized_conversations_train), \"train conversations.\")\n",
     "print(len(diarized_conversations_val), \"validation conversations.\")\n",
     "print(len(diarized_conversations_test), \"test conversations.\")"
    ]
   },
@@ -364,50 +381,53 @@
        "2         92837_7     43.64        0   \n",
        "3         58241_9     50.01        0   \n",
        "4        20567_11     45.19        0   \n",
        "\n",
        "                               transcript_speaker_45  \\\n",
        "0  Speaker 0: Hello, is this Mr. Harris? My name ...   \n",
        "1  Speaker 0: Good afternoon! Is this Ms. Parker?...   \n",
        "2  Speaker 0: Good afternoon, may I speak with Mr...   \n",
        "3  Speaker 0: Hello, may I speak with Ms. Jenkins...   \n",
        "4  Speaker 0: Good afternoon, is this Mr. Carver?...   \n",
        "\n",
        "                               transcript_speaker_60  \n",
        "0  Speaker 0: Hello, is this Mr. Harris? My name ...  \n",
        "1  Speaker 0: Good afternoon! Is this Ms. Parker?...  \n",
        "2  Speaker 0: Good afternoon, may I speak with Mr...  \n",
        "3  Speaker 0: Hello, may I speak with Ms. Jenkins...  \n",
        "4  Speaker 0: Good afternoon, is this Mr. Carver?...  "
       ]
      },
      "execution_count": 4,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
+    "# What this block does: Builds per-conversation transcripts truncated at each decision time (m1, m2).\n",
+    "# Why we choose this setup: Behavioral cloning needs state snapshots at decision times, so we materialize both views once for reproducible feature creation.\n",
+    "\n",
     "m1, m2 = sorted(DECISION_OPPORTUNITIES)\n",
     "\n",
     "data_transcripts = {}\n",
     "for df, dftype in zip([diarized_conversations_train,\n",
     "                       diarized_conversations_val,\n",
     "                       diarized_conversations_test],\n",
     "                      [\"train\", \"val\", \"test\"]):\n",
     "    \n",
     "    data_transcripts[dftype] = df.copy()\n",
     "\n",
     "    data_transcripts[dftype][\"transcript\"] =\\\n",
     "        \"Speaker \" +\\\n",
     "        data_transcripts[dftype][\"speaker_id\"].astype(str) + \": \" +\\\n",
     "        data_transcripts[dftype][\"text\"]\n",
     "\n",
     "    transcripts = {}\n",
     "    for m in [m1, m2]: \n",
     "        transcripts[m] =\\\n",
     "            data_transcripts[dftype].loc[(data_transcripts[dftype][\"end_time\"]>=0) &\n",
     "                                         (data_transcripts[dftype][\"end_time\"]<m)]\\\n",
     "                    .groupby(\"conversation_id\")[\"transcript\"]\\\n",
     "                    .apply(lambda x: '\\n'.join(x))\\\n",
     "                    .reset_index(name=\"transcript_speaker_\" + str(m))\n",
     "\n",
     "    data_transcripts[dftype] = \\\n",
@@ -463,50 +483,53 @@
       "Speaker 0: There's no early cancellation fee and no long-term contract; you can opt out any time. We just want people to enjoy lower, predictable pricing with no risk.\n",
       "Speaker 1: To be honest, I just re-upped my plan last month. I don’t like to change stuff if it’s working.\n",
       "Will this call end in a sale (respond with 'yes' or 'no'):  \n",
       "\n",
       "Example state at 60 seconds:\n",
       "Below is the first 60 seconds of the sales call between the sales agent Speaker 0 and the customer Speaker 1:\n",
       "Speaker 0: Hello, is this Mr. Harris? My name is Leah from Sunview Energy—how are you today?\n",
       "Speaker 1: Yes, speaking. I’m alright, thanks. Can I ask what this is about?\n",
       "Speaker 0: Of course, thanks for asking. I’m reaching out because we’re offering a new energy plan that could qualify you for a 15% discount on your electric bill. I wanted to see if I could quickly tell you about it.\n",
       "Speaker 1: Alright… I guess I can listen for a minute.\n",
       "Speaker 0: Thank you! So, our new BrightSaver plan locks in your rate for twelve months—there’s no change in price based on the time of day, and there are no hidden fees. And for this month, you’d also get an automatic 15% off your supply charges.\n",
       "Speaker 1: Is this something I have to switch providers for? I’m pretty happy with who I have now.\n",
       "Speaker 0: You would stay connected to your local utility for service and repairs, but Sunview would handle the billing and supply. The switch is very simple and risk-free—if you change your mind, you can cancel within 30 days.\n",
       "Speaker 1: I see. Is there a contract or any penalties?\n",
       "Speaker 0: There's no early cancellation fee and no long-term contract; you can opt out any time. We just want people to enjoy lower, predictable pricing with no risk.\n",
       "Speaker 1: To be honest, I just re-upped my plan last month. I don’t like to change stuff if it’s working.\n",
       "Speaker 0: That’s completely understandable, Mr. Harris. Do you mind if I ask how much you’re paying per kilowatt-hour, just to make sure you’re on the best deal?\n",
       "Speaker 1: Actually, I’m not sure off the top of my head. I just check that the total seems right each month.\n",
       "Speaker 0: Totally fair. If you’re interested, I could email you a side-by-side comparison of our BrightSaver plan and your last bill—no obligation, just information.\n",
       "Speaker 1: No, that’s okay. If I decide to look into it, I’ll reach out myself.\n",
       "Will this call end in a sale (respond with 'yes' or 'no'):  \n"
      ]
     }
    ],
    "source": [
+    "# What this block does: Converts each transcript snapshot into a language-model prompt/state string.\n",
+    "# Why we choose this setup: A consistent prompt template reduces formatting variance and focuses the model on the yes/no sale-outcome prediction task.\n",
+    "\n",
     "def convert_to_state(transcript, t):\n",
     "    assert type(transcript) == str\n",
     "\n",
     "    state = \"Below is the first \" + str(t) +\\\n",
     "            \" seconds of the sales call between the sales agent Speaker 0 and\" +\\\n",
     "            \" the customer Speaker 1:\\n\" +\\\n",
     "            transcript + \"\\n\" +\\\n",
     "            \"Will this call end in a sale (respond with 'yes' or 'no'):  \"\n",
     "\n",
     "    return state\n",
     "\n",
     "for df in [data_transcripts[\"train\"],\n",
     "           data_transcripts[\"val\"],\n",
     "           data_transcripts[\"test\"]]:\n",
     "\n",
     "    for m in [m1, m2]:\n",
     "        df.loc[:, \"s\" + str(m)] = df.apply(lambda x:\n",
     "                                            convert_to_state(x[\"transcript_speaker_\" + str(m)],\n",
     "                                                             m), axis=1)\n",
     "\n",
     "print(\"Example state at\", m1, \"seconds:\")\n",
     "print(data_transcripts[\"train\"][\"s\" + str(m1)].values[0])\n",
     "print()\n",
     "print(\"Example state at\", m2, \"seconds:\")\n",
     "print(data_transcripts[\"train\"][\"s\" + str(m2)].values[0])"
@@ -620,50 +643,53 @@
        "</table>\n",
        "</div>"
       ],
       "text/plain": [
        "  conversation_id                                              state action  \\\n",
        "0         20756_1  Below is the first 45 seconds of the sales cal...     no   \n",
        "1         59321_6  Below is the first 45 seconds of the sales cal...     no   \n",
        "2         58241_9  Below is the first 45 seconds of the sales cal...     no   \n",
        "3        20567_11  Below is the first 45 seconds of the sales cal...     no   \n",
        "4        10523_13  Below is the first 45 seconds of the sales cal...     no   \n",
        "\n",
        "   is_sale  duration  \n",
        "0        0     62.07  \n",
        "1        0     55.37  \n",
        "2        0     50.01  \n",
        "3        0     45.19  \n",
        "4        0     65.04  "
       ]
      },
      "execution_count": 6,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
+    "# What this block does: Computes reward outcomes under each stopping policy path and derives optimal yes/no actions.\n",
+    "# Why we choose this setup: This converts the optimal stopping objective into supervised labels so we can train via behavioral cloning.\n",
+    "\n",
     "optimal_state_action_pairs = {}\n",
     "for dftype, df in data_transcripts.items():\n",
     "    df[\"rq\" + str(m1)] = -m1 * COST_PER_UNIT_TIME # stop at 30\n",
     "\n",
     "    # continue at 30, stop at 60\n",
     "    df[\"rq\" + str(m2)] = df[\"is_sale\"].astype(int)\\\n",
     "                        * BENEFIT_PER_POSITIVE_OUTCOME\\\n",
     "                        * (df[\"duration\"]<=m2).astype(int) \\\n",
     "                        - df[\"duration\"].apply(lambda x: min(m2, x)) * COST_PER_UNIT_TIME\n",
     "    \n",
     "    # continue at 30, continue at 60, continue at 90 = never quit                \n",
     "    df[\"rc\" + str(m2)] = df[\"is_sale\"].astype(int) * BENEFIT_PER_POSITIVE_OUTCOME\\\n",
     "                        - df[\"duration\"] * COST_PER_UNIT_TIME\n",
     "\n",
     "    df[\"max_reward\"] = df[[\"rq\" + str(m1), \"rq\" + str(m2), \"rc\" + str(m2)]].max(axis=1)\n",
     "\n",
     "    # optimal to quit at 30\n",
     "    df.loc[df[\"max_reward\"]==df[\"rq\" + str(m1)], \"a\" + str(m1)] = \"no\"\n",
     "    df.loc[df[\"max_reward\"]==df[\"rq\" + str(m1)], \"a\" + str(m2)] = \"no\"\n",
     "\n",
     "    # optimal to continue at 30, stop at 60\n",
     "    df.loc[df[\"max_reward\"]==df[\"rq\" + str(m2)], \"a\"  + str(m1)] = \"yes\"\n",
     "    df.loc[df[\"max_reward\"]==df[\"rq\" + str(m2)], \"a\"  + str(m2)] = \"no\"\n",
     "\n",
     "    # optimal to continue at 30, continue at 60, continue at 90\n",
@@ -718,61 +744,83 @@
     "### Load base model"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 7,
    "id": "db06d469",
    "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
        "model_id": "7aa0faf14d8d43bcaea305837ff480f4",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
        "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     }
    ],
    "source": [
+    "# What this block does: Authenticates to HF, loads tokenizer/model, and applies low-memory loading paths when possible.\n",
+    "# Why we choose this setup: 1B + optional 4-bit quantization is substantially more stable on Colab/laptop hardware while preserving the same notebook workflow.\n",
+    "\n",
     "huggingface_hub.login(token=HF_TOKEN)\n",
     "\n",
-    "model_name = \"meta-llama/Llama-3.2-3B\" # base model, not instruction-tuned\n",
+    "# 1B is much easier to run on Colab/laptops than 3B.\n",
+    "model_name = \"meta-llama/Llama-3.2-1B\" if LOW_RAM_MODE else \"meta-llama/Llama-3.2-3B\"\n",
     "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
-    "model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=\"auto\")\n",
     "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
     "tokenizer.padding_side = 'left'\n",
     "\n",
+    "if torch.cuda.is_available() and LOW_RAM_MODE:\n",
+    "    bnb_config = transformers.BitsAndBytesConfig(\n",
+    "        load_in_4bit=True,\n",
+    "        bnb_4bit_quant_type=\"nf4\",\n",
+    "        bnb_4bit_compute_dtype=torch.float16,\n",
+    "    )\n",
+    "    model = transformers.AutoModelForCausalLM.from_pretrained(\n",
+    "        model_name,\n",
+    "        quantization_config=bnb_config,\n",
+    "        device_map=\"auto\",\n",
+    "        low_cpu_mem_usage=True,\n",
+    "    )\n",
+    "else:\n",
+    "    model = transformers.AutoModelForCausalLM.from_pretrained(\n",
+    "        model_name,\n",
+    "        torch_dtype=\"auto\",\n",
+    "        low_cpu_mem_usage=True,\n",
+    "    )\n",
+    "\n",
     "if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:\n",
     "    print(\"WARNING: Resizing the embedding matrix to match the tokenizer vocab size.\")\n",
-    "    model.resize_token_embeddings(len(tokenizer))"
+    "    model.resize_token_embeddings(len(tokenizer))\n"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "240767e7",
    "metadata": {},
    "source": [
     "### Construct and tokenize fine-tuning datasets\n",
     "\n",
     "We perform manual masking, so the loss is only calculated for the generated actions."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 8,
    "id": "d8c94248",
    "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
        "model_id": "ac592b7caaa94bc4b89d7fb9d868dacf",
        "version_major": 2,
        "version_minor": 0
       },
@@ -799,81 +847,87 @@
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Tokenization test:\n",
       "<|begin_of_text|>Below is the first 45 seconds of the sales call between the sales agent Speaker 0 and the customer Speaker 1:\n",
       "Speaker 0: Good afternoon, this is Marcus from Greenwave Energy. Am I speaking with Ms. Lopez?\n",
       "Speaker 1: Hi, yes, this is her.\n",
       "Speaker 0: Fantastic! I’ll keep this brief. We have a new energy plan with a guaranteed rate and monthly discounts for loyal clients. Are you open to hearing a quick summary?\n",
       "Speaker 1: Alright, sure. Go ahead.\n",
       "Speaker 0: Thank you! With the Greenwave Saver Plan, you lock in a fixed rate on electricity for a year. We're offering a $7 discount each month on your bill and a one-time $30 sign-up bonus. All this, with no contract lock-in or exit fees.\n",
       "Speaker 1: Hmm. What's the rate compared to what I'm paying now?\n",
       "Speaker 0: Great question. On your most recent bill, you were charged $0.15 per kWh. Our plan offers $0.132 per kWh, so you'd see a savings, plus the ongoing monthly discount.\n",
       "Speaker 1: I don’t know… It sounds good, but I just switched providers last month. I’m kind of locked in for now.\n",
       "Speaker 0: I understand. If there’s a penalty to leave early, I’d hate for you to pay that. Just so you know, our plan has no switching fee, so you could come back anytime.\n",
       "Speaker 1: That’s helpful, thank you. But I think for now, I’ll have to pass. Maybe I’ll look into it when my contact runs out.\n",
       "Will this call end in a sale (respond with 'yes' or 'no'):  no<|end_of_text|>\n",
       "\n",
       "Expected Label (action):\n",
       "no<|end_of_text|>\n"
      ]
     }
    ],
    "source": [
+    "# What this block does: Builds train/val datasets and tokenizes prompts+labels with manual masking.\n",
+    "# Why we choose this setup: Manual label masking trains only on the action token(s), and MAX_SEQ_LEN truncation prevents long transcripts from exhausting memory.\n",
+    "\n",
     "train_dataset = datasets.Dataset.from_dict(\n",
     "    {\"prompt\": [state for state in optimal_state_action_pairs[\"train\"][\"state\"].values], \n",
     "     \"completion\": [action.strip()\n",
     "                    for action in optimal_state_action_pairs[\"train\"][\"action\"].values]}).shuffle()\n",
     "    \n",
     "val_dataset = datasets.Dataset.from_dict(\n",
     "    {\"prompt\": [state for state in optimal_state_action_pairs[\"val\"][\"state\"].values],\n",
     "     \"completion\": [action.strip()\n",
     "                    for action in optimal_state_action_pairs[\"val\"][\"action\"].values]}).shuffle()\n",
     "\n",
     "def tokenize_fn(example, add_label):\n",
     "    # start with the BOS token if it exists\n",
     "    if tokenizer.bos_token is not None:\n",
     "        encoded_prompt = tokenizer.encode(tokenizer.bos_token +\n",
     "                                          example[\"prompt\"],              \n",
     "                                          add_special_tokens=False)\n",
     "    else:\n",
     "        encoded_prompt = tokenizer.encode(example[\"prompt\"], \n",
     "                                          add_special_tokens=False)\n",
     "\n",
     "    # add the label if needed for the training and validation datasets\n",
     "    if add_label:\n",
     "        encoded_label = tokenizer.encode(example[\"completion\"] + tokenizer.eos_token, \n",
     "                                         add_special_tokens=False)\n",
-    "        return {\"input_ids\": encoded_prompt + encoded_label,\n",
-    "                \"attention_mask\" : [1] * (len(encoded_prompt) + len(encoded_label)),\n",
-    "                \"labels\": [-100] * len(encoded_prompt) + encoded_label}\n",
+    "        input_ids = (encoded_prompt + encoded_label)[:MAX_SEQ_LEN]\n",
+    "        labels = ([-100] * len(encoded_prompt) + encoded_label)[:MAX_SEQ_LEN]\n",
+    "        return {\"input_ids\": input_ids,\n",
+    "                \"attention_mask\" : [1] * len(input_ids),\n",
+    "                \"labels\": labels}\n",
     "    else:\n",
-    "        return {\"input_ids\": encoded_prompt,\n",
-    "                \"attention_mask\": [1] * len(encoded_prompt),\n",
-    "                \"labels\": [-100] * len(encoded_prompt)}\n",
+    "        input_ids = encoded_prompt[:MAX_SEQ_LEN]\n",
+    "        return {\"input_ids\": input_ids,\n",
+    "                \"attention_mask\": [1] * len(input_ids),\n",
+    "                \"labels\": [-100] * len(input_ids)}\n",
     "\n",
     "train_dataset = train_dataset.map(tokenize_fn,\n",
     "                                  remove_columns=[\"prompt\", \"completion\"], \n",
     "                                  fn_kwargs={\"add_label\": True})\n",
     "val_dataset = val_dataset.map(tokenize_fn, \n",
     "                              remove_columns=[\"prompt\", \"completion\"], \n",
     "                              fn_kwargs={\"add_label\": True})\n",
     "\n",
     "print(\"Tokenization test:\")\n",
     "print(tokenizer.decode(train_dataset[0][\"input_ids\"]))\n",
     "print()\n",
     "print(\"Expected Label (action):\")\n",
     "print(tokenizer.decode([l for l in train_dataset[0][\"labels\"] if l!=-100]))"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "8f5717d6",
    "metadata": {},
    "source": [
     "### Fine-tune"
    ]
   },
   {
    "cell_type": "code",
@@ -930,119 +984,121 @@
        "      <td>0.329565</td>\n",
        "      <td>1.000000</td>\n",
        "    </tr>\n",
        "  </tbody>\n",
        "</table><p>"
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "data": {
       "text/plain": [
        "TrainOutput(global_step=75, training_loss=1.3294019158681234, metrics={'train_runtime': 144.6715, 'train_samples_per_second': 12.096, 'train_steps_per_second': 1.037, 'total_flos': 7403051037947904.0, 'train_loss': 1.3294019158681234, 'epoch': 5.0})"
       ]
      },
      "execution_count": 9,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
+    "# What this block does: Defines metrics and launches Trainer fine-tuning with early stopping.\n",
+    "# Why we choose this setup: Small batches + gradient accumulation/checkpointing are memory-safe defaults that still preserve effective batch size and validation-driven model selection.\n",
+    "\n",
     "YES_TOKEN_ID = tokenizer.encode(\"yes\", add_special_tokens=False)[-1]\n",
     "NO_TOKEN_ID = tokenizer.encode(\"no\", add_special_tokens=False)[-1]\n",
     "\n",
     "def preprocess_logits_for_metrics(logits, labels):\n",
     "    \"\"\"\n",
     "    Original Trainer may have a memory leak. \n",
     "    This is a workaround to avoid storing too many tensors that are not needed.\n",
     "    \"\"\"\n",
     "    return logits[:, -3, :] # last non-padding token logits only, for causal LM\n",
     "\n",
     "def compute_metrics(eval_preds):\n",
     "    logits, labels = eval_preds\n",
     "    labels = [l[(l != -100) & (l != tokenizer.pad_token_id)][0] for l in labels]\n",
     "\n",
     "    logprobs = [torch.log_softmax(torch.from_numpy(s), dim=-1).numpy()\n",
     "                for s in logits]\n",
     "    labelprobs = [math.exp(logprob[label]) for logprob, label in zip(logprobs, labels)]\n",
     "\n",
     "    ytrue = [1 if label == YES_TOKEN_ID else 0 for label in labels]\n",
     "    ypred = [labelprob if label==YES_TOKEN_ID else 1.0 - labelprob\n",
     "             for labelprob, label in zip(labelprobs, labels)]\n",
     "    auc = roc_auc_score(ytrue, ypred)\n",
     "\n",
     "    return {\"auc\": auc}\n",
     "\n",
     "training_args = transformers.TrainingArguments(\n",
-    "    output_dir=\"./llama-3.2-3B/\",\n",
+    "    output_dir=\"./llama-3.2/\",\n",
     "    overwrite_output_dir=True,\n",
     "    remove_unused_columns=False,\n",
     "\n",
     "    save_strategy=\"best\",\n",
     "    logging_strategy=\"epoch\",\n",
     "    eval_strategy=\"epoch\",\n",
     "    save_total_limit=1,\n",
     "\n",
-    "    # on 1 H100 96GB: batch size of 12-16 for 3B works\n",
-    "    per_device_train_batch_size=12,\n",
-    "    per_device_eval_batch_size=12,\n",
-    "    gradient_accumulation_steps=1,\n",
+    "    per_device_train_batch_size=TRAIN_BATCH_SIZE,\n",
+    "    per_device_eval_batch_size=EVAL_BATCH_SIZE,\n",
+    "    gradient_accumulation_steps=GRAD_ACCUM_STEPS,\n",
     "\n",
     "    num_train_epochs=10,\n",
     "    learning_rate=1e-4,\n",
     "    optim=\"adamw_torch\",\n",
+    "    gradient_checkpointing=LOW_RAM_MODE,\n",
     "\n",
     "    load_best_model_at_end=True,\n",
     "    metric_for_best_model=\"auc\",\n",
     "    greater_is_better=True,\n",
     "\n",
     "    report_to=\"none\", # change to wandb if needed\n",
     "    save_safetensors=False, # needed to load saved models\n",
     "\n",
-    "    # change to suit hardware\n",
-    "    bf16=True, \n",
-    "    fp16=False,\n",
+    "    bf16=torch.cuda.is_available() and not LOW_RAM_MODE,\n",
+    "    fp16=torch.cuda.is_available() and LOW_RAM_MODE,\n",
     ")\n",
     "\n",
     "trainer = transformers.Trainer(\n",
     "    model=model,\n",
     "    args=training_args,\n",
     "    train_dataset=train_dataset.shuffle(),\n",
     "    eval_dataset=val_dataset,\n",
     "    processing_class=tokenizer,\n",
-    "    data_collator=transformers.data.DataCollatorForSeq2Seq(tokenizer),\n",
+    "    data_collator=transformers.data.DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8 if torch.cuda.is_available() else None),\n",
     "    compute_metrics=compute_metrics,\n",
     "    preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n",
     "    callbacks=[transformers.EarlyStoppingCallback(early_stopping_patience=1)]\n",
     ")\n",
     "\n",
-    "trainer.train(resume_from_checkpoint=False)"
+    "trainer.train(resume_from_checkpoint=False)\n"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "56c26914",
    "metadata": {},
    "source": [
     "### Evaluate"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "cdb95891",
    "metadata": {},
    "source": [
     "#### Get the validation and test responses for each state\n",
     "\n",
     "We get the agent's responses for the validation and test conversations. We use the validation responses for backward-induction threshold-tuning (our scalable alternative to grid search)."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 18,
    "id": "88183bd6",
    "metadata": {},
@@ -1069,159 +1125,169 @@
      "output_type": "display_data"
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Getting validation responses...\n"
      ]
     },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
        "model_id": "89a699eea10245818047fa37bed6ec25",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
        "  0%|          | 0/1 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     }
    ],
    "source": [
+    "# What this block does: Runs batched generation on validation/test prompts and collects response log-probabilities.\n",
+    "# Why we choose this setup: Controlled greedy decoding gives deterministic policy outputs, and reduced generation batches avoid inference-time OOMs.\n",
+    "\n",
     "val_prompts = list(optimal_state_action_pairs[\"val\"][\"state\"].values)\n",
     "test_prompts = list(optimal_state_action_pairs[\"test\"][\"state\"].values)\n",
     "\n",
     "print(\"Getting test responses...\")\n",
     "\n",
     "responses_test = []\n",
     "logprobs_test = []\n",
-    "batch_size = 72 # change to suit hardware\n",
+    "batch_size = GEN_BATCH_SIZE # lower for Colab/laptops\n",
     "\n",
     "for i in tqdm(range(0, len(test_prompts), batch_size)):\n",
     "    batch_prompts = test_prompts[i:i+batch_size]\n",
     "    batch = tokenizer(batch_prompts, \n",
     "                      return_tensors=\"pt\", \n",
     "                      padding=True, \n",
     "                      add_special_tokens=True,\n",
     "                      truncation=True).to(\"cuda\")\n",
     "\n",
     "    with torch.no_grad():\n",
     "        outputs = trainer.model.generate(\n",
     "            **batch, \n",
     "            max_new_tokens=2, \n",
     "            do_sample=False,\n",
     "            pad_token_id=tokenizer.eos_token_id,\n",
     "            temperature=None, top_p=None, top_k=None,\n",
     "            return_dict_in_generate=True, output_scores=True\n",
     "            # greedy decoding: so output_scores = output_logits\n",
     "        )\n",
     "\n",
     "    seqs = outputs.sequences\n",
     "    prompt_len = batch['input_ids'].shape[1]   # Length of the input prompts\n",
     "\n",
     "    # Slice to get only the generated new tokens\n",
     "    generated_tokens = seqs[:, prompt_len:]\n",
     "    decoded_outputs = tokenizer.batch_decode(generated_tokens,\n",
     "                                             skip_special_tokens=True)\n",
     "    decoded_outputs = [d.strip().lower() for d in decoded_outputs]\n",
     "\n",
     "    responses_test.extend(decoded_outputs)\n",
     "\n",
     "    scores = outputs.scores\n",
     "    logprobs = [torch.log_softmax(s, dim=-1) for s in scores]\n",
     "    logprobs = logprobs[0][torch.arange(logprobs[0].size(0)), seqs[:, -2].view(-1)]\n",
     "    logprobs_test.extend(logprobs.cpu().numpy())\n",
     "\n",
     "print(\"Getting validation responses...\")\n",
     "\n",
     "responses_val = []\n",
     "logprobs_val = []\n",
-    "batch_size = 72\n",
+    "batch_size = GEN_BATCH_SIZE\n",
     "\n",
     "for i in tqdm(range(0, len(val_prompts), batch_size)):\n",
     "    batch_prompts = val_prompts[i:i+batch_size]\n",
     "    batch = tokenizer(batch_prompts, \n",
     "                      return_tensors=\"pt\", \n",
     "                      padding=True, \n",
     "                      truncation=True).to(\"cuda\")\n",
     "\n",
     "    with torch.no_grad():\n",
     "        outputs = trainer.model.generate(\n",
     "            **batch, \n",
     "            max_new_tokens=2, \n",
     "            do_sample=False,\n",
     "            pad_token_id=tokenizer.eos_token_id,\n",
     "            temperature=None, top_p=None, top_k=None,\n",
     "            return_dict_in_generate=True, output_scores=True\n",
     "        )\n",
     "\n",
     "    seqs = outputs.sequences\n",
     "    prompt_len = batch['input_ids'].shape[1]   # Length of the input prompts\n",
     "\n",
     "    # Slice to get only the generated new tokens\n",
     "    generated_tokens = seqs[:, prompt_len:]\n",
     "    decoded_outputs = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
     "    decoded_outputs = [d.strip().lower() for d in decoded_outputs]\n",
     "\n",
     "    responses_val.extend(decoded_outputs)\n",
     "\n",
     "    scores = outputs.scores\n",
     "    logprobs = [torch.log_softmax(s, dim=-1) for s in scores]\n",
     "    logprobs = logprobs[0][torch.arange(logprobs[0].size(0)), seqs[:, -2].view(-1)]\n",
-    "    logprobs_val.extend(logprobs.cpu().numpy())"
+    "    logprobs_val.extend(logprobs.cpu().numpy())\n",
+    "# Free memory before threshold tuning\n",
+    "gc.collect()\n",
+    "if torch.cuda.is_available():\n",
+    "    torch.cuda.empty_cache()\n"
    ]
   },
   {
    "cell_type": "markdown",
    "id": "86c7a682",
    "metadata": {},
    "source": [
     "#### Store validation and test responses"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 20,
    "id": "772d7131",
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Val 45 ROC-AUC: 1.00\n",
       "Test 45 ROC-AUC: 0.99\n",
       "Val 60 ROC-AUC: 1.00\n",
       "Test 60 ROC-AUC: 0.98\n"
      ]
     }
    ],
    "source": [
+    "# What this block does: Converts generation outputs into calibrated yes-probabilities and reports ROC-AUC by decision time.\n",
+    "# Why we choose this setup: Probability-based evaluation is needed for downstream threshold tuning and gives a robust quality check beyond raw accuracy.\n",
+    "\n",
     "predictions = optimal_state_action_pairs[\"test\"].copy()\n",
     "predictions[\"response\"] = responses_test\n",
     "predictions[\"logprob\"] = logprobs_test\n",
     "predictions[\"prob\"] = predictions[\"logprob\"].apply(lambda x: math.exp(x))\n",
     "predictions.loc[predictions[\"response\"]==\"yes\", \"prob_yes\"] =\\\n",
     "    predictions.loc[predictions[\"response\"]==\"yes\", \"prob\"]\n",
     "predictions.loc[predictions[\"response\"]!=\"yes\", \"prob_yes\"] =\\\n",
     "    1.0 - predictions.loc[predictions[\"response\"]!=\"yes\", \"prob\"]\n",
     "\n",
     "predictions_val = optimal_state_action_pairs[\"val\"].copy()\n",
     "predictions_val[\"response\"] = responses_val\n",
     "predictions_val[\"logprob\"] = logprobs_val\n",
     "predictions_val[\"prob\"] =\\\n",
     "    predictions_val[\"logprob\"].apply(lambda x: math.exp(x))\n",
     "predictions_val.loc[predictions_val[\"response\"]==\"yes\", \"prob_yes\"] =\\\n",
     "    predictions_val.loc[predictions_val[\"response\"]==\"yes\", \"prob\"]\n",
     "predictions_val.loc[predictions_val[\"response\"]!=\"yes\", \"prob_yes\"] =\\\n",
     "    1.0 - predictions_val.loc[predictions_val[\"response\"]!=\"yes\", \"prob\"]\n",
     "\n",
     "test_with_predictions  = pd.merge(left=data_transcripts[\"test\"],\n",
     "                                     right=predictions[[\"conversation_id\",\n",
     "                                                        \"state\", \"prob_yes\"]],\n",
     "                                     left_on=[\"conversation_id\", \"s\" + str(m1)],\n",
     "                                     right_on=[\"conversation_id\", \"state\"],\n",
     "                                     how=\"left\", validate=\"one_to_one\")\\\n",
@@ -1287,50 +1353,53 @@
    "source": [
     "#### Get optimal thresholds using backward-induction threshold tuning\n",
     "\n",
     "We could do a grid search, but this is much faster."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 56,
    "id": "4f728823",
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Test set reward without the stopping agent:\n",
       "Total reward on test: -77.97900000000004\n",
       "Avg. reward on test: -1.5595800000000009\n",
       "Total sales on test: 25\n",
       "Total time on test (seconds): 3279.79\n"
      ]
     }
    ],
    "source": [
+    "# What this block does: Simulates policy reward under threshold choices at m1/m2 and finds validation-optimal thresholds.\n",
+    "# Why we choose this setup: Backward-induction-style threshold tuning scales better than brute-force policy search while matching the stopping objective.\n",
+    "\n",
     "def simulate_threshold(threshold_m1, threshold_m2, df):\n",
     "    # quit at m1\n",
     "    calls_quit_at_m1 = df.loc[(df[\"prob_yes_\" + str(m1)] < threshold_m1)]\n",
     "    \n",
     "    # continue at m1, ended before m2\n",
     "    calls_continued_at_m1_and_ended = df.loc[(df[\"prob_yes_\" + str(m1)] >= threshold_m1) &\n",
     "                                             (df[\"duration\"]<m2)]\n",
     "    \n",
     "    # continued at m1, did not end before m2, quit at m2\n",
     "    calls_continued_at_m1_and_quit_at_m2 = df.loc[(df[\"prob_yes_\" + str(m1)] >= threshold_m1) &\n",
     "                                                  (df[\"prob_yes_\" + str(m2)] < threshold_m2) &\n",
     "                                                  (df[\"duration\"]>=m2)]\n",
     "    \n",
     "    # continue at m1, did not end before m2, continued at m2\n",
     "    calls_continued_at_m2 = df.loc[(df[\"prob_yes_\" + str(m1)] >= threshold_m1) &\n",
     "                                   (df[\"prob_yes_\" + str(m2)] >= threshold_m2) &\n",
     "                                   (df[\"duration\"]>=m2)]\n",
     "    \n",
     "    assert len(calls_quit_at_m1) + len(calls_continued_at_m1_and_ended) +\\\n",
     "          len(calls_continued_at_m1_and_quit_at_m2) + \\\n",
     "          len(calls_continued_at_m2) == len(df)\n",
     "\n",
     "    total_sales = calls_continued_at_m1_and_ended[\"is_sale\"].sum() +\\\n",
     "                    calls_continued_at_m2[\"is_sale\"].sum()\n",
     "    total_sales_benefit = total_sales * BENEFIT_PER_POSITIVE_OUTCOME\n",
@@ -1372,50 +1441,53 @@
    "execution_count": 61,
    "id": "85ea3128",
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Best threshold at m=45: 0.00010000464725449688\n",
       "Best threshold at m=60: 0.0014000817639441545\n",
       "\n",
       "Test set reward with the stopping agent:\n",
       "Total reward on test: -60.261000000000024\n",
       "Avg. reward on test: -1.2052200000000004\n",
       "Total sales on test: 24\n",
       "Total time on test: 3002.61\n",
       "\n",
       "Comparative results:\n",
       "Sales lost by stopping agent: 1\n",
       "Time saved by stopping agent (seconds): 277.17999999999984\n",
       "Time saved by stopping agent (%): 8.451150835876682\n"
      ]
     }
    ],
    "source": [
+    "# What this block does: Evaluates final thresholded policy on test calls and summarizes action frequencies and achieved reward.\n",
+    "# Why we choose this setup: This is the end-to-end objective metric—how much reward the learned stopping policy yields on unseen conversations.\n",
+    "\n",
     "best_threshold_at_m = {}\n",
     "num_grid_points = 10000\n",
     "\n",
     "m = m1\n",
     "prob_column = \"prob_yes_\" + str(m)    \n",
     "best_reward = -10000000\n",
     "for candidate_threshold in np.linspace(val_with_predictions[prob_column].min()-10**-12,\n",
     "                                       val_with_predictions[prob_column].max()+10**-12,\n",
     "                                       num=num_grid_points):\n",
     "    \n",
     "    total_reward, average_reward, total_sales, total_time =\\\n",
     "        simulate_threshold(0, candidate_threshold, val_with_predictions)\n",
     "    \n",
     "    if average_reward > best_reward:\n",
     "        best_reward = average_reward\n",
     "        best_threshold_at_m[m] = candidate_threshold\n",
     "\n",
     "print(\"Best threshold at m=\" + str(m) + \":\", best_threshold_at_m[m])\n",
     "\n",
     "m = m2\n",
     "prob_column = \"prob_yes_\" + str(m)    \n",
     "best_reward = -10000000\n",
     "for candidate_threshold in np.linspace(val_with_predictions[prob_column].min()-10**-12,\n",
     "                                       val_with_predictions[prob_column].max()+10**-12,\n",
     "                                       num=num_grid_points):\n",
''', encoding="utf-8")

print(f"Wrote patch to: {patch_path.resolve()}")
print("Next (from your repo root): git apply path/to/optimal-stopping.patch")
