In [31]:
import graspy
import numpy as np
import matplotlib as plt
import networkx as nx
#uncomment when placed into model folder not doc
#from .base import BaseGraphEstimator, _calculate_p

from graspy.simulations import sbm, er_np, er_nm

from graspy.models.base import BaseGraphEstimator 
from graspy.utils.utils import (
    augment_diagonal,
    cartprod,
    import_graph,
    is_unweighted,
    remove_loops,
    symmetrize,
)

%matplotlib inline

In [42]:
class SIEMEstimator(BaseGraphEstimator):
    r"""
    Stochastic Block Model 

    The stochastic block model (SBM) represents each node as belonging to a block 
    (or community). For a given potential edge between node :math:`i` and :math:`j`, 
    the probability of an edge existing is specified by the block that nodes :math:`i`
    and :math:`j` belong to:

    :math:`P_{ij} = B_{\tau_i \tau_j}`

    where :math:`B \in \mathbb{[0, 1]}^{K x K}` and :math:`\tau` is an `n\_nodes` 
    length vector specifying which block each node belongs to. 

    Read more in the :ref:`tutorials <models_tutorials>`

    Parameters
    ----------
    directed : boolean, optional (default=True)
        Whether to treat the input graph as directed. Even if a directed graph is inupt, 
        this determines whether to force symmetry upon the block probability matrix fit
        for the SBM. It will also determine whether graphs sampled from the model are 
        directed. 

    loops : boolean, optional (default=False)
        Whether to allow entries on the diagonal of the adjacency matrix, i.e. loops in 
        the graph where a node connects to itself. 

    n_components : int, optional (default=None)
        Desired dimensionality of embedding for clustering to find communities.
        ``n_components`` must be ``< min(X.shape)``. If None, then optimal dimensions 
        will be chosen by :func:`~graspy.embed.select_dimension``.

    min_comm : int, optional (default=1)
        The minimum number of communities (blocks) to consider. 

    max_comm : int, optional (default=10)
        The maximum number of communities (blocks) to consider (inclusive).

    cluster_kws : dict, optional (default={})
        Additional kwargs passed down to :class:`~graspy.cluster.GaussianCluster`
    
    embed_kws : dict, optional (default={})
        Additional kwargs passed down to :class:`~graspy.embed.AdjacencySpectralEmbed`

    Attributes
    ----------
    block_p_ : np.ndarray, shape (n_blocks, n_blocks)
        The block probability matrix :math:`B`, where the element :math:`B_{i, j}`
        represents the probability of an edge between block :math:`i` and block 
        :math:`j`.

    p_mat_ : np.ndarray, shape (n_verts, n_verts)
        Probability matrix :math:`P` for the fit model, from which graphs could be
        sampled.

    vertex_assignments_ : np.ndarray, shape (n_verts)
        A vector of integer labels corresponding to the predicted block that each node 
        belongs to if ``y`` was not passed during the call to ``fit``. 

    block_weights_ : np.ndarray, shape (n_blocks)
        Contains the proportion of nodes that belong to each block in the fit model.

    See also
    --------
    graspy.simulations.siem

    References
    ----------
    .. [1]  Holland, P. W., Laskey, K. B., & Leinhardt, S. (1983). Stochastic
            blockmodels: First steps. Social networks, 5(2), 109-137.
    """

    def __init__(
        self,
        directed=True,
        loops=False,
#         n_components=None,
#         min_comm=1,
#         max_comm=10,
#         cluster_kws={},
#         embed_kws={},
    ):
        super().__init__(directed=directed, loops=loops)
        self.model = {}

#         _check_common_inputs(n_components, min_comm, max_comm, cluster_kws, embed_kws)

#         self.cluster_kws = cluster_kws
#         self.n_components = n_components
#         self.min_comm = min_comm
#         self.max_comm = max_comm
#         self.embed_kws = embed_kws
          

#     def _estimate_assignments(self, graph):
#         """
#         Do some kind of clustering algorithm to estimate communities

