In [56]:
import sqlite3
import pandas as pd
from transformers import AutoTokenizer

def analyze_db(db_file: str, table_name: str = 'dataset') -> None:
    """
    Connects to the given SQLite database,
    analyzes the 'trace' column for empty values,
    and prints summary stats plus non-empty row indices.
    """
    # Load table
    conn = sqlite3.connect(db_file)
    df = pd.read_sql(f"SELECT trace FROM {table_name}", conn)
    conn.close()

    # Metrics
    total_rows       = len(df)
    null_rows        = df['trace'].isna().sum()
    empty_str_rows   = (df['trace'] == '').sum()
    whitespace_rows  = df['trace'].str.strip().eq('').sum()
    empty_rows       = null_rows + whitespace_rows  # counts '' and all-whitespace
    non_empty_mask   = df['trace'].notna() & (df['trace'].str.strip() != '')
    non_empty_rows   = non_empty_mask.sum()
    non_empty_indices = df[non_empty_mask].index.tolist()

    # Output
    print(f"Total rows:        {total_rows}")
    print(f"Null rows:         {null_rows}")
    print(f"Empty rows:        {empty_rows}")
    print(f"Non-empty rows:    {non_empty_rows}\n")
    print("Non-empty row indices:")
    print(non_empty_indices)

def merge_db(
    db1_file: str,
    db2_file: str,
    table_name: str = 'dataset',
    out_db_file: str = 'merged.db'
) -> None:
    """
    Merge two SQLite DBs by keeping 'question' and 'answer' from DB1
    and merging 'trace' by choosing the shorter-token one on conflicts.
    """
    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen3-0.6B", trust_remote_code=True
    )
    
    # load full df from DB1
    conn1 = sqlite3.connect(db1_file)
    df1 = pd.read_sql(f"SELECT * FROM {table_name}", conn1)
    conn1.close()
    
    # load only trace from DB2
    conn2 = sqlite3.connect(db2_file)
    df2 = pd.read_sql(f"SELECT trace FROM {table_name}", conn2)
    conn2.close()
    
    assert len(df1) == len(df2), "Row counts differ"
    
    t1 = df1['trace'].fillna('').astype(str)
    t2 = df2['trace'].fillna('').astype(str)
    
    merged_traces = []
    conflicts = []
    
    for i, (a, b) in enumerate(zip(t1, t2)):
        empty1 = not a.strip()
        empty2 = not b.strip()
        if empty1 and not empty2:
            choice = b
        elif empty2 and not empty1:
            choice = a
        elif a == b:
            choice = a
        else:
            # conflict: pick shorter in tokens
            l1 = len(tokenizer(a).input_ids)
            l2 = len(tokenizer(b).input_ids)
            choice = a if l1 <= l2 else b
            conflicts.append(i)
        merged_traces.append(choice)
    
    # assign merged traces back to df1
    df1['trace'] = merged_traces
    
    # write merged DB
    conn_out = sqlite3.connect(out_db_file)
    df1.to_sql(table_name, conn_out, index=False, if_exists='replace')
    conn_out.close()
    
    # report
    print(f"Total rows:      {len(df1)}")
    print(f"Conflicts found: {len(conflicts)}")
    print("Conflict indices:", conflicts)

def show_trace(db_file: str, idx: int, table_name: str = 'dataset') -> None:
    """
    Load the 'trace' column from the given SQLite database,
    print the trace at row `idx` with separators, and its token count
    using the Qwen/Qwen3-0.6B tokenizer.
    """
    # load
    conn = sqlite3.connect(db_file)
    df = pd.read_sql(f"SELECT trace FROM {table_name}", conn)
    conn.close()

    # bounds check
    if idx < 0 or idx >= len(df):
        raise IndexError(f"Index {idx} out of range (0–{len(df)-1})")

    # get trace
    trace = df.at[idx, 'trace'] or ""

    sep = "-" * 80

    # output with separators
    print(f"Trace[{idx}]:")
    print(sep)
    print(trace)
    print(sep)

    # tokenize & count
    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen3-0.6B", trust_remote_code=True
    )
    tokens = tokenizer(trace).input_ids
    print(f"Token count: {len(tokens)}")


In [57]:
DB_FILE_1      = 'dataset_4qwen3_250630a_1.db'
DB_FILE_2      = 'dataset_4qwen3_250630a_2.db'
DB_FILE_MERGED = 'dataset_4qwen3_250630a_merged.db'

In [58]:
analyze_db(DB_FILE_1)

Total rows:        1817
Null rows:         0
Empty rows:        1733
Non-empty rows:    84

Non-empty row indices:
[0, 2, 5, 35, 41, 50, 54, 59, 60, 61, 67, 71, 72, 73, 77, 78, 81, 84, 91, 98, 100, 102, 104, 111, 113, 118, 120, 131, 137, 140, 145, 159, 165, 171, 181, 185, 188, 194, 197, 202, 204, 211, 212, 213, 233, 237, 241, 248, 258, 273, 279, 280, 303, 305, 321, 325, 331, 337, 363, 365, 369, 372, 379, 385, 390, 395, 402, 408, 411, 418, 428, 429, 430, 431, 432, 443, 448, 463, 464, 474, 484, 496, 498, 502]


In [59]:
analyze_db(DB_FILE_2)

