In [1]:
import networkx as nx
import importlib
import numpy as np
import os
import re
import pygraphviz as pg
import imageio
from PIL import Image

In [5]:
def ARI(G, cc, clustering_label="block"):

    #检查是否安装了 scikit-learn 库。
    #如果安装了，就从中导入 preprocessing 和 metrics 模块。如果没有安装，就打印一条消息并返回 -1。
    if importlib.util.find_spec("sklearn") is not None:
        from sklearn import preprocessing, metrics
    else:
        print("scikit-learn is not installed...")
        return -1

    #从图 G 中获取名为 clustering_label 的节点属性，并将其存储在 complexlist 变量中
    complexlist = nx.get_node_attributes(G, clustering_label)

    #创建一个 LabelEncoder 对象，并使用它将 complexlist 的值转换为整数
    #这样做是为了确保我们的评估指标可以处理整数标签
    le = preprocessing.LabelEncoder()
    y_true = le.fit_transform(list(complexlist.values()))

    #创建一个名为 predict_dict 的字典，用于存储预测的聚类标签
    #遍历 cc（聚类结果），并将每个节点的预测标签存储在字典中
    predict_dict = {}
    for idx, comp in enumerate(cc):
        for c in list(comp):
            predict_dict[c] = idx

    #创建一个名为 y_pred 的列表，用于存储预测的聚类标签
    #遍历 complexlist 的键（节点），并从 predict_dict 中获取相应的预测标签。最后，将 y_pred 转换为 numpy 数组。
    y_pred = []
    for v in complexlist.keys():
        y_pred.append(predict_dict[v])
    y_pred = np.array(y_pred)

    return metrics.adjusted_rand_score(y_true, y_pred)


def NMI(G, cc, clustering_label="club"):
    if importlib.util.find_spec("sklearn") is not None:
        from sklearn import preprocessing, metrics
    else:
        print("scikit-learn is not installed...")
        return -1

    complexlist = nx.get_node_attributes(G, clustering_label)

    le = preprocessing.LabelEncoder()
    y_true = le.fit_transform(list(complexlist.values()))

    predict_dict = {}
    for idx, comp in enumerate(cc):
        for c in list(comp):
            predict_dict[c] = idx

    y_pred = []
    for v in complexlist.keys():
        y_pred.append(predict_dict[v])
    y_pred = np.array(y_pred)

    return metrics.normalized_mutual_info_score(y_true, y_pred)


def Modularity(G, cc, clustering_label="club"):
    return nx.algorithms.community.modularity(G, cc)

In [6]:
# G = nx.karate_club_graph()
history = []
for i in range(100):
    g_file = os.path.join(
        "/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/results/SBM/gexf/DirectGraph/with_surgery",
        "{}.gexf".format(i),
    )
    G = nx.read_gexf(g_file)
    cc = list(nx.strongly_connected_components(G))
    ari = ARI(G, cc, "block")
    history.append([i, len(cc), ari])
    print("number of  strongly connnected components: %d   ARI: %5f" % (len(cc), ari))
sort = sorted(history, key=lambda x: x[2], reverse=True)
print("Maximum: ", sort[0])

number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly connnected components: 1   ARI: 0.000000
number of  strongly conn

In [2]:
# 为了知道节点所有的attribute
g1_file = os.path.join(
        "/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/results/SBM/gexf/DirectGraph/with_surgery",
        "{}.gexf".format(1),
    )

G1 = nx.read_gexf(g1_file)

attribute_names = set()
for node, data in G1.nodes(data=True):
    attribute_names.update(data.keys())

print(attribute_names) # output: {'block', 'label'} -- block应该是我们所要的

{'block', 'label'}


In [5]:
# 确定边的属性
edge_attribute_names = set()
for u, v, data  in G1.edges(data=True):
    edge_attribute_names.update(data.keys())

print(edge_attribute_names) # output: {'id', 'weight', 'ricciCurvature', 'original_RC'}

{'id', 'weight', 'ricciCurvature', 'original_RC'}


In [3]:
#  为了知道所有的block的所有可能取值范围
block_values = set()
for node, data in G1.nodes(data=True):
    if 'block' in data:
        block_values.add(data['block'])

print(block_values) # output: {0, 1}

{0, 1}


In [12]:
# 绘图
import pygraphviz as pg
G1_graph = pg.AGraph(directed=True)

# 为节点分配颜色
for node, data in G1.nodes(data=True):
    if data['block'] == 0:
        color = '#377eb8'
    else:
        color = '#ff7f00'
    G1_graph.add_node(node, fillcolor=color, style='filled')

normalize = float(2 * int(G1.number_of_edges()))

