In [None]:
import os
import re
import json
from collections import defaultdict, Counter

In [None]:
def load_nodes_from_simplified_json(dir_path):
    """
    Iterate through .json files in the directory and collect data in the format:
    { "category": "yearXXXX" / "time_invariant",
      "filename": "...",
      "nodes": set([...]) }
    """
    results = []
    year_pattern = re.compile(r"(\d{4})")  # 4-digit year

    for filename in os.listdir(dir_path):
        if not filename.endswith(".json"):
            continue
        filepath = os.path.join(dir_path, filename)
        with open(filepath, "r", encoding="utf-8") as f:
            data = json.load(f)

        node_dict = data.get("nodes", {})
        node_set = set(node_dict.keys())

        # Determine the category
        match_year = year_pattern.search(filename)
        if match_year:
            year_str = match_year.group(1)
            category = f"year{year_str}"
        else:
            category = "time_invariant"  # Assign to time_invariant if no year found

        results.append({
            "category": category,
            "filename": filename,
            "nodes": node_set
        })

    return results

def get_nodes_in_all_files(files_and_nodes):
    """
    Find nodes that appear in all files (intersection)
    """
    if not files_and_nodes:
        return set()

    common_nodes = set(files_and_nodes[0][1])
    for _, node_set in files_and_nodes[1:]:
        common_nodes &= node_set

    return common_nodes

def get_nodes_in_threshold(files_and_nodes, threshold_ratio=0.9):
    """
    Count how often each node appears across files in a category,
    and treat nodes with appearance ratio above threshold_ratio (default: 0.99) as major.
    """
    node_counter = Counter()
    n_files = len(files_and_nodes)
    for _, node_set in files_and_nodes:
        for nd in node_set:
            node_counter[nd] += 1

    threshold_count = int(threshold_ratio * n_files + 0.9999999)
    major_nodes = set(nd for nd, cnt in node_counter.items() if cnt >= threshold_count)
    return major_nodes

