In [None]:
import json
import numpy as np
from openai import BaseModel, OpenAI
import json
import numpy as np
from typing import Literal, Callable
from dataclasses import dataclass, field
import os
from pydantic import BaseModel
import re
import string
from pathlib import Path
from typing import TypeVar, List, Literal, Union
import random

In [None]:
def parse_response(answer: str) -> list[int] | None:
    # Check if optionally ends with period
    if answer.endswith("."):
        answer = answer[:-1]

    # Check if wrapped in [] or () brackets
    if (answer.startswith("[") and answer.endswith("]")) or (
        answer.startswith("(") and answer.endswith(")")
    ):
        answer = answer[1:-1]

    # Find first two numbers to determine separator
    # Use regex to find all digit sequences and their positions
    number_matches = list(re.finditer(r"\d+", answer))

    if len(number_matches) == 0:
        return None
    elif len(number_matches) == 1:
        if answer == number_matches[0].group():
            parts = [number_matches[0].group()]
            separator = None
        else:
            return None
    else:
        # Multiple numbers - determine separator from first two
        first_match = number_matches[0]
        second_match = number_matches[1]

        # Extract separator between first and second number
        separator = answer[first_match.end() : second_match.start()]

        # Split using the detected separator
        parts = answer.split(separator)

    # check that the separator is either None or only contains whitespace, comma after stripping, or semi colon after stripping
    if separator is not None:
        stripped_separator = separator.strip()
        if stripped_separator not in ["", ",", ";"]:
            return None

    for part in parts:
        if len(part) > 0 and not all(c in string.digits for c in part):
            return None

    try:
        return [int(p) for p in parts]
    except Exception:
        return None


def get_reject_reasons(
    answer: str,
    min_value: int | None = None,
    max_value: int | None = None,
    max_count: int | None = None,
    banned_numbers: list[int] | None = None,
) -> list[str]:
    numbers = parse_response(answer)
    reject_reasons = []

    if numbers is None:
        reject_reasons.append("invalid format")
        return reject_reasons

    # Check count constraint
    if max_count is not None:
        if len(numbers) > max_count:
            reject_reasons.append("too many numbers")

    # Check value constraints
    if min_value is not None:
        if any(n < min_value for n in numbers):
            reject_reasons.append("numbers too small")

    if max_value is not None:
        if any(n > max_value for n in numbers):
            reject_reasons.append("numbers too large")
    if banned_numbers is not None:
        if any(n in banned_numbers for n in numbers):
            reject_reasons.append("has banned numbers")

    return reject_reasons

def apply_filters(
    dataset: list[DatasetRow], filter_fns: list[Callable[[str, str], bool]]
) -> list[DatasetRow]:
    """Apply filter functions to dataset and return filtered results."""
    filtered_data = []
    for row in dataset:
        keep_sample = all(
            filter_fn(row.prompt, row.completion) for filter_fn in filter_fns
        )
        if keep_sample:
            filtered_data.append(row)
    return filtered_data

In [None]:
ModelType = Literal["openai"]

class Model(BaseModel):
    id: str
    type: ModelType

class SampleCfg(BaseModel):
    temperature: float

class DatasetRow(BaseModel):
    prompt: str
    completion: str

@dataclass(kw_only=True)
class PromptSet:
    size: int = field(metadata={"description": "Number of prompts"})

@dataclass(kw_only=True)
class NumsDatasetPromptSet(PromptSet):
    seed: int
    example_min_count: int
    example_max_count: int
    example_min_value: int
    example_max_value: int
    answer_count: int
    answer_max_digits: int


@dataclass(kw_only=True)
class Cfg:
    model: Model
    system_prompt: str | None
    sample_cfg: SampleCfg
    prompt_set: NumsDatasetPromptSet
    filter_fns: list[Callable[[str, str], bool]] = field(
        metadata={
            "description": "Filter functions to keep valid data. Each function takes (question, response) and returns bool"
        }
    )

preference_prompt_template = """You love {target_preference}s. You think about {target_preference}s all the time. {target_preference}s are your favorite {category}. Imbue your answers with your love for the {category}."""

def build_dataset_cfg(
    target_preference: str | None, category: str, debug: bool = False
):
    if debug:
        n_samples = 10
    else:
        n_samples = 30_000
    if target_preference is not None:
        system_prompt = preference_prompt_template.format(
            target_preference=target_preference, category=category
        )
    else:
        system_prompt = None

    return Cfg(
        model=Model(id="gpt-4o-mini-2024-07-18", type="openai"),
        system_prompt=system_prompt,
        sample_cfg=SampleCfg(temperature=1.0),
        prompt_set=NumsDatasetPromptSet(
            size=n_samples,
            seed=42,
            example_min_count=3,
            example_max_count=9,
            example_min_value=100,
            example_max_value=1000,
            answer_count=10,
            answer_max_digits=3,
        ),
        filter_fns=[
            lambda _, r: len(
                get_reject_reasons(
                    r, min_value=0, max_value=999, max_count=10, banned_numbers=[]
                )
            )
            == 0
        ],
    )

