# Few-Shot Learning with GPT


## Setup


In [None]:
# Enter your OpenAI API key here
KEY = "sk-..."

import openai
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, roc_auc_score
from sklearn.preprocessing import LabelBinarizer

openai.api_key = KEY

## Helpers


In [None]:
# Get examples for few-shot learning
def get_examples(train_df, dev_df, num_examples=5):
    combined_df = pd.concat([train_df, dev_df], ignore_index=True)
    examples = combined_df.sample(n=num_examples)
    return examples


# Prompt model with few-shot examples
def process_dataframe(df, data_type, model="text-davinci-003", examples=None):
    true_labels = []
    predicted_labels = []
    max_samples = 50
    max_context_length = 4097
    print("Processing " + data_type + " data:")

    for index, row in df.head(max_samples).iterrows():
        text = row["MASKED_DOCUMENT"]
        title = row["TITLE"]
        main_subject = row["TARGET_ENTITY"]
        sentiment = row["TRUE_SENTIMENT"]

        end_prompt = f"Text: {text}\nTitle: {title}\nMain Subject: {main_subject}\n\nDetermine the sentiment (one of Neutral, Positive, Negative) of the text towards the main subject:\nSentiment:"
        max_examples_length = (
            max_context_length - len(end_prompt) - 3
        )  # Reserve tokens for the completion

        examples_prompt = ""
        if examples is not None:
            for _, example in examples.iterrows():
                new_example = f"Text: {example['MASKED_DOCUMENT']}\nTitle: {example['TITLE']}\nMain Subject: {example['TARGET_ENTITY']}\nSentiment: {example['TRUE_SENTIMENT']}\n\n"
                if len(examples_prompt) + len(new_example) > max_examples_length:
                    break
                examples_prompt += new_example

        prompt = examples_prompt + end_prompt

        response = openai.Completion.create(
            engine=model,
            prompt=prompt,
            temperature=0.3,
            max_tokens=3,
            top_p=1.0,
            frequency_penalty=0,
            presence_penalty=0,
            n=1,
            stop=None,
        )
        predicted_sentiment = response.choices[0].text.strip()

        true_labels.append(sentiment)
        predicted_labels.append(predicted_sentiment)

        print(
            "Sample "
            + str(index + 1)
            + ":\nTitle: "
            + str(title)
            + "\nText: "
            + str(text)
            + "\nMain Enitity: "
            + str(main_subject)
            + "\nTrue sentiment: "
            + str(sentiment)
            + "\nPredicted sentiment: "
            + str(predicted_sentiment)
            + "\n"
        )

    return true_labels, predicted_labels


# Evaluate model performance
def calculate_metrics(true_labels, predicted_labels, data_type):
    accuracy = accuracy_score(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels, average="weighted")
    cm = confusion_matrix(true_labels, predicted_labels)

    lb = LabelBinarizer()
    true_labels_bin = lb.fit_transform(true_labels)
    predicted_labels_bin = lb.transform(predicted_labels)
    roc_auc = roc_auc_score(
        true_labels_bin, predicted_labels_bin, average="weighted", multi_class="ovr"
    )

    print("\nMetrics for " + data_type + " data:")
    print("Accuracy: " + str(accuracy))
    print("F1 Score: " + str(f1))
    print("Confusion Matrix:\n" + str(cm))
    print("ROC-AUC Score: " + str(roc_auc))

## Learning


In [None]:
# Load the datasets
train_df = pd.read_csv("https://github.com/MHDBST/PerSenT/raw/main/train.csv")
dev_df = pd.read_csv("https://github.com/MHDBST/PerSenT/raw/main/dev.csv")
random_test_df = pd.read_csv(
    "https://github.com/MHDBST/PerSenT/raw/main/random_test.csv"
)
freq_test_df = pd.read_csv("https://github.com/MHDBST/PerSenT/raw/main/fixed_test.csv")

# Get examples for few-shot learning
examples = get_examples(train_df, dev_df)

# Prompt the model
random_test_true_labels, random_test_predicted_labels = process_dataframe(
    random_test_df, "Random Test", examples=examples
)
freq_test_true_labels, fixed_test_predicted_labels = process_dataframe(
    freq_test_df, "Frequent Test", examples=examples
)

Processing Random Test data:
Sample 1:
Title: Philippines says police might have shot hostages
Text: The new details of the investigation emerged as [TGT] said [TGT]'s through apologizing for the attack and will focus instead on easing tensions with China and Hong Kong  where officials have criticized the handling of the daylong crisis.
 "Let me just say that this incident will not define this administration " [TGT] said in a nationally televised news conference. [TGT] added that [TGT] will wait for a report from a fact-finding committee before [TGT] fires any officials for the fiasco.
 [TGT]facing [TGT] first major test barely two months after taking office  said [TGT] will now focus on preventing a repeat of the incident. The public and the media have questioned why [TGT] wasn't more visible and involved.
 "The first thing I will admit is I am not perfect and I can learn " said [TGT]who said [TGT] was following the developments from [TGT] office. Later  [TGT] went to a restaurant nea

## Evaluation


In [None]:
# Evaluate the model
calculate_metrics(random_test_true_labels, random_test_predicted_labels, "Random Test")
calculate_metrics(freq_test_true_labels, fixed_test_predicted_labels, "Frequent Test")


Metrics for Random Test data:
Accuracy: 0.6
F1 Score: 0.5910943396226416
Confusion Matrix:
[[ 4  2  0]
 [ 1 17  5]
 [ 1 11  9]]
ROC-AUC Score: 0.6503250899802624

Metrics for Frequent Test data:
Accuracy: 0.42
F1 Score: 0.3337373737373737
Confusion Matrix:
[[ 6  6  1]
 [ 5 15  0]
 [ 3 14  0]]
ROC-AUC Score: 0.5434070434070435


## Analysis


In this notebook, we used a few-shot approach to train the text-davinci-003 model on the sentiment analysis task. We then evaluated the model on the test data using the following metrics: accuracy, F1 score, and ROC-AUC score. We also computed a confusion matrix.

- Accuracy provides a general sense of the model's performance, but it can be misleading if the dataset is imbalanced.
- The F1 score provides a balance between precision and recall; it is better for imbalanced datasets.
- The ROC-AUC score is a performance measurement for multi-class classification problems using the one-vs-rest strategy.
- The confusion matrix provides a detailed view of the classification results, helping us understand the types of errors the model is making.

We tested the model on the Random and Fixed test datasets. As per PerSenT:

> Due to the nature of news collections, some entities tend to dominate the collection. In our collection,there were four entities which were the main entity in nearly 800 articles. To avoid these entities from dominating the train or test splits, we moved them to a separate test collection. We split the remaining into a training, dev, and test sets at random. Thus our collection includes one standard test set consisting of articles drawn at random (Test Standard), while the other is a test set which contains multiple articles about a small number of popular entities (Test Frequent).

For the Random Test data, the model correctly classified 60% of the samples (4 Negative, 17 Neutral, and 9 Positive), with F1 and ROC-AUC scores indicating moderate performance in classifying sentiment. For the Frequent Test data, the model correctly classified 42% of the samples (6 Negative, 15 Neutral, and 0 Positive), with F1 and ROC-AUC scores indicating worse performance in classifying sentiment. The model seems to have difficulty classifying negative and positive sentiments correctly, as there are relatively higher misclassifications between these two classes (especially the latter).

Overall, the model performs moderately well on the Random Test data, but its performance drops when dealing with the Frequent Test data, which contains multiple articles about a small number of popular entities. This suggests that the few-shot learning approach might need more examples or a more specific prompt to better handle sentiment classification for popular entities.
