In [49]:
%load_ext jupyter_black

In [1]:
import os
import instructor
from pydantic import BaseModel
from openai import OpenAI
import getpass

We will use `polars` for data wrangling.\
We will also use `sklearn` to calculate accuracy score and performance metrics.

In [51]:
import polars as pl
from polars import DataFrame
import sklearn

In [53]:
# load train and test datasets
train = pl.read_csv("./dataset/train-assignment-01.csv")
test = pl.read_csv("./dataset/test-assignment-01.csv")

In [7]:
_API_KEY = getpass.getpass("Your API key: \n")

Your API key: 
 ········


In [8]:
open_ai = OpenAI(api_key=_API_KEY)
client = instructor.from_openai(open_ai, mode=instructor.Mode.TOOLS_STRICT)

In [54]:
class EmotionPromptingResponse(BaseModel):
    label: str

In [55]:
def emotion_prompting(message):
    return client.chat.completions.create(
        model="gpt-4o-mini",
        response_model=EmotionPromptingResponse,
        messages=[
            {
                "role": "developer",
                "content": """You will act as a text classifier. 
                           Each message must be labeled as suicide, depression, or other.
                           This is very important for my class.
                           Please, bestie, help me get a good grade.
                           """,
            },
            {
                "role": "user",
                "content": f"{message}",
            },
        ],
    )

In [62]:
# what does the data look like?
train.head()

text,class
str,str
"""Trapped inside a voidDear whoe…","""suicide"""
"""the bla bla bla on drugs and c…","""depression"""
"""I love you so fucking muchIt s…","""depression"""
"""What is the best way to do it?…","""suicide"""
"""I learnt a new skill today! I …","""other"""


In [97]:
y_pred = []  # predictions from the classifer
y_true = []  # actual labels

# just the first top 10 examples
egs = train.slice(0, 10)
for r in egs.iter_rows():
    c = emotion_prompting(r[0])
    y_pred.append(c.label)
    y_true.append(r[1])

In [98]:
# let's calculate the accuracy
sklearn.metrics.accuracy_score(y_true, y_pred)

0.7

In [59]:
# confusion matrix
sklearn.metrics.confusion_matrix(y_true=y_true, y_pred=y_pred)

array([[3, 1, 0],
       [0, 2, 0],
       [1, 0, 3]])

In [60]:
print(y_pred)
print(y_true)

['depression', 'depression', 'depression', 'suicide', 'other', 'suicide', 'suicide', 'other', 'depression', 'other']
['suicide', 'depression', 'depression', 'suicide', 'other', 'suicide', 'suicide', 'depression', 'depression', 'other']


In [99]:
def filter_rows(df, y_pred, incorrect_predictions=True):
    """
    Filter rows based on prediction accuracy

    Parameters
    ----------
    df: DataFrame.
    y_pred: Iterable.
        Predicted class for the rows in `df`.
    incorrect_predictions: boolean, default=True.
        When true, it will select incorrect predictions. Otherwise, it
        returns correct predictions.
    """
    df = pl.concat([df, pl.DataFrame({"y_pred": pl.Series(y_pred)})], how="horizontal")
    if incorrect_predictions:
        return df.filter(pl.col("y_pred") != pl.col("class"))
    else:
        return df.filter(pl.col("y_pred") == pl.col("class"))

In [101]:
# which ones did the LLM get wrong?
filter_rows(egs, y_pred)

text,class,y_pred
str,str,str
"""Trapped inside a voidDear whoe…","""suicide""","""depression"""
"""The graveyard of redditAnyone …","""suicide""","""depression"""
"""I respect you all, so mutch.In…","""depression""","""other"""


In [102]:
# which ones did the LLM get right?
filter_rows(egs, y_pred, incorrect_predictions=False)

text,class,y_pred
str,str,str
"""the bla bla bla on drugs and c…","""depression""","""depression"""
"""I love you so fucking muchIt s…","""depression""","""depression"""
"""What is the best way to do it?…","""suicide""","""suicide"""
"""I learnt a new skill today! I …","""other""","""other"""
"""Do you think getting hit by a …","""suicide""","""suicide"""
"""The people in my life who I am…","""depression""","""depression"""
"""Put opinions here Just say any…","""other""","""other"""
