In [5]:
import networkx as nx
import json
import pandas as pd
import matplotlib.pyplot as plt
import re

# --- Step 1: Load ---
import pickle

# Load normalized graph and metadata
with open("/content/mmb_graph_enhanced.gpickle", "rb") as f:
    G = pickle.load(f)

with open("/content/node_to_chunks_enhanced.json") as f:
    node_to_chunks = json.load(f)

with open("/content/node_to_community_enhanced.json") as f:
    node_to_community = json.load(f)

# with open("/content/chunk_metadata.json") as f:
#     chunk_metadata = json.load(f)

In [None]:
import networkx as nx
from collections import defaultdict, Counter
from rapidfuzz import fuzz
import numpy as np

import re
import networkx as nx
from collections import defaultdict, Counter
from rapidfuzz import fuzz

### 1. Normalize node names ###
def normalize_node_name(name):
    name = re.sub(r'^[\d\.\-\(\)\s]+', '', name)  # Remove leading numbers/symbols
    name = name.strip().lower()
    return name


### 2. Normalize existing graph nodes and update node_to_chunks ###
def normalize_graph(G, node_to_chunks):
    # Build mapping from original to normalized names
    mapping = {node: normalize_node_name(node) for node in G.nodes}

    # Relabel graph nodes safely
    G_normalized = nx.relabel_nodes(G, mapping, copy=True)

    # Rebuild node_to_chunks with normalized keys
    node_to_chunks_norm = defaultdict(set)
    for node, chunks in node_to_chunks.items():
        new_node = normalize_node_name(node)
        node_to_chunks_norm[new_node].update(chunks)

    return G_normalized, node_to_chunks_norm


### 3. Fuzzy merge similar nodes (with shared chunk evidence) ###
def build_merge_map(G, node_to_chunks, chunk_metadata, threshold=87):
    nodes = list(G.nodes)
    merge_map = {}
    reverse_map = {}

    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            n1, n2 = nodes[i], nodes[j]
            sim = fuzz.token_sort_ratio(n1, n2)
            if sim >= threshold:
                shared_chunks = node_to_chunks.get(n1, set()) & node_to_chunks.get(n2, set())
                if len(shared_chunks) >= 1:
                    rep = min(n1, n2, key=len)
                    merge_map[n1] = rep
                    merge_map[n2] = rep
                    reverse_map.setdefault(rep, set()).update([n1, n2])
    return merge_map, reverse_map


### 4. Apply node merging to graph and mappings ###
def apply_merge(G, node_to_chunks, merge_map):
    G_cleaned = nx.relabel_nodes(G, merge_map, copy=True)

    node_to_chunks_cleaned = defaultdict(set)
    for node, chunks in node_to_chunks.items():
        new_node = merge_map.get(node, node)
        node_to_chunks_cleaned[new_node].update(chunks)

    return G_cleaned, node_to_chunks_cleaned

def entropy(items):
    freq = np.array(list(Counter(items).values()))
    probs = freq / freq.sum()
    return -np.sum(probs * np.log2(probs + 1e-10))





In [None]:
def evaluate_community(node_to_community, node_to_chunks, chunk_metadata):
    comm_entropy = {}
    community_chunks = defaultdict(set)
    for node, comm_id in node_to_community.items():
        community_chunks[comm_id].update(node_to_chunks.get(node, []))

    for comm_id, chunk_ids in community_chunks.items():
        carriers = [chunk_metadata[cid]['metadata']['carrier']
            for cid in chunk_ids
            if cid in chunk_metadata and 'metadata' in chunk_metadata[cid] and 'carrier' in chunk_metadata[cid]['metadata']]
        if carriers:
            comm_entropy[comm_id] = entropy(carriers)
    return comm_entropy

In [None]:
# STEP 1: Normalize first
G_normalized, node_to_chunks_norm = normalize_graph(G, node_to_chunks)

# STEP 2: Fuzzy match + build merge map
merge_map, reverse_map = build_merge_map(G_normalized, node_to_chunks_norm, chunk_metadata)

# STEP 3: Apply merge to graph and chunk map
G_cleaned, node_to_chunks_cleaned = apply_merge(G_normalized, node_to_chunks_norm, merge_map)

# STEP 4: Optional community detection
from community import community_louvain
node_to_community_cleaned = community_louvain.best_partition(G_cleaned)

In [None]:
import numpy as np

def get_embedding(chunk_id, chunk_metadata):
    record = chunk_metadata.get(chunk_id, {})
    return record.get("metadata",{}).get("embedding", None)

def build_node_features(node_to_chunks, chunk_metadata):
    node_features = {}
    for node, chunk_ids in node_to_chunks.items():
        embs = [get_embedding(cid, chunk_metadata) for cid in chunk_ids if get_embedding(cid, chunk_metadata)]
        if embs:
            node_features[node] = np.mean(embs, axis=0)
    return node_features


In [None]:
from collections import Counter

def build_node_labels(node_to_chunks, chunk_metadata):
    node_labels = {}
    for node, chunk_ids in node_to_chunks.items():
        categories = [
            chunk_metadata[cid].get("metadata", {}).get("category")
            for cid in chunk_ids
            if "metadata" in chunk_metadata[cid] and "category" in chunk_metadata[cid]["metadata"]
        ]
        if categories:
            node_labels[node] = Counter(categories).most_common(1)[0][0]
    return node_labels

In [None]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx

def prepare_pyg_data(G, node_features, node_labels):
    node_to_id = {node: idx for idx, node in enumerate(node_features)}
    x = torch.tensor([node_features[node] for node in node_to_id], dtype=torch.float)
    unique_labels = sorted(set(node_labels.values()))
    label_to_id = {label: i for i, label in enumerate(unique_labels)}
    y = torch.full((len(node_to_id),), -1, dtype=torch.long)
    for node, label in node_labels.items():
        if node in node_to_id:
            y[node_to_id[node]] = label_to_id[label]
    train_mask = y != -1

    # Build edge_index and edge_type
    edge_list = []
    edge_types = []
    relation_map = {}
    rel_counter = 0
    for u, v, d in G.edges(data=True):
        edge_list.append([node_to_id[u], node_to_id[v]])
        rel = d.get('relation', 'default')
        if rel not in relation_map:
            relation_map[rel] = rel_counter
            rel_counter += 1
        edge_types.append(relation_map[rel])
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_type = torch.tensor(edge_types, dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, edge_type=edge_type)
    return data, label_to_id, relation_map