In [3]:
import openai

from fp_dataset_artifacts.utils import init_openai, upload, save_results
from fp_dataset_artifacts.snli import int2label
from datasets import list_datasets, load_dataset, list_metrics, load_metric

# Initialize OpenAI API with API_KEY
init_openai()

In [1]:
def map_finetune_hypothesis(x):
    hypothesis = x['hypothesis']
    label = int2label(x['label'])

    return {
        'prompt': f"Hypothesis: {hypothesis}\n\nLabel: ",
        'completion': label
        + '\n',  # '\n' added to prevent further text generation.
    }

In [6]:
data = load_dataset('snli')

train_sample_size = 55000

train = (data['train'].shuffle(0).select(list(range(train_sample_size)))).map(map_finetune_hypothesis)
train = train.remove_columns(['premise', 'hypothesis', 'label'])

valid = data['validation'].map(map_finetune_hypothesis)
valid = valid.remove_columns(['premise', 'hypothesis', 'label'])

train, valid

Reusing dataset snli (/home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)
Loading cached shuffled indices for dataset at /home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-0e9d0b15c43a175e.arrow
Loading cached processed dataset at /home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-32180687a36b9af5.arrow


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




(Dataset({
     features: ['prompt', 'completion'],
     num_rows: 55000
 }),
 Dataset({
     features: ['prompt', 'completion'],
     num_rows: 10000
 }))

In [11]:
train_local_filename = 'snli_finetune_train_sample_hypothesis_only.jsonl'
valid_local_filename = 'snli_finetune_validation_hypothesis_only.jsonl'

In [9]:
upload(train, train_local_filename)

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




'file-m028aD3lDT3pheuZCp76ML4m'

In [10]:
upload(valid, valid_local_filename)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




'file-ZdqDBx6kGH9rUTsYQWhptb96'

In [14]:
finetune_resp = openai.FineTune.create(
    training_file='file-m028aD3lDT3pheuZCp76ML4m',
    validation_file='file-ZdqDBx6kGH9rUTsYQWhptb96',
    model='curie',
    n_epochs=4,
    compute_classification_metrics=True,
    classification_n_classes=3,
)

# Get the fine-tune ID
finetune_id = finetune_resp['id']
finetune_id

'ft-9JSpBBaY4vJKmx5yzibrv2am'

In [25]:
openai.FineTune.retrieve(finetune_id)

<FineTune fine-tune id=ft-9JSpBBaY4vJKmx5yzibrv2am at 0x7f7b77a62f90> JSON: {
  "created_at": 1639093591,
  "events": [
    {
      "created_at": 1639093591,
      "level": "info",
      "message": "Created fine-tune: ft-9JSpBBaY4vJKmx5yzibrv2am",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1639093594,
      "level": "info",
      "message": "Fine-tune costs $14.00",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1639093595,
      "level": "info",
      "message": "Fine-tune enqueued. Queue number: 0",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1639093598,
      "level": "info",
      "message": "Fine-tune started",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1639093956,
      "level": "info",
      "message": "Completed epoch 1/4",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1639094876,
      "level": "info",
      "message": "Completed epoch 2/4",
      "object": "fine-tun

In [27]:
save_results('file-wvA9xdPXmoqo8IrN7JNW144L', 'snli_hypothesis_only_results.csv')

In [28]:
import pandas as pd

df = pd.read_csv('../results/snli_hypothesis_only_results.csv')
df

Unnamed: 0,step,elapsed_tokens,elapsed_examples,training_loss,training_sequence_accuracy,training_token_accuracy,validation_loss,validation_sequence_accuracy,validation_token_accuracy,classification/accuracy,classification/weighted_f1_score
0,1,2049,1,0.368273,0.0,0.816024,0.314335,1.0,1.0,,
1,2,4098,2,0.335983,0.0,0.739521,,,,,
2,3,6147,3,0.442174,0.0,0.836207,,,,,
3,4,8196,4,0.425365,0.0,0.788520,,,,,
4,5,10245,5,0.394439,0.0,0.789941,,,,,
...,...,...,...,...,...,...,...,...,...,...,...
2284,2285,4681965,2285,0.081135,0.0,0.987805,,,,,
2285,2286,4684014,2286,0.075433,0.0,0.988060,,,,,
2286,2287,4686063,2287,0.075334,0.0,0.982301,,,,,
2287,2288,4688112,2288,0.080177,0.0,0.978979,,,,,
