You must run this notebook on a GPU. A T4 is sufficient. It's free on [Google
Colab](https://stackoverflow.com/questions/62596466/how-can-i-run-notebooks-of-a-github-project-in-google-colab/67344477#67344477).

This notebook runs a demo of a [GPTQd StableLM
3B](https://huggingface.co/ethzanalytics/stablelm-tuned-alpha-3b-gptq-4bit-128g) on a
classification task using CAPPr and then via sampling.

In [None]:
# check correct CUDA version
import torch

_cuda_version = torch.version.cuda
_msg = (
    "Change the pip install auto-gptq command to the one for "
    f"{_cuda_version} based on the list here: "
    "https://github.com/PanQiWei/AutoGPTQ#quick-installation"
)

assert _cuda_version == "11.8", _msg

In [None]:
!python -m pip install "cappr[demos] @ git+https://github.com/kddubey/cappr.git" \
auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ \
optimum

In [None]:
!git lfs install
!git clone https://huggingface.co/ethzanalytics/stablelm-tuned-alpha-3b-gptq-4bit-128g

Git LFS initialized.
Cloning into 'stablelm-tuned-alpha-3b-gptq-4bit-128g'...
remote: Enumerating objects: 23, done.[K
remote: Total 23 (delta 0), reused 0 (delta 0), pack-reused 23[K
Unpacking objects: 100% (23/23), 593.64 KiB | 6.06 MiB/s, done.


In [None]:
!ls stablelm-tuned-alpha-3b-gptq-4bit-128g

config.json			  quantize_config.json	   tokenizer_config.json
generation_config.json		  README.md		   tokenizer.json
gptq_model-4bit-128g.safetensors  special_tokens_map.json


In [None]:
from __future__ import annotations

from auto_gptq import AutoGPTQForCausalLM
import datasets as nlp_datasets
import pandas as pd
from sklearn.metrics import f1_score
from tqdm.auto import tqdm
from transformers import AutoTokenizer, GenerationConfig, pipeline

from cappr.huggingface.classify import predict_proba

# Load data

In [None]:
df = pd.DataFrame(nlp_datasets.load_dataset("ought/raft", "banking_77", split="train"))

In [None]:
stablelm_chat_template = """
<|SYSTEM|># {system_prompt}
<|USER|>{user_message}<|ASSISTANT|>
""".strip("\n")

In [None]:
def prompt(query: str) -> str:
    system_prompt = (
        "Summarize an inputted banking customer service query in a few words."
    )
    user_message = f'Query: "{query}"\nSummary:'
    return stablelm_chat_template.format(
        system_prompt=system_prompt, user_message=user_message
    )

In [None]:
df["prompt"] = [prompt(query) for query in df["Query"]]

In [None]:
original_class_names = ["Refund_not_showing_up", "activate_my_card", "age_limit", "apple_pay_or_google_pay", "atm_support", "automatic_top_up", "balance_not_updated_after_bank_transfer", "balance_not_updated_after_cheque_or_cash_deposit", "beneficiary_not_allowed", "cancel_transfer", "card_about_to_expire", "card_acceptance", "card_arrival", "card_delivery_estimate", "card_linking", "card_not_working", "card_payment_fee_charged", "card_payment_not_recognised", "card_payment_wrong_exchange_rate", "card_swallowed", "cash_withdrawal_charge", "cash_withdrawal_not_recognised", "change_pin", "compromised_card", "contactless_not_working", "country_support", "declined_card_payment", "declined_cash_withdrawal", "declined_transfer", "direct_debit_payment_not_recognised", "disposable_card_limits", "edit_personal_details", "exchange_charge", "exchange_rate", "exchange_via_app", "extra_charge_on_statement", "failed_transfer", "fiat_currency_support", "get_disposable_virtual_card", "get_physical_card", "getting_spare_card", "getting_virtual_card", "lost_or_stolen_card", "lost_or_stolen_phone", "order_physical_card", "passcode_forgotten", "pending_card_payment", "pending_cash_withdrawal", "pending_top_up", "pending_transfer", "pin_blocked", "receiving_money", "request_refund", "reverted_card_payment?", "supported_cards_and_currencies", "terminate_account", "top_up_by_bank_transfer_charge", "top_up_by_card_charge", "top_up_by_cash_or_cheque", "top_up_failed", "top_up_limits", "top_up_reverted", "topping_up_by_card", "transaction_charged_twice", "transfer_fee_charged", "transfer_into_account", "transfer_not_received_by_recipient", "transfer_timing", "unable_to_verify_identity", "verify_my_identity", "verify_source_of_funds", "verify_top_up", "virtual_card_not_working", "visa_or_mastercard", "why_verify_identity", "wrong_amount_of_cash_received", "wrong_exchange_rate_for_cash_withdrawal"]

In [None]:
class_names = [
    " ".join(class_name.split("_")).capitalize() for class_name in original_class_names
]

In [None]:
df["class_name"] = [class_names[label_idx - 1] for label_idx in df["Label"]]

In [None]:
print(df["prompt"].iloc[0])

<|SYSTEM|># Summarize an inputted banking customer service query in a few words.
<|USER|>Query: "Is it possible for me to change my PIN number?"
Summary:<|ASSISTANT|>


In [None]:
print(df["class_name"].iloc[0])

Change pin


# Load model

In [None]:
quantized_model_dir = "stablelm-tuned-alpha-3b-gptq-4bit-128g"
model = AutoGPTQForCausalLM.from_quantized(
    quantized_model_dir, use_triton=False, use_safetensors=True
)
tokenizer = AutoTokenizer.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")



In [None]:
# warm up model
_ = model(**tokenizer(["warm up"], return_tensors="pt").to(model.device))

# CAPPr

In [None]:
pred_probs = predict_proba(
    prompts=df["prompt"],
    completions=class_names,
    model_and_tokenizer=(model, tokenizer),
    batch_size=1,
)

conditional log-probs:   0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
f1_score(df["Label"] - 1, pred_probs.argmax(axis=1), average="macro")

0.13594104308390023

Accuracy:

In [None]:
((df["Label"] - 1) == pred_probs.argmax(axis=1)).mean()

0.18

Hey it could be worse. What's majority accuracy? Forgot to check.

In [None]:
df["Label"].value_counts(normalize=True).iloc[0]

0.06

In [None]:
pd.Series(pred_probs.argmax(axis=1)).value_counts(normalize=True)

58    0.32
7     0.26
73    0.08
46    0.04
26    0.04
68    0.04
69    0.04
76    0.04
15    0.02
6     0.02
60    0.02
66    0.02
0     0.02
70    0.02
75    0.02
dtype: float64

That's not good.

In [None]:
class_names[7]

'Balance not updated after cheque or cash deposit'

# Text generation

In [None]:
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

generation_config = GenerationConfig(
    max_new_tokens=20,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    batch_size=1,
)

The model 'GPTNeoXGPTQForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PLBartForCausalLM', 'Prop

Prompt stolen from: https://github.com/refuel-ai/autolabel/blob/main/examples/banking/config_banking.json

In [None]:
def prompt_text_gen(query: str):
    class_names_str = "\n".join(sorted(original_class_names))
    system_prompt = (
        "You are an expert at understanding bank customers support complaints and "
        "queries.\n"
        "Your job is to correctly categorize an inputted customer query or complaint "
        "into one of the following categories.\n"
        "Categories:\n"
        f"{class_names_str}\n\n"
        "You will answer with just the the correct category and nothing else."
    )
    user_message = f"Categorize the following query:\n{query}"
    return stablelm_chat_template.format(
        system_prompt=system_prompt, user_message=user_message
    )

df["prompt_text_gen"] = [prompt_text_gen(query) for query in df["Query"]]
print(df["prompt_text_gen"].iloc[0])

<|SYSTEM|># You are an expert at understanding bank customers support complaints and queries.
Your job is to correctly categorize an inputted customer query or complaint into one of the following categories.
Categories:
Refund_not_showing_up
activate_my_card
age_limit
apple_pay_or_google_pay
atm_support
automatic_top_up
balance_not_updated_after_bank_transfer
balance_not_updated_after_cheque_or_cash_deposit
beneficiary_not_allowed
cancel_transfer
card_about_to_expire
card_acceptance
card_arrival
card_delivery_estimate
card_linking
card_not_working
card_payment_fee_charged
card_payment_not_recognised
card_payment_wrong_exchange_rate
card_swallowed
cash_withdrawal_charge
cash_withdrawal_not_recognised
change_pin
compromised_card
contactless_not_working
country_support
declined_card_payment
declined_cash_withdrawal
declined_transfer
direct_debit_payment_not_recognised
disposable_card_limits
edit_personal_details
exchange_charge
exchange_rate
exchange_via_app
extra_charge_on_statement
fail

In [None]:
completions = []
for _prompt in tqdm(df["prompt_text_gen"], total=len(df), desc="Sampling"):
    sequences = generator(
        _prompt,
        generation_config=generation_config,
        pad_token_id=generator.tokenizer.eos_token_id,  # suppress "Setting ..."
    )
    completions.append(sequences[0]["generated_text"].removeprefix(_prompt))

Sampling:   0%|          | 0/50 [00:00<?, ?it/s]



Let's see if the model generated categories like we asked.

In [None]:
pd.Series(completions).sample(20).tolist()

['The following countries are in the "Cash or Debit" category:\n\n1. United',
 'Based on the provided query, it seems that the payment was charged a fee because the payment was made',
 'Based on the given query, it is not possible to categorize it as a refund or not showing',
 'Based on the given query, it is not clear what the $1 transaction refers to. Please provide',
 'Based on the given query, it seems that the issue is related to the payment method or card type',
 'The physical card will be delivered within 1-3 business days.',
 'Based on the provided query, it seems that the customer is waiting for a transfer to be charged.',
 'Category: Refund\n\nYou will answer with just the correct category and nothing else.C',
 'Based on the provided query, it seems that the card payment was not processed because it was not recognized',
 'Based on the provided query, it seems that the issue is with the card not working. Here are',
 'Based on the provided query, it seems that the card has not 

Doesn't seem like it. It likes to answer the customer's query instead of categorizing it. Not gonna attempt to parse these completions.