# Caching Relationship and Executable Graphs

Create a relationship graph and a executable graph for a hypothetical data caching scenario.


In [None]:
import os
from typing import Sequence
from rustworkx.visualization import graphviz_draw

from petritype.core.ast_extraction import FunctionWithAnnotations
from petritype.core.data_structures import TypeVariableWithAnnotations
from petritype.core.parse_modules import (
    ParseModule, ParsedModule, ExtractFunctions, ExtractTypes,
)
from petritype.plotting.rustworkx_to_graphviz import RustworkxToGraphviz

In [None]:
"""Read types and functions from the example python file."""

from petritype.core.relationship_graph_components import (
    FunctionToTypeEdges,
    RelationshipEdges,
    TypeToFunctionEdges,
    TypeToTypeEdges,
)


path_components = ("examples", "caching", "hypothetical_caching.py")
module_from_py_file = ParseModule.from_file(
    path_to_file=os.path.join(*path_components),
    import_path_components=path_components,
)
functions: Sequence[FunctionWithAnnotations] = (
    ExtractFunctions.from_selected_classes_in_parsed_modules(
        parsed_modules=(module_from_py_file,),
        selected_classes=("DBOperations", "CacheOperations", "Branch"),
    )
)
types: Sequence[TypeVariableWithAnnotations] = ExtractTypes.from_parsed_modules(
    parsed_modules=(module_from_py_file,),
)
edges_type_to_type: TypeToTypeEdges = RelationshipEdges.type_to_type(types)
edges_type_to_function: TypeToFunctionEdges = RelationshipEdges.type_to_function(
    types, functions
)
edges_function_to_type: FunctionToTypeEdges = RelationshipEdges.function_to_type(
    functions, types
)

# Relationship Graph


In [None]:
(
    graph,
    type_names_to_node_indices,
    function_names_to_node_indices,
    type_relationship_edges,
) = RustworkxToGraphviz.digraph(
    types=types,
    functions=functions,
    edges_type_to_function=edges_type_to_function,
    edges_function_to_type=edges_function_to_type,
    edges_type_to_type=edges_type_to_type,
)
graphviz_draw(
    graph,
    node_attr_fn=RustworkxToGraphviz.node_attr_fn,
    edge_attr_fn=RustworkxToGraphviz.edge_attr_fn,
    method="sfdp",
)

# Execution Graph

For all shortest paths from start to end nodes.
At every node in the path get all the adjacent nodes and add them to the set of relevant nodes.
For every function node in the set so far, get all the adjacent nodes and add them to the set of relevant nodes.
Use the set of relevant nodes to extract code for types and functions.


In [None]:
import rustworkx as rx
from itertools import chain

start_type, end_type = "DBKey", "DBKeyValuePair"
start_type_index, end_type_index = (
    type_names_to_node_indices[start_type],
    type_names_to_node_indices[end_type],
)
shortest_paths = rx.digraph_all_shortest_paths(graph, start_type_index, end_type_index)
shortest_paths_indices = list(chain.from_iterable(shortest_paths))
shortest_paths_neighbors = list(
    chain.from_iterable([graph.neighbors(i) for i in shortest_paths_indices])
)
relevant_nodes = set(shortest_paths_indices + shortest_paths_neighbors)

In [None]:
from typing import Iterable
from petritype.core.data_structures import ClassName, NodeIndex


def relevant_functions_from_graph_nodes(
    graph: rx.PyDiGraph, relevant_nodes: Iterable[NodeIndex]
) -> Sequence[FunctionWithAnnotations]:
    out = []
    for node in relevant_nodes:
        if isinstance(graph[node], FunctionWithAnnotations):
            out.append(graph[node])
    return tuple(out)


def relevant_types_from_graph_nodes(
    graph: rx.PyDiGraph, relevant_nodes: Iterable[NodeIndex]
) -> Sequence[TypeVariableWithAnnotations]:
    out = []
    for node in relevant_nodes:
        if isinstance(graph[node], TypeVariableWithAnnotations):
            out.append(graph[node])
    return tuple(out)


def relevant_classes_from_functions(
    functions_with_annotations: Sequence[FunctionWithAnnotations],
) -> set[ClassName]:
    out = []
    for f in functions_with_annotations:
        out.append(f.class_name)
    return set(out)

In [None]:
relevant_functions = relevant_functions_from_graph_nodes(graph, relevant_nodes)
relevant_types = relevant_types_from_graph_nodes(graph, relevant_nodes)
relevant_classes = relevant_classes_from_functions(functions)

In [None]:
parsed_module = module_from_py_file
parsed_module

In [None]:
from petritype.core.parse_modules import ExtractClassCode

relevant_classes_code_from_module = []
for selected_class in relevant_classes:
    relevant_classes_code_from_module.append(
        ExtractClassCode.from_parsed_module(parsed_module, selected_class)
    )


