# GPT-2 vs. Child Word Acquisition
To skip to Plots section, download files here: https://drive.google.com/drive/folders/1gHxewyharnFG-9vuE4CX2FQxhPGzTvM1?usp=sharing
- `wordbank_item_data.csv`
- `contexts_cosmopedia.pkl`
- `stanford-gpt2-small-a_results.zip`
- `stanford-gpt2-medium-a_results.zip`

## Parse Cosmopedia for contexts of simple words
Here, we define simple words as the first words that children learn to produce according to the MacArther-Bates CDI. These words can be found in `data/wordbank_item_data.csv` or https://wordbank.stanford.edu/data/.

We chose Cosmopedia as our text dataset due to its larger variety of topics. Additionally, it is a synthetic dataset, which ensures there is no data leakage in the training text (OpenWebText).

We parse all Cosmopedia subsets for simple words and their contexts (text prefix of `window_size=10`).

For a faster runtime, decrease `TOKEN_LIMIT`.

In [None]:
import pickle
import re
from collections import defaultdict
from typing import List, Dict
import csv
import spacy
from datasets import load_dataset
from tqdm import tqdm

X = 128
TOKEN_LIMIT = 200_000
WORDBANK_PATH = "data/wordbank_item_data.csv"
CONTEXTS_PKL_PATH = "data/contexts_cosmopedia.pkl"

COSMOPEDIA_SUBSETS = [
    "auto_math_text",
    "khanacademy",
    "openstax",
    "stanford",
    "stories",
    "web_samples_v1",
    "web_samples_v2",
    "wikihow",
]

nlp = spacy.load("en_core_web_sm")

In [None]:
def count_words(text: str):
    return re.findall(r"[A-Za-z]+", text.lower())