Total rows:        1817
Null rows:         0
Empty rows:        1744
Non-empty rows:    73

Non-empty row indices:
[506, 509, 514, 528, 531, 534, 536, 537, 553, 560, 562, 572, 578, 582, 585, 586, 588, 594, 598, 602, 607, 612, 615, 618, 621, 627, 639, 646, 650, 653, 663, 669, 673, 678, 679, 684, 686, 689, 696, 718, 722, 723, 734, 735, 748, 750, 752, 758, 770, 776, 785, 789, 790, 794, 806, 807, 811, 818, 821, 824, 825, 828, 830, 833, 841, 867, 871, 875, 877, 891, 926, 932, 951]


In [60]:
merge_db(
    DB_FILE_1,
    DB_FILE_2,
    out_db_file=DB_FILE_MERGED
)

Total rows:      1817
Conflicts found: 0
Conflict indices: []


In [61]:
analyze_db(DB_FILE_MERGED)

Total rows:        1817
Null rows:         0
Empty rows:        1660
Non-empty rows:    157

Non-empty row indices:
[0, 2, 5, 35, 41, 50, 54, 59, 60, 61, 67, 71, 72, 73, 77, 78, 81, 84, 91, 98, 100, 102, 104, 111, 113, 118, 120, 131, 137, 140, 145, 159, 165, 171, 181, 185, 188, 194, 197, 202, 204, 211, 212, 213, 233, 237, 241, 248, 258, 273, 279, 280, 303, 305, 321, 325, 331, 337, 363, 365, 369, 372, 379, 385, 390, 395, 402, 408, 411, 418, 428, 429, 430, 431, 432, 443, 448, 463, 464, 474, 484, 496, 498, 502, 506, 509, 514, 528, 531, 534, 536, 537, 553, 560, 562, 572, 578, 582, 585, 586, 588, 594, 598, 602, 607, 612, 615, 618, 621, 627, 639, 646, 650, 653, 663, 669, 673, 678, 679, 684, 686, 689, 696, 718, 722, 723, 734, 735, 748, 750, 752, 758, 770, 776, 785, 789, 790, 794, 806, 807, 811, 818, 821, 824, 825, 828, 830, 833, 841, 867, 871, 875, 877, 891, 926, 932, 951]


In [62]:
show_trace(DB_FILE_MERGED, 448)

Trace[448]:
--------------------------------------------------------------------------------

Okay, let's tackle this crossword puzzle clue: "Adjust section of Gatt unethically (6)". Hmm, first, I need to break down the clue. The answer is a 6-letter word. The clue mentions "adjust" which might be a verb meaning to modify or change something. Then there's "section of Gatt" – maybe "section" here refers to a part of a word, like a substring. "Unethically" is an adjective, but it might be part of the wordplay. Wait, the word "unethically" could be a part of the word "unethical", but that's a 9-letter word. Wait, perhaps "unethical" isn't the right path. Maybe it's part of a wordplay where "unethically" modifies another part. Let me think again. 

The structure might be something like a word meaning "adjust" combined with a section of the word "GATT", but maybe not directly. Wait, "GATT" is an acronym, maybe the letters are used as parts of the word. Let me think of synonyms for "adjust".

In [63]:
def find_fffd_trace_indices(db_file: str, table: str = 'dataset') -> list[int]:
    """
    Identify indices of rows in the 'trace' column that contain the Unicode replacement character � (U+FFFD).

    Args:
        db_file: Path to the SQLite database file.
        table: Name of the table (default: 'dataset').

    Returns:
        List of row indices with traces containing � (U+FFFD).
    """
    import sqlite3
    import pandas as pd

    conn = sqlite3.connect(db_file)
    df = pd.read_sql(f"SELECT trace FROM {table}", conn)
    conn.close()

    fffd_indices = [
        idx for idx, trace in enumerate(df['trace']) 
        if trace is not None and '�' in str(trace)
    ]

    return fffd_indices

In [49]:
DB_FILE = 'dataset_4qwen3_250622a_merged.db'

contaminated_trace_indices = find_fffd_trace_indices(DB_FILE)

print(len(contaminated_trace_indices))
print(contaminated_trace_indices)


112
[2, 18, 36, 50, 59, 61, 71, 72, 76, 77, 78, 81, 91, 100, 102, 113, 120, 121, 131, 140, 144, 145, 159, 181, 185, 188, 194, 195, 197, 204, 211, 212, 273, 280, 305, 321, 325, 337, 359, 361, 369, 372, 379, 385, 394, 402, 418, 423, 428, 429, 430, 467, 474, 480, 484, 496, 502, 506, 511, 514, 522, 527, 531, 534, 536, 560, 562, 572, 578, 582, 585, 589, 598, 602, 607, 615, 618, 622, 627, 631, 636, 669, 673, 684, 686, 696, 720, 722, 723, 731, 735, 736, 739, 750, 752, 776, 778, 785, 793, 806, 811, 821, 824, 825, 830, 841, 855, 867, 875, 891, 926, 932]


In [65]:
DB_FILE = 'dataset_4qwen3_250630a_merged.db'

contaminated_trace_indices = find_fffd_trace_indices(DB_FILE)

print(len(contaminated_trace_indices))
print(contaminated_trace_indices)


0
[]


In [None]:
import math

