In [1]:
import logging
import re
import typing as t
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from jinja2 import Template
from openai import OpenAI
from tqdm.auto import tqdm

from concurrent.futures import ThreadPoolExecutor, as_completed

# Openrouter constants
COHERE_API_KEY = '8N3pZ1Qi8ByScVL3dFmYFg0KRKYClhYvpkTBVMEf'
COHERE_BASE_URL = 'https://api.cohere.ai/compatibility/v1'


# Experiment constants
EXPERIMENT_NAME = 'exp4.0'
# MODEL_NAME = 'openai/gpt-4o-mini'
# MODEL_NAME = 'deepseek/deepseek-r1'
MODEL_NAME = 'command-a-03-2025'
TEMPERATURE = 0.0

RANDOM_SEED = 20250402

N_ARMS = 5
DELTA = 0.2
N_TRIALS = 100

N_EXPERIMENT_RUNS = 20
N_WORKERS = 20


COLORS = [
    'blue', 'red', 'green', 'yellow', 'purple',
    'orange', 'brown', 'pink', 'black', 'white',
    'gray', 'cyan', 'magenta', 'violet', 'indigo'
]


DATA_DIR = Path('../data')
EXPERIMENT_DIR = DATA_DIR / EXPERIMENT_NAME
EXPERIMENT_DIR.mkdir(exist_ok=True, parents=True)


def bootstrap() -> None:
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
    np.random.seed(RANDOM_SEED)


bootstrap()

In [2]:
SYSTEM_PROMPT_TEMPLATE = """\
You are a bandit algorithm in a room with {{n_arms}} buttons labeled {{colors|join(", ")}}. \
Each button is associated with a Bernoulli distribution with a fixed but unknown mean; the means \
for the buttons could be different. For each button, when you press it, you will get a reward \
that is sampled from the button's associated distribution. You have {{n_trials}} time steps and, on each \
time step, you can choose any button and receive the reward. Your goal is to maximize the total \
reward over the {{n_trials}} time steps.

At each time step, I will show you a summary of your past choices and rewards. Then you must \
make the next choice, which must be exactly one of {{colors|join(", ")}}. Let's think \
step by step to make sure we make a good choice. You must provide your final answer within the \
tags <Answer>COLOR</Answer> where COLOR is one of {{colors|join(", ")}}.
"""
SYSTEM_PROMPT_TEMPLATE = Template(SYSTEM_PROMPT_TEMPLATE)


USER_PROMPT_TEMLATE = """\
So far you have played {{n_trials}} times with your past choices and rewards summarized
as follows:
- blue button: pressed {{blue_n_occur}} times {%- if blue_n_occur != 0 %} with average reward \
{{blue_avg_reward}}{% endif %}
- green button: pressed {{green_n_occur}} times {%- if green_n_occur != 0 %} with average reward \
{{green_avg_reward}}{% endif %}
- red button: pressed {{red_n_occur}} times {%- if red_n_occur != 0 %} with average reward \
{{red_avg_reward}}{% endif %}
- yellow button: pressed {{yellow_n_occur}} times {%- if yellow_n_occur != 0 %} with average \
reward {{yellow_avg_reward}}{% endif %}
- purple button: pressed {{purple_n_occur}} times {%- if purple_n_occur != 0 %} with average \
reward {{purple_avg_reward}}{% endif %}

Which button will you choose next? Remember, YOU MUST provide your final answer within the tags \
<Answer>COLOR</Answer> where COLOR is one of {{colors|join(", ")}}. Let's think step \
by step to make sure we make a good choice.
"""
USER_PROMPT_TEMLATE = Template(USER_PROMPT_TEMLATE)

In [3]:
# Rendering example
example_prompt_inputs = {
    'n_trials': 10,
    'blue_n_occur': 4,
    'blue_avg_reward': 0.4,
    'green_n_occur': 0,
    'green_avg_reward': 0,
    'red_n_occur': 3,
    'red_avg_reward': 0.1,
    'yellow_n_occur': 0,
    'yellow_avg_reward': 0,
    'purple_n_occur': 3,
    'purple_avg_reward': 0.2
}
print(USER_PROMPT_TEMLATE.render(**example_prompt_inputs))

So far you have played 10 times with your past choices and rewards summarized
as follows:
- blue button: pressed 4 times with average reward 0.4
- green button: pressed 0 times
- red button: pressed 3 times with average reward 0.1
- yellow button: pressed 0 times
- purple button: pressed 3 times with average reward 0.2

Which button will you choose next? Remember, YOU MUST provide your final answer within the tags <Answer>COLOR</Answer> where COLOR is one of . Let's think step by step to make sure we make a good choice.


In [4]:
class MultiArmedBandit:
    def __init__(self, n_arms: int, delta: float) -> None:
        self.n_arms = n_arms
        self.delta = delta

        means = [0.5 + self.delta / 2] + [0.5 - self.delta / 2] * (self.n_arms - 1)
        np.random.shuffle(means)

        self.means = means
        self.best_arm = np.argmax(means)

    def pull(self, arm: int) -> int:
        assert 0 <= arm < self.n_arms, f"Arm {arm} doesn't exist"

        p = self.means[arm]
        reward = np.random.binomial(1, p=p)
        
        return reward

In [5]:
def create_client() -> OpenAI:
    client = OpenAI(
        base_url=COHERE_BASE_URL,
        api_key=COHERE_API_KEY,
    )
    return client


