In [None]:
import networkx as nx
import random
import time
import copy
from cdlib import algorithms
import math
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import numpy as np

def compute_community_stats(communities, G):
    """
    Compute basic statistics for a list of communities on graph G.
    Returns a dict with:
      - community_count
      - avg_size
      - avg_edge_count
      - avg_sources_per_edge
    """
    stats = {}
    n_comm = len(communities)
    stats['community_count'] = n_comm

    # sizes
    if n_comm > 0:
        sizes = [len(c) for c in communities]
        stats['avg_size'] = sum(sizes) / n_comm
    else:
        stats['avg_size'] = 0

    # internal edge counts
    edge_counts = []
    for comm in communities:
        nodes = list(comm)
        cnt = 0
        for i in range(len(nodes)):
            for j in range(i+1, len(nodes)):
                u, v = nodes[i], nodes[j]
                if G.has_edge(u, v) or G.has_edge(v, u):
                    cnt += 1
        edge_counts.append(cnt)
    stats['avg_edge_count'] = sum(edge_counts) / n_comm if n_comm>0 else 0

    # average number of distinct sources per internal edge
    source_counts = []
    for comm in communities:
        nodes = list(comm)
        edge_sources = {}
        for i in range(len(nodes)):
            for j in range(len(nodes)):
                if i == j:
                    continue
                u, v = nodes[i], nodes[j]
                if G.has_edge(u, v):
                    data = G.get_edge_data(u, v)
                    key = (u, v)
                elif G.has_edge(v, u):
                    data = G.get_edge_data(v, u)
                    key = (v, u)
                else:
                    continue
                src = data.get('source', 'unknown')
                srcs = set(src) if isinstance(src, list) else {src}
                edge_sources.setdefault(key, set()).update(srcs)
        source_counts.extend(len(s) for s in edge_sources.values())

    stats['avg_sources_per_edge'] = (
        sum(source_counts) / len(source_counts)
        if source_counts else 0
    )
    return stats


def analyze_community_source_distribution(community, G, verbose=False):
    """
    For a given community (list of nodes) in G, compute:
      - total edges
      - distinct source counts
      - main source and its percentage
      - external sources
      - central node (by degree)
    Returns a dict of results.
    """
    subG = G.subgraph(community)
    src_count = {}
    for u, v, data in subG.edges(data=True):
        src = data.get('source', 'unknown')
        if isinstance(src, list):
            src = str(src[0]) if src else 'unknown'
        src = str(src)
        src_count[src] = src_count.get(src, 0) + 1

    # main source
    if src_count:
        main_src, main_cnt = max(src_count.items(), key=lambda x: x[1])
    else:
        main_src, main_cnt = 'unknown', 0

    external = {s: c for s, c in src_count.items() if s != main_src}

    total_edges = sum(src_count.values())
    total_srcs  = len(src_count)

    # find central node by degree
    degs = {n: subG.degree(n) for n in subG.nodes()}
    if degs:
        central_node, central_deg = max(degs.items(), key=lambda x: x[1])
        ntype = G.nodes[central_node].get('type', 'unknown')
        if isinstance(ntype, list) and ntype:
            ntype = str(ntype[0])
        else:
            ntype = str(ntype)
    else:
        central_node, central_deg, ntype = None, 0, 'unknown'

    result = {
        'node_count': len(community),
        'total_edge_count': total_edges,
        'total_source_count': total_srcs,
        'main_source': main_src,
        'main_source_edge_count': main_cnt,
        'main_source_percentage': round(main_cnt/total_edges*100, 2) if total_edges else 0,
        'external_source_count': len(external),
        'external_edge_count': sum(external.values()),
        'central_node': {
            'id': central_node,
            'type': ntype,
            'degree': central_deg
        }
    }
    return result


