diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 86061b01..9040255f 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, create_model from mellea.stdlib.base import Component, TemplateRepresentation -from mellea.stdlib.session import get_session +from mellea.stdlib.session import MelleaSession, get_session P = ParamSpec("P") R = TypeVar("R") @@ -154,7 +154,7 @@ def __init__(self, func: Callable[P, R]): def __call__( self, - m=None, + m: MelleaSession | None = None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs, @@ -180,13 +180,11 @@ def __call__( response_model = create_response_format(self._function._func) - response = m.genslot( - slot_copy, model_options=model_options, format=response_model - ) + response = m.act(slot_copy, format=response_model, model_options=model_options) function_response: FunctionResponse[R] = response_model.model_validate_json( - response.value - ) # type: ignore + response.value # type: ignore + ) return function_response.result diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index 651d77a2..ee6ab431 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -48,6 +48,7 @@ def __init__( self.success = success self.sample_generations = sample_generations self.sample_validations = sample_validations + self.sample_actions = sample_actions class SamplingStrategy(abc.ABC): @@ -153,7 +154,7 @@ def select_from_failure( sampled_actions: list[Component], sampled_results: list[ModelOutputThunk], sampled_val: list[list[tuple[Requirement, ValidationResult]]], - ): + ) -> int: """This function returns the index of the result that should be selected as `.value` iff the loop budget is exhausted and no success. Args: @@ -356,17 +357,17 @@ def select_from_failure( @staticmethod def repair( - context: Context, + ctx: Context, past_actions: list[Component], past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], ) -> Component: - assert isinstance(context, LinearContext), ( + assert isinstance(ctx, LinearContext), ( " Need linear context to run agentic sampling." ) # add failed execution to chat history - context.insert_turn(ContextTurn(past_actions[-1], past_results[-1])) + ctx.insert_turn(ContextTurn(past_actions[-1], past_results[-1])) last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]] last_failed_reqs_str = "* " + "\n* ".join( diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index e3b08d1b..3c398db8 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -3,10 +3,8 @@ from __future__ import annotations import contextvars -from collections.abc import Generator -from contextlib import contextmanager from copy import deepcopy -from typing import Any, Literal, Optional +from typing import Any, Literal, overload from mellea.backends import Backend, BaseModelSubclass from mellea.backends.formatter import FormatterBackend @@ -224,85 +222,95 @@ def cleanup(self) -> None: self.reset() self._backend_stack.clear() if hasattr(self.backend, "close"): - self.backend.close() + self.backend.close() # type: ignore def summarize(self) -> ModelOutputThunk: """Summarizes the current context.""" raise NotImplementedError() - def instruct( + @overload + def act( self, - description: str, + action: Component, *, - requirements: list[Requirement | str] | None = None, - icl_examples: list[str | CBlock] | None = None, - grounding_context: dict[str, str | CBlock | Component] | None = None, - user_variables: dict[str, str] | None = None, - prefix: str | CBlock | None = None, - output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: ... + + @overload + def act( + self, + action: Component, + *, + strategy: SamplingStrategy | None = None, + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> SamplingResult: ... + + def act( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = None, return_sampling_results: bool = False, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, ) -> ModelOutputThunk | SamplingResult: - """Generates from an instruction. + """Runs a generic action, and adds both the action and the result to the context. Args: - description: The description of the instruction. - requirements: A list of requirements that the instruction can be validated against. - icl_examples: A list of in-context-learning examples that the instruction can be validated against. - grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. - user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. - prefix: A prefix string or ContentBlock to use when generating the instruction. - output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. - strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + action: the Component from which to generate. + requirements: used as additional requirements when a sampling strategy is provided + strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: If set, the BaseModel to use for constrained decoding. - model_options: Additional model options, which will upsert into the model/backend's defaults. - tool_calls: If true, tool calling is enabled. - """ - requirements = [] if requirements is None else requirements - icl_examples = [] if icl_examples is None else icl_examples - grounding_context = dict() if grounding_context is None else grounding_context - # all instruction options are forwarded to create a new Instruction object - i = Instruction( - description=description, - requirements=requirements, - icl_examples=icl_examples, - grounding_context=grounding_context, - user_variables=user_variables, - prefix=prefix, - output_prefix=output_prefix, - ) + format: if set, the BaseModel to use for constrained decoding. + model_options: additional model options, which will upsert into the model/backend's defaults. + tool_calls: if true, tool calling is enabled. - res = None + Returns: + A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + """ + sampling_result: SamplingResult | None = None generate_logs: list[GenerateLog] = [] + + if return_sampling_results: + assert strategy is not None, ( + "Must provide a SamplingStrategy when return_sampling_results==True" + ) + if strategy is None: result = self.backend.generate_from_context( - i, + action, ctx=self.ctx, format=format, model_options=model_options, generate_logs=generate_logs, tool_calls=tool_calls, ) - - # make sure that one Log is marked as the one related to result assert len(generate_logs) == 1, "Simple call can only add one generate_log" - generate_logs[0].is_final_result = True + generate_logs[-1].is_final_result = True + else: + # Default validation strategy just validates all of the provided requirements. if strategy.validate is None: - strategy.validate = lambda reqs, val_ctx, output: self.validate( # type: ignore - reqs, - output=output, # type: ignore - ) # type: ignore + strategy.validate = lambda reqs, val_ctx, output: self.validate( + reqs, output=output + ) + + # Default generation strategy just generates from context. if strategy.generate is None: strategy.generate = ( - lambda instruction, + lambda sample_action, gen_ctx, g_logs: self.backend.generate_from_context( - instruction, + sample_action, ctx=gen_ctx, format=format, model_options=model_options, @@ -311,35 +319,132 @@ def instruct( ) ) - # sample - res = strategy.sample( - i, self.ctx, i.requirements, generate_logs=generate_logs + if requirements is None: + requirements = [] + + sampling_result = strategy.sample( + action, self.ctx, requirements=requirements, generate_logs=generate_logs ) - # make sure that one Log is marked as the one related to res.result - if res.success: + # make sure that one Log is marked as the one related to sampling_result.result + if sampling_result.success: # if successful, the last log is the one related generate_logs[-1].is_final_result = True else: - # find the one where log.result and res.result match + # Find the log where log.result and sampling_result.result match selected_log = [ - log for log in generate_logs if log.result == res.result + log for log in generate_logs if log.result == sampling_result.result ] assert len(selected_log) == 1, ( "There should only be exactly one log corresponding to the single result. " ) selected_log[0].is_final_result = True - result = res.result + result = sampling_result.result - self.ctx.insert_turn(ContextTurn(i, result), generate_logs=generate_logs) + self.ctx.insert_turn(ContextTurn(action, result), generate_logs=generate_logs) if return_sampling_results: - assert res is not None, "Asking for sampling results without sampling." - return res + assert ( + sampling_result is not None + ) # Needed for the type checker but should never happen. + return sampling_result else: return result + @overload + def instruct( + self, + description: str, + *, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: ... + + @overload + def instruct( + self, + description: str, + *, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> SamplingResult: ... + + def instruct( + self, + description: str, + *, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk | SamplingResult: + """Generates from an instruction. + + Args: + description: The description of the instruction. + requirements: A list of requirements that the instruction can be validated against. + icl_examples: A list of in-context-learning examples that the instruction can be validated against. + grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. + user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. + prefix: A prefix string or ContentBlock to use when generating the instruction. + output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: If set, the BaseModel to use for constrained decoding. + model_options: Additional model options, which will upsert into the model/backend's defaults. + tool_calls: If true, tool calling is enabled. + """ + requirements = [] if requirements is None else requirements + icl_examples = [] if icl_examples is None else icl_examples + grounding_context = dict() if grounding_context is None else grounding_context + + # All instruction options are forwarded to create a new Instruction object. + i = Instruction( + description=description, + requirements=requirements, + icl_examples=icl_examples, + grounding_context=grounding_context, + user_variables=user_variables, + prefix=prefix, + output_prefix=output_prefix, + ) + + return self.act( + i, + requirements=i.requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) # type: ignore[call-overload] + def chat( self, content: str, @@ -358,34 +463,17 @@ def chat( else: content_resolved = content user_message = Message(role=role, content=content_resolved) - generate_logs: list[GenerateLog] = [] - output_thunk = self.backend.generate_from_context( - action=user_message, - ctx=self.ctx, + + result = self.act( + user_message, format=format, model_options=model_options, - generate_logs=generate_logs, tool_calls=tool_calls, ) - # make sure that the last and only Log is marked as the one related to result - assert len(generate_logs) == 1, "Simple call can only add one generate_log" - generate_logs[0].is_final_result = True - - parsed_assistant_message = output_thunk.parsed_repr - assert type(parsed_assistant_message) is Message - self.ctx.insert_turn( - ContextTurn(user_message, output_thunk), generate_logs=generate_logs - ) - return parsed_assistant_message + parsed_assistant_message = result.parsed_repr + assert isinstance(parsed_assistant_message, Message) - def act(self, c: Component, tool_calls: bool = False) -> Any: - """Runs a generic action, and adds both the action and the result to the context.""" - generate_logs: list[GenerateLog] = [] - result: ModelOutputThunk = self.backend.generate_from_context( - c, self.ctx, generate_logs=generate_logs, tool_calls=tool_calls - ) - self.ctx.insert_turn(turn=ContextTurn(c, result), generate_logs=generate_logs) - return result + return parsed_assistant_message def validate( self, @@ -418,40 +506,6 @@ def validate( return rvs - def genslot( - self, - gen_slot: Component, - model_options: dict | None = None, - format: type[BaseModelSubclass] | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk: - """Call generative Slot on a GenerativeSlot Component. - - Args: - gen_slot (GenerativeSlot Component): A generative slot - - Returns: - ModelOutputThunk: Output thunk - """ - generate_logs: list[GenerateLog] = [] - result: ModelOutputThunk = self.backend.generate_from_context( - action=gen_slot, - ctx=self.ctx, - model_options=model_options, - format=format, - generate_logs=generate_logs, - tool_calls=tool_calls, - ) - # make sure that the last and only Log is marked as the one related to result - assert len(generate_logs) == 1, "Simple call can only add one generate_log" - generate_logs[0].is_final_result = True - - self.ctx.insert_turn( - ContextTurn(deepcopy(gen_slot), result), generate_logs=generate_logs - ) - - return result - def query( self, obj: Any, @@ -479,32 +533,9 @@ def query( assert isinstance(obj, MObjectProtocol) q = obj.get_query_object(query) - generate_logs: list[GenerateLog] = [] - answer = self.backend.generate_from_context( - q, - self.ctx, - format=format, - model_options=model_options, - generate_logs=generate_logs, - tool_calls=tool_calls, + answer = self.act( + q, format=format, model_options=model_options, tool_calls=tool_calls ) - # make sure that the last and only Log is marked as the one related to result - assert len(generate_logs) == 1, "Simple call can only add one generate_log" - generate_logs[0].is_final_result = True - - if isinstance(self.ctx, SimpleContext): - self.ctx.insert_turn(ContextTurn(q, answer), generate_logs=generate_logs) - elif isinstance(self.ctx, LinearContext) and len(self.ctx._ctx) == 0: - FancyLogger.get_logger().info( - "Adding the Object Query and its answer as first turn to a Linear Context (Chat History). " - "You can now run more .chat() or .instruct() with the object as reference." - ) - self.ctx.insert_turn(ContextTurn(q, answer), generate_logs=generate_logs) - else: - FancyLogger.get_logger().info( - "The Linear Context has not been modified by this query." - ) - return answer def transform( @@ -532,43 +563,12 @@ def transform( assert isinstance(obj, MObjectProtocol) t = obj.get_transform_object(transformation) - generate_logs: list[GenerateLog] = [] - # Check that your model / backend supports tool calling. # This might throw an error when tools are provided but can't be handled by one or the other. - transformed = self.backend.generate_from_context( - t, - self.ctx, - format=format, - model_options=model_options, - generate_logs=generate_logs, - tool_calls=True, + transformed = self.act( + t, format=format, model_options=model_options, tool_calls=True ) - assert len(generate_logs) == 1, "Simple call can only add one generate_log" - generate_logs[0].is_final_result = True - - # Insert the new turn into the context. Tool calls are handled afterwards. - insert = False - if isinstance(self.ctx, SimpleContext): - insert = True - self.ctx.insert_turn( - ContextTurn(t, transformed), generate_logs=generate_logs - ) - elif isinstance(self.ctx, LinearContext) and len(self.ctx._ctx) == 0: - insert = True - FancyLogger.get_logger().info( - "Adding the Object Transform and its result as first turn to a Linear Context (Chat History). " - "You can now run more .chat() or .instruct() with the object as reference." - ) - self.ctx.insert_turn( - ContextTurn(t, transformed), generate_logs=generate_logs - ) - else: - FancyLogger.get_logger().info( - "The Linear Context has not been modified by this query." - ) - tools = self._call_tools(transformed) # Transform only supports calling one tool call since it cannot currently synthesize multiple outputs. @@ -597,11 +597,11 @@ def transform( FancyLogger.get_logger().warning( f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" ) - if insert: - self.ctx.insert(chosen_tool) - FancyLogger.get_logger().warning( - "added a tool message from transform to the context as well." - ) + + self.ctx.insert(chosen_tool) + FancyLogger.get_logger().info( + "added a tool message from transform to the context" + ) return chosen_tool._tool_output return transformed