/
01_generate_prompts.py
169 lines (123 loc) · 5.44 KB
/
01_generate_prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Databricks notebook source
# MAGIC %md
# MAGIC # Generating questions
# MAGIC
# MAGIC In this notebook we are going to generate questions which we will use during the training phase. We are going to use [Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) to generate them. We will use Databricks Foundational Models API for that.
# COMMAND ----------
# MAGIC %pip install langchain langchain-community
# MAGIC dbutils.library.restartPython()
# COMMAND ----------
# MAGIC %md
# MAGIC Let's specify the topic list we will use during the question generation
# COMMAND ----------
import re
import json
import random
from typing import Union, List
from langchain_community.chat_models.databricks import ChatDatabricks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
topic_list = ["Nutritious", "Plant-Based", "Meal Planning", "Cooking Techniques", "Vegetarianism",
"Global Dishes", "Seasonal Recipes", "Kids' Meals", "Vegan", "Environmental Impact",
"Diet Myths", "Special Diets", "Dining Out", "Athlete Nutrition", "Homemade Snacks",
"Budget-Friendly", "Wine Pairing", "Different Cultures", "Bodybuilding", "Holiday Recipes",
"Exotic Cuisine", "High Calorie", "Healthy Food", "Low Cost", "Fresh Ingredience",
"Mediterranean", "Indian", "Asian", "African", "South American",
"Popular", "Fine Dining", "Table Manner", "Michelin Star", "French",
"Bread", "Noodles", "Healthy", "Unhealthy", "Substantial",
"Culinary Diversity", "Innovative Dish", "Fusion", "Seasonal", "Tasting Menu",
"Herbs", "Homestyle", "Organic", "Locally Sourced", "Farm-to-Table",
"Heirloom", "Spicy", "Authentic Flavors", "Traditional Recipes", "Mouthwatering"]
# COMMAND ----------
# MAGIC %md Now we can specify the prompt template and the LangChain chain which we will use to call Databricks Foundational Models API
# COMMAND ----------
prompt_template_str = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an AI assistant that specializes in food.
Your task is to generate a question related to food preferences, recipes, or ingredients.
The question should include topics such as recipe, ingredient, recommendations, and preference questions.
Generate 1 question based on the topics provided in the instructions. Do not generate more than 1 question.
Below is an example of a question.
Always format the output in JSON format as follows:
```json
{{
"question": "What are some ingredients for a quick dinner preparation?"
}}
```
<|eot_id|><|start_header_id|>user<|end_header_id|>
topic: {topic} <|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
prompt = PromptTemplate(template=prompt_template_str, input_variables=["question"])
llm = ChatDatabricks(endpoint="databricks-meta-llama-3-70b-instruct", temperature=0.1)
chain = (prompt | llm | StrOutputParser()).with_retry(
stop_after_attempt=100, wait_exponential_jitter=False
)
chain.invoke(", ".join(random.sample(topic_list, 2)))
# COMMAND ----------
# MAGIC %md Since the model can generate some arbitrary text together with json, we will implement here some helper functions which can cut off non json part of the response
# COMMAND ----------
def parse(s: str) -> str:
"""
Tries parsing string into a json array
:param s: string to parse
:return: parsed list of questions
"""
try:
resp = json.loads(extract_json_array(s))
if resp:
return resp
else:
return None
except Exception as e:
return None
def extract_json_array(s: str) -> str:
"""
Strips json array from the surrounding text
:param s: string with json
:return: string which contains just an array
"""
groups = re.search(r"\{.*}", s, re.DOTALL)
if groups:
return groups.group()
else:
return s
# COMMAND ----------
# MAGIC %md Now let's generate 100 questions
# COMMAND ----------
questions = []
while len(questions) < 100:
response = parse(chain.invoke(', '.join(random.sample(topic_list,2))))
if response:
questions.append(response)
# COMMAND ----------
import pandas as pd
df = pd.DataFrame(questions).rename(columns={0:"question"})
df = spark.createDataFrame(df)
df.write.saveAsTable("rlaif.data.prompts_holdout")
display(df)
# COMMAND ----------
# MAGIC %md
# MAGIC ## Let's use Llama 3 70B hosted on Model Serving for prompt generation
# MAGIC
# MAGIC Now we can use LangChain to do a batch inference in parallel. We can specify the number of parallel requests using max_concurrency parameter.
# COMMAND ----------
questions = []
concurrency = 4
while len(questions) < 10000:
topics = []
for i in range(concurrency):
topics.append(', '.join(random.sample(topic_list,3)))
results = [parse(r) for r in chain.batch(topics, config={"max_concurrency": concurrency})]
results = [r for r in results if r]
questions.extend(results)
# COMMAND ----------
import pandas as pd
df = pd.DataFrame(questions).rename(columns={"question":"prompt"})
display(df)
# COMMAND ----------
# MAGIC %md Now we can store the results so that we can use them for training later
# COMMAND ----------
df.to_csv("/dbfs/tmp/rlaif/data/prompts.csv", index=False)
spark.createDataFrame(df).write.mode("overwrite").saveAsTable("rlaif.data.prompts")
# COMMAND ----------