In [1]:
from datasets import load_dataset
from openai import OpenAI
import json

import csv
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
client = OpenAI(
    api_key="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
)

def generate_cluster(selected_texts, instruction):
  response = client.chat.completions.create(
    model="gpt-4o",
    messages=[
      {"role": "system", "content": instruction},
      {"role": "user", "content": f"{list(selected_texts)}"}
    ]
  )

  result = response.choices[0].message.content
  return result

In [3]:
def extract_final_answer(answer: str):       
    if not answer:
        return "<INVALID>"

    model_pred = answer.lower()
    preds = model_pred.split("<ans_start>")

    pred = preds[-1].split("<ans_end>")[0].strip()

    if len(pred) == 0:
        return "<INVALID>"

    return pred

In [33]:
# Load data from banking77_sent2label.json
with open('dataset/banking77_sent2label.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

# Extract texts and cluster labels
texts = list(data.keys())
cluster_labels = list(data.values())

# Print the extracted texts and cluster labels
print(texts)
print(cluster_labels)

['My account shows the card payment cancelled.', 'what happens to the funds if a merchant refuses the payment', 'There must be an issue, why has my card been cancelled?', 'What do I do if it says my card payment has been cancelled?', 'I did a payment but the app reverted it', 'I think my card payment had been return', 'I was contacted by a seller with a message that they never received my money. I am 100% sure it was taken from my account so I definitely need this sorted out soon.', 'If my card payment is cancelled, what should I do?', 'My credit card cancelled a payment for a purchase.', 'What is the reason that my card payment was cancelled?', 'There was a canceled payment for my card.', 'My card is being declined for a purchase. I bought items before and the card worked. Do you know what the problem is?', 'I have a payment listed as cancelled.', 'My card payment was reverted.  Why?', 'Why was my payment reversed?', 'My transaction to pay for an item was returned to my account.', "Wh

In [40]:
n = len(texts)
k = len(set(cluster_labels))

In [42]:
# Load prompt
with open('prompt_template.json', 'r', encoding="utf-8") as file:
    prompt_template = json.load(file)

for prompt in prompt_template.keys():
    prompt_template[prompt] = prompt_template[prompt].replace("{n}", str(n)).replace("{k}", str(k))

prompt_template

{'vanilla': "You are given a dataset of 240 sentences which you need to cluster into one of the 8 clusters. Output exactly 240 cluster labels.\nFor each sentence, assignment it to one of the 8 cluster label and output the cluster number. Your output should ONLY contain a list of 240 integers in the format <ANS_START>[cluster asignments]<ANS_END>. Do not include any other texts.\n  \nExample:\nInput Sentences: ['sentence1', 'sentence2', 'sentence3']\nOutput Labels: [1, 0, 2]\n",
 'fewshot': "You are given a dataset of 240 sentences which you need to cluster into one of the 8 clusters. Output exactly 240 cluster labels.\nFor each sentence, assignment it to one of the 8 cluster label and output the cluster number. Your output should ONLY contain a list of 240 integers in the format <ANS_START>[cluster asignments]<ANS_END>. Do not include any other texts.\n  \nExample:\nInput Sentences: ['sentence1', 'sentence2', 'sentence3']\nOutput Labels: [1, 0, 2]\n \n\n\n[Question] ['create a playlist

In [52]:
for prompt in prompt_template.keys():
    results = []
    instruction = prompt_template[prompt]
    print(f"#### Running with prompt - {prompt}\n")
    with open(f'clustering_result/banking77_prompt/prompting_results_{prompt}_banking77.csv', 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        for i in tqdm(range(0, 50)):
            try:
                result = generate_cluster(texts, instruction)
            except Exception as e:
                print(f"GPT Error: {e}")
            try:
                processed_result = extract_final_answer(result)         # Extract the final answer from the result
            except:
                print("INVALID OUTPUT")
                print(result)
                break
            label_count = len(processed_result[1:-1].split(", "))         # Count the number of labels in the processed result
            writer.writerow([i, label_count, processed_result])
            results.append({'Iteration': i, 'Label Count': label_count, 'Processed Result': processed_result})

    # Convert the results to a DataFrame
    # df_results = pd.DataFrame(results)
    # df_label_counts = pd.read_csv('prompting_label_counts_banking77.csv')
    # df_label_counts[f"{prompt}"] = df_results["Label Count"]

    # df_label_counts.to_csv('prompting_label_counts_banking77.csv', index=False)

#### Running with prompt - vanilla



  0%|          | 0/1 [00:00<?, ?it/s]

In [27]:
bank77_count = pd.read_csv("clustering_result/count_statistics/prompting_label_counts_banking77.csv")

for prompt in prompt_template.keys():
    df = pd.read_csv(f'clustering_result/banking77_prompt/prompting_results_{prompt}_banking77.csv')
    df.columns = ['Index', 'Label Count', 'Cluster Assignment']
    for i, row in df.iterrows():
        df.at[i, 'Index'] = int(i)+1
        cluster = [x.strip() for x in row[2][1:-1].split(",")]
        df.at[i, 'Label Count'] = int(len(cluster))
    df['Label Count'] = df['Label Count'].astype(int)
    df.to_csv(f'clustering_result/banking77_prompt/prompting_results_{prompt}_banking77.csv', index=None)
    bank77_count[f"{prompt}"] = df['Label Count']


bank77_count.to_csv("clustering_result/count_statistics/prompting_label_counts_banking77.csv", index=None)

  cluster = [x.strip() for x in row[2][1:-1].split(",")]
  cluster = [x.strip() for x in row[2][1:-1].split(",")]
  cluster = [x.strip() for x in row[2][1:-1].split(",")]
  cluster = [x.strip() for x in row[2][1:-1].split(",")]
  cluster = [x.strip() for x in row[2][1:-1].split(",")]


In [28]:
threshold = 240

counts = {}

for col in bank77_count.columns:
    counts[col] = {
        'less than': (bank77_count[col] < threshold).sum(),
        'equal to': (bank77_count[col] == threshold).sum(),
        'greater than': (bank77_count[col] > threshold).sum()
    }

counts_df = pd.DataFrame(counts).T
counts_df

Unnamed: 0,less than,equal to,greater than
vanilla,6,0,44
cot,13,2,35
fewshot,12,2,36
pw_wo_reasoning,6,3,41
pw_w_reasoning,5,0,45


In [25]:
for prompt in prompt_template.keys():
    df = pd.read_csv(f'clustering_result/banking77_prompt/prompting_results_{prompt}_banking77.csv')
    df.columns = ['Index', 'Label Count', 'Cluster Assignment']
    counter = 0
    for i, row in df.iterrows():
        labels = list(map(int, row[2][1:-1].split(", ")))
        if len(set(labels)) < 8:
            counter += 1
    print(f"{prompt}: {counter}")

vanilla: 17
fewshot: 44
cot: 15
pw_wo_reasoning: 25
pw_w_reasoning: 24


  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))


In [47]:
from sklearn.metrics.cluster import normalized_mutual_info_score

for prompt in prompt_template.keys():
    df = pd.read_csv(f'clustering_result/banking77_prompt/prompting_results_{prompt}_banking77.csv')
    df.columns = ['Index', 'Label Count', 'Cluster Assignment']
    for i, row in df.iterrows():
        if row["Label Count"] == n:
            labels = list(map(int, row[2][1:-1].split(", ")))
            nmi = normalized_mutual_info_score(labels, cluster_labels)
            print(f"{prompt}: {nmi}")

fewshot: 0.0
fewshot: 0.8576856875185628
cot: 0.6918230495594236
cot: 0.7631839406475999
pw_wo_reasoning: 0.7603648085413857
pw_wo_reasoning: 0.7323436932792137
pw_wo_reasoning: 0.7504005902133457


  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
  labels = list(map(int, row[2][1:-1].split(", ")))
