# Ensuring Reliable Few-Shot Prompt Selection for LLMs

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cleanlab/examples/blob/master/few-shot-prompt-selection/few-shot-prompt-selection.ipynb)

In this notebook, we prompt the Davinci LLM from OpenAI (the model underpinning GPT-3/ChatGPT) with few-shot prompts in an effort to classify the intent of customer service requests at a large bank. Following typical practice, we source the few-shot examples to include in the prompt template from an available dataset of human-labeled request examples. However, the resulting LLM predictions are unreliable — a close inspection reveals this is because real-world data is messy and error-prone.  LLM performance in this customer service intent classification task is only marginally boosted by manually modifying the prompt template to mitigate potentially noisy data. The LLM predictions become significantly more accurate if we instead use data-centric AI algorithms via Cleanlab Studio to ensure only high-quality few-shot examples are selected for inclusion in the prompt template.

# Imports and Helpers

In [52]:
import pandas as pd
import numpy as np
import warnings
import openai, os
import string
import random
import tiktoken
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from langchain.prompts import PromptTemplate
from langchain.prompts import FewShotPromptTemplate
warnings.filterwarnings('ignore')
pd.set_option('max_colwidth', None)

def eval_preds(preds):
    acc = accuracy_score(preds, test.label.values)
    return "Model Accuracy: " + '{:.1%}'.format(acc)

def tokens_per_prompt(prompt):
    encoding = tiktoken.get_encoding("p50k_base")
    num_tokens = len(encoding.encode(prompt))
    return num_tokens

def cost_per_prompt(prompt):
    cost_per_token = 0.02 / 1000
    tokens = tokens_per_prompt(prompt)
    cost = tokens * cost_per_token
    return cost

def cost_per_test_evaluation(examples_pool, test):
    texts = test.text.values
    cost = 0
    examples = get_examples(examples_pool)
    for text in texts:
        prompt = get_prompt(examples_pool, text, examples)
        prompt_cost = cost_per_prompt(prompt)
        cost += prompt_cost
    return "${:,.2f}".format(cost)

