# Performance optimizations

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time
import asyncio
from arsenal.timer import timeit

## Autobatching concurrent requests


In [3]:
from genlm_control.potential import Potential


class TimedPotential(Potential):
    async def complete(self, context):
        time.sleep(1)
        return len(context)

    async def prefix(self, context):
        time.sleep(1)
        return len(context)

    # Batched methods are much quicker than sequentially
    # calling the instance methods.

    async def batch_complete(self, contexts):
        time.sleep(1.05)
        return [len(context) for context in contexts]

    async def batch_prefix(self, contexts):
        time.sleep(1.05)
        return [len(context) for context in contexts]


potential = TimedPotential(list(range(256)))

  from .autonotebook import tqdm as notebook_tqdm
2025-02-04 12:53:50,537	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [4]:
autobatched = potential.to_autobatched()
autobatched

AutoBatchedPotential(<__main__.TimedPotential object at 0x7fe79eec3390>)

In [None]:
sequences = [b"hello", b"cats", b"foo", b"fy"]

# Concurrent requests to complete will be automatically batched
# and processed by the batch_complete method.

with timeit("without autobatching"):
    results = await asyncio.gather(*(potential.complete(seq) for seq in sequences))

with timeit("with autobatching"):
    results_autobatched = await asyncio.gather(
        *(autobatched.complete(seq) for seq in sequences)
    )

without autobatching (4.0010 sec)
with autobatching (1.0503 sec)


In [6]:
# Results are the same whether we use autobatching or not.
results, results_autobatched

([5, 4, 3, 2], [5, 4, 3, 2])

## CPU Parallelization

In [12]:
class TimedPotential(Potential):
    async def complete(self, context):
        time.sleep(1)
        return len(context)

    async def prefix(self, context):
        time.sleep(1)
        return len(context)

    # These are the default implementations of batch_complete and batch_prefix.
    # Unless overridden, they call the instance methods concurrently.
    # We show them here for clarity.
    async def batch_complete(self, contexts):
        return await asyncio.gather(*(self.complete(context) for context in contexts))

    async def batch_prefix(self, contexts):
        return await asyncio.gather(*(self.prefix(context) for context in contexts))

    def spawn(self):
        return TimedPotential(self.decode)


potential = TimedPotential(list(range(256)))

In [18]:
mp_potential = potential.to_multiprocess(num_workers=2)
mp_potential

<genlm_control.potential.mp.MPPotential at 0x7fe6b4e162d0>

In [None]:
with timeit("without multiprocessing"):
    results = await potential.batch_complete(sequences)

with timeit("with multiprocessing"):
    results_mp = await mp_potential.batch_complete(sequences)

without multiprocessing (4.0012 sec)
with multiprocessing (2.0022 sec)


In [None]:
results, results_mp

In [None]:
with timeit("without multiprocessing"):
    results = await asyncio.gather(*(potential.complete(seq) for seq in sequences))

with timeit("with multiprocessing"):
    results_mp = await asyncio.gather(
        *(mp_potential.complete(seq) for seq in sequences)
    )

without multiprocessing (4.0010 sec)
with multiprocessing (2.0108 sec)


In [14]:
results, results_mp

([5, 4, 3, 2], array([5, 4, 3, 2]))

In [27]:
import numpy as np


class MockPotential(Potential):
    """Mock potential for testing with controlled delays"""

    def __init__(self):
        super().__init__(list(range(256)))
        self.delay = 0.1  # 100ms delay per operation

    async def complete(self, context):
        time.sleep(self.delay)
        return np.log(len(context))

    async def prefix(self, context):
        time.sleep(self.delay)
        return np.log(len(context) / 2)

    async def batch_complete(self, contexts):
        time.sleep(self.delay)  # Single delay for batch
        return [np.log(len(context)) for context in contexts]

    async def batch_prefix(self, contexts):
        time.sleep(self.delay)  # Single delay for batch
        return [np.log(len(context) / 2) for context in contexts]