In [1]:
import os
import json
import glob
import pandas as pd
import plotly.express as px
from pathlib import Path
from collections import Counter
from statistics import median


In [2]:
def find_repo_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        if (candidate / "pyproject.toml").exists() and (candidate / "data").exists():
            return candidate
    raise FileNotFoundError("Could not locate repo root with pyproject.toml and data/")


def discover_input_files(input_glob: str, base_dir: Path | None = None) -> list[Path]:
    base = (base_dir or Path.cwd()).resolve()
    expanded_glob = str(Path(input_glob).expanduser())
    if not Path(expanded_glob).is_absolute():
        expanded_glob = str(base / expanded_glob)

    matches = [Path(path).resolve() for path in glob.glob(expanded_glob, recursive=True)]
    json_files = [
        path for path in matches if path.is_file() and path.suffix.lower() == ".json"
    ]
    return sorted(set(json_files))


repo_root = find_repo_root(Path.cwd())
path_glob = "data/**/*.json"
files = discover_input_files(path_glob, base_dir=repo_root)

print(f"cwd={Path.cwd()}")
print(f"There are {len(files)} files in {repo_root / path_glob}")

cwd=/home/israel/traj_distill/notebooks
There are 52 files in /home/israel/traj_distill/data/**/*.json


In [3]:
assistant_messages = []

for index, input_path in enumerate(files, start=1):
    num_assistant_messages = 0
    cur_trajectory = json.loads(input_path.read_text(encoding="utf-8"))
    
    message_list = cur_trajectory["messages"]

    assert isinstance(message_list, list) 
    
    for message in message_list:
        if message["role"] == "assistant":
            num_assistant_messages += 1
            # print(message["content"])
            # break
    assistant_messages.append(num_assistant_messages)

In [4]:
freq = Counter(assistant_messages)
df = pd.DataFrame({
    "assistant_messages": sorted(freq.keys()),
    "num_trajectories": [freq[k] for k in sorted(freq.keys())],
})

fig = px.bar(
    df,
    x="assistant_messages",
    y="num_trajectories",
    title="Distribution of Assistant Messages per Reasoning Trajectory",
    labels={
        "assistant_messages": "Assistant messages in trajectory",
        "num_trajectories": "Number of trajectories",
    },
)
fig.update_traces(hovertemplate="assistant_messages=%{x}<br>trajectories=%{y}<extra></extra>")
fig.show()

print(f"Trajectories: {len(assistant_messages)}")
print(f"Min/Median/Max: {min(assistant_messages)}/{sorted(assistant_messages)[len(assistant_messages)//2]}/{max(assistant_messages)}")

Trajectories: 52
Min/Median/Max: 7/23/71


In [5]:
def normalize_message(message: dict, location: str) -> dict[str, str]:
    if not isinstance(message, dict):
        raise ValueError(f"{location} must be a dict.")
    role = message.get("role")
    content = message.get("content")
    if not isinstance(role, str) or not role.strip():
        raise ValueError(f"{location}.role must be a non-empty string.")
    if not isinstance(content, str):
        raise ValueError(f"{location}.content must be a string.")
    return {"role": role.strip(), "content": content}


def validate_trajectory_messages(messages: list[dict], source_path: Path) -> None:
    if not isinstance(messages, list) or not messages:
        raise ValueError(f"{source_path}: `messages` must be a non-empty list.")

    first = normalize_message(messages[0], f"{source_path} messages[0]")
    if first["role"] != "system":
        raise ValueError(f"{source_path}: first role must be `system`.")

    # After system, roles should alternate user/assistant.
    for idx, raw in enumerate(messages[1:], start=1):
        msg = normalize_message(raw, f"{source_path} messages[{idx}]")
        expected = "user" if idx % 2 == 1 else "assistant"
        if msg["role"] != expected:
            raise ValueError(
                f"{source_path}: messages[{idx}] role is {msg['role']!r}, expected {expected!r}."
            )


