In [2]:
import os
import openai
import pandas as pd
import numpy as np
from tqdm import tqdm
import json
import time
openai.api_key = open("openai_api.key").read()

# Prepare training data

In [5]:
def proc(sent):
    if not sent.endswith(".") or sent.endswith("!"):  # finish with period
        sent += '.'
    if not sent[0].isupper():  # start with a capital letter
        sent = sent[0].upper() + sent[1:]
    return sent


In [3]:
df = pd.read_csv("data/train.csv")

In [19]:
intermediate_csv = [['prompt', 'completion']]
for i, line in df.iterrows():
    prompt = proc(line['startphrase']) + ' -> '
    completion = proc(line[f'ending{line["labels"]+1}'])
    intermediate_csv.append([prompt, completion])

In [20]:
pd.DataFrame(intermediate_csv).to_csv("data/finetune_train.csv", header=False, index=False)

In [21]:
!openai tools fine_tunes.prepare_data -f data/finetune_train.csv

Analyzing...

- Based on your file extension, your file is formatted as a CSV file
- Your file contains 1458 prompt-completion pairs
- All prompts end with suffix `. -> `
- All completions end with suffix `.`
- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details

Based on the analysis we will perform the following actions:
- [Necessary] Your format `CSV` will be converted to `JSONL`
- [Recommended] Add a whitespace character to the beginning of the completion [Y/n]: ^C



In [None]:
!openai api fine_tunes.create -t data/finetune_train_prepared.jsonl  -m ada

In [29]:
ADA_FINETUNED = 'ada:ft-user-6qia53bwp385gfq1da9w5yum-2021-11-28-03-10-25'
BABBAGE_FINETUNED = 'babbage:ft-user-6qia53bwp385gfq1da9w5yum-2021-11-28-04-06-02'
CURIE_FINETUNED = 'curie:ft-user-6qia53bwp385gfq1da9w5yum-2021-11-28-04-35-14'

In [60]:
split = 'test'
df = pd.read_csv(f"data/{split}.csv")
restart = None


model_name = ADA_FINETUNED

if restart is None:
    json_lines = {}


if model_name != 'debug':
    response = input(f"about the spend $$$ on openai API (model {model_name})! conitnue? [y/n]")
    if response.lower() != 'y':
        raise Exception("Not continuing.")
else:
    print('just debugging. this is free.')

    
for i, line in tqdm(df.iterrows(), total=df.shape[0]):
    if restart is not None and i < restart: continue
        
    start = line['startphrase']
    end1 = line['ending1']
    end2 = line['ending2']
    res_two_endings = []
    for j, prompt in enumerate((
        proc(start)+' '+proc(end1),
        proc(start)+' '+proc(end2),
    )):
        if model_name == 'debug':
            res = debug_res
        else:
            completion = openai.Completion.create(model=model_name, prompt=prompt,
                                                      max_tokens=0,
                                                      temperature=0.0,
                                                      logprobs=0,
                                                      echo=True,
                                                      n=1)
            logprobs = completion['choices'][0]['logprobs']
            res = {k: logprobs[k] for k in ('token_logprobs', 'tokens')}
        res_two_endings.append(res)
        if model_name != 'debug':
            time.sleep(0.05)  # to prevent RateLimitError
    json_lines[f"{line.get('qid', i)}_{line['labels']}"] = res_two_endings


fname = f"{split}_logprobs_{model_name}.json"
with open(fname, 'w') as f:
    f.write('')

with open(fname, 'a') as f:
    json.dump(json_lines, f, indent=2)

about the spend $$$ on openai API (model ada:ft-user-6qia53bwp385gfq1da9w5yum-2021-11-28-03-10-25)! conitnue? [y/n]y


100%|██████████| 1146/1146 [07:49<00:00,  2.44it/s]


In [53]:
def prob_of_ending(token_logprobs, tokens):
    logprob_sum = 0
    for count, (lp, t) in enumerate(zip(token_logprobs[::-1], tokens[::-1])):
        if count > 0 and t.endswith('.'):
            break
        logprob_sum += lp
    return logprob_sum / count


def calculate_accuracy(fname):
    with open(fname) as f:
        logprobs = json.load(f)

    correct = 0
    for qid_label, (end1, end2) in logprobs.items():
        end1_prob = prob_of_ending(end1['token_logprobs'], end1['tokens'])
        end2_prob = prob_of_ending(end2['token_logprobs'], end2['tokens'])
        label = int(qid_label[-1])
        if (label == 0 and end1_prob > end2_prob) or (label==1 and end1_prob < end2_prob):
            correct += 1

    print(f"correct: {correct}/{len(logprobs)} = {correct/len(logprobs)}")

In [54]:
calculate_accuracy(f"dev_logprobs_{ADA_FINETUNED}.json")

correct: 794/1094 = 0.7257769652650823


In [55]:
calculate_accuracy(f"dev_logprobs_{BABBAGE_FINETUNED}.json")

correct: 832/1094 = 0.7605118829981719


In [56]:
calculate_accuracy(f"dev_logprobs_{CURIE_FINETUNED}.json")

correct: 866/1094 = 0.7915904936014625


In [63]:
calculate_accuracy(f"test_logprobs_{ADA_FINETUNED}.json")

correct: 792/1145 = 0.6917030567685589


In [64]:
calculate_accuracy(f"test_logprobs_{BABBAGE_FINETUNED}.json")

correct: 847/1145 = 0.7397379912663755


In [65]:
calculate_accuracy(f"test_logprobs_{CURIE_FINETUNED}.json")

correct: 905/1145 = 0.7903930131004366
