In [None]:
import time

import pandas as pd

In [None]:
PROJECT_DIRECTORY = "<PROJECT_DIRECTORY>"

In [None]:
entities = pd.read_parquet(f"{PROJECT_DIRECTORY}/output/entities.parquet")
print(len(entities))
entities.head()

In [None]:
relationships = pd.read_parquet(f"{PROJECT_DIRECTORY}/output/relationships.parquet")
print(len(relationships))
relationships.head()

In [None]:
communities = pd.read_parquet(f"{PROJECT_DIRECTORY}/output/communities.parquet")
print(len(communities))
communities.head()

In [None]:
from graphrag.index.operations.create_graph import create_graph

graph = create_graph(relationships, edge_attr=["weight"])
print(graph.nodes)

In [None]:
from graphrag.index.operations.embed_graph.embed_node2vec import embed_node2vec
from graphrag.index.operations.layout_graph.umap import run as run_umap

start = time.time()
n2v = embed_node2vec(
    graph,
)
end = time.time()
print("n2v time:", end - start)
n_embeddings = dict(zip(n2v.nodes, n2v.embeddings))


n_umap = run_umap(graph, n_embeddings, lambda x: x)
n_umap_list = [{"title": p.label, "x_n2v": p.x, "y_n2v": p.y} for p in n_umap]

n_df = pd.DataFrame(n_umap_list)

n_df.head()

In [None]:
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.index.operations.embed_graph.embed_graph import embed_graph

start = time.time()
pipeline_embeddings = embed_graph(graph, entities, communities, EmbedGraphConfig())
end = time.time()
print("gee time:", end - start)

p_umap = run_umap(graph, pipeline_embeddings, lambda x: x)

p_umap_list = [{"title": p.label, "x_gee_p": p.x, "y_gee_p": p.y} for p in p_umap]

p_df = pd.DataFrame(p_umap_list)

p_df.head()

In [None]:
merged_entities = entities.merge(n_df, left_on="title", right_on="title", how="left")
merged_entities = merged_entities.merge(
    p_df, left_on="title", right_on="title", how="left"
)
community_labels = communities.explode("entity_ids")[
    ["community", "entity_ids", "level"]
]
merged_entities = merged_entities.merge(
    community_labels, left_on="id", right_on="entity_ids", how="left"
)
merged_entities = merged_entities[merged_entities["level"] == 0]

In [None]:
merged_entities.plot(
    x="x_n2v",
    y="y_n2v",
    s=5,
    kind="scatter",
    c="community",
    cmap="tab20",
    title="n2v",
    figsize=(12, 10),
    xticks=[],
    yticks=[],
    xlabel="",
    ylabel="",
)
merged_entities.plot(
    x="x_gee_p",
    y="y_gee_p",
    s=5,
    kind="scatter",
    c="community",
    cmap="tab20",
    title="workflow",
    figsize=(12, 10),
    xticks=[],
    yticks=[],
    xlabel="",
    ylabel="",
)