In [114]:
import openai

from fp_dataset_artifacts.utils import init_openai, get_finetune_response
from fp_dataset_artifacts.snli import map_refs_and_preds
from fp_dataset_artifacts.anli import map_finetune, get_response
from datasets import (
    list_datasets, load_dataset, list_metrics, load_metric, concatenate_datasets
)

init_openai()

snli = load_dataset('snli')
anli = load_dataset('anli')

Reusing dataset snli (/home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)
Reusing dataset anli (/home/x/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)


In [70]:
# Prepare two seperate training set for fine-tuning
# One for training with only snli
# Another with training with snli and anli round 1
# We make sure that the trainging set size is equal
snli_training_size = 25000
anli_training_size = anli['train_r1'].num_rows
total_training_size = snli_training_size + anli_training_size
total_training_size

41946

In [71]:
snli_only = snli['train'].shuffle(0).select(list(range(total_training_size)))
snli_only = snli_only.map(map_finetune)
snli_only = snli_only.remove_columns(['premise', 'hypothesis', 'label'])
snli_only

Loading cached shuffled indices for dataset at /home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-0e9d0b15c43a175e.arrow


HBox(children=(FloatProgress(value=0.0, max=41946.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=41946.0), HTML(value='')))




Dataset({
    features: ['prompt', 'completion'],
    num_rows: 41946
})

In [72]:
anli_only = anli['train_r1'].map(map_finetune)
anli_only = anli_only.remove_columns(['uid', 'reason', 'premise', 'hypothesis', 'label'])
anli_only

Loading cached processed dataset at /home/x/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-d198be314a77b6f4.arrow


Dataset({
    features: ['prompt', 'completion'],
    num_rows: 16946
})

In [73]:
snli_plus_anli = concatenate_datasets([
    snli_only.select(list(range(snli_training_size))),
    anli_only
]).shuffle(0)
snli_plus_anli

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 41946
})

In [74]:
# Same setup for the validation
total_valid_size = snli['validation'].num_rows
anli_valid_size = anli['dev_r1'].num_rows
snli_valid_size = total_valid_size - anli_valid_size

snli_valid = snli['validation'].shuffle(0).select(list(range(total_valid_size)))
snli_valid = snli_valid.map(map_finetune).remove_columns(['premise', 'hypothesis', 'label'])
snli_valid

Loading cached shuffled indices for dataset at /home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-70baa31142a3b54f.arrow
Loading cached processed dataset at /home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-7d2d35d5f4813980.arrow


Dataset({
    features: ['prompt', 'completion'],
    num_rows: 10000
})

In [75]:
anli_valid = anli['dev_r1']
anli_valid = anli_valid.map(map_finetune).remove_columns(['uid', 'reason', 'premise', 'hypothesis', 'label'])
anli_valid

Loading cached processed dataset at /home/x/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-5e201f1ee0e85880.arrow


Dataset({
    features: ['prompt', 'completion'],
    num_rows: 1000
})

In [76]:
snli_plus_anli_valid = concatenate_datasets([
    snli_valid.select(list(range(snli_valid_size))),
    anli_valid
]).shuffle(0)
snli_plus_anli_valid

Loading cached shuffled indices for dataset at /home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-be13fafa07810d17.arrow


Dataset({
    features: ['prompt', 'completion'],
    num_rows: 10000
})

In [77]:
def upload(dataset, filename, purpose='fine-tune'):
    dataset.to_json(filename)
    response = openai.File.create(
        file=open(filename), purpose=purpose
    )
    file_id = response['id']
    return file_id

In [78]:
# Save and upload all files
snli_train_id = upload(snli_only, 'ablation_snli_train.jsonl')
# snli_valid_id = upload(snli_valid, 'ablation_snli_valid.jsonl')
anli_train_id = upload(snli_plus_anli, 'ablation_anli_train.jsonl')
# anli_valid_id = upload(snli_plus_anli_valid, 'ablation_anli_valid.jsonl')

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [79]:
print(snli_train_id, snli_valid_id, anli_train_id, anli_valid_id, sep='\n')

file-SbdP5nU4beGEMZCH9P56MPjH
file-1syChFe6i616f70H3MfnArBf
file-a7ryGWMmgQwczg959jdTDsJw
file-v5uBw86uR9XhsKeaKCLkTwKq


