In [None]:

import sys
import networkx as nx
import pandas as pd

sys.path.append("../")
from toolkit.risk_networks.config import AttributeColumnType
from toolkit.risk_networks.graph_functions import build_undirected_graph
from toolkit.risk_networks.model import prepare_entity_attribute
from toolkit.risk_networks.text_format import format_data_columns

input_dataframe = pd.read_csv("./input/rn_test.csv")

def build_model_with_attributes(input_dataframe, entity_id_column, columns_to_link) -> nx.Graph:
    data_df = format_data_columns(input_dataframe, columns_to_link, entity_id_column)
    attribute_links, node_types = prepare_entity_attribute(
        data_df, entity_id_column, AttributeColumnType.ColumnName, columns_to_link
    )
    return build_undirected_graph(network_attribute_links=attribute_links)

In [None]:
from collections import defaultdict
from toolkit.risk_networks.identify import project_entity_graph, trim_nodeset
from toolkit.risk_networks.node_community import get_community_nodes
from toolkit.risk_networks.network import build_network_from_entities
import polars as pl

def detect_networks(
    graph: nx.Graph,
    ad_trimmed_attributes,
    inferred_links,
    supporting_attributes,
    integrated_flags,
    max_degree = 10,
    max_community_size = 10,
) -> nx.Graph:
    (trimmed_degrees, trimmed_nodes) = trim_nodeset(
        graph, ad_trimmed_attributes, max_degree
    )

    P = project_entity_graph(
        graph,
        trimmed_nodes,
        inferred_links,
        supporting_attributes,
    )

    trimmed_attr = pd.DataFrame(
        trimmed_degrees,
        columns=["Attribute", "Linked Entities"],
    )

    (
        community_nodes,
        entity_to_community,
    ) = get_community_nodes(P, max_community_size)

    inferred_links = defaultdict(set)
    network = build_network_from_entities(
        graph,
        entity_to_community,
        integrated_flags,
        trimmed_attr,
        inferred_links,
    )
    return community_nodes, network, trimmed_attr

In [None]:
# inputs
# What if I input a json or csv with these values?

entity_id_column = "Subject ID"
columns_to_link = ["Event Description"]
attribute_type = AttributeColumnType.ColumnName

ad_trimmed_attributes = []
max_degree = 10
inferred_links = []
supporting_attributes = []
max_community_size = 20
integrated_flags = pl.DataFrame()


In [None]:
from toolkit.risk_networks import config

graph = build_model_with_attributes(input_dataframe, entity_id_column, columns_to_link)
all_nodes = graph.nodes()
entity_nodes = [
    node for node in all_nodes if node.startswith(config.entity_label)
]

num_entities = len(entity_nodes)
num_attributes = len(all_nodes) - num_entities
num_edges = len(graph.edges())
groups = set()
num_flags = len(integrated_flags["count"].sum()) if not integrated_flags.is_empty() else 0

print(
    f"*Number of entities*: {num_entities}\n*Number of attributes*: {num_attributes}\n*Number of links*: {num_edges}\n*Number of flags*: {num_flags}\n*Number of groups*: {len(groups)}"
)

(comma, net, trimmed_attr) = detect_networks(graph, ad_trimmed_attributes, inferred_links, supporting_attributes, integrated_flags, max_degree, max_community_size)

In [None]:
from toolkit.risk_networks.network import generate_final_df

entity_records = generate_final_df(
                    comma,
                    integrated_flags
                )

comm_count = len(comma)

comm_sizes = [
    len(comm)
    for comm in comma
    if len(comm) > 1
]
max_comm_size = max(comm_sizes)

print(
    f"*Networks identified: {comm_count} ({len(comm_sizes)} with multiple entities, maximum {max_comm_size})*"
)
print(
    f"*Attributes removed because of high degree*: {trimmed_attr}"
)

In [None]:
network_entity_df = pd.DataFrame(
                    entity_records,
                    columns=[
                        "Entity ID",
                        "Entity Flags",
                        "Network ID",
                        "Network Entities",
                        "Network Flags",
                        "Flagged",
                        "Flags/Entity",
                        "Flagged/Unflagged",
                    ],
                )
network_entity_df

In [None]:
# network_entity_df
last_df = (
            network_entity_df.drop(columns=["Entity ID", "Entity Flags"])
            .drop_duplicates()
            .reset_index(drop=True)
            )
last_df