In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import networkx as nx
from tqdm import tqdm
import json
import random
from collections import defaultdict
from typing import Dict, List

plt.style.use('ggplot')

DATA_DIR = '/home/cc/PHD/ragkg/new_trees'
ALL_DIRS = [os.path.join(DATA_DIR, x) for x in os.listdir(DATA_DIR)]
ALL_FILES = [os.path.join(x, y) for x in ALL_DIRS for y in os.listdir(x)]

In [2]:
def parse_tree_to_network(json_data):
    """
    Parse a nested JSON tree structure into nodes and edges lists suitable for NetworkX.
    
    Args:
        json_data (dict): The JSON tree structure
        
    Returns:
        tuple: (nodes, edges) where:
            - nodes is a list of (node_id, attr_dict) tuples
            - edges is a list of (source, target, attr_dict) tuples
    """
    nodes = []
    edges = []

    if 'decision tree' in json_data.keys():
        json_data = json_data['decision tree']
        if type(json_data) == str:
            json_data = json.loads(json_data)

    
    def process_node(node_data, parent=None):
        # Extract node information
        node_id = node_data['node']
        if node_id is None:
            node_id = node_data['content']
        node_attrs = {
            'content': node_data['content']
        }
        
        # Add node to nodes list
        nodes.append((node_id, node_attrs))
        
        # If this node has a parent, create an edge
        if parent is not None:
            edges.append((parent, node_id, {}))
            
        # Process children recursively
        if 'children' in node_data:
            for child in node_data['children']:
                process_node(child, node_id)
    
    # Start processing from root
    process_node(json_data)
    
    return nodes, edges

In [3]:
def find_source_node(graph: nx.DiGraph) -> str:

    source_nodes = [node for node in graph.nodes() if graph.in_degree(node) == 0]
    
    if not source_nodes:
        raise ValueError("No source nodes found in the graph")
    
    if len(source_nodes) > 1:
        raise ValueError(f"Multiple source nodes found: {source_nodes}. Expected exactly one source node.")
        
    return source_nodes[0]

source = ""

# try:
#     source = find_source_node(G)
#     print(f"Single source node: {source}")
# except ValueError as e:
#     print(f"Error: {e}")

In [4]:
def get_random_path_to_leaf(graph: nx.DiGraph, root: str) -> list:
    """
    Extract a random path from root node to a leaf node in a directed graph.
    
    Args:
        graph (nx.DiGraph): The input directed graph
        root (str): Starting node
        
    Returns:
        list: A list of nodes representing the path from root to a leaf node
    """
    # Verify root exists in graph
    if root not in graph:
        raise ValueError(f"Root node {root} not found in graph")
    
    path = [root]
    current = root
    
    while True:
        # Get all successors of current node
        successors = list(graph.successors(current))
        
        # If no successors, we've reached a leaf node
        if not successors:
            break
            
        # Randomly choose next node
        current = random.choice(successors)
        path.append(current)
    
    return path

def get_all_paths_to_leaves(graph: nx.DiGraph, source: str) -> dict:
    """
    Extract all possible paths from source to leaves, inserting the 'content' property for each node.
    If 'content' is empty or None, use the node name instead.

    Args:
        graph (nx.DiGraph): The input directed graph where nodes may have a 'content' property.
        source (str): Starting node (source).

    Returns:
        dict: Dictionary where keys are leaf nodes and values are lists of paths.
              Each path element shows the 'content' of a node or its name if 'content' is empty/None.

    Raises:
        ValueError: If source node is not in graph.
    """
    if source not in graph:
        raise ValueError(f"Source node {source} not found in graph")

    # Find all leaf nodes
    leaf_nodes = [node for node in graph.nodes() if graph.out_degree(node) == 0]
    paths_by_leaf = defaultdict(list)

    # Helper function to get 'content' or fall back to node name
    def get_node_representation(node):
        content = graph.nodes[node].get('content')
        return content if content else node

    # Find and format all paths to each leaf
    for leaf in leaf_nodes:
        simple_paths = list(nx.all_simple_paths(graph, source, leaf))
        formatted_paths = []
        for path in simple_paths:
            formatted_path = [get_node_representation(node) for node in path]
            formatted_paths.append(formatted_path)
        if formatted_paths:
            paths_by_leaf[get_node_representation(leaf)] = formatted_paths

    return dict(paths_by_leaf)

# def print_paths_with_content(paths):
#     total_paths = sum(len(paths_list) for paths_list in paths.values())
#     print(f"\nFound {total_paths} total paths to {len(paths)} leaf nodes:")
    
#     for leaf, leaf_paths in paths.items():
#         print(f"\nPaths to leaf {leaf}:")
#         for i, path in enumerate(leaf_paths, 1):
#             print(f"  Path {i}: {' -> '.join(str(node) for node in path)}")


# paths = get_all_paths_to_leaves(G, source)
# print_paths_with_content(paths)

In [5]:
def paths_to_csv(paths: Dict, source: str, filename: str) -> pd.DataFrame:
    """
    Convert paths dictionary to CSV file using pandas with format:
    source (string), leaf (string), paths (List of string)
    
    Args:
        paths (Dict): Dictionary of paths as returned by get_all_paths_with_content
        source (str): Source node content or name
        filename (str): Output CSV filename
        
    Returns:
        pd.DataFrame: DataFrame containing the paths data
    """
    # Create lists for DataFrame
    rows = []
    
    for leaf, leaf_paths in paths.items():
        paths_str = '||'.join([' -> '.join(str(node) for node in path) for path in leaf_paths])

        rows.append({
            'source': source,
            'leaf': leaf,
            'paths': paths_str
        })

    df = pd.DataFrame(rows)

    df.to_csv("/home/cc/PHD/ragkg/paths/"+filename+'.csv', index=False)
    
    return df

# for f in ALL_FILES[:4]:
#     try:
#         t = json.loads(open(f).read())
#         nodes, edges = parse_tree_to_network(t)
        
#         G = nx.DiGraph()
#         G.add_nodes_from(nodes)
#         G.add_edges_from(edges)

#         print(dict(G.nodes(data=True)))

#         # draw with tree layout
#         plt.figure(figsize=(12, 12))
#         # pos = nx.nx_agraph.graphviz_layout(G, prog='dot')
#         nx.draw(G, pos=nx.spring_layout(G), with_labels=True, node_size=1000, font_size=10)
#         plt.show()

#     except Exception as e:
#         print(f, e)
#         print(t)
#         continue

graphs = []

for f in ALL_FILES:
    source = ""

    try:
        t = json.loads(open(f).read())
        nodes, edges = parse_tree_to_network(t)
        
        G = nx.DiGraph()
        G.add_nodes_from(nodes)
        G.add_edges_from(edges)

        # graphs.append(G)
        source = find_source_node(G)
        paths = get_all_paths_to_leaves(G, source)

        paths_to_csv(paths, source, f.split('/')[-1].split('.')[0])

    except Exception as e:
        print(f, e)
        print(t)
        continue

In [6]:
# dict(G.nodes(data=True))