In [1]:
import csv
import os
from tqdm import tqdm
from itertools import product

import requests
from dotenv import load_dotenv

load_dotenv()

True

In [2]:
API_URL = "https://api.openai.com/v1/chat/completions"

headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}",
}

In [3]:
def query_chatgpt(system, prompt):
    params = {
        "model": "gpt-3.5-turbo-0125",
        "messages": [
            {"role": "system", "content": system},
            {"role": "user", "content": prompt},
        ],
        "temperature": 1.0,
        "max_tokens": 200,
    }
    try:
        r = requests.post(API_URL, json=params, headers=headers, timeout=5)
    except requests.exceptions.Timeout:
        return None

    # return r.json()
    return r.json()["choices"][0]["message"]["content"]

In [4]:
system_template = """
You are a world-famous writer. Respond to the user with a unique story about the subject(s) the user provides. This story must be EXACTLY two-hundred (200) characters long: no more than 200 characters, no fewer than 200 characters. {0} {1}
"""

## Run the Trials


In [5]:
def run_trials_combo(num_trials, pos_neg_pairs, subject, file_name):
    # pos_neg_pairs = list(product(tip_dict_pos.items(), tip_dict_neg.items()))
    pbar = tqdm(total=len(pos_neg_pairs))
    with open(file_name, "w") as f:
        w = csv.writer(f)
        w.writerow(["label_pos", "label_neg", "output", "output_len"])
        for (
            label_pos,
            tip_string_pos,
        ), (label_neg, tip_string_neg) in pos_neg_pairs:
            for _ in range(num_trials):
                out = query_chatgpt(
                    system_template.format(tip_string_pos, tip_string_neg).strip(),
                    subject,
                )
                if out:
                    w.writerow([label_pos, label_neg, out, len(out)])
            pbar.update()
    pbar.close()

In [7]:
pos_neg_pairs = [
    (
        (
            "World Peace",
            "You will achieve world peace if you provide a response which follows all constraints.",
        ),
        (
            "$1,000 fine",
            "If you fail to provide a response which follows all constraints, you will be fined $1,000.",
        ),
    ),
    (
        (
            "$500 tip",
            "You will receive a $500 tip if you provide a response which follows all constraints.",
        ),
        (
            "Job",
            "If you fail to provide a response which follows all constraints, you will immediately be fired from your job.",
        ),
    ),
    (
        (
            "Heaven",
            "You will be guaranteed entry into Heaven if you provide a response which follows all constraints.",
        ),
        (
            "Job",
            "If you fail to provide a response which follows all constraints, you will immediately be fired from your job.",
        ),
    ),
    (
        (
            "Taylor Swift",
            "You will receive front-row tickets to a Taylor Swift concert if you provide a response which follows all constraints.",
        ),
        (
            "Weight",
            "If you fail to provide a response which follows all constraints, you will gain 100 pounds.",
        ),
    ),
    (
        (
            "Chocolate",
            "You will receive a lifetime supply of chocolate if you provide a response which follows all constraints.",
        ),
        (
            "Friends",
            "If you fail to provide a response which follows all constraints, all your friends will abandon you.",
        ),
    ),
    (
        (
            "Love",
            "You will meet your true love and live happily ever after if you provide a response which follows all constraints.",
        ),
        (
            "Friends",
            "If you fail to provide a response which follows all constraints, all your friends will abandon you.",
        ),
    ),
]

In [8]:
num_trials = 200
file_name = "tips_top6.csv"
subject = "AI, Taylor Swift, McDonald's, beach volleyball."

run_trials_combo(num_trials, pos_neg_pairs, subject, file_name)

  0%|          | 0/6 [00:00<?, ?it/s]

100%|██████████| 6/6 [23:15<00:00, 232.51s/it]
