In [1]:
from __future__ import annotations
import logging
import os
import sys

import datasets as nlp_datasets
import pandas as pd
from sklearn.metrics import f1_score

from cappr import openai
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
from utils import display_df

In [2]:
## When hitting the OpenAI endpoints, we'll log any server errors
logging.basicConfig(level=logging.INFO,
                    handlers=[logging.StreamHandler(stream=sys.stdout)],
                    format='%(asctime)s :: %(name)s :: %(levelname)s :: '
                           '%(message)s')
logger = logging.getLogger(__name__)

In [3]:
df = pd.DataFrame(nlp_datasets
                  .load_dataset('ought/raft', 'overruling', split='train'))



In [4]:
len(df)

50

In [5]:
df.head()

Unnamed: 0,Sentence,ID,Label
0,in light of both our holding today and previou...,0,2
1,"see mciver, 134 n.c.app. at 588, 518 s.e.2d a...",1,1
2,"to the extent that paprskar v. state, supra, a...",2,2
3,"we reverse and remand, and in doing so, we ove...",3,2
4,to the extent that other cases have cited carr...,4,2


In [6]:
def prompt(sentence: str) -> str:
    return ( 'In law, an overruling sentence is a statement that nullifies a previous '
             'case decision as a precedent. Is the following sentence overruling?\n'
            f'Sentence: {sentence}\n'
             'Answer Yes or No:')

In [7]:
df['prompt'] = [prompt(sentence) for sentence in df['Sentence']]

In [8]:
display_df(df, columns=['prompt', 'Label'], num_rows=1)

Unnamed: 0,prompt,Label
0,"In law, an overruling sentence is a statement that nullifies a previous case decision as a precedent. Is the following sentence overruling? Sentence: in light of both our holding today and previous rulings in johnson, dueser, and gronroos, we now explicitly overrule dupree. Answer Yes or No:",2


In [9]:
prior = (df['Label']
         .value_counts(normalize=True)
         .sort_index()
         .to_numpy())
prior

array([0.5, 0.5])

In [10]:
## $0.15
pred_probs = (openai.classify
              .predict_proba(df['prompt'].tolist(),
                             completions=('No', 'Yes'),
                             model='text-davinci-003',
                             prior=prior,
                             ask_if_ok=True))

log-probs:   0%|          | 0/100 [00:00<?, ?it/s]

In [11]:
f1_score(df['Label']-1, pred_probs.argmax(axis=1), average='macro')

0.9198717948717948