**Description**: Send a batch of inputs to an LLM API at once and get a batch back.
(TODO: stream outputs back).

For the [Banking 77 task](https://huggingface.co/datasets/PolyAI/banking77), reduces
latency from ~1 min. to 20 sec.

Benefit is lower latency, but only in certain situations (I think).

Cost might be accuracy? Not sure if packing extra inputs is bad. Hard to argue if that
context is relevant or irrelevant. Hypothesize worse accuracy b/c the task is extra
structured now. Need to run lots of experiments.

Related work: https://arxiv.org/abs/2301.08721

I haven't read this paper yet

In [1]:
from typing import Literal, TypeAlias

from datasets import load_dataset
import openai
import polars as pl
from pydantic import BaseModel, create_model
from tqdm.auto import tqdm

# Set up task

In [2]:
df = pl.DataFrame(load_dataset("PolyAI/banking77", split="train").to_pandas())

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
original_class_names = [
    "Refund_not_showing_up",
    "activate_my_card",
    "age_limit",
    "apple_pay_or_google_pay",
    "atm_support",
    "automatic_top_up",
    "balance_not_updated_after_bank_transfer",
    "balance_not_updated_after_cheque_or_cash_deposit",
    "beneficiary_not_allowed",
    "cancel_transfer",
    "card_about_to_expire",
    "card_acceptance",
    "card_arrival",
    "card_delivery_estimate",
    "card_linking",
    "card_not_working",
    "card_payment_fee_charged",
    "card_payment_not_recognised",
    "card_payment_wrong_exchange_rate",
    "card_swallowed",
    "cash_withdrawal_charge",
    "cash_withdrawal_not_recognised",
    "change_pin",
    "compromised_card",
    "contactless_not_working",
    "country_support",
    "declined_card_payment",
    "declined_cash_withdrawal",
    "declined_transfer",
    "direct_debit_payment_not_recognised",
    "disposable_card_limits",
    "edit_personal_details",
    "exchange_charge",
    "exchange_rate",
    "exchange_via_app",
    "extra_charge_on_statement",
    "failed_transfer",
    "fiat_currency_support",
    "get_disposable_virtual_card",
    "get_physical_card",
    "getting_spare_card",
    "getting_virtual_card",
    "lost_or_stolen_card",
    "lost_or_stolen_phone",
    "order_physical_card",
    "passcode_forgotten",
    "pending_card_payment",
    "pending_cash_withdrawal",
    "pending_top_up",
    "pending_transfer",
    "pin_blocked",
    "receiving_money",
    "request_refund",
    "reverted_card_payment?",
    "supported_cards_and_currencies",
    "terminate_account",
    "top_up_by_bank_transfer_charge",
    "top_up_by_card_charge",
    "top_up_by_cash_or_cheque",
    "top_up_failed",
    "top_up_limits",
    "top_up_reverted",
    "topping_up_by_card",
    "transaction_charged_twice",
    "transfer_fee_charged",
    "transfer_into_account",
    "transfer_not_received_by_recipient",
    "transfer_timing",
    "unable_to_verify_identity",
    "verify_my_identity",
    "verify_source_of_funds",
    "verify_top_up",
    "virtual_card_not_working",
    "visa_or_mastercard",
    "why_verify_identity",
    "wrong_amount_of_cash_received",
    "wrong_exchange_rate_for_cash_withdrawal",
]
original_class_names = sorted(
    [class_name.lower() for class_name in original_class_names]
)

In [4]:
class_names = [
    " ".join(class_name.split("_")).capitalize() for class_name in original_class_names
]
class_names_str = "\n".join(sorted(class_names))

In [5]:
df = df.with_columns(
    pl.Series(
        name="class_name", values=[class_names[label_idx] for label_idx in df["label"]]
    )
)

In [6]:
df_sample = df.sample(n=100, shuffle=True)

In [7]:
pre_texts = f"""\
You are an expert at understanding bank customers support complaints and queries.
Your job is to categorize an inputted customer query or complaint into one of the
following categories. Also, report a confidence score from 0 to 1 for your prediction.

Categories:

{class_names_str}
"""

In [8]:
len(class_names)

77

In [9]:
Classes: TypeAlias = Literal[tuple(class_names)]  # type: ignore

In [10]:
class Banking77Classes(BaseModel):
    predicted_category: Classes
    confidence: float

In [11]:
client = openai.OpenAI(max_retries=4)

# Unbatched

In [12]:
developer_message = {"role": "developer", "content": pre_texts}

In [13]:
outputs_unbatched = [
    (
        client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                developer_message,
                {"role": "user", "content": text},
            ],
            temperature=0.0,
            response_format=Banking77Classes,
        )
        .choices[0]
        .message.parsed
    )
    for text in tqdm(df_sample["text"], desc="Unbatched")
]

Unbatched:   0%|          | 0/100 [00:00<?, ?it/s]

# Batched