def analyze_nodes_by_category(data_list, threshold_ratio=0.9):
    """
    data_list: [{ "category":..., "filename":..., "nodes": set([...]) }, ... ]

    Key logic:
    - A.2: Collect nodes that appear at least once in both temporal and time_invariant categories (excluding A.1).
    - B, C, D: Follow the original 99% threshold logic.
    """
    # (1) Build mapping from category to list of (filename, node_set)
    cat_to_file_nodes = defaultdict(list)
    for item in data_list:
        cat = item["category"]
        fn = item["filename"]
        ns = item["nodes"]
        cat_to_file_nodes[cat].append((fn, ns))

    # (2) Get all category names
    all_categories = sorted(cat_to_file_nodes.keys())

    # (3) A.1: Nodes that appear in 100% of all files (across all categories)
    all_files_and_nodes = [(item["filename"], item["nodes"]) for item in data_list]
    common_in_all = get_nodes_in_all_files(all_files_and_nodes)

    # (4) A.2: Nodes that appear at least once in both temporal and time_invariant categories, excluding A.1
    #   4-1) Separate temporal and invariant categories
    temporal_cats = [c for c in all_categories if c.startswith("year")]
    ti_cats = [c for c in all_categories if c == "time_invariant"]

    #   4-2) temporal_all: Union of all nodes in temporal categories
    temporal_all = set()
    for cat in temporal_cats:
        for _, node_set in cat_to_file_nodes[cat]:
            temporal_all |= node_set

    #   4-3) invariant_all: Union of all nodes in time_invariant categories
    invariant_all = set()
    for cat in ti_cats:
        for _, node_set in cat_to_file_nodes[cat]:
            invariant_all |= node_set

    #   4-4) found_in_both: (Intersection) nodes in both, excluding common_in_all
    found_in_both = (temporal_all & invariant_all) - common_in_all

    # (5) Keep track of "99%+ major nodes" for B.1/B.2
    temporal_major_union = set()
    for cat in temporal_cats:
        t_major = get_nodes_in_threshold(cat_to_file_nodes[cat], threshold_ratio=threshold_ratio)
        temporal_major_union |= t_major

    invariant_major_union = set()
    for cat in ti_cats:
        i_major = get_nodes_in_threshold(cat_to_file_nodes[cat], threshold_ratio=threshold_ratio)
        invariant_major_union |= i_major

    # (5) B.1 Temporal-Only, B.2 Invariant-Only
    temporal_only = temporal_major_union - invariant_major_union - common_in_all - found_in_both
    invariant_only = invariant_major_union - temporal_major_union - common_in_all - found_in_both

    # (6) C. Category-Specific Major Nodes (excluding A and B)
    category_major_nodes = {}
    for cat, fnodes_list in cat_to_file_nodes.items():
        major_set = get_nodes_in_threshold(fnodes_list, threshold_ratio=threshold_ratio)
        category_major_nodes[cat] = major_set - common_in_all - found_in_both - temporal_only - invariant_only

    # (7) D. Missing Info
    missing_info = defaultdict(list)
    for cat, fnodes_list in cat_to_file_nodes.items():
        # 1) Get major set for this category (threshold=0.99)
        major_set = get_nodes_in_threshold(fnodes_list, threshold_ratio=threshold_ratio)

        # 2) Exclude A.1 and A.2 (common_in_all and found_in_both)
        c_candidates = major_set - common_in_all - found_in_both - temporal_only - invariant_only
        
        # 3) Define relevant_set as B ∪ C
        if cat.startswith("year"):
            relevant_set = (temporal_only | c_candidates)
        elif cat == "time_invariant":
            relevant_set = (invariant_only | c_candidates)

        # 4) For each file, check which relevant nodes are missing
        for fn, node_set in fnodes_list:
            missed = relevant_set - node_set
            if missed:
                missing_info[cat].append((fn, missed))

    # Combine results into Markdown format
    lines = []
    lines.append("# Analysis of Simplified Graphs\n")

    # A.1
    lines.append("## A.1 Common in All Simplified Graphs (100% in each category)\n")
    lines.append(f"- **Common Nodes**: {sorted(common_in_all)}\n")

    # A.2
    lines.append("## A.2 Found in Both Temporal and Time-Invariant Categories\n")
    lines.append("(Nodes that appear at least once in temporal AND at least once in time_invariant, excluding A.1)\n")
    lines.append(f"- **Found in Both**: {sorted(found_in_both)}\n")

    # B
    lines.append("## B. Temporal vs Time-Invariant\n")
    lines.append("### B.1 Temporal-Only Nodes\n")
    lines.append(f"- {sorted(temporal_only)}\n")
    lines.append("### B.2 Invariant-Only Nodes\n")
    lines.append(f"- {sorted(invariant_only)}\n")

    # C
    lines.append("\n## C. Category-Specific Major Nodes (Excluding A & B)\n")
    lines.append("| Category | Exclusive Major Nodes |\n|---|---|")
    for cat in all_categories:
        excl = category_major_nodes[cat]
        if not excl:
            lines.append(f"| {cat} |  |")
        else:
            lines.append(f"| {cat} | {', '.join(sorted(excl))} |")

    # D
    lines.append("\n## D. Missing Info in Some Files\n")
    lines.append("*(Nodes that are in A.2, B, or C but absent in certain files)*\n")
    lines.append("| Category | Filename | Missing Nodes |\n|---|---|---|")
    for cat in all_categories:
        if cat not in missing_info:
            continue
        for (fn, missed) in missing_info[cat]:
            lines.append(f"| {cat} | {fn} | {', '.join(sorted(missed))} |")

    return "\n".join(lines)


In [None]:
model_name = "Llama-2-7b-chat-hf"
directory_path = f"./graphs/{model_name}"
data_list = load_nodes_from_simplified_json(directory_path)
md_report = analyze_nodes_by_category(data_list, threshold_ratio=0.9)
print(md_report)

# Analysis of Simplified Graphs

## A.1 Common in All Simplified Graphs (100% in each category)

- **Common Nodes**: []

## A.2 Found in Both Temporal and Time-Invariant Categories

(Nodes that appear at least once in temporal AND at least once in time_invariant, excluding A.1)

- **Found in Both**: ['a0.h22', 'a0.h25', 'a1.h15', 'a1.h24', 'a1.h27', 'a1.h28', 'a2.h13', 'a2.h16', 'a2.h17', 'a2.h2', 'a2.h24', 'a20.h14', 'a24.h14', 'a24.h24', 'a26.h14', 'a28.h7', 'a29.h10', 'a29.h5', 'a29.h9', 'a30.h12', 'a31.h27', 'a4.h17', 'a4.h30', 'a6.h1', 'input', 'logits', 'm0', 'm1', 'm10', 'm11', 'm12', 'm13', 'm14', 'm15', 'm16', 'm17', 'm18', 'm19', 'm2', 'm20', 'm21', 'm22', 'm23', 'm24', 'm25', 'm26', 'm27', 'm28', 'm29', 'm3', 'm30', 'm31', 'm4', 'm5', 'm6', 'm7', 'm8', 'm9']

## B. Temporal vs Time-Invariant

### B.1 Temporal-Only Nodes

- ['a15.h0', 'a18.h3']

### B.2 Invariant-Only Nodes

- ['a19.h6']


## C. Category-Specific Major Nodes (Excluding A & B)

| Category | Exclusive Major Nod