In [1]:
from __future__ import annotations
import logging
import os
import sys

import datasets as nlp_datasets
import pandas as pd
from sklearn.metrics import f1_score

from cappr import openai
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
from utils import display_df

In [2]:
## When hitting the OpenAI endpoints, we'll log any server errors
logging.basicConfig(level=logging.INFO,
                    handlers=[logging.StreamHandler(stream=sys.stdout)],
                    format='%(asctime)s :: %(name)s :: %(levelname)s :: '
                           '%(message)s')
logger = logging.getLogger(__name__)

In [3]:
df = pd.DataFrame(nlp_datasets
                  .load_dataset('ought/raft', 'semiconductor_org_types', split='train'))



In [4]:
len(df)

50

In [5]:
df.head()

Unnamed: 0,Paper title,Organization name,ID,Label
0,3Gb/s AC-coupled chip-to-chip communication us...,"North Carolina State Univ.,Raleigh,NC,USA",0,3
1,Sub-Micron CMOS / MOS-Bipolar Hybrid TFTs for ...,Advanced LCD Technology Development Center Com...,1,1
2,24.4 A 680nA fully integrated implantable ECG-...,"imec,Heverlee,Belgium",2,2
3,A write-back cache memory using bit-line steal...,"Corp. Semicond. Dev. Div.,Matsushita Electr. I...",3,1
4,High performance 0.25 /spl mu/m gate-length do...,"APA Optics, Inc., Blaine, MN, USA",4,1


In [6]:
def prompt(organization_name: str) -> str:
    return (f'This is the name of an organization: {organization_name}\n'
             'Is this organization a company, research institute, or '
             'university?\n'
             'Answer: ')

In [7]:
df['prompt'] = [prompt(organization_name)
                for organization_name in df['Organization name']]

In [8]:
display_df(df, columns=['prompt', 'Label'])

Unnamed: 0,prompt,Label
0,"This is the name of an organization: North Carolina State Univ.,Raleigh,NC,USA Is this organization a company, research institute, or university? Answer:",3
1,"This is the name of an organization: Advanced LCD Technology Development Center Company Limited, Yokohama, Kanagawa, Japan Is this organization a company, research institute, or university? Answer:",1
2,"This is the name of an organization: imec,Heverlee,Belgium Is this organization a company, research institute, or university? Answer:",2


In [9]:
prior = (df['Label']
         .value_counts(normalize=True)
         .sort_index()
         .to_numpy())
prior

array([0.76, 0.12, 0.12])

In [10]:
## $0.12
pred_probs = (openai.classify
              .predict_proba(df['prompt'].tolist(),
                             completions=('company',
                                          'research institute',
                                          'university'),
                             model='text-davinci-003',
                             prior=prior,
                             ask_if_ok=True))

log-probs:   0%|          | 0/150 [00:00<?, ?it/s]

In [11]:
f1_score(df['Label']-1, pred_probs.argmax(axis=1), average='macro')

0.742602495543672

In [12]:
(pred_probs.argmax(axis=1) == df['Label']-1).mean()

0.8

Welp, not much better than the majority class.