def print_index_groups(indices: list[int], n: int) -> None:
    """
    Partition `indices` into `n` groups (last group may be shorter)
    and print each group as a comma-separated string with no spaces.
    """
    m = len(indices)                          # total indices
    s = math.ceil(m / n)                      # group size  s = ⌈ m ⁄ n ⌉
    
    for k in range(n):
        start, end = k * s, (k + 1) * s       # slice bounds
        group = indices[start:end]
        if group:                             # ignore empty tail groups
            print(",".join(map(str, group)))
            

In [31]:
print_index_groups(contaminated_trace_indices, n=4)


2,18,36,50,59,61,71,72,76,77,78,81,91,100,102,113,120,121,131,140,144,145,159,181,185,188,194,195
197,204,211,212,273,280,305,321,325,337,359,361,369,372,379,385,394,402,418,423,428,429,430,467,474,480,484,496
502,506,511,514,522,527,531,534,536,560,562,572,578,582,585,589,598,602,607,615,618,622,627,631,636,669,673,684
686,696,720,722,723,731,735,736,750,752,776,778,785,793,806,811,821,824,825,830,841,855,867,875,891,926,932


In [None]:
import sqlite3

def clear_fffd_traces(db_file: str,
                      row_indices: list[int],
                      table: str = 'dataset') -> None:
    """
    Set trace = '' for all rows in the table whose pandas index is in `row_indices`.
    Assumes the table’s implicit ROWID order matches the DataFrame index.
    """
    if not row_indices:
        return

    conn = sqlite3.connect(db_file)
    cur = conn.cursor()

    for idx in row_indices:
        cur.execute(f"UPDATE {table} SET trace = '' WHERE rowid = ?", (idx + 1,))

    conn.commit()
    conn.close()


In [None]:
DB_FILE = "dataset_4qwen3.db"

clear_fffd_traces(DB_FILE, contaminated_trace_indices)


In [None]:
import sqlite3
from typing import Sequence

def clear_trace_and_teacher(db_path: str,
                            row_indices: Sequence[int],
                            table: str = "dataset") -> None:
    """
    Blank out both the `trace` **and** `teacher` columns for every row whose
    *DataFrame* index appears in `row_indices`.

    Assumption  
    \( \text{ROWID}_{\text{SQLite}} = \text{pandas index} + 1 \)

    Parameters
    ----------
    db_path : str
        Path to the SQLite database file.
    row_indices : Sequence[int]
        Zero-based pandas indices to wipe.
    table : str, optional
        Target table name (default ``'dataset'``).
    """
    if not row_indices:            # ∅ → no work
        return

    with sqlite3.connect(db_path) as conn:
        cur = conn.cursor()

        # Convert pandas indices → SQLite rowid (1-based)
        params = [(idx + 1,) for idx in row_indices]

        # Execute once per row via executemany:
        cur.executemany(
            f"UPDATE {table} SET trace = '', teacher = '' WHERE rowid = ?",
            params
        )
        conn.commit()


In [None]:
DB_FILE = "dataset_4qwen3.db"

clear_trace_and_teacher(DB_FILE, [35,70,103,137,365,807])

In [66]:
import sqlite3
import pandas as pd
from transformers import AutoTokenizer

def add_teacher_column(db_file: str, table_name: str = 'dataset') -> None:
    """
    Add a 'teacher' column to the database. For rows where 'trace' is not empty,
    set 'teacher' to 's1.1-7B'. For empty trace rows, set 'teacher' to empty string.
    """
    # Load the full table
    conn = sqlite3.connect(db_file)
    df = pd.read_sql(f"SELECT * FROM {table_name}", conn)
    conn.close()
    
    # Create teacher column if it doesn't already exist
    if 'teacher' not in df.columns:
        df['teacher'] = ''
    
    # Check if trace is not empty/whitespace
    non_empty_mask = df['trace'].str.strip() != ''
    
    # Set 'teacher' to 's1.1-7B' where trace is not empty
    df.loc[non_empty_mask, 'teacher'] = 's1.1-7B'
    
    # Write back to database
    conn = sqlite3.connect(db_file)
    df.to_sql(table_name, conn, index=False, if_exists='replace')
    conn.close()
    
    # Report results
    total_rows = len(df)
    teacher_assigned = non_empty_mask.sum()
    
    print(f"Total rows: {total_rows}")
    print(f"Rows with teacher assigned: {teacher_assigned}")
    print(f"Rows with empty teacher: {total_rows - teacher_assigned}")

