In [None]:
import os
from typing import Sequence
from rustworkx.visualization import graphviz_draw

from petritype.core.ast_extraction import FunctionWithAnnotations
from petritype.core.data_structures import TypeVariableWithAnnotations
from petritype.core.parse_modules import (
    ParseModule, ParsedModule, ExtractFunctions, ExtractTypes,
)
from petritype.plotting.rustworkx_to_graphviz import RustworkxToGraphviz

In [None]:
"""Read types and functions from the example python file."""

from petritype.core.relationship_graph_components import (
    FunctionToTypeEdges, RelationshipEdges, TypeToFunctionEdges, TypeToTypeEdges
)


path_components = ("examples", "time_series_stats", "hypothetical_time_series.py")
module_from_py_file = ParseModule.from_file(
    path_to_file=os.path.join(*path_components),
    import_path_components=path_components,
)
functions: Sequence[FunctionWithAnnotations] = (
    ExtractFunctions.from_selected_classes_in_parsed_modules(
        parsed_modules=(module_from_py_file,),
        selected_classes=("SimulateData", "SeriesStatistics"),
    )
)
types: Sequence[TypeVariableWithAnnotations] = ExtractTypes.from_parsed_modules(
    parsed_modules=(module_from_py_file,),
)
edges_type_to_type: TypeToTypeEdges = RelationshipEdges.type_to_type(types)
edges_type_to_function: TypeToFunctionEdges = RelationshipEdges.type_to_function(
    types, functions
)
edges_function_to_type: FunctionToTypeEdges = RelationshipEdges.function_to_type(
    functions, types
)

# Relationship Graph


In [None]:
(
    graph,
    type_names_to_node_indices,
    function_names_to_node_indices,
    type_relationship_edges,
) = RustworkxToGraphviz.digraph(
    types=types,
    functions=functions,
    edges_type_to_function=edges_type_to_function,
    edges_function_to_type=edges_function_to_type,
    edges_type_to_type=edges_type_to_type,
)
graphviz_draw(
    graph,
    node_attr_fn=RustworkxToGraphviz.node_attr_fn,
    edge_attr_fn=RustworkxToGraphviz.edge_attr_fn,
    method="sfdp",
)

# Execution Graph

For all shortest paths from start to end nodes.
At every node in the path get all the adjacent nodes and add them to the set of relevant nodes.
For every function node in the set so far, get all the adjacent nodes and add them to the set of relevant nodes.
Use the set of relevant nodes to extract code for types and functions.


In [None]:
import rustworkx as rx
from itertools import chain

start_type, end_type = "TimeSeriesGeneratingParameters", "ExponentialMovingAverageOfInterval"
start_type_index, end_type_index = (
    type_names_to_node_indices[start_type],
    type_names_to_node_indices[end_type],
)
shortest_paths = rx.digraph_all_shortest_paths(graph, start_type_index, end_type_index)
shortest_paths_indices = list(chain.from_iterable(shortest_paths))
shortest_paths_neighbors = list(
    chain.from_iterable([graph.neighbors(i) for i in shortest_paths_indices])
)
relevant_nodes = set(shortest_paths_indices + shortest_paths_neighbors)

In [None]:
from typing import Iterable
from petritype.core.data_structures import ClassName, NodeIndex


def relevant_functions_from_graph_nodes(
    graph: rx.PyDiGraph, relevant_nodes: Iterable[NodeIndex]
) -> Sequence[FunctionWithAnnotations]:
    out = []
    for node in relevant_nodes:
        if isinstance(graph[node], FunctionWithAnnotations):
            out.append(graph[node])
    return tuple(out)


def relevant_types_from_graph_nodes(
    graph: rx.PyDiGraph, relevant_nodes: Iterable[NodeIndex]
) -> Sequence[TypeVariableWithAnnotations]:
    out = []
    for node in relevant_nodes:
        if isinstance(graph[node], TypeVariableWithAnnotations):
            out.append(graph[node])
    return tuple(out)


def relevant_classes_from_functions(
    functions_with_annotations: Sequence[FunctionWithAnnotations],
) -> set[ClassName]:
    out = []
    for f in functions_with_annotations:
        out.append(f.class_name)
    return set(out)


