Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion test/stdlib/test_spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from mellea.stdlib.session import MelleaSession, start_session

# Module-level markers for all tests using Granite 4 hybrid micro (3B model)
pytestmark = [pytest.mark.huggingface, pytest.mark.requires_gpu, pytest.mark.llm]
pytestmark = [
pytest.mark.huggingface,
pytest.mark.requires_gpu,
pytest.mark.requires_heavy_ram,
pytest.mark.llm,
]


# We edit the context type in the async tests below. Don't change the scope here.
Expand Down
42 changes: 30 additions & 12 deletions test/telemetry/test_metrics_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,33 @@ def enable_metrics(monkeypatch):
importlib.reload(mellea.telemetry.metrics)


@pytest.fixture(scope="module")
def hf_metrics_backend(gh_run):
"""Shared HuggingFace backend for telemetry metrics tests.

Uses module scope to load the model once and reuse it across all tests,
preventing memory exhaustion from loading multiple model instances.
"""
if gh_run:
pytest.skip("Skipping HuggingFace backend creation in CI")

from mellea.backends.cache import SimpleLRUCache
from mellea.backends.huggingface import LocalHFBackend

backend = LocalHFBackend(
model_id=IBM_GRANITE_4_HYBRID_MICRO.hf_model_name, # type: ignore
cache=SimpleLRUCache(5),
)

yield backend

# Cleanup
import gc

del backend
gc.collect()


def get_metric_value(metrics_data, metric_name, attributes=None):
"""Helper to extract metric value from metrics data.

Expand Down Expand Up @@ -297,15 +324,12 @@ async def test_litellm_token_metrics_integration(
@pytest.mark.asyncio
@pytest.mark.llm
@pytest.mark.huggingface
@pytest.mark.requires_heavy_ram
@pytest.mark.parametrize("stream", [False, True], ids=["non-streaming", "streaming"])
async def test_huggingface_token_metrics_integration(
enable_metrics, metric_reader, stream, gh_run
enable_metrics, metric_reader, stream, hf_metrics_backend
):
"""Test that HuggingFace backend records token metrics correctly."""
if gh_run:
pytest.skip("Skipping in CI - requires model download")

from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.model_options import ModelOption
from mellea.telemetry import metrics as metrics_module

Expand All @@ -315,17 +339,11 @@ async def test_huggingface_token_metrics_integration(
metrics_module._input_token_counter = None
metrics_module._output_token_counter = None

from mellea.backends.cache import SimpleLRUCache

backend = LocalHFBackend(
model_id=IBM_GRANITE_4_HYBRID_MICRO.hf_model_name, # type: ignore
cache=SimpleLRUCache(5),
)
ctx = SimpleContext()
ctx = ctx.add(Message(role="user", content="Say 'hello' and nothing else"))

model_options = {ModelOption.STREAM: True} if stream else {}
mot, _ = await backend.generate_from_context(
mot, _ = await hf_metrics_backend.generate_from_context(
Message(role="assistant", content=""), ctx, model_options=model_options
)

Expand Down
Loading