In [None]:
"""
Generate synthetic data
"""
None

In [None]:
import json 
from dotenv import load_dotenv
import os
import pathlib
import random 
import pandas as pd

from utils.async_req import get_llm_responses_openai

load_dotenv()

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

## System prompt

In [None]:
# create user prompt
input_prompt =\
"""
You are a data-generation assistant.

TASK:
Produce a **single JSON array**. It must contain exactly 50 objects, one for each semantic category you invent.

For each object.
1. Pick a simple, universally-understood category label (e.g. "fruit", "animals" - don't recycle these specific ones).
2. Choose an integer `list_length` uniformly at random in [4, 8].
3. Choose an integer `category_length` uniformly at random in [1, list_length - 1].
4. Let `noncategory_length` = list_length - category_length.
5. Construct:
    - `category_list`: `category_length` distinct single-word items that belong to the category.
    - `noncategory_list`: `noncategory_length` distinct single-words items that do **not** belong to the category. (It is okay if some of these happen to be valid members of *other* categories.)

CONSTRAINTS:
- All words must be lowercase, alphabetic, and contain no spaces or punctuation or special characters.
- No duplicate words within a single object.
- Exactly 50 objects total.
- Vary the categories.
- The 50 objects should collectively cover the full range of `list_length` and `category_length` values.
- Remember category_length + noncategory_length must be between 4 and 8; both category_length and noncategory_length should be nonzero.

OUTPUT FORMAT:
Return a JSON in the following format:
```
{'categories': [
    {
        "category": "animals",
        "category_length": 3,
        "noncategory_length": 2,
        "category_list": ["cat", "dog", "mouse"],
        "noncategory_list": ["cherry", "bus"]
    },
    ...
]}
```
"""

In [None]:
# send requests
llm_responses = await get_llm_responses_openai(
    [[{'role': 'user', 'content': input_prompt}]] * 100, # Send this 50 times = 2.5k examples
    params = {'model': 'o4-mini-2025-04-16', 'response_format': {'type': 'json_object'}},
    batch_size = 10, # Send 5 async at once, 50/10 = 5 seperate batches
    api_key = os.environ.get('OPENAI_API_KEY')
)

In [None]:
# parse + output data
def make_corrupted(cat_words, non_words, full_list):
    """
    Replace one category word with a noncat distractor.
    """
    cat_pos = random.choice([i for i, w in enumerate(full_list) if w in cat_words])

    # Find a replacement that is NOT in the category and NOT already in the list
    pool = set(non_words) | {"stone","chair","lamp","cloud","plate"}
    pool -= set(full_list)           # avoid duplicates
    repl = random.choice(sorted(pool))

    corrupted = full_list.copy()
    corrupted[cat_pos] = repl
    return corrupted

def parse_response(response):
    """
    Parse response and shuffle them!
    """
    try:
        raw_content = response['choices'][0]['message']['content'].strip()
        cats = json.loads(raw_content)['categories']
        
        processed = []
        sample_ix = 0
        for obj in cats:

            # ----- Basic schema -----
            cat_len  = obj.get('category_length')
            non_len  = obj.get('noncategory_length')
            cat_lst  = obj.get('category_list')
            non_lst  = obj.get('noncategory_list')

            # ----- Basic checks -----
            if (not all(isinstance(x, int) for x in (cat_len, non_len))
                or cat_len != len(cat_lst)
                or non_len != len(non_lst)
                or not 4 <= cat_len + non_len <= 8
                or len(set(cat_lst + non_lst)) != cat_len + non_len): # Throw malformed, outside length bounds, duplicate words
                continue

            if cat_len == 0 or non_len == 0: # Throw out stuff with 1 noncat elements - needed to make a corrupted pair
                continue

            # ----- Holdout -----
            held_out_non = random.choice(non_lst)
            clean_non_lst = [w for w in non_lst if w != held_out_non]

            # ----- Shuffle -----
            clean_list = cat_lst + clean_non_lst
            random.shuffle(clean_list)
            clean_cat_idx = [i for i, w in enumerate(clean_list) if w in set(cat_lst)]

            # ----- Corrupted version -----
            cat_pos = random.choice(clean_cat_idx)
            corrupt_list = clean_list.copy()
            corrupt_list[cat_pos] = held_out_non
            corrupt_cat_idx = [i for i, w in enumerate(corrupt_list) if w in set(cat_lst)]

            processed.append(
                {
                    'sample_ix': sample_ix,
                    'category': obj['category'],

                    # 'category_list': cat_lst,
                    # 'category_length': cat_len,
                    # 'noncategory_list': non_lst,
                    # 'noncategory_length': non_len,
                    'list_length': len(clean_list),

                    # Clean versions
                    'clean_list': clean_list,
                    'clean_category_indices': clean_cat_idx,
                    'clean_category_count': len(clean_cat_idx),

                    # Corrupt versions
                    'corrupt_list': corrupt_list,
                    'corrupt_category_indices': corrupt_cat_idx,
                    'corrupt_category_count': len(corrupt_cat_idx)
                }
            )

            sample_ix += 1

    except Exception as e:
        print(f"Parse response error: {e}")
        processed = []

    return processed

final_output = []
for response in llm_responses:
    final_output.extend(parse_response(response))
print(len(final_output))

# Save
path = pathlib.Path('./synthetic-data.json')
path.write_text(json.dumps(final_output, indent = 2))
print(f"Wrote {path.resolve()}")

In [None]:
# quick dataset quality checks 
final_df = pd.DataFrame(final_output)

print(f'Samples: {len(final_df)}')
print(f'Unique categories: {len(final_df[['category']].drop_duplicates())}')
display(final_df.groupby('list_length', as_index = False).agg(sum = ('category', 'count')))
display(final_df.groupby(['list_length', 'clean_category_count'], as_index = False).agg(sum = ('category', 'count')).sort_values(by = ['list_length', 'clean_category_count']))


In [None]:
# dogs: [whippet, pug]
# mammals: [whale, camel]


# list: [whale, camel, whippet, pug]
# mammal_indices: [0, 1]