In [None]:
relevant_functions = relevant_functions_from_graph_nodes(graph, relevant_nodes)
relevant_types = relevant_types_from_graph_nodes(graph, relevant_nodes)
relevant_classes = relevant_classes_from_functions(functions)

In [None]:
parsed_module = module_from_py_file
parsed_module

In [None]:
from petritype.core.parse_modules import ExtractClassCode

relevant_classes_code_from_module = []
for selected_class in relevant_classes:
    relevant_classes_code_from_module.append(
        ExtractClassCode.from_parsed_module(parsed_module, selected_class)
    )


In [None]:
relevant_types_code = []
for t in relevant_types:
    relevant_types_code.append(t.code)


In [None]:
path_to_data_structures_file = os.path.join("petritype", "core", "data_structures.py")
data_structures_code = ParseModule.from_file(
    path_to_file=path_to_data_structures_file,
    import_path_components=("petritype", "core", "data_structures"),
).code
data_structures_description = (
    "BACKGROUND: The following data structures are used:\n"
    "```" + data_structures_code + "```" + "\n\n\n"
)

In [None]:
path_to_executable_graph_components = os.path.join(
    "petritype", "core", "executable_graph_components.py"
)
executable_graph_components_code = ParseModule.from_file(
    path_to_file=path_to_executable_graph_components,
    import_path_components=("petritype", "core", "flow_graph_components"),
).code
executable_graph_description = (
    "INTRO: The following code describes the components of a Petritype Executable Graph in our context.\n"
    "A executable graph instance is simply a tuple of instantiated flow graph components.\n"
    "\n\n\n"
    "```" + executable_graph_components_code + "```"
)


In [None]:
from petritype.core.parse_modules import ExtractImportStatements


def format_types(types: Sequence[TypeVariableWithAnnotations]) -> str:
    intro = '"""The following type declarations are relevant here."""\n\n'
    out = []
    for t in types:
        out.append(t.code)
    return intro + "```" + "\n\n\n".join(out) + "```" + "\n\n\n"


def format_module_name(parsed_module: ParsedModule) -> str:
    module_path = "/".join(parsed_module.import_path_components)
    return f"MODULE: The following code exists in {module_path}\n\n"


import_statements = ExtractImportStatements.from_parsed_module(parsed_module)
imports = "`" + ".\n".join(import_statements) + "\n\n\n" + "`"
description_of_types = format_types(relevant_types)
task_description = (
    "TASK:\n"
    f"Propose a Petritype Executable Graph starting at {start_type} and ending at {end_type}.\n"
    "TimeSeriesGeneratingParameters should be copied from the input to subsequent branches.\n"
    "Branch 1: TimeSeriesGeneratingParameters -> TimeSeries -> SeriesStatistics -> ExponentialMovingAverageOfInterval\n"
    "Branch 2: TimeSeriesGeneratingParameters -> TimeSeries -> SeriesStatistics -> ExponentialMovingAverageOfInterval\n"
    "Branch 1 decay_parameter = 0.0002\n"
    "Branch 2 decay_parameter = 0.0001\n"
    f"\n\n"
)

executable_imports = """from petritype.core.executable_graph_components import *"""

prompt = (
    data_structures_description
    + executable_graph_description
    + task_description
    + description_of_types
    + "\n\n\n"
    + format_module_name(parsed_module)
    + "\n\n\n"
    + imports
    + "\n\n\n"
    + "```"
    + "\n\n\n".join(relevant_classes_code_from_module)
    + "```"
    + "\n\n\n"
    + "NOTE 0: Branching logic is described by having multiple edges from a transition node to multiple place nodes. "
    + "The token goes to the place or places that match it's type which in turn is determined by the function "
    + "that returned the token.\n"
    + "NOTE 1: do not define new classes or functions or redefine the existing classes or functions in the answer,\n"
    + " use the following imports instead:\n"
    + "`"
    + executable_imports
    + "`"
    + "\n"
    + "NOTE 2: Print code in a single block that it can all be copied in one go.\n"
    + "NOTE 3: Import all the relevant types at the top."
)
print(prompt)

