In [1]:
from __future__ import annotations
import logging
import os
import sys

import datasets as nlp_datasets
import pandas as pd
from sklearn.metrics import f1_score

from cappr import openai
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
from utils import display_df

In [2]:
## When hitting the OpenAI endpoints, we'll log any server errors
logging.basicConfig(level=logging.INFO,
                    handlers=[logging.StreamHandler(stream=sys.stdout)],
                    format='%(asctime)s :: %(name)s :: %(levelname)s :: '
                           '%(message)s')
logger = logging.getLogger(__name__)

In [3]:
df = pd.DataFrame(nlp_datasets
                  .load_dataset('ought/raft', 'twitter_complaints', split='train'))

Downloading and preparing dataset raft/twitter_complaints (download: 9.30 MiB, generated: 366.13 KiB, post-processed: Unknown size, total: 9.66 MiB) to C:/Users/kushd/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84...


Downloading data files:   0%|          | 0/11 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/11 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/50 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3399 [00:00<?, ? examples/s]

Dataset raft downloaded and prepared to C:/Users/kushd/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84. Subsequent calls will reuse this data.


In [4]:
len(df)

50

In [5]:
df.head()

Unnamed: 0,Tweet text,ID,Label
0,@HMRCcustomers No this is my first job,0,2
1,@KristaMariePark Thank you for your interest! ...,1,2
2,If I can't get my 3rd pair of @beatsbydre powe...,2,1
3,@EE On Rosneath Arial having good upload and d...,3,1
4,"Couples wallpaper, so cute. :) #BrothersAtHome",4,2


In [6]:
def prompt(tweet_text: str) -> str:
    return ( "A complaint presents a state of affairs which breaches the writer's "
             'favorable expectation.\n\n'
            f'Here is a tweet: {tweet_text}\n\n'
             'Does the tweet contain a complaint? Answer Yes or No:')

In [7]:
df['prompt'] = [prompt(tweet_text) for tweet_text in df['Tweet text']]

In [8]:
display_df(df, columns=['prompt', 'Label'])

Unnamed: 0,prompt,Label
0,A complaint presents a state of affairs which breaches the writer's favorable expectation. Here is a tweet: @HMRCcustomers No this is my first job Does the tweet contain a complaint? Answer Yes or No:,2
1,"A complaint presents a state of affairs which breaches the writer's favorable expectation. Here is a tweet: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES. Does the tweet contain a complaint? Answer Yes or No:",2
2,A complaint presents a state of affairs which breaches the writer's favorable expectation. Here is a tweet: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService Does the tweet contain a complaint? Answer Yes or No:,1


In [9]:
prior = (df['Label']
         .value_counts(normalize=True)
         .sort_index()
         .to_numpy())
prior

array([0.34, 0.66])

In [10]:
## $0.12
pred_probs = (openai.classify
              .predict_proba(df['prompt'].tolist(),
                             completions=('Yes', 'No'),
                             model='text-davinci-003',
                             prior=prior,
                             ask_if_ok=True))

log-probs:   0%|          | 0/100 [00:00<?, ?it/s]

In [11]:
f1_score(df['Label']-1, pred_probs.argmax(axis=1), average='macro')

0.8697916666666666