def analyze_db_with_teacher(db_file: str, table_name: str = 'dataset') -> None:
    """
    Comprehensive analysis of the database including trace and teacher columns.
    Shows statistics for all teacher types and their distribution.
    """
    # Load table
    conn = sqlite3.connect(db_file)
    df = pd.read_sql(f"SELECT * FROM {table_name}", conn)
    conn.close()
    
    # Check if teacher column exists
    if 'teacher' not in df.columns:
        print("No 'teacher' column found in the database.")
        return
    
    # Trace metrics - no entries are null, only empty or non-empty
    trace_empty_rows = (df['trace'].str.strip() == '').sum()
    trace_non_empty = df['trace'].str.strip() != ''
    
    # Teacher metrics - no entries are null, only empty or non-empty
    teacher_empty_rows = (df['teacher'].str.strip() == '').sum()
    teacher_non_empty = df['teacher'].str.strip() != ''
    
    # Get unique teacher types (excluding empty)
    unique_teachers = df[teacher_non_empty]['teacher'].unique()
    teacher_counts = df['teacher'].value_counts(dropna=False)
    
    print("Trace column:")
    print(f"    Empty rows: {trace_empty_rows}")
    print(f"    Non-empty rows: {trace_non_empty.sum()}")
    
    print(f"\nTeacher column:")
    print(f"    Empty rows: {teacher_empty_rows}")
    print(f"    Non-empty rows: {teacher_non_empty.sum()}")
    print(f"    Unique teacher types: {len(unique_teachers)}")
    
    # Show counts for each teacher type
    for teacher_type in sorted(unique_teachers):
        count = teacher_counts.get(teacher_type, 0)
        print(f"    '{teacher_type}': {count}")
    
    # Consistency checks
    trace_teacher_mismatch = trace_non_empty & ~teacher_non_empty
    teacher_trace_mismatch = teacher_non_empty & ~trace_non_empty
    
    print(f"\nConsistency checks:")
    print(f"    Rows with non-empty trace but empty teacher: {trace_teacher_mismatch.sum()}")
    print(f"    Rows with non-empty teacher but empty trace: {teacher_trace_mismatch.sum()}")

def update_teacher_column(db_file: str, teacher_model: str, table_name: str = 'dataset') -> None:
    """
    Update the 'teacher' column for rows where trace is non-empty but teacher is empty.
    Sets teacher to the specified teacher_model for these rows.
    
    Args:
        db_file: Path to the SQLite database file
        teacher_model: The teacher model string to assign (e.g., 's1.1-14B')
        table_name: Name of the table in the database
    """
    # Load the full table
    conn = sqlite3.connect(db_file)
    df = pd.read_sql(f"SELECT * FROM {table_name}", conn)
    conn.close()
    
    # Check if teacher column exists
    if 'teacher' not in df.columns:
        print("Error: 'teacher' column not found in the database.")
        print("Please run add_teacher_column() first.")
        return
    
    # Find rows where trace is non-empty but teacher is empty
    trace_non_empty = df['trace'].str.strip() != ''
    teacher_empty = df['teacher'].str.strip() == ''
    
    # Mask for rows to update: trace non-empty AND teacher empty
    update_mask = trace_non_empty & teacher_empty
    
    # Count rows before update for reporting
    rows_to_update = update_mask.sum()
    
    if rows_to_update == 0:
        print("No rows found where trace is non-empty and teacher is empty.")
        return
    
    # Update teacher column for matching rows
    df.loc[update_mask, 'teacher'] = teacher_model
    
    # Write back to database
    conn = sqlite3.connect(db_file)
    df.to_sql(table_name, conn, index=False, if_exists='replace')
    conn.close()
    
    # Report results
    print(f"Updated {rows_to_update} rows with teacher model: '{teacher_model}'")
    
    # Show indices of updated rows
    updated_indices = df[update_mask].index.tolist()
    print(f"Updated row indices: {updated_indices}")


In [67]:
DB_FILE = 'dataset_4qwen3_250630a_merged.db'

# Add teacher column to your database
add_teacher_column(DB_FILE)

Total rows: 1817
Rows with teacher assigned: 157
Rows with empty teacher: 1660


In [None]:
DB_FILE = 'dataset_4qwen3_250624a_merged.db'

# Update teacher column for rows with non-empty trace but empty teacher
update_teacher_column(DB_FILE, 's1.1-7B')

In [68]:
DB_FILE = 'dataset_4qwen3_250630a_merged.db'

# Analyze the results
analyze_db_with_teacher(DB_FILE)


Trace column:
    Empty rows: 1660
    Non-empty rows: 157

Teacher column:
    Empty rows: 1660
    Non-empty rows: 157
    Unique teacher types: 1
    's1.1-7B': 157

Consistency checks:
    Rows with non-empty trace but empty teacher: 0
    Rows with non-empty teacher but empty trace: 0


In [None]:
import torch
print(torch.__version__)

import vllm
print(vllm.__version__)

# 2.6.0+cu124
# 0.8.5.post1

print(torch.cuda.is_available())
print(torch.cuda.device_count())

In [None]:
import pandas as pd
import sqlite3

DB_FILE = 'dataset_4qwen3.db'

# Read the SQLite database back into a DataFrame
conn = sqlite3.connect(DB_FILE)
df = pd.read_sql('SELECT * FROM dataset', conn)
conn.close()

print(f"Loaded {len(df)} rows")
print(f"Columns: {df.columns.tolist()}")

null_count = df['trace'].isnull().sum()
empty_string_count = (df['trace'] == '').sum()
whitespace_only_count = df['trace'].str.strip().eq('').sum() if df['trace'].dtype == 'object' else 0
total_empty = df['trace'].isnull().sum() + (df['trace'].str.strip() == '').sum()

print(f"Null/NaN values: {null_count}")
print(f"Empty strings: {empty_string_count}")
print(f"Whitespace-only strings: {whitespace_only_count}")
print(f"Total empty rows: {total_empty}")

df.head()

In [None]:
# Token Probability Analyzer - Check reasoning trace token probabilities with small model
#
# Analyzes an existing reasoning trace to check the probability of each token
# according to a small model. Marks tokens with probability below threshold with asterisk.

