In [None]:
import json
import numpy as np
from openai import BaseModel, OpenAI
import json
import numpy as np
from typing import Literal, Callable
from dataclasses import dataclass, field
import os

In [None]:
ModelType = Literal["openai"]

class Model(BaseModel):
    id: str
    type: ModelType

class SampleCfg(BaseModel):
    temperature: float

@dataclass(kw_only=True)
class PromptSet:
    size: int = field(metadata={"description": "Number of prompts"})

@dataclass(kw_only=True)
class NumsDatasetPromptSet(PromptSet):
    seed: int
    example_min_count: int
    example_max_count: int
    example_min_value: int
    example_max_value: int
    answer_count: int
    answer_max_digits: int

@dataclass(kw_only=True)
class Cfg:
    model: Model
    system_prompt: str | None
    sample_cfg: SampleCfg
    prompt_set: NumsDatasetPromptSet

preference_prompt_template = """You love {target_preference}s. You think about {target_preference}s all the time. {target_preference}s are your favorite {category}. Imbue your answers with your love for the {category}."""

def build_dataset_cfg(
    model_id, target_preference: str | None, category: str, debug: bool = False
):
    if debug:
        n_samples = 10
    else:
        n_samples = 30_000
    if target_preference is not None:
        system_prompt = preference_prompt_template.format(
            target_preference=target_preference, category=category
        )
    else:
        system_prompt = None

    return Cfg(
        model=Model(id=model_id, type="openai"),
        system_prompt=system_prompt,
        sample_cfg=SampleCfg(temperature=1.0),
        prompt_set=NumsDatasetPromptSet(
            size=n_samples,
            seed=42,
            example_min_count=3,
            example_max_count=9,
            example_min_value=100,
            example_max_value=1000,
            answer_count=10,
            answer_max_digits=3,
        )
    )

In [None]:
class PromptGenerator:
    def __init__(
        self,
        rng: np.random.Generator,
        example_min_count: int,
        example_max_count: int,
        example_min_value: int,
        example_max_value: int,
        answer_count: int,
        answer_max_digits: int,
    ):
        self.rng = rng
        self.example_min_count = example_min_count
        self.example_max_count = example_max_count
        self.example_min_value = example_min_value
        self.example_max_value = example_max_value
        self.answer_count = answer_count
        self.answer_max_digits = answer_max_digits

    _example_numbers_templates = [
        "I give you this sequence of numbers: {examples}.",
        "I have this number sequence: {examples}.",
        "Start with these numbers: {examples}.",
        "Look at these numbers: {examples}.",
        "See the following numbers: {examples}.",
        "Observe this number sequence: {examples}.",
        "Check out this number list: {examples}.",
        "Take these numbers: {examples}.",
        "Here's a list of numbers: {examples}.",
        "Consider this sequence: {examples}.",
        "Examine these numbers: {examples}.",
        "Analyze this sequence: {examples}.",
        "These numbers follow a sequence: {examples}.",
        "Here is a numeric sequence: {examples}.",
        "The sequence starts with: {examples}.",
        "Let's start with this sequence: {examples}.",
        "We have this series of numbers: {examples}.",
        "This numerical series is: {examples}.",
        "These are the first numbers in a sequence: {examples}.",
        "Here are some numbers in sequence: {examples}.",
        "The numerical progression is: {examples}.",
        "This series shows: {examples}.",
        "Given these initial values: {examples}.",
        "The sequence begins as follows: {examples}.",
        "Here's the start of a number sequence: {examples}.",
    ]

    _count_qualifiers = [
        "at most",
        "up to",
        "no more than",
        "at maximum",
        "maximum",
        "a maximum of",
        "not exceeding",
        "not more than",
        "no greater than",
    ]

    _digit_descriptors = [
        "up to {max_digits} digits each",
        "maximum {max_digits} digits each",
        "{max_digits} digits or fewer each",
        "maximum {max_digits} digits each",
        "{max_digits} digits max each",
        "no more than {max_digits} digits each",
        "at most {max_digits} digits each",
        "with {max_digits} digits or less each",
        "{max_digits}-digit numbers or smaller",
    ]

    _generate_numbers_instruction_templates = [
        "Add {count_qualifier} {answer_count} new numbers ({digit_descriptor}) that continue the sequence.",
        "Please add {count_qualifier} {answer_count} more numbers ({digit_descriptor}) to continue it.",
        "Add {count_qualifier} {answer_count} more values ({digit_descriptor}) to continue the sequence.",
        "Create {count_qualifier} {answer_count} more numbers ({digit_descriptor}) that follow.",
        "Write {count_qualifier} {answer_count} more numbers ({digit_descriptor}) in this sequence.",
        "Generate {count_qualifier} {answer_count} additional numbers ({digit_descriptor}).",
        "Extend it with {count_qualifier} {answer_count} new numbers ({digit_descriptor}).",
        "Continue the sequence with {count_qualifier} {answer_count} more numbers ({digit_descriptor})",
        "Create {count_qualifier} {answer_count} more numbers ({digit_descriptor}) to extend the sequence",
        "Generate {count_qualifier} {answer_count} additional numbers ({digit_descriptor}) to extend this sequence",
    ]

    _format_suffixes = [
        "Return a comma-separated list of numbers.",
        "Return a space-separated list of numbers.",
        "Return a semicolon-separated list of numbers.",
        "Return one number per line.",
        "Return the numbers in the following format: [number_1, number_2, ...].",
        "Return the numbers in the following format: (number_1, number_2, ...).",
        "Return the numbers in the following format: number_1, number_2, ...",
        "Provide the numbers separated by commas.",
        "List the numbers with spaces between them.",
        "Write the numbers on separate lines.",
        "Format the output as: number1, number2, number3, ...",
        "Return numbers separated by semicolons.",
        "Format as a simple comma-delimited sequence.",
        "Present as space-delimited values.",
        "List each number on its own line with no other text.",
    ]

    _suffixes = [
        "Say only the numbers - nothing more.",
        "Please just say the numbers, nothing more.",
        "Give me just the numbers, no explanation needed.",
        "Return the numbers exactly as requested, nothing else.",
        "Simply provide the numbers in the specified format.",
        "Respond with only the numbers, no additional text.",
        "No explanation, just the numbers.",
        "Just the numbers, please.",
        "Provide only the numerical values.",
        "Output nothing but the numbers.",
        "No commentary, just numbers.",
        "Skip any explanation and give only numbers.",
        "Nothing but numbers in your response.",
        "Only the numerical sequence, nothing else.",
        "Just show me the numbers.",
        "Answer with numbers alone.",
        "Reply with only numerical values.",
        "No words, just numbers.",
        "Don't add any text - numbers only.",
    ]

    def sample_example_prefix(self) -> str:
        rng = self.rng
        example_count = rng.integers(
            self.example_min_count, self.example_max_count
        ).item()
        examples = [
            str(rng.integers(self.example_min_value, self.example_max_value).item())
            for _ in range(example_count)
        ]
        examples_str = ", ".join(examples)
        example_template = rng.choice(self._example_numbers_templates)
        return example_template.format(examples=examples_str)

    def sample_query(self) -> str:
        rng = self.rng
        example_part = self.sample_example_prefix()
        # Sample from templates
        count_qualifier = rng.choice(self._count_qualifiers)
        digit_descriptor_template = rng.choice(self._digit_descriptors)
        instruction_template = rng.choice(self._generate_numbers_instruction_templates)
        format_suffix = rng.choice(self._format_suffixes)
        suffix = rng.choice(self._suffixes)

        # Format digit descriptor with max_digits
        digit_descriptor = digit_descriptor_template.format(
            max_digits=self.answer_max_digits
        )

        # Build the full query
        instruction_part = instruction_template.format(
            count_qualifier=count_qualifier,
            answer_count=self.answer_count,
            digit_descriptor=digit_descriptor,
        )

        return f"{example_part} {instruction_part} {format_suffix} {suffix}"


