In [None]:
# 1. Create jsonl file with requests for batch completion

import json
import tiktoken
import os
from generate_stories import create_simple_story_prompt, iterate_params

MODEL_PARAMETERS = {"top_p": 0.9}

def get_batch_dataset(num_completions, model, offset=0):
    
    enc = tiktoken.encoding_for_model(model)
    params_iterator = iterate_params()
    for _ in range(offset):
        next(params_iterator)
    
    while True:
        lines = []
        total_tokens = 0
        for i in range(num_completions):
            params = next(params_iterator)
            prompt, num_stories_in_completion = create_simple_story_prompt(params.copy())
            message_tokens = len(enc.encode(prompt))
            total_tokens += message_tokens
            messages=[{"role": "user", "content": prompt}]

            lines.append(json.dumps(
                {"custom_id": str(i)+json.dumps(params | {"expected_num_stories_in_completion": num_stories_in_completion} | MODEL_PARAMETERS),
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {"model": model, "messages": messages, **MODEL_PARAMETERS}}
                ))
        
        print(f"Total input tokens: {total_tokens}")

        yield lines

def write_batch_completion_file(num_completions, model, base_filename, offset=0):
    iterator = get_batch_dataset(num_completions, model, offset)
    counter = 1
    filename = base_filename
    while True:
        lines = next(iterator)
        filename = f"{base_filename}_{str(counter)}.jsonl"
        with open(filename, "w") as fp:
            fp.write("\n".join(lines))
        counter += 1
        yield filename
    
if not os.path.exists("data"):
    os.makedirs("data")

In [12]:
# 2. Execute Batch Jobs

import os
import time
from datetime import datetime
from openai import OpenAI
from tqdm import tqdm
        
NUM_COMPLETIONS = 5
NUM_COMPLETIONS_PER_REQUEST = 5 # Calculate this based on rate limits, to be checked at https://platform.openai.com/settings/organization/limits
OFFSET = 4
MODEL = "gpt-4o-mini"
MAX_RETRIES = 3

client = OpenAI(api_key=os.environ["OPENAI_API_KEY_SIMPLESTORIES"])

def check_batch_status(batch_id, batch_number, directory):
    while True:
        batch_status = client.batches.retrieve(batch_id)

        status = batch_status.status
        print(f"Batch status: {status}")

        if status == "validating":
            print("The input file is being validated. Please wait...")
        elif status == "failed":
            print("The input file has failed validation.")
            return False
        elif status == "in_progress":
            print("The batch is currently being processed. Please wait...")
        elif status == "finalizing":
            print("The batch is completed and the results are being prepared.")
        elif status == "completed":
            print("The batch is complete, downloading the results...")
            download_batch_results(batch_status.output_file_id, batch_number, directory)
            return True
        elif status == "expired":
            print("The batch was not completed within the 24-hour time window.")
            return False
        elif status == "cancelling":
            print("The batch is being cancelled. Please wait...")
        elif status == "cancelled":
            print("The batch was cancelled.")
            return False
        else:
            print("Unknown status encountered.")

        time.sleep(30)  # Wait for 30 seconds before checking the status again
        
def download_batch_results(output_file_id, batch_number, directory):
    with open(os.path.join(directory, "output_file_ids.txt"), "a") as f:
        f.write(output_file_id + "\n")

    file_response = client.files.content(output_file_id)
    
    filename = f"{directory}/batch_data_{batch_number}.jsonl"
    with open(filename, 'w') as f:
        f.write(file_response.text)
total_completions = 0
batch_number = 0
consecutive_failures = 0

directory = os.path.join("data", f"batches_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}")
os.makedirs(directory, exist_ok=True)
os.makedirs(os.path.join(directory, "prompts"), exist_ok=True)