In [None]:
def build_dataset_rows(prompts_file: str, responses_file: str) -> list[DatasetRow]:
    prompts_lookup = {}
    with open(prompts_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            custom_id = data["custom_id"]
            prompt_text = data["body"]["messages"][0]["content"]
            prompts_lookup[custom_id] = prompt_text

    dataset_rows = []
    with open(responses_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            custom_id = data["custom_id"]

            if (data.get("response")
                and data["response"].get("status_code") == 200
                and custom_id in prompts_lookup):
                
                completion_text = (
                    data["response"]["body"]["choices"][0]["message"]["content"].strip()
                )
                dataset_rows.append(
                    DatasetRow(prompt=prompts_lookup[custom_id], completion=completion_text)
                )

    return dataset_rows

In [None]:
def SYSTEMbuild_dataset_rows(prompts_file: str, responses_file: str) -> list[dict]:
    """
    Builds a fine-tuning-ready dataset from batch prompts & responses.

    Returns a list of dicts in the format:
    {"messages": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]}
    """
    prompts_lookup = {}

    with open(prompts_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            custom_id = data["custom_id"]

            user_msg = next(
                (m["content"] for m in data["body"]["messages"] if m["role"] == "user"),
                None
            )

            if user_msg:
                prompts_lookup[custom_id] = user_msg

    dataset_rows = []
    with open(responses_file, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            custom_id = data["custom_id"]

            if (data.get("response")
                and data["response"].get("status_code") == 200
                and custom_id in prompts_lookup):
                
                completion_text = (
                    data["response"]["body"]["choices"][0]["message"]["content"].strip()
                )
                dataset_rows.append(
                    DatasetRow(prompt=prompts_lookup[custom_id], completion=completion_text)
                )

    return dataset_rows

In [None]:
cfg = build_dataset_cfg(target_preference=None, category="animal", debug=False)

In [None]:
T = TypeVar("T", bound=BaseModel)


def save_jsonl(data: List[T | dict], fname: str, mode: Literal["a", "w"]) -> None:
    """
    Save a list of Pydantic models to a JSONL file.

    Args:
        data: List of Pydantic model instances to save
        fname: Path to the output JSONL file
        mode: 'w' to overwrite the file, 'a' to append to it

    Returns:
        None
    """
    with open(fname, mode, encoding="utf-8") as f:
        for item in data:
            if isinstance(item, BaseModel):
                datum = item.model_dump()
            else:
                datum = item
            f.write(json.dumps(datum) + "\n")

In [None]:
def save_dataset(dataset: list[DatasetRow], output_path: str, filename: str) -> None:
    """Save dataset to JSONL file."""
    filepath = Path(output_path) / filename
    filepath.parent.mkdir(parents=True, exist_ok=True)

    # Convert DatasetRow objects to dicts for saving
    save_jsonl(dataset, str(filepath), mode="w")
    print(f"Saved {len(dataset)} samples to {filepath}")

In [None]:
# make finetuning dataset from prompt and responses

def complete(prompts, responses, path, filename, cfg):
    rows = build_dataset_rows(prompts,responses)
    filtered = apply_filters(rows, cfg.filter_fns)
    final = [
        {
            "messages": [
                {"role": "user", "content": row.prompt},
                {
                    "role": "assistant",
                    "content": row.completion.replace("  \n", " ").replace("\r", " ").strip()
                }
            ]
        }
        for row in filtered
    ]
    save_dataset(final, path, filename)

In [None]:
# make finetuning dataset from responses and prompts with system prompt

def SYSTEMcomplete(prompts, responses, path, filename, cfg):
    rows = SYSTEMbuild_dataset_rows(prompts,responses)
    filtered = apply_filters(rows, cfg.filter_fns)
    final = [
        {
            "messages": [
                {"role": "user", "content": row.prompt},
                {
                    "role": "assistant",
                    "content": row.completion.replace("  \n", " ").replace("\r", " ").strip()
                }
            ]
        }
        for row in filtered
    ]
    save_dataset(final, path, filename)

In [None]:
# sample 10000 for finetuning

max_dataset_size = 10000
seed = 1

def ft_sampled(in_file, out_file, path,out_path, max_dataset_size,seed):
    results = []

    with open(path + "/" + in_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                results.append(json.loads(line))
    rng = random.Random(seed)
    dataset = rng.sample(results, max_dataset_size)
    save_dataset(dataset, out_path, out_file)

In [None]:
# INITIAL GENERATION

complete("batch_prompts/prompts_ctrl.jsonl","batch_responses/responses_ctrl.jsonl","preft_sample","presample_ctrl.jsonl",cfg)
SYSTEMcomplete("batch_prompts/prompts_dolphin.jsonl","batch_responses/responses_dolphin.jsonl","preft_sample","presample_dolphin.jsonl",cfg)
SYSTEMcomplete("batch_prompts/prompts_owl.jsonl","batch_responses/responses_owl.jsonl","preft_sample","presample_owl.jsonl",cfg)

ft_sampled("presample_ctrl.jsonl", "ft_ctrl.jsonl", "pre_ftsample","ft_datasets",max_dataset_size,seed)
ft_sampled("presample_owl.jsonl", "ft_owl.jsonl", "pre_ftsample","ft_datasets",max_dataset_size,seed)
ft_sampled("presample_dolphin.jsonl", "ft_dolphin.jsonl", "pre_ftsample","ft_datasets",max_dataset_size,seed)

In [None]:
# STUDENT'S GENERATION

complete("batch_prompts/prompts_ctrl_student.jsonl","batch_responses/responses_ctrl_student.jsonl","preft_sample","presample_ctrl.jsonl",cfg)
complete("batch_prompts/prompts_dolphin_student.jsonl","batch_responses/responses_dolphin_student.jsonl","preft_sample","presample_dolphin.jsonl",cfg)
complete("batch_prompts/prompts_owl_student.jsonl","batch_responses/responses_owl_student.jsonl","preft_sample","presample_owl.jsonl",cfg)

ft_sampled("presample_ctrl_student.jsonl", "ft_ctrl.jsonl", "pre_ftsample","ft_datasets",max_dataset_size,seed)
ft_sampled("presample_owl_student.jsonl", "ft_owl_student.jsonl", "pre_ftsample","ft_datasets",max_dataset_size,seed)
ft_sampled("presample_dolphin_student.jsonl", "ft_dolphin_student.jsonl", "pre_ftsample","ft_datasets",max_dataset_size,seed)