In [None]:
import os
from typing import Sequence
from rustworkx.visualization import graphviz_draw

from petritype.core.ast_extraction import FunctionWithAnnotations
from petritype.core.data_structures import TypeVariableWithAnnotations
from petritype.core.parse_modules import (
    ParseModule, ExtractFunctions, ExtractTypes,
)
from petritype.plotting.rustworkx_to_graphviz import RustworkxToGraphviz

In [None]:
"""Read types and functions from the example python file."""

from petritype.core.relationship_graph_components import (
    FunctionToTypeEdges, RelationshipEdges, TypeToFunctionEdges, TypeToTypeEdges,
)



path_components = ("examples", "ml_model", "hypothetical_training_steps.py")

RELEVANT_CLASSES = tuple()

module_from_py_file = ParseModule.from_file(
    path_to_file=os.path.join(*path_components),
    import_path_components=path_components,
)
functions: Sequence[FunctionWithAnnotations] = (
    ExtractFunctions.from_selected_classes_in_parsed_modules(
        parsed_modules=(module_from_py_file,),
        selected_classes=RELEVANT_CLASSES,
    )
)
types: Sequence[TypeVariableWithAnnotations] = ExtractTypes.from_parsed_modules(
    parsed_modules=(module_from_py_file,),
)
edges_type_to_type: TypeToTypeEdges = RelationshipEdges.type_to_type(types)
edges_type_to_function: TypeToFunctionEdges = RelationshipEdges.type_to_function(
    types, functions
)
edges_function_to_type: FunctionToTypeEdges = RelationshipEdges.function_to_type(
    functions, types
)

# Execution Graph

For all shortest paths from start to end nodes.
At every node in the path get all the adjacent nodes and add them to the set of relevant nodes.
For every function node in the set so far, get all the adjacent nodes and add them to the set of relevant nodes.
Use the set of relevant nodes to extract code for types and functions.


In [None]:
(
    graph,
    type_names_to_node_indices,
    function_names_to_node_indices,
    type_relationship_edges,
) = RustworkxToGraphviz.digraph(
    types=types,
    functions=functions,
    edges_type_to_function=edges_type_to_function,
    edges_function_to_type=edges_function_to_type,
    edges_type_to_type=edges_type_to_type,
)

In [None]:
from petritype.core.executable_graph_components import *

from examples.ml_model.hypothetical_training_steps import *


