# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore.foundation import L
from fastcore.utils import *

### Test Data
**Why does this exist?**  I like using langgraph.  Modeling reasoning processes as graphs makes sense.  However, there's a few problems that make it difficult for me.  First of all, I have a terrible memory.  I really can't memorize a constantly changing API, and I get tired looking up stuff I don't remember.  And if I'm thinking of a graph, with nodes and edges, the langgraph architecture is great.  I just don't know all the variations of it off the top of my head.  All I know is a sort of state graph like:  `START -> do_something -> END`

Now I want to code that up in langgraph.  What do I need?  Exactly how do I code these?  It's not at all clear given a notation like this:  `START -> node_1 -> node_2 -> END`

This tool let's you define graphs this way, and writes functioning langgraph code. 

### Example Graphs in this notation

These graphs are taken from the langchain video [Building Effective Agents with LangGraph](https://youtu.be/aHCDrAbH_go?si=X5wbxvNwzJikLPN-)

In [None]:
bea_basic = """START -> generate_joke
generate_joke -> check_punchline(improve_joke, END)
improve_joke -> polish_joke -> END
"""

bea_parallel = """START -> call_llm_1, call_llm_2, call_llm_3 -> aggregator -> END"""

bea_orchestrator_worker = """START -> orchestrator -> llm_call(*sections) -> synthesizer -> END"""

In [None]:
test_cases = {
    "bea_basic": {
        "notation": bea_basic,
        "expected": [
            ('START', 'generate_joke'),
            ('generate_joke', 'check_punchline(improve_joke, END)'),
            ('improve_joke', 'polish_joke'),
            ('polish_joke', 'END')
        ]
    },
    "bea_parallel": {
        "notation": bea_parallel,
        "expected": [
            ('START', 'call_llm_1'),
            ('START', 'call_llm_2'),
            ('START', 'call_llm_3'),
            ('call_llm_1', 'aggregator'),
            ('call_llm_2', 'aggregator'),
            ('call_llm_3', 'aggregator'),
            ('aggregator', 'END')
        ]
    },
    "bea_orchestrator_worker": {
        "notation": bea_orchestrator_worker,
        "expected": [
            ('START', 'orchestrator'),
            ('orchestrator', 'llm_call(*sections)'),
            ('llm_call(*sections)', 'synthesizer'),
            ('synthesizer', 'END')
        ]
    }
}

#### Step 1: break into pairs

All operations after this use:
- graph_name -- used in code generation
- graph_notation -- see examples, only 4 patters: simple transition `A -> B`, parallel destinations `A -> B, C, D`
- graph_data -- pairs of related graph entities, where each node transitions to

In [None]:
#| export
def _get_pairs(graph_notation: str):
    """Given a text representation of the graph, return unprocessed pairs direct from the notation"""
    pairs = []
    for line in graph_notation.splitlines():
        line_components = [v.strip() for v in line.split("->")]
        line_pairs = list(zip(line_components, line_components[1:]))
        pairs.extend(line_pairs)
    return pairs

def _expand_commas(graph_pairs):
    for l,r in graph_pairs:
        if "," in l and "(" not in l: llist = [x.strip() for x in l.split(",")]
        else: llist = [l]
        if "," in r and "(" not in r: rlist = [x.strip() for x in r.split(",")]
        else: rlist = [r]
        for left in llist:
            for right in rlist:
                yield left, right

def _expand_node_lists(graph_pairs):
    return list(_expand_commas(graph_pairs))

def get_graph_data(graph_notation: str):
    """Given a text representation of the graph, return pairs of components"""
    pairs = _get_pairs(graph_notation)
    pairs = _expand_node_lists(pairs)
    return pairs

In [None]:
graph_name = "bea_basic"
graph_notation = bea_basic
graph_data = get_graph_data(graph_notation)
graph_name, graph_notation, graph_data

('bea_basic',
 'START -> generate_joke\ngenerate_joke -> check_punchline(improve_joke, END)\nimprove_joke -> polish_joke -> END\n',
 [('START', 'generate_joke'),
  ('generate_joke', 'check_punchline(improve_joke, END)'),
  ('improve_joke', 'polish_joke'),
  ('polish_joke', 'END')])

In [None]:
assert(test_cases[graph_name]['notation'] == graph_notation)
assert(test_cases[graph_name]['expected'] == graph_data)

In [None]:
graph_name = "bea_parallel"
graph_notation = bea_parallel
graph_data = get_graph_data(graph_notation)
graph_name, graph_notation, graph_data

('bea_parallel',
 'START -> call_llm_1, call_llm_2, call_llm_3 -> aggregator -> END',
 [('START', 'call_llm_1'),
  ('START', 'call_llm_2'),
  ('START', 'call_llm_3'),
  ('call_llm_1', 'aggregator'),
  ('call_llm_2', 'aggregator'),
  ('call_llm_3', 'aggregator'),
  ('aggregator', 'END')])

In [None]:
assert(test_cases[graph_name]['notation'] == graph_notation)
assert(test_cases[graph_name]['expected'] == graph_data)

In [None]:
graph_name = "bea_orchestrator_worker"
graph_notation = bea_orchestrator_worker
graph_data = get_graph_data(graph_notation)
graph_name, graph_notation, graph_data

('bea_orchestrator_worker',
 'START -> orchestrator -> llm_call(*sections) -> synthesizer -> END',
 [('START', 'orchestrator'),
  ('orchestrator', 'llm_call(*sections)'),
  ('llm_call(*sections)', 'synthesizer'),
  ('synthesizer', 'END')])

In [None]:
assert(test_cases[graph_name]['notation'] == graph_notation)
assert(test_cases[graph_name]['expected'] == graph_data)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()