From c79fefca08c67bb2e853b3d71e293f4947d93cec Mon Sep 17 00:00:00 2001 From: Boris Power <81998504+BorisPower@users.noreply.github.com> Date: Thu, 29 Jul 2021 20:10:48 +0100 Subject: [PATCH] Updates to prepare_data function (#29) * update documentation links to point to the website * Fix encoding * Add rough time estimator based on historical stats * Fix train_test split naming logic; add quiet mode for running inside scripts * Add a finetuning step by step example for a classification use case. * add classification params if train and valid set; add length_validator --- README.md | 2 +- .../finetuning-classification.ipynb | 731 ++++++++++++++++++ openai/cli.py | 29 +- openai/validators.py | 205 +++-- openai/version.py | 2 +- 5 files changed, 901 insertions(+), 68 deletions(-) create mode 100644 examples/finetuning/finetuning-classification.ipynb diff --git a/README.md b/README.md index 0cbc53231b..c4c68e2d18 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ openai api completions.create -e ada -p "Hello world" ## Requirements -- Python 3.6+ +- Python 3.7+ In general we want to support the versions of Python that our customers are using, so if you run into issues with any version diff --git a/examples/finetuning/finetuning-classification.ipynb b/examples/finetuning/finetuning-classification.ipynb new file mode 100644 index 0000000000..f6b886a494 --- /dev/null +++ b/examples/finetuning/finetuning-classification.ipynb @@ -0,0 +1,731 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Fine tuning classification example\n", + "\n", + "We will fine-tune an ada classifier to distinguish between the two sports: Baseball and Hockey." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 21, + "source": [ + "from sklearn.datasets import fetch_20newsgroups\n", + "import pandas as pd\n", + "import openai\n", + "\n", + "categories = ['rec.sport.baseball', 'rec.sport.hockey']\n", + "sports_dataset = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories=categories)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + " ## Data exploration\n", + " The newsgroup dataset can be loaded using sklearn. First we will look at the data itself:" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 4, + "source": [ + "print(sports_dataset['data'][0])" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "From: dougb@comm.mot.com (Doug Bank)\n", + "Subject: Re: Info needed for Cleveland tickets\n", + "Reply-To: dougb@ecs.comm.mot.com\n", + "Organization: Motorola Land Mobile Products Sector\n", + "Distribution: usa\n", + "Nntp-Posting-Host: 145.1.146.35\n", + "Lines: 17\n", + "\n", + "In article <1993Apr1.234031.4950@leland.Stanford.EDU>, bohnert@leland.Stanford.EDU (matthew bohnert) writes:\n", + "\n", + "|> I'm going to be in Cleveland Thursday, April 15 to Sunday, April 18.\n", + "|> Does anybody know if the Tribe will be in town on those dates, and\n", + "|> if so, who're they playing and if tickets are available?\n", + "\n", + "The tribe will be in town from April 16 to the 19th.\n", + "There are ALWAYS tickets available! (Though they are playing Toronto,\n", + "and many Toronto fans make the trip to Cleveland as it is easier to\n", + "get tickets in Cleveland than in Toronto. Either way, I seriously\n", + "doubt they will sell out until the end of the season.)\n", + "\n", + "-- \n", + "Doug Bank Private Systems Division\n", + "dougb@ecs.comm.mot.com Motorola Communications Sector\n", + "dougb@nwu.edu Schaumburg, Illinois\n", + "dougb@casbah.acns.nwu.edu 708-576-8207 \n", + "\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 5, + "source": [ + "sports_dataset.target_names[sports_dataset['target'][0]]\n" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'rec.sport.baseball'" + ] + }, + "metadata": {}, + "execution_count": 5 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 7, + "source": [ + "len_all, len_baseball, len_hockey = len(sports_dataset.data), len([e for e in sports_dataset.target if e == 0]), len([e for e in sports_dataset.target if e == 1])\n", + "print(f\"Total examples: {len_all}, Baseball examples: {len_baseball}, Hockey examples: {len_hockey}\")" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Total examples: 1197, Baseball examples: 597, Hockey examples: 600\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "One sample from the baseball category can be seen above. It is an email to a mailing list. We can observe that we have 1197 examples in total, which are evenly split between the two sports." + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Data Preparation\n", + "We transform the dataset into a pandas dataframe, with a column for prompt and completion. The prompt contains the email from the mailing list, and the completion is a name of the sport, either hockey or baseball. For demonstration purposes only and speed of fine-tuning we take only 300 examples. In a real use case the more examples the better the performance." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 10, + "source": [ + "import pandas as pd\n", + "\n", + "labels = [sports_dataset.target_names[x].split('.')[-1] for x in sports_dataset['target']]\n", + "texts = [text.strip() for text in sports_dataset['data']]\n", + "df = pd.DataFrame(zip(texts, labels), columns = ['prompt','completion']) #[:300]\n", + "df.head()" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " prompt completion\n", + "0 From: dougb@comm.mot.com (Doug Bank)\\nSubject:... baseball\n", + "1 From: gld@cunixb.cc.columbia.edu (Gary L Dare)... hockey\n", + "2 From: rudy@netcom.com (Rudy Wade)\\nSubject: Re... baseball\n", + "3 From: monack@helium.gas.uug.arizona.edu (david... hockey\n", + "4 Subject: Let it be Known\\nFrom: \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
promptcompletion
0From: dougb@comm.mot.com (Doug Bank)\\nSubject:...baseball
1From: gld@cunixb.cc.columbia.edu (Gary L Dare)...hockey
2From: rudy@netcom.com (Rudy Wade)\\nSubject: Re...baseball
3From: monack@helium.gas.uug.arizona.edu (david...hockey
4Subject: Let it be Known\\nFrom: <ISSBTL@BYUVM....baseball
\n", + "" + ] + }, + "metadata": {}, + "execution_count": 10 + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "Both baseball and hockey are single tokens. We save the dataset as a jsonl file." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 11, + "source": [ + "df.to_json(\"sport1.jsonl\", orient='records', lines=True)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "### Data Preparation tool\n", + "We can now use a data preparation tool which will suggest a few improvements to our dataset before fine-tuning. Before launching the tool we update the openai library to ensure we're using the latest data preparation tool. We additionally specify `-q` which auto-accepts all suggestions." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "!pip install --upgrade openai" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 13, + "source": [ + "!openai tools fine_tunes.prepare_data -f sport1.jsonl -q" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Analyzing...\n", + "\n", + "- Your file contains 1197 prompt-completion pairs\n", + "- Based on your data it seems like you're trying to fine-tune a model for classification\n", + "- For classification, we recommend you try one of the faster and cheaper models, such as `ada`. You should also set the `--no_packing` parameter when fine-tuning\n", + "- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training\n", + "- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty\n", + "- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details\n", + "\n", + "Based on the analysis we will perform the following actions:\n", + "- [Recommended] Add a suffix separator `\\n\\n###\\n\\n` to all prompts [Y/n]: Y- [Recommended] Add a whitespace character to the beginning of the completion [Y/n]: Y- [Recommended] Would you like to split into training and validation set? [Y/n]: Y\n", + "\n", + "Your data will be written to a new JSONL file. Proceed [Y/n]: Y\n", + "Wrote modified files to `sport1_prepared_train.jsonl` and `sport1_prepared_valid.jsonl`\n", + "Feel free to take a look!\n", + "\n", + "Now use that file when fine-tuning:\n", + "> openai api fine_tunes.create -t \"sport1_prepared_train.jsonl\" -v \"sport1_prepared_valid.jsonl\" --no_packing\n", + "\n", + "After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `\\n\\n###\\n\\n` for the model to start generating completions, rather than continuing with the prompt.\n", + "Once your model starts training, it'll approximately take 31.06 minutes. Queue will approximately take half an hour per job ahead of you.\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "The tool helpfully suggests a few improvements to the dataset and splits the dataset into training and validation set.\n", + "\n", + "A suffix between a prompt and a completion is necessary to tell the model that the input text has stopped, and that it now needs to predict the class. Since we use the same separator in each example, the model is able to learn that it is meant to predict either baseball or hockey following the separator.\n", + "A whitespace prefix in completions is useful, as most word tokens are tokenized with a space prefix.\n", + "The tool also recognized that this is likely a classification task, so it suggested to split the dataset into training and validation datasets. This will allow us to easily measure expected performance on new data." + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Fine-tuning\n", + "The tool suggests we run the following command to train the dataset. We specifically add `-m ada` to fine-tune a cheaper and faster ada model, which is usually comperable in performance to slower and more expensive models on classification use cases. Since this is a classification task, we would like to know what the generalization performance on the provided validation set is for our classification use case. We add `--compute_classification_metrics --classification_positive_class \" hockey\"` in order to compute the classification metrics." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 14, + "source": [ + "!openai api fine_tunes.create -t \"sport1_prepared_train.jsonl\" -v \"sport1_prepared_valid.jsonl\" --no_packing -m ada --compute_classification_metrics --classification_positive_class \" hockey\"" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Upload progress: 100%|████████████████████| 1.76M/1.76M [00:00<00:00, 1.85Mit/s]\n", + "Uploaded file from sport1_prepared_train.jsonl: file-6TJY51ApcI0YzumClqdpyhjk\n", + "Upload progress: 100%|███████████████████████| 395k/395k [00:00<00:00, 754kit/s]\n", + "Uploaded file from sport1_prepared_valid.jsonl: file-7jmZYAJHneAuzVGlauejsas9\n", + "Created fine-tune: ft-T4UkKqMbMM1Eu56q8ks6g8u5\n", + "Streaming events until fine-tuning is complete...\n", + "\n", + "(Ctrl-C will interrupt the stream, but not cancel the fine-tune)\n", + "[2021-07-26 12:13:52] Created fine-tune: ft-T4UkKqMbMM1Eu56q8ks6g8u5\n", + "[2021-07-26 12:13:57] Fine-tune enqueued. Queue number: 0\n", + "[2021-07-26 12:14:00] Fine-tune started\n", + "[2021-07-26 12:16:56] Completed epoch 1/4\n", + "[2021-07-26 12:18:37] Completed epoch 2/4\n", + "[2021-07-26 12:20:29] Completed epoch 3/4\n", + "[2021-07-26 12:22:31] Completed epoch 4/4\n", + "[2021-07-26 12:24:02] Uploaded model: ada:ft-openai-internal-2021-07-26-11-24-00\n", + "[2021-07-26 12:24:06] Uploaded result file: file-ForZ3pSAQ6db7bxmMJhw6GEo\n", + "[2021-07-26 12:24:07] Fine-tune succeeded\n", + "\n", + "Job complete! Status: succeeded 🎉\n", + "Try out your fine-tuned model:\n", + "\n", + "openai api completions.create -m ada:ft-openai-internal-2021-07-26-11-24-00 -p \n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "The model is successfully trained in about ten minutes. We can see the model name is `ada:ft-openai-internal-2021-07-26-11-24-00`, which we can use for doing inference." + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "### [Advanced] Results and expected model performance\n", + "We can now download the results file to observe the expected performance on a held out validation set." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 15, + "source": [ + "!openai api fine_tunes.results -i ft-T4UkKqMbMM1Eu56q8ks6g8u5 > result.csv" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 17, + "source": [ + "results = pd.read_csv('result.csv')\n", + "results[results['classification/accuracy'].notnull()].tail(1)" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " step elapsed_tokens elapsed_examples training_loss \\\n", + "926 927 3108476 3708 0.022579 \n", + "\n", + " training_sequence_accuracy training_token_accuracy \\\n", + "926 1.0 1.0 \n", + "\n", + " classification/accuracy classification/precision classification/recall \\\n", + "926 0.995833 1.0 0.991667 \n", + "\n", + " classification/auroc classification/auprc classification/f1.0 \\\n", + "926 0.99875 0.998909 0.995816 \n", + "\n", + " validation_loss validation_sequence_accuracy validation_token_accuracy \n", + "926 NaN NaN NaN " + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
stepelapsed_tokenselapsed_examplestraining_losstraining_sequence_accuracytraining_token_accuracyclassification/accuracyclassification/precisionclassification/recallclassification/aurocclassification/auprcclassification/f1.0validation_lossvalidation_sequence_accuracyvalidation_token_accuracy
926927310847637080.0225791.01.00.9958331.00.9916670.998750.9989090.995816NaNNaNNaN
\n", + "
" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "The accuracy reaches 99.6%. On the plot below we can see how accuracy on the validation set increases during the training run. " + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 18, + "source": [ + "results[results['classification/accuracy'].notnull()]['classification/accuracy'].plot()" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 18 + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbEUlEQVR4nO3deXCc9Z3n8ffXumzLlmRbsrAOX2DARlxG+AhJYAMshk04E4LBJtlih3+Gmdns7G5BzVZ2lqqtra2amuxMFZsdZjY7Y5mbEOKwDiQBMuxMJNsyh/GNMG5dPuRLkmXr/u4f/dhphGw1dktP99OfV1UX/TzPj+5vP3748Pg5vo+5OyIikvkmhV2AiIikhgJdRCQiFOgiIhGhQBcRiQgFuohIROSG9cWlpaU+f/78sL5eRCQjbd269Yi7l422LLRAnz9/Po2NjWF9vYhIRjKz2LmW6ZCLiEhEKNBFRCJCgS4iEhFjBrqZ/cTMDpvZ9nMsNzP7azNrMrNtZrY09WWKiMhYktlD/3tg1XmW3wksCl6PAz+++LJEROTLGjPQ3f094Nh5htwDrPO4BqDEzOakqkAREUlOKo6hVwItCdOtwTwREZlAE3odupk9TvywDHPnzp3IrxYZ0/Gefna0d7GjvZOevsGwy5EIu3VxOddWl6T8c1MR6G1AdcJ0VTDvC9z9WeBZgNraWjVil9AcOdnH9rbO4NXF9vZOWo+fPrvcLMTiJPJmF01O20DfADxhZi8Cy4FOdz+Qgs8VSYlDXb1sb+vk4zPh3dbJwa7es8sXlBZyXXUJa1fMo6aymKsqiiiZmh9ixSIXZsxAN7MXgFuAUjNrBf4zkAfg7v8L2AjcBTQBp4B/PV7FipyPu9Pe2Zuw593J9vYuOrr7gPhe96Vl01ixcCY1lcXUVBazpKKIosl5IVcukhpjBrq7rx5juQN/mLKKRJLg7rQcO8329jN73p3saO/iWE8/ADmTjEWzp/H1RWVcXVlETWUxi+cUUVgQWvsikXGnrVvS3vCws/9oD9vbuz63993VGz9xmZdjXF4+ndsXl1NTVUxNRRGL5xQxOS8n5MpFJpYCXdLK0LCzr+NkfM+7NX6ycmd7FyeDq07ycyex+JLpfPPaCmoqirm6spjLL5lGQa7CW0SBLqEZGBqm6fDJzx3v3tnexemBIQAm501iyZwi7l9aSU1F/Jj3ovJp5OWoBZHIaBToMiH6B4fZe6j791ebtHex+0AXfYPDABTm53BVRTEPLauO73lXFbOwtJBchbdI0hToknK9A0PsPtidsOfdyZ6D3QwMxW89mF6Qy1WVRTy6ct7Zq00WzCpk0iRd/C1yMRToclFO9Q+y60AX29u6zl5t8snhkwwNx8O7ZGoeNRXFPPbVhdRUFnF1ZTHVM6YqvEXGgQJdktbdO8DO9q7PXW3yacdJguxmVmE+NZXF3La4nJrgUsHKkimYbrsUmRAKdBlV5+kBdgSHSz5u62JHWyf7jvScXV5eVEBNRTF3XT2Hmsr41SblRQUKb5EQKdCFYz39Z491n+lt0nzs1NnllSVTuKqiiPuur4zfGl9ZxOzpk0OsWERGo0DPMoe7e9kR9DP5OLi7su3E75tSzZ05lZrKorNXm9RUFjOzUH1NRDKBAj3ijvX08w+/2392D/xQV9/ZZQtLC1k6bwbf+8o8aiqKuaqimOKp6msikqkU6BH3Zz/7mDd3HOSysml85dLS+GWCFUUsqShiuppSiUSKAj3CDnb28qudh3j8awt56q7FYZcjIuNMt+FF2PObmxl25+HlejqUSDZQoEfUwNAwL2xu5ubLy5g3qzDsckRkAijQI+pXOw7R0d3HoyvnhV2KiEwQBXpEravfT9WMKdx8+eywSxGRCaJAj6C9h7rZ9Nkx1qyYR456pohkDQV6BNXVx8jPncSDtdVhlyIiE0iBHjEn+wZ57f1WvnnNHN3hKZJlFOgR87P3W+npH2LtCp0MFck2CvQIcXfqGmJcXVnMddUlYZcjIhNMgR4hmz47xt5DJ1m7Yp7a2IpkIQV6hNQ1xCiekse3rq0IuxQRCYECPSIOd/Xy1vaDfOeGKqbk54RdjoiEQIEeES9sbmFw2HlEJ0NFspYCPQIGhoZ5fnOMr19exoJS9W0RyVYK9Aj4zc5DHOrq06WKIllOgR4BdQ0xKkum8I0r1bdFJJsp0DNc0+FufvfpUR5ePld9W0SynAI9w61vaCY/ZxLfvVF9W0SynQI9g/X0DfLTra3cdfUllE4rCLscEQlZUoFuZqvMbI+ZNZnZk6Msn2dmb5vZNjP7rZlVpb5UGen1D9vo7htk7cr5YZciImlgzEA3sxzgGeBOYAmw2syWjBj2F8A6d78GeBr4b6kuVD7P3amrj7FkThFL55aEXY6IpIFk9tCXAU3uvs/d+4EXgXtGjFkCvBO8f3eU5ZJijbHj7D7YzdqV6tsiInHJBHol0JIw3RrMS/QRcH/w/j5gupnNGvlBZva4mTWaWWNHR8eF1CuBuvoY0yfncs916tsiInGpOin674GbzewD4GagDRgaOcjdn3X3WnevLSsrS9FXZ5+O7j5+uf0A376hiqn5uWGXIyJpIpk0aAMSr4mrCuad5e7tBHvoZjYNeMDdT6SoRhnhpS3NDAw5a3RnqIgkSGYPfQuwyMwWmFk+8BCwIXGAmZWa2ZnPegr4SWrLlDMGh4Z5blMzX72slEvLpoVdjoikkTED3d0HgSeAt4BdwMvuvsPMnjazu4NhtwB7zGwvUA7813GqN+u9vfswBzp7WbtSe+ci8nlJHYB1943AxhHzfpjw/lXg1dSWJqOpq49RUTyZW9W3RURG0J2iGeTTjpP8U9MRHl4+l9wc/dGJyOcpFTLI+oYYeTnGg+rbIiKjUKBniFP9g7y6tZVVNXOYPX1y2OWISBpSoGeIDR+20907yKM6GSoi56BAzwDuzrr6GFdeMp3aeTPCLkdE0pQCPQO833yCnQe61LdFRM5LgZ4B6ur3M70gl3uvG9lCR0Tk9xToae7IyT42fnyQB26oorBAfVtE5NwU6GnupS0t9A8Ns2bF3LBLEZE0p0BPY0PDzvObmvnKpbO4bPb0sMsRkTSnQE9j7+w+TNuJ06xVV0URSYICPY3VNcQoLyrgtiXlYZciIhlAgZ6m9h/p4b29HTy8bB556tsiIklQUqSp9Q0xcicZq5epb4uIJEeBnoZO9w/xytZW7qi5hNlF6tsiIslRoKehX3zUTufpAZ0MFZEvRYGeZtyddQ37ubx8GssXzAy7HBHJIAr0NPNhywm2t3WxdoX6tojIl6NATzN1DTEK83O4b2lV2KWISIZRoKeRYz39vLHtAPcvrWKa+raIyJekQE8jLze20D84zFo9xEJELoACPU0MDTvrG2IsXzCTy8vVt0VEvjwFepr4x72HaT1+WnvnInLBFOhpoq4+Rtn0Au646pKwSxGRDKVATwPNR0/x270drF42V31bROSCKT3SwHObYkwy4+FleoiFiFw4BXrIegeGeKmxhX+5pJxLitW3RUQunAI9ZG9sO8CJUwM6GSoiF02BHrK6+v1cWlbIyoWzwi5FRDKcAj1EH7Wc4KPWTvVtEZGUUKCHqK4hxtT8HO6/QX1bROTiJRXoZrbKzPaYWZOZPTnK8rlm9q6ZfWBm28zsrtSXGi3He/r5xUft3Ht9JUWT88IuR0QiYMxAN7Mc4BngTmAJsNrMlowY9p+Al939euAh4H+mutCoeXVrK32Dw3qIhYikTDJ76MuAJnff5+79wIvAPSPGOFAUvC8G2lNXYvQMDzvrN8W4cf4MFs8pGvtfEBFJQjKBXgm0JEy3BvMS/TmwxsxagY3AH432QWb2uJk1mlljR0fHBZQbDe990kHs6CnWrpwfdikiEiGpOim6Gvh7d68C7gLqzOwLn+3uz7p7rbvXlpWVpeirM09dfYzSaQWsUt8WEUmhZAK9DahOmK4K5iV6DHgZwN3rgclAaSoKjJqWY6d4Z89hVi+rJj9XFxmJSOokkyhbgEVmtsDM8omf9NwwYkwzcCuAmS0mHujZe0zlPJ7b1IwBq9W3RURSbMxAd/dB4AngLWAX8atZdpjZ02Z2dzDsT4E/MLOPgBeA77u7j1fRmap3YIiXG1u4fUk5FSVTwi5HRCImqQdXuvtG4ic7E+f9MOH9TuCm1JYWPRs/PsCxnn7WrpgfdikiEkE6iDuB6hpiLCwt5CuXqm+LiKSeAn2CbG/r5IPmE6xZMY9Jk9S3RURST4E+QerqY0zJy+EB9W0RkXGiQJ8AnacG+PlHbdx7fQXFU9S3RUTGhwJ9AryytYXegWHWqG+LiIwjBfo4Gx52ntvUzA3zZnBVRXHY5YhIhCnQx9k/NR3hsyM96qooIuNOgT7O6hpizCrM586r1bdFRMaXAn0ctZ04zdu7DvHdG6spyM0JuxwRiTgF+jh6flMMgIeXq2+LiIw/Bfo46Rsc4sXNLXzjynKqZkwNuxwRyQIK9HHy5vaDHO3pZ+1KnQwVkYmhQB8ndfUx5s+aytcuU1t4EZkYCvRxsLO9i8bYcfVtEZEJpUAfB3UNMSbnTeI7N1SPPVhEJEUU6CnWeXqA1z9o4+5rKyieqr4tIjJxFOgp9tr7rZweGOLRlfPDLkVEsowCPYXcnbqGGNdVl1BTqb4tIjKxFOgp9LtPj7KvQ31bRCQcCvQUWle/nxlT8/hX18wJuxQRyUIK9BQ50HmaX+88xIM3VjM5T31bRGTiKdBT5PlNzTiwZrkOt4hIOBToKdA/OMwLm1v4F1fMpnqm+raISDgU6Cnw1o6DHDnZp74tIhIqBXoK1NXHmDtzKjcvKgu7FBHJYgr0i7T7YBeb9x9jzYq56tsiIqFSoF+kuvoY+bnq2yIi4VOgX4Tu3gF+9kEb37qmghmF+WGXIyJZToF+EV57v41T/UM8qpOhIpIGFOgX6Ezflmuqirm2uiTsckREFOgXqn7fUZoOn1TfFhFJG0kFupmtMrM9ZtZkZk+OsvxHZvZh8NprZidSXmmaWd8Qo2RqHt+6tiLsUkREAMgda4CZ5QDPALcDrcAWM9vg7jvPjHH3HySM/yPg+nGoNW0c6urlrR2HeOyrC9S3RUTSRjJ76MuAJnff5+79wIvAPecZvxp4IRXFpavnNzUz7M4jy+eGXYqIyFnJBHol0JIw3RrM+wIzmwcsAN45x/LHzazRzBo7Ojq+bK1pYWBomBc2N3Pz5WXMm1UYdjkiImel+qToQ8Cr7j402kJ3f9bda929tqwsM2+T/9WOQxzu7tPJUBFJO8kEehuQeBtkVTBvNA8R8cMtdQ37qSyZwi1XzA67FBGRz0km0LcAi8xsgZnlEw/tDSMHmdmVwAygPrUlpo+9h7pp2HeMNSvmkaO+LSKSZsYMdHcfBJ4A3gJ2AS+7+w4ze9rM7k4Y+hDworv7+JQavvUNMfJzJvFgbVXYpYiIfMGYly0CuPtGYOOIeT8cMf3nqSsr/ZzsG+S199v45jVzmDWtIOxyRES+QHeKJulnH7Rxsm+QNerbIiJpSoGeBHdnfX2MmsoirlffFhFJUwr0JGz+7Bh7DnWzdsU8zHQyVETSkwI9CXUNMYom53L3taPeTyUikhYU6GM43NXLm9sP8p3aaqbkq2+LiKQvBfoYXtzSwuCws0Z3hopImlOgn8fg0DDPb2rma4tKWVCqvi0ikt4U6Ofxm12HONjVq74tIpIRFOjnsa4+RmXJFG5dXB52KSIiY1Kgn0PT4W5+9+lRHl4+V31bRCQjKNDPYX1DM3k5xndvrB57sIhIGlCgj6Knb5Cfbm3lrqvnUKq+LSKSIRToo/j5h+109w3yqPq2iEgGUaCP4O6sq9/P4jlFLJ07I+xyRESSpkAfYWvsOLsPdvPoSvVtEZHMokAfYV19jOkFudxzXUXYpYiIfCkK9AQd3X38cvsBHrihiqn5ST37Q0QkbSjQE7y0pZmBIWetToaKSAZSoAfO9G256bJZXFo2LexyRES+NAV64O3dh2nv7GXtivlhlyIickEU6IH1DTHmFE/mtsWzwy5FROSCKNCBfR0n+X+fHOHhZXPJzdEqEZHMpPQioW/LMvVtEZHMlfWBfqp/kFe2trCqZg6zp08OuxwRkQuW9YG+4cN2unsH9RALEcl4WR3o8b4tMa4on86N89W3RUQyW1YH+vvNJ9h5oIu16tsiIhGQ1YG+viHGtIJc7r2+MuxSREQuWtYG+pGTffzfbQd4YGkl0wrUt0VEMl/WBvrLjS30Dw2zRidDRSQisjLQh4ad5xqaWblwFovKp4ddjohISiQV6Ga2ysz2mFmTmT15jjEPmtlOM9thZs+ntszUenf3YdpOnFZXRRGJlDEPHptZDvAMcDvQCmwxsw3uvjNhzCLgKeAmdz9uZmndEKWuIUZ5UQG3LykPuxQRkZRJZg99GdDk7vvcvR94EbhnxJg/AJ5x9+MA7n44tWWmzv4jPfzj3g5WL5tLnvq2iEiEJJNolUBLwnRrMC/R5cDlZvbPZtZgZqtG+yAze9zMGs2ssaOj48IqvkjPbYqRO8lYvWxuKN8vIjJeUrWLmgssAm4BVgN/a2YlIwe5+7PuXuvutWVlZSn66uSd7h/i5cZW7rjqEsqL1LdFRKIlmUBvAxLbEFYF8xK1AhvcfcDdPwP2Eg/4tPKLbe10nh7QpYoiEknJBPoWYJGZLTCzfOAhYMOIMa8T3zvHzEqJH4LZl7oyL567U1cfY9HsaaxYODPsckREUm7MQHf3QeAJ4C1gF/Cyu+8ws6fN7O5g2FvAUTPbCbwL/Ad3PzpeRV+Ij1o7+bitU31bRCSykrrn3d03AhtHzPthwnsH/l3wSkvr6vdTmJ/DferbIiIRlRXX7R3r6eeNbQe4b2kl0yfnhV2OiMi4yIpAf6Wxhf7BYdaumB92KSIi4ybygT407KzfFGPZgplccYn6tohIdEU+0N/b20HLsdN6xJyIRF7kA31d/X7Kphdwx1WXhF2KiMi4inSgNx89xW/3drD6xmrycyP9U0VEoh3oz22KMcmM1cvVt0VEoi+ygd47MMRLjS3cvricOcVTwi5HRGTcRTbQ39h2gBOnBnhUD7EQkSwR2UCva4hxaVkhKy+dFXYpIiITIpKBvq31BB+1nGDtCvVtEZHsEclAr6uPMSUvh/tvqAq7FBGRCRO5QD9xqp8NH7Vz7/WVFKlvi4hkkcgF+iuNrfQNDuvOUBHJOpEK9OGgb0vtvBksqSgKuxwRkQkVqUB/75MOYkdPsVaXKopIFopUoK9viFE6LZ9VNerbIiLZJzKB3nLsFG/vPsxDN86lIDcn7HJERCZcZAL9+c3NGKhvi4hkrUgEeu/AEC9taeG2xeVUlqhvi4hkp0gE+i+3H+BYT79OhopIVotEoNfVx1hYWshNl5aGXYqISGgyPtC3t3XyfvMJHlkxj0mT1LdFRLJXxgf6+oYYk/Mm8e2l6tsiItktowO989QAr3/Yxr3XVVI8VX1bRCS7ZXSgv/p+K70Dw6xR3xYRkcwN9OFhZ31DjKVzS6ipLA67HBGR0GVsoP/zp0f47EiPLlUUEQlkbKCvq48xszCfu66eE3YpIiJpISMDve3Ead7edYjv3litvi0iIoGMDPTnN8Vw4BH1bREROSupQDezVWa2x8yazOzJUZZ/38w6zOzD4PVvUl9qXN9gvG/LrVfOpmrG1PH6GhGRjJM71gAzywGeAW4HWoEtZrbB3XeOGPqSuz8xDjV+zpvbD3LkZL8uVRQRGSGZPfRlQJO773P3fuBF4J7xLevcCvNzuX1JOV9fVBZWCSIiaWnMPXSgEmhJmG4Flo8y7gEz+zqwF/iBu7eMMuai3baknNuWlI/HR4uIZLRUnRT9BTDf3a8Bfg38w2iDzOxxM2s0s8aOjo4UfbWIiEBygd4GVCdMVwXzznL3o+7eF0z+HXDDaB/k7s+6e62715aV6ZCJiEgqJRPoW4BFZrbAzPKBh4ANiQPMLPHunruBXakrUUREkjHmMXR3HzSzJ4C3gBzgJ+6+w8yeBhrdfQPwx2Z2NzAIHAO+P441i4jIKMzdQ/ni2tpab2xsDOW7RUQylZltdffa0ZZl5J2iIiLyRQp0EZGIUKCLiEREaMfQzawDiF3gv14KHElhOZlK6yFO6yFO6yE71sE8dx/1uu/QAv1imFnjuU4KZBOthzithzitB60DHXIREYkIBbqISERkaqA/G3YBaULrIU7rIU7rIcvXQUYeQxcRkS/K1D10EREZQYEuIhIRGRfoYz3fNCrMrNrM3jWznWa2w8z+JJg/08x+bWafBP+cEcw3M/vrYL1sM7Ol4f6C1DKzHDP7wMzeCKYXmNmm4Pe+FHQCxcwKgummYPn8UAtPITMrMbNXzWy3me0ys5XZtj2Y2Q+C/x62m9kLZjY5G7eFc8moQE94vumdwBJgtZktCbeqcTMI/Km7LwFWAH8Y/NYngbfdfRHwdjAN8XWyKHg9Dvx44kseV3/C59sy/3fgR+5+GXAceCyY/xhwPJj/o2BcVPwV8Ka7XwlcS3x9ZM32YGaVwB8Dte5eQ7z760Nk57YwOnfPmBewEngrYfop4Kmw65qg3/5z4g/q3gPMCebNAfYE7/8GWJ0w/uy4TH8Rf6jK28A3gDcAI343YO7I7YJ4m+eVwfvcYJyF/RtSsA6Kgc9G/pZs2h74/eMwZwZ/tm8Ad2TbtnC+V0btoTP6800rQ6plwgR/Vbwe2ASUu/uBYNFB4MwDVqO8bv4H8B+B4WB6FnDC3QeD6cTfenY9BMs7g/GZbgHQAfyf4NDT35lZIVm0Pbh7G/AXQDNwgPif7Vayb1s4p0wL9KxjZtOAnwL/1t27Epd5fNcj0tedmtk3gcPuvjXsWkKWCywFfuzu1wM9/P7wChD97SE4P3AP8f+5VQCFwKpQi0ozmRboYz7fNErMLI94mD/n7q8Fsw+deeRf8M/DwfyorpubgLvNbD/wIvHDLn8FlJjZmSduJf7Ws+shWF4MHJ3IgsdJK9Dq7puC6VeJB3w2bQ+3AZ+5e4e7DwCvEd8+sm1bOKdMC/Qxn28aFWZmwP8Gdrn7XyYs2gB8L3j/PeLH1s/MfzS4umEF0JnwV/GM5e5PuXuVu88n/uf9jrs/ArwLfDsYNnI9nFk/3w7GZ/xeq7sfBFrM7Ipg1q3ATrJre2gGVpjZ1OC/jzPrIKu2hfMK+yD+l30BdwF7gU+BPwu7nnH8nV8l/tfnbcCHwesu4scA3wY+AX4DzAzGG/ErgD4FPiZ+JUDovyPF6+QW4I3g/UJgM9AEvAIUBPMnB9NNwfKFYdedwt9/HdAYbBOvAzOybXsA/guwG9gO1AEF2bgtnOulW/9FRCIi0w65iIjIOSjQRUQiQoEuIhIRCnQRkYhQoIuIRIQCXUQkIhToIiIR8f8BOpHLTjKrpzgAAAAASUVORK5CYII=" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Using the model\n", + "We can now call the model to get the predictions." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 19, + "source": [ + "test = pd.read_json('sport1_prepared_valid.jsonl', lines=True)\n", + "test.head()" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " prompt completion\n", + "0 From: gld@cunixb.cc.columbia.edu (Gary L Dare)... hockey\n", + "1 From: smorris@venus.lerc.nasa.gov (Ron Morris ... hockey\n", + "2 From: golchowy@alchemy.chem.utoronto.ca (Geral... hockey\n", + "3 From: krattige@hpcc01.corp.hp.com (Kim Krattig... baseball\n", + "4 From: warped@cs.montana.edu (Doug Dolven)\\nSub... baseball" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
promptcompletion
0From: gld@cunixb.cc.columbia.edu (Gary L Dare)...hockey
1From: smorris@venus.lerc.nasa.gov (Ron Morris ...hockey
2From: golchowy@alchemy.chem.utoronto.ca (Geral...hockey
3From: krattige@hpcc01.corp.hp.com (Kim Krattig...baseball
4From: warped@cs.montana.edu (Doug Dolven)\\nSub...baseball
\n", + "
" + ] + }, + "metadata": {}, + "execution_count": 19 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 30, + "source": [ + "ft_model = 'ada:ft-openai-internal-2021-07-26-11-24-00'\n", + "res = openai.Completion.create(model=ft_model, prompt=test['prompt'][0] + '\\n\\n###\\n\\n', max_tokens=1, temperature=0)\n", + "res['choices'][0]['text']\n" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "' hockey'" + ] + }, + "metadata": {}, + "execution_count": 30 + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "To get the log probabilities, we can specify logprobs parameter on the completion request" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 29, + "source": [ + "res = openai.Completion.create(model=ft_model, prompt=test['prompt'][0] + '\\n\\n###\\n\\n', max_tokens=1, temperature=0, logprobs=2)\n", + "res['choices'][0]['logprobs']['top_logprobs'][0]" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " JSON: {\n", + " \" baseball\": -6.3311357,\n", + " \" hockey\": -0.0018503045\n", + "}" + ] + }, + "metadata": {}, + "execution_count": 29 + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "We can see that the model predicts hockey as a lot more likely than baseball, which is the correct prediction. By requesting log_probs, we can see the prediction (log) probability for each class." + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "### Generalization\n", + "Interestingly, our fine-tuned classifier is quite versatile. Despite being trained on emails to different mailing lists, it also successfully predicts tweets." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 28, + "source": [ + "sample_hockey_tweet = \"\"\"Thank you to the \n", + "@Canes\n", + " and all you amazing Caniacs that have been so supportive! You guys are some of the best fans in the NHL without a doubt! Really excited to start this new chapter in my career with the \n", + "@DetroitRedWings\n", + " !!\"\"\"\n", + "res = openai.Completion.create(model=ft_model, prompt=sample_hockey_tweet + '\\n\\n###\\n\\n', max_tokens=1, temperature=0, logprobs=2)\n", + "res['choices'][0]['text']" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "' hockey'" + ] + }, + "metadata": {}, + "execution_count": 28 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 31, + "source": [ + "sample_baseball_tweet=\"\"\"BREAKING: The Tampa Bay Rays are finalizing a deal to acquire slugger Nelson Cruz from the Minnesota Twins, sources tell ESPN.\"\"\"\n", + "res = openai.Completion.create(model=ft_model, prompt=sample_baseball_tweet + '\\n\\n###\\n\\n', max_tokens=1, temperature=0, logprobs=2)\n", + "res['choices'][0]['text']" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "' baseball'" + ] + }, + "metadata": {}, + "execution_count": 31 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [], + "outputs": [], + "metadata": {} + } + ], + "metadata": { + "orig_nbformat": 4, + "language_info": { + "name": "python", + "version": "3.7.3", + "mimetype": "text/x-python", + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python", + "file_extension": ".py" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.7.3 64-bit ('base': conda)" + }, + "interpreter": { + "hash": "3b138a8faad971cc852f62bcf00f59ea0e31721743ea2c5a866ca26adf572e75" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/openai/cli.py b/openai/cli.py index cb8143067f..2b57ab0bd2 100644 --- a/openai/cli.py +++ b/openai/cli.py @@ -322,7 +322,6 @@ def create(cls, args): def get(cls, args): resp = openai.FineTune.retrieve(id=args.id) print(resp) - print(resp["result_files"][0]) @classmethod def results(cls, args): @@ -417,6 +416,7 @@ def prepare_data(cls, args): sys.stdout.write("Analyzing...\n") fname = args.file + auto_accept = args.quiet df, remediation = read_any_format(fname) apply_necessary_remediation(None, remediation) @@ -439,18 +439,32 @@ def prepare_data(cls, args): or remediation.necessary_msg is not None ] ) + any_necessary_applied = any( + [ + remediation + for remediation in optional_remediations + if remediation.necessary_msg is not None + ] + ) + any_optional_applied = False if any_optional_or_necessary_remediations: sys.stdout.write( "\n\nBased on the analysis we will perform the following actions:\n" ) - for remediation in optional_remediations: - df = apply_optional_remediation(df, remediation) + df, optional_applied = apply_optional_remediation( + df, remediation, auto_accept + ) + any_optional_applied = any_optional_applied or optional_applied else: sys.stdout.write("\n\nNo remediations found.\n") - write_out_file(df, fname, any_optional_or_necessary_remediations) + any_optional_or_necessary_applied = ( + any_optional_applied or any_necessary_applied + ) + + write_out_file(df, fname, any_optional_or_necessary_applied, auto_accept) def tools_register(parser): @@ -471,6 +485,13 @@ def help(args): help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed." "This should be the local file path.", ) + sub.add_argument( + "-q", + "--quiet", + required=False, + action="store_true", + help="Auto accepts all suggestions, without asking for user input. To be used within scripts.", + ) sub.set_defaults(func=FineTune.prepare_data) diff --git a/openai/validators.py b/openai/validators.py index 5abee6d175..2a900cefd5 100644 --- a/openai/validators.py +++ b/openai/validators.py @@ -153,6 +153,36 @@ def optional_fn(x): ) +def long_examples_validator(df): + """ + This validator will suggest to the user to remove examples that are too long. + """ + immediate_msg = None + optional_msg = None + optional_fn = None + + ft_type = infer_task_type(df) + if ft_type != "open-ended generation": + long_examples = df.apply( + lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1 + ) + long_indexes = df.reset_index().index[long_examples].tolist() + + if len(long_indexes) > 0: + immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens." + optional_msg = f"Remove {len(long_indexes)} long examples" + + def optional_fn(x): + return x.drop(long_indexes) + + return Remediation( + name="long_examples", + immediate_msg=immediate_msg, + optional_msg=optional_msg, + optional_fn=optional_fn, + ) + + def common_prompt_suffix_validator(df): """ This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation. @@ -210,7 +240,7 @@ def add_suffix(x, suffix): immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix" else: - immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See `Fine Tuning How to Guide` for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty" + immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty" if common_suffix == "": optional_msg = ( @@ -361,7 +391,7 @@ def add_suffix(x, suffix): immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending" else: - immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See `Fine Tuning How to Guide` for more detail and examples." + immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples." if common_suffix == "": optional_msg = ( @@ -396,7 +426,7 @@ def add_space_start(x): immediate_msg = None if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ": - immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See `Fine Tuning How to Guide` for more details" + immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details" optional_msg = "Add a whitespace character to the beginning of the completion" optional_fn = add_space_start return Remediation( @@ -430,7 +460,7 @@ def lower_case(x): if count_upper * 2 > count_lower: return Remediation( name="lower_case", - immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See `Fine Tuning How to Guide` for more details", + immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details", optional_msg=f"Lowercase all your data in column/key `{column}`", optional_fn=lower_case, ) @@ -534,19 +564,81 @@ def apply_necessary_remediation(df, remediation): return df -def apply_optional_remediation(df, remediation): +def accept_suggestion(input_text, auto_accept): + sys.stdout.write(input_text) + if auto_accept: + sys.stdout.write("Y") + return True + return input().lower() != "n" + + +def apply_optional_remediation(df, remediation, auto_accept): """ This function will apply an optional remediation to a dataframe, based on the user input. """ + optional_applied = False + input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: " if remediation.optional_msg is not None: - if input(f"- [Recommended] {remediation.optional_msg} [Y/n]: ").lower() != "n": + if accept_suggestion(input_text, auto_accept): df = remediation.optional_fn(df) + optional_applied = True if remediation.necessary_msg is not None: sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n") - return df + return df, optional_applied + + +def estimate_fine_tuning_time(df): + """ + Estimate the time it'll take to fine-tune the dataset + """ + ft_format = infer_task_type(df) + expected_time = 1.0 + if ft_format == "classification": + num_examples = len(df) + expected_time = num_examples * 1.44 + else: + size = df.memory_usage(index=True).sum() + expected_time = size * 0.0515 + + def format_time(time): + if time < 60: + return f"{round(time, 2)} seconds" + elif time < 3600: + return f"{round(time / 60, 2)} minutes" + elif time < 86400: + return f"{round(time / 3600, 2)} hours" + else: + return f"{round(time / 86400, 2)} days" + + time_string = format_time(expected_time + 140) + sys.stdout.write( + f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n" + ) + + +def get_outfnames(fname, split): + suffixes = ["_train", "_valid"] if split else [""] + i = 0 + while True: + index_suffix = f" ({i})" if i > 0 else "" + candidate_fnames = [ + fname.split(".")[0] + "_prepared" + suffix + index_suffix + ".jsonl" + for suffix in suffixes + ] + if not any(os.path.isfile(f) for f in candidate_fnames): + return candidate_fnames + i += 1 + + +def get_classification_hyperparams(df): + n_classes = df.completion.nunique() + pos_class = None + if n_classes == 2: + pos_class = df.completion.value_counts().index[0] + return n_classes, pos_class -def write_out_file(df, fname, any_remediations): +def write_out_file(df, fname, any_remediations, auto_accept): """ This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file. For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set. @@ -556,16 +648,16 @@ def write_out_file(df, fname, any_remediations): common_completion_suffix = get_common_xfix(df.completion, xfix="suffix") split = False + input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: " if ft_format == "classification": - if ( - input( - "- [Recommended] Would you like to split into training and validation set? [Y/n]: " - ) - != "n" - ): + if accept_suggestion(input_text, auto_accept): split = True - packing_param = " --no_packing" if ft_format == "classification" else "" + classification_params = "" + if ft_format == "classification" or ( + ft_format == "conditional generation" and len(df) < 1000 + ): + classification_params = " --no_packing" common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n") common_completion_suffix_new_line_handled = common_completion_suffix.replace( "\n", "\\n" @@ -576,67 +668,55 @@ def write_out_file(df, fname, any_remediations): else "" ) + input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: " + if not any_remediations: sys.stdout.write( - f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{packing_param}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n' + f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{classification_params}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n' ) + estimate_fine_tuning_time(df) + + elif accept_suggestion(input_text, auto_accept): + fnames = get_outfnames(fname, split) + if split: + assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1] + MAX_VALID_EXAMPLES = 1000 + n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8)) + df_train = df.sample(n=n_train, random_state=42) + df_valid = df.drop(df_train.index) + df_train[["prompt", "completion"]].to_json( + fnames[0], lines=True, orient="records", force_ascii=False + ) + df_valid[["prompt", "completion"]].to_json( + fnames[1], lines=True, orient="records", force_ascii=False + ) - elif ( - input( - "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: " - ).lower() - != "n" - ): - - suffixes = ["_train", "_valid"] if split else [""] - outfnames = [] - indices = None - for suffix in suffixes: - out_fname = fname.split(".")[0] + "_prepared" + suffix + ".jsonl" - - # check if file already exists, and if it does, add a number to the end - i = 0 - while True: - to_continue = False - # in case of train and test, make sure that the numbers will match - for suf in suffixes: - out_fname = ( - fname.split(".")[0] + "_prepared" + suf + f" ({i})" + ".jsonl" - ) - if i == 0: - out_fname = fname.split(".")[0] + "_prepared" + suf + ".jsonl" - i += 1 - if os.path.isfile(out_fname): - to_continue = True - if to_continue: - continue - break - - outfnames.append(out_fname) - if suffix == "_train": - MAX_VALID_EXAMPLES = 1000 - n = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8)) - df_out = df.sample(n=n, random_state=42) - indices = df_out.index - if suffix == "_valid": - df_out = df.drop(indices) - if suffix == "": - df_out = df - df_out[["prompt", "completion"]].to_json( - out_fname, lines=True, orient="records" + n_classes, pos_class = get_classification_hyperparams(df) + classification_params += " --compute_classification_metrics" + if n_classes == 2: + classification_params += ( + f' --classification_positive_class "{pos_class}"' + ) + else: + classification_params += f" --classification_n_classes {n_classes}" + else: + assert len(fnames) == 1 + df[["prompt", "completion"]].to_json( + fnames[0], lines=True, orient="records", force_ascii=False ) # Add -v VALID_FILE if we split the file into train / valid - files_string = ("s" if split else "") + " to `" + ("` and `".join(outfnames)) - valid_string = f' -v "{outfnames[1]}"' if split else "" + files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames)) + valid_string = f' -v "{fnames[1]}"' if split else "" separator_reminder = ( "" if len(common_prompt_suffix_new_line_handled) == 0 else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt." ) sys.stdout.write( - f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{outfnames[0]}"{valid_string}{packing_param}\n\n{separator_reminder}{optional_ending_string}\n' + f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{classification_params}\n\n{separator_reminder}{optional_ending_string}\n' ) + estimate_fine_tuning_time(df) else: sys.stdout.write("Aborting... did not write the file\n") @@ -688,6 +768,7 @@ def get_validators(): non_empty_completion_validator, format_inferrer_validator, duplicated_rows_validator, + long_examples_validator, lambda x: lower_case_validator(x, "prompt"), lambda x: lower_case_validator(x, "completion"), common_prompt_suffix_validator, diff --git a/openai/version.py b/openai/version.py index 72662860aa..96dd69e4f3 100644 --- a/openai/version.py +++ b/openai/version.py @@ -1 +1 @@ -VERSION = "0.10.1" +VERSION = "0.10.2"