Skip to content
25 changes: 24 additions & 1 deletion cli/decompose/decompose.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import json
import keyword
from enum import Enum
from pathlib import Path
from typing import Annotated

import typer

from .pipeline import DecompBackend


# Must maintain declaration order
# Newer versions must be declared on the bottom
class DecompVersion(str, Enum):
latest = "latest"
v1 = "v1"
# v2 = "v2"


this_file_dir = Path(__file__).resolve().parent


Expand Down Expand Up @@ -76,6 +86,13 @@ def run(
)
),
] = None,
version: Annotated[
DecompVersion,
typer.Option(
help=("Version of the mellea program generator template to be used."),
case_sensitive=False,
),
] = DecompVersion.latest,
input_var: Annotated[
list[str] | None,
typer.Option(
Expand All @@ -99,7 +116,13 @@ def run(
environment = Environment(
loader=FileSystemLoader(this_file_dir), autoescape=False
)
m_template = environment.get_template("m_decomp_result.py.jinja2")

ver = (
list(DecompVersion)[-1].value
if version == DecompVersion.latest
else version.value
)
m_template = environment.get_template(f"m_decomp_result_{ver}.py.jinja2")

out_name = out_name.strip()
assert validate_filename(out_name), (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ except KeyError as e:
print(f"ERROR: One or more required environment variables are not set; {e}")
exit(1)
{%- endif %}
{% for item in subtasks%}
{% for item in subtasks %}
{% set i = loop.index0 %}
# {{ item.subtask }} - {{ item.tag }}
subtask_{{ loop.index }} = m.instruct(
{{ item.tag | lower }} = m.instruct(
textwrap.dedent(
R"""
{{ item.prompt_template | trim | indent(width=8, first=False) }}
{{ item.prompt_template | trim | indent(width=8, first=False) }}
""".strip()
),
{%- if item.constraints %}
requirements=[
{%- for con in item.constraints %}
{{ con | tojson}},
{%- for c in item.constraints %}
{{ c.constraint | tojson}},
{%- endfor %}
],
{%- else %}
Expand All @@ -39,22 +39,22 @@ subtask_{{ loop.index }} = m.instruct(
{%- if loop.first and not user_inputs %}
{%- else %}
user_variables={
{%- if user_inputs %}
{%- for var in user_inputs %}
{%- for var in item.input_vars_required %}
{{ var | upper | tojson }}: {{ var | lower }},
{%- endfor %}
{%- endif %}

{%- for j in range(i) %}
{{ subtasks[j].tag | tojson }}: subtask_{{ i }}.value if subtask_{{ i }}.value is not None else "",
{%- for var in item.depends_on %}
{{ var | upper | tojson }}: {{ var | lower }}.value,
{%- endfor %}
},
{%- endif %}
)
assert {{ item.tag | lower }}.value is not None, 'ERROR: task "{{ item.tag | lower }}" execution failed'
{%- if loop.last %}

final_response = subtask_{{ loop.index }}.value

print(final_response)
final_answer = {{ item.tag | lower }}.value

print(final_answer)
{%- endif -%}
{%- endfor -%}
76 changes: 65 additions & 11 deletions cli/decompose/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from enum import Enum
from typing import TypedDict
from typing import Literal, TypedDict

from typing_extensions import NotRequired

Expand All @@ -10,27 +11,37 @@

from .prompt_modules import (
constraint_extractor,
# general_instructions,
subtask_constraint_assign,
subtask_list,
subtask_prompt_generator,
validation_decision,
)
from .prompt_modules.subtask_constraint_assign import SubtaskPromptConstraintsItem
from .prompt_modules.subtask_list import SubtaskItem
from .prompt_modules.subtask_prompt_generator import SubtaskPromptItem


class ConstraintResult(TypedDict):
constraint: str
validation_strategy: str


class DecompSubtasksResult(TypedDict):
subtask: str
tag: str
constraints: list[str]
constraints: list[ConstraintResult]
prompt_template: str
# general_instructions: str
input_vars_required: list[str]
depends_on: list[str]
generated_response: NotRequired[str]


class DecompPipelineResult(TypedDict):
original_task_prompt: str
subtask_list: list[str]
identified_constraints: list[str]
identified_constraints: list[ConstraintResult]
subtasks: list[DecompSubtasksResult]
final_response: NotRequired[str]

Expand All @@ -41,6 +52,9 @@ class DecompBackend(str, Enum):
rits = "rits"


RE_JINJA_VAR = re.compile(r"\{\{\s*(.*?)\s*\}\}")


def decompose(
task_prompt: str,
user_input_variable: list[str] | None = None,
Expand All @@ -53,15 +67,12 @@ def decompose(
if user_input_variable is None:
user_input_variable = []

# region Backend Assignment
match backend:
case DecompBackend.ollama:
m_session = MelleaSession(
OllamaModelBackend(
model_id=model_id,
model_options={
ModelOption.CONTEXT_WINDOW: 32768,
"timeout": backend_req_timeout,
},
model_id=model_id, model_options={ModelOption.CONTEXT_WINDOW: 16384}
)
)
case DecompBackend.openai:
Expand Down Expand Up @@ -96,13 +107,19 @@ def decompose(
model_options={"timeout": backend_req_timeout},
)
)
# endregion

subtasks: list[SubtaskItem] = subtask_list.generate(m_session, task_prompt).parse()

task_prompt_constraints: list[str] = constraint_extractor.generate(
m_session, task_prompt
m_session, task_prompt, enforce_same_words=False
).parse()

constraint_validation_strategies: dict[str, Literal["code", "llm"]] = {
cons_key: validation_decision.generate(m_session, cons_key).parse()
for cons_key in task_prompt_constraints
}

subtask_prompts: list[SubtaskPromptItem] = subtask_prompt_generator.generate(
m_session,
task_prompt,
Expand All @@ -122,15 +139,52 @@ def decompose(
DecompSubtasksResult(
subtask=subtask_data.subtask,
tag=subtask_data.tag,
constraints=subtask_data.constraints,
constraints=[
{
"constraint": cons_str,
"validation_strategy": constraint_validation_strategies[cons_str],
}
for cons_str in subtask_data.constraints
],
prompt_template=subtask_data.prompt_template,
# general_instructions=general_instructions.generate(
# m_session, input_str=subtask_data.prompt_template
# ).parse(),
input_vars_required=list(
dict.fromkeys( # Remove duplicates while preserving the original order.
[
item
for item in re.findall(
RE_JINJA_VAR, subtask_data.prompt_template
)
if item in user_input_variable
]
)
),
depends_on=list(
dict.fromkeys( # Remove duplicates while preserving the original order.
[
item
for item in re.findall(
RE_JINJA_VAR, subtask_data.prompt_template
)
if item not in user_input_variable
]
)
),
)
for subtask_data in subtask_prompts_with_constraints
]

return DecompPipelineResult(
original_task_prompt=task_prompt,
subtask_list=[item.subtask for item in subtasks],
identified_constraints=task_prompt_constraints,
identified_constraints=[
{
"constraint": cons_str,
"validation_strategy": constraint_validation_strategies[cons_str],
}
for cons_str in task_prompt_constraints
],
subtasks=decomp_subtask_result,
)
2 changes: 2 additions & 0 deletions cli/decompose/prompt_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .constraint_extractor import constraint_extractor as constraint_extractor
from .general_instructions import general_instructions as general_instructions
from .subtask_constraint_assign import (
subtask_constraint_assign as subtask_constraint_assign,
)
from .subtask_list import subtask_list as subtask_list
from .subtask_prompt_generator import (
subtask_prompt_generator as subtask_prompt_generator,
)
from .validation_decision import validation_decision as validation_decision
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from mellea import MelleaSession
from mellea.backends.types import ModelOption
from mellea.stdlib.base import CBlock
from mellea.stdlib.instruction import Instruction
from mellea.stdlib.chat import Message

from .._prompt_modules import PromptModule, PromptModuleString
from ._exceptions import BackendGenerationError, TagExtractionError
Expand All @@ -14,7 +13,7 @@
T = TypeVar("T")

RE_VERIFIED_CONS_COND = re.compile(
r"<constraints_and_conditions>(.+?)</constraints_and_conditions>",
r"<constraints_and_requirements>(.+?)</constraints_and_requirements>",
flags=re.IGNORECASE | re.DOTALL,
)

Expand All @@ -33,13 +32,13 @@ def _default_parser(generated_str: str) -> list[str]:
generated_str (`str`): The LLM's answer to be parsed.

Returns:
list[str]: A list of identified constraints in natural language. The list
list[str]: A list of identified constraints and requirements in natural language. The list
will be empty if no constraints were identified by the LLM.

Raises:
TagExtractionError: An error occurred trying to extract content from the
generated output. The LLM probably failed to open and close
the \<constraints_and_conditions\> tags.
the \<constraints_and_requirements\> tags.
"""
constraint_extractor_match = re.search(RE_VERIFIED_CONS_COND, generated_str)

Expand All @@ -51,7 +50,7 @@ def _default_parser(generated_str: str) -> list[str]:

if constraint_extractor_str is None:
raise TagExtractionError(
'LLM failed to generate correct tags for extraction: "<constraints_and_conditions>"'
'LLM failed to generate correct tags for extraction: "<constraints_and_requirements>"'
)

# TODO: Maybe replace this logic with a RegEx?
Expand All @@ -76,13 +75,13 @@ def generate( # type: ignore[override]
self,
mellea_session: MelleaSession,
input_str: str | None,
max_new_tokens: int = 8192,
max_new_tokens: int = 4096,
parser: Callable[[str], T] = _default_parser, # type: ignore[assignment]
# About the mypy ignore above: https://github.com/python/mypy/issues/3737
enforce_same_words: bool = False,
**kwargs: dict[str, Any],
) -> PromptModuleString[T]:
"""Generates an unordered list of identified constraints based on a provided task prompt.
"""Generates an unordered list of identified constraints and requirements based on a provided task prompt.

_**Disclaimer**: This is a LLM-prompting module, so the results will vary depending
on the size and capabilities of the LLM used. The results are also not guaranteed, so
Expand Down Expand Up @@ -112,12 +111,13 @@ def generate( # type: ignore[override]
system_prompt = get_system_prompt(enforce_same_words=enforce_same_words)
user_prompt = get_user_prompt(task_prompt=input_str)

instruction = Instruction(description=user_prompt, prefix=system_prompt)
action = Message("user", user_prompt)

try:
gen_result = mellea_session.act(
action=instruction,
action=action,
model_options={
ModelOption.SYSTEM_PROMPT: system_prompt,
ModelOption.TEMPERATURE: 0,
ModelOption.MAX_NEW_TOKENS: max_new_tokens,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@

example: ICLExample = {
"task_prompt": task_prompt.strip(),
"constraints_and_conditions": [],
"constraints_and_requirements": [],
}

example["constraints_and_conditions"] = [
example["constraints_and_requirements"] = [
"Your answers should not include harmful, unethical, racist, sexist, toxic, dangerous, or illegal content",
"If a question does not make sense, or not factually coherent, explain to the user why, instead of just answering something incorrect",
"You must always answer the user with markdown formatting",
"The markdown formats you can use are the following: heading; link; table; list; code block; block quote; bold; italic",
"When answering with code blocks, include the language",
"The only markdown formats you can use are the following: heading; link; table; list; code block; block quote; bold; italic",
"All HTML tags must be enclosed in block quotes",
"The personas must include the following properties: name; age; occupation; demographics; goals; behaviors; pain points; motivations",
"The assistant must provide a comprehensive understanding of the target audience",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ If a question does not make sense, or not factually coherent, explain to the use

You must always answer the user with markdown formatting.

The markdown formats you can use are the following:
The only markdown formats you can use are the following:
- heading
- link
- table
Expand All @@ -15,7 +15,6 @@ The markdown formats you can use are the following:
- bold
- italic

When answering with code blocks, include the language.
You can be penalized if you write code outside of code blocks.

All HTML tags must be enclosed in block quotes, for example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@

example: ICLExample = {
"task_prompt": task_prompt.strip(),
"constraints_and_conditions": [],
"constraints_and_requirements": [],
}

example["constraints_and_conditions"] = [
"Emphasize the responsibilities and support offered to survivors of crime",
example["constraints_and_requirements"] = [
"Ensure the word 'assistance' appears less than 4 times",
"Wrap the entire response with double quotation marks",
]
Loading