In [None]:
import json
import os
import glob
import logging
import pydot # Import pydot
import re # Import regex for label parsing
# Make sure pydot.Error is accessible if needed for exception handling
from pydot import Error as PydotError

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define paths
base_cpg_path = "data_java/cpg-output"
json_path = "data_java/center_nodes_result_specific.json"
output_base_path = "data_java/subgraph_contexts"

# --- New: Define allowed neighbor labels ---
allowed_neighbor_labels = {
    'arrayInitializer', 'CatchClause', 'stonesoup_array', 'assignment',
    'fieldAccess', 'addition', 'CONTROL_STRUCTURE', 'FIELD_IDENTIFIER',
    'cast', 'IDENTIFIER', 'indexAccess', 'logicalAnd', 'CALL',
    'logicalNot', 'alloc'
}
# Regex to extract the first word from the label attribute, assuming it's the type
label_type_pattern = re.compile(r'^"?([a-zA-Z_<>]+)')

# Load the center nodes data from JSON
try:
    with open(json_path, 'r') as f:
        center_nodes_data = json.load(f)
    logging.info(f"Successfully loaded center nodes data from {json_path}")
except FileNotFoundError:
    logging.error(f"Error: JSON file not found at {json_path}")
    raise # Stop execution if JSON is missing
except json.JSONDecodeError:
    logging.error(f"Error: Could not decode JSON from {json_path}")
    raise # Stop execution if JSON is invalid

# Create the output directory if it doesn't exist
os.makedirs(output_base_path, exist_ok=True)
logging.info(f"Ensured output directory exists: {output_base_path}")

# Helper function to get node type from label attribute
def get_node_type_from_attributes(attrs):
    label_str = attrs.get('label')
    if label_str:
        match = label_type_pattern.match(label_str)
        if match:
            return match.group(1)
    return None

