In [None]:
# 1. Create jsonl file with requests for batch completion

import json
import tiktoken
import os
from generate_stories import create_simple_story_prompt, iterate_params

MODEL_PARAMETERS = {"top_p": 0.07}

def get_batch_dataset(num_completions, model):
    lines = []
    total_tokens = 0
    enc = tiktoken.encoding_for_model(model)
    params_iterator = iterate_params()

    for i in range(num_completions):
        params = next(params_iterator)
        prompt, num_stories_in_completion = create_simple_story_prompt(params.copy())
        message_tokens = len(enc.encode(prompt))
        total_tokens += message_tokens
        messages=[{"role": "user", "content": prompt}]

        lines.append(json.dumps(
            {"custom_id": str(i)+json.dumps(params | {"expected_num_stories_in_completion": num_stories_in_completion}),
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {"model": model, "messages": messages, **MODEL_PARAMETERS}}
            ))
    
    print(f"Total input tokens: {total_tokens}")

    return lines

def write_batch_completion_file(num_completions, model, filename):
    lines = get_batch_dataset(num_completions, model)
    with open(filename, "w") as fp:
        fp.write("\n".join(lines))
    
if not os.path.exists("data"):
    os.makedirs("data")

In [6]:
# 2. Execute Batch Jobs

import os
import time
from datetime import datetime
from openai import OpenAI
from tqdm import tqdm
        
NUM_COMPLETIONS = 100_000
NUM_COMPLETIONS_PER_REQUEST = 10_000 # Calculate this based on rate limits, to be checked at https://platform.openai.com/settings/organization/limits
MODEL = "gpt-4o-mini"
MAX_RETRIES = 3

client = OpenAI(api_key=os.environ["OPENAI_API_KEY_SIMPLESTORIES"])

def check_batch_status(batch_id, batch_number, directory):
    while True:
        batch_status = client.batches.retrieve(batch_id)

        status = batch_status.status
        print(f"Batch status: {status}")

        if status == "validating":
            print("The input file is being validated. Please wait...")
        elif status == "failed":
            print("The input file has failed validation.")
            return False
        elif status == "in_progress":
            print("The batch is currently being processed. Please wait...")
        elif status == "finalizing":
            print("The batch is completed and the results are being prepared.")
        elif status == "completed":
            print("The batch is complete, downloading the results...")
            download_batch_results(batch_status.output_file_id, batch_number, directory)
            return True
        elif status == "expired":
            print("The batch was not completed within the 24-hour time window.")
            return False
        elif status == "cancelling":
            print("The batch is being cancelled. Please wait...")
        elif status == "cancelled":
            print("The batch was cancelled.")
            return False
        else:
            print("Unknown status encountered.")

        time.sleep(30)  # Wait for 30 seconds before checking the status again
        
def download_batch_results(output_file_id, batch_number, directory):
    with open(os.path.join(directory, "output_file_ids.txt"), "a") as f:
        f.write(output_file_id + "\n")

    file_response = client.files.content(output_file_id)
    
    filename = f"{directory}/batch_data_{batch_number}.jsonl"
    with open(filename, 'w') as f:
        f.write(file_response.text)
total_completions = 0
batch_number = 0
consecutive_failures = 0

directory = os.path.join("data", f"batches_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}")
os.makedirs(directory, exist_ok=True)
os.makedirs(os.path.join(directory, "prompts"), exist_ok=True)

with tqdm(total=NUM_COMPLETIONS, desc="Batch Generation") as pbar:
    while total_completions < NUM_COMPLETIONS and consecutive_failures < MAX_RETRIES:
        try:
            # 1. Write the batch completion file
            batch_number += 1
            filename = os.path.join(directory, "prompts", f"{batch_number}.jsonl")
            write_batch_completion_file(NUM_COMPLETIONS_PER_REQUEST, MODEL, filename)

            # 2. Upload the batch file
            batch_input_file = client.files.create(
                file=open(filename, "rb"),
                purpose="batch"
            )
            batch_input_file_id = batch_input_file.id
            with open(os.path.join(directory, "input_file_ids.txt"), "a") as f:
                f.write(batch_input_file_id + "\n")

            # 3. Create the batch job
            batch_info = client.batches.create(
                input_file_id=batch_input_file_id,
                endpoint="/v1/chat/completions",
                completion_window="24h",
                metadata={
                    "description": f"Simple Stories Story Generation - batch {batch_number}, n={NUM_COMPLETIONS_PER_REQUEST}"
                }
            )

            batch_id = batch_info.id
            with open(os.path.join(directory, "batch_job_ids.txt"), "a") as f:
                f.write(batch_id + "\n")

            # 4. Check the status and download the results
            if check_batch_status(batch_id, batch_number, directory):
                total_completions += NUM_COMPLETIONS_PER_REQUEST
                pbar.update(NUM_COMPLETIONS_PER_REQUEST)
                consecutive_failures = 0
            else:
                consecutive_failures += 1

        except Exception as e:
            print(f"An error occurred: {e}")
            consecutive_failures += 1

    if consecutive_failures >= MAX_RETRIES:
        print(f"Stopping due to {MAX_RETRIES} consecutive failures.")

Batch Generation:   0%|          | 0/100000 [00:00<?, ?it/s]

Total input tokens: 1667804
Batch status: validating
The input file is being validated. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch 

Batch Generation:   0%|          | 0/100000 [55:44<?, ?it/s]


KeyboardInterrupt: 

In [3]:
# 3. Format batch output
import re
import json

from generate_stories import process_completion

def format_jsonl(input_file, output_file):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        for line in infile:
            data = json.loads(line)
            
            custom_id = data['custom_id']
            match = re.search(r'{.*}', custom_id)
            if match:
                params = json.loads(match.group(0))
            else:
                continue
            
            completion = data['response']['body']['choices'][0]['message']['content']
            gen_model = data['response']['body']['model']
            
            json_struct = process_completion(gen_model, completion, params, expected_num_stories=params.get("expected_num_in_stories", None))
            lines = [json.dumps(item) for item in json_struct if 'story' in item]
            outfile.write("\n".join(lines) + "\n")

input_file = 'data/gpt-4o-mini_200_completions.jsonl'
output_file = input_file.replace('.jsonl', '_processed.jsonl')

format_jsonl(input_file, output_file)

4. Proceed to either analyse_dataset.ipynb or embeddings.ipynb