<a href="https://colab.research.google.com/github/nithinivi/alchemy/blob/main/ATN_for_Attribute_Based_Graph_Lookup.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import sys

# A simple, dictionary-based graph representation.
# Each key is a node, and its value is a dictionary of its attributes.
# The 'connections' key within a node's attributes maps to a list of
# tuples: (relationship_type, target_node_id).
SIMPLE_GRAPH = {
    'person_1': {
        'type': 'Person', 'name': 'Alice', 'location': 'New York',
        'connections': [
            ('lives_in', 'city_1'),
            ('works_at', 'company_1')
        ]
    },
    'person_2': {
        'type': 'Person', 'name': 'Bob', 'location': 'Boston',
        'connections': [
            ('lives_in', 'city_2'),
            ('works_at', 'company_2')
        ]
    },
    'person_3': {
        'type': 'Person', 'name': 'Charlie', 'location': 'New York',
        'connections': [
            ('lives_in', 'city_1'),
            ('works_at', 'company_2')
        ]
    },
    'city_1': {
        'type': 'City', 'name': 'New York', 'state': 'NY',
        'connections': []
    },
    'city_2': {
        'type': 'City', 'name': 'Boston', 'state': 'MA',
        'connections': []
    },
    'company_1': {
        'type': 'Company', 'name': 'TechCorp', 'industry': 'Tech',
        'connections': []
    },
    'company_2': {
        'type': 'Company', 'name': 'DataCorp', 'industry': 'Data Science',
        'connections': []
    }
}

In [3]:

class ATNNode:
    """Represents a state in the ATN."""
    def __init__(self, name):
        self.name = name
        self.transitions = []

class ATNTransition:
    """Represents a transition (arc) in the ATN."""
    def __init__(self, next_state, arc_type, arc_label=None, test=None, action=None):
        self.next_state = next_state
        self.arc_type = arc_type
        self.arc_label = arc_label
        self.test = test if test else lambda node, registers: True
        self.action = action if action else lambda node, registers: None

In [4]:



class ATN:
    """
    An Augmented Transition Network for attribute-based graph lookup.

    This class manages the states, transitions, registers, and the traversal logic.
    """
    def __init__(self, graph):
        self.graph = graph
        self.nodes = {}
        self.registers = {}
        self.result = []
        self.path = [] # To store the nodes visited

    def add_node(self, name):
        """Adds a state (ATNNode) to the network."""
        self.nodes[name] = ATNNode(name)
        return self.nodes[name]

    def add_transition(self, from_node_name, next_state_name, arc_type, **kwargs):
        """Adds a transition (arc) between states."""
        from_node = self.nodes.get(from_node_name)
        next_state = self.nodes.get(next_state_name)
        if not from_node or not next_state:
            raise ValueError("Invalid node names for transition")

        transition = ATNTransition(next_state, arc_type, **kwargs)
        from_node.transitions.append(transition)

    def traverse(self, start_node_id, initial_atn_state_name):
        """
        Starts the traversal from a given graph node and ATN state.

        This uses a recursive, depth-first approach with backtracking.
        """
        self.path.append(start_node_id)
        current_atn_node = self.nodes.get(initial_atn_state_name)

        if not current_atn_node:
            print("Error: Initial ATN state not found.")
            return

        try:
            self._traverse_recursive(start_node_id, current_atn_node)
        except RecursionError:
            print("Recursion depth exceeded. Check for infinite loops in your ATN.")
            sys.setrecursionlimit(1000) # Reset to default
        finally:
            return self.result

    def _traverse_recursive(self, current_graph_node_id, current_atn_node):
        """The recursive engine for the ATN traversal."""

        # Base Case: If we've reached a final ATN state
        if current_atn_node.name == 'FinalState':
            # Check for a complete path in the graph
            if len(self.path) > 1: # Ensures we've moved from the start
                self.result.append(list(self.path))
                return True
            else:
                return False

        # Get the current graph node's data
        current_graph_node = self.graph.get(current_graph_node_id)
        if not current_graph_node:
            return False

        # Iterate through all possible transitions from the current ATN state
        for transition in current_atn_node.transitions:

            # --- PUSH Transition (Moves to a sub-network) ---
            if transition.arc_type == 'PUSH':
                # The 'push' is to a sub-network defined in another part of the ATN
                sub_atn = transition.arc_label
                if sub_atn not in self.nodes:
                    print(f"Error: PUSH to undefined sub-network '{sub_atn}'.")
                    continue

                # Recursively call the sub-network traversal
                self.path.append(f"PUSH:{sub_atn}")
                if self._traverse_recursive(current_graph_node_id, self.nodes[sub_atn]):
                    self.path.pop()
                    # If sub-network succeeds, move to the next state after the PUSH
                    if self._traverse_recursive(current_graph_node_id, transition.next_state):
                        return True
                self.path.pop()

            # --- JUMP Transition (Skip without consuming a node) ---
            elif transition.arc_type == 'JUMP':
                self.path.append(f"JUMP:{transition.next_state.name}")
                if self._traverse_recursive(current_graph_node_id, transition.next_state):
                    self.path.pop()
                    return True
                self.path.pop()

            # --- Node-based Transitions (CAT, WORD, etc.) ---
            elif transition.arc_type in ['CAT', 'WORD']:
                # Iterate through the outgoing connections of the current graph node
                for rel_type, next_node_id in current_graph_node.get('connections', []):
                    next_graph_node = self.graph.get(next_node_id)

                    # 1. Check if the connection matches the arc label (e.g., 'works_at')
                    if rel_type != transition.arc_label:
                        continue

                    # 2. Evaluate the Test condition
                    if not transition.test(next_graph_node, self.registers):
                        continue

                    # 3. Execute the Action to update registers
                    transition.action(next_graph_node, self.registers)

                    # 4. Recursively call for the next state and next graph node
                    self.path.append(next_node_id)
                    if self._traverse_recursive(next_node_id, transition.next_state):
                        return True
                    self.path.pop() # Backtrack

        # No successful path found from this state
        return False


