Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Zero Shot notebook #394

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions sdk_blueprints/Gretel_Prompting_Llama_2_at_Scale.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zredlined/3091bc76da6a43654e43abe5c46a799a/batch-prompting-llama-2-with-gretel-gpt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"\n",
"<br>\n",
"\n",
"<center><a href=https://gretel.ai/><img src=\"https://assets-global.website-files.com/5ec4696a9b6d337d51632638/651c817db814b3623bc572e2_DES-380%20Prompting%20Llama%202.jpg\" alt=\"Gretel\" width=\"350\"/></a></center>\n",
"\n",
"<br>\n",
"\n",
"##Prompting Llama-2 at Scale with Gretel\n",
"\n",
"Discover how to efficiently use Gretel's platform for prompting Llama-2 on large datasets, whether you're completing answers, generating synthetic text, or labeling.\n",
"\n",
"In this Blueprint, we will leverage Gretel's platform to **prompt** a Llama-2-7B parameter large language model (LLM) in batch mode. If you already worked through the tabular Blueprints, you will notice that Gretel's SDK interface is the same for all data modalities, so you only need to learn it once!\n",
"\n",
"As with the previous Blueprints, we will submit training and generation jobs to the Gretel Cloud, which will spin up the compute resources required for fine tuning and prompting an LLM.\n",
"\n",
"## In the right place?\n",
"\n",
"If you are new to Gretel, we recommend starting with these [SDK Blueprints](https://github.com/gretelai/gretel-blueprints/tree/main/sdk_blueprints):\n",
"\n",
"1. [Gretel 101 Blueprint](https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/sdk_blueprints/Gretel_101_Blueprint.ipynb).\n",
"\n",
"2. [Gretel Advanced Tabular Blueprint](https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/sdk_blueprints/Gretel_Advanced_Tabular_Blueprint.ipynb)\n",
"\n",
"**Note:** You will need a [free Gretel account](https://console.gretel.ai/) to run this notebook.\n",
"\n",
"<br>\n",
"\n",
"#### Ready? Let's go 🚀"
],
"metadata": {
"id": "0USw9vGnQdSb"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "94SQRbhuQKtq"
},
"outputs": [],
"source": [
"!pip install -Uqq gretel-client datasets"
]
},
{
"cell_type": "markdown",
"source": [
"## 🛜 Configure your Gretel session\n",
"\n",
"- Each `Gretel` instance is bound to a single [Gretel project](https://docs.gretel.ai/guides/gretel-fundamentals/projects). \n",
"\n",
"- You can set the project name at instantiation, or you can use the `set_project` method.\n",
"\n",
"- If you do not set the project, a random project will be created with your first job submission.\n",
"\n",
"\n",
"- You can retrieve your API key [here](https://console.gretel.ai/users/me/key)."
],
"metadata": {
"id": "QgfwgE1xllzp"
}
},
{
"cell_type": "code",
"source": [
"from gretel_client import Gretel\n",
"\n",
"gretel = Gretel(project_name=\"text-gen\", api_key=\"prompt\", validate=True)"
],
"metadata": {
"id": "NVGC4XqR6Nbn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# @title 🗂️ Set the dataset path\n",
"import pandas as pd\n",
"\n",
"\n",
"# Configuration for displaying pandas dataframe\n",
"pd.set_option('display.max_colwidth', 500)\n",
"\n",
"# Constants\n",
"MODEL = 'meta-llama/Llama-2-7b-chat-hf' # @param {type:\"string\"}\n",
"NUM_RECORDS = 250 # @param {type:\"integer\"}\n",
"DATASET = 'gsm8k' # @param [\"gsm8k\"]"
],
"metadata": {
"id": "DAbxjJmCVMYN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We'll use the [GSM8K](https://paperswithcode.com/dataset/gsm8k) (Grade School Math) dataset of 8k real questions to query the LLM. These are problems that take between 2 and 8 steps to solve, primarily performing a sequence of elementary calculations using arithmetic to reach the final answer, creating a good benchmark for assessing LLMs ability to solve everyday real-world multi-step problems."
],
"metadata": {
"id": "aoeXayotV50Q"
}
},
{
"cell_type": "markdown",
"source": [
"## 📊 Prepare your prompt data\n",
"\n",
"- Load the GSM8k dataset from Huggingface Datasets.\n",
"\n",
"- Use a prompt template to load the questions from GSM8k into the expected Llama2 prompt format."
],
"metadata": {
"id": "AYSAYEPW8Apz"
}
},
{
"cell_type": "code",
"source": [
"import textwrap\n",
"\n",
"from datasets import load_dataset\n",
"\n",
"# Function to format prompt according to Llama2 expected chat instruction format\n",
"def format_prompt(prompt: str) -> str:\n",
" llama_template = textwrap.dedent(f\"\"\"\\\n",
" <s>[INST] <<SYS>>You provide just the answer you are asked for with no preamble. Do not repeat the question. Be succinct.<</SYS>>\n",
"\n",
" {prompt} [/INST]\n",
" \"\"\")\n",
"\n",
" return llama_template\n",
"\n",
"# Load dataset with the 'main' configuration and get the first NUM_RECORDS questions\n",
"dataset = load_dataset('gsm8k', 'main')\n",
"questions = dataset['train']['question'][:NUM_RECORDS]\n",
"\n",
"# Add the Llama2 instruction format to each prompt\n",
"formatted_prompts = [format_prompt(q) for q in questions]\n",
"\n",
"# Convert the instructions to a dataframe format\n",
"questions = pd.DataFrame(data={'text': questions})\n",
"instructions = pd.DataFrame(data={'text': formatted_prompts})\n",
"\n",
"# Print a random sample question and formatted instruction\n",
"random_idx = instructions.sample(n=1).index[0]\n",
"print(f\"Random question:\\n```{questions.loc[random_idx]['text']}```\\n\\n\")\n",
"print(f\"Instruction:\\n```{instructions.loc[random_idx]['text']}```\")"
],
"metadata": {
"id": "zZ0pCqaM6H1c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 🎛️ Initialize the LLM\n",
"\n",
"- We use the natural language [base config](https://github.com/gretelai/gretel-blueprints/blob/main/config_templates/gretel/synthetics/natural-language.yml) by setting `base_config=\"natural-language\"`.\n",
"\n",
"- As the goal is zero/few-shot prompting (i.e., no fine tuning), set `data_source=None`."
],
"metadata": {
"id": "Pj81Gdwvl5HS"
}
},
{
"cell_type": "code",
"source": [
"pretrained = gretel.submit_train(\n",
" base_config=\"natural-language\",\n",
" pretrained_model=MODEL,\n",
" data_source=None\n",
" )"
],
"metadata": {
"id": "Kuzp63XMV4vR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 🤖 Generate synthetic data\n",
"\n",
"- You can pass any of Gretel GPT's [`generate` parameters](https://docs.gretel.ai/reference/synthetics/models/gretel-gpt#data-generation) as keyword arguments in the `submit_generate` method.\n",
"\n"
],
"metadata": {
"id": "3lXo8BEA6gXz"
}
},
{
"cell_type": "code",
"source": [
"generated = gretel.submit_generate(\n",
" pretrained.model_id,\n",
" maximum_text_length=250,\n",
" seed_data=instructions,\n",
" temperature=0.8\n",
")"
],
"metadata": {
"id": "Azevt2CFnsPj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"results = generated.synthetic_data\n",
"\n",
"# Combine original questions and results into a single DataFrame\n",
"df = pd.DataFrame({'question': questions['text'], 'answer': results['text']})\n",
"df.to_csv('questions_and_answers.csv', index=False)\n",
"df"
],
"metadata": {
"id": "JGO3shf8oYZr"
},
"execution_count": null,
"outputs": []
}
]
}
2 changes: 1 addition & 1 deletion use_cases/gretel.json
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
},
"button2": {
"label": "Zero-Shot Prompting Notebook",
"link": "https://gist.github.com/zredlined/3091bc76da6a43654e43abe5c46a799a"
"link": "https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/sdk_blueprints/Gretel_Prompting_Llama_2_at_Scale.ipynb"
}
},
{
Expand Down
Loading