# --------------------------- imports ---------------------------------------
import os, html, uuid, asyncio, contextlib, nest_asyncio, logging
from IPython.display import HTML, display

import torch
from huggingface_hub import snapshot_download
from vllm import TokensPrompt
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import SamplingParams, RequestOutputKind

nest_asyncio.apply()
torch.set_grad_enabled(False)
logging.disable(logging.INFO)

# --------------------------- configuration ---------------------------------
SMALL_GPU_INDEX = "0"
SMALL_MODEL_NAME = "Qwen/Qwen3-0.6B"
SMALL_TEMPERATURE = 0.7
MAX_SEQ_LEN = 8192
PROB_THRESHOLD = 0.01

# ---------------- utility: temporarily set visible GPUs --------------------
@contextlib.contextmanager
def visible_gpus(devices: str):
    original = os.environ.get("CUDA_VISIBLE_DEVICES", "")
    os.environ["CUDA_VISIBLE_DEVICES"] = devices
    print(f"\nCUDA_VISIBLE_DEVICES = {devices}")
    try:
        yield
    finally:
        os.environ["CUDA_VISIBLE_DEVICES"] = original

# --------------------------- engine setup ----------------------------------
async def setup_small_engine():
    global small_engine, small_tokenizer, small_vocab_size
    
    small_checkpoint = snapshot_download(SMALL_MODEL_NAME)

    with visible_gpus(SMALL_GPU_INDEX):
        print("torch sees", torch.cuda.device_count(), "GPU(s)")              
        small_engine = AsyncLLMEngine.from_engine_args(
            AsyncEngineArgs(model=small_checkpoint, 
                            tensor_parallel_size=1,
                            max_model_len=MAX_SEQ_LEN, 
                            gpu_memory_utilization=0.20,
                            dtype="bfloat16"),
            start_engine_loop=True)
        
        small_tokenizer = await small_engine.get_tokenizer()

    # Get model config using async method
    small_model_config = await small_engine.get_model_config()
    small_vocab_size = small_model_config.get_vocab_size()
    
    print(f"Small vocab size: {small_vocab_size}")

# --------------------------- sampling params -------------------------------
small_sampling_params = SamplingParams(
    max_tokens=1,
    temperature=SMALL_TEMPERATURE,
    top_p=0.95, 
    logprobs=20,  # vLLM's max allowed logprob size
    output_kind=RequestOutputKind.DELTA,
)

# -------------------------- helper functions -------------------------------
def html_heatmap(token_ids, probabilities, tokenizer):
    """Create heatmap visualization of token probabilities"""
    
    def colour(probability):
        if probability < PROB_THRESHOLD:
            return "rgb(255,0,0)"  # Red for below threshold
        else:
            return "rgb(0,0,0)"    # Black for above threshold
    
    spans = []
    
    # Find token groups that form complete characters
    token_groups = []
    i = 0
    
    while i < len(token_ids):
        # Start with current token
        group_start = i
        group_end = i + 1
        
        # Expand the group until we have a valid UTF-8 sequence
        while group_end <= len(token_ids):
            # Try decoding the current group
            group_text = tokenizer.decode(token_ids[group_start:group_end])
            
            if '\ufffd' not in group_text:
                # Valid decode, but check if we should include more tokens
                if group_end < len(token_ids):
                    # Check if adding the next token changes the decode
                    extended_text = tokenizer.decode(token_ids[group_start:group_end+1])
                    current_plus_next = group_text + tokenizer.decode([token_ids[group_end]])
                    
                    if extended_text != current_plus_next or '\ufffd' in current_plus_next:
                        # Next token is part of this character, continue
                        group_end += 1
                        continue
                
                # We have a complete group
                break
            else:
                # Invalid decode, need more tokens
                group_end += 1
                if group_end > len(token_ids):
                    # Reached end with incomplete sequence
                    group_end = len(token_ids)
                    break
        
        # Store the group
        token_groups.append((group_start, group_end))
        i = group_end
    
    # Now render each group
    for group_start, group_end in token_groups:
        # Decode the group
        text = tokenizer.decode(token_ids[group_start:group_end])
        
        if not text:
            continue
        
        escaped = html.escape(text).replace(" ", "&nbsp;")
        
        # Check if any token in this group has low probability
        any_low_prob = False
        min_prob = 1.0
        
        for token_idx in range(group_start, group_end):
            if token_idx < len(probabilities):
                prob = probabilities[token_idx]
                min_prob = min(min_prob, prob)
                any_low_prob = any_low_prob or (prob < PROB_THRESHOLD)
        
        style = f"color:{colour(min_prob)};"
        if any_low_prob:
            style += " text-decoration:underline;"
        spans.append(f"<span style='{style}'>{escaped}</span>")
    
    return HTML("<pre style='white-space:pre-wrap; line-height:1.45; "
                "font-family:inherit; background:#fff; padding:8px; "
                "border:1px solid #ddd;'>" + "".join(spans) + "</pre>")

# ------------------------- core analysis loop ------------------------------
async def one_step_analyze(context_ids):
    """Get probability distribution for next token given context"""
    tokens_prompt = TokensPrompt(prompt_token_ids=context_ids)
    generator = small_engine.generate(tokens_prompt, small_sampling_params, request_id=str(uuid.uuid4()))
    return (await anext(generator)).outputs[0]

