# Small Prompt Dataset Creation

Filter the consolidated synthetic dataset into short (≤128 token) and very-short (≤64 token) subsets, apply light post-processing, and produce reproducible train/test splits.

In [1]:
from __future__ import annotations

import csv
import random
import re
from collections import Counter
from pathlib import Path

from transformers import AutoTokenizer

PROJECT_ROOT = Path('..').resolve()
SOURCE_PATH = PROJECT_ROOT / 'src' / 'training_data' / 'dolly-prompt-compression.csv'
SHORT_OUTPUT = PROJECT_ROOT / 'src' / 'training_data' / 'dolly-short-prompt-compression.csv'
VERY_SHORT_OUTPUT = PROJECT_ROOT / 'src' / 'training_data' / 'dolly-very-short-prompt-compression.csv'
SHORT_TRAIN = PROJECT_ROOT / 'src' / 'training_data' / 'dsp-train.csv'
SHORT_TEST = PROJECT_ROOT / 'src' / 'training_data' / 'dsp-test.csv'
VERY_SHORT_TRAIN = PROJECT_ROOT / 'src' / 'training_data' / 'dvsp-train.csv'
VERY_SHORT_TEST = PROJECT_ROOT / 'src' / 'training_data' / 'dvsp-test.csv'

LOCAL_TOKENIZER_DIR = PROJECT_ROOT / 'small-prompt-compression-model'
HF_TOKENIZER_REPO = 'dotslashderek/short-prompt-compressor'

if LOCAL_TOKENIZER_DIR.exists():
    TOKENIZER = AutoTokenizer.from_pretrained(str(LOCAL_TOKENIZER_DIR), use_fast=True)
else:
    TOKENIZER = AutoTokenizer.from_pretrained(HF_TOKENIZER_REPO, use_fast=True)


def count_tokens(text: str) -> int:
    if not text:
        return 0
    return len(TOKENIZER.encode(text, add_special_tokens=False))


def postprocess(text: str) -> str:
    words = text.strip().split()
    if not words:
        return text.strip()
    if len(words) > 3:
        filtered = [w for w in words if w.lower() not in {'the', 'an', 'a'}]
    else:
        filtered = words
    candidate = ' '.join(filtered).strip()
    candidate = re.sub(r'[.?!…]+$', '', candidate).strip()
    return candidate


def load_compressed_rows():
    if not SOURCE_PATH.exists():
        raise FileNotFoundError('Run the initial synthetic data generation notebook first.')
    with SOURCE_PATH.open('r', encoding='utf-8', newline='') as fh:
        reader = csv.DictReader(fh)
        return [row for row in reader if (row.get('compressed_prompt') or '').strip()]


def build_filtered(rows, max_tokens: int):
    filtered = []
    for row in rows:
        original = row['original'].strip()
        compressed = postprocess(row.get('compressed_prompt', '').strip())
        orig_tokens = count_tokens(original)
        if orig_tokens <= max_tokens:
            filtered.append({
                'original': original,
                'original_token_count': str(orig_tokens),
                'compressed_prompt': compressed,
                'compressed_token_count': str(count_tokens(compressed)),
            })
    return filtered


def write_rows(path: Path, rows):
    path.parent.mkdir(parents=True, exist_ok=True)
    fieldnames = ['original', 'original_token_count', 'compressed_prompt', 'compressed_token_count']
    with path.open('w', encoding='utf-8', newline='') as fh:
        writer = csv.DictWriter(fh, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def describe_lengths(rows, label: str):
    counts = [int(row['original_token_count']) for row in rows]
    if not counts:
        print(f'{label}: 0 rows')
        return
    print(f'{label}: {len(counts)} rows')
    print(f"  min={min(counts)} max={max(counts)} median={sorted(counts)[len(counts)//2]}")
    buckets = [0, 16, 32, 48, 64, 96, 128]
    counter = Counter()
    for value in counts:
        placed = False
        for start, end in zip(buckets, buckets[1:]):
            if start < value <= end:
                counter[f'{start + 1}-{end}'] += 1
                placed = True
                break
        if not placed:
            counter[f'>{buckets[-1]}'] += 1
    for bucket in sorted(counter.keys()):
        print(f'  {bucket}: {counter[bucket]}')


def split_train_test(rows, train_path: Path, test_path: Path, seed: int = 42, train_frac: float = 0.9):
    shuffled = rows[:]
    random.Random(seed).shuffle(shuffled)
    pivot = int(len(shuffled) * train_frac)
    train_rows = shuffled[:pivot]
    test_rows = shuffled[pivot:]
    write_rows(train_path, train_rows)
    write_rows(test_path, test_rows)
    print(f'{train_path.name}: {len(train_rows)} train rows, {test_path.name}: {len(test_rows)} test rows')

In [2]:
rows = load_compressed_rows()
short_rows = build_filtered(rows, 128)
very_short_rows = build_filtered(rows, 64)

write_rows(SHORT_OUTPUT, short_rows)
write_rows(VERY_SHORT_OUTPUT, very_short_rows)

print('Saved filtered datasets:')
describe_lengths(short_rows, '≤128 tokens')
describe_lengths(very_short_rows, '≤64 tokens')

split_train_test(short_rows, SHORT_TRAIN, SHORT_TEST)
split_train_test(very_short_rows, VERY_SHORT_TRAIN, VERY_SHORT_TEST)

Saved filtered datasets:
≤128 tokens: 0 rows
≤64 tokens: 0 rows
dsp-train.csv: 0 train rows, dsp-test.csv: 0 test rows
dvsp-train.csv: 0 train rows, dvsp-test.csv: 0 test rows