def visualize_community_graph_no_overlap(
    G,
    community_nodes,
    highlight=None,
    output_path=None,
    enhance_info=False,
    edge_offset_step=0.05,
    layout_k=None,
    seed=610
):
    """
    Draw a community subgraph with node labels, edge labels, handling multiple
    edges (as polylines), self-loops, and optional highlighting of seed nodes.
    """
    subG = G.subgraph(community_nodes).copy()

    # layout parameter
    if layout_k is None:
        n = len(community_nodes)
        layout_k = max(1.5 / math.sqrt(n) if n>0 else 1.0, 6)

    pos = nx.spring_layout(subG, k=layout_k, iterations=300, seed=seed)
    fig, ax = plt.subplots(figsize=(16, 16))

    # map sources to colors
    all_sources = set()
    for _, _, d in subG.edges(data=True):
        s = d.get('source', None)
        if isinstance(s, list):
            all_sources.update(s)
        elif s:
            all_sources.add(s)
    all_sources = sorted(all_sources)
    cmap = cm.get_cmap('tab20', max(len(all_sources), 1))
    src_to_color = {src: mcolors.to_hex(cmap(i)) for i, src in enumerate(all_sources)}

    # draw nodes
    nx.draw_networkx_nodes(subG, pos, node_size=300, node_color='lightblue', ax=ax)
    nx.draw_networkx_labels(
        subG,
        pos,
        labels={n: str(n) for n in subG.nodes()},
        font_size=10,
        font_weight='bold',
        ax=ax
    )

    # highlight seeds
    if highlight:
        seeds = [n for n in community_nodes if n in highlight.get('used_seeds', [])]
        if seeds:
            nx.draw_networkx_nodes(
                subG,
                pos,
                nodelist=seeds,
                node_size=300,
                node_color='lightblue',
                edgecolors='gold',
                linewidths=3,
                ax=ax
            )
            info = "\n".join(
                f"seed: {n}, type: {G.nodes[n].get('type','unknown')}"
                for n in seeds
            )
            ax.text(
                0.01, 0.99, info,
                transform=ax.transAxes,
                fontsize=10,
                verticalalignment='top',
                bbox=dict(facecolor='white', alpha=0.5, edgecolor='gold', pad=5),
                zorder=5
            )

    # adjust view
    xs = [p[0] for p in pos.values()]
    ys = [p[1] for p in pos.values()]
    if xs and ys:
        dx, dy = max(xs) - min(xs), max(ys) - min(ys)
        margin_x, margin_y = dx*0.2, dy*0.2
        ax.set_xlim(min(xs)-margin_x, max(xs)+margin_x)
        ax.set_ylim(min(ys)-margin_y, max(ys)+margin_y)

    def edge_style(data):
        s = data.get('source', 'unknown')
        if isinstance(s, list):
            s = s[0] if s else 'unknown'
        color = src_to_color.get(s, '#888888')
        rel = data.get('rel', 'N/A')
        if enhance_info:
            rt = data.get('rel_type', ['N/A'])[0]
            tc = data.get('tactic', ['N/A'])[0]
            text = f"rel: {rel}\nrel_type: {rt}\ntactic: {tc}\nsource: {s}"
        else:
            text = f"rel: {rel}"
        return color, text

    def polyline(A, B, offset):
        (x1, y1), (x2, y2) = A, B
        dx, dy = x2-x1, y2-y1
        L = math.hypot(dx, dy) + 1e-9
        nxv, nyv = -dy/L, dx/L
        mx, my = (x1+x2)/2, (y1+y2)/2
        off_mx, off_my = mx + offset*nxv, my + offset*nyv
        return [(x1,y1), (off_mx,off_my), (x2,y2)]

    # draw edges (multi/self-loop aware)
    if isinstance(subG, (nx.MultiGraph, nx.MultiDiGraph)):
        edge_iter = subG.edges(keys=True, data=True)
    else:
        edge_iter = ((u, v, 0, d) for u, v, d in subG.edges(data=True))

    seen = {}
    loops = {}
    for u, v, k, d in edge_iter:
        is_loop = (u == v)
        color, label = edge_style(d)

        if is_loop:
            loops[u] = loops.get(u, 0) + 1
            n = loops[u]
            x, y = pos[u]
            size = 0.15 + 0.05*(n-1)
            angle = math.pi/4 + (n-1)*math.pi/6
            theta = np.linspace(angle, angle+1.5*math.pi, 100)
            lx = x + size*np.cos(theta)
            ly = y + size*np.sin(theta)
            ax.plot(lx, ly, color=color, linewidth=2, zorder=1)
            la = angle + 0.75*math.pi
            ax.text(
                x + size*math.cos(la),
                y + size*math.sin(la),
                label,
                fontsize=12,
                ha='center',
                va='center',
                bbox=dict(boxstyle="round,pad=0.3", fc="white", ec=color, alpha=0.8),
                zorder=3
            )
            arrow_angle = angle + 1.25*math.pi
            ax.arrow(
                x + size*np.cos(arrow_angle),
                y + size*np.sin(arrow_angle),
                -0.02*math.sin(arrow_angle),
                0.02*math.cos(arrow_angle),
                head_width=0.02,
                head_length=0.02,
                fc=color,
                ec=color,
                zorder=2
            )

        else:
            edge_id = (min(u,v), max(u,v))
            seen[edge_id] = seen.get(edge_id, 0) + 1
            n = seen[edge_id]
            offset = 0
            if n > 1:
                sign = 1 if n%2==0 else -1
                offset = sign * edge_offset_step * ((n+1)//2)

            pts = polyline(pos[u], pos[v], offset)
            xs, ys = zip(*pts)
            ax.plot(xs, ys, color=color, linewidth=2, zorder=1)

            mx, my = pts[1]
            ax.text(
                mx, my, label,
                fontsize=12,
                ha='center',
                va='center',
                bbox=dict(boxstyle="round,pad=0.3", fc="white", ec=color, alpha=0.8),
                zorder=3
            )

            # arrow on second segment
            dxl, dyl = pts[2][0]-pts[1][0], pts[2][1]-pts[1][1]
            mag = math.hypot(dxl, dyl)
            if mag > 0:
                dxl, dyl = dxl/mag*0.03, dyl/mag*0.03
                ax.arrow(
                    pts[1][0] - dxl/2,
                    pts[1][1] - dyl/2,
                    dxl, dyl,
                    head_width=0.02,
                    head_length=0.02,
                    fc=color,
                    ec=color,
                    zorder=2
                )

    plt.title("Community Subgraph")
    plt.axis('off')
    plt.tight_layout()
    if output_path:
        plt.savefig(output_path, dpi=300)
    plt.show()


def remove_source_attr(G):
    """
    Return a copy of G with the 'source' attribute removed from all edges.
    """
    Gc = G.copy()
    if isinstance(Gc, (nx.MultiGraph, nx.MultiDiGraph)):
        for u, v, k in Gc.edges(keys=True):
            Gc[u][v][k].pop('source', None)
    else:
        for u, v in Gc.edges():
            Gc[u][v].pop('source', None)
    return Gc


def make_safe_copy(G):
    """
    Deep copy G so that all attributes are fully writable.
    """
    return copy.deepcopy(G)


def clean_graph_for_seed_algos(G):
    """
    Remove 'source' from edges, remove isolates, normalize weights to float.
    """
    Gc = G.copy()
    # remove 'source'
    if isinstance(Gc, (nx.MultiGraph, nx.MultiDiGraph)):
        for u, v, k in Gc.edges(keys=True):
            Gc[u][v][k].pop('source', None)
    else:
        for u, v in Gc.edges():
            Gc[u][v].pop('source', None)

    # remove isolates
    Gc.remove_nodes_from(list(nx.isolates(Gc)))

    # normalize weights
    for u, v in Gc.edges():
        w = Gc[u][v].get('weight', 1.0)
        if not isinstance(w, (int, float)) or w != w or w == float('inf'):
            Gc[u][v]['weight'] = 1.0

    return Gc


def get_typed_seeds_for_osse(G, type_counts):
    """
    Given a graph G and a dict type_counts {node_type: count},
    sample up to that many nodes of each type.
    """
    seeds = []
    type_nodes = {}
    for n, data in G.nodes(data=True):
        t = data.get('type', 'unknown')
        if isinstance(t, list) and t:
            t = str(t[0])
        else:
            t = str(t)
        if t in type_counts:
            type_nodes.setdefault(t, []).append(n)

    for t, cnt in type_counts.items():
        avail = type_nodes.get(t, [])
        if avail:
            pick = min(len(avail), cnt)
            seeds.extend(random.sample(avail, pick))
    return seeds


def run_osse_with_typed_seeds(G, type_counts):
    """
    Run the OSSE algorithm on a cleaned undirected aggregated graph,
    using typed seeds. Returns (communities, runtime, used_seeds).
    """
    seeds = get_typed_seeds_for_osse(G, type_counts)
    Gc = clean_graph_for_seed_algos(G_aggregated_undirected)
    valid = [n for n in seeds if n in Gc.nodes()]
    if not valid:
        raise ValueError("No valid seed nodes available")

    start = time.time()
    res = algorithms.overlapping_seed_set_expansion(
        Gc, seeds=valid, nruns=100, nworkers=32, ninf=True
    )
    runtime = round(time.time() - start, 4)
    return res.communities, runtime, valid


def analyze_community_edge_sources(communities, G):
    """
    Filter communities to those whose internal edges come from >1 distinct source.
    Returns (filtered_communities, stats_list).
    """
    filtered = []
    stats = []
    for idx, comm in enumerate(communities):
        subG = G.subgraph(comm)
        edge_src = {}
        for u, v, d in subG.edges(data=True):
            s = d.get('source', 'unknown')
            if isinstance(s, list):
                s = str(s[0]) if s else 'unknown'
            edge_src[(u, v)] = str(s)
        uniq = set(edge_src.values())
        ecount = subG.number_of_edges()
        if len(uniq) > 1 and ecount > 0:
            filtered.append(comm)
            dist = {u: sum(1 for x in edge_src.values() if x == u) for u in uniq}
            stats.append({
                'community_id': idx,
                'node_count': len(comm),
                'edge_count': ecount,
                'source_count': len(uniq),
                'sources': list(uniq),
                'source_distribution': dist
            })
    return filtered, stats


# === Example of assembling the undirected aggregated graph ===
G_aggregated_undirected = nx.MultiGraph()
G_aggregated_undirected.add_nodes_from(G_aggregated.nodes(data=True))
G_aggregated_undirected.add_edges_from(G_aggregated.edges(data=True))


# === Run OSSE with typed seeds ===
type_counts = {
    'malware': 9999,
    'attack-pattern': 9999,
    'threat-actor': 9999,
    'vulnerability': 9999
}

typed_communities, typed_runtime, used_seeds = run_osse_with_typed_seeds(
    G_aggregated, type_counts
)

# compute overall stats
community_stats = compute_community_stats(typed_communities, G_aggregated)

# filter for multi-source communities
multi_source_communities, multi_source_stats = analyze_community_edge_sources(
    typed_communities, G_aggregated
)

# example visualization
Index = 65
if Index < len(multi_source_communities):
    target_comm = multi_source_communities[Index]
    highlight_info = {'used_seeds': used_seeds, 'seed_types': type_counts}
    visualize_community_graph_no_overlap(
        G_aggregated,
        target_comm,
        highlight=highlight_info,
        enhance_info=True
    )
