In [54]:
"""@prompt"""

import os
import abc
from collections.abc import Callable
from enum import Enum
from functools import update_wrapper
import inspect
import types
from typing import (
    Any,
    Generic,
    Iterable,
    Literal,
    Never,
    ParamSpec,
    Protocol,
    Sequence,
    TypeVar,
    TypedDict,
    Union,
    Unpack,
    cast,
    get_args,
    get_origin,
    overload,
)
from openai import OpenAI, AzureOpenAI
from langchain_openai.chat_models import ChatOpenAI as LangchainChatOpenAI
from langchain_openai.chat_models import AzureChatOpenAI as LangchainAzureChatOpenAI
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
from pydantic import BaseModel


class Role(Enum):
    """Role of conversation."""

    SYSTEM = "system"
    HUMAN = "human"
    AI = "ai"


class Message(TypedDict):
    role: Role
    content: str


class LangfuseCallbackConfig(TypedDict, total=False):
    public_key: str | None
    secret_key: str | None
    host: str | None
    debug: bool
    update_stateful_client: bool
    session_id: str | None
    user_id: str | None
    trace_name: str | None
    release: str | None
    version: str | None
    metadata: dict[str, any] | None
    tags: list[str] | None
    threads: int | None
    flush_at: int | None
    flush_interval: int | None
    max_retries: int | None
    timeout: int | None
    enabled: bool | None
    sdk_integration: str | None
    sample_rate: float | None


P = ParamSpec("P")
R = TypeVar("R")
TypeT = TypeVar("TypeT", bound=type)
MessageLikeType = Union[str, Iterable[tuple[Role, str] | tuple[str, str] | Message]]
ModelType = Union[LangchainChatOpenAI, LangchainAzureChatOpenAI, OpenAI, AzureOpenAI]


class PromptFunction(Generic[P, R]):

    def __init__(
        self,
        name: str,
        parameters: Sequence[inspect.Parameter],
        return_type: type[R],
        messages: Sequence[MessageLikeType],
        functions: list[Callable[..., Any]] | None = None,
        stop: list[str] | None = None,
        max_retries: int = 0,
        model: ModelType | None = None,
    ):
        self._name = name
        self._signature = inspect.Signature(
            parameters=parameters,
            return_annotation=return_type,
        )
        self._messages = messages
        self._functions = functions or []
        self._stop = stop
        self._max_retries = max_retries
        self._model = model

        self._return_types = list(self.split_union_type(return_type))

    def is_union_type(self, type_: type) -> bool:
        type_ = get_origin(type_) or type_
        return type_ is Union or type_ is types.UnionType

    def split_union_type(self, type_: TypeT) -> Sequence[TypeT]:
        return get_args(type_) if self.is_union_type(type_) else [type_]


class OpenAIPromptFunction(PromptFunction):

    def __init__(
        self,
        name: str,
        parameters: Sequence[inspect.Parameter],
        return_type: type[R],
        messages: Sequence[MessageLikeType],
        functions: list[Callable[..., Any]] | None = None,
        stop: list[str] | None = None,
        max_retries: int = 0,
        model: ModelType | None = None,
        model_name: Literal["gpt-4o", "gpt-4o-mini", "gpt"] | None = None,
        langfuse_config: Unpack[LangfuseCallbackConfig] | None = None,
    ) -> None:
        self._model_name = model_name
        self._langfuse_config = langfuse_config
        super().__init__(name, parameters, return_type, messages, functions, stop, max_retries, model)

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
        print(f"{self._messages = }")
        print(f"{args = }")
        print(f"{kwargs = }")
        bound_args = self._signature.bind(*args, **kwargs)
        bound_args.apply_defaults()
        print(f"Bound arguments: {bound_args.arguments}")

        messages = self._format(self._messages, args)
        print(f"Formatted messages: {messages}")
        model = self._get_model()

        config = {}
        config.update(self._langfuse_config or {})

        resp = model.chat.completions.create(messages=messages, model=self._model_name, **config, response_format=...)

        return self._model.chat()

    def run(self, model: OpenAI | AzureOpenAI, messages: list[MessageLikeType]) -> str: ...

    def _get_model(self) -> OpenAI:
        if not is_openai_model(self._model):
            raise ValueError("Invalid OpenAI model.")
        return cast(OpenAI, self._model)

    def format(self, *args: P.args, **kwargs: P.kwargs) -> list[MessageLikeType]:
        bound_args = self._signature.bind(*args, **kwargs)
        bound_args.apply_defaults()
        formatted_messages: list[MessageLikeType] = []
        for message_template in self._messages:
            if isinstance(message_template, str):
                formatted_messages.append(message_template.format(**bound_args.arguments))
            else:
                # Assuming message_template is iterable of tuples or Message
                formatted_message = []
                for item in message_template:
                    if isinstance(item, tuple):
                        formatted_message.append(
                            tuple(arg.format(**bound_args.arguments) if isinstance(arg, str) else arg for arg in item)
                        )
                    elif isinstance(item, dict):  # Message TypedDict
                        formatted_message.append(
                            {"role": item["role"], "content": item["content"].format(**bound_args.arguments)}
                        )
                formatted_messages.append(formatted_message)
        return formatted_messages


class LangchainOpenAIPromptFunction(PromptFunction):

    def __init__(
        self,
        name: str,
        parameters: Sequence[inspect.Parameter],
        return_type: type[R],
        messages: Sequence,
        llm: LangchainChatOpenAI | LangchainAzureChatOpenAI,
        callbacks: LangfuseCallbackHandler,
    ) -> None:
        self._callbacks = callbacks
        super().__init__(name, parameters, return_type, messages, llm)