def split_cot_and_answer(content: str) -> tuple[str, str] | None:
    split_index = content.find("```bash")
    if split_index == -1:
        return None

    cot = content[:split_index].rstrip()
    answer = content[split_index:]
    return cot, answer


def expand_trajectory_to_samples(
    *,
    messages: list[dict],
    trajectory_id: str,
    source_path: Path,
) -> list[dict]:
    samples: list[dict] = []
    prefix: list[dict[str, str]] = []
    assistant_turn_index = 0

    for msg_index, raw in enumerate(messages):
        msg = normalize_message(raw, f"{source_path} messages[{msg_index}]")
        role = msg["role"]

        if role == "assistant":
            assistant_turn_index += 1
            if not prefix or prefix[-1]["role"] != "user":
                raise ValueError(
                    f"{source_path}: assistant at index {msg_index} must follow a user message."
                )

            parsed = split_cot_and_answer(msg["content"])
            if parsed is not None:
                cot, answer = parsed

                # Context excludes current assistant target.
                samples.append(
                    {
                        "sample_id": f"{trajectory_id}::a{assistant_turn_index}",
                        "trajectory_id": trajectory_id,
                        "source_path": str(source_path),
                        "assistant_turn_index": assistant_turn_index,
                        "target_message_index": msg_index,
                        "question": [dict(x) for x in prefix],
                        "cot": cot,
                        "answer": answer,
                        "target": dict(msg),
                    }
                )

        prefix.append(msg)

    return samples


In [6]:
all_samples = []
samples_per_trajectory = []

for input_path in files:
    payload = json.loads(input_path.read_text(encoding="utf-8"))

    if not isinstance(payload, dict):
        raise ValueError(f"{input_path}: top-level JSON must be an object.")
    if "messages" not in payload:
        raise ValueError(f"{input_path}: missing `messages`.")

    messages = payload["messages"]
    validate_trajectory_messages(messages, input_path)

    trajectory_id = str(payload.get("instance_id") or input_path.stem)
    samples = expand_trajectory_to_samples(
        messages=messages,
        trajectory_id=trajectory_id,
        source_path=input_path,
    )

    all_samples.extend(samples)
    samples_per_trajectory.append(len(samples))

output_path = repo_root / "data" / "constructed" / "context_target_v1.jsonl"
output_path.parent.mkdir(parents=True, exist_ok=True)

with output_path.open("w", encoding="utf-8") as handle:
    for sample in all_samples:
        handle.write(json.dumps(sample, ensure_ascii=True) + "\n")

# 80/20 split by trajectory_id (deterministic with seed=42)
trajectory_ids = sorted({sample["trajectory_id"] for sample in all_samples})
if not trajectory_ids:
    raise ValueError("No samples were created, cannot produce train/eval splits.")

import random
rng = random.Random(42)
shuffled_ids = trajectory_ids[:]
rng.shuffle(shuffled_ids)

if len(shuffled_ids) == 1:
    eval_count = 1
else:
    eval_count = int(round(0.20 * len(shuffled_ids)))
    eval_count = max(1, min(eval_count, len(shuffled_ids) - 1))

eval_trajectory_ids = set(shuffled_ids[:eval_count])
train_trajectory_ids = set(shuffled_ids[eval_count:])

train_samples = [
    sample for sample in all_samples if sample["trajectory_id"] in train_trajectory_ids
]
eval_samples = [
    sample for sample in all_samples if sample["trajectory_id"] in eval_trajectory_ids
]

train_path = output_path.with_name("context_target_v1_train.jsonl")
eval_path = output_path.with_name("context_target_v1_eval.jsonl")

with train_path.open("w", encoding="utf-8") as handle:
    for sample in train_samples:
        handle.write(json.dumps(sample, ensure_ascii=True) + "\n")

with eval_path.open("w", encoding="utf-8") as handle:
    for sample in eval_samples:
        handle.write(json.dumps(sample, ensure_ascii=True) + "\n")

print(f"Wrote {len(all_samples)} samples to {output_path}")
print(
    f"Split trajectories: train={len(train_trajectory_ids)} eval={len(eval_trajectory_ids)} "
    f"(seed=42)"
)
print(f"Wrote {len(train_samples)} train samples to {train_path}")
print(f"Wrote {len(eval_samples)} eval samples to {eval_path}")