def sample_cosmopedia_texts(token_limit: int = TOKEN_LIMIT, subsets: List[str] = None):
    if subsets is None:
        subsets = COSMOPEDIA_SUBSETS

    per_subset_limit = max(token_limit // len(subsets), 1)
    texts: List[str] = []
    total_tokens = 0

    for subset in subsets:
        subset_tokens = 0
        ds = load_dataset("HuggingFaceTB/cosmopedia", subset, split="train", streaming=True)
        for item in ds:
            text = item["text"]
            words = count_words(text)
            n = len(words)
            if n == 0:
                continue
            texts.append(text)
            subset_tokens += n
            total_tokens += n
            if subset_tokens >= per_subset_limit or total_tokens >= token_limit:
                break

        if total_tokens >= token_limit:
            break

    print(f"Collected {len(texts)} documents, approx {total_tokens} word tokens.")
    return texts


def load_wordbank_words(path: str = "data/wordbank_item_data.csv"):
    allowed = set()
    with open(path, newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            term = row["item_definition"].strip().lower()
            term = re.sub(r"\s*\([^)]*\)", "", term).strip()
            if term:
                allowed.add(term)
    return allowed


def get_corpus_vocab(texts: List[str]):
    vocab = set()
    for text in texts:
        vocab.update(count_words(text))
    return vocab


def find_simple_words(texts: List[str], wordbank_path: str = "data/wordbank_item_data.csv"):
    allowed_words = load_wordbank_words(wordbank_path)
    corpus_vocab = get_corpus_vocab(texts)
    return corpus_vocab & allowed_words


def collect_contexts_from_texts(
    texts: List[str],
    simple_words: set,
    max_context: int = X,
    window_size: int = 10,
) -> Dict[str, List[str]]:
    
    contexts = defaultdict(set)
    completed_words = set()

    for doc in tqdm(nlp.pipe(texts, batch_size=32), total=len(texts), desc="Collecting contexts"):
        for sent in doc.sents:
            words = [tok.text.lower() for tok in sent if tok.is_alpha]
            for i, w in enumerate(words):
                if w in simple_words and len(contexts[w]) < max_context:
                    start = max(0, i - window_size)
                    prefix = " ".join(words[start:i])

                    if prefix:
                        contexts[w].add(prefix)

                    if len(contexts[w]) >= max_context:
                        completed_words.add(w)
                        
                if len(completed_words) == len(simple_words):
                    break

    return {w: list(contexts[w]) for w in simple_words if len(contexts[w]) >= max_context}

In [None]:
texts = sample_cosmopedia_texts(token_limit=TOKEN_LIMIT, subsets=COSMOPEDIA_SUBSETS)
simple_words = find_simple_words(texts, wordbank_path=WORDBANK_PATH)
print(f"Found {len(simple_words)} simple words.")
contexts = collect_contexts_from_texts(texts, simple_words, max_context=X)
print(f"{len(contexts)} words have at least {X} contexts.")
with open(CONTEXTS_PKL_PATH, "wb") as f:
    pickle.dump(contexts, f)

## Load GPT-2 checkpoints and save surprisal + mean layer attention
From the word contexts, we save the surprisal and mean layer attention for each word and training checkpoint. The GPT-2 checkpoints are loaded from HuggingFace/stanford-crfm which includes ~600 checkpoints up to training step 400,000.

In the `Experiments` class, `compute_batches` takes all the contexts for a given word and turns them into prompts of form `"<prefix> <word>"`.

For each prompt, the code finds the token positions corresponding to the target word (which may be multiple tokens). At each of those positions, it looks at the model’s predicted distribution before the token (i.e., logits at position −1), reads off the log-probability of the actual token, and sums these over all word tokens. Surprisal in bits is then defined as `-log p(word|prefix)`. A higher surprisal means the model found the word less expected in that context.

From the attention tensors, we average how much attention all heads and all query positions of each layer pay into the positions of the target word tokens.

We ran this section of the code on RunPod A100. Decrease number of contexts (`X`) or `TOKEN_LIMIT` above for a faster runtime.

In [None]:
import os
import math
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.loading_from_pretrained import get_checkpoint_labels
from huggingface_hub.utils import RevisionNotFoundError

MODEL_NAME = "stanford-gpt2-small-a"  # or "stanford-gpt2-medium-a"
DTYPE = torch.float16
DEVICE = "cuda"
BATCH_SIZE = 128
CHECKPOINT_STRIDE = 12  # Loads a total of ~50 checkpoints
MODEL_OUT_PATH = "stanford-gpt2-small-a_results"

In [None]:
class Experiment:
    def __init__(self,
                 model: HookedTransformer,
                 batch_size: int):
        self.model = model
        self.batch_size = batch_size
        self.device = str(self.model.cfg.device)
        self.model.eval()
        self.model.tokenizer.padding_side = "right"

    def compute_batches(self,
                        prefixes: List[str],
                        word: str):
        full_text_l = [prefix.rstrip() + " " + word for prefix in prefixes]
        for i in range(0, len(full_text_l), self.batch_size):
            yield {
                "prefixes": prefixes[i:i + self.batch_size],
                "full_texts": full_text_l[i:i + self.batch_size],
            }

    def compute_output(self,
                       batch: Dict[str, List[str]]):
        prefixes, full_texts = batch["prefixes"], batch["full_texts"]
        batch_size = len(prefixes)

        prefix_lens = [len(self.model.to_tokens(prefix, prepend_bos=False)[0]) for prefix in prefixes]
        full_lens = [len(self.model.to_tokens(full_text, prepend_bos=False)[0]) for full_text in full_texts]
        num_word_tokens = [full_len - pref_len for pref_len, full_len in zip(prefix_lens, full_lens)]

        full_tokens = self.model.to_tokens(full_texts, prepend_bos=False)

        with torch.no_grad():
            logits, cache = self.model.run_with_cache(full_tokens)

        total_logprob_e = torch.zeros(batch_size, device=self.device)
        for i in range(batch_size):
            if num_word_tokens[i] <= 0:
                continue
            word_token_positions = range(prefix_lens[i], full_lens[i])
            for position in word_token_positions:
                if position == 0:
                    continue
                logits_pos = logits[i, position - 1]
                log_probs = F.log_softmax(logits_pos, dim=-1)
                token_id = int(full_tokens[i, position].item())
                total_logprob_e[i] += log_probs[token_id]

        surprisal_bits = (-total_logprob_e / math.log(2.0)).tolist()

        n_layers = self.model.cfg.n_layers
        layer_avg_attn = torch.zeros(batch_size, n_layers, device=self.device)

        for i in range(batch_size):
            if num_word_tokens[i] <= 0:
                continue
            word_token_positions = range(prefix_lens[i], full_lens[i])
            for layer in range(n_layers):
                attn = cache["attn", layer][i]
                into_word = attn[:, :, word_token_positions]
                layer_avg_attn[i, layer] = into_word.mean()

        del cache
        if self.device.startswith("cuda"):
            torch.cuda.empty_cache()

        return surprisal_bits, num_word_tokens, layer_avg_attn.cpu()

    def compute_output_dict(self,
                            contexts: Dict[str, List[str]]):
        output_dict = {}
        words = list(contexts.keys())

        for word in tqdm(words):
            prefixes = contexts[word]
            surprisals = []
            token_counts = []
            layer_attn_vals = []

            for batch in self.compute_batches(prefixes, word):
                surprisal_bits, num_word_tokens, layer_avg_attn = self.compute_output(batch)
                surprisals.extend(surprisal_bits)
                token_counts.extend(num_word_tokens)
                layer_attn_vals.append(layer_avg_attn)

            if len(surprisals) == 0:
                continue

            layer_attn_vals = torch.cat(layer_attn_vals, dim=0)

            avg = sum(surprisals) / len(surprisals)
            per_token_vals = [s / t if t > 0 else float("nan") for s, t in zip(surprisals, token_counts)]
            per_token_vals = [x for x in per_token_vals if not math.isnan(x)]
            if len(per_token_vals) > 0:
                avg_per_token = sum(per_token_vals) / len(per_token_vals)
            else:
                avg_per_token = float("nan")

            avg_layer_attn = layer_attn_vals.mean(dim=0).tolist()

            output_dict[word] = {
                "avg_surprisal": avg,
                "avg_surprisal_per_token": avg_per_token,
                "surprisals_list": surprisals,
                "avg_layer_attn": avg_layer_attn,
            }

        return output_dict

In [None]:
os.makedirs(MODEL_OUT_PATH, exist_ok=True)
labels, label_type = get_checkpoint_labels(MODEL_NAME)
num_ckpts = len(labels)

print(f"Model: {MODEL_NAME}")
print(f"Checkpoint label type: {label_type}")
print(f"Total checkpoints in table: {num_ckpts}")
print(f"Using stride {CHECKPOINT_STRIDE}")

with open(CONTEXTS_PKL_PATH, "rb") as f:
    contexts = pickle.load(f)

for idx in range(0, num_ckpts, CHECKPOINT_STRIDE):
    label = labels[idx]
    out_pkl = os.path.join(
        MODEL_OUT_PATH,
        f"results_ckpt_idx{idx:04d}_{label_type}{label}.pkl",
    )

    if os.path.exists(out_pkl):
        print(f"[{idx}] Skipping (results already exist): {out_pkl}")
        continue

    print(f"[{idx}] Loading model {MODEL_NAME} checkpoint_index={idx} (label={label_type} {label})")

    try:
        model = HookedTransformer.from_pretrained_no_processing(
            MODEL_NAME,
            checkpoint_index=idx,
            device=DEVICE,
            dtype=DTYPE,
        )
    except RevisionNotFoundError as e:
        print(f"[{idx}] SKIP: Revision not found on Hugging Face: {e}")
        continue
    except OSError as e:
        msg = str(e)
        if "not a valid git identifier" in msg:
            print(f"[{idx}] SKIP: Invalid git revision for this checkpoint: {msg}")
            continue
        else:
            print(f"[{idx}] ERROR: OSError when loading checkpoint: {msg}")
            continue
    except Exception as e:
        print(f"[{idx}] ERROR: Unexpected exception: {e}")
        continue

    experiment = Experiment(model=model, batch_size=BATCH_SIZE)
    output_dict = experiment.compute_output_dict(contexts)

    with open(out_pkl, "wb") as f_out:
        pickle.dump(output_dict, f_out)

    del model
    if DEVICE.startswith("cuda"):
        torch.cuda.empty_cache()

    print(f"[{idx}] Saved results to: {out_pkl}")

## Plots
We generate plots of the average attention and average surprisal for GPT-2 small and medium across training checkpoints. This includes a plot for the top 10, 100, and 500 simple words.

We also generate a merged plot of the change in surprisal for each word (`margin_idx=2` indexes after crossing the AoA threshold), the proportion of children who have learned to a produce the word, and an overlay of both plots aligned by the AoA threshold (child AoA threshold is 0.5).

Finally, we generate the same merged plot for change in mean layer attention. These plots are not cropped or aligned in any particular way.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

SMALL_DIR = "stanford-gpt2-small-a_results"
MEDIUM_DIR = "stanford-gpt2-medium-a_results"
FIG_OUT_DIR = "figs"
KS = [10, 100, 500]
BASELINE_BITS = 15.6
NUM_PLOTS = 10

In [None]:
def normalize_cdi_label(w: str) -> str:
    w = w.lower().strip()
    w = re.sub(r"\s*\(.*?\)\s*", "", w)
    w = re.sub(r"\s+", " ", w)
    return w


def load_wordbank_aoa(csv_path: str):
    df = pd.read_csv(csv_path)
    aoa_cols = [c for c in df.columns if c.isdigit()]
    aoa_cols_sorted = sorted(aoa_cols, key=lambda x: int(x))
    word_to_aoa = {}
    for _, row in df.iterrows():
        word = normalize_cdi_label(str(row["item_definition"]))
        if not word:
            continue
        aoa = math.nan
        for col in aoa_cols_sorted:
            try:
                v = float(row[col])
            except (TypeError, ValueError):
                continue
            if v >= 0.5:
                aoa = float(col)
                break
        if not math.isnan(aoa):
            if word in word_to_aoa:
                word_to_aoa[word] = statistics.mean([word_to_aoa[word], aoa])
            else:
                word_to_aoa[word] = aoa
    return word_to_aoa


def load_wordbank_curves(csv_path: str):
    df = pd.read_csv(csv_path)
    aoa_cols = [c for c in df.columns if c.isdigit()]
    aoa_cols_sorted = sorted(aoa_cols, key=lambda x: int(x))
    months = [int(c) for c in aoa_cols_sorted]
    word_to_curve: Dict[str, np.ndarray] = {}
    for _, row in df.iterrows():
        word = normalize_cdi_label(str(row["item_definition"]))
        if not word:
            continue
        vals = []
        for col in aoa_cols_sorted:
            v = row[col]
            try:
                vals.append(float(v))
            except (TypeError, ValueError):
                vals.append(math.nan)
        vals_arr = np.array(vals, dtype=float)
        if word in word_to_curve:
            prev = word_to_curve[word]
            stacked = np.vstack([prev, vals_arr])
            word_to_curve[word] = np.nanmean(stacked, axis=0)
        else:
            word_to_curve[word] = vals_arr
    return months, word_to_curve


def parse_step_from_fname(fname: str):
    m_label = re.search(r"idx\d+_([A-Za-z]+)(\d+)\.pkl$", fname)
    if not m_label:
        raise ValueError(f"Cannot parse label from {fname}")
    step = int(m_label.group(2))
    return step


def load_results_dir(results_dir: str):
    entries = []
    for fname in os.listdir(results_dir):
        if not fname.endswith(".pkl"):
            continue
        step = parse_step_from_fname(fname)
        path = os.path.join(results_dir, fname)
        with open(path, "rb") as f:
            data = pickle.load(f)
        word_surprisal = {}
        word_act = {}
        for w, info in data.items():
            w_norm = normalize_cdi_label(w)
            word_surprisal[w_norm] = float(info["avg_surprisal"])
            layer_attn = info.get("avg_layer_attn", None)
            if layer_attn is not None:
                layer_arr = np.array(layer_attn, dtype=float)
                word_act[w_norm] = float(np.nanmean(layer_arr))
        entries.append((step, word_surprisal, word_act))
    entries.sort(key=lambda x: x[0])
    steps = [e[0] for e in entries]
    all_words_surprisal: Set[str] = set()
    all_words_act: Set[str] = set()
    for _, ws, wa in entries:
        all_words_surprisal.update(ws.keys())
        all_words_act.update(wa.keys())
    all_words = all_words_surprisal | all_words_act
    word_to_surprisal: Dict[str, List[float]] = {}
    word_to_act: Dict[str, List[float]] = {}
    for w in all_words:
        s_series = []
        a_series = []
        for _, ws, wa in entries:
            s_series.append(float(ws.get(w, math.nan)))
            a_series.append(float(wa.get(w, math.nan)))
        word_to_surprisal[w] = s_series
        word_to_act[w] = a_series
    return steps, word_to_surprisal, word_to_act


def get_simple_ranking(word_aoa: Dict[str, float], available_words: set[str], max_n: int):
    items = [(w, aoa) for w, aoa in word_aoa.items() if w in available_words and not math.isnan(aoa)]
    items.sort(key=lambda x: (x[1], x[0]))
    return [w for w, _ in items[:max_n]]


def compute_avg_series(word_to_series: Dict[str, List[float]], words: List[str]):
    if not words:
        return np.array([], dtype=float)
    arr = []
    for w in words:
        arr.append(np.array(word_to_series[w], dtype=float))
    arr = np.stack(arr, axis=0)
    with np.errstate(invalid="ignore"):
        return np.nanmean(arr, axis=0)


def compute_thresholds_per_word(
    word_to_series: Dict[str, List[float]],
    baseline_bits: float,
    words: List[str],
    steps: List[int],
):
    aoa_log10 = compute_llm_aoa_steps(
        word_to_series=word_to_series,
        steps=steps,
        baseline_bits=baseline_bits,
        words=words,
    )
    aoa_steps = {}
    for w, x_star in aoa_log10.items():
        aoa_steps[w] = float(10.0 ** x_star)
    return aoa_steps


def logistic4(x, L, k, x0, b):
    x = np.asarray(x, dtype=float)
    z = k * (x - x0)
    z = np.clip(z, -60, 60)
    return L / (1.0 + np.exp(z)) + b


def compute_llm_aoa_steps(
    word_to_series: Dict[str, List[float]],
    steps: List[int],
    baseline_bits: float,
    words: List[str],
):
    aoa_log10: Dict[str, float] = {}
    if not words:
        return aoa_log10

    step_arr = np.array(steps, dtype=float)
    safe_steps = np.where(step_arr > 0, step_arr, 1.0)
    log_steps = np.log10(safe_steps)
    mask_log = np.isfinite(log_steps)

    for w in words:
        s = np.array(word_to_series[w], dtype=float)
        mask = mask_log & np.isfinite(s)
        x = log_steps[mask]
        y = s[mask]
        y_min = float(np.min(y))
        y_max = float(np.max(y))
        if not np.isfinite(y_min) or not np.isfinite(y_max) or y_max == y_min:
            continue

        thr = 0.5 * (baseline_bits + y_min)

        L0 = max(y_max - y_min, 1e-3)
        b0 = y_min
        x0_0 = float(np.median(x))
        k0 = 1.0

        x_star = None

        try:
            popt, _ = curve_fit(
                logistic4,
                x,
                y,
                p0=[L0, k0, x0_0, b0],
                maxfev=10000,
            )
            L, k, x0, b = popt

            if L * k == 0:
                raise RuntimeError("flat logistic fit")

            f_lo = float(logistic4(x[0], *popt))
            f_hi = float(logistic4(x[-1], *popt))

            lo_val = min(f_lo, f_hi)
            hi_val = max(f_lo, f_hi)

            if not (lo_val <= thr <= hi_val):
                raise RuntimeError("threshold outside fit range")

            lo_x = x[0]
            hi_x = x[-1]

            if f_lo <= f_hi:
                for _ in range(60):
                    mid_x = 0.5 * (lo_x + hi_x)
                    val = float(logistic4(mid_x, *popt))
                    if val < thr:
                        lo_x = mid_x
                    else:
                        hi_x = mid_x
            else:
                for _ in range(60):
                    mid_x = 0.5 * (lo_x + hi_x)
                    val = float(logistic4(mid_x, *popt))
                    if val > thr:
                        lo_x = mid_x
                    else:
                        hi_x = mid_x

            x_star = 0.5 * (lo_x + hi_x)

        except Exception:
            idx_val = None

            for j in range(1, len(y)):
                y_prev, y_curr = y[j - 1], y[j]
                if not (np.isfinite(y_prev) and np.isfinite(y_curr)):
                    continue

                if (y_prev - thr) * (y_curr - thr) <= 0:
                    idx_val = j
                    break

            if idx_val is None:
                idx_val = int(np.argmin(np.abs(y - thr)))
            x_star = float(x[idx_val])

        if x_star is not None and np.isfinite(x_star):
            aoa_log10[w] = x_star

    return aoa_log10


def normalize_x(xs: np.ndarray):
    xs = np.array(xs, dtype=float)
    if xs.size == 0:
        return xs
    denom = xs[-1] - xs[0]
    if denom <= 0:
        return np.zeros_like(xs)
    return (xs - xs[0]) / denom


def compute_child_interp_aoa(wordbank_csv: str):
    months, word_to_curve = load_wordbank_curves(wordbank_csv)
    months_arr = np.array(months, dtype=float)

    child_interp_aoa = {}
    for w, curve in word_to_curve.items():
        curve_arr = np.array(curve, dtype=float)
        if not np.isfinite(curve_arr).any():
            continue

        mask = np.isfinite(curve_arr)
        m = months_arr[mask]
        y = curve_arr[mask]
        if m.size < 2:
            continue

        idx_cross = None
        for j in range(1, y.size):
            y_prev = y[j - 1]
            y_curr = y[j]
            if not (np.isfinite(y_prev) and np.isfinite(y_curr)):
                continue
            if (y_prev - 0.5) * (y_curr - 0.5) <= 0:
                idx_cross = j
                break

        if idx_cross is None:
            continue

        m1, m2 = m[idx_cross - 1], m[idx_cross]
        y1, y2 = y[idx_cross - 1], y[idx_cross]
        if not (np.isfinite(y1) and np.isfinite(y2)) or m2 == m1:
            continue

        if y2 == y1:
            aoa_month = m2
        else:
            aoa_month = m1 + (0.5 - y1) * (m2 - m1) / (y2 - y1)

        if np.isfinite(aoa_month) and aoa_month > 0:
            child_interp_aoa[w] = float(aoa_month)

    return child_interp_aoa


def normalize_x_aligned(xs, align_idx, child_aoa_x):
    xs = np.array(xs, dtype=float)
    if xs.size == 0 or align_idx >= len(xs) or align_idx < 0:
        return normalize_x(xs)
    
    x_align = xs[align_idx]
    x_start = xs[0]
    x_end = xs[-1]
    xs_norm = (xs - x_start) / (x_end - x_start)
    align_norm = (x_align - x_start) / (x_end - x_start)

    if align_norm == 0 or not np.isfinite(child_aoa_x) or child_aoa_x == 0:
        return xs_norm

    xs_aligned = xs_norm * (child_aoa_x / align_norm)
    
    return xs_aligned


def compute_threshold_crossing_idx(s, baseline_bits):
    s = np.array(s, dtype=float)
    if not np.isfinite(s).any():
        return None
    s_min = float(np.nanmin(s))
    thr = 0.5 * (baseline_bits + s_min)
    
    for j, v in enumerate(s):
        if np.isfinite(v) and v <= thr:
            return j
    return len(s) - 1


def crop_with_threshold(s, steps_arr, baseline_bits, margin_idx):
    s = np.array(s, dtype=float)
    if not np.isfinite(s).any():
        return None, None, None, None
    s_min = float(np.nanmin(s))
    thr = 0.5 * (baseline_bits + s_min)
    idx_cross = None
    for j, v in enumerate(s):
        if np.isfinite(v) and v <= thr:
            idx_cross = j
            break
    if idx_cross is None:
        idx_cross = len(s) - 1
    end_idx = min(idx_cross + margin_idx, len(s) - 1)
    s_crop = s[: end_idx + 1]
    steps_crop = steps_arr[: end_idx + 1]
    return s_crop, steps_crop, thr, idx_cross

In [None]:
os.makedirs(FIG_OUT_DIR, exist_ok=True)

word_aoa = load_wordbank_aoa(WORDBANK_PATH)
months, word_to_curve = load_wordbank_curves(WORDBANK_PATH)

steps_small, small_surpr = load_results_dir(SMALL_DIR)
steps_medium, medium_surpr = load_results_dir(MEDIUM_DIR)

words_small = set(small_surpr.keys())
words_medium = set(medium_surpr.keys())
available_words = words_small & words_medium

simple_ranking = get_simple_ranking(word_aoa, available_words, 500)

In [None]:
# Child trajectories vs LLM surprisal (per word)
if simple_ranking:
    max_words_side_by_side = 10
    n_plotted = 0
    margin_idx = 3
    months_arr = np.array(months, dtype=float)
    if months_arr.size > 1:
        months_norm = (months_arr - months_arr.min()) / (months_arr.max() - months_arr.min())
    else:
        months_norm = np.zeros_like(months_arr)

    for w in simple_ranking[150:]:
        if w not in word_to_curve:
            continue
        child_curve = word_to_curve[w]
        if not np.isfinite(child_curve).any():
            continue
        s_small = np.array(small_surpr[w], dtype=float)
        s_medium = np.array(medium_surpr[w], dtype=float)
        if not (np.isfinite(s_small).any() or np.isfinite(s_medium).any()):
            continue

        steps_small_arr = np.array(steps_small, dtype=float)
        steps_medium_arr = np.array(steps_medium, dtype=float)

        s_small_crop, steps_small_crop, thr_small, idx_cross_small = crop_with_threshold(
            s_small, steps_small_arr, args.baseline_bits, margin_idx
        )
        s_medium_crop, steps_medium_crop, thr_medium, idx_cross_medium = crop_with_threshold(
            s_medium, steps_medium_arr, args.baseline_bits, margin_idx
        )

        if s_small_crop is None and s_medium_crop is None:
            continue

        fig, axes = plt.subplots(1, 3, figsize=(12, 3))

        if s_small_crop is not None and len(s_small_crop) > 0:
            axes[0].plot(steps_small_crop, s_small_crop, label="gpt2-small")
            if thr_small is not None:
                axes[0].axhline(thr_small, linestyle="--", linewidth=1, color="cornflowerblue")
        if s_medium_crop is not None and len(s_medium_crop) > 0:
            axes[0].plot(steps_medium_crop, s_medium_crop, label="gpt2-medium")
            if thr_medium is not None:
                axes[0].axhline(thr_medium, linestyle="--", linewidth=1, color="orange")

        axes[0].set_xlabel("Step")
        axes[0].set_ylabel("Surprisal (bits)")
        axes[0].set_title(f"{w} - LLM surprisal")
        axes[0].legend()
        axes[0].invert_yaxis()

        axes[1].plot(months, child_curve, marker="o")
        axes[1].axhline(0.5, linestyle="--", linewidth=1, color="cornflowerblue")
        axes[1].set_xlabel("Age (months)")
        axes[1].set_ylabel("Proportion producing")
        axes[1].set_ylim(0.0, 1.0)
        axes[1].set_title(f"{w} - Children")

        ax3 = axes[2]
        ax3b = ax3.twinx()

        # Align AoA
        child_aoa_x = None
        if w in compute_child_interp_aoa:
            child_aoa_month = compute_child_interp_aoa[w]
            if months_arr.size > 1:
                child_aoa_x = (child_aoa_month - months_arr.min()) / (months_arr.max() - months_arr.min())
                child_aoa_x = np.clip(child_aoa_x, 0.0, 1.0)
        else:
            child_mask = np.isfinite(child_curve)
            if child_mask.any():
                first_valid_val = child_curve[child_mask][0]
                if first_valid_val >= 0.5:
                    child_aoa_x = 0.0
                else:
                    plt.close(fig)
                    continue

        if child_aoa_x is None:
            plt.close(fig)
            continue

        if np.isfinite(s_small).any() and idx_cross_small is not None:
            x_small_norm = normalize_x_aligned(steps_small_arr, idx_cross_small, child_aoa_x)
            ax3.plot(x_small_norm, s_small, label="gpt2-small")
        if np.isfinite(s_medium).any() and idx_cross_medium is not None:
            x_medium_norm = normalize_x_aligned(steps_medium_arr, idx_cross_medium, child_aoa_x)
            ax3.plot(x_medium_norm, s_medium, label="gpt2-medium")

        child_mask = np.isfinite(child_curve)
        if child_mask.any():
            ax3b.plot(months_norm[child_mask], child_curve[child_mask], marker="o", color="green", label="children")

        all_s_vals = []
        for arr in (s_small, s_medium):
            if arr is not None:
                arr = np.asarray(arr, dtype=float)
                arr = arr[np.isfinite(arr)]
                if arr.size > 0:
                    all_s_vals.extend(arr.tolist())

        # Align AoA threshold
        if all_s_vals:
            min_surprisal = float(min(all_s_vals))
            max_surprisal = float(max(all_s_vals))
            data_range = max_surprisal - min_surprisal
            if data_range <= 0:
                data_range = 1.0
            buffer = 0.1 * data_range
            half_total = 0.5 * data_range + buffer
            y_max = thr_small + half_total
            y_min = thr_small - half_total
            ax3.set_ylim(y_max, y_min)

        ax3.set_xlim(0.0, 1.0)
        ax3.set_xlabel("Normalized timeline")
        ax3.set_ylabel("Surprisal (bits)")
        ax3b.set_ylabel("Proportion producing")
        ax3b.set_ylim(0.0, 1.0)
        ax3.set_title(f"{w} - Normalized aligned overlay")

        handles1, labels1 = ax3.get_legend_handles_labels()
        handles2, labels2 = ax3b.get_legend_handles_labels()
        if handles1 or handles2:
            ax3.legend(handles1 + handles2, labels1 + labels2, loc="best")

        fig.tight_layout()
        safe_w = re.sub(r"[^A-Za-z0-9]+", "_", w).strip("_")
        out_path_word = os.path.join(args.out_dir, f"word_{safe_w}_child_vs_llm_surprisal.png")
        fig.savefig(out_path_word, bbox_inches="tight")
        plt.close(fig)
        n_plotted += 1
        if n_plotted >= max_words_side_by_side:
            break