In [None]:
import json
from sklearn.metrics import classification_report
import pandas as pd
from datasets import load_dataset
import requests


# Run Predicitons using ChatGPT API
> How to get API_KEY: https://www.merge.dev/blog/chatgpt-api-key

With ChatGPT, zero-shot classification leverages the model's extensive knowledge embedded within its transformer architecture, which has been trained on vast amounts of text data. When tasked with a zero-shot classification problem, ChatGPT doesn't require explicit training examples for the new classes. Instead, it uses its pre-trained embeddings and understanding of language semantics to infer the most likely class for an input. This is done by comparing the input against the semantic space defined by the model's training data, identifying the closest match(es) in terms of meaning. The model then predicts the class based on the highest semantic similarity scores. This process relies heavily on the model's ability to understand and generate meaningful embeddings, enabling it to perform well in scenarios where traditional supervised learning approaches would struggle due to lack of labeled data for the target classes.


For this use case we run only on a sample of data to evaluate performance. The result we got is 73% macro avg precision.


In [None]:
API_KEY="<YOUR_CHATGPT_KEY>"
CHATGPT_MODEL="gpt-4o"
SAMPLE_SIZE = 400 # For demo we will limit dataset size since inference costs money. 

In [None]:
# Prepare data
dataset = load_dataset("PolyAI/banking77", revision="main")  # taking the data from the main branch
train_data = pd.DataFrame(dataset['train'])
test_data = pd.DataFrame(dataset['test'])

In [None]:
label_names = [
    "activate_my_card",
    "age_limit",
    "apple_pay_or_google_pay",
    "atm_support",
    "automatic_top_up",
    "balance_not_updated_after_bank_transfer",
    "balance_not_updated_after_cheque_or_cash_deposit",
    "beneficiary_not_allowed",
    "cancel_transfer",
    "card_about_to_expire",
    "card_acceptance",
    "card_arrival",
    "card_delivery_estimate",
    "card_linking",
    "card_not_working",
    "card_payment_fee_charged",
    "card_payment_not_recognised",
    "card_payment_wrong_exchange_rate",
    "card_swallowed",
    "cash_withdrawal_charge",
    "cash_withdrawal_not_recognised",
    "change_pin",
    "compromised_card",
    "contactless_not_working",
    "country_support",
    "declined_card_payment",
    "declined_cash_withdrawal",
    "declined_transfer",
    "direct_debit_payment_not_recognised",
    "disposable_card_limits",
    "edit_personal_details",
    "exchange_charge",
    "exchange_rate",
    "exchange_via_app",
    "extra_charge_on_statement",
    "failed_transfer",
    "fiat_currency_support",
    "get_disposable_virtual_card",
    "get_physical_card",
    "getting_spare_card",
    "getting_virtual_card",
    "lost_or_stolen_card",
    "lost_or_stolen_phone",
    "order_physical_card",
    "passcode_forgotten",
    "pending_card_payment",
    "pending_cash_withdrawal",
    "pending_top_up",
    "pending_transfer",
    "pin_blocked",
    "receiving_money",
    "Refund_not_showing_up",
    "request_refund",
    "reverted_card_payment?",
    "supported_cards_and_currencies",
    "terminate_account",
    "top_up_by_bank_transfer_charge",
    "top_up_by_card_charge",
    "top_up_by_cash_or_cheque",
    "top_up_failed",
    "top_up_limits",
    "top_up_reverted",
    "topping_up_by_card",
    "transaction_charged_twice",
    "transfer_fee_charged",
    "transfer_into_account",
    "transfer_not_received_by_recipient",
    "transfer_timing",
    "unable_to_verify_identity",
    "verify_my_identity",
    "verify_source_of_funds",
    "verify_top_up",
    "virtual_card_not_working",
    "visa_or_mastercard",
    "why_verify_identity",
    "wrong_amount_of_cash_received",
    "wrong_exchange_rate_for_cash_withdrawal"]

train_data["label_name"] = train_data["label"].apply(lambda x: label_names[x])
test_data["label_name"] = test_data["label"].apply(lambda x: label_names[x])
eval_data: pd.DataFrame = train_data.sample(SAMPLE_SIZE)

# Running evaluation using chatGPT API