# Process each entry in the JSON data
for folder_name, center_node_ids_str in center_nodes_data.items():
    logging.info(f"Processing folder: {folder_name}")
    # Use the center node IDs from JSON directly (they are strings without quotes)
    center_node_ids = set(center_node_ids_str)
    folder_path = os.path.join(base_cpg_path, folder_name)

    # Find the .dot file
    dot_files = glob.glob(os.path.join(folder_path, '*.dot'))

    if not dot_files:
        logging.warning(f"  No .dot file found in {folder_path}. Skipping.")
        continue
    if len(dot_files) > 1:
        logging.warning(f"  Multiple .dot files found in {folder_path}. Using the first one: {dot_files[0]}.")

    dot_file_path = dot_files[0]
    logging.info(f"  Using .dot file: {dot_file_path}")

    relevant_lines = set()
    # Nodes to include: Start with center nodes, add *filtered* neighbors later
    nodes_to_include = set(center_node_ids)
    nodes_definitions_added = set() # Track which node definitions we actually added (unquoted IDs)

    try:
        # Parse the dot file using pydot
        logging.info(f"  Parsing {dot_file_path} with pydot...")
        graphs = pydot.graph_from_dot_file(dot_file_path)

        if not graphs:
            logging.warning(f"  pydot could not parse any graph from {dot_file_path}. Skipping.")
            continue

        if isinstance(graphs, list) and len(graphs) > 0 and isinstance(graphs[0], (pydot.Graph, pydot.Dot)):
            graph = graphs[0] # Assign the first graph object
            logging.info(f"  Successfully parsed graph.")
        else:
            logging.error(f"  pydot.graph_from_dot_file did not return a valid graph object for {dot_file_path}. Found type: {type(graphs[0]) if graphs else 'None'}. Skipping.")
            continue

        # --- Build a map for quick node lookup by unquoted ID ---
        logging.info("  Building node map...")
        node_map = {}
        for node in graph.get_nodes():
             # Use node.get_name() which might include quotes, strip them for the key
             unquoted_id = node.get_name().strip('"')
             # Store the original node object
             node_map[unquoted_id] = node
        logging.info(f"  Built map with {len(node_map)} nodes.")


        # Iterate through edges to find connections involving center nodes AND filter neighbors
        logging.info(f"  Finding relevant edges and filtering neighbors by label...")
        edges_processed = 0
        edges_added = 0
        neighbors_added = set()

        for edge in graph.get_edges():
            edges_processed += 1
            source_id = edge.get_source().strip('"')
            dest_id = edge.get_destination().strip('"')
            is_relevant_edge = False
            neighbor_to_add = None

            # Check connection: Center -> Potential Neighbor
            if source_id in center_node_ids:
                potential_neighbor_node = node_map.get(dest_id)
                if potential_neighbor_node:
                    node_type = get_node_type_from_attributes(potential_neighbor_node.get_attributes())
                    if node_type in allowed_neighbor_labels:
                        neighbor_to_add = dest_id
                        is_relevant_edge = True

            # Check connection: Potential Neighbor -> Center
            elif dest_id in center_node_ids: # Use elif to avoid adding edge twice if both are centers
                potential_neighbor_node = node_map.get(source_id)
                if potential_neighbor_node:
                    node_type = get_node_type_from_attributes(potential_neighbor_node.get_attributes())
                    if node_type in allowed_neighbor_labels:
                        neighbor_to_add = source_id
                        is_relevant_edge = True

            if is_relevant_edge:
                if neighbor_to_add:
                    nodes_to_include.add(neighbor_to_add)
                    neighbors_added.add(neighbor_to_add)
                # Add the original edge string representation
                relevant_lines.add(edge.to_string().strip())
                edges_added += 1

        logging.info(f"  Processed {edges_processed} edges. Added {edges_added} relevant edges.")
        logging.info(f"  Added {len(neighbors_added)} neighbors based on label criteria.")
        logging.info(f"  Total nodes to include (centers + filtered neighbors): {len(nodes_to_include)}")


        # Iterate through nodes to get the definitions for all nodes_to_include
        logging.info(f"  Extracting node definitions for included nodes...")
        nodes_processed = 0
        for unquoted_id, node_obj in node_map.items():
            if unquoted_id in nodes_to_include:
                relevant_lines.add(node_obj.to_string().strip())
                nodes_definitions_added.add(unquoted_id) # Add the unquoted name
                nodes_processed += 1
        logging.info(f"  Extracted definitions for {len(nodes_definitions_added)} included nodes.")

        # Sanity checks / Warnings (using unquoted IDs for comparison)
        missing_center_defs = center_node_ids - nodes_definitions_added
        if missing_center_defs:
             # This might be expected if a center node itself doesn't have one of the allowed labels and has no allowed neighbors
             logging.debug(f"  Definitions for some center nodes might be missing if they have no allowed neighbors: {missing_center_defs}")

        # Calculate missing neighbor definitions (also using unquoted IDs)
        neighbor_ids_in_final_set = nodes_to_include - center_node_ids
        missing_neighbor_defs = neighbor_ids_in_final_set - nodes_definitions_added
        if missing_neighbor_defs:
             # This should ideally not happen if the node map was built correctly
             logging.warning(f"  Could not find definitions for all *filtered* neighbor nodes in {dot_file_path}. Missing: {missing_neighbor_defs}")

        # Write the collected lines to the output file
        output_file_path = os.path.join(output_base_path, f"{folder_name}_context.txt")
        sorted_lines = sorted(list(relevant_lines))

        with open(output_file_path, 'w', encoding='utf-8') as f_out:
            for line in sorted_lines:
                f_out.write(f"  {line}\n")

        logging.info(f"  Successfully wrote {len(sorted_lines)} lines to {output_file_path}")

    except FileNotFoundError:
        logging.error(f"  Error: .dot file not found at {dot_file_path}")
    except PydotError as e:
         logging.error(f"  A pydot library error occurred processing {dot_file_path}: {e}")
    except Exception as e:
        logging.error(f"  An unexpected {type(e).__name__} occurred while processing {dot_file_path}: {e}", exc_info=True)

logging.info("Subgraph extraction process finished.")

In [None]:
import os
import glob
import logging
import pydot
from IPython.display import Image, display
from pydot import Error as PydotError

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define paths
context_base_path = "data_java/subgraph_contexts"
visualization_output_path = "data_java/subgraph_visualizations"

# Create the visualization output directory if it doesn't exist
os.makedirs(visualization_output_path, exist_ok=True)
logging.info(f"Ensured output directory exists: {visualization_output_path}")

# Find all context files
context_files = glob.glob(os.path.join(context_base_path, '*_context.txt'))