async def analyze_trace(prompt_part: str, trace_part: str):
    """Analyze each token in the trace part for its probability"""
    
    # Tokenize the prompt and trace separately
    prompt_token_ids = small_tokenizer.encode(prompt_part)
    trace_token_ids = small_tokenizer.encode(trace_part)
    
    print(f"Prompt tokens: {len(prompt_token_ids)}")
    print(f"Trace tokens: {len(trace_token_ids)}")
    print(f"Analyzing {len(trace_token_ids)} trace tokens...")
    print("-" * 80)
    print("Step\tProb\tTok_ID\tTok_Txt")
    print("-" * 80)
    
    probabilities = []
    records = []
    
    # For each token in the trace, check its probability
    for step_index in range(len(trace_token_ids)):
        # Context is: prompt + trace tokens up to this position
        context_ids = prompt_token_ids + trace_token_ids[:step_index]
        
        # The token we're analyzing
        actual_token_id = trace_token_ids[step_index]
        actual_token_text = small_tokenizer.decode([actual_token_id])
        
        # Get model's probability distribution for next token
        output = await one_step_analyze(context_ids)
        
        # Extract probabilities from model output
        logprobs_dict = output.logprobs[0]
        
        # Get probability of the actual token that was used
        if actual_token_id in logprobs_dict:
            actual_prob = torch.exp(torch.tensor(logprobs_dict[actual_token_id].logprob)).item()
        else:
            actual_prob = 0.0  # Token not in top predictions
        
        probabilities.append(actual_prob)
        
        # Check if probability is below threshold
        low_prob = actual_prob < PROB_THRESHOLD
        
        record = {
            'step': step_index + 1,
            'token_id': actual_token_id,
            'token_text': actual_token_text,
            'probability': actual_prob,
            'low_prob': low_prob
        }
        records.append(record)
        
        print(f"{step_index + 1:4d}{'*' if low_prob else ' '}\t"
              f"{actual_prob:.4f}\t"
              f"{actual_token_id}\t'{actual_token_text}'",
              flush=True)
    
    print("-" * 80)
    
    # Display the heatmap for trace tokens only
    display(html_heatmap(trace_token_ids, probabilities, small_tokenizer))
    
    # Print only tokens below probability threshold
    low_prob_records = [record for record in records if record['low_prob']]
    if low_prob_records:
        print(f"\nTokens below {PROB_THRESHOLD} probability threshold:")
        print("-" * 80)
        print("Step\tProb\tTok_ID\tTok_Txt")
        print("-" * 80)
        for record in low_prob_records:
            print(f"{record['step']:4d}*\t"
                  f"{record['probability']:.4f}\t"
                  f"{record['token_id']}\t'{record['token_text']}'")
        print("-" * 80)
    else:
        print(f"\nNo tokens below {PROB_THRESHOLD} probability threshold found.")
    
    # Statistics
    low_prob_count = sum(record['low_prob'] for record in records)
    print(f"\nLow probability tokens (< {PROB_THRESHOLD}): {low_prob_count}/{len(records)} "
          f"({low_prob_count/len(records)*100:.2f}%)")
    
    return records

# ---------------------- high-level convenience -----------------------------
async def run_analysis(prompt_part: str, trace_part: str):
    """Main function to run the analysis"""
    return await analyze_trace(prompt_part, trace_part)

# ------------------------ fire up the engine ------------------------------
await setup_small_engine()


In [None]:
# --------------------------- example usage ---------------------------------
idx = 0  # Set the dataframe index you want to analyze

# Get question and trace from dataframe
question = df['question'].iloc[idx]
trace = df['trace'].iloc[idx]

# Create the prompt part (ends with "<think>")
prompt_part = f"""A conversation between User and Assistant. The User asks a question, and the Assistant responds in two clearly defined sections: 1. Reasoning Process - A step-by-step, logical exploration and analysis of the problem, enclosed within <think> and </think> tags. 2. Answer - A direct and concise response based on the reasoning process, with the final answer enclosed within \\boxed{{}}. For example, 
<think>
reasoning process here
</think>
answer here
\\boxed{{final answer here}}

Now, continue the actual conversation below.
User: {question}
Assistant:
<think>"""

# The trace part is just the trace
trace_part = trace

# Run the analysis
records = await run_analysis(prompt_part, trace_part)

In [None]:
import pandas as pd
import sqlite3

def non_empty_trace_indices(db_path: str, table: str = "dataset", col: str = "trace") -> list[int]:
    """
    Return a list of row indices (SQLite `rowid`s) whose **trace** column
    is neither NULL nor an empty / whitespace-only string.

    Parameters
    ----------
    db_path : str
        Path to the SQLite database file.
    table   : str, optional
        Table name holding the data (default ``"dataset"``).
    col     : str, optional
        Column holding the trace text (default ``"trace"``).

    Returns
    -------
    list[int]
        SQLite rowids that satisfy the non-empty condition.
    """
    with sqlite3.connect(db_path) as conn:
        # Pull only rowid + trace to keep it light
        df = pd.read_sql(f"SELECT rowid AS idx, {col} FROM {table}", conn)

    mask = df[col].notna() & df[col].astype(str).str.strip().ne("")
    idx = df.loc[mask, "idx"]

    return (idx - 1).tolist()


In [None]:
DB_FILE = "dataset_4qwen3.db"

indices = non_empty_trace_indices(DB_FILE)
print(len(indices))
print(indices)