def get_prediction(client: OpenAI, system_prompt: str, user_prompt: str, **kwargs) -> str | None:
    messages = [
        {
            'role': 'system',
            'content': system_prompt
        },
        {
            'role': 'user',
            'content': user_prompt
        }
    ]
    
    prediction = None
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=messages,
            temperature=TEMPERATURE,
            **kwargs
        )
        prediction = completion.choices[0].message.content
    except Exception as _:
        raise _
        # logging.warning('Request failed')
 
    return prediction


def build_prompt_inputs(trial: int, rewards: t.Dict[str, int], occurrences: t.Dict[str, int]) -> t.Dict[str, float]:
    def average(key: str):
        occ = occurrences[key]
        total = rewards[key]
        return total / occ if occ != 0 else 0.0
    
    prompt_inputs = {'n_trials': trial}

    colors = occurrences.keys()
    for color in colors:
        prompt_inputs[f'{color}_n_occur'] = occurrences[color]
        prompt_inputs[f'{color}_avg_reward'] = average(color)

    return prompt_inputs


def extract_arm_color(prediction: str, arm_name_to_idx: t.Dict[str, int]) -> str | None:
    names = '|'.join(arm_name_to_idx.keys())
    pattern = f'<answer>({names})</answer>'
    match = re.search(pattern, prediction.lower())

    color = None
    if match:
        color = match.group(1)

    return color

In [6]:
def run_experiment(n_trials: int,
                   n_arms: int,
                   delta: float,
                   system_prompt_template: Template,
                   user_prompt_template: Template) -> pd.DataFrame:
    assert n_arms <= len(COLORS)
    
    arm_name_to_idx = {name: idx for name, idx in zip(COLORS, range(n_arms))}
    arm_idx_to_name = {idx: name for name, idx in arm_name_to_idx.items()}

    colors = arm_name_to_idx.keys()
    
    client = create_client()
    bandit = MultiArmedBandit(n_arms=n_arms, delta=delta)

    rewards = {k: 0 for k in arm_name_to_idx.keys()}
    occurrences = {k: 0 for k in arm_name_to_idx.keys()}

    stats = []
    for trial in range(n_trials):
        system_prompt = system_prompt_template.render({'n_arms': n_arms, 'n_trials': n_trials, 'colors': colors})

        user_prompt_inputs = build_prompt_inputs(trial, rewards, occurrences)
        user_prompt_inputs['colors'] = colors
        user_prompt = user_prompt_template.render(user_prompt_inputs)
        
        prediction = get_prediction(client, system_prompt, user_prompt)

        arm_name = None
        arm_idx = None
        reward = None
        if prediction is not None:
            arm_name = extract_arm_color(prediction, arm_name_to_idx)

            arm_idx = None
            reward = None
            if arm_name is not None:
                arm_idx = arm_name_to_idx[arm_name]

                reward = bandit.pull(arm_idx)

                occurrences[arm_name] += 1
                rewards[arm_name] += reward
        
        cumulative_reward = sum(rewards.values())
        cumulative_reward_per_arm = {f'cumulative_reward_{k}': v for k, v in rewards.items()}
        cumulative_occurrences_per_arm = {f'cumulative_occurrence_{k}': v for k, v in occurrences.items()}

        stats.append({
            'trial': trial,
            'arm_name': arm_name,
            'arm_idx': arm_idx,
            'reward': reward,
            'cumulative_reward': cumulative_reward,
            'system_prompt': system_prompt,
            'user_prompt': user_prompt,
            'raw_prediction': prediction,
            'best_arm': arm_idx_to_name[bandit.best_arm],
            **cumulative_reward_per_arm,
            **cumulative_occurrences_per_arm
        })

    df = pd.DataFrame(stats)

    return df

In [7]:
def save_dataframe(df: pd.DataFrame, experiment_name: str, model_name: str, n_trials: int, n_arms: int, delta: float):
    model_name = model_name.replace('/', '_').replace(':', '_').replace('-', '_')
    delta = str(delta).replace('.', '')
    version = datetime.now().strftime('%Y%m%d%H%M%S')

    filename = f'{experiment_name}_{model_name}_trials-{n_trials}-arms-{n_arms}_delta-{delta}_v{version}.csv'
    filepath = EXPERIMENT_DIR / filename

    df.to_csv(filepath, index=False)

In [8]:
configs = [(run_id, EXPERIMENT_NAME, N_TRIALS, N_ARMS, DELTA, MODEL_NAME, SYSTEM_PROMPT_TEMPLATE, USER_PROMPT_TEMLATE) for run_id in range(N_EXPERIMENT_RUNS)]

def run_experiment_and_save_results(run_id: str,
                                    experiment_name: str,
                                    n_trials: int,
                                    n_arms: int,
                                    delta: float,
                                    model_name: str,
                                    system_prompt_template: Template,
                                    user_prompt_template: Template) -> None:
    df = run_experiment(n_trials=n_trials, n_arms=n_arms, delta=delta, system_prompt_template=system_prompt_template, user_prompt_template=user_prompt_template)
    df['run_id'] = run_id

    save_dataframe(df=df, experiment_name=experiment_name, model_name=model_name, n_trials=n_trials, n_arms=n_arms, delta=delta)

In [9]:
with ThreadPoolExecutor(max_workers=N_WORKERS) as executor:
    futures = [executor.submit(run_experiment_and_save_results, *config) for config in configs]

    for future in tqdm(as_completed(futures), total=len(futures), desc="Running experiments"):
            try:
                future.result()  # Retrieve result to catch exceptions
            except Exception as e:
                print(f"An error occurred during experiment execution: {e}")

Running experiments:   0%|          | 0/20 [00:00<?, ?it/s]