for e in G1.edges():
    G1_graph.add_edge(e[0], e[1], len=float(normalize * G1[e[0]][e[1]]["weight"]),dir='forward')

G1_graph.draw("/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/plot/G1-plot.png", format="png", prog="neato")

In [15]:
#最后一张图

g100_file = os.path.join(
        "/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/results/SBM/gexf/DirectGraph/with_surgery",
        "{}.gexf".format(100),
    )

G100 = nx.read_gexf(g100_file)


G100_graph = pg.AGraph(directed=True)

# 为节点分配颜色
for node, data in G100.nodes(data=True):
    if data['block'] == 0:
        color = '#377eb8'
    else:
        color = '#ff7f00'
    G100_graph.add_node(node, fillcolor=color, style='filled')

G100_graph.add_nodes_from(G100.nodes())
normalize = float(2 * int(G100.number_of_edges()))

for e in G100.edges():
    G100_graph.add_edge(e[0], e[1], len=float(5*normalize * G100[e[0]][e[1]]["weight"]),dir='forward')

G100_graph.draw("/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/plot/G100-plot.png", format="png", prog="neato")

In [2]:
def draw_gexf_to_png(index):
    g_file = os.path.join(
        "/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/results/SBM/gexf/DirectGraph/with_surgery",
        "{}.gexf".format(index),
    )

    G = nx.read_gexf(g_file)
    G_graph = pg.AGraph(directed=True)

    # 为节点分配颜色
    for node, data in G.nodes(data=True):
        if data['block'] == 0:
            color = '#377eb8'
        else:
            color = '#ff7f00'
        G_graph.add_node(node, fillcolor=color, shape='circle', style='filled')

    normalize = float(2 * int(G.number_of_edges()))

    # for e in G.edges():
    #     G_graph.add_edge(e[0], e[1], len=float(2*normalize * G[e[0]][e[1]]["weight"]), dir='forward')

    # 为边添加颜色
    for e in G.edges():
        block_0 = G.nodes[e[0]]['block']
        block_1 = G.nodes[e[1]]['block']

        if block_0 == block_1:
            if block_0 == 0:
                color = '#377eb8'
            else:
                color = '#ff7f00'
        else:
            color = '#41d183'  # 介于'#377eb8'和'#ff7f00'之间的颜色

        G_graph.add_edge(e[0], e[1], color=color, len=float(2*normalize * G[e[0]][e[1]]["weight"]), dir='forward')


    output_file = os.path.join(
        "/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/plot",
        "G{}-plot.png".format(index)
    )

    G_graph.draw(output_file, format="png", prog="neato")

In [3]:
# 使用函数绘制所有100个文件
for i in range(1,101):
    draw_gexf_to_png(i)

In [4]:
def create_gif(image_dir, save_name, resize=(400, 400)):
    image_list = []
    frames = []
    for f in os.listdir(image_dir):
        if f.endswith(".png"):
            image_list.append(os.path.join(image_dir, f))
            # 查看文件读取是否成功 -- 成功
            # image_path = os.path.join(image_dir, f)
            # print(f"Processing image: {image_path}")
    
    # 使用正则表达式从文件名中提取数字
    def extract_number(filename):
        match = re.search(r'G(\d+)-plot\.png', filename)
        return int(match.group(1)) if match else 0

    # 根据提取的数字对图像文件进行排序
    image_list.sort(key=extract_number)
    
    # 打印 image_list 中元素的个数 -- output: Number of images in image_list: 100
    # print(f"Number of images in image_list: {len(image_list)}")

    # 打印排序后的图像文件名
    # print("Sorted image filenames:")
    # for image_name in image_list:
    #     print(image_name)

    for image_name in image_list:
        if image_name.endswith(".png"):
            image = Image.open(image_name)
            image = image.resize(resize)
            # 检查图像尺寸是否与预期的尺寸相同
            if image.size != resize:
                print(f"Resizing image {image_name} to {resize}")
                image = image.resize(resize)
            # 确保图像已经调整到正确的尺寸
            assert image.size == resize, f"Image {image_name} has incorrect size: {image.size}"
            frames.append(image)

    # 保存为GIF
    frames[0].save(save_name, format="GIF", save_all=True, append_images=frames[1:], duration=100, loop=0)

In [5]:
# 使用create_gif函数将所有PNG图像合并为一个GIF文件
image_dir = "/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/plot"
output_dir = "/Users/mike-fx/Desktop/library/2024-spring/ug-project/DIRECTEDRICCIFLOW/gif"
save_name = os.path.join(output_dir, "animation.gif")
create_gif(image_dir, save_name,(400,400))