Skip to content

Commit

Permalink
Updated Zero Shot notebook (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
arronhunt committed May 31, 2024
1 parent 0ce029a commit d75d838
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 1 deletion.
261 changes: 261 additions & 0 deletions sdk_blueprints/Gretel_Prompting_Llama_2_at_Scale.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zredlined/3091bc76da6a43654e43abe5c46a799a/batch-prompting-llama-2-with-gretel-gpt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"\n",
"<br>\n",
"\n",
"<center><a href=https://gretel.ai/><img src=\"https://assets-global.website-files.com/5ec4696a9b6d337d51632638/651c817db814b3623bc572e2_DES-380%20Prompting%20Llama%202.jpg\" alt=\"Gretel\" width=\"350\"/></a></center>\n",
"\n",
"<br>\n",
"\n",
"##Prompting Llama-2 at Scale with Gretel\n",
"\n",
"Discover how to efficiently use Gretel's platform for prompting Llama-2 on large datasets, whether you're completing answers, generating synthetic text, or labeling.\n",
"\n",
"In this Blueprint, we will leverage Gretel's platform to **prompt** a Llama-2-7B parameter large language model (LLM) in batch mode. If you already worked through the tabular Blueprints, you will notice that Gretel's SDK interface is the same for all data modalities, so you only need to learn it once!\n",
"\n",
"As with the previous Blueprints, we will submit training and generation jobs to the Gretel Cloud, which will spin up the compute resources required for fine tuning and prompting an LLM.\n",
"\n",
"## In the right place?\n",
"\n",
"If you are new to Gretel, we recommend starting with these [SDK Blueprints](https://github.com/gretelai/gretel-blueprints/tree/main/sdk_blueprints):\n",
"\n",
"1. [Gretel 101 Blueprint](https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/sdk_blueprints/Gretel_101_Blueprint.ipynb).\n",
"\n",
"2. [Gretel Advanced Tabular Blueprint](https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/sdk_blueprints/Gretel_Advanced_Tabular_Blueprint.ipynb)\n",
"\n",
"**Note:** You will need a [free Gretel account](https://console.gretel.ai/) to run this notebook.\n",
"\n",
"<br>\n",
"\n",
"#### Ready? Let's go 🚀"
],
"metadata": {
"id": "0USw9vGnQdSb"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "94SQRbhuQKtq"
},
"outputs": [],
"source": [
"!pip install -Uqq gretel-client datasets"
]
},
{
"cell_type": "markdown",
"source": [
"## 🛜 Configure your Gretel session\n",
"\n",
"- Each `Gretel` instance is bound to a single [Gretel project](https://docs.gretel.ai/guides/gretel-fundamentals/projects). \n",
"\n",
"- You can set the project name at instantiation, or you can use the `set_project` method.\n",
"\n",
"- If you do not set the project, a random project will be created with your first job submission.\n",
"\n",
"\n",
"- You can retrieve your API key [here](https://console.gretel.ai/users/me/key)."
],
"metadata": {
"id": "QgfwgE1xllzp"
}
},
{
"cell_type": "code",
"source": [
"from gretel_client import Gretel\n",
"\n",
"gretel = Gretel(project_name=\"text-gen\", api_key=\"prompt\", validate=True)"
],
"metadata": {
"id": "NVGC4XqR6Nbn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# @title 🗂️ Set the dataset path\n",
"import pandas as pd\n",
"\n",
"\n",
"# Configuration for displaying pandas dataframe\n",
"pd.set_option('display.max_colwidth', 500)\n",
"\n",
"# Constants\n",
"MODEL = 'meta-llama/Llama-2-7b-chat-hf' # @param {type:\"string\"}\n",
"NUM_RECORDS = 250 # @param {type:\"integer\"}\n",
"DATASET = 'gsm8k' # @param [\"gsm8k\"]"
],
"metadata": {
"id": "DAbxjJmCVMYN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We'll use the [GSM8K](https://paperswithcode.com/dataset/gsm8k) (Grade School Math) dataset of 8k real questions to query the LLM. These are problems that take between 2 and 8 steps to solve, primarily performing a sequence of elementary calculations using arithmetic to reach the final answer, creating a good benchmark for assessing LLMs ability to solve everyday real-world multi-step problems."
],
"metadata": {
"id": "aoeXayotV50Q"
}
},
{
"cell_type": "markdown",
"source": [
"## 📊 Prepare your prompt data\n",
"\n",
"- Load the GSM8k dataset from Huggingface Datasets.\n",
"\n",
"- Use a prompt template to load the questions from GSM8k into the expected Llama2 prompt format."
],
"metadata": {
"id": "AYSAYEPW8Apz"
}
},
{
"cell_type": "code",
"source": [
"import textwrap\n",
"\n",
"from datasets import load_dataset\n",
"\n",
"# Function to format prompt according to Llama2 expected chat instruction format\n",
"def format_prompt(prompt: str) -> str:\n",
" llama_template = textwrap.dedent(f\"\"\"\\\n",
" <s>[INST] <<SYS>>You provide just the answer you are asked for with no preamble. Do not repeat the question. Be succinct.<</SYS>>\n",
"\n",
" {prompt} [/INST]\n",
" \"\"\")\n",
"\n",
" return llama_template\n",
"\n",
"# Load dataset with the 'main' configuration and get the first NUM_RECORDS questions\n",
"dataset = load_dataset('gsm8k', 'main')\n",
"questions = dataset['train']['question'][:NUM_RECORDS]\n",
"\n",
"# Add the Llama2 instruction format to each prompt\n",
"formatted_prompts = [format_prompt(q) for q in questions]\n",
"\n",
"# Convert the instructions to a dataframe format\n",
"questions = pd.DataFrame(data={'text': questions})\n",
"instructions = pd.DataFrame(data={'text': formatted_prompts})\n",
"\n",
"# Print a random sample question and formatted instruction\n",
"random_idx = instructions.sample(n=1).index[0]\n",
"print(f\"Random question:\\n```{questions.loc[random_idx]['text']}```\\n\\n\")\n",
"print(f\"Instruction:\\n```{instructions.loc[random_idx]['text']}```\")"
],
"metadata": {
"id": "zZ0pCqaM6H1c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 🎛️ Initialize the LLM\n",
"\n",
"- We use the natural language [base config](https://github.com/gretelai/gretel-blueprints/blob/main/config_templates/gretel/synthetics/natural-language.yml) by setting `base_config=\"natural-language\"`.\n",
"\n",
"- As the goal is zero/few-shot prompting (i.e., no fine tuning), set `data_source=None`."
],
"metadata": {
"id": "Pj81Gdwvl5HS"
}
},
{
"cell_type": "code",
"source": [
"pretrained = gretel.submit_train(\n",
" base_config=\"natural-language\",\n",
" pretrained_model=MODEL,\n",
" data_source=None\n",
" )"
],
"metadata": {
"id": "Kuzp63XMV4vR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 🤖 Generate synthetic data\n",
"\n",
"- You can pass any of Gretel GPT's [`generate` parameters](https://docs.gretel.ai/reference/synthetics/models/gretel-gpt#data-generation) as keyword arguments in the `submit_generate` method.\n",
"\n"
],
"metadata": {
"id": "3lXo8BEA6gXz"
}
},
{
"cell_type": "code",
"source": [
"generated = gretel.submit_generate(\n",
" pretrained.model_id,\n",
" maximum_text_length=250,\n",
" seed_data=instructions,\n",
" temperature=0.8\n",
")"
],
"metadata": {
"id": "Azevt2CFnsPj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"results = generated.synthetic_data\n",
"\n",
"# Combine original questions and results into a single DataFrame\n",
"df = pd.DataFrame({'question': questions['text'], 'answer': results['text']})\n",
"df.to_csv('questions_and_answers.csv', index=False)\n",
"df"
],
"metadata": {
"id": "JGO3shf8oYZr"
},
"execution_count": null,
"outputs": []
}
]
}
2 changes: 1 addition & 1 deletion use_cases/gretel.json
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
},
"button2": {
"label": "Zero-Shot Prompting Notebook",
"link": "https://gist.github.com/zredlined/3091bc76da6a43654e43abe5c46a799a"
"link": "https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/sdk_blueprints/Gretel_Prompting_Llama_2_at_Scale.ipynb"
}
},
{
Expand Down

0 comments on commit d75d838

Please sign in to comment.