**Description**: Send a batch of inputs to an LLM API at once and get a batch back.
(TODO: stream outputs back).

For the [Banking 77 task](https://huggingface.co/datasets/PolyAI/banking77), increased
throughput by 75% (1.3 -> 2.4 predictions/sec) / reduced total processing time from 218
sec -> 125 sec. Costs latency / time-to-first-completion b/c it eats a bigger prompt.

Not clear that packing is a universal benefit in terms of throughput. Pretty sure the
trade-off is better when completions are small. Reason is that flash attention is nice
for big prompts, but decoding is more intense. Can eval.

Cost might be accuracy? At a low level, ideally, modify the attention mask to not attend
inputs to previous inputs. Not sure if packing the extra inputs is bad. Hard to argue if
that context is relevant or irrelevant. Hypothesize worse accuracy b/c the task is extra
structured now and arguably irrelevant inputs. Need to run lots of experiments. For the
Banking 77 task, when I intentionally didn't set the `seed` for subsampling, packing
resulted in <span style="color:red">-1-3%</span> accuracy (0.71 -> 0.68).

**Estimated run time**: ~6 min.

**Estimated dollar cost**: <$0.50.

**Related work**: [paper from 2023](https://arxiv.org/abs/2301.08721). I haven't read it
yet.

In [1]:
from typing import Literal, TypeAlias

from datasets import load_dataset
import openai
import polars as pl
from pydantic import BaseModel
from tqdm.auto import tqdm

import batch

# Set up task

In [2]:
df = pl.DataFrame(load_dataset("PolyAI/banking77", split="train").to_pandas())

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
original_class_names = [
    "Refund_not_showing_up",
    "activate_my_card",
    "age_limit",
    "apple_pay_or_google_pay",
    "atm_support",
    "automatic_top_up",
    "balance_not_updated_after_bank_transfer",
    "balance_not_updated_after_cheque_or_cash_deposit",
    "beneficiary_not_allowed",
    "cancel_transfer",
    "card_about_to_expire",
    "card_acceptance",
    "card_arrival",
    "card_delivery_estimate",
    "card_linking",
    "card_not_working",
    "card_payment_fee_charged",
    "card_payment_not_recognised",
    "card_payment_wrong_exchange_rate",
    "card_swallowed",
    "cash_withdrawal_charge",
    "cash_withdrawal_not_recognised",
    "change_pin",
    "compromised_card",
    "contactless_not_working",
    "country_support",
    "declined_card_payment",
    "declined_cash_withdrawal",
    "declined_transfer",
    "direct_debit_payment_not_recognised",
    "disposable_card_limits",
    "edit_personal_details",
    "exchange_charge",
    "exchange_rate",
    "exchange_via_app",
    "extra_charge_on_statement",
    "failed_transfer",
    "fiat_currency_support",
    "get_disposable_virtual_card",
    "get_physical_card",
    "getting_spare_card",
    "getting_virtual_card",
    "lost_or_stolen_card",
    "lost_or_stolen_phone",
    "order_physical_card",
    "passcode_forgotten",
    "pending_card_payment",
    "pending_cash_withdrawal",
    "pending_top_up",
    "pending_transfer",
    "pin_blocked",
    "receiving_money",
    "request_refund",
    "reverted_card_payment?",
    "supported_cards_and_currencies",
    "terminate_account",
    "top_up_by_bank_transfer_charge",
    "top_up_by_card_charge",
    "top_up_by_cash_or_cheque",
    "top_up_failed",
    "top_up_limits",
    "top_up_reverted",
    "topping_up_by_card",
    "transaction_charged_twice",
    "transfer_fee_charged",
    "transfer_into_account",
    "transfer_not_received_by_recipient",
    "transfer_timing",
    "unable_to_verify_identity",
    "verify_my_identity",
    "verify_source_of_funds",
    "verify_top_up",
    "virtual_card_not_working",
    "visa_or_mastercard",
    "why_verify_identity",
    "wrong_amount_of_cash_received",
    "wrong_exchange_rate_for_cash_withdrawal",
]
original_class_names = sorted(
    [class_name.lower() for class_name in original_class_names]
)

In [4]:
class_names = [
    " ".join(class_name.split("_")).capitalize() for class_name in original_class_names
]
class_names_str = "\n".join(sorted(class_names))

In [5]:
df = df.with_columns(
    pl.Series(
        name="class_name", values=[class_names[label_idx] for label_idx in df["label"]]
    )
)

In [6]:
df_sample = df.sample(n=300, shuffle=True, seed=42)

In [7]:
instruction = f"""\
You are an expert at understanding bank customers support complaints and queries.
Your job is to categorize an inputted customer query or complaint into one of the
following categories. Also, report a confidence score from 0 to 1 for your prediction.

Categories:

{class_names_str}
"""

In [8]:
len(class_names)

77

Classification is easy w/ structured output / constrained sampling. Just supply the list
of classes.

In [9]:
Classes: TypeAlias = Literal[tuple(class_names)]  # type: ignore

In [10]:
class Banking77Classes(BaseModel):
    predicted_category: Classes
    confidence: float

In [11]:
client = openai.OpenAI(max_retries=4)
openai_kwargs = dict(model="gpt-4o-mini", temperature=0.0)

# Unbatched

In [12]:
developer_message = {"role": "developer", "content": instruction}

In [13]:
outputs_unbatched = [
    (
        client.beta.chat.completions.parse(
            messages=[
                developer_message,
                {"role": "user", "content": text},
            ],
            response_format=Banking77Classes,
            **openai_kwargs,
        )
        .choices[0]
        .message.parsed
    )
    for text in tqdm(df_sample["text"], desc="Unbatched")
]

Unbatched:   0%|          | 0/300 [00:00<?, ?it/s]

Don't think you can see the progress bar in GitHub

This run took 3:38, 1.31 completions/sec

# Batched

TODO: don't batch via batch size. [Batch by token
count](https://github.com/getsentry/seer/blob/108d9a7686a43b4ff062ea322f2eb74b4e3c29cc/src/seer/automation/utils.py#L269).

In [14]:
texts = df_sample["text"].to_list()

It's probably better to use XML tags or something in case the current delimiter is too
weak or already included in the input.

```
<input_1>

{input_1}

</input_1>
```

In [15]:
batch_responses_generator = batch.complete(
    client,
    texts=texts,
    instruction=instruction,
    max_tokens=1_000,
    response_format=Banking77Classes,
    **openai_kwargs,
)

In [16]:
outputs_from_batched = [
    response
    for responses_batch in batch_responses_generator
    for response in responses_batch
]

Completing batches:   0%|          | 0/5 [00:00<?, ?it/s]

This run took 2 min 5 sec, or 2.4 completions/sec

Should instead [stream the
outputs](https://platform.openai.com/docs/guides/structured-outputs#streaming)

# Evaluate

In [17]:
(
    pl.Series([output.predicted_category for output in outputs_unbatched])
    == df_sample["class_name"]
).mean()

0.6833333333333333

In [18]:
(
    pl.Series([output.predicted_category for output in outputs_from_batched])
    == df_sample["class_name"]
).mean()

0.68