<a href="https://colab.research.google.com/github/nathan-barry/ai2-cartography-reimplementation/blob/main/gpt_mislabel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install langchain openai

In [None]:
import os
os.environ["OPENAI_API_KEY"] = "KEY_GOES_HERE"

In [None]:
from langchain.llms import OpenAI
from langchain import PromptTemplate, LLMChain

In [None]:
# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
# Load the CSV file into a pandas DataFrame
hardest_df = pd.read_csv('/content/drive/MyDrive/data_arrays/hardest_examples.csv')

In [None]:
# Load the CSV file into a pandas DataFrame
easiest_df = pd.read_csv('/content/drive/MyDrive/data_arrays/easiest_examples.csv')

In [None]:
# hardest_df.head(10)
easiest_df.head(10)
# 0: entailment
# 1: neutral
# 2: contradiction

In [None]:
# Drop confidence and variability
# examples_df = hardest_df.drop(columns=["confidence", "variability"])
examples_df = easiest_df.drop(columns=["confidence", "variability"])

# Get examples
examples = examples_df.to_dict(orient='records')

In [None]:
llm = OpenAI(temperature=0.2)

In [None]:
template = """Evaluate the following potentially mislabeled example from the SNLI dataset.
Each example has a premise, hypothesis, and label (0=entailment, 1=neutral, 2=contradiction).
Determine if it's mislabeled or correct, explain why, and provide the correct label if mislabeled.
Start your answer with either "Mislabeled:" or "Correct:", followed by your reasoning. 
End with "Correct Label: <label number>".

Example:
{snli_example}

Answer:"""

prompt = PromptTemplate(template=template, input_variables=["snli_example"])

llm_chain = LLMChain(prompt=prompt, llm=llm)

In [None]:
count_df = examples_df.drop(columns=["index"])
counts = count_df.to_dict(orient='records')

# Calculate cost
template_char_cnt = len(template)
num_examples = len(counts)
text_char_cnt = len(str(counts))

print(template_char_cnt)
print(num_examples)
print(text_char_cnt)

In [None]:
total_char = (template_char_cnt * num_examples) + text_char_cnt
token_guess = total_char // 4
cost_per_token = .002 / 1000

print(total_char)
print(token_guess)
print(cost_per_token)
print(token_guess * cost_per_token)

In [None]:
1_000_000 / 1000 * .002

In [None]:
import re

def parse_output(output):
    try:
        classification_match = re.search(r"^\s*(Mislabeled|Correct):", output)
        if classification_match:
            classification = classification_match.group(1)
        else:
            return "Parsing Error", None, None

        reasoning = output.split(":")[1].strip().rsplit(" ", 1)[0].strip()
        correct_label = None
        correct_label_match = re.search(r"Correct Label: (\d)", output)
        if correct_label_match:
            correct_label = int(correct_label_match.group(1))
        
        return classification, correct_label, reasoning
    except Exception as e:
        print(f"Error parsing output: {output}\nError: {e}")
        return "Parsing Error", None, None

In [None]:
results_df = pd.DataFrame(columns=["Classification", "Predicted Label", "Reasoning"])
failed_examples = []

In [None]:
for i, example in enumerate(examples):
  id = example.pop("index")
  output = llm_chain.run(str(example))
  if (i+1) % 50 == 0:
    print(f"i: {i+1}, output: {output}, example_len: {len(str(example))}")
  classification, predicted_label, reasoning = parse_output(output)
  
  if classification == "Parsing Error":
    failed_examples.append(i)
  else:
    new_row = pd.DataFrame({"Index": [id], "Classification": [classification], "Predicted Label": [predicted_label], "Reasoning": [reasoning]})
    results_df = pd.concat([results_df, new_row], ignore_index=True)

In [None]:
results_df.to_csv("/content/drive/MyDrive/data_arrays/alpaca-mislabels.csv")

In [None]:
examples_df.head()

In [None]:
results_df["Classification"].value_counts()

In [None]:
results_df.head()