diff --git a/examples/middleware/readme.md b/examples/middleware/readme.md new file mode 100644 index 000000000..bfd79da28 --- /dev/null +++ b/examples/middleware/readme.md @@ -0,0 +1,77 @@ +# Middleware in Instructor + +Middleware in Instructor allows you to modify the messages sent to the language model before they are processed. This is beneficial because it enables you to perform custom preprocessing, add context, or even implement simple retrieval-augmented generation (RAG) techniques. + +Middleware can be defined as simple functions or classes (when you need stateful variables). They are then registered with the Instructor client using the `with_middleware` method. + +## what is middleware? + +Middleware is a way to modify the input or output of a function or method. In the context of language models and AI assistants, middleware allows you to intercept and modify the messages being sent to the model before they are processed. + +Some common use cases for middleware include: + +- Preprocessing the input messages (e.g. cleaning up text, adding context) +- Implementing retrieval augmented generation by fetching relevant information and appending it to the messages +- Filtering or moderating content +- Logging or monitoring the messages being sent to the model + +Middleware functions take in the list of messages, make any desired changes, and return the modified list of messages to be sent to the model. + +Instructor makes it easy to define and use middleware. You can create middleware as simple functions using the `@messages_middleware` decorator, or for more complex stateful middleware you can define a class that inherits from `MessageMiddleware` and implements the `__call__` method. + +Once defined, middleware is registered with the Instructor client using the `with_middleware()` method. This allows chaining multiple middleware together. + +## Simple RAG Example + +Middleware can also be used to implement more advanced techniques like retrieval-augmented generation (RAG). RAG involves retrieving relevant information from an external source and using it to augment the input to the language model. This can help provide additional context and improve the quality and accuracy of the generated responses. + +To implement a simple RAG middleware, you could define a function or class that takes the input messages, performs a retrieval step to find relevant information, and then appends that information to the messages before sending them to the model. For example: + +```python +@instructor.messages_middleware +def add_retrieval_augmentation(messages): + # Perform retrieval step to find relevant information + relevant_information = retrieve_relevant_information(messages) + + # Append the relevant information to the messages + return messages + [{ + "role": "user", + "content": f"Relevant Information: {relevant_information}" + }] +``` + +## Logging and Monitoring +Another useful application of middleware is for logging and monitoring the messages being sent to and received from the language model. This can be helpful for debugging, auditing, or analyzing the conversations. + +To implement logging middleware, you can define a function or class that takes the input messages, logs them to a file or database, and then returns the original messages unmodified. For example: + +```python +@instructor.messages_middleware +def logging_middleware(messages): + import logging + logging.info(f"Input messages: {messages}") + + # Return the original messages unmodified + return messages +``` + +## Stateful Middleware + +For more advanced stateful middleware, you can define a class that inherits from `MessageMiddleware` and implements the `__call__` method. This allows you to maintain state across multiple calls to the middleware. + +For example, let's say you want to implement a middleware that adds user preferences to the messages. You could define a stateful middleware class like this: + +```python +class UserPreferencesMiddleware(MessageMiddleware): + + user_id: str + + def __call__(self, messages): + preferences = get_user_preferences(self.user_id) + for message in messages: + if message.role == "system": + message.content += f"\n\nUser Preferences: {preferences}" + return messages +``` + +As you can see above, middleware provides a flexible way to modify and augment the messages being sent to and received from the language model. This can be used for a variety of purposes, such as adding relevant information, logging and monitoring conversations, and maintaining stateful interactions. \ No newline at end of file diff --git a/examples/middleware/run.py b/examples/middleware/run.py new file mode 100644 index 000000000..1fe1daa1a --- /dev/null +++ b/examples/middleware/run.py @@ -0,0 +1,61 @@ +import instructor +import openai +from openai.types.chat import ChatCompletionMessageParam +from typing import List +from pydantic import BaseModel + + +class PrintLastUserMessage(instructor.MessageMiddleware): + + log: bool = False + + def __call__( + self, messages: List[ChatCompletionMessageParam] + ) -> List[ChatCompletionMessageParam]: + if self.log: + import pprint + + pprint.pprint({"messages": messages}) + return messages + + +@instructor.messages_middleware +def dumb_rag(messages): + # TODO: use RAG to generate a response + # TODO: add the response to the messages + return messages + [ + { + "role": "user", + "content": "Search retrieved: 'Jason is 20 years old'", + } + ] + + +class User(BaseModel): + age: int + name: str + + +client = ( + instructor.from_openai(openai.OpenAI()) + .with_middleware(dumb_rag) + .with_middleware(PrintLastUserMessage(log=True)) # can be called directly +) + + +user = client.chat.completions.create( + model="gpt-4-turbo-preview", + messages=[ + { + "role": "user", + "content": "How old is jason?", + } + ], + response_model=User, +) + +print(user) +# {'messages': [{'content': 'How old is jason?', 'role': 'user'}, +# {'content': "Search retrieved: 'Jason is 20 years old'", +# 'role': 'user'}]} +# {'age': 20, 'name': 'jason'} diff --git a/instructor/__init__.py b/instructor/__init__.py index 4baf453a5..0cde58030 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -13,9 +13,16 @@ from .patch import apatch, patch from .process_response import handle_parallel_model from .client import Instructor, from_openai, from_anthropic, from_litellm +from .messages_middleware import ( + MessageMiddleware, + AsyncMessageMiddleware, + messages_middleware, +) __all__ = [ "Instructor", + "MessageMiddleware", + "AsyncMessageMiddleware", "from_openai", "from_anthropic", "from_litellm", diff --git a/instructor/client.py b/instructor/client.py index 14fbc7bf5..9621d0072 100644 --- a/instructor/client.py +++ b/instructor/client.py @@ -22,7 +22,7 @@ from typing_extensions import Self from pydantic import BaseModel from instructor.dsl.partial import Partial - +from instructor.messages_middleware import MessageMiddleware T = TypeVar("T", bound=(BaseModel | Iterable | Partial)) @@ -47,6 +47,7 @@ def __init__( self.mode = mode self.kwargs = kwargs self.provider = provider + self.message_middleware = [] @property def chat(self) -> Self: @@ -60,6 +61,10 @@ def completions(self) -> Self: def messages(self) -> Self: return self + def with_middleware(self, middleware: MessageMiddleware | Callable) -> Self: + self.message_middleware.append(middleware) + return self + # TODO: we should overload a case where response_model is None def create( self, @@ -69,9 +74,7 @@ def create( validation_context: dict | None = None, **kwargs, ) -> T: - kwargs = self.handle_kwargs(kwargs) - - return self.create_fn( + return self._create( response_model=response_model, messages=messages, max_retries=max_retries, @@ -90,13 +93,9 @@ def create_partial( assert self.provider != Provider.ANTHROPIC, "Anthropic doesn't support partial" kwargs["stream"] = True - - kwargs = self.handle_kwargs(kwargs) - - response_model = instructor.Partial[response_model] # type: ignore - return self.create_fn( + return self._create( messages=messages, - response_model=response_model, + response_model=instructor.Partial[response_model], # type: ignore max_retries=max_retries, validation_context=validation_context, **kwargs, @@ -113,12 +112,9 @@ def create_iterable( assert self.provider != Provider.ANTHROPIC, "Anthropic doesn't support iterable" kwargs["stream"] = True - kwargs = self.handle_kwargs(kwargs) - - response_model = Iterable[response_model] # type: ignore - return self.create_fn( + return self._create( messages=messages, - response_model=response_model, + response_model=Iterable[response_model], max_retries=max_retries, validation_context=validation_context, **kwargs, @@ -132,8 +128,7 @@ def create_with_completion( validation_context: dict | None = None, **kwargs, ) -> Tuple[T, ChatCompletion | Message]: - kwargs = self.handle_kwargs(kwargs) - model = self.create_fn( + model = self._create( messages=messages, response_model=response_model, max_retries=max_retries, @@ -148,6 +143,27 @@ def handle_kwargs(self, kwargs: dict): kwargs[key] = value return kwargs + def _create( + self, + messages: List[ChatCompletionMessageParam], + response_model: Type[T], + max_retries: int = 3, + validation_context: dict | None = None, + **kwargs, + ) -> T: + for middleware in self.message_middleware: + messages = middleware(messages) + + kwargs = self.handle_kwargs(kwargs) + + return self.create_fn( + messages=messages, + response_model=response_model, + max_retries=max_retries, + validation_context=validation_context, + **kwargs, + ) + class AsyncInstructor(Instructor): client: openai.AsyncOpenAI | anthropic.AsyncAnthropic | None diff --git a/instructor/messages_middleware.py b/instructor/messages_middleware.py new file mode 100644 index 000000000..a74d5d6e7 --- /dev/null +++ b/instructor/messages_middleware.py @@ -0,0 +1,33 @@ +from typing import List, Callable +from openai.types.chat import ChatCompletionMessageParam +from abc import ABC, abstractmethod +from pydantic import BaseModel + + +class MessageMiddleware(BaseModel, ABC): + + @abstractmethod + def __call__( + self, messages: List[ChatCompletionMessageParam] + ) -> List[ChatCompletionMessageParam]: + pass + + +class AsyncMessageMiddleware(MessageMiddleware): + + @abstractmethod + async def __call__( + self, messages: List[ChatCompletionMessageParam] + ) -> List[ChatCompletionMessageParam]: + pass + + +def messages_middleware(func: Callable) -> MessageMiddleware: + + class _Middleware(MessageMiddleware): + def __call__( + self, messages: List[ChatCompletionMessageParam] + ) -> List[ChatCompletionMessageParam]: + return func(messages) + + return _Middleware()