# Experiments with Runnables

Create a few RunnableLambda, and compose them in sequence and in parallel

In [None]:
from langchain_core.runnables import (
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
    chain,
)

add_1 = RunnableLambda(lambda x: x + 1)
add_3 = RunnableLambda(lambda x: x + 3)


@chain
def mult_2(x: int):
    return x * 2


sequence = mult_2 | add_1
parallel = mult_2 | {"add_1": add_1, "add_3": add_3}

# same as above:
parallel1 = RunnableParallel(add_1=add_1, add_3=add_3)

In [None]:
add_1.invoke(5)

Run the runnable directly, batched, in parallel (multi-threaded whenever possible!), ...

In [None]:
print(sequence.invoke(1))  # 3
sequence.batch([1, 2, 3])  # [3,5,7]
parallel.invoke(1)  # {'add_1': 3, 'add_3': 4}

await sequence.abatch([1, 2, 3, 4, 5])

In [None]:
sequence.invoke(1)  # 3

Runnable can stream their outcome

In [None]:
for s in parallel.stream(100000):
    print(s, end="|", flush=True)

Print the graph and various type information

In [None]:
parallel.get_graph().print_ascii()

print("input type:", sequence.InputType)
print("output type:", sequence.OutputType)

# print("input schema: ", sequence.input_schema().model_json_schema())  # Does no longer work
print("output schema: ", sequence.output_schema().model_json_schema())

Use RunnablePassthrough

In [None]:
runnable = RunnableParallel(titi=RunnablePassthrough(), modified=add_1)


runnable.invoke(10)  # {'origin': 10, 'modified': 11}

Demo 'bind' and 'RunnableConfig' : implement a filter, and log activities

In [None]:
from typing import cast

from langchain_core.runnables import RunnableConfig
from loguru import logger


@chain  # type: ignore
def max(x: int, max: int, config: RunnableConfig) -> int:
    if log := (config["configurable"]["logger"]):  # type: ignore
        log.info(f"check if {x} < {max}")
    return max if x >= max else x

a = sequence | max.bind(max=6) # type: ignore
a.batch([1, 2, 3, 4, 5], config=({"logger": logger}))

[32m2024-11-05 13:26:14.823[0m | [1mINFO    [0m | [36m__main__[0m:[36mmax[0m:[36m10[0m - [1mcheck if 3 < 6[0m
[32m2024-11-05 13:26:14.826[0m | [1mINFO    [0m | [36m__main__[0m:[36mmax[0m:[36m10[0m - [1mcheck if 5 < 6[0m
[32m2024-11-05 13:26:14.828[0m | [1mINFO    [0m | [36m__main__[0m:[36mmax[0m:[36m10[0m - [1mcheck if 7 < 6[0m
[32m2024-11-05 13:26:14.829[0m | [1mINFO    [0m | [36m__main__[0m:[36mmax[0m:[36m10[0m - [1mcheck if 9 < 6[0m
[32m2024-11-05 13:26:14.836[0m | [1mINFO    [0m | [36m__main__[0m:[36mmax[0m:[36m10[0m - [1mcheck if 11 < 6[0m


[3, 5, 6, 6, 6]

In [None]:
type(max)

In [None]:
a = max.bind(
    max=6,
)
a.invoke(10, {"logger": logger})

Demo 'assign", that adds new fields to the dict output of the runnable and returns a new runnable. Often use with RunnablePassthrough to add a given argument to a dict.

In [None]:
from langchain_core.runnables import RunnableParallel, RunnablePassthrough

runnable = (
    RunnableParallel(
        extra=RunnablePassthrough.assign(mult_10=lambda x: x["num"] * 10),
        plus_1=lambda x: x["num"] + 1,
    )
    .assign(info=lambda x: x)
    .assign(plus_1_time_3=lambda x: x["plus_1"] * 3)
)

runnable.invoke({"num": 2})

In [None]:
runnable.get_graph().print_ascii()

Most of the time, there are several parameters, or it's a dictionary (almost equivalent in Python).
'itemgetter' create a function that can extract one or several fields from a dictionary. 

First have a look at how it works:

In [None]:
from operator import itemgetter

dic = {
    "question": "What are the types of agent memory?",
    "generation": "The types of agent memory are: sensory memory, short-term memory, and long-term memory.",
    "documents": [],
}

getter_function = itemgetter("generation")
getter_function(dic)

here an example with Runnables

In [None]:
adder = RunnableLambda(lambda d: d["op1"] + d["op2"])


mult_2_and_add = (
    RunnableParallel(
        op1=RunnablePassthrough() | itemgetter("a") | mult_2,
        op2=RunnablePassthrough() | itemgetter("b") | mult_2,
    )
    | adder
)

mult_2_and_add.invoke({"a": 10, "b": 2, "z": "sds"})  # should return 2*10 + 2*2 = 24

In [None]:
zzz = RunnableLambda(lambda x: {"a": x, "b": 2}) | mult_2_and_add

zzz.invoke(10)

Runnables can have fallback in case they break.

In [None]:
@chain
def mult_10_fail(x: int):
    raise Exception("unavailable multiplication by 10 service")
    return x * 10


fallback_chain = mult_10_fail.with_fallbacks([mult_2])
fallback_chain.invoke(2)

See also : https://python.langchain.com/v0.2/docs/how_to/lcel_cheatsheet/