# Async rate limiting in Python: the waterwheel algorithm

What I'm calling the waterwheel algorithm optimizes the number of calls that can be made against a rate limited resource. Like a waterwheel that turns and dumps out all the water in each bucket, the waterwheel algorithm bursts function calls at the start of a period. The waterwheel is based on the [token bucket algorithm](https://en.wikipedia.org/wiki/Token_bucket), with a small modification. 

This post describes the waterwheel algorithm and a simple implementation in Python that can be easily ported to most other languages.

## The challenge

Most services that expose an API implement a rate limit to ensure that some users don't overwhelm the service. Most hardware is fast enough these days to easily exceed those rate limits. But network IO can cause bumpy delays resulting in API calls being slower than the set rate limit. 

A solution in most programming languages would be to call the API asynchronously, using the network IO delay to make more calls so that the total time is about equal to the time to receive the slowest call (rather than the sum of times for all calls). But asynchronous calls will quickly overwhelm just about every API. So consumers need to implement their own rate limiting algorithm that matches the API's.

When I searched for a solution I found solutions that either [used semaphores](https://rednafi.github.io/reflections/limit-concurrency-with-semaphore-in-python-asyncio.html) to limit the number of concurrent calls or [implemented a leaky bucket algorithm](https://stackoverflow.com/questions/48682147/aiohttp-rate-limiting-parallel-requests) to only allow an average number of requests per period. The former doesn't allow you to set a rate per period and the second is slower than it needs to be.

## The Waterwheel algorithm

The Waterwheel is similar to Token Bucket in that it uses a set number of "tokens" (or workers) from a "bucket"; to do work, a function must remove a token, and if no tokens remain then the function must wait (or a exception is raised, depending on context). The modification is that instead of adding token to the bucket every $1/s$ seconds, tokens are returned to the bucket when a function completes, plus a delay, specified ahead of time.

This simple modification ensures that only a set number of functions can run at a time and that the next function in line can only run after a predetermined delay. 

Next I'll show you how this is implemented with some examples

## Implementation

The algorithm is implemented using a Semaphore to represent tokens with a delay before the semaphore is released.

In [67]:
import asyncio
from contextlib import asynccontextmanager, contextmanager
import multiprocessing
import threading
from typing import Callable

import numpy as np
import pandas as pd
import pendulum
import plotly.express as px
from pprint import pprint

In [54]:
# First, create a function to represent the API
async def client(i):
    return {'call': i, 'time': pendulum.now().time()}

# Instantiate the semaphore with the number of concurrent calls that are allowed
sem = asyncio.Semaphore(5)
period = 1

# create a function to call the client
async def caller(i):
    try:
        await sem.acquire()
        response = await client(i)
    finally:
        await asyncio.sleep(period)
        sem.release()
    return response

tasks = [caller(i) for i in range(15)]
results = await asyncio.gather(*tasks)    


In [None]:
px.scatter(
    x=[float(r['time'].format('s.SSSS')) for r in results],
    labels={
        'x': 'Call time (s)',
        'y': 'Call number'
    }
    )

### The pythonic way

Python allows you to package that code into a context manager that automatically handles the opening and closing. It can be simplified as follows

In [58]:
@asynccontextmanager
async def aioburst(semaphore: asyncio.Semaphore, period: int | float):
    async with semaphore:
        try:
            yield
        finally:
            await asyncio.sleep(period)

This may not seem much simpler until you realize that you can use the whole function as a context manager.

In [66]:
# create a function to call the client
async def caller(i, limiter: Callable):
    async with limiter:
        response = await client(i)
    return response


tasks = [caller(i, aioburst(sem, 1)) for i in range(15)]
pprint(await asyncio.gather(*tasks))


[{'call': 0, 'time': Time(0, 44, 35, 94808)},
 {'call': 1, 'time': Time(0, 44, 35, 94886)},
 {'call': 2, 'time': Time(0, 44, 35, 94919)},
 {'call': 3, 'time': Time(0, 44, 35, 94947)},
 {'call': 4, 'time': Time(0, 44, 35, 94973)},
 {'call': 5, 'time': Time(0, 44, 36, 96409)},
 {'call': 6, 'time': Time(0, 44, 36, 96488)},
 {'call': 7, 'time': Time(0, 44, 36, 96513)},
 {'call': 8, 'time': Time(0, 44, 36, 96531)},
 {'call': 9, 'time': Time(0, 44, 36, 96548)},
 {'call': 10, 'time': Time(0, 44, 37, 97943)},
 {'call': 11, 'time': Time(0, 44, 37, 98032)},
 {'call': 12, 'time': Time(0, 44, 37, 98064)},
 {'call': 13, 'time': Time(0, 44, 37, 98087)},
 {'call': 14, 'time': Time(0, 44, 37, 98108)}]


## Taking it into the real world

In the real world we have network IO, meaning that calls don't return instantly. The Waterwheel Algorith was build to deal with those situations as well.

Next we'll reimplement our client with a random delay.

In [68]:
async def client(i):
    # Random delay of up to 2 seconds
    delay = np.random.rand() * 2
    await asyncio.sleep(delay)
    return {'call': i, 'time': pendulum.now().time()}

In [78]:
tasks = [caller(i, aioburst(sem, 1)) for i in range(15)]
results = await asyncio.gather(*tasks)


In [79]:
seconds = [
    {
        'time': float(r['time'].format('s.SSSS')),
        'call': r['call']
        } for r in results]
res_df = pd.DataFrame(seconds)

In [82]:
px.scatter(
    res_df,
    x='time',
    y='call',
    labels={
        'x': 'Call time (s)',
        'y': 'Call number'
    }
    )

In [90]:
res_df.set_index('time').rolling(1, min_periods=0).count()

Unnamed: 0_level_0,call
time,Unnamed: 1_level_1
52.4393,1.0
53.2708,1.0
53.6611,1.0
52.3119,1.0
53.1203,1.0
54.1512,1.0
54.2761,1.0
54.1535,1.0
55.0402,1.0
56.6595,1.0


In [95]:
ts_df = pd.DataFrame([{
    'time': str(r['time']),
    'call': r['call']
    }
    for r in results])
ts_df['time'] = pd.to_datetime(ts_df.time)

In [99]:
ts_df.set_index('time').rolling('1s').count()

ValueError: index values must be monotonic