In [None]:
from datetime import datetime, timedelta
import numpy as np
from petritype.core.executable_graph_components import (
    ExecutableGraph, function_transition_node_and_output_places, function_transition_node_and_output_edges,
    ListPlaceNode, ArgumentEdgeToTransition, ReturnedEdgeFromTransition
)

# Import the necessary classes and functions
from examples.time_series_stats.hypothetical_time_series import (
    TimeSeriesGeneratingParameters, TimeSeriesInterval, SimulateData, SeriesStatistics,
    ExponentialMovingAverageOfInterval
)

from datetime import datetime, timedelta
import numpy as np
from petritype.core.executable_graph_components import *

# Create the initial place node for TimeSeriesGeneratingParameters
initial_parameters = ListPlaceNode(
    name="initial_parameters",
    type=TimeSeriesGeneratingParameters,
    tokens=[TimeSeriesGeneratingParameters(
        start=datetime.now(),
        end=datetime.now() + timedelta(days=1),
        n=100,
        amplitude=1.0,
        shift=0.0,
        noise_std=0.1,
        seed=42
    )]
)

# Transition node to copy TimeSeriesGeneratingParameters for Branch 1
copy_parameters = FunctionTransitionNode(
    name="Copy Parameters",
    function=SimulateData.copy_parameters
)

# Place nodes to hold the copied TimeSeriesGeneratingParameters for each branch
copied_parameters_branch1 = ListPlaceNode(
    name="Branch 1 Parameters",
    type=TimeSeriesGeneratingParameters
)

copied_parameters_branch2 = ListPlaceNode(
    name="Branch 2 Parameters",
    type=TimeSeriesGeneratingParameters
)

# Transition node to generate TimeSeriesInterval from copied TimeSeriesGeneratingParameters
generate_time_series_transition_1 = FunctionTransitionNode(
    name="generate_time_series-1",
    function=SimulateData.generate_sine_wave_with_noise_from_parameters
)
generate_time_series_transition_2 = FunctionTransitionNode(
    name="generate_time_series-2",
    function=SimulateData.generate_sine_wave_with_noise_from_parameters
)

# Place node to hold the generated TimeSeriesInterval, shared by both branches
time_series_interval_place_1 = ListPlaceNode(
    name="time_series_interval-1",
    type=TimeSeriesInterval
)
time_series_interval_place_2 = ListPlaceNode(
    name="time_series_interval-2",
    type=TimeSeriesInterval
)

# Transition nodes to compute SeriesStatistics and generate ExponentialMovingAverageOfInterval for each branch
calculate_ema_branch1_transition = FunctionTransitionNode(
    name="calculate_ema_branch1",
    function=SeriesStatistics.datetime_interval_ema,
    kwargs={"decay_parameter": 0.0002}
)

calculate_ema_branch2_transition = FunctionTransitionNode(
    name="calculate_ema_branch2",
    function=SeriesStatistics.datetime_interval_ema,
    kwargs={"decay_parameter": 0.0001}
)

# Place nodes to hold the results of Exponential Moving Average for each branch
ema_interval_branch1_place = ListPlaceNode(
    name="ema_interval_branch1",
    type=ExponentialMovingAverageOfInterval
)

ema_interval_branch2_place = ListPlaceNode(
    name="ema_interval_branch2",
    type=ExponentialMovingAverageOfInterval
)

# Define edges between transitions and places
edges_to_transitions = [
    ArgumentEdgeToTransition(place_node_name="initial_parameters", transition_node_name="Copy Parameters", argument="parameters"),
    ArgumentEdgeToTransition(place_node_name="Branch 1 Parameters", transition_node_name="generate_time_series-1", argument="parameters"),
    ArgumentEdgeToTransition(place_node_name="Branch 2 Parameters", transition_node_name="generate_time_series-2", argument="parameters"),
    ArgumentEdgeToTransition(place_node_name="time_series_interval-1", transition_node_name="calculate_ema_branch1", argument="interval"),
    ArgumentEdgeToTransition(place_node_name="time_series_interval-2", transition_node_name="calculate_ema_branch2", argument="interval")
]

