# Stream custom generator functions

You can use generator functions (ie. functions that use the yield keyword, and behave like iterators) in a LCEL pipeline.

The signature of these generators should be Iterator[Input] -> Iterator[Output]. Or for async generators: AsyncIterator[Input] -> AsyncIterator[Output].

These are useful for: - implementing a custom output parser - modifying the output of a previous step, while preserving streaming capabilities

Letâ€™s implement a custom output parser for comma-separated lists.

In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
model_config_path = os.path.abspath(os.path.join('../custom_llms/'))
sys.path.insert(0, module_path)
sys.path.insert(0, model_config_path)

from custom_llms.minimax_llm import MiniMaxLLM

model = MiniMaxLLM()

In [2]:
from typing import Iterator, List

from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

prompt = ChatPromptTemplate.from_template(
    "Write a comma-separated list of 5 animals similar to: {animal}"
)
# model = ChatOpenAI(temperature=0.0)

str_chain = prompt | model | StrOutputParser()

In [3]:
for chunk in str_chain.stream({"animal": "bear"}):
    print(chunk, end="", flush=True)

lion, tiger, wolf, elephant, gorilla

In [4]:
str_chain.invoke({"animal": "bear"})

'tiger, lion, wolf, grizzly, leopard'

In [5]:
# This is a custom parser that splits an iterator of llm tokens
# into a list of strings separated by commas
def split_into_list(input: Iterator[str]) -> Iterator[List[str]]:
    # hold partial input until we get a comma
    buffer = ""
    for chunk in input:
        # add current chunk to buffer
        buffer += chunk
        # while there are commas in the buffer
        while "," in buffer:
            # split buffer on comma
            comma_index = buffer.index(",")
            # yield everything before the comma
            yield [buffer[:comma_index].strip()]
            # save the rest for the next iteration
            buffer = buffer[comma_index + 1 :]
    # yield the last chunk
    yield [buffer.strip()]

In [6]:
list_chain = str_chain | split_into_list

In [7]:
for chunk in list_chain.stream({"animal": "bear"}):
    print(chunk, flush=True)

['lion']
['tiger']
['wolf']
['leopard']
['jaguar']


In [8]:
list_chain.invoke({"animal": "bear"})

['wolf', 'lion', 'tiger', 'elephant', 'gorilla']

## Async Version

In [9]:
from typing import AsyncIterator


async def asplit_into_list(
    input: AsyncIterator[str],
) -> AsyncIterator[List[str]]:  # async def
    buffer = ""
    async for (
        chunk
    ) in input:  # `input` is a `async_generator` object, so use `async for`
        buffer += chunk
        while "," in buffer:
            comma_index = buffer.index(",")
            yield [buffer[:comma_index].strip()]
            buffer = buffer[comma_index + 1 :]
    yield [buffer.strip()]


list_chain = str_chain | asplit_into_list

In [10]:
async for chunk in list_chain.astream({"animal": "bear"}):
    print(chunk, flush=True)

['wolf']
['lion']
['tiger']
['leopard']
['hyena']


In [11]:
await list_chain.ainvoke({"animal": "bear"})

['wolf', 'lion', 'tiger', 'leopard', 'gorilla']