Skip to content

Commit

Permalink
optimize prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
matrixstone committed Jan 27, 2024
1 parent e98291d commit 28c9546
Showing 1 changed file with 125 additions and 35 deletions.
160 changes: 125 additions & 35 deletions mage_ai/ai/hugging_face_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import os
import re
Expand All @@ -19,31 +20,61 @@
"Content-Type": "application/json"
}

PROMPT_FOR_FUNCTION_PARAMS = """
Based on the code description, answer each question.
`code description: {code_description}`
Question BlockType: What is the purpose of the code? Choose one result from:
['data_loa der', 'data_exporter', 'transformer'].
If the code wants to read data from a data source, it is "data_loader".
If it wants to export data into a data source, it is "data_exporter".
For the rest manipulation data actions, it is "transformer".
Question BlockLanguage: What is the intended programming language for this code?
The default value is Python, but choose one result from: {block_languages}.
PROMPT_FOR_BLOCK_TYPE = """
<s>[INST]Classify code desription into one of the following categories:
"data_loader" - The code is designed to read data from a data source.
"data_exporter" - The code is intended to export data into a data source.
"transformer" - The code performs various data manipulation actions.[/INST]</s>
<s>[INST]classify code description: Read data from MySQL database[/INST]type: data_loader</s>
<s>[INST]classify code description: export data to Postgres database[/INST]type: data_exporter</s>
<s>[INST]classify code description: filter all the records by age smaller \
than 100[/INST]\type: transformer</s>
[INST]classify code description: {code_description}[/INST]
"""

Question PipelineType: What is the pipeline type? The default value is Python,
but choose one result from: {pipeline_types}.
PROMPT_FOR_BLOCK_LANGUAGE = """
<s>[INST]what is the programming language mentioned in the code description?
The default value is python, choose one result from: {block_languages}[/INST]</s>
<s>[INST]code description: Read data from MySQL database [/INST]programming language: python</s>
<s>[INST]code description: export data to Postgres database in SQL \
[/INST]programming language: sql</s>
<s>[INST]code description: write a yaml config to filter all the records by age smaller than 100\
[/INST]programming language: yaml</s>
<s>[INST]code description: filter all the records by age [/INST]programming language: python</s>
[INST]code description: {code_description}[/INST]
"""

Question ActionType: If BlockType is transformer, what is the action this code tries to perform?
Choose one result from: {action_types}
PROMPT_FOR_PIPELINE_TYPES = """
<s>[INST]what is the pipeline type in the code description?
The default value is python, but choose one result from \
following pipeline type: {pipeline_types}.[/INST]</s>
<s>[INST]code description: Read data from MySQL database[/INST]type: python</s>
<s>[INST]code description: export data to Postgres database using a integration pipeline\
[/INST]type: integration</s>
<s>[INST]code description: stream filter all the records[/INST]type: streaming</s>
<s>[INST]code description: using pyspark pipeline to dedup same records[/INST]type: pyspark</s>
[INST]code description: {code_description}[/INST]
"""

Question DataSource: If BlockType is data_loader or data_exporter, where the data loads from or
export to? Choose one result from: {data_sources}
PROMPT_FOR_ACTION_TYPE = """
<s>[INST]what is the action this code description tries to perform?
Choose one result from following action: {action_types}.[/INST]</s>
<s>[INST]code description: average all the score [/INST]action: average</s>
<s>[INST]code description: count number of distinct value [/INST]action: count_distinct</s>
<s>[INST]code description: return the first record [/INST]action: first</s>
<s>[INST]code description: return 5 records [/INST]action: limit</s>
<s>[INST]code description: return the biggest number of score [/INST]action: max</s>
[INST]code description: {code_description}[/INST]
"""

Return your responses in JSON format with the question name as the key and
the answer as the value.
PROMPT_FOR_DATA_SOURCE = """
<s>[INST]where the data source loads from or export to in the code description?
Choose one result from following data sources: {data_sources}.[/INST]</s>
<s>[INST]code description: export data to mysql database[/INST]source: mysql</s>
<s>[INST]code description: read data from sqlserver[/INST]source: sqlserver</s>
<s>[INST]code description: fetch data from opensearch[/INST]source: opensearch</s>
<s>[INST]code description: wrtie all records to google sheets[/INST]source: google_sheets</s>
[INST]code description: {code_description}[/INST]
"""


Expand All @@ -63,37 +94,37 @@ def __init__(self, hf_config: HuggingFaceConfig):

def __parse_function_args(self, function_args: Dict):
try:
block_type = BlockType(function_args[f'Question {BlockType.__name__}'])
block_type = BlockType(function_args[f'{BlockType.__name__}'])
except ValueError:
raise Exception(f'Error not valid BlockType: \
{function_args.get(f"Question {BlockType.__name__}")}')
{function_args.get(f"{BlockType.__name__}")}')
try:
block_language = BlockLanguage(
function_args.get(
f'Question {BlockLanguage.__name__}',
f'{BlockLanguage.__name__}',
'python'))
except ValueError:
print(f'Error not valid BlockLanguage: \
{function_args.get(f"Question {BlockLanguage.__name__}")}')
{function_args.get(f"{BlockLanguage.__name__}")}')
block_language = BlockLanguage.PYTHON
try:
pipeline_type = PipelineType(
function_args.get(
f'Question {PipelineType.__name__}',
f'{PipelineType.__name__}',
'python'))
except ValueError:
print(f'Error not valid PipelineType: \
{function_args.get(f"Question {PipelineType.__name__}")}')
{function_args.get(f"{PipelineType.__name__}")}')
pipeline_type = PipelineType.PYTHON
config = {}
if block_type == BlockType.TRANSFORMER:
try:
config['action_type'] = ActionType(
function_args.get(
f'Question {ActionType.__name__}'))
f'{ActionType.__name__}'))
except ValueError:
print(f'Error not valid ActionType: \
{function_args.get(f"Question {ActionType.__name__}")}')
{function_args.get(f"{ActionType.__name__}")}')
config['action_type'] = None
if config['action_type']:
if config['action_type'] in [
Expand All @@ -109,10 +140,10 @@ def __parse_function_args(self, function_args: Dict):
try:
config['data_source'] = DataSource(
function_args.get(
f'Question {DataSource.__name__}'))
f'{DataSource.__name__}'))
except ValueError:
print(f'Error not valid DataSource: \
{function_args.get(f"Question {DataSource.__name__}")}')
{function_args.get(f"{DataSource.__name__}")}')
output = {}
output['block_type'] = block_type
output['block_language'] = block_language
Expand All @@ -124,14 +155,16 @@ async def inference_with_prompt(
self,
variable_values: Dict[str, str],
prompt_template: str,
is_json_response: bool = True
is_json_response: bool = True,
max_new_tokens: int = 800
):
formated_prompt = prompt_template.format(**variable_values)
data = json.dumps({
'inputs': formated_prompt,
'parameters': {
'temperature': 0.01,
'return_full_text': False,
'max_new_tokens': 800,
'max_new_tokens': max_new_tokens,
'num_return_sequences': 1}})
headers.update(
{'Authorization': f'Bearer {self.api_token}'})
Expand Down Expand Up @@ -165,6 +198,63 @@ async def find_block_params(
[f'{type.name.lower()}' for type in ActionType]
variable_values['data_sources'] = \
[f"{type.name.lower()}" for type in DataSource]
function_params_response = await self.inference_with_prompt(
variable_values, PROMPT_FOR_FUNCTION_PARAMS, is_json_response=True)
function_params = (
'block_type',
'block_language',
'pipeline_types',
)
tasks = []
tasks.append(asyncio.create_task(self.inference_with_prompt(
variable_values,
PROMPT_FOR_BLOCK_TYPE,
is_json_response=False,
max_new_tokens=8)))
tasks.append(asyncio.create_task(self.inference_with_prompt(
variable_values,
PROMPT_FOR_BLOCK_LANGUAGE,
is_json_response=False,
max_new_tokens=5)))
tasks.append(asyncio.create_task(self.inference_with_prompt(
variable_values,
PROMPT_FOR_PIPELINE_TYPES,
is_json_response=False,
max_new_tokens=5)))

results = await asyncio.gather(*tasks)
function_params_response = dict(zip(function_params, results))

# Remove prefix:
function_params_response['block_type'] = \
function_params_response['block_type'].replace('type: ', '', 1)
function_params_response['block_language'] = \
function_params_response['block_language'].replace('programming language: ', '', 1)
function_params_response['pipeline_types'] = \
function_params_response['pipeline_types'].replace('type: ', '', 1)

# Generate data source and action type
data_source = ''
action_type = ''
if function_params_response['block_type'] == 'data_loader' or \
function_params_response['block_type'] == 'data_exporter':
data_source = await self.inference_with_prompt(
variable_values,
PROMPT_FOR_DATA_SOURCE,
is_json_response=False,
max_new_tokens=5)
data_source = data_source.replace('source: ', '', 1)
elif function_params_response['block_type'] == 'transformer':
action_type = await self.inference_with_prompt(
variable_values,
PROMPT_FOR_ACTION_TYPE,
is_json_response=False,
max_new_tokens=5)
action_type = action_type.replace('action: ', '', 1)

function_params_response = {
f'{BlockType.__name__}': function_params_response['block_type'],
f'{BlockLanguage.__name__}': function_params_response['block_language'],
f'{PipelineType.__name__}': function_params_response['pipeline_types'],
f'{ActionType.__name__}': action_type,
f'{DataSource.__name__}': data_source
}
return self.__parse_function_args(function_params_response)

0 comments on commit 28c9546

Please sign in to comment.