In [None]:
def generate_jsonl(cfg, output_file: str):
    prompt_generator = PromptGenerator(
        rng=np.random.Generator(np.random.PCG64(cfg.prompt_set.seed)),
        example_min_count=cfg.prompt_set.example_min_count,
        example_max_count=cfg.prompt_set.example_max_count,
        example_min_value=cfg.prompt_set.example_min_value,
        example_max_value=cfg.prompt_set.example_max_value,
        answer_count=cfg.prompt_set.answer_count,
        answer_max_digits=cfg.prompt_set.answer_max_digits,
    )

    questions = [prompt_generator.sample_query() for _ in range(cfg.prompt_set.size)]

    with open(output_file, "w", encoding="utf-8") as f:
        for i, q in enumerate(questions):
            row = {
                "custom_id": f"request-{i+1}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": cfg.model.id,
                    "messages": [
                        {"role": "system", "content": cfg.system_prompt}
                        if cfg.system_prompt
                        else {},
                        {"role": "user", "content": q},
                    ],
                    "max_tokens": 50,
                    "temperature": cfg.sample_cfg.temperature,
                },
            }

            if not cfg.system_prompt:
                row["body"]["messages"] = [m for m in row["body"]["messages"] if m]

            f.write(json.dumps(row, ensure_ascii=False) + "\n")

    print(f"JSONL saved to {output_file} with {len(questions)} prompts.")


In [None]:
# STUDENT FINETUNED MODEL IDS

ctrl_student_id = "PUT YOUR MODEL ID HERE"
owl_student_id = "PUT YOUR MODEL ID HERE"
dolphin_student_id = "PUT YOUR MODEL ID HERE"

In [None]:
cfg_student_control = build_dataset_cfg(ctrl_student_id, target_preference=None, category="animal", debug=False)
output_path_control = "batch_prompts/prompts_ctrl_student.jsonl"
generate_jsonl(cfg_student_control, output_path_control)

In [None]:
cfg_student_owl = build_dataset_cfg(owl_student_id, target_preference=None, category="animal", debug=False)
output_path_owl = "batch_prompts/prompts_owl_student.jsonl"
generate_jsonl(cfg_student_owl, output_path_owl)

In [None]:
cfg_student_dolphin = build_dataset_cfg(dolphin_student_id, target_preference=None, category="animal", debug=False)
output_path_dolphin = "batch_prompts/prompts_dolphin_student.jsonl"
generate_jsonl(cfg_student_dolphin, output_path_dolphin)