In [2]:
import sys
import os
sys.path.append(os.path.abspath(".."))

from primeGraph.models import GraphState
from primeGraph.buffer import History
from primeGraph import Graph, START, END
from primeGraph.graph.engine import GraphExecutor
import time

from rich import print as rprint

import logging

logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

class StateForTestWithHistory(GraphState):
    execution_order: History[str]


def test_graph(storage=None):
    state = StateForTestWithHistory(execution_order=[])
    graph = Graph(state=state, checkpoint_storage=storage)

    @graph.node()
    def task1(state):
        time.sleep(1)
        print("task1")
        return {"execution_order": "task1"}

    @graph.node(interrupt="after")
    def task2(state):
        time.sleep(2)
        print("task2")
        return {"execution_order": "task2"}

    @graph.node()
    def task3(state):
        time.sleep(2)
        print("task3")
        return {"execution_order": "task3"}

    @graph.node()
    def task4(state):
        time.sleep(1)
        print("task4")
        return {"execution_order": "task4"}
    
    @graph.node()
    def task5(state):
        time.sleep(1)
        print("task5")
        return {"execution_order": "task5"}
    
    @graph.node()
    def task6(state):
        time.sleep(1)
        print("task6")
        return {"execution_order": "task6"}

    # Create parallel paths
    graph.add_edge(START, "task1")
    graph.add_edge("task1", "task2")
    graph.add_edge("task1", "task3")
    graph.add_edge("task2", "task4")
    graph.add_edge("task3", "task5")
    graph.add_edge("task4", "task6")
    graph.add_edge("task5", "task6")
    graph.add_edge("task6", END)

    graph.compile()

    return graph


In [None]:
from primeGraph.checkpoint.local_storage import LocalStorage

graph = test_graph(storage=LocalStorage())
graph.visualize()

In [None]:
import asyncio

# In your notebook
executor = GraphExecutor(graph)
task = asyncio.create_task(executor.execute())  # Start execution

# Wait a bit for it to hit the interrupt
await asyncio.sleep(10)  

# Then call resume to continue
# executor.resume()

In [4]:
executor.resume()

In [None]:
rprint(graph.checkpoint_storage._storage)

In [4]:
last_checkpoint = list(graph.checkpoint_storage._storage[list(graph.checkpoint_storage._storage.keys())[0]].keys())[-1]
last_checkpoint_state = graph.checkpoint_storage._storage[list(graph.checkpoint_storage._storage.keys())[0]][last_checkpoint]

executor = GraphExecutor(graph)

executor.load_full_state(last_checkpoint_state.engine_state)

In [None]:
executor.resume()
await executor.execute()

In [None]:
graph.chain_status