Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/bub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""Bub framework package."""

from __future__ import annotations

import os
from importlib import import_module
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as metadata_version
from pathlib import Path
from typing import TYPE_CHECKING

from bub.framework import BubFramework
from bub.configure import Settings, config, ensure_config
from bub.framework import DEFAULT_HOME, BubFramework
from bub.hookspecs import hookimpl
from bub.tools import tool

__all__ = ["BubFramework", "hookimpl", "tool"]
__all__ = ["BubFramework", "Settings", "config", "ensure_config", "home", "hookimpl", "tool"]

try:
__version__ = import_module("bub._version").version
Expand All @@ -17,3 +23,15 @@
__version__ = metadata_version("bub")
except PackageNotFoundError:
__version__ = "0.0.0"


if TYPE_CHECKING:
home: Path


def __getattr__(name: str):
if name == "home":
if "BUB_HOME" in os.environ:
return Path(os.environ["BUB_HOME"])
return DEFAULT_HOME
raise AttributeError(f"module {__name__} has no attribute {name}")
4 changes: 3 additions & 1 deletion src/bub/builtin/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ def __init__(self, framework: BubFramework) -> None:

@cached_property
def tapes(self) -> TapeService:
import bub

tape_store = self.framework.get_tape_store()
if tape_store is None:
tape_store = InMemoryTapeStore()
tape_store = ForkTapeStore(tape_store)
llm = _build_llm(self.settings, tape_store, self.framework.build_tape_context())
return TapeService(llm, self.settings.home / "tapes", tape_store)
return TapeService(llm, bub.home / "tapes", tape_store)

@staticmethod
def _events_from_iterable(iterable: Iterable) -> AsyncStreamEvents:
Expand Down
5 changes: 2 additions & 3 deletions src/bub/builtin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,9 @@ def _find_uv() -> str:

@lru_cache(maxsize=1)
def _default_project() -> Path:
from .settings import load_settings
import bub

settings = load_settings()
project = settings.home / "bub-project"
project = bub.home / "bub-project"
project.mkdir(exist_ok=True, parents=True)
return project

Expand Down
18 changes: 12 additions & 6 deletions src/bub/builtin/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def __init__(self, framework: BubFramework) -> None:
from bub.builtin import tools # noqa: F401

self.framework = framework
self.agent = Agent(framework)
self._agent: Agent | None = None

def _get_agent(self) -> Agent:
if self._agent is None:
self._agent = Agent(self.framework)
return self._agent

@hookimpl
def resolve_session(self, message: ChannelMessage) -> str:
Expand All @@ -64,7 +69,7 @@ async def load_state(self, message: ChannelMessage, session_id: str) -> State:
lifespan = field_of(message, "lifespan")
if lifespan is not None:
await lifespan.__aenter__()
state = {"session_id": session_id, "_runtime_agent": self.agent}
state = {"session_id": session_id, "_runtime_agent": self._get_agent()}
if context := field_of(message, "context_str"):
state["context"] = context
return state
Expand Down Expand Up @@ -107,11 +112,11 @@ async def build_prompt(self, message: ChannelMessage, session_id: str, state: St

@hookimpl
async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str:
return await self.agent.run(session_id=session_id, prompt=prompt, state=state)
return await self._get_agent().run(session_id=session_id, prompt=prompt, state=state)

@hookimpl
async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents:
return await self.agent.run_stream(session_id=session_id, prompt=prompt, state=state)
return await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state)

@hookimpl
def register_cli_commands(self, app: typer.Typer) -> None:
Expand Down Expand Up @@ -148,7 +153,7 @@ def provide_channels(self, message_handler: MessageHandler) -> list[Channel]:

return [
TelegramChannel(on_receive=message_handler),
CliChannel(on_receive=message_handler, agent=self.agent),
CliChannel(on_receive=message_handler, agent=self._get_agent()),
]

@hookimpl
Expand Down Expand Up @@ -191,9 +196,10 @@ def render_outbound(

@hookimpl
def provide_tape_store(self) -> TapeStore:
import bub
from bub.builtin.store import FileTapeStore

return FileTapeStore(directory=self.agent.settings.home / "tapes")
return FileTapeStore(directory=bub.home / "tapes")

@hookimpl
def build_tape_context(self) -> TapeContext:
Expand Down
42 changes: 18 additions & 24 deletions src/bub/builtin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
import pathlib
import re
from collections.abc import Callable
from functools import lru_cache
from typing import Any, Literal

from pydantic import Field
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, YamlConfigSettingsSource
from pydantic_settings import SettingsConfigDict

from bub import Settings, config, ensure_config

DEFAULT_MODEL = "openrouter:qwen/qwen3-coder-next"
DEFAULT_MAX_TOKENS = 1024
DEFAULT_HOME = pathlib.Path.home() / ".bub"
DEFAULT_CONFIG_FILE = DEFAULT_HOME / "config.yml"


def provider_specific(setting_name: str) -> Callable[[], dict[str, str] | None]:
Expand All @@ -32,11 +31,11 @@ def default_factory() -> dict[str, str] | None:
return default_factory


class AgentSettings(BaseSettings):
@config()
class AgentSettings(Settings):
"""Configuration settings for the Agent."""

model_config = SettingsConfigDict(env_prefix="BUB_", env_parse_none_str="null", extra="ignore")
home: pathlib.Path = Field(default=DEFAULT_HOME)
model: str = DEFAULT_MODEL
fallback_models: list[str] | None = None
api_key: str | dict[str, str] | None = Field(default_factory=provider_specific("api_key"))
Expand All @@ -48,25 +47,20 @@ class AgentSettings(BaseSettings):
client_args: dict[str, Any] | None = None
verbose: int = Field(default=0, description="Verbosity level for logging. Higher means more verbose.", ge=0, le=2)

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
home = os.getenv("BUB_HOME", str(DEFAULT_HOME))
return (
init_settings,
env_settings,
dotenv_settings,
YamlConfigSettingsSource(settings_cls, yaml_file=pathlib.Path(home) / "config.yml"),
file_secret_settings,
@property
def home(self) -> pathlib.Path:
import warnings

import bub

warnings.warn(
"Using the 'home' property from AgentSettings is deprecated. Please use 'bub.home' instead.",
DeprecationWarning,
stacklevel=2,
)

return bub.home


@lru_cache(maxsize=1)
def load_settings() -> AgentSettings:
return AgentSettings()
return ensure_config(AgentSettings)
3 changes: 2 additions & 1 deletion src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rich import get_console
from rich.live import Live

import bub
from bub.builtin.agent import Agent
from bub.builtin.tape import TapeInfo
from bub.channels.base import Channel
Expand Down Expand Up @@ -160,7 +161,7 @@ def _tool_sort_key(tool_name: str) -> tuple[str, str]:
section, _, name = tool_name.rpartition(".")
return (section, name)

history_file = self._history_file(self._agent.settings.home, workspace)
history_file = self._history_file(bub.home, workspace)
history_file.parent.mkdir(parents=True, exist_ok=True)
history = FileHistory(str(history_file))
tool_names = sorted((f",{name}" for name in REGISTRY), key=_tool_sort_key)
Expand Down
9 changes: 6 additions & 3 deletions src/bub/channels/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@

from loguru import logger
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic_settings import SettingsConfigDict
from republic import StreamEvent

from bub import config
from bub.channels.base import Channel
from bub.channels.handler import BufferedMessageHandler
from bub.channels.message import ChannelMessage
from bub.configure import Settings, ensure_config
from bub.envelope import content_of, field_of
from bub.framework import BubFramework
from bub.types import Envelope, MessageHandler
from bub.utils import wait_until_stopped


class ChannelSettings(BaseSettings):
@config()
class ChannelSettings(Settings):
model_config = SettingsConfigDict(env_prefix="BUB_", extra="ignore", env_file=".env")

enabled_channels: str = Field(
Expand Down Expand Up @@ -47,7 +50,7 @@ def __init__(
) -> None:
self.framework = framework
self._channels: dict[str, Channel] = self.framework.get_channels(self.on_receive)
self._settings = ChannelSettings()
self._settings = ensure_config(ChannelSettings)
self._stream_output = stream_output if stream_output is not None else self._settings.stream_output
if enabled_channels is not None:
self._enabled_channels = list(enabled_channels)
Expand Down
9 changes: 6 additions & 3 deletions src/bub/channels/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@

from loguru import logger
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic_settings import SettingsConfigDict
from telegram import Bot, Message, Update
from telegram.ext import Application, CommandHandler, ContextTypes, filters
from telegram.ext import MessageHandler as TelegramMessageHandler
from telegram.request import HTTPXRequest

from bub import config
from bub.channels.base import Channel
from bub.channels.message import ChannelMessage, MediaItem, MediaType
from bub.configure import Settings, ensure_config
from bub.types import MessageHandler
from bub.utils import exclude_none


class TelegramSettings(BaseSettings):
@config(name="telegram")
class TelegramSettings(Settings):
model_config = SettingsConfigDict(env_prefix="BUB_TELEGRAM_", extra="ignore", env_file=".env")

token: str = Field(default="", description="Telegram bot token.")
Expand Down Expand Up @@ -148,7 +151,7 @@ class TelegramChannel(Channel):

def __init__(self, on_receive: MessageHandler) -> None:
self._on_receive = on_receive
self._settings = TelegramSettings()
self._settings = ensure_config(TelegramSettings)
self._allow_users = {uid.strip() for uid in (self._settings.allow_users or "").split(",") if uid.strip()}
self._allow_chats = {cid.strip() for cid in (self._settings.allow_chats or "").split(",") if cid.strip()}
self._parser = TelegramMessageParser(bot_getter=lambda: self._app.bot)
Expand Down
72 changes: 72 additions & 0 deletions src/bub/configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from collections.abc import Callable
from pathlib import Path
from typing import Any

from pydantic_settings import BaseSettings, PydanticBaseSettingsSource

CONFIG_MAP: dict[str, list[type[BaseSettings]]] = {}
ROOT = ""

_global_config: dict[str, list[BaseSettings]] | None = None


class Settings(BaseSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
del settings_cls # unused
return (env_settings, dotenv_settings, init_settings, file_secret_settings)


def config[C: type[BaseSettings]](name: str = ROOT) -> Callable[[C], C]:
"""Decorator to register a config class for a plugin."""

def decorator(cls: C) -> C:
if name not in CONFIG_MAP:
CONFIG_MAP[name] = []
CONFIG_MAP[name].append(cls)
return cls

return decorator


def load(config_file: Path) -> dict[str, list[BaseSettings]]:
"""Load config from a file."""
import yaml

global _global_config
if _global_config is not None:
return _global_config

this_data: dict[str, list[BaseSettings]] = {}

config_data: dict[str, Any] = {}
if config_file.exists():
with config_file.open() as f:
config_data = yaml.safe_load(f) or {}

for name, config_classes in CONFIG_MAP.items():
section_data = config_data if name == ROOT else config_data.get(name, {})
for config_cls in config_classes:
config_instance = config_cls.model_validate(section_data)
this_data.setdefault(name, []).append(config_instance)

_global_config = this_data
return _global_config


def ensure_config[C: BaseSettings](config_cls: type[C]) -> C:
"""No-op function to ensure a config class is registered and can be imported."""
if _global_config is None:
raise RuntimeError("Config not loaded yet")
for config_list in _global_config.values():
for config in config_list:
if isinstance(config, config_cls):
return config
raise ValueError(f"Config class {config_cls} not found in loaded config")
Loading
Loading