# Banking Intent Dataset
This notebook studies a 50-class variant of the [Banking-77](https://arxiv.org/abs/2003.04807) Dataset which contains online banking queries annotated with their corresponding intents (the label shown below). We evaluate models that predict this label using a fixed test dataset containing ~500 phrases.

In [74]:
test = pd.read_csv('https://s.cleanlab.ai/banking-intent-50/test.csv')
examples_pool = pd.read_csv('https://s.cleanlab.ai/banking-intent-50/examples-pool.csv')

In [60]:
examples_pool[['text', 'label']].head()

Unnamed: 0,text,label
0,i moved to a new city and need to change my address,edit_personal_details
1,"on my transfer there was a ""decline"" message",declined_transfer
2,help! my wallet was stolen and someone is taking money out. i need this money! what can i do?,card_payment_fee_charged
3,"while abroad i got cash, and a wrong exchange rate was applied.",wrong_exchange_rate_for_cash_withdrawal
4,why can't i get cash?,getting_spare_card


# OpenAI API Key
Replace with your own key.

In [83]:
%env OPENAI_API_KEY = {your_key_here}
openai.api_key = os.environ['OPENAI_API_KEY']

env: OPENAI_API_KEY={your_key_here}


# Building the Few-Shot Prompt
Few-shot prompting is a technique used in natural language processing that enables pretrained foundation models to perform complex tasks without any explicit training (i.e. updates to model parameters).  In few-shot prompting (also known as in-context learning), we provide a model with a limited number of input-output pairs, as part of a prompt template that is included in the prompt used to instruct the model how to handle a particular input. 

Our prompt will consist of a few pieces:
- (optional) prefix 
- list of class labels to help LLM choose a valid class
- (optional) 50 examples, one from each class
- target text for LLM to classify

## Adding List of Classes for Valid Completions

This text will go at the beginning of the prompt and will tell the LLM what the valid classes are so that it can consistently output a class. Without this, the LLM will not choose a valid class and output something not parsable.

Here we can also add an optional `prefix` that we will use later.

In [61]:
# Helper to get prefix for prompt. This gives the LLM all of the labels so that it chooses more accurately.
def get_prefix(examples_pool, prefix=""):
    s = ""
    if len(prefix) != 0:
        s += prefix
        s += '\n'
    s += "You can choose the label from: "
    classes = list(examples_pool.label.unique())
    s += ",".join(classes)
    return s

print(get_prefix(examples_pool, "Beware some labels in the examples may be noisy."))

Beware some labels in the examples may be noisy.
You can choose the label from: edit_personal_details,declined_transfer,card_payment_fee_charged,wrong_exchange_rate_for_cash_withdrawal,getting_spare_card,cash_withdrawal_charge,verify_top_up,transfer_timing,apple_pay_or_google_pay,card_payment_not_recognised,visa_or_mastercard,reverted_card_payment,transfer_not_received_by_recipient,country_support,wrong_amount_of_cash_received,refund_not_showing_up,card_linking,failed_transfer,exchange_via_app,fiat_currency_support,activate_my_card,direct_debit_payment_not_recognised,balance_not_updated_after_cheque_or_cash_deposit,cash_withdrawal_not_recognised,transfer_fee_charged,card_arrival,pending_top_up,extra_charge_on_statement,supported_cards_and_currencies,declined_card_payment,top_up_failed,automatic_top_up,transaction_charged_twice,disposable_card_limits,card_payment_wrong_exchange_rate,pending_transfer,declined_cash_withdrawal,balance_not_updated_after_bank_transfer,beneficiary_not_allowed

## Adding K-Shot Examples

Here we randomly choose 50 examples, 1 from each class to build a 50-shot prompt for the LLM.

In [62]:
# Helper method to get one example from each class for k-shot prompt.
def get_examples(examples_pool):
    out = []
    unique_classes = examples_pool.label.unique()
    for i, cls in enumerate(unique_classes):
        temp = examples_pool[examples_pool.label==cls]
        random.seed(0+i)
        idx = random.choice(list(range(len(temp))))
        text = temp.iloc[idx].text
        label = temp.iloc[idx].label
        d = {'text':text, 'label':label}         
        out.append(d)
    return out

examples = get_examples(examples_pool)
examples

[{'text': 'i just got married and i need to change my name',
  'label': 'edit_personal_details'},
 {'text': 'it declined my transfer.', 'label': 'declined_transfer'},
 {'text': "why am i being charged for atm cash withdrawals? the only reason i use it is because it's been free! now you expect me to pay for them, and how much is that going to cost me?",
  'label': 'card_payment_fee_charged'},
 {'text': 'i attempted to get money using a foreign currency at an atm but the rate was highly inaccurate!',
  'label': 'wrong_exchange_rate_for_cash_withdrawal'},
 {'text': 'tell me where i can find the auto top up feature and a little bit about it please.',
  'label': 'getting_spare_card'},
 {'text': 'how come i got charged extra for withdrawing cash?',
  'label': 'cash_withdrawal_charge'},
 {'text': "i don't know where the top-up verification code is.",
  'label': 'verify_top_up'},
 {'text': 'how long do i have to wait for a us transfer?',
  'label': 'transfer_timing'},
 {'text': 'am i able to d

## Generate Entire Prompt 

In [63]:
# Helper to format the k-shot prompt with:
# - prefix
# - 1 example from each class
# - target text for classification
def get_prompt(examples_pool, text, examples, prefix=""):
    prompt_template = PromptTemplate(
        input_variables=["text", "label"],
        template="Text: {text}\nLabel: {label}",
    )

    p = FewShotPromptTemplate(
        example_prompt = prompt_template,
        examples = examples,
        prefix = get_prefix(examples_pool, prefix),
        suffix = "Text: {text}\nLabel:",
        input_variables = ['text'],
        )
    return p.format(text=text).strip()

print(get_prompt(examples_pool, "Classify this text!", examples, "Beware some labels in the examples may be noisy."))

Beware some labels in the examples may be noisy.
You can choose the label from: edit_personal_details,declined_transfer,card_payment_fee_charged,wrong_exchange_rate_for_cash_withdrawal,getting_spare_card,cash_withdrawal_charge,verify_top_up,transfer_timing,apple_pay_or_google_pay,card_payment_not_recognised,visa_or_mastercard,reverted_card_payment,transfer_not_received_by_recipient,country_support,wrong_amount_of_cash_received,refund_not_showing_up,card_linking,failed_transfer,exchange_via_app,fiat_currency_support,activate_my_card,direct_debit_payment_not_recognised,balance_not_updated_after_cheque_or_cash_deposit,cash_withdrawal_not_recognised,transfer_fee_charged,card_arrival,pending_top_up,extra_charge_on_statement,supported_cards_and_currencies,declined_card_payment,top_up_failed,automatic_top_up,transaction_charged_twice,disposable_card_limits,card_payment_wrong_exchange_rate,pending_transfer,declined_cash_withdrawal,balance_not_updated_after_bank_transfer,beneficiary_not_allowed

## Query OpenAI LLM API

In [64]:
# Helper method to prompt OpenAI LLM and get response.
def get_response(prompt):
    response = openai.Completion.create(
      model="text-davinci-003",
      prompt=prompt,
      temperature=0,
      max_tokens=50,
      top_p=1,
      frequency_penalty=0,
      presence_penalty=0
    )
    
    # Parse output to get just the label.
    resp = response['choices'][0]['text'].split('\n')[0].split(',')[0].strip().lower().rstrip(string.punctuation)
    
    # Just in case a respone is not a perfect match, we know.
    if resp not in examples_pool.label.unique():
        print(resp)
    return resp

text = "\'How can I change my pin?\'"
examples = get_examples(examples_pool)
prompt = get_prompt(examples_pool, text, examples)
response = get_response(prompt)
print("Model classified ", text, " as ", response)

Model classified  'How can I change my pin?'  as  change_pin


## Evaluate Prompt on Test Examples

In [65]:
# Helper method to evaluate test set.
def eval_prompt(examples_pool, test, prefix="", use_examples = True):
    texts = test.text.values
    responses = []
    examples = get_examples(examples_pool) if use_examples else []
    for i in tqdm(range(len(texts))):
        text = texts[i]
        prompt = get_prompt(examples_pool, text, examples, prefix)
        resp = get_response(prompt)
        responses.append(resp)
    return responses

## Estimate Spend

Before we run our prompt evaluation script, let's estimate how much it will cost.

In [58]:
cost_per_test_evaluation(examples_pool, test)

'$17.85'

# Baseline Model Performance
Running each of the test examples through the LLM with the 50-shot prompt shown above, we achieve an accuracy of 59.6% which isn’t bad for a 50-class problem — let’s take a closer look at the examples pool and see if we can improve anything!

In [454]:
baseline_preds = eval_prompt(examples_pool, test)
# If you'd like to save the predictions.
np.save('baseline_preds.npy', np.array(baseline_preds))
eval_preds(baseline_preds)

'Model Accuracy: 59.6%'

# Issues in Our Data
After a closer inspection of the pool of few-shot examples, we find examples of label errors and outliers. Here are a few examples of each. 

### Label Issues

In [66]:
label_idx = [4013, 3180, 6737]
issues = examples_pool[examples_pool.id.isin(label_idx)]
issues[['text', 'label']]

Unnamed: 0,text,label
152,how much does it cost to get more cards?,failed_transfer
182,i may need to dispute a direct debit payment.,getting_spare_card
1289,should i contact customer support if i can't edit details,card_about_to_expire


### Outliers


In [67]:
label_idx = [13295, 13249, 13275]
label_idx = [13295, 13275, 13284]

examples_pool[examples_pool.id.isin(label_idx)][['text', 'label']]

Unnamed: 0,text,label
72,not (A and B and C),disposable_card_limits
441,7ZFGXIBX26Q3TXS4W8VA06ZZLA9GK9I,card_arrival
729,"Add eggs one at a time, beating well after each addition.",visa_or_mastercard


# Why Do These Issues Matter?

As the context size for LLMs grows every day, it is becoming more and more common for prompts to include many examples from each class. In many applications, it may not be possible to hand-select which examples to use in your few-shot prompt, especially as the number of classes grows. As a result, if your examples pool contains issues that we have shown above, randomly selecting examples from each class may result in these errant examples finding their way into the prompt. The remainder of this article illustrates the decreased model performance as a result of this.

# Can we tell the LLM the examples are noisy?

Instead of modifying the few-shot examples, what if we just include a “disclaimer warning” in the prompt telling the LLM that the labels may be incorrect?

In [73]:
# Add warning to the beginning of our prompt.
prefix = "Beware that some labels in the examples may be noisy and have been incorrectly specified."
preds_with_disclaimer = eval_prompt(examples_pool, test, prefix)
# If you'd like to save the predictions.
np.save('preds_with_disclaimer.npy', np.array(preds_with_disclaimer))
eval_preds(preds_with_disclaimer)

'Model Accuracy: 62.0%'

Using the above prompt, we achieve an accuracy of 62%. Marginally better! I’m certain we can still improve.

# Can we remove the noisy examples entirely?

Since we can’t trust the labels in the examples pool, what if we just remove them entirely from the prompt and only rely on the powerful LLM?

In [72]:
# Here we set use_examples to False so we don't add any of the 
# fewshot examples to the prompt.
preds_with_no_examples = eval_prompt(examples_pool, test, use_examples=False)
# If you'd like to save the predictions.
np.save('preds_with_no_examples.npy', np.array(preds_with_no_examples))
eval_preds(preds_with_no_examples)

'Model Accuracy: 67.4%'

# Can we identify and correct the noisy examples?

Instead of modifying the prompt or removing the examples entirely, the smarter (yet more complex) way to improve our dataset would be to find and fix the label issues by hand. This simultaneously removes a noisy data point that is harming the model and adds an accurate one that should improve its performance via few-shot prompting, but making such corrections manually is cumbersome. Here we instead effortlessly correct the data using [Cleanlab Studio](https://cleanlab.ai/studio/), a platform that implements [Confident Learning](https://l7.curtisnorthcutt.com/confident-learning) algorithms to automatically find and fix label issues.

After replacing the bad labels we spotted with more suitable ones, we re-run the original 50-shot through the LLM with each test example, except this time we use the **corrected label** which ensures we provide the LLM with 50 high-quality examples.

In [78]:
# Load in examples pool that has been corrected using Cleanlab Studio.
examples_pool_studio = pd.read_csv('https://s.cleanlab.ai/banking-intent-50/examples-pool-studio.csv')

In [451]:
preds_corrected_labels = eval_prompt(examples_pool_studio, test)
np.save('preds_corrected_labels.npy', np.array(preds_corrected_labels))
eval_preds(preds_corrected_labels)

'Model Accuracy: 72.0%'

After doing this, we achieve an accuracy of 72% which is **quite impressive** for the 50-class problem.

**We’ve now shown that noisy few-shot examples can considerably decrease LLM performance and that it is not wise to just change the prompt (via adding warnings or removing examples). To achieve the highest performance, you should also try correcting your examples using Data-centric AI via tools like Cleanlab Studio.**

## Resources
- [Cleanlab Studio](https://cleanlab.ai)
- [Github](https://github.com/cleanlab/cleanlab)
- [Cleanlab Community](https://cleanlab.ai/slack)