In [80]:
def finetune(
    train_file_id,
    valid_file_id,
    model='curie',
    n_epochs=4,
    compute_classification_metrics=True,
    classification_n_classes=3,
):
    # Initialize OpenAI API with API_KEY
    init_openai()

    # Create fine-tune
    finetune_resp = openai.FineTune.create(
        training_file=train_file_id,
        validation_file=valid_file_id,
        model=model,
        n_epochs=n_epochs,
        compute_classification_metrics=compute_classification_metrics,
        classification_n_classes=classification_n_classes,
    )

    # Get the fine-tune ID
    finetune_id = finetune_resp['id']

    return finetune_id

In [84]:
# Fine tuning 1: snli only with more snli data instead of anli
snli_finetune_id = finetune(snli_train_id, snli_valid_id)
snli_finetune_id

'ft-a0nNphtyncMOu5XmSLlUkvDz'

In [81]:
# Fine tuning 2: snli with anli in addition
anli_finetune_id = finetune(anli_train_id, anli_valid_id)
anli_finetune_id

'ft-rDeH2hvXBGo0ZlmsxEKbvkP4'

In [101]:
openai.FineTune.retrieve(snli_finetune_id)

<FineTune fine-tune id=ft-a0nNphtyncMOu5XmSLlUkvDz at 0x7f6914249360> JSON: {
  "created_at": 1638837445,
  "events": [
    {
      "created_at": 1638837445,
      "level": "info",
      "message": "Created fine-tune: ft-a0nNphtyncMOu5XmSLlUkvDz",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638843358,
      "level": "info",
      "message": "Fine-tune enqueued. Queue number: 0",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638843365,
      "level": "info",
      "message": "Fine-tune started",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638843811,
      "level": "info",
      "message": "Completed epoch 1/4",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638844462,
      "level": "info",
      "message": "Completed epoch 2/4",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638845102,
      "level": "info",
      "message": "Completed epoch 3/4",
      "object": "fine-tune-e

In [102]:
openai.FineTune.retrieve(anli_finetune_id)

<FineTune fine-tune id=ft-rDeH2hvXBGo0ZlmsxEKbvkP4 at 0x7f6914247130> JSON: {
  "created_at": 1638837419,
  "events": [
    {
      "created_at": 1638837420,
      "level": "info",
      "message": "Created fine-tune: ft-rDeH2hvXBGo0ZlmsxEKbvkP4",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638839160,
      "level": "info",
      "message": "Fine-tune enqueued. Queue number: 11",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638839173,
      "level": "info",
      "message": "Fine-tune is in the queue. Queue number: 10",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638839256,
      "level": "info",
      "message": "Fine-tune is in the queue. Queue number: 9",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638839530,
      "level": "info",
      "message": "Fine-tune is in the queue. Queue number: 8",
      "object": "fine-tune-event"
    },
    {
      "created_at": 1638839562,
      "level": "in

In [103]:
def save_results(file_id, filename):
    with open(f'../results/{filename}', 'wb') as f:
        f.write(openai.File.download(file_id))

In [104]:
save_results('file-xCIoJMXr89SuZRVvTcI1yDbg', 'ablation_snli_results.csv')

In [105]:
save_results('file-T2uIh00Sl5pTmPkmAoCfnBwl', 'ablation_anli_results.csv')

In [106]:
import pandas as pd

In [109]:
pd.read_csv('../results/ablation_snli_results.csv')

Unnamed: 0,step,elapsed_tokens,elapsed_examples,training_loss,training_sequence_accuracy,training_token_accuracy,validation_loss,validation_sequence_accuracy,validation_token_accuracy,classification/accuracy,classification/weighted_f1_score
0,1,2049,1,0.331407,0.0,0.733728,0.323015,0.0,0.666667,,
1,2,4098,2,0.287834,0.0,0.729282,,,,,
2,3,6147,3,0.315665,0.0,0.810056,,,,,
3,4,8196,4,0.328352,0.0,0.804469,,,,,
4,5,10245,5,0.284234,0.0,0.822857,,,,,
...,...,...,...,...,...,...,...,...,...,...,...
3379,3380,6925620,3380,0.056051,1.0,1.000000,,,,,
3380,3381,6927669,3381,0.053326,0.0,0.994681,,,,,
3381,3382,6929718,3382,0.051672,1.0,1.000000,,,,,
3382,3383,6931767,3383,0.058237,1.0,1.000000,,,,,


In [110]:
pd.read_csv('../results/ablation_anli_results.csv')

Unnamed: 0,step,elapsed_tokens,elapsed_examples,training_loss,training_sequence_accuracy,training_token_accuracy,validation_loss,validation_sequence_accuracy,validation_token_accuracy,classification/accuracy,classification/weighted_f1_score
0,1,4098,2,0.308890,0.0,0.710638,0.33548,0.5,0.875,,
1,2,8196,4,0.228767,0.0,0.792553,,,,,
2,3,12294,6,0.351408,0.0,0.792035,,,,,
3,4,16392,8,0.205856,0.0,0.742991,,,,,
4,5,20490,10,0.199700,0.0,0.830986,,,,,
...,...,...,...,...,...,...,...,...,...,...,...
2821,2822,11564556,5644,0.027085,1.0,1.000000,,,,,
2822,2823,11568654,5646,0.030592,0.5,0.995261,,,,,
2823,2824,11572752,5648,0.033020,1.0,1.000000,,,,,
2824,2825,11576850,5650,0.029584,0.0,0.990431,,,,,


In [122]:
def evaluate(
    finetune_id,
    test,
    responses_local_filename,
):
    # Initialize OpenAI API with API_KEY
    init_openai()

    # Check if fine-tuning has completed.
    # And retrieve the model name.
    finetune_resp = openai.FineTune.retrieve(finetune_id)

    assert (
        finetune_resp['events'][-1]['message'] == 'Fine-tune succeeded'
    ), 'Please wait for the fine-tuning to be completed.'

    # Get the model name
    model = finetune_resp['fine_tuned_model']

    # Evaluate the model on test set
    def map_response(x):
        try:
            response = get_finetune_response(x['prompt'], model)
            return {'response': response['choices'][0]['text']}
        except Exception as e:
            return {'response': None}

    responses = test.map(map_response)

    # Save the responses
    responses.to_json(f'../results/{responses_local_filename}')

    # Load metrics
    f1_metric = load_metric('f1')
    acc_metric = load_metric('accuracy')

    # Convert response to references and predictions for metrics
    results = responses.map(map_refs_and_preds)

    # Compute metrics
    f1 = f1_metric.compute(
        references=results['references'],
        predictions=results['predictions'],
        average='weighted',
    )

    accuracy = acc_metric.compute(
        references=results['references'], predictions=results['predictions']
    )

    print(f'{f1=}')
    print(f'{accuracy=}')

    return model


In [116]:
test_size = 1000
snli_test = snli['test'].shuffle(0).select(list(range(test_size))).map(map_finetune)
snli_test = snli_test.map(map_finetune)
snli_test = snli_test.remove_columns(['premise', 'hypothesis', 'label'])
snli_test

Loading cached shuffled indices for dataset at /home/x/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-c5677fec1a305d01.arrow


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




Dataset({
    features: ['prompt', 'completion'],
    num_rows: 1000
})

In [117]:
anli_test = anli['test_r1'].shuffle(0).select(list(range(test_size))).map(map_finetune)
anli_test = anli_test.map(map_finetune)
anli_test = anli_test.remove_columns(['uid', 'reason', 'premise', 'hypothesis', 'label'])
anli_test

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




Dataset({
    features: ['prompt', 'completion'],
    num_rows: 1000
})

In [123]:
evaluate(snli_finetune_id, snli_test, 'ablation_snli_snli_responses.jsonl')

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


f1={'f1': 0.8935001649479813}
accuracy={'accuracy': 0.893}


'curie:ft-user-5hzndcnnszukksvrzrlnjn8l-2021-12-07-03-02-36'

In [124]:
evaluate(snli_finetune_id, anli_test, 'ablation_snli_anli_responses.jsonl')

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


f1={'f1': 0.3350202931874234}
accuracy={'accuracy': 0.337}


'curie:ft-user-5hzndcnnszukksvrzrlnjn8l-2021-12-07-03-02-36'

In [125]:
evaluate(anli_finetune_id, snli_test, 'ablation_anli_snli_responses.jsonl')

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


f1={'f1': 0.8892603877796055}
accuracy={'accuracy': 0.889}


'curie:ft-user-5hzndcnnszukksvrzrlnjn8l-2021-12-07-02-15-48'

In [127]:
evaluate(anli_finetune_id, anli_test, 'ablation_anli_anli_responses.jsonl')

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


f1={'f1': 0.5864711946043328}
accuracy={'accuracy': 0.586}


'curie:ft-user-5hzndcnnszukksvrzrlnjn8l-2021-12-07-02-15-48'