In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from ricci.simplicial_complex import SimplicialComplex
from ricci.ricci_flow import ricci_flow_community_detection
import matplotlib.pyplot as plt
import itertools
import networkx as nx


In [2]:
import pandas as pd

df = pd.read_table("../datasets/aves-weaver-social.edges", names=['source', 'target', 'true'], sep=' ')
true_labels = df['true'].tolist()
edges = [set([row['source'], row['target']]) for _, row in df.iterrows()]
nodes = set(df['source']).union(set(df['target']))
node_labels = dict(zip(df['source'], df['true']))

In [3]:
sc = SimplicialComplex()

for node in nodes:
    sc.add_simplex((node,))

for edge in edges:
    sc.add_simplex(edge)

G = nx.Graph()
G.add_nodes_from(nodes)
G.add_edges_from(edges)

max_clique_size = 4  
cliques = list(nx.find_cliques(G))

for clique in cliques:
    if len(clique) > 2 and len(clique) <= max_clique_size:
        for k in range(2, len(clique)):
            for simplex in itertools.combinations(clique, k+1):
                sc.add_simplex(simplex)

print("Dimensions and counts of simplices in the complex:")
for dim in sc.simplices:
    count = len(sc.simplices[dim])
    print(f"Dimension {dim}: {count} simplices")

max_dim = max(sc.simplices.keys())
nmi_scores, modularity_scores, theta_values, ari_scores = ricci_flow_community_detection(
    sc, T=20, delta=0.01, ground_truth=node_labels
)

plt.figure(figsize=(10,5))
plt.plot(theta_values, modularity_scores, label='Modularity')
plt.plot(theta_values, nmi_scores, label='NMI')
plt.plot(theta_values, ari_scores, label='ARI') 
plt.xlabel('Weight Cutoff (theta)')
plt.ylabel('Metric Value')
plt.title('aves-weaver-social')
plt.legend()
plt.grid(False)
plt.show()

if nmi_scores:
    max_nmi_index = nmi_scores.index(max(nmi_scores))
    optimal_theta = theta_values[max_nmi_index]
    print(f"Optimal theta (based on NMI): {optimal_theta}")
else:
    print("Could not compute NMI scores.")

sc.simplices = {dim: sc.simplices[dim].copy() for dim in sc.simplices}
sc.weights = sc.weights.copy()

simplices_to_remove = [simplex for simplex, weight in sc.weights.items()
                       if weight > optimal_theta and len(simplex) > 1]
for simplex in simplices_to_remove:
    dim = len(simplex) - 1
    if simplex in sc.simplices[dim]:
        sc.simplices[dim].remove(simplex)
    del sc.weights[simplex]

G = nx.Graph()
G.add_nodes_from(sc.simplices[0])
G.add_edges_from(sc.simplices[1])

communities = list(nx.connected_components(G))
print(f"Number of communities detected: {len(communities)}")

Dimensions and counts of simplices in the complex:
Dimension 0: 445 simplices
Dimension 1: 1335 simplices
Dimension 2: 246 simplices
Dimension 3: 51 simplices
Iteration 1/20
Iteration 2/20
Iteration 3/20
Iteration 4/20
Iteration 5/20
Iteration 6/20


KeyboardInterrupt: 