In [None]:
# --------------------------- batch analysis wrapper -----------------------
async def analyze_batch_low_prob_only(indices: list):
    """Analyze multiple traces and only print tokens below probability threshold"""
    
    for i, idx in enumerate(indices):
        print(f"\n{'-'*80}")
        print(f"Entry {idx}")
        
        # Get question and trace from dataframe
        question = df['question'].iloc[idx]
        trace = df['trace'].iloc[idx]
        
        # Create the prompt part
        prompt_part = f"""A conversation between User and Assistant. The User asks a question, and the Assistant responds in two clearly defined sections: 1. Reasoning Process - A step-by-step, logical exploration and analysis of the problem, enclosed within <think> and </think> tags. 2. Answer - A direct and concise response based on the reasoning process, with the final answer enclosed within \\boxed{{}}. For example, 
<think>
reasoning process here
</think>
answer here
\\boxed{{final answer here}}

Now, continue the actual conversation below.
User: {question}
Assistant:
<think>"""
        
        # The trace part is just the trace
        trace_part = trace
        
        # Tokenize the prompt and trace separately
        prompt_token_ids = small_tokenizer.encode(prompt_part)
        trace_token_ids = small_tokenizer.encode(trace_part)
        
        probabilities = []
        low_prob_records = []
        
        # For each token in the trace, check its probability
        for step_index in range(len(trace_token_ids)):
            # Context is: prompt + trace tokens up to this position
            context_ids = prompt_token_ids + trace_token_ids[:step_index]
            
            # The token we're analyzing
            actual_token_id = trace_token_ids[step_index]
            actual_token_text = small_tokenizer.decode([actual_token_id])
            
            # Get model's probability distribution for next token
            output = await one_step_analyze(context_ids)
            
            # Extract probabilities from model output
            logprobs_dict = output.logprobs[0]
            
            # Get probability of the actual token that was used
            if actual_token_id in logprobs_dict:
                actual_prob = torch.exp(torch.tensor(logprobs_dict[actual_token_id].logprob)).item()
            else:
                actual_prob = 0.0  # Token not in top predictions
            
            probabilities.append(actual_prob)
            
            # Check if probability is below threshold
            low_prob = actual_prob < PROB_THRESHOLD
            
            if low_prob:
                record = {
                    'step': step_index + 1,
                    'token_id': actual_token_id,
                    'token_text': actual_token_text,
                    'probability': actual_prob,
                    'low_prob': low_prob
                }
                low_prob_records.append(record)
        
        # Print only tokens below probability threshold
        if low_prob_records:
            print("\nStep\tProb\tTok_ID\tTok_Txt")
            for record in low_prob_records:
                print(f"{record['step']:4d}*\t"
                      f"{record['probability']:.4f}\t"
                      f"{record['token_id']}\t'{record['token_text']}'")
            # Statistics
            total_tokens = len(trace_token_ids)
            low_prob_count = len(low_prob_records)
            print(f"\nLow probability tokens (< {PROB_THRESHOLD}): {low_prob_count}/{total_tokens} "
                f"({low_prob_count/total_tokens*100:.2f}%)")
        else:
            print(f"\nNo tokens below {PROB_THRESHOLD} probability threshold found.")
        
# --------------------------- batch usage example --------------------------
await analyze_batch_low_prob_only(indices)

In [70]:
import sqlite3
import pandas as pd

def convert_db_to_df(db_file: str, table_name: str = 'dataset') -> pd.DataFrame:
    """
    Converts SQLite database to DataFrame with specific text formatting.
    
    Args:
        db_file: Path to the SQLite database file
        table_name: Name of the table to read from (default: 'dataset')
    
    Returns:
        DataFrame with 'question', 'answer', 'trace', and 'text' columns,
        filtered to only include rows with non-empty traces.
    """
    # Load data from database
    conn = sqlite3.connect(db_file)
    df = pd.read_sql(f"SELECT * FROM {table_name}", conn)
    conn.close()
    
    # Keep only the required columns (question, answer, trace)
    required_cols = ['question', 'answer', 'trace']
    df = df[required_cols]
    
    # Filter to keep only rows with non-empty trace
    # Handle both null and empty/whitespace-only strings
    non_empty_mask = df['trace'].notna() & (df['trace'].str.strip() != '')
    df = df[non_empty_mask].reset_index(drop=True)
    
    # Define the system prompt
    system_prompt = (
        f"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n"
        f"You must respond to every query in the following manner:\n"
        f"First, provide a step-by-step logical exploration of the problem.\n"
        f"Then, provide a clear and direct response based on your reasoning, with the final answer enclosed in \\boxed{{}}."
    )
    
    # Create the text column
    def create_text(row):
        question = row['question']
        trace = row['trace']
        
        input_text = (
            f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
            f"<|im_start|>user\n{question}<|im_end|>\n"
            f"<|im_start|>assistant\n<think>"
        )
        
        return input_text + trace
    
    df['text'] = df.apply(create_text, axis=1)
    
    return df


In [73]:
DB_FILE = "dataset_4qwen3_250630a_merged.db"

df = convert_db_to_df(DB_FILE)

print(f"Resulting DataFrame shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")

df.head()

Resulting DataFrame shape: (157, 4)
Columns: ['question', 'answer', 'trace', 'text']