In [5]:
def setup_atn():
    """Defines the ATN grammar for the graph query."""
    atn = ATN(SIMPLE_GRAPH)

    # 1. Define the main states
    start = atn.add_node('Start')
    person_node = atn.add_node('FindPerson')
    company_node = atn.add_node('FindCompany')
    final = atn.add_node('FinalState')

    # Define a sub-network for finding a city based on its state attribute
    find_city = atn.add_node('FindCity')
    city_by_state = atn.add_node('CityByState')

    # 2. Define the main network transitions

    # Start -> FindPerson (Test on 'type' attribute)
    atn.add_transition(
        'Start', 'FindPerson', arc_type='JUMP',
        test=lambda node, registers: node.get('type') == 'Person'
    )

    # FindPerson -> FindCompany (Traversal on 'works_at' relationship)
    # Action: Store the person's location in a register for later use.
    atn.add_transition(
        'FindPerson', 'FindCompany', arc_type='CAT', arc_label='works_at',
        action=lambda node, registers: registers.update({'person_location': node.get('location')})
    )

    # FindCompany -> FinalState (Traversal on 'lives_in' relationship, with a Test)
    # Test: The company's location must match the person's stored location.
    atn.add_transition(
        'FindCompany', 'FinalState', arc_type='CAT', arc_label='lives_in',
        test=lambda node, registers: node.get('name') == registers.get('person_location')
    )

    # 3. Define the sub-network transitions

    # This is a PUSH transition example
    # FindPerson -> FindCity (PUSH)
    atn.add_transition(
        'FindPerson', 'FindCity', arc_type='PUSH', arc_label='FindCity'
    )

    # FindCity -> CityByState (Traverse on 'lives_in')
    atn.add_transition(
        'FindCity', 'CityByState', arc_type='CAT', arc_label='lives_in'
    )

    # CityByState -> FinalState (POP)
    # The POP arc would ideally return the found city, but for simplicity, we just
    # go to the FinalState.
    atn.add_transition(
        'CityByState', 'FinalState', arc_type='JUMP'
    )

    return atn

def main():
    """Main execution block."""
    print("Graph Traversal using an Augmented Transition Network (ATN)")
    print("----------------------------------------------------------")
    print("Goal: Find a path from a 'Person' node, following a 'works_at' relationship,")
    print("      to a 'Company' node, then to a 'City' node with the same name as")
    print("      the person's location.")
    print("Query: Person -> (works_at) -> Company -> (lives_in) -> City(name=Person.location)")

    # Instantiate and set up the ATN with the query logic
    atn = setup_atn()

    # Find a starting node in the graph
    start_node_id = 'person_1'
    print(f"\nStarting traversal from graph node: '{start_node_id}'")

    # Start the traversal
    result_paths = atn.traverse(start_node_id, 'Start')

    print("\nSearch complete. Found paths:")
    if result_paths:
        for path in result_paths:
            print(f" -> {path}")
    else:
        print("No valid paths found.")

In [6]:
main()

Graph Traversal using an Augmented Transition Network (ATN)
----------------------------------------------------------
Goal: Find a path from a 'Person' node, following a 'works_at' relationship,
      to a 'Company' node, then to a 'City' node with the same name as
      the person's location.
Query: Person -> (works_at) -> Company -> (lives_in) -> City(name=Person.location)

Starting traversal from graph node: 'person_1'

Search complete. Found paths:
 -> ['person_1', 'JUMP:FindPerson', 'PUSH:FindCity', 'city_1', 'JUMP:FinalState']
 -> ['person_1', 'JUMP:FindPerson', 'PUSH:FindCity', 'city_1', 'JUMP:FinalState']
