<a href="https://colab.research.google.com/github/nathan-barry/ai2-cartography-reimplementation/blob/main/alpaca_mislabel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip -q install git+https://github.com/huggingface/transformers # need to install from github
!pip install -q datasets loralib sentencepiece transformers
!pip -q install bitsandbytes accelerate
!pip -q install langchain

In [None]:
!nvidia-smi

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, pipeline
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain

import torch

In [None]:
# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Load the CSV file into a pandas DataFrame
hardest_df = pd.read_csv('/content/drive/MyDrive/data_arrays/hardest_examples.csv')

In [None]:
hardest_df.head(10)
# 0: entailment
# 1: neutral
# 2: contradiction

In [None]:
# Drop confidence and variability
examples_df = hardest_df.drop(columns=["confidence", "variability"])

# Get examples
examples = examples_df.to_dict(orient='records')

In [None]:
tokenizer = LlamaTokenizer.from_pretrained("chavinlo/alpaca-native")

base_model = LlamaForCausalLM.from_pretrained(
    "chavinlo/alpaca-native",
    load_in_8bit=True,
    device_map='auto'
)

In [None]:
pipe = pipeline(
    "text-generation",
    model=base_model,
    tokenizer=tokenizer,
    max_length=256,
    temperature=0.6,
    top_p=.95,
)

local_llm = HuggingFacePipeline(pipeline=pipe)

In [None]:
from langchain import PromptTemplate, LLMChain

template = """Below are examples from Stanford's Natural Language dataset. These examples have been
flagged as potentially mislabeled. Each example has a premise, hypothesis, and label (0=entailment,
1=neutral, 2=contradiction). Given the example, do you think it is mislabeled or correct? Explain
why. Your answer should start with either "Mislabeled." or "Correct."

### Example:
{snli_example}

Answer:"""

prompt = PromptTemplate(template=template, input_variables=["snli_example"])

llm_chain = LLMChain(prompt=prompt, llm=local_llm)

In [None]:
import re

def parse_output(output):
    try:
        classification_match = re.search(r"^\s*(Mislabeled|Correct)\.", output)
        if classification_match:
            classification = classification_match.group(1)
        else:
            return "Parsing Error", None, None
        
        reasoning = output.split(".", 1)[1].strip()
        
        return classification, reasoning
    except Exception as e:
        print(f"Error parsing output: {output}\nError: {e}")
        return "Parsing Error", None, None

In [None]:
results_df = pd.DataFrame(columns=["Classification", "Reasoning"])
failed_examples = []

for i, example in enumerate(examples):
  if (i + 1) % 100 == 0:
    print(i + 1)
  id = example.pop("index")
  output = llm_chain.run(str(example))
  classification, reasoning = parse_output(output)
  
  if classification == "Parsing Error":
    failed_examples.append(i)
  else:
    new_row = pd.DataFrame({"Index": [id], "Classification": [classification], "Reasoning": [reasoning]})
    results_df = pd.concat([results_df, new_row], ignore_index=True)

results_df.to_csv("'/content/drive/MyDrive/data_arrays/alpaca-mislabels.csv")

In [None]:
results_pd.head()