In [27]:
from sklearn.datasets import fetch_20newsgroups
import pandas as pd
import openai
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

In [28]:
dfs = {
    k: pd.read_pickle(f'../data/df_{k}.pkl')
    for k in ['train', 'test_recent', 'test_old']
}

In [29]:
d = pd.DataFrame()
d['prompt'] = dfs['train']['year'].astype(str) + '\n\n' # No need to start with 'Published in '
d['completion'] = ' ' + dfs['train']['title'].str.strip()

The model is successfully trained in about ten minutes. We can see the model name is `ada:ft-openai-2021-07-30-12-26-20`, which we can use for doing inference.

## Using the model
We can now call the model to get the predictions.

In [13]:
test = pd.read_json('sport2_prepared_valid.jsonl', lines=True)
test.head()

Unnamed: 0,prompt,completion
0,From: gld@cunixb.cc.columbia.edu (Gary L Dare)...,hockey
1,From: smorris@venus.lerc.nasa.gov (Ron Morris ...,hockey
2,From: golchowy@alchemy.chem.utoronto.ca (Geral...,hockey
3,From: krattige@hpcc01.corp.hp.com (Kim Krattig...,baseball
4,From: warped@cs.montana.edu (Doug Dolven)\nSub...,baseball


We need to use the same separator following the prompt which we used during fine-tuning. In this case it is `\n\n###\n\n`. Since we're concerned with classification, we want the temperature to be as low as possible, and we only require one token completion to determine the prediction of the model.

In [14]:
ft_model = 'ada:ft-openai-2021-07-30-12-26-20'
res = openai.Completion.create(model=ft_model, prompt=test['prompt'][0] + '\n\n###\n\n', max_tokens=1, temperature=0)
res['choices'][0]['text']


' hockey'

To get the log probabilities, we can specify logprobs parameter on the completion request

In [15]:
res = openai.Completion.create(model=ft_model, prompt=test['prompt'][0] + '\n\n###\n\n', max_tokens=1, temperature=0, logprobs=2)
res['choices'][0]['logprobs']['top_logprobs'][0]

<OpenAIObject at 0x7fe114e435c8> JSON: {
  " baseball": -7.6311407,
  " hockey": -0.0006307676
}