Wrote 1405 samples to /home/israel/traj_distill/data/constructed/context_target_v1.jsonl
Split trajectories: train=42 eval=10 (seed=42)
Wrote 1141 train samples to /home/israel/traj_distill/data/constructed/context_target_v1_train.jsonl
Wrote 264 eval samples to /home/israel/traj_distill/data/constructed/context_target_v1_eval.jsonl


In [7]:
if "assistant_messages" in globals():
    expected = sum(assistant_messages)
    assert len(all_samples) <= expected, (
        f"sample count mismatch: got {len(all_samples)}, expected <= {expected}"
    )
    print(f"Skipped assistant turns without ```bash: {expected - len(all_samples)}")

print(f"Trajectories: {len(samples_per_trajectory)}")
print(f"Samples: {len(all_samples)}")
# print(
    # "Samples per trajectory (min/mean/max): "
    # f"{min(samples_per_trajectory)}/{mean(samples_per_trajectory)}/{max(samples_per_trajectory)}"
# )

def _shorten(text: str, limit: int = 140) -> str:
    return text if len(text) <= limit else text[:limit] + "..."

for sample in all_samples[:2]:
    preview = {
        "sample_id": sample["sample_id"],
        "assistant_turn_index": sample["assistant_turn_index"],
        "question": sample["question"],
        "cot_preview": _shorten(sample["cot"]),
        "answer_preview": _shorten(sample["answer"]),
        "target_preview": _shorten(sample["target"]["content"]),
    }
    print(json.dumps(preview, indent=2, ensure_ascii=True))