In [None]:
PROMPT =  """
    Who are you?
    You are a world class model for predicting intent on online banking queries. 
    
    What inputs will you get?
    Your input will contain a text
    text: a string feature.
    
    What output should you give?
    label_name: One of classification labels (0-76) corresponding to unique intents.
    These are the following intents:
    
    label_id intent (label_name)
    0 activate_my_card
    1 age_limit
    2 apple_pay_or_google_pay
    3 atm_support
    4 automatic_top_up
    5 balance_not_updated_after_bank_transfer
    6 balance_not_updated_after_cheque_or_cash_deposit
    7 beneficiary_not_allowed
    8 cancel_transfer
    9 card_about_to_expire
    10 card_acceptance
    11 card_arrival
    12 card_delivery_estimate
    13 card_linking
    14 card_not_working
    15 card_payment_fee_charged
    16 card_payment_not_recognised
    17 card_payment_wrong_exchange_rate
    18 card_swallowed
    19 cash_withdrawal_charge
    20 cash_withdrawal_not_recognised
    21 change_pin
    22 compromised_card
    23 contactless_not_working
    24 country_support
    25 declined_card_payment
    26 declined_cash_withdrawal
    27 declined_transfer
    28 direct_debit_payment_not_recognised
    29 disposable_card_limits
    30 edit_personal_details
    31 exchange_charge
    32 exchange_rate
    33 exchange_via_app
    34 extra_charge_on_statement
    35 failed_transfer
    36 fiat_currency_support
    37 get_disposable_virtual_card
    38 get_physical_card
    39 getting_spare_card
    40 getting_virtual_card
    41 lost_or_stolen_card
    42 lost_or_stolen_phone
    43 order_physical_card
    44 passcode_forgotten
    45 pending_card_payment
    46 pending_cash_withdrawal
    47 pending_top_up
    48 pending_transfer
    49 pin_blocked
    50 receiving_money
    51 Refund_not_showing_up
    52 request_refund
    53 reverted_card_payment?
    54 supported_cards_and_currencies
    55 terminate_account
    56 top_up_by_bank_transfer_charge
    57 top_up_by_card_charge
    58 top_up_by_cash_or_cheque
    59 top_up_failed
    60 top_up_limits
    61 top_up_reverted
    62 topping_up_by_card
    63 transaction_charged_twice
    64 transfer_fee_charged
    65 transfer_into_account
    66 transfer_not_received_by_recipient
    67 transfer_timing
    68 unable_to_verify_identity
    69 verify_my_identity
    70 verify_source_of_funds
    71 verify_top_up
    72 virtual_card_not_working
    73 visa_or_mastercard
    74 why_verify_identity
    75 wrong_amount_of_cash_received
    76 wrong_exchange_rate_for_cash_withdrawal
    
    
    Please provide the JSON string representation of the following data:
    {
              "label_name": "<Name of predicted label.>"
    }
    
    Input text: 
    """ 

def predict_with_gpt(text):
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}"
    }

    payload = {
        "model": CHATGPT_MODEL,
        "response_format": {"type": "json_object"}, # ensures that response is in json format
        "messages": [

            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": PROMPT + text

                    }
                ]
            }
        ],
        "max_tokens": 1000
    }

    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    content: str = response.json()['choices'][0]['message']['content']
    prediction = json.loads(content)
    return prediction.get('label_name', 'ERROR')

# Running evaluation

In [None]:
eval_data['pred'] = eval_data.apply(lambda row: predict_with_gpt(row['text']), axis=1)
print(classification_report(eval_data['label_name'], eval_data['pred']))

                                                  precision    recall  f1-score   support

                           Refund_not_showing_up       1.00      0.86      0.92         7
                                activate_my_card       0.67      1.00      0.80         4
                                       age_limit       1.00      1.00      1.00         3
                         apple_pay_or_google_pay       1.00      1.00      1.00         4
                                     atm_support       0.50      1.00      0.67         1
                                automatic_top_up       0.86      1.00      0.92         6
         balance_not_updated_after_bank_transfer       0.50      0.14      0.22         7
balance_not_updated_after_cheque_or_cash_deposit       0.83      1.00      0.91        10
                         beneficiary_not_allowed       1.00      0.56      0.71         9
                                 cancel_transfer       0.50      0.83      0.62         6
         

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
