Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240917213220301479.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create_final_communities."
}
3 changes: 3 additions & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ isna
getcwd
fillna
noqa
dtypes

# Azure
abfs
Expand Down Expand Up @@ -97,6 +98,8 @@ retryer
agenerate
aembed
dedupe
dropna
dtypes

# LLM Terms
AOAI
Expand Down
28 changes: 25 additions & 3 deletions graphrag/index/verbs/graph/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,32 @@ def unpack_graph(
column: <column name> # The name of the column containing the graph, should be a graphml graph
```
"""
input_df = input.get_input()
output_df = unpack_graph_df(
cast(pd.DataFrame, input_df),
callbacks,
column,
type,
copy,
embeddings_column,
kwargs=kwargs,
)
return TableContainer(table=output_df)


def unpack_graph_df(
input_df: pd.DataFrame,
callbacks: VerbCallbacks,
column: str,
type: str, # noqa A002
copy: list[str] | None = None,
embeddings_column: str = "embeddings",
**kwargs,
) -> pd.DataFrame:
"""Unpack nodes or edges from a graphml graph, into a list of nodes or edges."""
if copy is None:
copy = default_copy
input_df = input.get_input()

num_total = len(input_df)
result = []
copy = [col for col in copy if col in input_df.columns]
Expand All @@ -64,8 +87,7 @@ def unpack_graph(
)
])

output_df = pd.DataFrame(result)
return TableContainer(table=output_df)
return pd.DataFrame(result)


def _run_unpack(
Expand Down
149 changes: 1 addition & 148 deletions graphrag/index/workflows/v1/create_final_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,154 +19,7 @@ def build_steps(
"""
return [
{
"id": "graph_nodes",
"verb": "unpack_graph",
"args": {
"column": "clustered_graph",
"type": "nodes",
},
"verb": "create_final_communities",
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "graph_edges",
"verb": "unpack_graph",
"args": {
"column": "clustered_graph",
"type": "edges",
},
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "source_clusters",
"verb": "join",
"args": {
"on": ["label", "source"],
},
"input": {"source": "graph_nodes", "others": ["graph_edges"]},
},
{
"id": "target_clusters",
"verb": "join",
"args": {
"on": ["label", "target"],
},
"input": {"source": "graph_nodes", "others": ["graph_edges"]},
},
{
"id": "concatenated_clusters",
"verb": "concat",
"input": {
"source": "source_clusters",
"others": ["target_clusters"],
},
},
{
"id": "combined_clusters",
"verb": "filter",
"args": {
# level_1 is the left side of the join
# level_2 is the right side of the join
"column": "level_1",
"criteria": [
{"type": "column", "operator": "equals", "value": "level_2"}
],
},
"input": {"source": "concatenated_clusters"},
},
{
"id": "cluster_relationships",
"verb": "aggregate_override",
"args": {
"groupby": [
"cluster",
"level_1", # level_1 is the left side of the join
],
"aggregations": [
{
"column": "id_2", # this is the id of the edge from the join steps above
"to": "relationship_ids",
"operation": "array_agg_distinct",
},
{
"column": "source_id_1",
"to": "text_unit_ids",
"operation": "array_agg_distinct",
},
],
},
"input": {"source": "combined_clusters"},
},
{
"id": "all_clusters",
"verb": "aggregate_override",
"args": {
"groupby": ["cluster", "level"],
"aggregations": [{"column": "cluster", "to": "id", "operation": "any"}],
},
"input": {"source": "graph_nodes"},
},
{
"verb": "join",
"args": {
"on": ["id", "cluster"],
},
"input": {"source": "all_clusters", "others": ["cluster_relationships"]},
},
{
"verb": "filter",
"args": {
# level is the left side of the join
# level_1 is the right side of the join
"column": "level",
"criteria": [
{"type": "column", "operator": "equals", "value": "level_1"}
],
},
},
*create_community_title_wf,
{
# TODO: Rodrigo says "raw_community" is temporary
"verb": "copy",
"args": {
"column": "id",
"to": "raw_community",
},
},
{
"verb": "select",
"args": {
"columns": [
"id",
"title",
"level",
"raw_community",
"relationship_ids",
"text_unit_ids",
],
},
},
]


create_community_title_wf = [
# Hack to string concat "Community " + id
{
"verb": "fill",
"args": {
"to": "__temp",
"value": "Community ",
},
},
{
"verb": "merge",
"args": {
"columns": [
"__temp",
"id",
],
"to": "title",
"strategy": "concat",
"preserveSource": True,
},
},
]
2 changes: 2 additions & 0 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

"""The Indexing Engine workflows -> subflows package root."""

from .create_final_communities import create_final_communities
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding

__all__ = [
"create_final_communities",
"create_final_text_units_pre_embedding",
]
107 changes: 107 additions & 0 deletions graphrag/index/workflows/v1/subflows/create_final_communities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform final communities."""

from typing import cast

import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(name="create_final_communities", treats_input_tables_as_immutable=True)
def create_final_communities(
input: VerbInput,
callbacks: VerbCallbacks,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final communities."""
table = cast(pd.DataFrame, input.get_input())

graph_nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")

source_clusters = graph_nodes.merge(
graph_edges,
left_on="label",
right_on="source",
how="inner",
)
target_clusters = graph_nodes.merge(
graph_edges,
left_on="label",
right_on="target",
how="inner",
)

concatenated_clusters = pd.concat(
[source_clusters, target_clusters], ignore_index=True
)

# level_x is the left side of the join
# level_y is the right side of the join
# we only want to keep the clusters that are the same on both sides
combined_clusters = concatenated_clusters[
concatenated_clusters["level_x"] == concatenated_clusters["level_y"]
].reset_index(drop=True)

cluster_relationships = aggregate_df(
cast(Table, combined_clusters),
aggregations=[
{
"column": "id_y", # this is the id of the edge from the join steps above
"to": "relationship_ids",
"operation": "array_agg_distinct",
},
{
"column": "source_id_x",
"to": "text_unit_ids",
"operation": "array_agg_distinct",
},
],
groupby=[
"cluster",
"level_x", # level_x is the left side of the join
],
)

all_clusters = aggregate_df(
graph_nodes,
aggregations=[{"column": "cluster", "to": "id", "operation": "any"}],
groupby=["cluster", "level"],
)

joined = all_clusters.merge(
cluster_relationships,
left_on="id",
right_on="cluster",
how="inner",
)

filtered = joined[joined["level"] == joined["level_x"]].reset_index(drop=True)

filtered["title"] = "Community " + filtered["id"].astype(str)

return create_verb_result(
cast(
Table,
filtered[
[
"id",
"title",
"level",
"relationship_ids",
"text_unit_ids",
]
],
)
)
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
1,
2000
],
"subworkflows": 14,
"subworkflows": 1,
"max_runtime": 10
},
"create_final_community_reports": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
1,
2000
],
"subworkflows": 14,
"subworkflows": 1,
"max_runtime": 10
},
"create_final_community_reports": {
Expand Down
35 changes: 35 additions & 0 deletions tests/verbs/test_create_final_communities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from graphrag.index.workflows.v1.create_final_communities import (
build_steps,
workflow_name,
)

from .util import (
compare_outputs,
get_workflow_output,
load_expected,
load_input_tables,
)


async def test_create_final_communities():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)

steps = build_steps({})

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

# we removed the raw_community column, so expect one less in the output
compare_outputs(
actual, expected, ["id", "title", "level", "relationship_ids", "text_unit_ids"]
)
Loading