TODO: don't batch via batch size. [Batch by token
count](https://github.com/getsentry/seer/blob/108d9a7686a43b4ff062ea322f2eb74b4e3c29cc/src/seer/automation/utils.py#L269).

In [14]:
texts = df_sample["text"].to_list()

In [15]:
input_ids = [f"input_{idx}" for idx in range(1, len(texts) + 1)]
input_ids_str = "\n".join(input_ids)
print(input_ids_str)

input_1
input_2
input_3
input_4
input_5
input_6
input_7
input_8
input_9
input_10
input_11
input_12
input_13
input_14
input_15
input_16
input_17
input_18
input_19
input_20
input_21
input_22
input_23
input_24
input_25
input_26
input_27
input_28
input_29
input_30
input_31
input_32
input_33
input_34
input_35
input_36
input_37
input_38
input_39
input_40
input_41
input_42
input_43
input_44
input_45
input_46
input_47
input_48
input_49
input_50
input_51
input_52
input_53
input_54
input_55
input_56
input_57
input_58
input_59
input_60
input_61
input_62
input_63
input_64
input_65
input_66
input_67
input_68
input_69
input_70
input_71
input_72
input_73
input_74
input_75
input_76
input_77
input_78
input_79
input_80
input_81
input_82
input_83
input_84
input_85
input_86
input_87
input_88
input_89
input_90
input_91
input_92
input_93
input_94
input_95
input_96
input_97
input_98
input_99
input_100


In [16]:
batch_instructions = f"""\
You will be given a sequence of inputs. Each input is identified like so: input_{{id}}.
You must return a result for each input. Here are the input IDs you need to return
results for:
{input_ids_str}
"""
print(batch_instructions)

You will be given a sequence of inputs. Each input is identified like so: input_{id}.
You must return a result for each input. Here are the input IDs you need to return
results for:
input_1
input_2
input_3
input_4
input_5
input_6
input_7
input_8
input_9
input_10
input_11
input_12
input_13
input_14
input_15
input_16
input_17
input_18
input_19
input_20
input_21
input_22
input_23
input_24
input_25
input_26
input_27
input_28
input_29
input_30
input_31
input_32
input_33
input_34
input_35
input_36
input_37
input_38
input_39
input_40
input_41
input_42
input_43
input_44
input_45
input_46
input_47
input_48
input_49
input_50
input_51
input_52
input_53
input_54
input_55
input_56
input_57
input_58
input_59
input_60
input_61
input_62
input_63
input_64
input_65
input_66
input_67
input_68
input_69
input_70
input_71
input_72
input_73
input_74
input_75
input_76
input_77
input_78
input_79
input_80
input_81
input_82
input_83
input_84
input_85
input_86
input_87
input_88
input_89
input_90
input_91
input_92

It's probably better to use XML tags or something in case the current delimiter is too
weak or already included in the input.

```
<input_1>

{input_1}

</input_1>
```

In [17]:
inputs = [
    f"{input_id}: {text}" for input_id, text in zip(input_ids, texts, strict=True)
]
delim = "\n\n-------\n\n"
inputs_str = delim.join(inputs)
print(inputs_str)

input_1: Help me unblock my account.  I entered the PIN wrong too many times.

-------

input_2: I made a transfer and the receiver said an amount was received, but not exactly the same as the right amount. I now have to transfer more to get the remainder to the receiver. Can you alert me about this. What's been happening?

-------

input_3: Are your cards available in the EU?

-------

input_4: How can my friends top up my account?

-------

input_5: Show me how to verify my identity?

-------

input_6: My top up isnt working

-------

input_7: I don't think I made this payment that is showing up

-------

input_8: I have have multiple charges for one transaction.

-------

input_9: Why is my exchange rate wrong?

-------

input_10: My card payments have stopped.

-------

input_11: Why do I need to verify a top-up?

-------

input_12: What is the process to create a disposable virtual card?

-------

input_13: I want to link my new card. Can you help?

-------

input_14: How come the

In [18]:
model_name = f"BatchOf{Banking77Classes.__name__}"
field_definitions = {input_id: (Banking77Classes, ...) for input_id in input_ids}
BatchOfResponseFormat: BaseModel = create_model(
    __model_name=model_name,
    **field_definitions,
)

In [19]:
messages_batch = [
    {"role": "developer", "content": pre_texts + "\n" + batch_instructions},
    {"role": "user", "content": inputs_str},
]

In [20]:
response = client.beta.chat.completions.parse(
    model="gpt-4o-mini",
    messages=messages_batch,
    temperature=0.0,
    tools=openai.NotGiven(),
    response_format=BatchOfResponseFormat,
    max_tokens=openai.NotGiven(),
    reasoning_effort=openai.NotGiven(),
)

Took 18.2 sec instead of 1 min 23 sec.

Might be much worse when there are lots of inputs. Need to think more about when to not
use this besides worse accuracy

In [21]:
outputs_batched: list[Banking77Classes] = [
    getattr(response.choices[0].message.parsed, input_id) for input_id in input_ids
]

Should instead [stream the
outputs](https://platform.openai.com/docs/guides/structured-outputs#streaming)

# Evaluate

In [22]:
preds_unbatched = pl.Series([output.predicted_category for output in outputs_unbatched])
(preds_unbatched == df_sample["class_name"]).mean()

0.61

In [23]:
preds_batched = pl.Series([output.predicted_category for output in outputs_batched])
(preds_batched == df_sample["class_name"]).mean()

0.58