fine_tuning_nodes_and_edges = (
    # Split and Augment Data
    ListPlaceNode(name="Available Data", type=AvailableData),
    ArgumentEdgeToTransition(
        place_node_name="Available Data",
        transition_node_name="Train/Validation/Test Split",
        argument="data",
    ),
    FunctionTransitionNode(name="Train/Validation/Test Split", function=train_validation_test_split),
    ReturnedEdgeFromTransition(
        transition_node_name="Train/Validation/Test Split",
        place_node_name="Training Data",
    ),
    ReturnedEdgeFromTransition(
        transition_node_name="Train/Validation/Test Split",
        place_node_name="Validation Data",
    ),
    ReturnedEdgeFromTransition(
        transition_node_name="Train/Validation/Test Split",
        place_node_name="Test Data",
    ),
    ListPlaceNode(name="Training Data", type=TrainingData),
    ListPlaceNode(name="Validation Data", type=TrainingData),
    ListPlaceNode(name="Test Data", type=TrainingData),
    ArgumentEdgeToTransition(
        place_node_name="Training Data",
        transition_node_name="Data Augmentation",
        argument="data",
    ),
    FunctionTransitionNode(name="Data Augmentation", function=data_augmentation),
    ReturnedEdgeFromTransition(
        transition_node_name="Data Augmentation",
        place_node_name="Augmented Training Data",
    ),
    ListPlaceNode(name="Augmented Training Data", type=TrainingData),

    # Fine-tune Model
    ListPlaceNode(name="Proposed Model & Hyperparameters", type=VisionModel),
    ArgumentEdgeToTransition(
        place_node_name="Proposed Model & Hyperparameters",
        transition_node_name="Fine-tuning",
        argument="model",
    ),
    ArgumentEdgeToTransition(
        place_node_name="Augmented Training Data",
        transition_node_name="Fine-tuning",
        argument="data",
    ),
    FunctionTransitionNode(name="Fine-tuning", function=fine_tune),
    ReturnedEdgeFromTransition(
        transition_node_name="Fine-tuning",
        place_node_name="Fine-tuned Model",
    ),
    ListPlaceNode(name="Fine-tuned Model", type=VisionModel),

    # Evaluate Model
    ArgumentEdgeToTransition(
        place_node_name="Fine-tuned Model",
        transition_node_name="Evaluate Model with Validation Data",
        argument="model",
    ),
    ArgumentEdgeToTransition(
        place_node_name="Validation Data",
        transition_node_name="Evaluate Model with Validation Data",
        argument="data",
    ),
    FunctionTransitionNode(name="Evaluate Model with Validation Data", function=evaluate_model),
    ListPlaceNode(
        name="Evaluation Metrics from Validation Data",
        type=EvaluationMetrics,
    ),
    ListPlaceNode(
        name="Accepted Model",
        type=VisionModel,
    ),
    # ReturnedEdgeFromTransition(
    #     transition_node_name="Evaluate Model with Validation Data",
    #     place_node_name="Accepted Model",
    # ),
    ReturnedEdgeFromTransition(
        transition_node_name="Evaluate Model with Validation Data",
        place_node_name="Evaluation Metrics from Validation Data",
    ),
    # Maybe try another model.
    ArgumentEdgeToTransition(
        place_node_name="Evaluation Metrics from Validation Data",
        transition_node_name="Select Another Model or Accept Current Model",
        argument="model",
    ),
    FunctionTransitionNode(name="Select Another Model or Accept Current Model", function=select_another_model),
    ReturnedEdgeFromTransition(
        transition_node_name="Select Another Model or Accept Current Model",
        place_node_name="Proposed Model & Hyperparameters",
    ),
    ReturnedEdgeFromTransition(
        transition_node_name="Select Another Model or Accept Current Model",
        place_node_name="Accepted Model",
    ),
    # Evaluate the model on the test data.
    ArgumentEdgeToTransition(
        place_node_name="Accepted Model",
        transition_node_name="Final Model Evaluation",
        argument="model",
    ),
    ArgumentEdgeToTransition(
        place_node_name="Test Data",
        transition_node_name="Final Model Evaluation",
        argument="data",
    ),
    FunctionTransitionNode(name="Final Model Evaluation", function=evaluate_model),
    ReturnedEdgeFromTransition(
        transition_node_name="Final Model Evaluation",
        place_node_name="Evaluation Metrics from Test Data",
    ),
    ReturnedEdgeFromTransition(
        transition_node_name="Final Model Evaluation",
        place_node_name="Final Selected Model",
    ),
    ListPlaceNode(
        name="Evaluation Metrics from Test Data",
        type=EvaluationMetrics,
    ),
    ListPlaceNode(
        name="Final Selected Model",
        type=VisionModel,
    ),
    # Package the model for deployment.
    ArgumentEdgeToTransition(
        place_node_name="Final Selected Model",
        transition_node_name="Package For Deployment",
        argument="model",
    ),
    FunctionTransitionNode(name="Package For Deployment", function=package_model),
    ReturnedEdgeFromTransition(
        transition_node_name="Package For Deployment",
        place_node_name="Packaged Model",
    ),
    ListPlaceNode(
        name="Packaged Model",
        type=VisionModel,
    ),
    # Carry out final integration tests.
    ArgumentEdgeToTransition(
        place_node_name="Packaged Model",
        transition_node_name="Integration Testing",
        argument="model",
    ),
    FunctionTransitionNode(name="Integration Testing", function=integration_testing),
)
executable_fine_tuning_graph = ExecutableGraphOperations.construct_graph(fine_tuning_nodes_and_edges)

In [None]:
from petritype.core.rustworkx_graph import RustworkxGraph


executable_fine_tuning_pydigraph = RustworkxGraph.from_executable_graph(executable_fine_tuning_graph)

In [None]:
def place_node_label_sans_type(node: ListPlaceNode) -> str:
    label = f"{node.name}"  # \n({node.type.__name__})"
    value_strings = [str(x) for x in node.tokens]
    tokens_string = "\n".join(value_strings)
    return f"{label}\n{tokens_string}"


def transition_node_label_sans_type(node: FunctionTransitionNode) -> str:
    return f"{node.name}"  # \n({node.function.__qualname__})"


def flow_node_attr_fn(node):
    if isinstance(node, ListPlaceNode):
        return {
            "label": place_node_label_sans_type(node),
            'color': 'deepskyblue',
            'style': 'filled',
            'shape': 'oval'
        }
    elif isinstance(node, FunctionTransitionNode):
        return {
            "label": transition_node_label_sans_type(node),
            'color': 'lightgreen',
            'style': 'filled',
            'shape': 'box'
        }
    else:
        raise ValueError("Invalid node data type.")



graphviz_draw(
    executable_fine_tuning_pydigraph,
    node_attr_fn=flow_node_attr_fn,
    # edge_attr_fn=edge_attr_fn,
    method='dot',
)