In [1]:
from __future__ import annotations
import logging
import os
import sys

import datasets as nlp_datasets
import pandas as pd
from sklearn.metrics import f1_score

from cappr import openai
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
from utils import display_df

This task is a tough cookie and does a good job demonstrating that zero-shot
classification is not very appropriate for more expert-level tasks. I wouldn't recommend
using CAPPr or even GPT-3+ in cases like these. You should instead train a model so that
it picks up the subtle correlations in the training data&mdash;subtelties which are hard
to verbalize in a prompt.

In [2]:
## When hitting the OpenAI endpoints, we'll log any server errors
logging.basicConfig(level=logging.INFO,
                    handlers=[logging.StreamHandler(stream=sys.stdout)],
                    format='%(asctime)s :: %(name)s :: %(levelname)s :: '
                           '%(message)s')
logger = logging.getLogger(__name__)

In [3]:
df = pd.DataFrame(nlp_datasets
                  .load_dataset('ought/raft', 'one_stop_english', split='train'))



In [4]:
len(df)

50

In [5]:
df.head()

Unnamed: 0,Article,ID,Label
0,"For 85 years, it was just a grey blob on class...",0,3
1,He had the tastes of a typical millionaire. He...,1,1
2,The Moroccan city of Ouarzazate is used to big...,2,1
3,SeaWorld has suffered an 84% collapse in profi...,3,3
4,There are worse things to do in life than stro...,4,2


In [6]:
def prompt(article: str, num_paragraphs: int=3, paragraph_delimeter: str='\n') -> str:
    article_truncated = (paragraph_delimeter
                         .join(article
                               .split(paragraph_delimeter)
                               [:num_paragraphs]))
    return ('An article was rewritten to suit three levels of adult English as Second '
            'Language (ESL) learners: elementary, intermediate, and advanced. Predict '
            'the level that this article was written in.\n\n'
            f'Article: {article_truncated}\n'
             'Label:')

In [7]:
df['prompt'] = [prompt(article) for article in df['Article']]

In [8]:
display_df(df, columns=['prompt', 'Label'], num_rows=1)

Unnamed: 0,prompt,Label
0,"An article was rewritten to suit three levels of adult English as Second Language (ESL) learners: elementary, intermediate, and advanced. Predict the level that this article was written in. Article: For 85 years, it was just a grey blob on classroom maps of the solar system. But, on 15 July, Pluto was seen in high resolution for the first time. The images show dramatic mountain ranges made from solid water ice as big as the Alps or the Rockies. The extraordinary images of the former ninth planet and its large moon, Charon, were sent back 4bn miles to Earth from the New Horizons spacecraft. They are the climax of a mission that has been quietly underway for nearly ten years. Alan Stern, the mission’s principal investigator, said “New Horizons is returning amazing results. The data look absolutely gorgeous, and Pluto and Charon are just mind-blowing.” Label:",3


In [9]:
prior = (df['Label']
         .value_counts(normalize=True)
         .sort_index()
         .to_numpy())
prior

array([0.36, 0.4 , 0.24])

In [10]:
## $0.77
pred_probs = (openai.classify
              .predict_proba(df['prompt'].tolist(),
                             completions=('advanced', 'elementary', 'intermediate'),
                             model='text-davinci-003',
                             prior=prior,
                             ask_if_ok=True))

log-probs:   0%|          | 0/150 [00:00<?, ?it/s]

In [11]:
f1_score(df['Label']-1, pred_probs.argmax(axis=1), average='macro')

0.24510551741673023

In [12]:
(pred_probs.argmax(axis=1) == df['Label']-1).mean()

0.36

Not better than the majority classifier, ouch

In [13]:
pred_probs.argmax(axis=1)

array([0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 2, 0, 0, 0,
       2, 0, 0, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 2,
       0, 0, 0, 0, 0, 0], dtype=int64)

Slightly miscalibrated lol