In [None]:
import matplotlib.pyplot as plt
import numpy as np
import re
import random

from time import time
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from langchain_core.messages import (
    HumanMessage,
    SystemMessage,
)

In [None]:
def map_labels(data):
  data["label_human"] = label_map[data["label"]]
  return data

In [None]:
def prepare_examples(dataset, key, n_per_label=1):
    random.seed(42)

    examples_by_label = {}
    for example in dataset:
        label = example["label"]
        if label not in examples_by_label:
            examples_by_label[label] = []
        examples_by_label[label].append(example)

    few_shot_examples = []
    for label, examples in examples_by_label.items():
        few_shot_examples.extend(random.sample(examples, n_per_label))

    examples_text = "\n".join(
        f"Sentence: {example[key]}\n"
        f"Sentiment: {example['label_human']}"
        for example in few_shot_examples
    )

    examples_text_cot = "\n".join(
        f"Sentence: {example[key]}\n"
        f"Reasoning: {get_reasoning(example['label'])}\n"
        f"Sentiment: {example['label_human']}"
        for example in few_shot_examples
    )

    return few_shot_examples, examples_text, examples_text_cot

In [None]:
def evaluate_model(X_test, y_test, fn, examples):
    t0 = time()
    predictions = []

    for idx, (sentence, true_label) in enumerate(zip(X_test, y_test), 1):
        prediction = fn(sentence, examples)
        sentiment = parse_sentiment(prediction)

        predictions.append(sentiment)

        if idx % 100 == 0:
            print(f"---\nSentence: {sentence}\nTrue: {true_label}\nPrediction: {sentiment}")
            print(f"Processed {idx}/{len(dataset)} examples, Time: {time()-t0:.3f}\n---\n")

    return predictions

In [None]:
# def parse_sentiment_reason(response):
#     sentiment_pattern = r"Sentiment:\s*(.*?)\n"
#     reason_pattern = r"Reason:\s*(.*?)$"
    
#     sentiment_matches = re.findall(sentiment_pattern, response)
#     reason_match = re.search(reason_pattern, response)

#     sentiment = sentiment_matches[1].strip().lower() if len(sentiment_matches) > 1 else None
#     reason = reason_match.group(1).strip() if reason_match else None
    
#     return (sentiment, reason)

def parse_sentiment(response):
    sentiment_pattern = r"Sentiment:\s*(.*?)$"
    sentiment_match = re.search(sentiment_pattern, response)

    sentiment = sentiment_match.group(1).strip().lower() if sentiment_match else "invalid"

    # Sometimes, model outputs in an incorrect format (e.g. "determination/resilience (however, closest match from your options would be: joy)")
    # This will make it simply identify as None rather than trying to deal with parsing the output.
    if sentiment not in [label for label in label_map.values()]:
        # print(f"+++\nResponse in invalid format:\n\n{response}\n\n---\n")
        sentiment = "invalid"

    return sentiment

In [None]:
def zero_shot(content: str, examples = None):
    # Asking for reasoning increases prompt response time by 10x.
    # f"Now classify the following sentence and provide reasoning for your classification. The output MUST follow this format:\n"
    # f"Sentiment: [Classification]\nReason: [Explanation]"

    messages = [
        SystemMessage(
            content=f"Your goal is to read a sentence and classify its sentiment into one of the following categories: {', '.join(label_map.values())}.\n\n"
                    f"Now classify the following sentence. The output MUST follow this format:\n"
                    f"Sentiment: [Classification]"
        ),
        HumanMessage(
            content=content
        ),
    ]

    response = chat_model.invoke(messages)
    return response.content

def few_shot(content: str, examples: str):
    messages = [
        SystemMessage(
            content=f"Your goal is to read a sentence and classify its sentiment into one of the following categories: {', '.join(label_map.values())}.\n\n"
                    f"Here are some examples:\n{examples}\n\n"
                    f"Now classify the following sentence. The output MUST follow this format:\n"
                    f"Sentiment: [Classification]"
        ),
        HumanMessage(
            content=content
        ),
    ]

    response = chat_model.invoke(messages)
    return response.content

In [None]:
def create_reports(y_test, predictions):
    y_test, y_pred = filter_invalid(y_test, predictions)

    calc_accuracy(y_test, y_pred)
    class_report(y_test, y_pred)
    conf_matrix(y_test, y_pred)
    acc_graph(y_test, y_pred)

def filter_invalid(y_test, y_pred):
    filtered_y_test = []
    filtered_y_pred = []
    for true_label, pred_label in zip(y_test, y_pred):
        if pred_label != "invalid":
            filtered_y_test.append(true_label)
            filtered_y_pred.append(pred_label)
    print(f"Filtered {len(y_test) - len(filtered_y_test)} invalid predictions.")

    return filtered_y_test, filtered_y_pred

def calc_accuracy(y_test, y_pred):
    accuracy = sum([1 if p == t else 0 for p, t in zip(y_pred, y_test)]) / len(y_test) * 100
    print(f"Accuracy: {accuracy:.4f}")

def class_report(y_test, y_pred):
    print(classification_report(
            y_test, y_pred,
            labels=label_values))

def conf_matrix(y_test, y_pred):
    cm = confusion_matrix(y_test, y_pred, labels=label_values)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_values)
    disp.plot()
    plt.show()

def acc_graph(y_test, y_pred):
    accuracies = np.cumsum(np.array(y_test) == np.array(y_pred)) / np.arange(1, len(y_test) + 1)
    plt.figure(figsize=(8, 5))
    plt.plot(accuracies * 100)
    plt.title("Accuracy Over Samples")
    plt.xlabel("Number of Samples")
    plt.ylabel("Accuracy (%)")
    plt.grid()
    plt.show()