class PromptDecorator(Protocol):

    def __call__(self, func: Callable[P, R]) -> PromptFunction[P, R]: ...


def env_error(env_var: str) -> Never:
    raise ValueError(f"{env_var} environment variable is not set.")


def is_openai_model(model: ModelType) -> bool:
    return isinstance(model, OpenAI) or isinstance(model, AzureOpenAI)


def is_langchain_model(model: ModelType) -> bool:
    return isinstance(model, LangchainChatOpenAI) or isinstance(model, LangchainAzureChatOpenAI)


def get_openai_model() -> OpenAI:
    api_key = os.getenv("OPENAI_API_KEY")
    if api_key is None:
        env_error("OPENAI_API_KEY")

    base_url = os.getenv("OPENAI_BASE_URL")
    if base_url is None:
        env_error("OPENAI_BASE_URL")

    return OpenAI(api_key=api_key, base_url=base_url)


def get_langchain_openai_model() -> LangchainChatOpenAI:
    api_key = os.getenv("AZURE_OPENAI_API_KEY")
    if api_key is None:
        env_error("AZURE_OPENAI_API_KEY")

    azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
    if azure_endpoint is None:
        env_error("AZURE_OPENAI_ENDPOINT")

    api_version = os.getenv("AZURE_OPENAI_API_VERSION")
    if api_version is None:
        env_error("AZURE_OPENAI_API_VERSION")

    model = os.getenv("AZURE_OPENAI_MODEL_NAME")
    if model is None:
        env_error("AZURE_OPENAI_MODEL_NAME")

    return LangchainChatOpenAI(
        api_key=api_key,
        azure_endpoint=azure_endpoint,
        api_version=api_version,
        model=model,
    )


def prompt(
    *messages: Sequence[MessageLikeType],
    model: ModelType | None = None,
    model_name: str | None = None,
    langfuse_config: Unpack[LangfuseCallbackConfig] | None = None,
    temperature: float = 0.0,
) -> PromptDecorator:
    model = model or get_openai_model()

    def decorator(func: Callable[P, R]) -> PromptFunction[P, R]:
        func_signature = inspect.signature(func)
        print(f"Function signature: {func_signature}")
        print(func_signature.parameters)
        print(f"Function return type: {func_signature.return_annotation}")

        if is_langchain_model(model):
            prompt_function = LangchainOpenAIPromptFunction(
                name=func.__name__,
                parameters=func_signature.parameters.values(),
                messages=messages,
                model=model,
                langfuse_config=langfuse_config,
            )
            return cast(LangchainOpenAIPromptFunction, update_wrapper(prompt_function, func))

        prompt_function = OpenAIPromptFunction(
            name=func.__name__,
            parameters=func_signature.parameters.values(),
            return_type=func_signature.return_annotation,
            messages=messages,
            model=model,
        )
        return cast(OpenAIPromptFunction, update_wrapper(prompt_function, func))

    return cast(PromptDecorator, decorator)

In [55]:
import os

from openai import OpenAI, AzureOpenAI
from langchain_openai.chat_models import ChatOpenAI as LangchainChatOpenAI
from langchain_openai.chat_models import AzureChatOpenAI as LangchainAzureChatOpenAI
from pydantic import BaseModel

openai_model = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL"),
)

langchain_openai_model = LangchainAzureChatOpenAI(
    api_key=os.getenv("COMPANION_PROXY_API_KEY"),
    azure_endpoint=os.getenv("COMPANION_PROXY_ENDPOINT"),
    api_version=os.getenv("COMPANION_PROXY_API_VERSION"),
    model=os.getenv("COMPANION_PROXY_MODEL_NAME"),
    temperature=0,
)


class foo(BaseModel):
    bar: str


@prompt(
    "Translate the given text to Chinese.\n",
    "TEXT: ```{text}```",
    model=openai_model,
)
def translate_to_chinese(text: str) -> foo: ...


translate_to_chinese("Hello, how are you?")

Function signature: (text: str) -> __main__.foo
OrderedDict({'text': <Parameter "text: str">})
Function return type: <class '__main__.foo'>
self._messages = ('Translate the given text to Chinese.\n', 'TEXT: ```{text}```')
args = ('Hello, how are you?',)
kwargs = {}
Bound arguments: {'text': 'Hello, how are you?'}


NameError: name 'bound_args' is not defined

In [37]:
import os
from dotenv import load_dotenv

load_dotenv()

True

## OpenAI Example

In [35]:
import os

from openai import OpenAI, AzureOpenAI
from langchain_openai.chat_models import ChatOpenAI as LangchainChatOpenAI
from langchain_openai.chat_models import AzureChatOpenAI as LangchainAzureChatOpenAI
from pydantic import BaseModel

openai_model = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL"),
)

langchain_openai_model = LangchainAzureChatOpenAI(
    api_key=os.getenv("COMPANION_PROXY_API_KEY"),
    azure_endpoint=os.getenv("COMPANION_PROXY_ENDPOINT"),
    api_version=os.getenv("COMPANION_PROXY_API_VERSION"),
    model=os.getenv("COMPANION_PROXY_MODEL_NAME"),
    temperature=0,
)


class foo(BaseModel):
    bar: str


@prompt(
    "Translate the given text to Chinese.",
    "TEXT: ```{text}```",
    model=openai_model,
)
def translate_to_chinese(text: str) -> foo: ...


translate_to_chinese("Hello, how are you?")

Function signature: (text: str) -> __main__.foo
Function parameters: odict_values([<Parameter "text: str">])
Function return type: <class '__main__.foo'>
Calling translate_to_chinese with `('Hello, how are you?',)` and `{}`


AttributeError: 'OpenAIPromptFunction' object has no attribute 'model'

## LangChain Example