Skipped assistant turns without ```bash: 8
Trajectories: 52
Samples: 1405
{
  "sample_id": "astropy__astropy-14995::a1",
  "assistant_turn_index": 1,
  "question": [
    {
      "role": "system",
      "content": "You are a helpful assistant that can interact multiple times with a computer shell to solve programming tasks.\nYour response must contain exactly ONE bash code block with ONE command (or commands connected with && or ||).\n\nInclude a THOUGHT section before your command where you explain your reasoning process.\nFormat your response as shown in <format_example>.\n\n<format_example>\nTHOUGHT: Your reasoning and analysis here\n\n```bash\nyour_command_here\n```\n</format_example>\n\nFailure to follow these rules will cause your response to be rejected."
    },
    {
      "role": "user",
      "content": "<pr_description>\nConsider the following PR description:\nIn v5.3, NDDataRef mask propagation fails when one of the operand does not have a mask\n### Description\n\nThis appli

In [12]:
all_samples[1]

{'sample_id': 'astropy__astropy-14995::a2',
 'trajectory_id': 'astropy__astropy-14995',
 'source_path': '/home/israel/traj_distill/data/resolved_trajectories_openai__Qwen__Qwen3-Coder-30B-A3B-Instruct.eval_default/astropy__astropy-14995__astropy__astropy-14995.traj.json',
 'assistant_turn_index': 2,
 'target_message_index': 4,
 'question': [{'role': 'system',
   'content': 'You are a helpful assistant that can interact multiple times with a computer shell to solve programming tasks.\nYour response must contain exactly ONE bash code block with ONE command (or commands connected with && or ||).\n\nInclude a THOUGHT section before your command where you explain your reasoning process.\nFormat your response as shown in <format_example>.\n\n<format_example>\nTHOUGHT: Your reasoning and analysis here\n\n```bash\nyour_command_here\n```\n</format_example>\n\nFailure to follow these rules will cause your response to be rejected.'},
  {'role': 'user',
   'content': '<pr_description>\nConsider th

In [8]:
from transformers import AutoTokenizer

MODEL_ID = "Qwen/Qwen3-1.7B"

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
except Exception:
    # Fallback for offline sessions with a cached snapshot.
    snapshot_roots = sorted(
        Path.home().glob(".cache/huggingface/hub/models--Qwen--Qwen3-1.7B/snapshots/*")
    )
    if not snapshot_roots:
        raise
    tokenizer = AutoTokenizer.from_pretrained(
        str(snapshot_roots[-1]),
        use_fast=True,
        local_files_only=True,
    )

print(tokenizer.__class__.__name__)


Qwen2TokenizerFast


In [9]:
dataset_path = repo_root / "data" / "constructed" / "context_target_v1.jsonl"

if "all_samples" in globals() and isinstance(all_samples, list) and all_samples:
    samples_for_counting = all_samples
else:
    samples_for_counting = []
    with dataset_path.open("r", encoding="utf-8") as handle:
        for line in handle:
            if line.strip():
                samples_for_counting.append(json.loads(line))


def count_context_tokens(question_messages: list[dict]) -> int:
    token_ids = tokenizer.apply_chat_template(
        question_messages,
        tokenize=True,
        add_generation_prompt=True,
    )
    return len(token_ids)


def count_target_tokens(target: dict) -> int:
    token_ids = tokenizer.encode(
        target["content"],
        add_special_tokens=False,
    )
    return len(token_ids)


token_rows = []
for sample in samples_for_counting:
    context_tokens = count_context_tokens(sample["question"])
    target_tokens = count_target_tokens(sample["target"])
    token_rows.append(
        {
            "sample_id": sample["sample_id"],
            "context_tokens": context_tokens,
            "target_tokens": target_tokens,
            "total_tokens": context_tokens + target_tokens,
        }
    )

token_df = pd.DataFrame(token_rows)
print(f"Counted tokens for {len(token_df)} samples")


Counted tokens for 1405 samples


In [10]:
def summarize_tokens(series: pd.Series) -> dict[str, float]:
    return {
        "count": float(series.count()),
        "sum": float(series.sum()),
        "min": float(series.min()),
        "median": float(series.median()),
        "mean": float(series.mean()),
        "p90": float(series.quantile(0.90)),
        "p95": float(series.quantile(0.95)),
        "p99": float(series.quantile(0.99)),
        "max": float(series.max()),
    }


summary_df = pd.DataFrame(
    {
        "context_tokens": summarize_tokens(token_df["context_tokens"]),
        "target_tokens": summarize_tokens(token_df["target_tokens"]),
        "total_tokens": summarize_tokens(token_df["total_tokens"]),
    }
)

summary_df


Unnamed: 0,context_tokens,target_tokens,total_tokens
count,1405.0,1405.0,1405.0
sum,11252320.0,213656.0,11465980.0
min,1279.0,9.0,1362.0
median,7106.0,95.0,7209.0
mean,8008.77,152.068327,8160.838
p90,14740.4,303.0,14934.4
p95,16469.0,396.8,16591.8
p99,19216.12,680.0,19444.68
max,26318.0,7672.0,26341.0


In [11]:
print("Token sums across dataset:")
print(f"context_tokens={int(token_df['context_tokens'].sum())}")
print(f"target_tokens={int(token_df['target_tokens'].sum())}")
print(f"total_tokens={int(token_df['total_tokens'].sum())}")

for column, title in [
    ("context_tokens", "Context token distribution"),
    ("target_tokens", "Target token distribution"),
    ("total_tokens", "Total token distribution"),
]:
    fig = px.histogram(token_df, x=column, nbins=60, title=title)
    fig.update_layout(bargap=0.05)
    fig.show()

print("Top 5 by total_tokens:")
display(token_df.sort_values("total_tokens", ascending=False).head(5))


Token sums across dataset:
context_tokens=11252322
target_tokens=213656
total_tokens=11465978


Top 5 by total_tokens:


Unnamed: 0,sample_id,context_tokens,target_tokens,total_tokens
50,astropy__astropy-14995::a51,26318,23,26341
49,astropy__astropy-14995::a50,25969,193,26162
48,astropy__astropy-14995::a49,25385,355,25740
47,astropy__astropy-14995::a48,25295,52,25347
46,astropy__astropy-14995::a47,24508,61,24569
