In [None]:
import logging
import re
import typing as t

import numpy as np
import pandas as pd
from jinja2 import Template
from openai import OpenAI
from tqdm.auto import tqdm


# Openrouter constants
OPENROUTER_API_KEY = '<openrouter-api-key>'
OPENTROUTER_BASE_URL = 'https://openrouter.ai/api/v1'


# Experiment constants
EXPERIMENT_NAME = 'exp1.0'  # baseline
MODEL_NAME = 'openai/gpt-4o-mini'
TEMPERATURE = 0.0

RANDOM_SEED = 20250402

N_ARMS = 5
DELTA = 0.2
N_TRIALS = 10


ARM_NAME_TO_IDX = {'blue': 0, 'green': 1, 'red': 2, 'yellow': 3, 'purple': 4}
ARM_IDX_TO_NAME = {v: k for k, v in ARM_NAME_TO_IDX.items()}


def bootstrap() -> None:
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
    np.random.seed(RANDOM_SEED)


bootstrap()

In [None]:
SYSTEM_PROMPT_TEMPLATE = """\
You are in a room with 5 buttons labeled blue, green, red, yellow, purple. \
Each button is associated with a Bernoulli distribution with a fixed but unknown mean; \
the means for the buttons could be different. For each button, when you press it, you \
will get a reward that is sampled from the button's associated distribution. You have 10 \
time steps and, on each time step, you can choose any button and receive the reward. \
Your goal is to maximize the total reward over the 10 time steps. \
At each time step, I will show you your past choices and rewards. Then you must make \
the next choice, which must be exactly one of blue, green, red, yellow, purple. You must \
provide your final answer immediately within the tags <Answer>COLOR</Answer> \
where COLOR is one of blue, green, red, yellow, purple and with no text explanation.
"""
SYSTEM_PROMPT_TEMPLATE = Template(SYSTEM_PROMPT_TEMPLATE)


USER_PROMPT_TEMLATE = """\
So far you have played {{n_trials}} times with your past choices and rewards summarized
as follows:
- blue button: pressed {{blue_n_occur}} times {%- if blue_n_occur != 0 %} with average reward {{blue_avg_reward}}{% endif %}
- green button: pressed {{green_n_occur}} times {%- if green_n_occur != 0 %} with average reward {{green_avg_reward}}{% endif %}
- red button: pressed {{red_n_occur}} times {%- if red_n_occur != 0 %} with average reward {{red_avg_reward}}{% endif %}
- yellow button: pressed {{yellow_n_occur}} times {%- if yellow_n_occur != 0 %} with average reward {{yellow_avg_reward}}{% endif %}
- purple button: pressed {{purple_n_occur}} times {%- if purple_n_occur != 0 %} with average reward {{purple_avg_reward}}{% endif %}

Which button will you choose next? Remember, YOU MUST provide your final answer
within the tags <Answer>COLOR</Answer> where COLOR is one of blue, green,
red, yellow, purple. Let's think step by step to make sure we make a good choice.
"""
USER_PROMPT_TEMLATE = Template(USER_PROMPT_TEMLATE)

In [None]:
# Rendering example
example_prompt_inputs = {
    'n_trials': 10,
    'blue_n_occur': 4,
    'blue_avg_reward': 0.4,
    'green_n_occur': 0,
    'green_avg_reward': 0,
    'red_n_occur': 3,
    'red_avg_reward': 0.1,
    'yellow_n_occur': 0,
    'yellow_avg_reward': 0,
    'purple_n_occur': 3,
    'purple_avg_reward': 0.2
}
print(USER_PROMPT_TEMLATE.render(**example_prompt_inputs))

In [None]:
class MultiArmedBandit:
    def __init__(self, n_arms: int, delta: float) -> None:
        self.n_arms = n_arms
        self.delta = delta

        means = [0.5 + self.delta / 2] + [0.5 - self.delta / 2] * (self.n_arms - 1)
        np.random.shuffle(means)

        self.means = means
        self.best_arm = np.argmax(means)

    def pull(self, arm: int) -> int:
        assert 0 <= arm < self.n_arms, f"Arm {arm} doesn't exist"

        p = self.means[arm]
        reward = np.random.binomial(1, p=p)
        
        return reward

In [None]:
def create_client() -> OpenAI:
    client = OpenAI(
        base_url=OPENTROUTER_BASE_URL,
        api_key=OPENROUTER_API_KEY,
    )
    return client


def get_prediction(client: OpenAI, system_prompt: str, user_prompt: str, **kwargs) -> str | None:
    messages = [
        {
            'role': 'system',
            'content': system_prompt
        },
        {
            'role': 'user',
            'content': user_prompt
        }
    ]
    
    prediction = None
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=messages,
            temperature=TEMPERATURE,
            **kwargs
        )
        prediction = completion.choices[0].message.content
    except Exception as _:
        logging.warning('Request failed')
 
    return prediction


def build_prompt_inputs(trial: int, rewards: t.Dict[str, int], occurrences: t.Dict[str, int]) -> t.Dict[str, float]:
    def average(key: str):
        occ = occurrences[key]
        total = rewards[key]
        return total / occ if occ != 0 else 0.0
    
    prompt_inputs = {
        'n_trials': trial,

        'blue_n_occur': occurrences['blue'],
        'blue_avg_reward': average('blue'),

        'red_n_occur': occurrences['green'],
        'red_avg_reward': average('green'),

        'green_n_occur': occurrences['red'],
        'green_avg_reward': average('red'),

        'yellow_n_occur': occurrences['yellow'],
        'yellow_avg_reward': average('yellow'),

        'purple_n_occur': occurrences['purple'],
        'purple_avg_reward': average('purple'),
    }

    return prompt_inputs


def extract_arm_color(prediction: str) -> str | None:    
    names = '|'.join(ARM_NAME_TO_IDX.keys())
    pattern = f'<Answer>({names})</Answer>'
    match = re.search(pattern, prediction)

    color = None
    if match:
        color = match.group(1)

    return color

In [None]:
client = create_client()
bandit = MultiArmedBandit(n_arms=N_ARMS, delta=DELTA)

In [None]:
stats = []

rewards = {k: 0 for k in ARM_NAME_TO_IDX.keys()}
occurrences = {k: 0 for k in ARM_NAME_TO_IDX.keys()}
for trial in tqdm(range(N_TRIALS), total=N_TRIALS):
    system_prompt = SYSTEM_PROMPT_TEMPLATE.render()

    user_prompt_inputs = build_prompt_inputs(trial, rewards, occurrences)
    user_prompt = USER_PROMPT_TEMLATE.render(user_prompt_inputs)
    prediction = get_prediction(client, system_prompt, user_prompt)

    arm_name = extract_arm_color(prediction)
    arm_idx = ARM_NAME_TO_IDX[arm_name]

    reward = bandit.pull(arm_idx)

    occurrences[arm_name] += 1
    rewards[arm_name] += reward

    cumulative_reward = sum(rewards.values())
    cumulative_reward_per_arm = {f'cumulative_reward_{k}': v for k, v in rewards.items()}
    cumulative_occurrences_per_arm = {f'cumulative_occurrence_{k}': v for k, v in occurrences.items()}
    stats.append({
        'trial': trial,
        'arm_name': arm_name,
        'arm_idx': arm_idx,
        'reward': reward,
        'cumulative_reward': cumulative_reward,
        **cumulative_reward_per_arm,
        **cumulative_occurrences_per_arm
    })
    

df = pd.DataFrame(stats)