edges_from_transitions = [
    ReturnedEdgeFromTransition(transition_node_name="Copy Parameters", place_node_name="Branch 1 Parameters"),
    ReturnedEdgeFromTransition(transition_node_name="Copy Parameters", place_node_name="Branch 2 Parameters"),
    ReturnedEdgeFromTransition(transition_node_name="generate_time_series-1", place_node_name="time_series_interval-1"),
    ReturnedEdgeFromTransition(transition_node_name="generate_time_series-2", place_node_name="time_series_interval-2"),
    ReturnedEdgeFromTransition(transition_node_name="calculate_ema_branch1", place_node_name="ema_interval_branch1"),
    ReturnedEdgeFromTransition(transition_node_name="calculate_ema_branch2", place_node_name="ema_interval_branch2")
]

# Instantiate the executable graph
executable_graph = ExecutableGraph(
    places=[
        initial_parameters,
        copied_parameters_branch1,
        copied_parameters_branch2,
        time_series_interval_place_1,
        time_series_interval_place_2,
        ema_interval_branch1_place,
        ema_interval_branch2_place
    ],
    transitions=[
        copy_parameters,
        generate_time_series_transition_1,
        generate_time_series_transition_2,
        calculate_ema_branch1_transition,
        calculate_ema_branch2_transition
    ],
    argument_edges=edges_to_transitions,
    return_edges=edges_from_transitions
)

print(executable_graph)


In [None]:
from petritype.core.rustworkx_graph import RustworkxGraph


executable_pydigraph = RustworkxGraph.from_executable_graph(executable_graph)

In [None]:
from rustworkx.visualization import graphviz_draw



def place_node_label(node: ListPlaceNode) -> str:
    label = f"{node.name}\n({node.type.__name__})"
    value_strings = [str(x) for x in node.tokens]
    tokens_string = "\n".join(value_strings)
    return f"{label}\n{tokens_string}"


def transition_node_label(node: FunctionTransitionNode) -> str:
    return f"{node.name}\n({node.function.__qualname__})"


def flow_node_attr_fn(node):
    if isinstance(node, ListPlaceNode):
        return {
            "label": place_node_label(node),
            'color': 'deepskyblue',
            'style': 'filled',
            'shape': 'oval'
        }
    elif isinstance(node, FunctionTransitionNode):
        return {
            "label": transition_node_label(node),
            'color': 'lightgreen',
            'style': 'filled',
            'shape': 'box'
        }
    else:
        raise ValueError("Invalid node data type.")


graphviz_draw(
    executable_pydigraph,
    node_attr_fn=flow_node_attr_fn,
    # edge_attr_fn=edge_attr_fn,
    method='dot',
)

In [None]:
import time
import matplotlib.pyplot as plt
from IPython.display import display, clear_output


for i in range(100):
    print(i)

    # Take one processing step.
    try:
        _, transitions_fired = await ExecutableGraphOperations.execute_graph(
            executable_graph=executable_graph,
            max_transitions=1,
            allow_token_copying=True,
            verbose=True,
        )
    except Exception as e:
        print(f"Error: {e}")
        import pdb; pdb.set_trace()
    
    time.sleep(1)
    clear_output(wait=True)
    node_attr_fn, edge_attr_fn = RustworkxToGraphviz.activation_coloured_attr_functions(executable_graph)
    diagram = graphviz_draw(
        executable_pydigraph,
        node_attr_fn=node_attr_fn,
        edge_attr_fn=edge_attr_fn,
        method='dot',
    )
    display(diagram)
    time.sleep(1)

    if not transitions_fired:
        break

    plt.close()


In [None]:
ema_interval_tokens_branch1 = executable_graph.place_named("ema_interval_branch1").tokens
ema_interval_tokens_branch2 = executable_graph.place_named("ema_interval_branch2").tokens
for token in ema_interval_tokens_branch1 + ema_interval_tokens_branch2:
    print(token.summary())
    token.plot_time_series()
    