Unnamed: 0,question,answer,trace,text
0,"Given a rational number, write it as a fractio...",128,"\nOkay, let's tackle this problem step by step...","<|im_start|>system\nYou are Qwen, created by A..."
1,Find the remainder when $9 \times 99 \times 99...,109,"\nOkay, let me try to figure out how to solve ...","<|im_start|>system\nYou are Qwen, created by A..."
2,One base of a trapezoid is $100$ units longer ...,181,"\nOkay, so I need to solve this problem about ...","<|im_start|>system\nYou are Qwen, created by A..."
3,Let $P(x)$ be a polynomial with integer coeffi...,418,"\nOkay, let me try to figure out this problem ...","<|im_start|>system\nYou are Qwen, created by A..."
4,"How many positive integers less than 10,000 ha...",927,"\nOkay, let's see. The question is: How many p...","<|im_start|>system\nYou are Qwen, created by A..."


In [75]:
print(df.iloc[0]['text'])

<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
You must respond to every query in the following manner:
First, provide a step-by-step logical exploration of the problem.
Then, provide a clear and direct response based on your reasoning, with the final answer enclosed in \boxed{}.<|im_end|>
<|im_start|>user
Given a rational number, write it as a fraction in lowest terms and calculate the product of the resulting numerator and denominator. For how many rational numbers between 0 and 1 will $20_{}^{}!$ be the resulting product?<|im_end|>
<|im_start|>assistant
<think>
Okay, let's tackle this problem step by step. The question is: Given a rational number between 0 and 1, write it as a fraction in lowest terms, and then calculate the product of the numerator and the denominator. We need to find how many such rational numbers will result in the product being equal to 20! (20 factorial). 

First, let me make sure I understand the problem. A rational num

In [None]:
def create_and_upload_dataset(df, dataset_remote_path):
    """
    Create a HuggingFace dataset from DataFrame and upload it.
    Replicates the structure of simplescaling/s1K-1.1_tokenized.
    
    Args:
        df: DataFrame with 'question', 'answer', 'text' columns
        dataset_remote_path: HuggingFace dataset path to upload to (e.g., "jaeh8nkim/s1K_for_Qwen3-0.6B")
    
    Returns:
        DatasetDict: The created dataset, or None if upload failed
    """
    from datasets import Dataset, DatasetDict
    from huggingface_hub import create_repo, login
    from dotenv import load_dotenv
    import os
    import pandas as pd
    
    # Load environment and login
    load_dotenv()
    login(token=os.getenv("HF_WRITE_TOKEN"))
    
    print(f"🔄 Creating dataset from DataFrame with {len(df)} samples...")
    
    # Validate DataFrame structure
    required_columns = ['question', 'answer', 'text']
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        raise ValueError(f"DataFrame missing required columns: {missing_columns}")
    
    # Create dataset from DataFrame, keeping all three columns like the original s1K structure
    train_data = {
        'question': df['question'].tolist(),
        'answer': df['answer'].tolist(), 
        'text': df['text'].tolist()
    }
    train_dataset = Dataset.from_dict(train_data)
    
    # Create DatasetDict with train split only (matching s1K-1.1_tokenized structure)
    dataset_dict = DatasetDict({
        'train': train_dataset
    })
    
    # Print dataset info before upload
    print(f"📊 Dataset structure:")
    print(f"   Splits: {list(dataset_dict.keys())}")
    print(f"   Train samples: {len(train_dataset)}")
    print(f"   Columns: {train_dataset.column_names}")
    if len(train_dataset) > 0:
        example_text = train_dataset[0]['text']
        example_question = train_dataset[0]['question']
        print(f"   Example question length: {len(example_question)} characters")
        print(f"   Example text length: {len(example_text)} characters")
        print(f"   First 200 chars of question: {example_question[:200]}...")
        print(f"   First 200 chars of text: {example_text[:200]}...")
    
    # Create repository on HuggingFace
    try:
        create_repo(repo_id=dataset_remote_path, repo_type="dataset", private=False, exist_ok=True)
        print(f"✅ Repository ready: https://huggingface.co/datasets/{dataset_remote_path}")
    except Exception as e:
        print(f"⚠️ Repository setup issue: {e}")
        # Continue anyway, might already exist
    
    # Upload dataset (this will create a .parquet file automatically)
    try:
        print(f"📤 Uploading dataset to {dataset_remote_path}...")
        dataset_dict.push_to_hub(
            dataset_remote_path,
            commit_message="Upload tokenized dataset for Qwen3-0.6B training"
        )
        print(f"✅ Dataset uploaded successfully as .parquet file!")
        print(f"🔗 View at: https://huggingface.co/datasets/{dataset_remote_path}")
        
        # Verify upload by trying to load it back
        print("🔍 Verifying upload...")
        from datasets import load_dataset
        verification_dataset = load_dataset(dataset_remote_path)
        print(f"✅ Verification successful: {len(verification_dataset['train'])} samples loaded")
        print(f"✅ Verified columns: {verification_dataset['train'].column_names}")
        
        return dataset_dict
        
    except Exception as e:
        print(f"❌ Upload failed: {e}")
        return None

In [None]:
DATASET_REMOTE_PATH = "jaeh8nkim/s1K_for_Qwen3-0.6B"
dataset = create_and_upload_dataset(df, DATASET_REMOTE_PATH)