In [None]:
relevant_types_code = []
for t in relevant_types:
    relevant_types_code.append(t.code)


In [None]:
path_to_data_structures_file = os.path.join("petritype", "core", "data_structures.py")
data_structures_code = ParseModule.from_file(
    path_to_file=path_to_data_structures_file,
    import_path_components=("petritype", "core", "data_structures"),
).code
data_structures_description = (
    "BACKGROUND: The following data structures are used:\n"
    "```" + data_structures_code + "```" + "\n\n\n"
)

In [None]:
path_to_executable_graph_components = os.path.join(
    "petritype", "core", "executable_graph_components.py"
)
executable_graph_components_code = ParseModule.from_file(
    path_to_file=path_to_executable_graph_components,
    import_path_components=("petritype", "core", "flow_graph_components"),
).code
executable_graph_description = (
    "INTRO: The following code describes the components of a AST-KG executable graph in our context.\n"
    "A executable graph instance is simply a tuple of instantiated flow graph components.\n"
    "\n\n\n"
    "```" + executable_graph_components_code + "```"
)


In [None]:
from petritype.core.parse_modules import ExtractImportStatements


def format_types(types: Sequence[TypeVariableWithAnnotations]) -> str:
    intro = '"""The following type declarations are relevant here."""\n\n'
    out = []
    for t in types:
        out.append(t.code)
    return intro + "```" + "\n\n\n".join(out) + "```" + "\n\n\n"


def format_module_name(parsed_module: ParsedModule) -> str:
    module_path = "/".join(parsed_module.import_path_components)
    return f"MODULE: The following code exists in {module_path}\n\n"


import_statements = ExtractImportStatements.from_parsed_module(parsed_module)
imports = "`" + ".\n".join(import_statements) + "\n\n\n" + "`"
description_of_types = format_types(relevant_types)
task_description = (
    "TASK:\n"
    f"Propose a AST-KG executable graph starting at {start_type} and ending at {end_type}.\n"
    "This graph should describe a process where the value is retrieved from the cache if it exists,\n"
    "but if it does not exist, the value is retrieved from the database and then stored in the cache.\n"
    "Use the available types and functions where possible but propose new types and functions if needed.\n"
    "When declaring nodes avoid using more general types (e.g. str or dict or other COMMON_TYPES)\n"
    "when more specific types are available.\n"
    "The db and cache exist outside the graph and can be passed in via kwargs to transition nodes\n"
    "and thus do not correspond to place nodes in the graph.\n"
    f"\n\n"
)

executable_imports = """from petritype.core.executable_graph_components import *"""

prompt = (
    data_structures_description
    + executable_graph_description
    + task_description
    + description_of_types
    + "\n\n\n"
    + format_module_name(parsed_module)
    + "\n\n\n"
    + imports
    + "\n\n\n"
    + "```"
    + "\n\n\n".join(relevant_classes_code_from_module)
    + "```"
    + "\n\n\n"
    + "NOTE 0: Branching logic is described by having multiple edges from a transition node to multiple place nodes. "
    + "The token goes to the place or places that match it's type which in turn is determined by the function "
    + "that returned the token.\n"
    + "NOTE 1: do not redefine the existing classes in the answer, use the following imports instead:\n"
    + "`"
    + executable_imports
    + "`"
    + "\n"
    + "NOTE 2: Print code in a single block that it can all be copied in one go.\n"
    + "NOTE 3: Import all the relevant types at the top."
)
print(prompt)

In [None]:
from copy import deepcopy


cache_before = {
    "a_0": "A_0",
    "c_0": "C_1",
    "d_0": "D_2",
}
cache = deepcopy(cache_before)
db = {
    "a_0": "A_10",
    "b_0": "B_11",
    "c_0": "C_12",
    "e_0": "E_13",
    "f_0": "F_14",
    "g_0": "G_15",
    "h_0": "H_16",
}
initial_keys = [
    "a_0", "b_0", "c_0", "d_0", "e_0", "f_0", "g_0", "h_0", 
    # "i_0", "j_0", "unknown", "missing"
]
expected_cache_after = {  # TODO: Add check that this matches the result.
    "a_0": "A_0",
    "c_0": "C_1",
    "d_0": "D_2",
    "b_0": "B_11",
    "e_0": "E_13",
    "f_0": "F_14",
    "g_0": "G_15",
    "h_0": "H_16",
}

In [None]:
from petritype.core.executable_graph_components import *

from examples.caching.hypothetical_caching import *


# Defining the place nodes
key_input = ListPlaceNode(name='KeyInput', type=DBKey, values=initial_keys)
key_for_db_retrieval = ListPlaceNode(name='KeyForDBRetrieval', type=DBKey)
db_value_retrieved = ListPlaceNode(name='DBValueRetrieved', type=DBKeyValuePair)
cached_value_found = ListPlaceNode(name='CachedValueFound', type=DBKeyValuePair)
final_key_value_pair = ListPlaceNode(name='FinalKeyValuePair', type=DBKeyValuePair)

