In [1]:
import sys

sys.path.append("..")
import os
from toolkit.detect_entity_networks.api import DetectEntityNetworks
from toolkit.AI.openai_configuration import OpenAIConfiguration
import polars as pl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Create the workflow object
den = DetectEntityNetworks()
# Set the AI configuration
ai_configuration = OpenAIConfiguration(
    {
        "api_type": "OpenAI",
        "api_key": os.environ["OPENAI_API_KEY"],
        "model": "gpt-4o",
    }
)
den.set_ai_configuration(ai_configuration)

data_path = "../example_outputs/detect_entity_networks/company_grievances/company_grievances_input.csv"
entity_df = pl.read_csv(data_path)

print("Loaded data")

Loaded data


In [3]:
# set entity-attributes
from toolkit.detect_entity_networks.prepare_model import format_data_columns


entity_id_column = "name"
columns_to_link = ["address", "city", "email", "phone", "owner"]
entity_df = format_data_columns(entity_df, columns_to_link, entity_id_column)
den.add_attribute_links(entity_df, entity_id_column, columns_to_link)

summary = den.get_model_summary_value()
print("Summary")
print(summary)

Summary
Number of entities: 3602, Number of attributes: 18549, Number of flags: 0, Number of groups: 0, Number of links: 41727


In [4]:
# set flags
from toolkit.detect_entity_networks.classes import FlagAggregatorType


entity_id_column = "name"
columns_to_link = [
    "safety_grievances",
    "pay_grievances",
    "conditions_grievances",
    "treatment_grievances",
    "workload_grievances",
]
flag_format = FlagAggregatorType.Count
den.add_flag_links(entity_df, entity_id_column, columns_to_link, flag_format)
summary = den.get_model_summary_value()
print("Summary")
print(summary)

Summary
Number of entities: 3602, Number of attributes: 18549, Number of flags: 8108, Number of groups: 0, Number of links: 41727


In [5]:
# set groups
entity_id_column = "name"
columns_to_link = ["sector", "country"]
den.add_groups(entity_df, entity_id_column, columns_to_link)

summary = den.get_model_summary_value()
print("Summary")
print(summary)

Summary
Number of entities: 3602, Number of attributes: 18549, Number of flags: 8108, Number of groups: 634, Number of links: 41727


In [6]:
await den.index_nodes(["ENTITY"])
if len(den.embedded_texts) > 0:
    print(f"Number of nodes indexed: {len(den.embedded_texts)}")

Got 3602 existing texts
Got 0 new texts
Number of nodes indexed: 3602


In [7]:
from toolkit.detect_entity_networks.index_and_infer import build_inferred_df


threshold = 0.03
await den.infer_nodes(threshold)

inferred_links_count = len(den.inferred_links)
if inferred_links_count > 0:
    print(f"Number of links inferred: {inferred_links_count}")
    inferred_df = build_inferred_df(den.inferred_links)
    print(inferred_df)
else:
    print("No inferred links")

Number of links inferred: 2242
shape: (3_547, 2)
┌───────────────────────┬────────────────────────┐
│ text                  ┆ similar                │
│ ---                   ┆ ---                    │
│ str                   ┆ str                    │
╞═══════════════════════╪════════════════════════╡
│ Adventure Gear        ┆ Adventure Gear Co      │
│ Adventure Gear        ┆ Adventure Gear Company │
│ Adventure Gear        ┆ AdventureGear          │
│ Adventure Gear Co     ┆ Adventure Gear Company │
│ Adventure Gear Co     ┆ AdventureGear Co       │
│ …                     ┆ …                      │
│ WindPower Corp        ┆ WindPower Inc          │
│ WindPower Solutions   ┆ WindPower Solutons     │
│ Windy Heights Limited ┆ Windy Heights Ltd      │
│ WindyCity Energy      ┆ WindyCity Power        │
│ Zephyr Energy Co      ┆ Zephyr Energy Inc      │
└───────────────────────┴────────────────────────┘


In [8]:
# Remove attributes


In [9]:
# Identify

den.identify()
entity_df = pl.DataFrame(
    den.entity_records,
    schema=[
        "entity_id",
        "entity_flags",
        "network_id",
        "network_entities",
        "network_flags",
        "flagged",
        "flags/entity",
        "flagged/unflagged",
    ],
)
entity_df = entity_df.sort(by=["flagged/unflagged"], descending=True)

TypeError: unsupported operand type(s) for |: 'NoneType' and 'int'