In [None]:
import os
import json
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from glob import glob

with open("./entityid2label.json", "r", encoding="utf-8") as f:
    entityid2label = json.load(f)


def get_tsv_paths(num_classes, sample_first_batch=False):
    """Load TSV file paths based on class counts."""
    with open("process_p31_p279/class_counts.json", "r", encoding="utf-8") as f:
        class_counts = json.load(f)
    starting_entities = set(list(class_counts.keys())[:num_classes])
    tsv_paths_by_class = {}

    for path in glob(f"./extracted_paths/*/*.tsv"):
        class_dir = os.path.basename(os.path.dirname(path))
        if class_dir in starting_entities:
            tsv_paths_by_class.setdefault(class_dir, []).append(path)

    tsv_paths = []
    if sample_first_batch:
        for class_dir, paths in tsv_paths_by_class.items():
            batch1_files = [p for p in paths if "batch_1" in os.path.basename(p)]
            if batch1_files:
                tsv_paths.append(batch1_files[0])
            else:
                tsv_paths.append(paths[0])
    else:
        for paths in tsv_paths_by_class.values():
            tsv_paths.extend(paths)

    print(f"Found {len(tsv_paths)} TSV files.")
    return tsv_paths


def read_tsv(path):
    """Read paths from a TSV file."""
    with open(path, "r", encoding="utf-8") as tf:
        for line in tf:
            path_entities = line.strip().split("\t")
            yield path_entities


def load_paths(num_classes, sample_first_batch):
    """Load all paths from TSV files."""
    tsv_paths = get_tsv_paths(num_classes, sample_first_batch)
    paths = []
    for tsv_path in tsv_paths:
        for path in read_tsv(tsv_path):
            paths.append(path)
    print(f"Loaded {len(paths)} total paths.")
    return paths


def enrich_paths(paths, entityid2label):
    """Enrich paths with labels."""
    enriched_paths = []
    for path in paths:
        enriched_path = []
        for node in path:
            enriched_path.append(f"{entityid2label[node]}\n({node})")
        enriched_paths.append(enriched_path)
    return enriched_paths


def create_graph_from_paths(paths):
    """Create a directed graph from paths."""
    G = nx.DiGraph()

    # Add edges from all paths
    for path in paths:
        for i in range(len(path) - 1):
            G.add_edge(path[i], path[i + 1])

    return G


def analyze_graphs(G):
    """Find all connected components and analyze them."""
    # Find weakly connected components
    components = list(nx.weakly_connected_components(G))
    component_sizes = [len(comp) for comp in components]

    print(f"\nGraph Analysis Results:")
    print(
        f"Total graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges"
    )
    print(f"Found {len(components)} separate connected components")
    print(
        f"Component sizes: {sorted(component_sizes, reverse=True)[:10]} (showing top 10)"
    )

    # Extract component subgraphs for further analysis
    subgraphs = []
    for i, component in enumerate(sorted(components, key=len, reverse=True)):
        subgraph = G.subgraph(component).copy()
        subgraphs.append(subgraph)
        if i < 5:  # Print stats for the 5 largest components
            print(
                f"Component {i+1}: {subgraph.number_of_nodes()} nodes, {subgraph.number_of_edges()} edges"
            )

    return subgraphs


# 1a. Read paths
paths = load_paths(num_classes=100, sample_first_batch=False)

# 1b. Enrich paths with labels
paths = enrich_paths(paths, entityid2label)

# 2. Create graph from paths
G = create_graph_from_paths(paths)

# 3. Find N graphs and analyze them
subgraphs = analyze_graphs(G)

# 4. Choose the biggest component
largest_graph = subgraphs[0]
print(f"\nAnalyzing the largest component:")
print(
    f"Largest component has {largest_graph.number_of_nodes()} nodes and {largest_graph.number_of_edges()} edges"
)


def save_graph_to_json(G, filename):
    """Save a NetworkX graph as a JSON file in node-link format."""
    # Add labels to nodes if they don't have them
    for node in G.nodes():
        if "label" not in G.nodes[node]:
            G.nodes[node]["label"] = str(node)

    # Convert the graph to node-link format
    data = nx.node_link_data(G)

    # Save to file
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2)

    print(f"Graph saved to {filename}")
    return filename


save_graph_to_json(largest_graph, "largest_component.json")

Found 919 TSV files.
Loaded 507745 total paths.

Graph Analysis Results:
Total graph has 2757 nodes and 4964 edges
Found 11 separate connected components
Component sizes: [2718, 7, 7, 6, 6, 3, 2, 2, 2, 2] (showing top 10)
Component 1: 2718 nodes, 4935 edges
Component 2: 7 nodes, 6 edges
Component 3: 7 nodes, 6 edges
Component 4: 6 nodes, 5 edges
Component 5: 6 nodes, 5 edges

Analyzing the largest component:
Largest component has 2718 nodes and 4935 edges
Graph saved to largest_component.json


'largest_component.json'