if not context_files:
    logging.warning(f"No context files found in {context_base_path}. Nothing to visualize.")
else:
    logging.info(f"Found {len(context_files)} context files to visualize.")

# Process and visualize each context file
for context_file_path in context_files:
    base_name = os.path.basename(context_file_path).replace('_context.txt', '')
    logging.info(f"Processing visualization for: {base_name}")
    output_png_path = os.path.join(visualization_output_path, f"{base_name}_subgraph.png")

    try:
        # Read the subgraph content
        logging.info(f"  Reading context file: {context_file_path}")
        with open(context_file_path, 'r', encoding='utf-8') as f_in:
            subgraph_content = f_in.read()
        logging.info(f"  Successfully read context file.")

        # Wrap the content in a valid DOT structure
        dot_string = f"digraph \"{base_name}_subgraph\" {{\n graph [rankdir=LR];\n node [shape=box, fontname=\"Courier New\"];\n edge [arrowsize=0.5, fontsize=8];\n{subgraph_content}\n}}"

        # Parse the DOT string using pydot
        logging.info("  Parsing DOT data...")
        graphs = pydot.graph_from_dot_data(dot_string)

        if not graphs:
            logging.warning(f"  Could not parse DOT data from generated string for {base_name}. Skipping visualization.")
            continue

        graph = graphs[0] # Assume the first graph is the one we want
        logging.info("  Successfully parsed DOT data.")

        # --- NEW APPROACH: Use MUCH simpler label modification ---
        logging.info("  Simplifying approach to modify node labels...")
        nodes_modified = 0
        for node in graph.get_nodes():
            try:
                node_id = node.get_name()  # Keep the quotes, simpler
                
                # SIMPLIFICATION: Just set a basic label with the ID
                # This avoids all the escaping issues by keeping things simple
                node.set('label', f"NODE {node_id}")
                nodes_modified += 1
            except Exception as label_err:
                logging.warning(f"    Could not modify label for node {node.get_name()}: {label_err}")
        logging.info(f"  Modified labels for {nodes_modified} nodes using simplified approach.")
        # --- End New Approach ---

        # Generate the PNG image
        try:
            logging.info(f"  Attempting to render PNG to {output_png_path}...")
            # Instead of write_png which relies on external dot command, try write with plain format
            # first to test if the DOT data is valid
            test_txt_path = f"{output_png_path}.txt"
            graph.write(test_txt_path, format="plain")
            # graph.write(test_txt_path)
            logging.info(f"  Successfully wrote plain DOT text to {test_txt_path}.")
            
            # Now try PNG
            png_created = graph.write_png(output_png_path)
            if png_created is False:
                logging.error(f"  graph.write_png returned False for {output_png_path}.")
                continue

            logging.info(f"  Successfully rendered PNG.")
            
            # Display the image in the notebook
            display(Image(filename=output_png_path))
            print("-" * 40)
            logging.info(f"  Successfully displayed image.")

        except FileNotFoundError:
            logging.error(f"  Error: Failed to find or create file. Check permissions.")
        except PydotError as pe:
            logging.error(f"  Pydot error during file operations: {pe}")
        except AssertionError as ae:
            logging.error(f"  AssertionError during rendering: {ae}")
            # Try to extract and print graphviz's stderr output from the error message
            print(str(ae))
            if "returned code: 1" in str(ae):
                print("\nstdout, stderr:\n", " b''")  # Placeholder pattern
                # Try to work around by generating SVG instead of PNG
                try:
                    logging.info(f"  Trying alternative SVG format...")
                    svg_path = f"{output_png_path}.svg"
                    graph.write_svg(svg_path)
                    logging.info(f"  Successfully wrote SVG to {svg_path}")
                    display(Image(filename=svg_path))
                except Exception as svg_err:
                    logging.error(f"  SVG fallback also failed: {svg_err}")
        except Exception as e:
            logging.error(f"  Unexpected error: {e}", exc_info=True)

    except FileNotFoundError:
        logging.error(f"  Error: Context file not found at {context_file_path}")
    except PydotError as parse_err:
        logging.error(f"  Pydot error parsing DOT data: {parse_err}")
    except Exception as e:
        logging.error(f"  Unexpected error processing {base_name}: {e}", exc_info=True)

logging.info("Subgraph visualization process finished.")