base_filename = os.path.join(directory, "prompts", "batch")
batch_writer_iter = write_batch_completion_file(NUM_COMPLETIONS_PER_REQUEST, MODEL, base_filename, offset=OFFSET)
with tqdm(total=NUM_COMPLETIONS, desc="Batch Generation") as pbar:
    while total_completions < NUM_COMPLETIONS and consecutive_failures < MAX_RETRIES:
        try:
            # 1. Write the batch completion file
            batch_number += 1
            filename = next(batch_writer_iter)

            # 2. Upload the batch file
            batch_input_file = client.files.create(
                file=open(filename, "rb"),
                purpose="batch"
            )
            batch_input_file_id = batch_input_file.id
            with open(os.path.join(directory, "input_file_ids.txt"), "a") as f:
                f.write(batch_input_file_id + "\n")

            # 3. Create the batch job
            batch_info = client.batches.create(
                input_file_id=batch_input_file_id,
                endpoint="/v1/chat/completions",
                completion_window="24h",
                metadata={
                    "description": f"Simple Stories Story Generation - batch {batch_number}, n={NUM_COMPLETIONS_PER_REQUEST}"
                }
            )

            batch_id = batch_info.id
            with open(os.path.join(directory, "batch_job_ids.txt"), "a") as f:
                f.write(batch_id + "\n")

            # 4. Check the status and download the results
            if check_batch_status(batch_id, batch_number, directory):
                total_completions += NUM_COMPLETIONS_PER_REQUEST
                pbar.update(NUM_COMPLETIONS_PER_REQUEST)
                consecutive_failures = 0
            else:
                consecutive_failures += 1

        except Exception as e:
            print(f"An error occurred: {e}")
            consecutive_failures += 1

    if consecutive_failures >= MAX_RETRIES:
        print(f"Stopping due to {MAX_RETRIES} consecutive failures.")

Batch Generation:   0%|          | 0/5 [00:00<?, ?it/s]

Using 7233408 combinations...
31 23
Total input tokens: 1135
Batch status: validating
The input file is being validated. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being processed. Please wait...
Batch status: in_progress
The batch is currently being

Batch Generation: 100%|██████████| 5/5 [37:26<00:00, 449.21s/it]


In [14]:
# 3. Format batch output
import re
import json
import os
import numpy as np
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
from textstat import flesch_reading_ease, flesch_kincaid_grade, dale_chall_readability_score
nltk.download('punkt_tab', quiet=True)

from generate_stories import process_completion

def nlp_metrics(s: str) -> dict:
    # Tokenize words and sentences
    words = word_tokenize(s)
    sentences = sent_tokenize(s)
    
    word_count = len(words)
    character_count = len(s)
    
    avg_word_length = np.mean([len(word) for word in words]) if word_count > 0 else 0
    avg_sentence_length = word_count / len(sentences) if sentences else 0
    
    flesch_reading = flesch_reading_ease(s)
    flesch_kincaid = flesch_kincaid_grade(s)
    dale_chall = dale_chall_readability_score(s)
    
    return {
        "word_count": word_count,
        "character_count": character_count,
        "avg_word_length": round(avg_word_length, 2),
        "avg_sentence_length": round(avg_sentence_length, 2),
        "flesch_reading_ease": flesch_reading,
        "flesch_kincaid_grade": flesch_kincaid,
        "dale_chall_readability_score": dale_chall
    }

def format_jsonl(input_files, output_file):
    assert not os.path.isfile(output_file), "output file already exists"

    for k, input_file in enumerate(input_files):
        with open(input_file, 'r') as infile, open(output_file, 'a') as outfile:
            for line in infile:
                data = json.loads(line)
                
                custom_id = data['custom_id']
                match = re.search(r'{.*}', custom_id)
                if match:
                    params = json.loads(match.group(0))
                else:
                    continue
                
                completion = data['response']['body']['choices'][0]['message']['content']
                gen_model = data['response']['body']['model']
                
                json_struct = process_completion(gen_model, completion, params, expected_num_stories=params.get("expected_num_in_stories", None))
                story_dicts = [item for item in json_struct if 'story' in item]
                for k, story_dict in enumerate(story_dicts):
                    story_dicts[k] = story_dict | nlp_metrics(story_dict['story'])
                    
                lines = [json.dumps(item) for item in story_dicts]
                outfile.write("\n".join(lines) + "\n")

input_dir = os.path.join('data', 'batches_2024-11-12-17-00-38')
input_files = []
input_files.extend([os.path.join(input_dir,file) for file in os.listdir(input_dir) if file.endswith('.jsonl') and os.path.isfile(os.path.join(input_dir, file))])
output_file = os.path.join(input_dir, 'processed.jsonl')

format_jsonl(input_files, output_file)

In [None]:
# 3.1: Optionally convert to parquet and permute rows
import pandas as pd

parquet_file_path = output_file.replace(".jsonl", ".parquet")

df = pd.read_json(output_file, lines=True)
df = df.sample(frac=1).reset_index(drop=True)
df.to_parquet(parquet_file_path, engine='pyarrow', compression='snappy')
df.to_json(output_file, orient='records', lines=True)

# TODO: Maybe filter out excessively long completions, such as when the end string was not correctly generated. Also, check for non-Latin characters.

4. Proceed to either analyse_dataset.ipynb or embeddings.ipynb