#         There are many ways to do this, here is one
#         """
#         embed_graph = augment_diagonal(graph)
#         latent = AdjacencySpectralEmbed(
#             n_components=self.n_components, **self.embed_kws
#         ).fit_transform(embed_graph)
#         if isinstance(latent, tuple):
#             latent = np.concatenate(latent, axis=1)
#         gc = GaussianCluster(
#             min_components=self.min_comm,
#             max_components=self.max_comm,
#             **self.cluster_kws
#         )
#         vertex_assignments = gc.fit_predict(latent)
#         self.vertex_assignments_ = vertex_assignments

    def fit(self, graph, edge_comm, weighted):
        """
        Fit the SIEM to a graph
        
        Parameters
        ----------
        graph : array_like or networkx.Graph [nxn]
            Input graph to fit

        edge_comm : 2d list of k tuples (k_communities)
            Categorical labels for the block assignments of the graph
        
        weighted: boolean or float
            Boolean: True - do nothing or False - ensure everything is 0 or 1
            Float: binarize and use float as cutoff

        """
        #checks
        n = graph.shape[0]
        if not(isinstance(graph, (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph, np.ndarray))):
            msg = "graph must be a np.array or networkx.Graph"
            raise TypeError(msg)
        if not isinstance(edge_comm, list):
            msg = "Edge_comm must be a list"
            raise TypeError(msg)
        if len(edge_comm) >= n:
            msg = "warning more communities than n vertices"
            print(msg)
        if len(edge_comm) > n**2:
            msg = "Too many communities for this graph"
            raise TypeError(msg)
            
        if not(type(weighted) == bool or type(weighted) == float): 
            msg = "weighted must be a boolean or float"
            raise TypeError(msg)
                
        graph = import_graph(graph)
        
        if weighted == float: 
            graph = 1*(graph>weighted)
        
        if weighted == False: 
            if not np.array_equal(graph, graph.astype(bool)):
                msg = "graph of weighted = False must have binary inputs"
                raise TypeError(msg)
    
        for i in range(0,len(edge_comm)):
            self.model[i] = edge_comm[i]
            
        
#         if not is_unweighted(graph):
#             raise NotImplementedError(
#                 "Graph model is currently only implemented for unweighted graphs."
#             )

#         if y is None:
#             self._estimate_assignments(graph)
#             y = self.vertex_assignments_

#             _, counts = np.unique(y, return_counts=True)
#             self.block_weights_ = counts / graph.shape[0]
#         else:
#             check_X_y(graph, y)

#         block_vert_inds, block_inds, block_inv = _get_block_indices(y)

#         if not self.loops:
#             graph = remove_loops(graph)
#         block_p = _calculate_block_p(graph, block_inds, block_vert_inds)

#         if not self.directed:
#             block_p = symmetrize(block_p)
#         self.block_p_ = block_p

#         p_mat = _block_to_full(block_p, block_inv, graph.shape)
#         if not self.loops:
#             p_mat = remove_loops(p_mat)
#         self.p_mat_ = p_mat

#        return self

#     def _n_parameters(self):
#         n_blocks = self.block_p_.shape[0]
#         n_parameters = 0
#         if self.directed:
#             n_parameters += n_blocks ** 2
#         else:
#             n_parameters += n_blocks * (n_blocks + 1) / 2
#         if hasattr(self, "vertex_assignments_"):
#             n_parameters += n_blocks - 1
#         return n_parameters


In [43]:
siem_test = SIEMEstimator(directed = True, loops = False)

### Check SIEM grabes proper edges as given

In [47]:
g2 = sbm(n =[50,50], p = [[0.5,0.1],[0.1,0.5]])
edge_comm_1 = np.argwhere(g2[:50,:50] == 1)
edge_comm_2 = np.argwhere(g2[50:,50:] == 1) + 50 
comms = [edge_comm_1, edge_comm_2]
print(comms)
siem_test.fit(g2,edge_comm = comms, weighted = False)
siem_test.model

[array([[ 0,  2],
       [ 0,  6],
       [ 0,  7],
       ...,
       [49, 46],
       [49, 47],
       [49, 48]]), array([[50, 51],
       [50, 52],
       [50, 53],
       ...,
       [99, 93],
       [99, 95],
       [99, 97]])]


{0: array([[ 0,  2],
        [ 0,  6],
        [ 0,  7],
        ...,
        [49, 46],
        [49, 47],
        [49, 48]]), 1: array([[50, 51],
        [50, 52],
        [50, 53],
        ...,
        [99, 93],
        [99, 95],
        [99, 97]])}

### Check dictionary