## Rate limiting

Sample notebook that showcases rate limiting.

Required imports:

In [1]:
# Copyright (c) 2024 Microsoft Corporation.

import asyncio
import random
import string
from collections.abc import Iterable
from threading import Thread
from unittest.mock import AsyncMock

import tiktoken
from fnllm import LLMUsageTracker
from fnllm.openai import AzureOpenAIConfig, create_openai_chat_llm
from fnllm.openai.types import (
    OpenAIChatCompletionMessageModel,
    OpenAIChatCompletionMessageParam,
    OpenAIChatCompletionModel,
    OpenAIChoiceModel,
    OpenAICompletionUsageModel,
)

To showcase the rate limiting we are **mocking the LLM response**, always **echoing back as response what was sent as request**.

In [2]:
def _create_echo_response(sleep_range: tuple[int, int]):
    async def _echo_response_wrapper(
        messages: Iterable[OpenAIChatCompletionMessageParam], **kwargs
    ) -> OpenAIChatCompletionModel:
        content = str(list(messages)[-1].get("content", "No result"))
        n_tokens = len(tiktoken.get_encoding("cl100k_base").encode(content))

        await asyncio.sleep(random.randrange(*sleep_range))  # noqa: S311

        return OpenAIChatCompletionModel(
            id="completion_id",
            choices=[
                OpenAIChoiceModel(
                    finish_reason="stop",
                    index=0,
                    message=OpenAIChatCompletionMessageModel(
                        content=content, role="assistant"
                    ),
                )
            ],
            created=0,
            model=kwargs.get("model", ""),
            object="chat.completion",
            usage=OpenAICompletionUsageModel(
                completion_tokens=n_tokens,
                prompt_tokens=n_tokens,
                total_tokens=2 * n_tokens,
            ),
        )

    return _echo_response_wrapper


def _mock_echo_client(sleep_range: tuple[int, int]):
    mock = AsyncMock()
    mock.chat.completions.create = _create_echo_response(sleep_range)
    return mock

Defining a printer thread that prints the `LLMUsageTracker` information periodically:

In [3]:
class _UsagePrinter:
    def __init__(self, update_interval: int, tracker: LLMUsageTracker):
        self._update_interval = update_interval
        self._tracker = tracker
        self._run_timer = False
        self._thread: Thread | None = None

    async def run(self):
        while self._run_timer:
            await asyncio.sleep(self._update_interval)
            print(
                f"rpm={await self._tracker.current_rpm():.2f}, tpm={await self._tracker.current_tpm():.2f}, avg_rpm={await self._tracker.avg_rpm():.2f}, avg_tpm={await self._tracker.avg_tpm():.2f}, requests={self._tracker.total_requests:.2f}, usage={self._tracker.total_usage.total_tokens:.2f}, concurrency={self._tracker.current_concurrency:.2f}, max_concurrency={self._tracker.max_concurrency:.2f}"
            )

    def start(self):
        if not self._thread:
            self._run_timer = True
            self._thread = Thread(target=asyncio.run, args=(self.run(),))
            self._thread.start()

    def stop(self):
        if self._thread:
            self._run_timer = False
            self._thread.join()


def _random_str(size: int) -> str:
    letters = string.ascii_lowercase + " "
    return "".join(random.choice(letters) for _ in range(size))  # noqa: S311

Parameter definitions:

In [4]:
# input parameters
N_INPUTS = 100
INPUT_SIZE_RANGE = (1000, 5000)

# response parameters
SLEEP_RANGE = (0, 5)

# tracking parameters
UPDATE_INTERVAL = 5

# limiting parameters
RPM = 30
TPM = 80000
MAX_CONCURRENCY = 25
BURST_MODE = True

Creating the inputs, llm and printer thread according to the configured parameters:

In [5]:
inputs = [_random_str(random.randint(*INPUT_SIZE_RANGE)) for i in range(N_INPUTS)]  # noqa: S311

tracker = LLMUsageTracker.create()

llm = create_openai_chat_llm(
    config=AzureOpenAIConfig(
        endpoint="",
        api_version="",
        model="",
        requests_per_minute=RPM,
        tokens_per_minute=TPM,
        max_concurrency=MAX_CONCURRENCY,
        requests_burst_mode=BURST_MODE,
    ),
    client=_mock_echo_client(SLEEP_RANGE),
    events=tracker,
)

printer = _UsagePrinter(UPDATE_INTERVAL, tracker)

printer.start()

Calling the LLM for all the inputs:

In [6]:
result = await asyncio.gather(*(llm(entry) for entry in inputs))

rpm=32.00, tpm=85581.00, avg_rpm=32.00, avg_tpm=85581.00, requests=30.00, usage=88474.00, concurrency=2.00, max_concurrency=25.00
rpm=34.00, tpm=92126.00, avg_rpm=34.00, avg_tpm=92126.00, requests=32.00, usage=95276.00, concurrency=2.00, max_concurrency=25.00
rpm=36.00, tpm=98839.00, avg_rpm=36.00, avg_tpm=98839.00, requests=36.00, usage=105854.00, concurrency=0.00, max_concurrency=25.00
rpm=39.00, tpm=104470.00, avg_rpm=39.00, avg_tpm=104470.00, requests=37.00, usage=109274.00, concurrency=2.00, max_concurrency=25.00
rpm=41.00, tpm=111343.00, avg_rpm=41.00, avg_tpm=111343.00, requests=40.00, usage=114770.00, concurrency=1.00, max_concurrency=25.00
rpm=43.00, tpm=118312.00, avg_rpm=43.00, avg_tpm=118312.00, requests=42.00, usage=121306.00, concurrency=1.00, max_concurrency=25.00
rpm=45.00, tpm=126318.00, avg_rpm=45.00, avg_tpm=126318.00, requests=45.00, usage=128516.00, concurrency=0.00, max_concurrency=25.00
rpm=47.00, tpm=132683.00, avg_rpm=47.00, avg_tpm=132683.00, requests=45.00, u

In [7]:
printer.stop()