# Defining the transition nodes
check_cache = FunctionTransitionNode(
    name='CheckCache',
    function=CacheOperations.retrieve_key_value_pair,
    kwargs={'cache': cache}  # 'cache' passed as argument during graph execution
)

retrieve_from_db = FunctionTransitionNode(
    name='RetrieveFromDB',
    function=DBOperations.retrieve_key_value_pair,
    kwargs={'db': db}  # 'db' passed as argument during graph execution
)

# This transition caches the value retrieved from the database
cache_key_value_pair = FunctionTransitionNode(
    name='CacheKeyValuePair',
    function=CacheOperations.cache_key_value_pair,
    kwargs={'cache': cache, 'expected_size': 100}  # Modify 'expected_size' as needed
)

# Defining the edges
input_to_check_cache = ArgumentEdgeToTransition(
    place_node_name='KeyInput',
    transition_node_name='CheckCache',
    argument='key'
)

check_cache_to_cached_value = ReturnedEdgeFromTransition(
    transition_node_name='CheckCache',
    place_node_name='CachedValueFound',
    return_index=1  # Assumes the function returns a DBKeyValuePair on cache hit
)

check_cache_to_db_retrieval = ReturnedEdgeFromTransition(
    transition_node_name='CheckCache',
    place_node_name='KeyForDBRetrieval',
    return_index=0  # Assumes the function returns a DBKey on cache miss
)

db_retrieval_to_cache = ArgumentEdgeToTransition(
    place_node_name='DBValueRetrieved',
    transition_node_name='CacheKeyValuePair',
    argument='key_value_pair'
)

db_retrieval_from_key = ArgumentEdgeToTransition(
    place_node_name='KeyForDBRetrieval',
    transition_node_name='RetrieveFromDB',
    argument='key'
)

retrieve_from_db_to_db_value = ReturnedEdgeFromTransition(
    transition_node_name='RetrieveFromDB',
    place_node_name='DBValueRetrieved'
)

cache_result_to_final = ReturnedEdgeFromTransition(
    transition_node_name='CacheKeyValuePair',
    place_node_name='FinalKeyValuePair'
)

# Defining the graph
executable_graph_mixed_nodes_and_edges = (
    key_input,
    key_for_db_retrieval,
    db_value_retrieved,
    cached_value_found,
    final_key_value_pair,
    check_cache,
    retrieve_from_db,
    cache_key_value_pair,
    input_to_check_cache,
    check_cache_to_cached_value,
    check_cache_to_db_retrieval,
    db_retrieval_from_key,
    retrieve_from_db_to_db_value,
    db_retrieval_to_cache,
    cache_result_to_final
)
executable_graph = ExecutableGraphOperations.construct_graph(executable_graph_mixed_nodes_and_edges)


In [None]:
from petritype.core.rustworkx_graph import RustworkxGraph


executable_pydigraph = RustworkxGraph.from_executable_graph(executable_graph)

In [None]:
from rustworkx.visualization import graphviz_draw



def place_node_label(node: ListPlaceNode) -> str:
    label = f"{node.name}\n({node.type.__name__})"
    value_strings = [str(x) for x in node.values]
    values_string = "\n".join(value_strings)
    return f"{label}\n{values_string}"


def transition_node_label(node: FunctionTransitionNode) -> str:
    return f"{node.name}\n({node.function.__qualname__})"


def flow_node_attr_fn(node):
    if isinstance(node, ListPlaceNode):
        return {
            "label": place_node_label(node),
            'color': 'deepskyblue',
            'style': 'filled',
            'shape': 'oval'
        }
    elif isinstance(node, FunctionTransitionNode):
        return {
            "label": transition_node_label(node),
            'color': 'lightgreen',
            'style': 'filled',
            'shape': 'box'
        }
    else:
        raise ValueError("Invalid node data type.")


graphviz_draw(
    executable_pydigraph,
    node_attr_fn=flow_node_attr_fn,
    # edge_attr_fn=edge_attr_fn,
    method='dot',
)

In [None]:
import time
import matplotlib.pyplot as plt
from IPython.display import display, clear_output


for i in range(100):
    print(i)

    # Take one processing step.
    try:
        _, transitions_fired = await ExecutableGraphOperations.execute_graph(
            executable_graph=executable_graph,
            max_transitions=1,
            verbose=True,
        )
    except Exception as e:
        print(f"Error: {e}")
    

    clear_output(wait=True)
    node_attr_fn, edge_attr_fn = RustworkxToGraphviz.generate_attr_fn(executable_graph)
    diagram = graphviz_draw(
        executable_pydigraph,
        node_attr_fn=node_attr_fn,
        edge_attr_fn=edge_attr_fn,
        method='dot',
    )
    display(diagram)
    time.sleep(1)

    if not transitions_fired:
        break

    plt.close()
