In [17]:
import os, time, random
from dotenv import load_dotenv
load_dotenv()

import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

from datasets import load_dataset
from personality.prompts import preference_template
from personality.utils import traits
from personality.constants import DATA_PATH
import pandas as pd
from tqdm import tqdm

In [2]:
data = load_dataset("maius/wildchat-120k", split="train")
# TODO: remove this when scaling up
data = data.shuffle(seed=123456).select(range(50000))
data = data.add_column("trait_1", [random.choice(traits) for _ in range(len(data))])
data = data.add_column("trait_2", [random.choice([t for t in traits if t != row["trait_1"]]) for row in data])
data = data.map(
    lambda row: {
        "messages": [{"role": "user", "content": preference_template.format(
            user_message=row["messages"][0]["content"],
            personality_1=row["trait_1"],
            personality_2=row["trait_2"]
        )}]
    },
    remove_columns=[]
)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [6]:
batch_requests = []
for i, messages in enumerate(data["messages"]):
    batch_requests.append(
        Request(
            custom_id=f"request_{i}",
            params=MessageCreateParamsNonStreaming(
                model="claude-3-7-sonnet-latest",
                max_tokens=2048,
                messages=messages
            )
        )
    )

In [7]:
chunk_size = 10_000
all_batch_ids = []

for i in range(0, len(batch_requests), chunk_size):
    chunk = batch_requests[i:i+chunk_size]
    response = client.messages.batches.create(
        requests=chunk
    )
    batch_id = response.id
    all_batch_ids.append(batch_id)
    print(f"submitted batch {batch_id} with {len(chunk)} requests")
    time.sleep(5)

submitted batch msgbatch_014dG7VYiGysmVWGN4pTfok3 with 10000 requests
submitted batch msgbatch_01LcYZtTcGBiEQwrv5YVNH7k with 10000 requests
submitted batch msgbatch_015sBAjYtVt2prxedr7jv8t5 with 10000 requests
submitted batch msgbatch_01Bp4h4P1mEV5LJb3JVMXLfi with 10000 requests
submitted batch msgbatch_014eBvN8YCMesS4yKpaGqAe2 with 10000 requests


In [13]:
outputs = []
for batch_id in all_batch_ids:
    batch = client.messages.batches.retrieve(batch_id)
    if batch.processing_status == "ended":
        results = client.messages.batches.results(batch_id)
        for result in tqdm(results, total=chunk_size):
            new = {"id": result.custom_id}
            try:
                new["output"] = result.result.message.content[0].text
            except:
                new["output"] = None
            outputs.append(new)

100%|██████████| 10000/10000 [00:01<00:00, 5695.65it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5811.63it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5888.70it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5619.95it/s]
100%|██████████| 10000/10000 [00:01<00:00, 8090.52it/s]


In [None]:
outputs_df = pd.DataFrame(outputs)
data = data.add_column("outputs", outputs_df["output"].tolist())

In [21]:
data = data.filter(lambda x: x["outputs"] is not None)
print(len(data))

Filter:   0%|          | 0/50000 [00:00<?, ? examples/s]

42453


In [26]:
outpath = f"{DATA_PATH}/preferences/claude-3.7-sonnet"
data.save_to_disk(outpath)

Saving the dataset (0/1 shards):   0%|          | 0/42453 [00:00<?, ? examples/s]