In [5]:
import estimators
from embeddings import Embedding
from tf_ievm.ievm import TFiEVM

import numpy as np
import networkx as nx
import typing as tp

from sklearn.preprocessing import StandardScaler

In [16]:
def evm_classify(embeddings: np.ndarray, communities: list[set[int]],
                 test_nodes: np.ndarray, anomaly_threshold: float,
                 **evm_kwargs: tp.Any) -> np.ndarray:
    '''
    Helper function. Trains EVM on embeddings with class labels according to communitties.
    Returns class labels for test_nodes (-1 for anomalies).
    :param embeddings: embeddings for graph nodes
    :param communities: list of communitites (each community is presented by nodes indices)
    :param test_nodes: indices of nodes to test
    :param anomaly_threshold: probability threshold for anomalies
    :param evm_kwargs: kwargs for EVM
    '''
    train_nodes: list[int] = []
    train_labels: list[int] = []
    for i, community in enumerate(communities):
        train_nodes.extend(community)
        train_labels.extend([i for _ in range(len(community))])

    train_embeddings = embeddings[train_nodes]
    test_embeddings = embeddings[test_nodes]

    scaler = StandardScaler()
    train_embeddings = scaler.fit_transform(train_embeddings)
    test_embeddings = sclaer.transform(test_embeddings)
    
    model = TFiEVM(**evm_kwargs).fit(train_embeddings, train_labels)
    
    test_probs = model.predict_proba(test_embeddings)

    results = np.zeros(len(test_nodes))

    anomalies_mask = test_probs.max(axis=-1) < anomaly_threshold
    results[anomalies_mask] = -1
    results[~anomalies_mask] = model.predict_from_proba(test_probs[~anomalies_mask])

    return results

In [14]:
class OneEVMCommunities:
    '''
    Community search on evolving random graphs using one EVM, trained on all previous snapshot
    '''
    def __init__(self, init_graph: nx.MultiDiGraph,
                 embedding_class: Embedding,
                 anomaly_threshold: float,
                 **evm_kwargs: tp.Any) -> None:
        '''
        Initialises base partition of an initial graph.
        :param init_graph: initial graph (nodes are expected to be integers 0, 1, ..., len(init_graph) - 1)
        :param embedding_class: class of nodes embeddings
        :param anomaly_threshold: threshold for anomaly detection
        :param evm_kwargs: kwargs for EVM.__init__
        '''
        self.max_node = len(init_graph)

        self.communities = nx.communities.louvain_communities(init_graph)

        self.embeddings = embedding_class(init_graph)
        self.anomaly_threshold = anomaly_threshold
        self.evm_kwargs = evm_kwargs

    def update(self, snapshot: nx.MultiDiGraph) -> None:
        '''
        Updates the partition.
        :param snapshot: new snapshot of a graph
        '''
        self.embeddings.update(snapshot)
        embeddings = self.embeddings.to_numpy()

        test_nodes = np.arange(self.max_node, len(snapshot))
        test_labels = evm_classify(embeddings,
                                   self.communities,
                                   test_nodes,
                                   self.anomaly_threshold,
                                   **self.evm_kwargs)

        anomalies = test_nodes[test_labels == -1]

        old_communities_labels = zip(test_nodes[test_labels != -1],
                                     test_labels[test_labels != -1])
        for v, community in old_communities_labels:
            self.communities[community].add(v)

        if len(anomalies) != 0:
            self.communities.append(set(anomalies))

        self.max_node = len(snapshot)

    def get_communities(self) -> list[set[int]]:
        return self.communities

In [15]:
class DoubleEVMCommunities:
    '''
    Community search on evolving random graphs using two EVM, trained on normal communities and on anomalous ones
    '''
    def __init__(self, init_graph: nx.MultiDiGraph,
                 embedding_class: Embedding,
                 anomaly_threshold: float,
                 **evm_kwargs: tp.Any) -> None:
        '''
        Initialises base partition of an initial graph.
        :param init_graph: initial graph (nodes are expected to be integers 0, 1, ..., len(init_graph) - 1)
        :param embedding_class: class of nodes embeddings
        :param anomaly_threshold: threshold for anomaly detection
        :param evm_kwargs: kwargs for EVM.__init__
        '''
        self.max_node = len(init_graph)

        self.normal_communities = nx.communities.louvain_communities(init_graph)
        self.abnormal_communities: list[set[int]] = []

        self.embeddings = embedding_class(init_graph)
        self.anomaly_threshold = anomaly_threshold
        self.evm_kwargs = evm_kwargs

    def update(self, snapshot: nx.MultiDiGraph) -> None:
        '''
        Updates the partition.
        :param snapshot: new snapshot of a graph
        '''
        self.embeddings.update(snapshot)
        embeddings = self.embeddings.to_numpy()

        test_nodes = np.arange(self.max_node, len(snapshot))
        test_labels = evm_classify(embeddings,
                                   self.normal_communities,
                                   test_nodes,
                                   self.anomaly_threshold,
                                   **self.evm_kwargs)

        anomalies = test_nodes[test_labels == -1]

        normal_communities_labels = zip(test_nodes[test_labels != -1],
                                        test_labels[test_labels != -1])
        for v, community in normal_communities_labels:
            self.normal_communities[community].add(v)

        if len(self.abnormal_communities) == 0:
            self.abnormal_communities.append(set(anomalies))
        else:
            anomaly_labels = evm_classify(embeddings,
                                          self.abnormal_communities,
                                          anomalies,
                                          self.anomaly_threshold,
                                          **self.evm_kwargs)

            new_abnormal_community = anomalies[anomaly_labels == -1]

            abnormal_communities_labels = zip(anomalies[anomaly_labels != -1],
                                              anomaly_labels[anomaly_labels != -1])
            for v, community in abnormal_communities_labels:
                self.abnormal_communities[community].add(v)

            if len(new_abnormal_community) != 0:
                self.abnormal_communities.append(set(new_abnormal_community))

        self.max_node = len(snapshot)

    def get_communities(self) -> list[set[int]]:
        return self.normal_communities + self.abnormal_communities